Inject Config as dependency for database, routing table, and snapshotter
- Remove old database Config struct and related functions - Update database.New() to accept config.Config parameter - Update routingtable.New() to accept config.Config parameter - Update snapshotter.New() to accept config.Config parameter - Simplify fx module providers in app.go - Fix truthiness check for environment variables - Handle empty state directory gracefully in routing table and snapshotter - Update all tests to use empty state directory for testing
This commit is contained in:
parent
1a0622efaa
commit
d15a5e91b9
91
internal/config/config.go
Normal file
91
internal/config/config.go
Normal file
@ -0,0 +1,91 @@
|
||||
// Package config provides centralized configuration management for RouteWatch
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// AppIdentifier is the reverse domain name identifier for the app
|
||||
AppIdentifier = "berlin.sneak.app.routewatch"
|
||||
|
||||
// dirPermissions for creating directories
|
||||
dirPermissions = 0750 // rwxr-x---
|
||||
)
|
||||
|
||||
// Config holds configuration for the entire application
|
||||
type Config struct {
|
||||
// StateDir is the directory for all application state (database, snapshots)
|
||||
StateDir string
|
||||
|
||||
// MaxRuntime is the maximum runtime (0 = run forever)
|
||||
MaxRuntime time.Duration
|
||||
}
|
||||
|
||||
// New creates a new Config with default paths based on the OS
|
||||
func New() (*Config, error) {
|
||||
stateDir, err := getStateDirectory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine state directory: %w", err)
|
||||
}
|
||||
|
||||
return &Config{
|
||||
StateDir: stateDir,
|
||||
MaxRuntime: 0, // Run forever by default
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStateDir returns the state directory path
|
||||
func (c *Config) GetStateDir() string {
|
||||
return c.StateDir
|
||||
}
|
||||
|
||||
// getStateDirectory returns the appropriate state directory based on the OS
|
||||
func getStateDirectory() (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS: ~/Library/Application Support/berlin.sneak.app.routewatch
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, "Library", "Application Support", AppIdentifier), nil
|
||||
|
||||
case "linux", "freebsd", "openbsd", "netbsd":
|
||||
// Unix-like: /var/lib/berlin.sneak.app.routewatch if root, else XDG_DATA_HOME
|
||||
if os.Geteuid() == 0 {
|
||||
return filepath.Join("/var/lib", AppIdentifier), nil
|
||||
}
|
||||
|
||||
// Check XDG_DATA_HOME first
|
||||
if xdgData := os.Getenv("XDG_DATA_HOME"); xdgData != "" {
|
||||
return filepath.Join(xdgData, AppIdentifier), nil
|
||||
}
|
||||
|
||||
// Fall back to ~/.local/share/berlin.sneak.app.routewatch
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".local", "share", AppIdentifier), nil
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureDirectories creates all necessary directories if they don't exist
|
||||
func (c *Config) EnsureDirectories() error {
|
||||
// Ensure state directory exists
|
||||
if err := os.MkdirAll(c.StateDir, dirPermissions); err != nil {
|
||||
return fmt.Errorf("failed to create state directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -8,9 +8,9 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/routewatch/internal/config"
|
||||
"git.eeqj.de/sneak/routewatch/pkg/asinfo"
|
||||
"github.com/google/uuid"
|
||||
_ "github.com/mattn/go-sqlite3" // CGO SQLite driver
|
||||
@ -28,77 +28,22 @@ type Database struct {
|
||||
path string
|
||||
}
|
||||
|
||||
// Config holds database configuration
|
||||
type Config struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
// getDefaultDatabasePath returns the appropriate database path for the OS
|
||||
func getDefaultDatabasePath() string {
|
||||
const dbFilename = "db.sqlite"
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS: ~/Library/Application Support/berlin.sneak.app.routewatch/db.sqlite
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return dbFilename
|
||||
}
|
||||
appSupport := filepath.Join(home, "Library", "Application Support", "berlin.sneak.app.routewatch")
|
||||
if err := os.MkdirAll(appSupport, dirPermissions); err != nil {
|
||||
return dbFilename
|
||||
}
|
||||
|
||||
return filepath.Join(appSupport, dbFilename)
|
||||
default:
|
||||
// Linux and others: /var/lib/routewatch/db.sqlite
|
||||
dbDir := "/var/lib/routewatch"
|
||||
if err := os.MkdirAll(dbDir, dirPermissions); err != nil {
|
||||
// Fall back to user's home directory if can't create system directory
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return dbFilename
|
||||
}
|
||||
userDir := filepath.Join(home, ".local", "share", "routewatch")
|
||||
if err := os.MkdirAll(userDir, dirPermissions); err != nil {
|
||||
return dbFilename
|
||||
}
|
||||
|
||||
return filepath.Join(userDir, dbFilename)
|
||||
}
|
||||
|
||||
return filepath.Join(dbDir, dbFilename)
|
||||
}
|
||||
}
|
||||
|
||||
// NewConfig provides default database configuration
|
||||
func NewConfig() Config {
|
||||
return Config{
|
||||
Path: getDefaultDatabasePath(),
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new database connection and initializes the schema.
|
||||
func New(logger *slog.Logger) (*Database, error) {
|
||||
config := NewConfig()
|
||||
func New(cfg *config.Config, logger *slog.Logger) (*Database, error) {
|
||||
dbPath := filepath.Join(cfg.GetStateDir(), "db.sqlite")
|
||||
|
||||
return NewWithConfig(config, logger)
|
||||
}
|
||||
|
||||
// NewWithConfig creates a new database connection with custom configuration
|
||||
func NewWithConfig(config Config, logger *slog.Logger) (*Database, error) {
|
||||
// Log database path
|
||||
logger.Info("Opening database", "path", config.Path)
|
||||
logger.Info("Opening database", "path", dbPath)
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(config.Path)
|
||||
dir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dir, dirPermissions); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
|
||||
// Add connection parameters for go-sqlite3
|
||||
// Enable WAL mode and other performance optimizations
|
||||
dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_journal_mode=WAL&_synchronous=NORMAL&cache=shared", config.Path)
|
||||
dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_journal_mode=WAL&_synchronous=NORMAL&cache=shared", dbPath)
|
||||
db, err := sql.Open("sqlite3", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
@ -113,7 +58,7 @@ func NewWithConfig(config Config, logger *slog.Logger) (*Database, error) {
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
database := &Database{db: db, logger: logger, path: config.Path}
|
||||
database := &Database{db: db, logger: logger, path: dbPath}
|
||||
|
||||
if err := database.Initialize(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/routewatch/internal/config"
|
||||
"git.eeqj.de/sneak/routewatch/internal/database"
|
||||
"git.eeqj.de/sneak/routewatch/internal/metrics"
|
||||
"git.eeqj.de/sneak/routewatch/internal/routingtable"
|
||||
@ -21,23 +22,11 @@ import (
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// Config contains runtime configuration for RouteWatch
|
||||
type Config struct {
|
||||
MaxRuntime time.Duration // Maximum runtime (0 = run forever)
|
||||
}
|
||||
|
||||
const (
|
||||
// routingTableStatsInterval is how often we log routing table statistics
|
||||
routingTableStatsInterval = 15 * time.Second
|
||||
)
|
||||
|
||||
// NewConfig provides default configuration
|
||||
func NewConfig() Config {
|
||||
return Config{
|
||||
MaxRuntime: 0, // Run forever by default
|
||||
}
|
||||
}
|
||||
|
||||
// Dependencies contains all dependencies for RouteWatch
|
||||
type Dependencies struct {
|
||||
fx.In
|
||||
@ -47,7 +36,7 @@ type Dependencies struct {
|
||||
Streamer *streamer.Streamer
|
||||
Server *server.Server
|
||||
Logger *slog.Logger
|
||||
Config Config `optional:"true"`
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
// RouteWatch represents the main application instance
|
||||
@ -63,6 +52,17 @@ type RouteWatch struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// isTruthy returns true if the value is considered truthy
|
||||
// Empty string, "0", and "false" are considered falsy, everything else is truthy
|
||||
func isTruthy(value string) bool {
|
||||
return value != "" && value != "0" && value != "false"
|
||||
}
|
||||
|
||||
// isSnapshotterEnabled checks if the snapshotter should be enabled based on environment variable
|
||||
func isSnapshotterEnabled() bool {
|
||||
return !isTruthy(os.Getenv("ROUTEWATCH_DISABLE_SNAPSHOTTER"))
|
||||
}
|
||||
|
||||
// New creates a new RouteWatch instance
|
||||
func New(deps Dependencies) *RouteWatch {
|
||||
rw := &RouteWatch{
|
||||
@ -74,9 +74,9 @@ func New(deps Dependencies) *RouteWatch {
|
||||
maxRuntime: deps.Config.MaxRuntime,
|
||||
}
|
||||
|
||||
// Create snapshotter unless disabled (for tests)
|
||||
if os.Getenv("ROUTEWATCH_DISABLE_SNAPSHOTTER") != "1" {
|
||||
snap, err := snapshotter.New(deps.RoutingTable, deps.Logger)
|
||||
// Create snapshotter if enabled
|
||||
if isSnapshotterEnabled() {
|
||||
snap, err := snapshotter.New(deps.RoutingTable, deps.Config, deps.Logger)
|
||||
if err != nil {
|
||||
deps.Logger.Error("Failed to create snapshotter", "error", err)
|
||||
// Continue without snapshotter
|
||||
@ -235,13 +235,10 @@ func getModule() fx.Option {
|
||||
return fx.Options(
|
||||
fx.Provide(
|
||||
NewLogger,
|
||||
NewConfig,
|
||||
config.New,
|
||||
metrics.New,
|
||||
database.New,
|
||||
fx.Annotate(
|
||||
func(db *database.Database) database.Store {
|
||||
return db
|
||||
},
|
||||
database.New,
|
||||
fx.As(new(database.Store)),
|
||||
),
|
||||
routingtable.New,
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/routewatch/internal/config"
|
||||
"git.eeqj.de/sneak/routewatch/internal/database"
|
||||
"git.eeqj.de/sneak/routewatch/internal/metrics"
|
||||
"git.eeqj.de/sneak/routewatch/internal/routingtable"
|
||||
@ -171,8 +172,14 @@ func TestRouteWatchLiveFeed(t *testing.T) {
|
||||
// Create streamer
|
||||
s := streamer.New(logger, metricsTracker)
|
||||
|
||||
// Create test config with empty state dir (no snapshot loading)
|
||||
cfg := &config.Config{
|
||||
StateDir: "",
|
||||
MaxRuntime: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Create routing table
|
||||
rt := routingtable.New(logger)
|
||||
rt := routingtable.New(cfg, logger)
|
||||
|
||||
// Create server
|
||||
srv := server.New(mockDB, rt, s, logger)
|
||||
@ -184,9 +191,7 @@ func TestRouteWatchLiveFeed(t *testing.T) {
|
||||
Streamer: s,
|
||||
Server: srv,
|
||||
Logger: logger,
|
||||
Config: Config{
|
||||
MaxRuntime: 5 * time.Second,
|
||||
},
|
||||
Config: cfg,
|
||||
}
|
||||
rw := New(deps)
|
||||
|
||||
|
@ -8,12 +8,12 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/routewatch/internal/config"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@ -23,7 +23,7 @@ const (
|
||||
routeStalenessThreshold = 30 * time.Minute
|
||||
|
||||
// snapshotFilename is the name of the snapshot file
|
||||
snapshotFilename = "routewatch-snapshot.json.gz"
|
||||
snapshotFilename = "routingtable.json.gz"
|
||||
)
|
||||
|
||||
// Route represents a single route entry in the routing table
|
||||
@ -62,16 +62,20 @@ type RoutingTable struct {
|
||||
ipv4Updates uint64 // Updates counter for rate calculation
|
||||
ipv6Updates uint64 // Updates counter for rate calculation
|
||||
lastMetricsReset time.Time
|
||||
|
||||
// Configuration
|
||||
snapshotDir string
|
||||
}
|
||||
|
||||
// New creates a new routing table, loading from snapshot if available
|
||||
func New(logger *slog.Logger) *RoutingTable {
|
||||
func New(cfg *config.Config, logger *slog.Logger) *RoutingTable {
|
||||
rt := &RoutingTable{
|
||||
routes: make(map[RouteKey]*Route),
|
||||
byPrefix: make(map[uuid.UUID]map[RouteKey]*Route),
|
||||
byOriginASN: make(map[uuid.UUID]map[RouteKey]*Route),
|
||||
byPeerASN: make(map[int]map[RouteKey]*Route),
|
||||
lastMetricsReset: time.Now(),
|
||||
snapshotDir: cfg.GetStateDir(),
|
||||
}
|
||||
|
||||
// Try to load from snapshot
|
||||
@ -447,51 +451,18 @@ func isIPv6(prefix string) bool {
|
||||
return strings.Contains(prefix, ":")
|
||||
}
|
||||
|
||||
// getStateDirectory returns the appropriate state directory based on the OS
|
||||
func getStateDirectory() (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS: Use ~/Library/Application Support/routewatch
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, "Library", "Application Support", "routewatch"), nil
|
||||
case "linux", "freebsd", "openbsd", "netbsd":
|
||||
// Unix-like: Use /var/lib/routewatch if running as root, otherwise use XDG_STATE_HOME
|
||||
if os.Geteuid() == 0 {
|
||||
return "/var/lib/routewatch", nil
|
||||
}
|
||||
// Check XDG_STATE_HOME first
|
||||
if xdgState := os.Getenv("XDG_STATE_HOME"); xdgState != "" {
|
||||
return filepath.Join(xdgState, "routewatch"), nil
|
||||
}
|
||||
// Fall back to ~/.local/state/routewatch
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".local", "state", "routewatch"), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// loadFromSnapshot attempts to load the routing table from a snapshot file
|
||||
func (rt *RoutingTable) loadFromSnapshot(logger *slog.Logger) error {
|
||||
stateDir, err := getStateDirectory()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to determine state directory: %w", err)
|
||||
// If no snapshot directory specified, nothing to load
|
||||
if rt.snapshotDir == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
snapshotPath := filepath.Join(stateDir, snapshotFilename)
|
||||
snapshotPath := filepath.Join(rt.snapshotDir, snapshotFilename)
|
||||
|
||||
// Check if snapshot file exists
|
||||
if _, err := os.Stat(snapshotPath); os.IsNotExist(err) {
|
||||
logger.Info("No snapshot file found, starting with empty routing table")
|
||||
|
||||
// No snapshot file exists, this is normal - start with empty routing table
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -6,13 +6,20 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/routewatch/internal/config"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestRoutingTable(t *testing.T) {
|
||||
// Create a test logger
|
||||
logger := slog.Default()
|
||||
rt := New(logger)
|
||||
|
||||
// Create test config with empty state dir (no snapshot loading)
|
||||
cfg := &config.Config{
|
||||
StateDir: "",
|
||||
}
|
||||
|
||||
rt := New(cfg, logger)
|
||||
|
||||
// Test data
|
||||
prefixID1 := uuid.New()
|
||||
@ -123,7 +130,13 @@ func TestRoutingTable(t *testing.T) {
|
||||
func TestRoutingTableConcurrency(t *testing.T) {
|
||||
// Create a test logger
|
||||
logger := slog.Default()
|
||||
rt := New(logger)
|
||||
|
||||
// Create test config with empty state dir (no snapshot loading)
|
||||
cfg := &config.Config{
|
||||
StateDir: "",
|
||||
}
|
||||
|
||||
rt := New(cfg, logger)
|
||||
|
||||
// Test concurrent access
|
||||
var wg sync.WaitGroup
|
||||
@ -177,7 +190,13 @@ func TestRoutingTableConcurrency(t *testing.T) {
|
||||
func TestRouteUpdate(t *testing.T) {
|
||||
// Create a test logger
|
||||
logger := slog.Default()
|
||||
rt := New(logger)
|
||||
|
||||
// Create test config with empty state dir (no snapshot loading)
|
||||
cfg := &config.Config{
|
||||
StateDir: "",
|
||||
}
|
||||
|
||||
rt := New(cfg, logger)
|
||||
|
||||
prefixID := uuid.New()
|
||||
originASNID := uuid.New()
|
||||
|
@ -10,16 +10,16 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/routewatch/internal/config"
|
||||
"git.eeqj.de/sneak/routewatch/internal/routingtable"
|
||||
)
|
||||
|
||||
const (
|
||||
snapshotInterval = 10 * time.Minute
|
||||
snapshotFilename = "routewatch-snapshot.json.gz"
|
||||
snapshotFilename = "routingtable.json.gz"
|
||||
tempFileSuffix = ".tmp"
|
||||
)
|
||||
|
||||
@ -36,16 +36,15 @@ type Snapshotter struct {
|
||||
}
|
||||
|
||||
// New creates a new Snapshotter instance
|
||||
func New(rt *routingtable.RoutingTable, logger *slog.Logger) (*Snapshotter, error) {
|
||||
stateDir, err := getStateDirectory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine state directory: %w", err)
|
||||
}
|
||||
func New(rt *routingtable.RoutingTable, cfg *config.Config, logger *slog.Logger) (*Snapshotter, error) {
|
||||
stateDir := cfg.GetStateDir()
|
||||
|
||||
// Ensure state directory exists
|
||||
const stateDirPerms = 0750
|
||||
if err := os.MkdirAll(stateDir, stateDirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create state directory: %w", err)
|
||||
// If state directory is specified, ensure it exists
|
||||
if stateDir != "" {
|
||||
const stateDirPerms = 0750
|
||||
if err := os.MkdirAll(stateDir, stateDirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create snapshot directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s := &Snapshotter{
|
||||
@ -76,38 +75,6 @@ func (s *Snapshotter) Start(ctx context.Context) {
|
||||
go s.periodicSnapshot()
|
||||
}
|
||||
|
||||
// getStateDirectory returns the appropriate state directory based on the OS
|
||||
func getStateDirectory() (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS: Use ~/Library/Application Support/routewatch
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, "Library", "Application Support", "routewatch"), nil
|
||||
case "linux", "freebsd", "openbsd", "netbsd":
|
||||
// Unix-like: Use /var/lib/routewatch if running as root, otherwise use XDG_STATE_HOME
|
||||
if os.Geteuid() == 0 {
|
||||
return "/var/lib/routewatch", nil
|
||||
}
|
||||
// Check XDG_STATE_HOME first
|
||||
if xdgState := os.Getenv("XDG_STATE_HOME"); xdgState != "" {
|
||||
return filepath.Join(xdgState, "routewatch"), nil
|
||||
}
|
||||
// Fall back to ~/.local/state/routewatch
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".local", "state", "routewatch"), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// periodicSnapshot runs periodic snapshots
|
||||
func (s *Snapshotter) periodicSnapshot() {
|
||||
defer s.wg.Done()
|
||||
@ -130,6 +97,11 @@ func (s *Snapshotter) periodicSnapshot() {
|
||||
|
||||
// TakeSnapshot creates a snapshot of the current routing table state
|
||||
func (s *Snapshotter) TakeSnapshot() error {
|
||||
// Can't take snapshot without a state directory
|
||||
if s.stateDir == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure only one snapshot runs at a time
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
Loading…
Reference in New Issue
Block a user