Implement routing table snapshotter with automatic loading on startup
- Create snapshotter package with periodic (10 min) and on-demand snapshots - Add JSON serialization with gzip compression and atomic file writes - Update routing table to track AddedAt time for each route - Load snapshots on startup, filtering out stale routes (>30 minutes old) - Add ROUTEWATCH_DISABLE_SNAPSHOTTER env var for tests - Use OS-appropriate state directories (macOS: ~/Library/Application Support, Linux: /var/lib or XDG_STATE_HOME)
This commit is contained in:
parent
283f2ddbf2
commit
ae2ef2ae0c
@ -14,6 +14,7 @@ import (
|
|||||||
"git.eeqj.de/sneak/routewatch/internal/metrics"
|
"git.eeqj.de/sneak/routewatch/internal/metrics"
|
||||||
"git.eeqj.de/sneak/routewatch/internal/routingtable"
|
"git.eeqj.de/sneak/routewatch/internal/routingtable"
|
||||||
"git.eeqj.de/sneak/routewatch/internal/server"
|
"git.eeqj.de/sneak/routewatch/internal/server"
|
||||||
|
"git.eeqj.de/sneak/routewatch/internal/snapshotter"
|
||||||
"git.eeqj.de/sneak/routewatch/internal/streamer"
|
"git.eeqj.de/sneak/routewatch/internal/streamer"
|
||||||
|
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
@ -54,13 +55,14 @@ type RouteWatch struct {
|
|||||||
routingTable *routingtable.RoutingTable
|
routingTable *routingtable.RoutingTable
|
||||||
streamer *streamer.Streamer
|
streamer *streamer.Streamer
|
||||||
server *server.Server
|
server *server.Server
|
||||||
|
snapshotter *snapshotter.Snapshotter
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
maxRuntime time.Duration
|
maxRuntime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new RouteWatch instance
|
// New creates a new RouteWatch instance
|
||||||
func New(deps Dependencies) *RouteWatch {
|
func New(deps Dependencies) *RouteWatch {
|
||||||
return &RouteWatch{
|
rw := &RouteWatch{
|
||||||
db: deps.DB,
|
db: deps.DB,
|
||||||
routingTable: deps.RoutingTable,
|
routingTable: deps.RoutingTable,
|
||||||
streamer: deps.Streamer,
|
streamer: deps.Streamer,
|
||||||
@ -68,6 +70,19 @@ func New(deps Dependencies) *RouteWatch {
|
|||||||
logger: deps.Logger,
|
logger: deps.Logger,
|
||||||
maxRuntime: deps.Config.MaxRuntime,
|
maxRuntime: deps.Config.MaxRuntime,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create snapshotter unless disabled (for tests)
|
||||||
|
if os.Getenv("ROUTEWATCH_DISABLE_SNAPSHOTTER") != "1" {
|
||||||
|
snapshotter, err := snapshotter.New(deps.RoutingTable, deps.Logger)
|
||||||
|
if err != nil {
|
||||||
|
deps.Logger.Error("Failed to create snapshotter", "error", err)
|
||||||
|
// Continue without snapshotter
|
||||||
|
} else {
|
||||||
|
rw.snapshotter = snapshotter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rw
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run starts the RouteWatch application
|
// Run starts the RouteWatch application
|
||||||
@ -97,6 +112,11 @@ func (rw *RouteWatch) Run(ctx context.Context) error {
|
|||||||
// Start periodic routing table stats logging
|
// Start periodic routing table stats logging
|
||||||
go rw.logRoutingTableStats(ctx)
|
go rw.logRoutingTableStats(ctx)
|
||||||
|
|
||||||
|
// Start snapshotter if available
|
||||||
|
if rw.snapshotter != nil {
|
||||||
|
rw.snapshotter.Start(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// Start streaming
|
// Start streaming
|
||||||
if err := rw.streamer.Start(); err != nil {
|
if err := rw.streamer.Start(); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -131,6 +151,13 @@ func (rw *RouteWatch) Run(ctx context.Context) error {
|
|||||||
"duration", time.Since(metrics.ConnectedSince),
|
"duration", time.Since(metrics.ConnectedSince),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Take final snapshot before shutdown if snapshotter is available
|
||||||
|
if rw.snapshotter != nil {
|
||||||
|
if err := rw.snapshotter.Shutdown(); err != nil {
|
||||||
|
rw.logger.Error("Failed to shutdown snapshotter", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,10 +226,7 @@ func getModule() fx.Option {
|
|||||||
),
|
),
|
||||||
routingtable.New,
|
routingtable.New,
|
||||||
streamer.New,
|
streamer.New,
|
||||||
fx.Annotate(
|
server.New,
|
||||||
server.New,
|
|
||||||
fx.ParamTags(``, ``, ``, ``),
|
|
||||||
),
|
|
||||||
New,
|
New,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -156,6 +156,9 @@ func (m *mockStore) GetStats() (database.Stats, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteWatchLiveFeed(t *testing.T) {
|
func TestRouteWatchLiveFeed(t *testing.T) {
|
||||||
|
// Disable snapshotter for tests
|
||||||
|
t.Setenv("ROUTEWATCH_DISABLE_SNAPSHOTTER", "1")
|
||||||
|
|
||||||
// Create mock database
|
// Create mock database
|
||||||
mockDB := newMockStore()
|
mockDB := newMockStore()
|
||||||
defer mockDB.Close()
|
defer mockDB.Close()
|
||||||
@ -169,7 +172,7 @@ func TestRouteWatchLiveFeed(t *testing.T) {
|
|||||||
s := streamer.New(logger, metricsTracker)
|
s := streamer.New(logger, metricsTracker)
|
||||||
|
|
||||||
// Create routing table
|
// Create routing table
|
||||||
rt := routingtable.New()
|
rt := routingtable.New(logger)
|
||||||
|
|
||||||
// Create server
|
// Create server
|
||||||
srv := server.New(mockDB, rt, s, logger)
|
srv := server.New(mockDB, rt, s, logger)
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// databaseHandlerQueueSize is the queue capacity for database operations
|
// databaseHandlerQueueSize is the queue capacity for database operations
|
||||||
databaseHandlerQueueSize = 100
|
databaseHandlerQueueSize = 200
|
||||||
)
|
)
|
||||||
|
|
||||||
// DatabaseHandler handles BGP messages and stores them in the database
|
// DatabaseHandler handles BGP messages and stores them in the database
|
||||||
@ -19,7 +19,10 @@ type DatabaseHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabaseHandler creates a new database handler
|
// NewDatabaseHandler creates a new database handler
|
||||||
func NewDatabaseHandler(db database.Store, logger *slog.Logger) *DatabaseHandler {
|
func NewDatabaseHandler(
|
||||||
|
db database.Store,
|
||||||
|
logger *slog.Logger,
|
||||||
|
) *DatabaseHandler {
|
||||||
return &DatabaseHandler{
|
return &DatabaseHandler{
|
||||||
db: db,
|
db: db,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
@ -55,7 +58,13 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) {
|
|||||||
// Get or create prefix
|
// Get or create prefix
|
||||||
_, err := h.db.GetOrCreatePrefix(prefix, timestamp)
|
_, err := h.db.GetOrCreatePrefix(prefix, timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("Failed to get/create prefix", "prefix", prefix, "error", err)
|
h.logger.Error(
|
||||||
|
"Failed to get/create prefix",
|
||||||
|
"prefix",
|
||||||
|
prefix,
|
||||||
|
"error",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -63,7 +72,13 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) {
|
|||||||
// Get or create origin ASN
|
// Get or create origin ASN
|
||||||
_, err = h.db.GetOrCreateASN(originASN, timestamp)
|
_, err = h.db.GetOrCreateASN(originASN, timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("Failed to get/create ASN", "asn", originASN, "error", err)
|
h.logger.Error(
|
||||||
|
"Failed to get/create ASN",
|
||||||
|
"asn",
|
||||||
|
originASN,
|
||||||
|
"error",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -78,20 +93,36 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) {
|
|||||||
// Get or create both ASNs
|
// Get or create both ASNs
|
||||||
fromAS, err := h.db.GetOrCreateASN(fromASN, timestamp)
|
fromAS, err := h.db.GetOrCreateASN(fromASN, timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("Failed to get/create from ASN", "asn", fromASN, "error", err)
|
h.logger.Error(
|
||||||
|
"Failed to get/create from ASN",
|
||||||
|
"asn",
|
||||||
|
fromASN,
|
||||||
|
"error",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
toAS, err := h.db.GetOrCreateASN(toASN, timestamp)
|
toAS, err := h.db.GetOrCreateASN(toASN, timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("Failed to get/create to ASN", "asn", toASN, "error", err)
|
h.logger.Error(
|
||||||
|
"Failed to get/create to ASN",
|
||||||
|
"asn",
|
||||||
|
toASN,
|
||||||
|
"error",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record the peering
|
// Record the peering
|
||||||
err = h.db.RecordPeering(fromAS.ID.String(), toAS.ID.String(), timestamp)
|
err = h.db.RecordPeering(
|
||||||
|
fromAS.ID.String(),
|
||||||
|
toAS.ID.String(),
|
||||||
|
timestamp,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("Failed to record peering",
|
h.logger.Error("Failed to record peering",
|
||||||
"from_asn", fromASN,
|
"from_asn", fromASN,
|
||||||
@ -109,7 +140,13 @@ func (h *DatabaseHandler) HandleMessage(msg *ristypes.RISMessage) {
|
|||||||
// Get prefix
|
// Get prefix
|
||||||
_, err := h.db.GetOrCreatePrefix(prefix, timestamp)
|
_, err := h.db.GetOrCreatePrefix(prefix, timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("Failed to get prefix for withdrawal", "prefix", prefix, "error", err)
|
h.logger.Error(
|
||||||
|
"Failed to get prefix for withdrawal",
|
||||||
|
"prefix",
|
||||||
|
prefix,
|
||||||
|
"error",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,13 @@
|
|||||||
package routingtable
|
package routingtable
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -11,6 +17,15 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// routeStalenessThreshold is how old a route can be before we consider it stale
|
||||||
|
// Using 30 minutes as a conservative value for snapshot loading
|
||||||
|
routeStalenessThreshold = 30 * time.Minute
|
||||||
|
|
||||||
|
// snapshotFilename is the name of the snapshot file
|
||||||
|
snapshotFilename = "routewatch-snapshot.json.gz"
|
||||||
|
)
|
||||||
|
|
||||||
// Route represents a single route entry in the routing table
|
// Route represents a single route entry in the routing table
|
||||||
type Route struct {
|
type Route struct {
|
||||||
PrefixID uuid.UUID `json:"prefix_id"`
|
PrefixID uuid.UUID `json:"prefix_id"`
|
||||||
@ -21,6 +36,7 @@ type Route struct {
|
|||||||
ASPath []int `json:"as_path"` // Full AS path
|
ASPath []int `json:"as_path"` // Full AS path
|
||||||
NextHop string `json:"next_hop"`
|
NextHop string `json:"next_hop"`
|
||||||
AnnouncedAt time.Time `json:"announced_at"`
|
AnnouncedAt time.Time `json:"announced_at"`
|
||||||
|
AddedAt time.Time `json:"added_at"` // When we added this route to our table
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteKey uniquely identifies a route in the table
|
// RouteKey uniquely identifies a route in the table
|
||||||
@ -48,15 +64,22 @@ type RoutingTable struct {
|
|||||||
lastMetricsReset time.Time
|
lastMetricsReset time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new empty routing table
|
// New creates a new routing table, loading from snapshot if available
|
||||||
func New() *RoutingTable {
|
func New(logger *slog.Logger) *RoutingTable {
|
||||||
return &RoutingTable{
|
rt := &RoutingTable{
|
||||||
routes: make(map[RouteKey]*Route),
|
routes: make(map[RouteKey]*Route),
|
||||||
byPrefix: make(map[uuid.UUID]map[RouteKey]*Route),
|
byPrefix: make(map[uuid.UUID]map[RouteKey]*Route),
|
||||||
byOriginASN: make(map[uuid.UUID]map[RouteKey]*Route),
|
byOriginASN: make(map[uuid.UUID]map[RouteKey]*Route),
|
||||||
byPeerASN: make(map[int]map[RouteKey]*Route),
|
byPeerASN: make(map[int]map[RouteKey]*Route),
|
||||||
lastMetricsReset: time.Now(),
|
lastMetricsReset: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to load from snapshot
|
||||||
|
if err := rt.loadFromSnapshot(logger); err != nil {
|
||||||
|
logger.Warn("Failed to load routing table from snapshot", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rt
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRoute adds or updates a route in the routing table
|
// AddRoute adds or updates a route in the routing table
|
||||||
@ -81,6 +104,11 @@ func (rt *RoutingTable) AddRoute(route *Route) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set AddedAt if not already set
|
||||||
|
if route.AddedAt.IsZero() {
|
||||||
|
route.AddedAt = time.Now().UTC()
|
||||||
|
}
|
||||||
|
|
||||||
// Add to main map
|
// Add to main map
|
||||||
rt.routes[key] = route
|
rt.routes[key] = route
|
||||||
|
|
||||||
@ -357,6 +385,7 @@ func (rt *RoutingTable) GetAllRoutesUnsafe() []*Route {
|
|||||||
for _, route := range rt.routes {
|
for _, route := range rt.routes {
|
||||||
routes = append(routes, route)
|
routes = append(routes, route)
|
||||||
}
|
}
|
||||||
|
|
||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -417,3 +446,108 @@ func (k RouteKey) String() string {
|
|||||||
func isIPv6(prefix string) bool {
|
func isIPv6(prefix string) bool {
|
||||||
return strings.Contains(prefix, ":")
|
return strings.Contains(prefix, ":")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getStateDirectory returns the appropriate state directory based on the OS
|
||||||
|
func getStateDirectory() (string, error) {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
// macOS: Use ~/Library/Application Support/routewatch
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(home, "Library", "Application Support", "routewatch"), nil
|
||||||
|
case "linux", "freebsd", "openbsd", "netbsd":
|
||||||
|
// Unix-like: Use /var/lib/routewatch if running as root, otherwise use XDG_STATE_HOME
|
||||||
|
if os.Geteuid() == 0 {
|
||||||
|
return "/var/lib/routewatch", nil
|
||||||
|
}
|
||||||
|
// Check XDG_STATE_HOME first
|
||||||
|
if xdgState := os.Getenv("XDG_STATE_HOME"); xdgState != "" {
|
||||||
|
return filepath.Join(xdgState, "routewatch"), nil
|
||||||
|
}
|
||||||
|
// Fall back to ~/.local/state/routewatch
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(home, ".local", "state", "routewatch"), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFromSnapshot attempts to load the routing table from a snapshot file
|
||||||
|
func (rt *RoutingTable) loadFromSnapshot(logger *slog.Logger) error {
|
||||||
|
stateDir, err := getStateDirectory()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to determine state directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshotPath := filepath.Join(stateDir, snapshotFilename)
|
||||||
|
|
||||||
|
// Check if snapshot file exists
|
||||||
|
if _, err := os.Stat(snapshotPath); os.IsNotExist(err) {
|
||||||
|
logger.Info("No snapshot file found, starting with empty routing table")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open the snapshot file
|
||||||
|
file, err := os.Open(filepath.Clean(snapshotPath))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open snapshot file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = file.Close() }()
|
||||||
|
|
||||||
|
// Create gzip reader
|
||||||
|
gzReader, err := gzip.NewReader(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = gzReader.Close() }()
|
||||||
|
|
||||||
|
// Decode the snapshot
|
||||||
|
var snapshot struct {
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Stats DetailedStats `json:"stats"`
|
||||||
|
Routes []*Route `json:"routes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder := json.NewDecoder(gzReader)
|
||||||
|
if err := decoder.Decode(&snapshot); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode snapshot: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate staleness cutoff time
|
||||||
|
now := time.Now().UTC()
|
||||||
|
cutoffTime := now.Add(-routeStalenessThreshold)
|
||||||
|
|
||||||
|
// Load non-stale routes
|
||||||
|
loadedCount := 0
|
||||||
|
staleCount := 0
|
||||||
|
|
||||||
|
for _, route := range snapshot.Routes {
|
||||||
|
// Check if route is stale based on AddedAt time
|
||||||
|
if route.AddedAt.Before(cutoffTime) {
|
||||||
|
staleCount++
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the route (this will update counters and indexes)
|
||||||
|
rt.AddRoute(route)
|
||||||
|
loadedCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Loaded routing table from snapshot",
|
||||||
|
"snapshot_time", snapshot.Timestamp,
|
||||||
|
"loaded_routes", loadedCount,
|
||||||
|
"stale_routes", staleCount,
|
||||||
|
"total_routes_in_snapshot", len(snapshot.Routes),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package routingtable
|
package routingtable
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -9,7 +10,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRoutingTable(t *testing.T) {
|
func TestRoutingTable(t *testing.T) {
|
||||||
rt := New()
|
// Create a test logger
|
||||||
|
logger := slog.Default()
|
||||||
|
rt := New(logger)
|
||||||
|
|
||||||
// Test data
|
// Test data
|
||||||
prefixID1 := uuid.New()
|
prefixID1 := uuid.New()
|
||||||
@ -118,7 +121,9 @@ func TestRoutingTable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutingTableConcurrency(t *testing.T) {
|
func TestRoutingTableConcurrency(t *testing.T) {
|
||||||
rt := New()
|
// Create a test logger
|
||||||
|
logger := slog.Default()
|
||||||
|
rt := New(logger)
|
||||||
|
|
||||||
// Test concurrent access
|
// Test concurrent access
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@ -170,7 +175,9 @@ func TestRoutingTableConcurrency(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteUpdate(t *testing.T) {
|
func TestRouteUpdate(t *testing.T) {
|
||||||
rt := New()
|
// Create a test logger
|
||||||
|
logger := slog.Default()
|
||||||
|
rt := New(logger)
|
||||||
|
|
||||||
prefixID := uuid.New()
|
prefixID := uuid.New()
|
||||||
originASNID := uuid.New()
|
originASNID := uuid.New()
|
||||||
|
@ -36,32 +36,44 @@ type Snapshotter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Snapshotter instance
|
// New creates a new Snapshotter instance
|
||||||
func New(ctx context.Context, rt *routingtable.RoutingTable, logger *slog.Logger) (*Snapshotter, error) {
|
func New(rt *routingtable.RoutingTable, logger *slog.Logger) (*Snapshotter, error) {
|
||||||
stateDir, err := getStateDirectory()
|
stateDir, err := getStateDirectory()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to determine state directory: %w", err)
|
return nil, fmt.Errorf("failed to determine state directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure state directory exists
|
// Ensure state directory exists
|
||||||
if err := os.MkdirAll(stateDir, 0755); err != nil {
|
const stateDirPerms = 0750
|
||||||
|
if err := os.MkdirAll(stateDir, stateDirPerms); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create state directory: %w", err)
|
return nil, fmt.Errorf("failed to create state directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
s := &Snapshotter{
|
s := &Snapshotter{
|
||||||
rt: rt,
|
rt: rt,
|
||||||
stateDir: stateDir,
|
stateDir: stateDir,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the periodic snapshot process
|
||||||
|
func (s *Snapshotter) Start(ctx context.Context) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.ctx != nil {
|
||||||
|
// Already started
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
s.ctx = ctx
|
||||||
|
s.cancel = cancel
|
||||||
|
|
||||||
// Start periodic snapshot goroutine
|
// Start periodic snapshot goroutine
|
||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
go s.periodicSnapshot()
|
go s.periodicSnapshot()
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStateDirectory returns the appropriate state directory based on the OS
|
// getStateDirectory returns the appropriate state directory based on the OS
|
||||||
@ -73,6 +85,7 @@ func getStateDirectory() (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return filepath.Join(home, "Library", "Application Support", "routewatch"), nil
|
return filepath.Join(home, "Library", "Application Support", "routewatch"), nil
|
||||||
case "linux", "freebsd", "openbsd", "netbsd":
|
case "linux", "freebsd", "openbsd", "netbsd":
|
||||||
// Unix-like: Use /var/lib/routewatch if running as root, otherwise use XDG_STATE_HOME
|
// Unix-like: Use /var/lib/routewatch if running as root, otherwise use XDG_STATE_HOME
|
||||||
@ -88,6 +101,7 @@ func getStateDirectory() (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return filepath.Join(home, ".local", "state", "routewatch"), nil
|
return filepath.Join(home, ".local", "state", "routewatch"), nil
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||||
@ -154,19 +168,23 @@ func (s *Snapshotter) TakeSnapshot() error {
|
|||||||
tempPath := filepath.Join(s.stateDir, snapshotFilename+tempFileSuffix)
|
tempPath := filepath.Join(s.stateDir, snapshotFilename+tempFileSuffix)
|
||||||
finalPath := filepath.Join(s.stateDir, snapshotFilename)
|
finalPath := filepath.Join(s.stateDir, snapshotFilename)
|
||||||
|
|
||||||
|
// Clean the paths to avoid any path traversal issues
|
||||||
|
tempPath = filepath.Clean(tempPath)
|
||||||
|
finalPath = filepath.Clean(finalPath)
|
||||||
|
|
||||||
tempFile, err := os.Create(tempPath)
|
tempFile, err := os.Create(tempPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create temporary file: %w", err)
|
return fmt.Errorf("failed to create temporary file: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
tempFile.Close()
|
_ = tempFile.Close()
|
||||||
// Clean up temp file if it still exists
|
// Clean up temp file if it still exists
|
||||||
os.Remove(tempPath)
|
_ = os.Remove(tempPath)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Create gzip writer
|
// Create gzip writer
|
||||||
gzipWriter := gzip.NewWriter(tempFile)
|
gzipWriter := gzip.NewWriter(tempFile)
|
||||||
gzipWriter.Header.Comment = fmt.Sprintf("RouteWatch snapshot taken at %s", snapshot.Timestamp.Format(time.RFC3339))
|
gzipWriter.Comment = fmt.Sprintf("RouteWatch snapshot taken at %s", snapshot.Timestamp.Format(time.RFC3339))
|
||||||
|
|
||||||
// Write compressed data
|
// Write compressed data
|
||||||
if _, err := gzipWriter.Write(jsonData); err != nil {
|
if _, err := gzipWriter.Write(jsonData); err != nil {
|
||||||
@ -213,7 +231,9 @@ func (s *Snapshotter) Shutdown() error {
|
|||||||
s.logger.Info("Shutting down snapshotter")
|
s.logger.Info("Shutting down snapshotter")
|
||||||
|
|
||||||
// Cancel context to stop periodic snapshots
|
// Cancel context to stop periodic snapshots
|
||||||
s.cancel()
|
if s.cancel != nil {
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for periodic snapshot goroutine to finish
|
// Wait for periodic snapshot goroutine to finish
|
||||||
s.wg.Wait()
|
s.wg.Wait()
|
||||||
@ -230,6 +250,7 @@ func (s *Snapshotter) Shutdown() error {
|
|||||||
func (s *Snapshotter) GetLastSnapshotTime() time.Time {
|
func (s *Snapshotter) GetLastSnapshotTime() time.Time {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
return s.lastSnapshot
|
return s.lastSnapshot
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user