refactor: use pinned golangci-lint Docker image for linting
All checks were successful
check / check (push) Successful in 1m37s
All checks were successful
check / check (push) Successful in 1m37s
Refactor Dockerfile to use a separate lint stage with a pinned golangci-lint v2.11.3 Docker image instead of installing golangci-lint via curl in the builder stage. This follows the pattern used by sneak/pixa. Changes: - Dockerfile: separate lint stage using golangci/golangci-lint:v2.11.3 (Debian-based, pinned by sha256) with COPY --from=lint dependency - Bump Go from 1.24 to 1.26.1 (golang:1.26.1-bookworm, pinned) - Bump golangci-lint from v1.64.8 to v2.11.3 - Migrate .golangci.yml from v1 to v2 format (same linters, format only) - All Docker images pinned by sha256 digest - Fix all lint issues from the v2 linter upgrade: - Add package comments to all packages - Add doc comments to all exported types, functions, and methods - Fix unchecked errors (errcheck) - Fix unused parameters (revive) - Fix gosec warnings (MaxBytesReader for form parsing) - Fix staticcheck suggestions (fmt.Fprintf instead of WriteString) - Rename DeliveryTask to Task to avoid stutter (delivery.Task) - Rename shadowed builtin 'max' parameter - Update README.md version requirements
This commit is contained in:
@@ -11,15 +11,16 @@ import (
|
||||
// This replaces gorm.Model but uses UUID instead of uint for ID
|
||||
type BaseModel struct {
|
||||
ID string `gorm:"type:uuid;primary_key" json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deletedAt,omitzero"`
|
||||
}
|
||||
|
||||
// BeforeCreate hook to set UUID before creating a record
|
||||
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
|
||||
// BeforeCreate hook to set UUID before creating a record.
|
||||
func (b *BaseModel) BeforeCreate(_ *gorm.DB) error {
|
||||
if b.ID == "" {
|
||||
b.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package database provides SQLite persistence for webhooks, events, and users.
|
||||
package database
|
||||
|
||||
import (
|
||||
@@ -19,30 +20,42 @@ import (
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
// nolint:revive // DatabaseParams is a standard fx naming convention
|
||||
const (
|
||||
dataDirPerm = 0750
|
||||
randomPasswordLen = 16
|
||||
sessionKeyLen = 32
|
||||
)
|
||||
|
||||
//nolint:revive // DatabaseParams is a standard fx naming convention.
|
||||
type DatabaseParams struct {
|
||||
fx.In
|
||||
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// Database manages the main SQLite connection and schema migrations.
|
||||
type Database struct {
|
||||
db *gorm.DB
|
||||
log *slog.Logger
|
||||
params *DatabaseParams
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
||||
// New creates a Database that connects on fx start and disconnects on stop.
|
||||
func New(
|
||||
lc fx.Lifecycle,
|
||||
params DatabaseParams,
|
||||
) (*Database, error) {
|
||||
d := &Database{
|
||||
params: ¶ms,
|
||||
log: params.Logger.Get(),
|
||||
}
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
OnStart: func(_ context.Context) error {
|
||||
return d.connect()
|
||||
},
|
||||
OnStop: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
OnStop: func(_ context.Context) error {
|
||||
return d.close()
|
||||
},
|
||||
})
|
||||
@@ -50,21 +63,92 @@ func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// DB returns the underlying GORM database handle.
|
||||
func (d *Database) DB() *gorm.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
// GetOrCreateSessionKey retrieves the session encryption key from the
|
||||
// settings table. If no key exists, a cryptographically secure random
|
||||
// 32-byte key is generated, base64-encoded, and stored for future use.
|
||||
func (d *Database) GetOrCreateSessionKey() (string, error) {
|
||||
var setting Setting
|
||||
|
||||
result := d.db.Where(
|
||||
&Setting{Key: "session_key"},
|
||||
).First(&setting)
|
||||
if result.Error == nil {
|
||||
return setting.Value, nil
|
||||
}
|
||||
|
||||
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", fmt.Errorf(
|
||||
"failed to query session key: %w",
|
||||
result.Error,
|
||||
)
|
||||
}
|
||||
|
||||
// Generate a new cryptographically secure 32-byte key
|
||||
keyBytes := make([]byte, sessionKeyLen)
|
||||
|
||||
_, err := rand.Read(keyBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf(
|
||||
"failed to generate session key: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(keyBytes)
|
||||
|
||||
setting = Setting{
|
||||
Key: "session_key",
|
||||
Value: encoded,
|
||||
}
|
||||
|
||||
err = d.db.Create(&setting).Error
|
||||
if err != nil {
|
||||
return "", fmt.Errorf(
|
||||
"failed to store session key: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
d.log.Info(
|
||||
"generated new session key and stored in database",
|
||||
)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (d *Database) connect() error {
|
||||
// Ensure the data directory exists before opening the database.
|
||||
dataDir := d.params.Config.DataDir
|
||||
if err := os.MkdirAll(dataDir, 0750); err != nil {
|
||||
return fmt.Errorf("creating data directory %s: %w", dataDir, err)
|
||||
|
||||
err := os.MkdirAll(dataDir, dataDirPerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"creating data directory %s: %w",
|
||||
dataDir,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
// Construct the main application database path inside DATA_DIR.
|
||||
dbPath := filepath.Join(dataDir, "webhooker.db")
|
||||
dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", dbPath)
|
||||
dbURL := fmt.Sprintf(
|
||||
"file:%s?cache=shared&mode=rwc",
|
||||
dbPath,
|
||||
)
|
||||
|
||||
// Open the database with the pure Go SQLite driver
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
d.log.Error("failed to open database", "error", err)
|
||||
d.log.Error(
|
||||
"failed to open database",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -73,7 +157,11 @@ func (d *Database) connect() error {
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
d.log.Error("failed to connect to database", "error", err)
|
||||
d.log.Error(
|
||||
"failed to connect to database",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -86,101 +174,100 @@ func (d *Database) connect() error {
|
||||
|
||||
func (d *Database) migrate() error {
|
||||
// Run GORM auto-migrations
|
||||
if err := d.Migrate(); err != nil {
|
||||
d.log.Error("failed to run database migrations", "error", err)
|
||||
err := d.Migrate()
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to run database migrations",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Info("database migrations completed")
|
||||
|
||||
// Check if admin user exists
|
||||
var userCount int64
|
||||
if err := d.db.Model(&User{}).Count(&userCount).Error; err != nil {
|
||||
d.log.Error("failed to count users", "error", err)
|
||||
|
||||
err = d.db.Model(&User{}).Count(&userCount).Error
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to count users",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if userCount == 0 {
|
||||
// Create admin user
|
||||
d.log.Info("no users found, creating admin user")
|
||||
|
||||
// Generate random password
|
||||
password, err := GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
d.log.Error("failed to generate random password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
hashedPassword, err := HashPassword(password)
|
||||
if err != nil {
|
||||
d.log.Error("failed to hash password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create admin user
|
||||
adminUser := &User{
|
||||
Username: "admin",
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
if err := d.db.Create(adminUser).Error; err != nil {
|
||||
d.log.Error("failed to create admin user", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Info("admin user created",
|
||||
"username", "admin",
|
||||
"password", password,
|
||||
"message", "SAVE THIS PASSWORD - it will not be shown again!",
|
||||
)
|
||||
return d.createAdminUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) createAdminUser() error {
|
||||
d.log.Info("no users found, creating admin user")
|
||||
|
||||
// Generate random password
|
||||
password, err := GenerateRandomPassword(
|
||||
randomPasswordLen,
|
||||
)
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to generate random password",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
hashedPassword, err := HashPassword(password)
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to hash password",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Create admin user
|
||||
adminUser := &User{
|
||||
Username: "admin",
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
err = d.db.Create(adminUser).Error
|
||||
if err != nil {
|
||||
d.log.Error(
|
||||
"failed to create admin user",
|
||||
"error", err,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
d.log.Info("admin user created",
|
||||
"username", "admin",
|
||||
"password", password,
|
||||
"message",
|
||||
"SAVE THIS PASSWORD - it will not be shown again!",
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) close() error {
|
||||
if d.db != nil {
|
||||
sqlDB, err := d.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) DB() *gorm.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
// GetOrCreateSessionKey retrieves the session encryption key from the
|
||||
// settings table. If no key exists, a cryptographically secure random
|
||||
// 32-byte key is generated, base64-encoded, and stored for future use.
|
||||
func (d *Database) GetOrCreateSessionKey() (string, error) {
|
||||
var setting Setting
|
||||
result := d.db.Where(&Setting{Key: "session_key"}).First(&setting)
|
||||
if result.Error == nil {
|
||||
return setting.Value, nil
|
||||
}
|
||||
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", fmt.Errorf("failed to query session key: %w", result.Error)
|
||||
}
|
||||
|
||||
// Generate a new cryptographically secure 32-byte key
|
||||
keyBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(keyBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate session key: %w", err)
|
||||
}
|
||||
encoded := base64.StdEncoding.EncodeToString(keyBytes)
|
||||
|
||||
setting = Setting{
|
||||
Key: "session_key",
|
||||
Value: encoded,
|
||||
}
|
||||
if err := d.db.Create(&setting).Error; err != nil {
|
||||
return "", fmt.Errorf("failed to store session key: %w", err)
|
||||
}
|
||||
|
||||
d.log.Info("generated new session key and stored in database")
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package database
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -6,37 +6,37 @@ import (
|
||||
|
||||
"go.uber.org/fx/fxtest"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
func TestDatabaseConnection(t *testing.T) {
|
||||
// Set up test dependencies
|
||||
func setupTestDB(
|
||||
t *testing.T,
|
||||
) (*database.Database, *fxtest.Lifecycle) {
|
||||
t.Helper()
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
// Create globals
|
||||
globals.Appname = "webhooker-test"
|
||||
globals.Version = "test"
|
||||
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create globals: %v", err)
|
||||
g := &globals.Globals{
|
||||
Appname: "webhooker-test",
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
// Create logger
|
||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
||||
l, err := logger.New(
|
||||
lc,
|
||||
logger.LoggerParams{Globals: g},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create logger: %v", err)
|
||||
}
|
||||
|
||||
// Create config with DataDir pointing to a temp directory
|
||||
c := &config.Config{
|
||||
DataDir: t.TempDir(),
|
||||
Environment: "dev",
|
||||
}
|
||||
|
||||
// Create database
|
||||
db, err := New(lc, DatabaseParams{
|
||||
db, err := database.New(lc, database.DatabaseParams{
|
||||
Config: c,
|
||||
Logger: l,
|
||||
})
|
||||
@@ -44,31 +44,45 @@ func TestDatabaseConnection(t *testing.T) {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
|
||||
// Start lifecycle (this will trigger the connection)
|
||||
return db, lc
|
||||
}
|
||||
|
||||
func TestDatabaseConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, lc := setupTestDB(t)
|
||||
ctx := context.Background()
|
||||
err = lc.Start(ctx)
|
||||
|
||||
err := lc.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if stopErr := lc.Stop(ctx); stopErr != nil {
|
||||
t.Errorf("Failed to stop lifecycle: %v", stopErr)
|
||||
stopErr := lc.Stop(ctx)
|
||||
if stopErr != nil {
|
||||
t.Errorf(
|
||||
"Failed to stop lifecycle: %v",
|
||||
stopErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
// Verify we can get the DB instance
|
||||
if db.DB() == nil {
|
||||
t.Error("Expected non-nil database connection")
|
||||
}
|
||||
|
||||
// Test that we can perform a simple query
|
||||
var result int
|
||||
|
||||
err = db.DB().Raw("SELECT 1").Scan(&result).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute test query: %v", err)
|
||||
}
|
||||
|
||||
if result != 1 {
|
||||
t.Errorf("Expected query result to be 1, got %d", result)
|
||||
t.Errorf(
|
||||
"Expected query result to be 1, got %d",
|
||||
result,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,11 +6,11 @@ import "time"
|
||||
type APIKey struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
UserID string `gorm:"type:uuid;not null" json:"userId"`
|
||||
Key string `gorm:"uniqueIndex;not null" json:"key"`
|
||||
Description string `json:"description"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
LastUsedAt *time.Time `json:"lastUsedAt,omitempty"`
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitempty"`
|
||||
User User `json:"user,omitzero"`
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package database
|
||||
// DeliveryStatus represents the status of a delivery
|
||||
type DeliveryStatus string
|
||||
|
||||
// Delivery status values.
|
||||
const (
|
||||
DeliveryStatusPending DeliveryStatus = "pending"
|
||||
DeliveryStatusDelivered DeliveryStatus = "delivered"
|
||||
@@ -14,12 +15,12 @@ const (
|
||||
type Delivery struct {
|
||||
BaseModel
|
||||
|
||||
EventID string `gorm:"type:uuid;not null" json:"event_id"`
|
||||
TargetID string `gorm:"type:uuid;not null" json:"target_id"`
|
||||
EventID string `gorm:"type:uuid;not null" json:"eventId"`
|
||||
TargetID string `gorm:"type:uuid;not null" json:"targetId"`
|
||||
Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"`
|
||||
|
||||
// Relations
|
||||
Event Event `json:"event,omitempty"`
|
||||
Target Target `json:"target,omitempty"`
|
||||
DeliveryResults []DeliveryResult `json:"delivery_results,omitempty"`
|
||||
Event Event `json:"event,omitzero"`
|
||||
Target Target `json:"target,omitzero"`
|
||||
DeliveryResults []DeliveryResult `json:"deliveryResults,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,14 +4,14 @@ package database
|
||||
type DeliveryResult struct {
|
||||
BaseModel
|
||||
|
||||
DeliveryID string `gorm:"type:uuid;not null" json:"delivery_id"`
|
||||
AttemptNum int `gorm:"not null" json:"attempt_num"`
|
||||
DeliveryID string `gorm:"type:uuid;not null" json:"deliveryId"`
|
||||
AttemptNum int `gorm:"not null" json:"attemptNum"`
|
||||
Success bool `json:"success"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ResponseBody string `gorm:"type:text" json:"response_body,omitempty"`
|
||||
StatusCode int `json:"statusCode,omitempty"`
|
||||
ResponseBody string `gorm:"type:text" json:"responseBody,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Duration int64 `json:"duration_ms"` // Duration in milliseconds
|
||||
Duration int64 `json:"durationMs"` // Duration in milliseconds
|
||||
|
||||
// Relations
|
||||
Delivery Delivery `json:"delivery,omitempty"`
|
||||
Delivery Delivery `json:"delivery,omitzero"`
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ package database
|
||||
type Entrypoint struct {
|
||||
BaseModel
|
||||
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||
Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this entrypoint
|
||||
Description string `json:"description"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
|
||||
// Relations
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
Webhook Webhook `json:"webhook,omitzero"`
|
||||
}
|
||||
|
||||
@@ -4,17 +4,17 @@ package database
|
||||
type Event struct {
|
||||
BaseModel
|
||||
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
EntrypointID string `gorm:"type:uuid;not null" json:"entrypoint_id"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||
EntrypointID string `gorm:"type:uuid;not null" json:"entrypointId"`
|
||||
|
||||
// Request data
|
||||
Method string `gorm:"not null" json:"method"`
|
||||
Headers string `gorm:"type:text" json:"headers"` // JSON
|
||||
Body string `gorm:"type:text" json:"body"`
|
||||
ContentType string `json:"content_type"`
|
||||
Method string `gorm:"not null" json:"method"`
|
||||
Headers string `gorm:"type:text" json:"headers"` // JSON
|
||||
Body string `gorm:"type:text" json:"body"`
|
||||
ContentType string `json:"contentType"`
|
||||
|
||||
// Relations
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
Entrypoint Entrypoint `json:"entrypoint,omitempty"`
|
||||
Webhook Webhook `json:"webhook,omitzero"`
|
||||
Entrypoint Entrypoint `json:"entrypoint,omitzero"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
|
||||
@@ -3,6 +3,6 @@ package database
|
||||
// Setting stores application-level key-value configuration.
|
||||
// Used for auto-generated values like the session encryption key.
|
||||
type Setting struct {
|
||||
Key string `gorm:"primaryKey" json:"key"`
|
||||
Key string `gorm:"primaryKey" json:"key"`
|
||||
Value string `gorm:"type:text;not null" json:"value"`
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package database
|
||||
// TargetType represents the type of delivery target
|
||||
type TargetType string
|
||||
|
||||
// Target type values.
|
||||
const (
|
||||
TargetTypeHTTP TargetType = "http"
|
||||
TargetTypeDatabase TargetType = "database"
|
||||
@@ -14,19 +15,19 @@ const (
|
||||
type Target struct {
|
||||
BaseModel
|
||||
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Type TargetType `gorm:"not null" json:"type"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhookId"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Type TargetType `gorm:"not null" json:"type"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
|
||||
// Configuration fields (JSON stored based on type)
|
||||
Config string `gorm:"type:text" json:"config"` // JSON configuration
|
||||
|
||||
// For HTTP targets (max_retries=0 means fire-and-forget, >0 enables retries with backoff)
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
MaxQueueSize int `json:"max_queue_size,omitempty"`
|
||||
MaxRetries int `json:"maxRetries,omitempty"`
|
||||
MaxQueueSize int `json:"maxQueueSize,omitempty"`
|
||||
|
||||
// Relations
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
Webhook Webhook `json:"webhook,omitzero"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ type User struct {
|
||||
BaseModel
|
||||
|
||||
Username string `gorm:"uniqueIndex;not null" json:"username"`
|
||||
Password string `gorm:"not null" json:"-"` // Argon2 hashed
|
||||
Password string `gorm:"not null" json:"-"` // Argon2 hashed
|
||||
|
||||
// Relations
|
||||
Webhooks []Webhook `json:"webhooks,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
APIKeys []APIKey `json:"apiKeys,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ package database
|
||||
type Webhook struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
UserID string `gorm:"type:uuid;not null" json:"userId"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Description string `json:"description"`
|
||||
RetentionDays int `gorm:"default:30" json:"retention_days"` // Days to retain events
|
||||
RetentionDays int `gorm:"default:30" json:"retentionDays"` // Days to retain events
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitempty"`
|
||||
User User `json:"user,omitzero"`
|
||||
Entrypoints []Entrypoint `json:"entrypoints,omitempty"`
|
||||
Targets []Target `json:"targets,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
@@ -20,6 +21,23 @@ const (
|
||||
argon2SaltLen = 16
|
||||
)
|
||||
|
||||
// hashParts is the expected number of $-separated segments
|
||||
// in an encoded Argon2id hash string.
|
||||
const hashParts = 6
|
||||
|
||||
// minPasswordComplexityLen is the minimum password length that
|
||||
// triggers per-character-class complexity enforcement.
|
||||
const minPasswordComplexityLen = 4
|
||||
|
||||
// Sentinel errors returned by decodeHash.
|
||||
var (
|
||||
errInvalidHashFormat = errors.New("invalid hash format")
|
||||
errInvalidAlgorithm = errors.New("invalid algorithm")
|
||||
errIncompatibleVersion = errors.New("incompatible argon2 version")
|
||||
errSaltLengthOutOfRange = errors.New("salt length out of range")
|
||||
errHashLengthOutOfRange = errors.New("hash length out of range")
|
||||
)
|
||||
|
||||
// PasswordConfig holds Argon2 configuration
|
||||
type PasswordConfig struct {
|
||||
Time uint32
|
||||
@@ -46,26 +64,44 @@ func HashPassword(password string) (string, error) {
|
||||
|
||||
// Generate a salt
|
||||
salt := make([]byte, config.SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
|
||||
_, err := rand.Read(salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate the hash
|
||||
hash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
hash := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
config.Time,
|
||||
config.Memory,
|
||||
config.Threads,
|
||||
config.KeyLen,
|
||||
)
|
||||
|
||||
// Encode the hash and parameters
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash
|
||||
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, config.Memory, config.Time, config.Threads, b64Salt, b64Hash)
|
||||
encoded := fmt.Sprintf(
|
||||
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version,
|
||||
config.Memory,
|
||||
config.Time,
|
||||
config.Threads,
|
||||
b64Salt,
|
||||
b64Hash,
|
||||
)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks if the provided password matches the hash
|
||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
func VerifyPassword(
|
||||
password, encodedHash string,
|
||||
) (bool, error) {
|
||||
// Extract parameters and hash from encoded string
|
||||
config, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
@@ -73,60 +109,119 @@ func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
}
|
||||
|
||||
// Generate hash of the provided password
|
||||
otherHash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
otherHash := argon2.IDKey(
|
||||
[]byte(password),
|
||||
salt,
|
||||
config.Time,
|
||||
config.Memory,
|
||||
config.Threads,
|
||||
config.KeyLen,
|
||||
)
|
||||
|
||||
// Compare hashes using constant time comparison
|
||||
return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil
|
||||
}
|
||||
|
||||
// decodeHash extracts parameters, salt, and hash from an encoded hash string
|
||||
func decodeHash(encodedHash string) (*PasswordConfig, []byte, []byte, error) {
|
||||
// decodeHash extracts parameters, salt, and hash from an
|
||||
// encoded hash string.
|
||||
func decodeHash(
|
||||
encodedHash string,
|
||||
) (*PasswordConfig, []byte, []byte, error) {
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return nil, nil, nil, fmt.Errorf("invalid hash format")
|
||||
if len(parts) != hashParts {
|
||||
return nil, nil, nil, errInvalidHashFormat
|
||||
}
|
||||
|
||||
if parts[1] != "argon2id" {
|
||||
return nil, nil, nil, fmt.Errorf("invalid algorithm")
|
||||
return nil, nil, nil, errInvalidAlgorithm
|
||||
}
|
||||
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
version, err := parseVersion(parts[2])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, fmt.Errorf("incompatible argon2 version")
|
||||
return nil, nil, nil, errIncompatibleVersion
|
||||
}
|
||||
|
||||
config := &PasswordConfig{}
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &config.Memory, &config.Time, &config.Threads); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
config, err := parseParams(parts[3])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
saltLen := len(salt)
|
||||
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("salt length out of range")
|
||||
}
|
||||
config.SaltLen = uint32(saltLen) // nolint:gosec // checked above
|
||||
|
||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
salt, err := decodeSalt(parts[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
hashLen := len(hash)
|
||||
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("hash length out of range")
|
||||
|
||||
config.SaltLen = uint32(len(salt)) //nolint:gosec // validated in decodeSalt
|
||||
|
||||
hash, err := decodeHashBytes(parts[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
config.KeyLen = uint32(hashLen) // nolint:gosec // checked above
|
||||
|
||||
config.KeyLen = uint32(len(hash)) //nolint:gosec // validated in decodeHashBytes
|
||||
|
||||
return config, salt, hash, nil
|
||||
}
|
||||
|
||||
// GenerateRandomPassword generates a cryptographically secure random password
|
||||
func parseVersion(s string) (int, error) {
|
||||
var version int
|
||||
|
||||
_, err := fmt.Sscanf(s, "v=%d", &version)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing version: %w", err)
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
func parseParams(s string) (*PasswordConfig, error) {
|
||||
config := &PasswordConfig{}
|
||||
|
||||
_, err := fmt.Sscanf(
|
||||
s, "m=%d,t=%d,p=%d",
|
||||
&config.Memory, &config.Time, &config.Threads,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing params: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func decodeSalt(s string) ([]byte, error) {
|
||||
salt, err := base64.RawStdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding salt: %w", err)
|
||||
}
|
||||
|
||||
saltLen := len(salt)
|
||||
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
||||
return nil, errSaltLengthOutOfRange
|
||||
}
|
||||
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func decodeHashBytes(s string) ([]byte, error) {
|
||||
hash, err := base64.RawStdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding hash: %w", err)
|
||||
}
|
||||
|
||||
hashLen := len(hash)
|
||||
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
||||
return nil, errHashLengthOutOfRange
|
||||
}
|
||||
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// GenerateRandomPassword generates a cryptographically secure
|
||||
// random password.
|
||||
func GenerateRandomPassword(length int) (string, error) {
|
||||
const (
|
||||
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
@@ -141,27 +236,27 @@ func GenerateRandomPassword(length int) (string, error) {
|
||||
// Create password slice
|
||||
password := make([]byte, length)
|
||||
|
||||
// Ensure at least one character from each set for password complexity
|
||||
if length >= 4 {
|
||||
// Get one character from each set
|
||||
// Ensure at least one character from each set
|
||||
if length >= minPasswordComplexityLen {
|
||||
password[0] = uppercase[cryptoRandInt(len(uppercase))]
|
||||
password[1] = lowercase[cryptoRandInt(len(lowercase))]
|
||||
password[2] = digits[cryptoRandInt(len(digits))]
|
||||
password[3] = special[cryptoRandInt(len(special))]
|
||||
|
||||
// Fill the rest randomly from all characters
|
||||
for i := 4; i < length; i++ {
|
||||
for i := minPasswordComplexityLen; i < length; i++ {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
|
||||
// Shuffle the password to avoid predictable pattern
|
||||
for i := len(password) - 1; i > 0; i-- {
|
||||
j := cryptoRandInt(i + 1)
|
||||
password[i], password[j] = password[j], password[i]
|
||||
for i := range len(password) - 1 {
|
||||
j := cryptoRandInt(len(password) - i)
|
||||
idx := len(password) - 1 - i
|
||||
password[idx], password[j] = password[j], password[idx]
|
||||
}
|
||||
} else {
|
||||
// For very short passwords, just use all characters
|
||||
for i := 0; i < length; i++ {
|
||||
for i := range length {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
}
|
||||
@@ -169,16 +264,17 @@ func GenerateRandomPassword(length int) (string, error) {
|
||||
return string(password), nil
|
||||
}
|
||||
|
||||
// cryptoRandInt generates a cryptographically secure random integer in [0, max)
|
||||
func cryptoRandInt(max int) int {
|
||||
if max <= 0 {
|
||||
panic("max must be positive")
|
||||
// cryptoRandInt generates a cryptographically secure random
|
||||
// integer in [0, upperBound).
|
||||
func cryptoRandInt(upperBound int) int {
|
||||
if upperBound <= 0 {
|
||||
panic("upperBound must be positive")
|
||||
}
|
||||
|
||||
// Calculate the maximum valid value to avoid modulo bias
|
||||
// For example, if max=200 and we have 256 possible values,
|
||||
// we only accept values 0-199 (reject 200-255)
|
||||
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
|
||||
nBig, err := rand.Int(
|
||||
rand.Reader,
|
||||
big.NewInt(int64(upperBound)),
|
||||
)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand error: %v", err))
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package database
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
)
|
||||
|
||||
func TestGenerateRandomPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
@@ -18,109 +22,172 @@ func TestGenerateRandomPassword(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
password, err := GenerateRandomPassword(tt.length)
|
||||
t.Parallel()
|
||||
|
||||
password, err := database.GenerateRandomPassword(
|
||||
tt.length,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
t.Fatalf(
|
||||
"GenerateRandomPassword() error = %v",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if len(password) != tt.length {
|
||||
t.Errorf("Password length = %v, want %v", len(password), tt.length)
|
||||
t.Errorf(
|
||||
"Password length = %v, want %v",
|
||||
len(password), tt.length,
|
||||
)
|
||||
}
|
||||
|
||||
// For passwords >= 4 chars, check complexity
|
||||
if tt.length >= 4 {
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
hasSpecial := false
|
||||
|
||||
for _, char := range password {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
|
||||
t.Errorf("Password lacks required complexity: upper=%v, lower=%v, digit=%v, special=%v",
|
||||
hasUpper, hasLower, hasDigit, hasSpecial)
|
||||
}
|
||||
}
|
||||
checkPasswordComplexity(
|
||||
t, password, tt.length,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkPasswordComplexity(
|
||||
t *testing.T,
|
||||
password string,
|
||||
length int,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
// For passwords >= 4 chars, check complexity
|
||||
if length < 4 {
|
||||
return
|
||||
}
|
||||
|
||||
flags := classifyChars(password)
|
||||
|
||||
if !flags[0] || !flags[1] || !flags[2] || !flags[3] {
|
||||
t.Errorf(
|
||||
"Password lacks required complexity: "+
|
||||
"upper=%v, lower=%v, digit=%v, special=%v",
|
||||
flags[0], flags[1], flags[2], flags[3],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func classifyChars(s string) [4]bool {
|
||||
var flags [4]bool // upper, lower, digit, special
|
||||
|
||||
for _, char := range s {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
flags[0] = true
|
||||
case char >= 'a' && char <= 'z':
|
||||
flags[1] = true
|
||||
case char >= '0' && char <= '9':
|
||||
flags[2] = true
|
||||
case strings.ContainsRune(
|
||||
"!@#$%^&*()_+-=[]{}|;:,.<>?",
|
||||
char,
|
||||
):
|
||||
flags[3] = true
|
||||
}
|
||||
}
|
||||
|
||||
return flags
|
||||
}
|
||||
|
||||
func TestGenerateRandomPasswordUniqueness(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate multiple passwords and ensure they're different
|
||||
passwords := make(map[string]bool)
|
||||
|
||||
const numPasswords = 100
|
||||
|
||||
for i := 0; i < numPasswords; i++ {
|
||||
password, err := GenerateRandomPassword(16)
|
||||
for range numPasswords {
|
||||
password, err := database.GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
t.Fatalf(
|
||||
"GenerateRandomPassword() error = %v",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if passwords[password] {
|
||||
t.Errorf("Duplicate password generated: %s", password)
|
||||
t.Errorf(
|
||||
"Duplicate password generated: %s",
|
||||
password,
|
||||
)
|
||||
}
|
||||
|
||||
passwords[password] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
password := "testPassword123!"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
hash, err := database.HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
// Check that hash has correct format
|
||||
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||
t.Errorf("Hash doesn't have correct prefix: %s", hash)
|
||||
t.Errorf(
|
||||
"Hash doesn't have correct prefix: %s",
|
||||
hash,
|
||||
)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
valid, err := VerifyPassword(password, hash)
|
||||
valid, err := database.VerifyPassword(password, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("VerifyPassword() returned false for correct password")
|
||||
t.Error(
|
||||
"VerifyPassword() returned false " +
|
||||
"for correct password",
|
||||
)
|
||||
}
|
||||
|
||||
// Verify wrong password fails
|
||||
valid, err = VerifyPassword("wrongPassword", hash)
|
||||
valid, err = database.VerifyPassword(
|
||||
"wrongPassword", hash,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("VerifyPassword() returned true for wrong password")
|
||||
t.Error(
|
||||
"VerifyPassword() returned true " +
|
||||
"for wrong password",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPasswordUniqueness(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
password := "testPassword123!"
|
||||
|
||||
// Same password should produce different hashes due to salt
|
||||
hash1, err := HashPassword(password)
|
||||
// Same password should produce different hashes
|
||||
hash1, err := database.HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
hash2, err := HashPassword(password)
|
||||
hash2, err := database.HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error("Same password produced identical hashes (salt not working)")
|
||||
t.Error(
|
||||
"Same password produced identical hashes " +
|
||||
"(salt not working)",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@@ -16,87 +17,82 @@ import (
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
// nolint:revive // WebhookDBManagerParams is a standard fx naming convention
|
||||
// WebhookDBManagerParams holds the fx dependencies for
|
||||
// WebhookDBManager.
|
||||
type WebhookDBManagerParams struct {
|
||||
fx.In
|
||||
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// WebhookDBManager manages per-webhook SQLite database files for event storage.
|
||||
// Each webhook gets its own dedicated database containing Events, Deliveries,
|
||||
// and DeliveryResults. Database connections are opened lazily and cached.
|
||||
// errInvalidCachedDBType indicates a type assertion failure
|
||||
// when retrieving a cached database connection.
|
||||
var errInvalidCachedDBType = errors.New(
|
||||
"invalid cached database type",
|
||||
)
|
||||
|
||||
// WebhookDBManager manages per-webhook SQLite database files
|
||||
// for event storage. Each webhook gets its own dedicated
|
||||
// database containing Events, Deliveries, and DeliveryResults.
|
||||
// Database connections are opened lazily and cached.
|
||||
type WebhookDBManager struct {
|
||||
dataDir string
|
||||
dbs sync.Map // map[webhookID]*gorm.DB
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// NewWebhookDBManager creates a new WebhookDBManager and registers lifecycle hooks.
|
||||
func NewWebhookDBManager(lc fx.Lifecycle, params WebhookDBManagerParams) (*WebhookDBManager, error) {
|
||||
// NewWebhookDBManager creates a new WebhookDBManager and
|
||||
// registers lifecycle hooks.
|
||||
func NewWebhookDBManager(
|
||||
lc fx.Lifecycle,
|
||||
params WebhookDBManagerParams,
|
||||
) (*WebhookDBManager, error) {
|
||||
m := &WebhookDBManager{
|
||||
dataDir: params.Config.DataDir,
|
||||
log: params.Logger.Get(),
|
||||
}
|
||||
|
||||
// Create data directory if it doesn't exist
|
||||
if err := os.MkdirAll(m.dataDir, 0750); err != nil {
|
||||
return nil, fmt.Errorf("creating data directory %s: %w", m.dataDir, err)
|
||||
err := os.MkdirAll(m.dataDir, dataDirPerm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"creating data directory %s: %w",
|
||||
m.dataDir,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStop: func(_ context.Context) error { //nolint:revive // ctx unused but required by fx
|
||||
OnStop: func(_ context.Context) error {
|
||||
return m.CloseAll()
|
||||
},
|
||||
})
|
||||
|
||||
m.log.Info("webhook database manager initialized", "data_dir", m.dataDir)
|
||||
m.log.Info(
|
||||
"webhook database manager initialized",
|
||||
"data_dir", m.dataDir,
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// dbPath returns the filesystem path for a webhook's database file.
|
||||
func (m *WebhookDBManager) dbPath(webhookID string) string {
|
||||
return filepath.Join(m.dataDir, fmt.Sprintf("events-%s.db", webhookID))
|
||||
}
|
||||
|
||||
// openDB opens (or creates) a per-webhook SQLite database and runs migrations.
|
||||
func (m *WebhookDBManager) openDB(webhookID string) (*gorm.DB, error) {
|
||||
path := m.dbPath(webhookID)
|
||||
dbURL := fmt.Sprintf("file:%s?cache=shared&mode=rwc", path)
|
||||
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening webhook database %s: %w", webhookID, err)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("connecting to webhook database %s: %w", webhookID, err)
|
||||
}
|
||||
|
||||
// Run migrations for event-tier models only
|
||||
if err := db.AutoMigrate(&Event{}, &Delivery{}, &DeliveryResult{}); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("migrating webhook database %s: %w", webhookID, err)
|
||||
}
|
||||
|
||||
m.log.Info("opened per-webhook database", "webhook_id", webhookID, "path", path)
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// GetDB returns the database connection for a webhook, creating the database
|
||||
// file lazily if it doesn't exist. This handles both new webhooks and existing
|
||||
// webhooks that were created before per-webhook databases were introduced.
|
||||
func (m *WebhookDBManager) GetDB(webhookID string) (*gorm.DB, error) {
|
||||
// GetDB returns the database connection for a webhook,
|
||||
// creating the database file lazily if it doesn't exist.
|
||||
func (m *WebhookDBManager) GetDB(
|
||||
webhookID string,
|
||||
) (*gorm.DB, error) {
|
||||
// Fast path: already open
|
||||
if val, ok := m.dbs.Load(webhookID); ok {
|
||||
cachedDB, castOK := val.(*gorm.DB)
|
||||
if !castOK {
|
||||
return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
|
||||
return nil, fmt.Errorf(
|
||||
"%w for webhook %s",
|
||||
errInvalidCachedDBType,
|
||||
webhookID,
|
||||
)
|
||||
}
|
||||
|
||||
return cachedDB, nil
|
||||
}
|
||||
|
||||
@@ -106,44 +102,61 @@ func (m *WebhookDBManager) GetDB(webhookID string) (*gorm.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store it; if another goroutine beat us, close ours and use theirs
|
||||
// Store it; if another goroutine beat us, close ours
|
||||
actual, loaded := m.dbs.LoadOrStore(webhookID, db)
|
||||
if loaded {
|
||||
// Another goroutine created it first; close our duplicate
|
||||
if sqlDB, closeErr := db.DB(); closeErr == nil {
|
||||
sqlDB.Close()
|
||||
sqlDB, closeErr := db.DB()
|
||||
if closeErr == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
|
||||
existingDB, castOK := actual.(*gorm.DB)
|
||||
if !castOK {
|
||||
return nil, fmt.Errorf("invalid cached database type for webhook %s", webhookID)
|
||||
return nil, fmt.Errorf(
|
||||
"%w for webhook %s",
|
||||
errInvalidCachedDBType,
|
||||
webhookID,
|
||||
)
|
||||
}
|
||||
|
||||
return existingDB, nil
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// CreateDB explicitly creates a new per-webhook database file and runs migrations.
|
||||
// This is called when a new webhook is created.
|
||||
func (m *WebhookDBManager) CreateDB(webhookID string) error {
|
||||
// CreateDB explicitly creates a new per-webhook database file
|
||||
// and runs migrations.
|
||||
func (m *WebhookDBManager) CreateDB(
|
||||
webhookID string,
|
||||
) error {
|
||||
_, err := m.GetDB(webhookID)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DBExists checks if a per-webhook database file exists on disk.
|
||||
func (m *WebhookDBManager) DBExists(webhookID string) bool {
|
||||
// DBExists checks if a per-webhook database file exists on
|
||||
// disk.
|
||||
func (m *WebhookDBManager) DBExists(
|
||||
webhookID string,
|
||||
) bool {
|
||||
_, err := os.Stat(m.dbPath(webhookID))
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// DeleteDB closes the connection and deletes the database file for a webhook.
|
||||
// This performs a hard delete — the file is permanently removed.
|
||||
func (m *WebhookDBManager) DeleteDB(webhookID string) error {
|
||||
// DeleteDB closes the connection and deletes the database file
|
||||
// for a webhook. The file is permanently removed.
|
||||
func (m *WebhookDBManager) DeleteDB(
|
||||
webhookID string,
|
||||
) error {
|
||||
// Close and remove from cache
|
||||
if val, ok := m.dbs.LoadAndDelete(webhookID); ok {
|
||||
if gormDB, castOK := val.(*gorm.DB); castOK {
|
||||
if sqlDB, err := gormDB.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
sqlDB, err := gormDB.DB()
|
||||
if err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -151,12 +164,20 @@ func (m *WebhookDBManager) DeleteDB(webhookID string) error {
|
||||
// Delete the main DB file and WAL/SHM files
|
||||
path := m.dbPath(webhookID)
|
||||
for _, suffix := range []string{"", "-wal", "-shm"} {
|
||||
if err := os.Remove(path + suffix); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("deleting webhook database file %s%s: %w", path, suffix, err)
|
||||
err := os.Remove(path + suffix)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf(
|
||||
"deleting webhook database file %s%s: %w",
|
||||
path, suffix, err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
m.log.Info("deleted per-webhook database", "webhook_id", webhookID)
|
||||
m.log.Info(
|
||||
"deleted per-webhook database",
|
||||
"webhook_id", webhookID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -164,20 +185,97 @@ func (m *WebhookDBManager) DeleteDB(webhookID string) error {
|
||||
// Called during application shutdown.
|
||||
func (m *WebhookDBManager) CloseAll() error {
|
||||
var lastErr error
|
||||
m.dbs.Range(func(key, value interface{}) bool {
|
||||
|
||||
m.dbs.Range(func(key, value any) bool {
|
||||
if gormDB, castOK := value.(*gorm.DB); castOK {
|
||||
if sqlDB, err := gormDB.DB(); err == nil {
|
||||
if closeErr := sqlDB.Close(); closeErr != nil {
|
||||
sqlDB, err := gormDB.DB()
|
||||
if err == nil {
|
||||
closeErr := sqlDB.Close()
|
||||
if closeErr != nil {
|
||||
lastErr = closeErr
|
||||
m.log.Error("failed to close webhook database",
|
||||
m.log.Error(
|
||||
"failed to close webhook database",
|
||||
"webhook_id", key,
|
||||
"error", closeErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.dbs.Delete(key)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// DBPath returns the filesystem path for a webhook's database
|
||||
// file.
|
||||
func (m *WebhookDBManager) DBPath(
|
||||
webhookID string,
|
||||
) string {
|
||||
return m.dbPath(webhookID)
|
||||
}
|
||||
|
||||
func (m *WebhookDBManager) dbPath(
|
||||
webhookID string,
|
||||
) string {
|
||||
return filepath.Join(
|
||||
m.dataDir,
|
||||
fmt.Sprintf("events-%s.db", webhookID),
|
||||
)
|
||||
}
|
||||
|
||||
// openDB opens (or creates) a per-webhook SQLite database and
|
||||
// runs migrations.
|
||||
func (m *WebhookDBManager) openDB(
|
||||
webhookID string,
|
||||
) (*gorm.DB, error) {
|
||||
path := m.dbPath(webhookID)
|
||||
dbURL := fmt.Sprintf(
|
||||
"file:%s?cache=shared&mode=rwc",
|
||||
path,
|
||||
)
|
||||
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"opening webhook database %s: %w",
|
||||
webhookID, err,
|
||||
)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
_ = sqlDB.Close()
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
"connecting to webhook database %s: %w",
|
||||
webhookID, err,
|
||||
)
|
||||
}
|
||||
|
||||
// Run migrations for event-tier models only
|
||||
err = db.AutoMigrate(
|
||||
&Event{}, &Delivery{}, &DeliveryResult{},
|
||||
)
|
||||
if err != nil {
|
||||
_ = sqlDB.Close()
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
"migrating webhook database %s: %w",
|
||||
webhookID, err,
|
||||
)
|
||||
}
|
||||
|
||||
m.log.Info(
|
||||
"opened per-webhook database",
|
||||
"webhook_id", webhookID,
|
||||
"path", path,
|
||||
)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package database
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -10,23 +10,29 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/fx/fxtest"
|
||||
"gorm.io/gorm"
|
||||
"sneak.berlin/go/webhooker/internal/config"
|
||||
"sneak.berlin/go/webhooker/internal/database"
|
||||
"sneak.berlin/go/webhooker/internal/globals"
|
||||
"sneak.berlin/go/webhooker/internal/logger"
|
||||
)
|
||||
|
||||
func setupTestWebhookDBManager(t *testing.T) (*WebhookDBManager, *fxtest.Lifecycle) {
|
||||
func setupTestWebhookDBManager(
|
||||
t *testing.T,
|
||||
) (*database.WebhookDBManager, *fxtest.Lifecycle) {
|
||||
t.Helper()
|
||||
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
globals.Appname = "webhooker-test"
|
||||
globals.Version = "test"
|
||||
g := &globals.Globals{
|
||||
Appname: "webhooker-test",
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
g, err := globals.New(lc)
|
||||
require.NoError(t, err)
|
||||
|
||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
||||
l, err := logger.New(
|
||||
lc,
|
||||
logger.LoggerParams{Globals: g},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
dataDir := filepath.Join(t.TempDir(), "events")
|
||||
@@ -35,19 +41,25 @@ func setupTestWebhookDBManager(t *testing.T) (*WebhookDBManager, *fxtest.Lifecyc
|
||||
DataDir: dataDir,
|
||||
}
|
||||
|
||||
mgr, err := NewWebhookDBManager(lc, WebhookDBManagerParams{
|
||||
Config: cfg,
|
||||
Logger: l,
|
||||
})
|
||||
mgr, err := database.NewWebhookDBManager(
|
||||
lc,
|
||||
database.WebhookDBManagerParams{
|
||||
Config: cfg,
|
||||
Logger: l,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return mgr, lc
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
@@ -68,7 +80,7 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
||||
require.NotNil(t, db)
|
||||
|
||||
// Verify we can write an event
|
||||
event := &Event{
|
||||
event := &database.Event{
|
||||
WebhookID: webhookID,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
@@ -80,27 +92,35 @@ func TestWebhookDBManager_CreateAndGetDB(t *testing.T) {
|
||||
assert.NotEmpty(t, event.ID)
|
||||
|
||||
// Verify we can read it back
|
||||
var readEvent Event
|
||||
require.NoError(t, db.First(&readEvent, "id = ?", event.ID).Error)
|
||||
var readEvent database.Event
|
||||
|
||||
require.NoError(
|
||||
t,
|
||||
db.First(&readEvent, "id = ?", event.ID).Error,
|
||||
)
|
||||
assert.Equal(t, webhookID, readEvent.WebhookID)
|
||||
assert.Equal(t, "POST", readEvent.Method)
|
||||
assert.Equal(t, `{"test": true}`, readEvent.Body)
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_DeleteDB(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
|
||||
// Create the DB and write some data
|
||||
require.NoError(t, mgr.CreateDB(webhookID))
|
||||
|
||||
db, err := mgr.GetDB(webhookID)
|
||||
require.NoError(t, err)
|
||||
|
||||
event := &Event{
|
||||
event := &database.Event{
|
||||
WebhookID: webhookID,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
@@ -116,15 +136,19 @@ func TestWebhookDBManager_DeleteDB(t *testing.T) {
|
||||
assert.False(t, mgr.DBExists(webhookID))
|
||||
|
||||
// Verify the file is actually gone from disk
|
||||
dbPath := mgr.dbPath(webhookID)
|
||||
dbPath := mgr.DBPath(webhookID)
|
||||
|
||||
_, err = os.Stat(dbPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_LazyCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
@@ -139,9 +163,12 @@ func TestWebhookDBManager_LazyCreation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhookID := uuid.New().String()
|
||||
@@ -150,8 +177,23 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
||||
db, err := mgr.GetDB(webhookID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an event
|
||||
event := &Event{
|
||||
event, delivery := seedDeliveryWorkflow(
|
||||
t, db, webhookID, targetID,
|
||||
)
|
||||
|
||||
verifyPendingDeliveries(t, db, event)
|
||||
completeDelivery(t, db, delivery)
|
||||
verifyNoPending(t, db)
|
||||
}
|
||||
|
||||
func seedDeliveryWorkflow(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
webhookID, targetID string,
|
||||
) (*database.Event, *database.Delivery) {
|
||||
t.Helper()
|
||||
|
||||
event := &database.Event{
|
||||
WebhookID: webhookID,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
@@ -161,25 +203,45 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, db.Create(event).Error)
|
||||
|
||||
// Create a delivery
|
||||
delivery := &Delivery{
|
||||
delivery := &database.Delivery{
|
||||
EventID: event.ID,
|
||||
TargetID: targetID,
|
||||
Status: DeliveryStatusPending,
|
||||
Status: database.DeliveryStatusPending,
|
||||
}
|
||||
require.NoError(t, db.Create(delivery).Error)
|
||||
|
||||
// Query pending deliveries
|
||||
var pending []Delivery
|
||||
require.NoError(t, db.Where("status = ?", DeliveryStatusPending).
|
||||
Preload("Event").
|
||||
Find(&pending).Error)
|
||||
return event, delivery
|
||||
}
|
||||
|
||||
func verifyPendingDeliveries(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
event *database.Event,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
var pending []database.Delivery
|
||||
|
||||
require.NoError(
|
||||
t,
|
||||
db.Where(
|
||||
"status = ?",
|
||||
database.DeliveryStatusPending,
|
||||
).Preload("Event").Find(&pending).Error,
|
||||
)
|
||||
require.Len(t, pending, 1)
|
||||
assert.Equal(t, event.ID, pending[0].EventID)
|
||||
assert.Equal(t, "POST", pending[0].Event.Method)
|
||||
}
|
||||
|
||||
// Create a delivery result
|
||||
result := &DeliveryResult{
|
||||
func completeDelivery(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
delivery *database.Delivery,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
result := &database.DeliveryResult{
|
||||
DeliveryID: delivery.ID,
|
||||
AttemptNum: 1,
|
||||
Success: true,
|
||||
@@ -188,19 +250,40 @@ func TestWebhookDBManager_DeliveryWorkflow(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, db.Create(result).Error)
|
||||
|
||||
// Update delivery status
|
||||
require.NoError(t, db.Model(delivery).Update("status", DeliveryStatusDelivered).Error)
|
||||
require.NoError(
|
||||
t,
|
||||
db.Model(delivery).Update(
|
||||
"status",
|
||||
database.DeliveryStatusDelivered,
|
||||
).Error,
|
||||
)
|
||||
}
|
||||
|
||||
// Verify no more pending deliveries
|
||||
var stillPending []Delivery
|
||||
require.NoError(t, db.Where("status = ?", DeliveryStatusPending).Find(&stillPending).Error)
|
||||
func verifyNoPending(
|
||||
t *testing.T,
|
||||
db *gorm.DB,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
var stillPending []database.Delivery
|
||||
|
||||
require.NoError(
|
||||
t,
|
||||
db.Where(
|
||||
"status = ?",
|
||||
database.DeliveryStatusPending,
|
||||
).Find(&stillPending).Error,
|
||||
)
|
||||
assert.Empty(t, stillPending)
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
defer func() { require.NoError(t, lc.Stop(ctx)) }()
|
||||
|
||||
webhook1 := uuid.New().String()
|
||||
@@ -212,34 +295,38 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
||||
|
||||
db1, err := mgr.GetDB(webhook1)
|
||||
require.NoError(t, err)
|
||||
|
||||
db2, err := mgr.GetDB(webhook2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write events to each webhook's DB
|
||||
event1 := &Event{
|
||||
event1 := &database.Event{
|
||||
WebhookID: webhook1,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "POST",
|
||||
Body: `{"webhook": 1}`,
|
||||
ContentType: "application/json",
|
||||
}
|
||||
event2 := &Event{
|
||||
event2 := &database.Event{
|
||||
WebhookID: webhook2,
|
||||
EntrypointID: uuid.New().String(),
|
||||
Method: "PUT",
|
||||
Body: `{"webhook": 2}`,
|
||||
ContentType: "application/json",
|
||||
}
|
||||
|
||||
require.NoError(t, db1.Create(event1).Error)
|
||||
require.NoError(t, db2.Create(event2).Error)
|
||||
|
||||
// Verify isolation: each DB only has its own events
|
||||
var count1 int64
|
||||
db1.Model(&Event{}).Count(&count1)
|
||||
|
||||
db1.Model(&database.Event{}).Count(&count1)
|
||||
assert.Equal(t, int64(1), count1)
|
||||
|
||||
var count2 int64
|
||||
db2.Model(&Event{}).Count(&count2)
|
||||
|
||||
db2.Model(&database.Event{}).Count(&count2)
|
||||
assert.Equal(t, int64(1), count2)
|
||||
|
||||
// Delete webhook1's DB, webhook2 should be unaffected
|
||||
@@ -248,25 +335,31 @@ func TestWebhookDBManager_MultipleWebhooks(t *testing.T) {
|
||||
assert.True(t, mgr.DBExists(webhook2))
|
||||
|
||||
// webhook2's data should still be accessible
|
||||
var events []Event
|
||||
var events []database.Event
|
||||
|
||||
require.NoError(t, db2.Find(&events).Error)
|
||||
assert.Len(t, events, 1)
|
||||
assert.Equal(t, "PUT", events[0].Method)
|
||||
}
|
||||
|
||||
func TestWebhookDBManager_CloseAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr, lc := setupTestWebhookDBManager(t)
|
||||
ctx := context.Background()
|
||||
require.NoError(t, lc.Start(ctx))
|
||||
|
||||
// Create a few DBs
|
||||
for i := 0; i < 3; i++ {
|
||||
require.NoError(t, mgr.CreateDB(uuid.New().String()))
|
||||
for range 3 {
|
||||
require.NoError(
|
||||
t,
|
||||
mgr.CreateDB(uuid.New().String()),
|
||||
)
|
||||
}
|
||||
|
||||
// CloseAll should close all connections without error
|
||||
require.NoError(t, mgr.CloseAll())
|
||||
|
||||
// Stop lifecycle (CloseAll already called, but shouldn't panic)
|
||||
// Stop lifecycle (CloseAll already called)
|
||||
require.NoError(t, lc.Stop(ctx))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user