From ae2ef2ae0c2232a0615a86455a89f9a96ed16b9e Mon Sep 17 00:00:00 2001 From: sneak Date: Mon, 28 Jul 2025 00:03:19 +0200 Subject: [PATCH] Implement routing table snapshotter with automatic loading on startup - Create snapshotter package with periodic (10 min) and on-demand snapshots - Add JSON serialization with gzip compression and atomic file writes - Update routing table to track AddedAt time for each route - Load snapshots on startup, filtering out stale routes (>30 minutes old) - Add ROUTEWATCH_DISABLE_SNAPSHOTTER env var for tests - Use OS-appropriate state directories (macOS: ~/Library/Application Support, Linux: /var/lib or XDG_STATE_HOME) --- internal/routewatch/app.go | 34 ++++- internal/routewatch/app_integration_test.go | 5 +- internal/routewatch/dbhandler.go | 53 ++++++-- internal/routingtable/routingtable.go | 140 +++++++++++++++++++- internal/routingtable/routingtable_test.go | 13 +- internal/snapshotter/snapshotter.go | 45 +++++-- 6 files changed, 258 insertions(+), 32 deletions(-) diff --git a/internal/routewatch/app.go b/internal/routewatch/app.go index 3a0b626..e17791c 100644 --- a/internal/routewatch/app.go +++ b/internal/routewatch/app.go @@ -14,6 +14,7 @@ import ( "git.eeqj.de/sneak/routewatch/internal/metrics" "git.eeqj.de/sneak/routewatch/internal/routingtable" "git.eeqj.de/sneak/routewatch/internal/server" + "git.eeqj.de/sneak/routewatch/internal/snapshotter" "git.eeqj.de/sneak/routewatch/internal/streamer" "go.uber.org/fx" @@ -54,13 +55,14 @@ type RouteWatch struct { routingTable *routingtable.RoutingTable streamer *streamer.Streamer server *server.Server + snapshotter *snapshotter.Snapshotter logger *slog.Logger maxRuntime time.Duration } // New creates a new RouteWatch instance func New(deps Dependencies) *RouteWatch { - return &RouteWatch{ + rw := &RouteWatch{ db: deps.DB, routingTable: deps.RoutingTable, streamer: deps.Streamer, @@ -68,6 +70,19 @@ func New(deps Dependencies) *RouteWatch { logger: deps.Logger, maxRuntime: deps.Config.MaxRuntime, } + + // Create snapshotter unless disabled (for tests) + if os.Getenv("ROUTEWATCH_DISABLE_SNAPSHOTTER") != "1" { + snapshotter, err := snapshotter.New(deps.RoutingTable, deps.Logger) + if err != nil { + deps.Logger.Error("Failed to create snapshotter", "error", err) + // Continue without snapshotter + } else { + rw.snapshotter = snapshotter + } + } + + return rw } // Run starts the RouteWatch application @@ -97,6 +112,11 @@ func (rw *RouteWatch) Run(ctx context.Context) error { // Start periodic routing table stats logging go rw.logRoutingTableStats(ctx) + // Start snapshotter if available + if rw.snapshotter != nil { + rw.snapshotter.Start(ctx) + } + // Start streaming if err := rw.streamer.Start(); err != nil { return err @@ -131,6 +151,13 @@ func (rw *RouteWatch) Run(ctx context.Context) error { "duration", time.Since(metrics.ConnectedSince), ) + // Take final snapshot before shutdown if snapshotter is available + if rw.snapshotter != nil { + if err := rw.snapshotter.Shutdown(); err != nil { + rw.logger.Error("Failed to shutdown snapshotter", "error", err) + } + } + return nil } @@ -199,10 +226,7 @@ func getModule() fx.Option { ), routingtable.New, streamer.New, - fx.Annotate( - server.New, - fx.ParamTags(``, ``, ``, ``), - ), + server.New, New, ), ) diff --git a/internal/routewatch/app_integration_test.go b/internal/routewatch/app_integration_test.go index fe2c1d8..90379dd 100644 --- a/internal/routewatch/app_integration_test.go +++ b/internal/routewatch/app_integration_test.go @@ -156,6 +156,9 @@ func (m *mockStore) GetStats() (database.Stats, error) { } func TestRouteWatchLiveFeed(t *testing.T) { + // Disable snapshotter for tests + t.Setenv("ROUTEWATCH_DISABLE_SNAPSHOTTER", "1") + // Create mock database mockDB := newMockStore() defer mockDB.Close() @@ -169,7 +172,7 @@ func TestRouteWatchLiveFeed(t *testing.T) { s := streamer.New(logger, metricsTracker) // Create routing table - rt := routingtable.New() + rt := routingtable.New(logger) // Create server srv := server.New(mockDB, rt, s, logger) diff --git a/internal/routewatch/dbhandler.go b/internal/routewatch/dbhandler.go index ca41a71..cc5c566 100644 --- a/internal/routewatch/dbhandler.go +++ b/internal/routewatch/dbhandler.go @@ -9,7 +9,7 @@ import ( const ( // databaseHandlerQueueSize is the queue capacity for database operations - databaseHandlerQueueSize = 100 + databaseHandlerQueueSize = 200 ) // DatabaseHandler handles BGP messages and stores them in the database @@ -19,7 +19,10 @@ type DatabaseHandler struct { } // NewDatabaseHandler creates a new database handler -func NewDatabaseHandler(db database.Store, logger *slog.Logger) *DatabaseHandler { +func NewDatabaseHandler( + db database.Store, + logger *slog.Logger, +) *DatabaseHandler { return &DatabaseHandler{ db: db, logger: logger, @@ -55,7 +58,13 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) { // Get or create prefix _, err := h.db.GetOrCreatePrefix(prefix, timestamp) if err != nil { - h.logger.Error("Failed to get/create prefix", "prefix", prefix, "error", err) + h.logger.Error( + "Failed to get/create prefix", + "prefix", + prefix, + "error", + err, + ) continue } @@ -63,7 +72,13 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) { // Get or create origin ASN _, err = h.db.GetOrCreateASN(originASN, timestamp) if err != nil { - h.logger.Error("Failed to get/create ASN", "asn", originASN, "error", err) + h.logger.Error( + "Failed to get/create ASN", + "asn", + originASN, + "error", + err, + ) continue } @@ -78,20 +93,36 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) { // Get or create both ASNs fromAS, err := h.db.GetOrCreateASN(fromASN, timestamp) if err != nil { - h.logger.Error("Failed to get/create from ASN", "asn", fromASN, "error", err) + h.logger.Error( + "Failed to get/create from ASN", + "asn", + fromASN, + "error", + err, + ) continue } toAS, err := h.db.GetOrCreateASN(toASN, timestamp) if err != nil { - h.logger.Error("Failed to get/create to ASN", "asn", toASN, "error", err) + h.logger.Error( + "Failed to get/create to ASN", + "asn", + toASN, + "error", + err, + ) continue } // Record the peering - err = h.db.RecordPeering(fromAS.ID.String(), toAS.ID.String(), timestamp) + err = h.db.RecordPeering( + fromAS.ID.String(), + toAS.ID.String(), + timestamp, + ) if err != nil { h.logger.Error("Failed to record peering", "from_asn", fromASN, @@ -109,7 +140,13 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) { // Get prefix _, err := h.db.GetOrCreatePrefix(prefix, timestamp) if err != nil { - h.logger.Error("Failed to get prefix for withdrawal", "prefix", prefix, "error", err) + h.logger.Error( + "Failed to get prefix for withdrawal", + "prefix", + prefix, + "error", + err, + ) continue } diff --git a/internal/routingtable/routingtable.go b/internal/routingtable/routingtable.go index ccf9007..c76a114 100644 --- a/internal/routingtable/routingtable.go +++ b/internal/routingtable/routingtable.go @@ -2,7 +2,13 @@ package routingtable import ( + "compress/gzip" + "encoding/json" "fmt" + "log/slog" + "os" + "path/filepath" + "runtime" "strings" "sync" "sync/atomic" @@ -11,6 +17,15 @@ import ( "github.com/google/uuid" ) +const ( + // routeStalenessThreshold is how old a route can be before we consider it stale + // Using 30 minutes as a conservative value for snapshot loading + routeStalenessThreshold = 30 * time.Minute + + // snapshotFilename is the name of the snapshot file + snapshotFilename = "routewatch-snapshot.json.gz" +) + // Route represents a single route entry in the routing table type Route struct { PrefixID uuid.UUID `json:"prefix_id"` @@ -21,6 +36,7 @@ type Route struct { ASPath []int `json:"as_path"` // Full AS path NextHop string `json:"next_hop"` AnnouncedAt time.Time `json:"announced_at"` + AddedAt time.Time `json:"added_at"` // When we added this route to our table } // RouteKey uniquely identifies a route in the table @@ -48,15 +64,22 @@ type RoutingTable struct { lastMetricsReset time.Time } -// New creates a new empty routing table -func New() *RoutingTable { - return &RoutingTable{ +// New creates a new routing table, loading from snapshot if available +func New(logger *slog.Logger) *RoutingTable { + rt := &RoutingTable{ routes: make(map[RouteKey]*Route), byPrefix: make(map[uuid.UUID]map[RouteKey]*Route), byOriginASN: make(map[uuid.UUID]map[RouteKey]*Route), byPeerASN: make(map[int]map[RouteKey]*Route), lastMetricsReset: time.Now(), } + + // Try to load from snapshot + if err := rt.loadFromSnapshot(logger); err != nil { + logger.Warn("Failed to load routing table from snapshot", "error", err) + } + + return rt } // AddRoute adds or updates a route in the routing table @@ -81,6 +104,11 @@ func (rt *RoutingTable) AddRoute(route *Route) { } } + // Set AddedAt if not already set + if route.AddedAt.IsZero() { + route.AddedAt = time.Now().UTC() + } + // Add to main map rt.routes[key] = route @@ -357,6 +385,7 @@ func (rt *RoutingTable) GetAllRoutesUnsafe() []*Route { for _, route := range rt.routes { routes = append(routes, route) } + return routes } @@ -417,3 +446,108 @@ func (k RouteKey) String() string { func isIPv6(prefix string) bool { return strings.Contains(prefix, ":") } + +// getStateDirectory returns the appropriate state directory based on the OS +func getStateDirectory() (string, error) { + switch runtime.GOOS { + case "darwin": + // macOS: Use ~/Library/Application Support/routewatch + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + return filepath.Join(home, "Library", "Application Support", "routewatch"), nil + case "linux", "freebsd", "openbsd", "netbsd": + // Unix-like: Use /var/lib/routewatch if running as root, otherwise use XDG_STATE_HOME + if os.Geteuid() == 0 { + return "/var/lib/routewatch", nil + } + // Check XDG_STATE_HOME first + if xdgState := os.Getenv("XDG_STATE_HOME"); xdgState != "" { + return filepath.Join(xdgState, "routewatch"), nil + } + // Fall back to ~/.local/state/routewatch + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + return filepath.Join(home, ".local", "state", "routewatch"), nil + default: + return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +// loadFromSnapshot attempts to load the routing table from a snapshot file +func (rt *RoutingTable) loadFromSnapshot(logger *slog.Logger) error { + stateDir, err := getStateDirectory() + if err != nil { + return fmt.Errorf("failed to determine state directory: %w", err) + } + + snapshotPath := filepath.Join(stateDir, snapshotFilename) + + // Check if snapshot file exists + if _, err := os.Stat(snapshotPath); os.IsNotExist(err) { + logger.Info("No snapshot file found, starting with empty routing table") + + return nil + } + + // Open the snapshot file + file, err := os.Open(filepath.Clean(snapshotPath)) + if err != nil { + return fmt.Errorf("failed to open snapshot file: %w", err) + } + defer func() { _ = file.Close() }() + + // Create gzip reader + gzReader, err := gzip.NewReader(file) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer func() { _ = gzReader.Close() }() + + // Decode the snapshot + var snapshot struct { + Timestamp time.Time `json:"timestamp"` + Stats DetailedStats `json:"stats"` + Routes []*Route `json:"routes"` + } + + decoder := json.NewDecoder(gzReader) + if err := decoder.Decode(&snapshot); err != nil { + return fmt.Errorf("failed to decode snapshot: %w", err) + } + + // Calculate staleness cutoff time + now := time.Now().UTC() + cutoffTime := now.Add(-routeStalenessThreshold) + + // Load non-stale routes + loadedCount := 0 + staleCount := 0 + + for _, route := range snapshot.Routes { + // Check if route is stale based on AddedAt time + if route.AddedAt.Before(cutoffTime) { + staleCount++ + + continue + } + + // Add the route (this will update counters and indexes) + rt.AddRoute(route) + loadedCount++ + } + + logger.Info("Loaded routing table from snapshot", + "snapshot_time", snapshot.Timestamp, + "loaded_routes", loadedCount, + "stale_routes", staleCount, + "total_routes_in_snapshot", len(snapshot.Routes), + ) + + return nil +} diff --git a/internal/routingtable/routingtable_test.go b/internal/routingtable/routingtable_test.go index a3bb48e..240cc4a 100644 --- a/internal/routingtable/routingtable_test.go +++ b/internal/routingtable/routingtable_test.go @@ -1,6 +1,7 @@ package routingtable import ( + "log/slog" "sync" "testing" "time" @@ -9,7 +10,9 @@ import ( ) func TestRoutingTable(t *testing.T) { - rt := New() + // Create a test logger + logger := slog.Default() + rt := New(logger) // Test data prefixID1 := uuid.New() @@ -118,7 +121,9 @@ func TestRoutingTable(t *testing.T) { } func TestRoutingTableConcurrency(t *testing.T) { - rt := New() + // Create a test logger + logger := slog.Default() + rt := New(logger) // Test concurrent access var wg sync.WaitGroup @@ -170,7 +175,9 @@ func TestRoutingTableConcurrency(t *testing.T) { } func TestRouteUpdate(t *testing.T) { - rt := New() + // Create a test logger + logger := slog.Default() + rt := New(logger) prefixID := uuid.New() originASNID := uuid.New() diff --git a/internal/snapshotter/snapshotter.go b/internal/snapshotter/snapshotter.go index feffb65..4604aa5 100644 --- a/internal/snapshotter/snapshotter.go +++ b/internal/snapshotter/snapshotter.go @@ -36,32 +36,44 @@ type Snapshotter struct { } // New creates a new Snapshotter instance -func New(ctx context.Context, rt *routingtable.RoutingTable, logger *slog.Logger) (*Snapshotter, error) { +func New(rt *routingtable.RoutingTable, logger *slog.Logger) (*Snapshotter, error) { stateDir, err := getStateDirectory() if err != nil { return nil, fmt.Errorf("failed to determine state directory: %w", err) } // Ensure state directory exists - if err := os.MkdirAll(stateDir, 0755); err != nil { + const stateDirPerms = 0750 + if err := os.MkdirAll(stateDir, stateDirPerms); err != nil { return nil, fmt.Errorf("failed to create state directory: %w", err) } - ctx, cancel := context.WithCancel(ctx) - s := &Snapshotter{ rt: rt, stateDir: stateDir, logger: logger, - ctx: ctx, - cancel: cancel, } + return s, nil +} + +// Start begins the periodic snapshot process +func (s *Snapshotter) Start(ctx context.Context) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.ctx != nil { + // Already started + return + } + + ctx, cancel := context.WithCancel(ctx) + s.ctx = ctx + s.cancel = cancel + // Start periodic snapshot goroutine s.wg.Add(1) go s.periodicSnapshot() - - return s, nil } // getStateDirectory returns the appropriate state directory based on the OS @@ -73,6 +85,7 @@ func getStateDirectory() (string, error) { if err != nil { return "", err } + return filepath.Join(home, "Library", "Application Support", "routewatch"), nil case "linux", "freebsd", "openbsd", "netbsd": // Unix-like: Use /var/lib/routewatch if running as root, otherwise use XDG_STATE_HOME @@ -88,6 +101,7 @@ func getStateDirectory() (string, error) { if err != nil { return "", err } + return filepath.Join(home, ".local", "state", "routewatch"), nil default: return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS) @@ -154,19 +168,23 @@ func (s *Snapshotter) TakeSnapshot() error { tempPath := filepath.Join(s.stateDir, snapshotFilename+tempFileSuffix) finalPath := filepath.Join(s.stateDir, snapshotFilename) + // Clean the paths to avoid any path traversal issues + tempPath = filepath.Clean(tempPath) + finalPath = filepath.Clean(finalPath) + tempFile, err := os.Create(tempPath) if err != nil { return fmt.Errorf("failed to create temporary file: %w", err) } defer func() { - tempFile.Close() + _ = tempFile.Close() // Clean up temp file if it still exists - os.Remove(tempPath) + _ = os.Remove(tempPath) }() // Create gzip writer gzipWriter := gzip.NewWriter(tempFile) - gzipWriter.Header.Comment = fmt.Sprintf("RouteWatch snapshot taken at %s", snapshot.Timestamp.Format(time.RFC3339)) + gzipWriter.Comment = fmt.Sprintf("RouteWatch snapshot taken at %s", snapshot.Timestamp.Format(time.RFC3339)) // Write compressed data if _, err := gzipWriter.Write(jsonData); err != nil { @@ -213,7 +231,9 @@ func (s *Snapshotter) Shutdown() error { s.logger.Info("Shutting down snapshotter") // Cancel context to stop periodic snapshots - s.cancel() + if s.cancel != nil { + s.cancel() + } // Wait for periodic snapshot goroutine to finish s.wg.Wait() @@ -230,6 +250,7 @@ func (s *Snapshotter) Shutdown() error { func (s *Snapshotter) GetLastSnapshotTime() time.Time { s.mu.Lock() defer s.mu.Unlock() + return s.lastSnapshot }