// Package database provides SQLite storage for BGP routing data including ASNs, prefixes, announcements and peerings. package database import ( "context" "database/sql" _ "embed" "encoding/json" "errors" "fmt" "net" "os" "path/filepath" "runtime" "sync" "time" "git.eeqj.de/sneak/routewatch/internal/config" "git.eeqj.de/sneak/routewatch/internal/logger" "git.eeqj.de/sneak/routewatch/pkg/asinfo" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" // CGO SQLite driver ) //go:embed schema.sql var dbSchema string const ( dirPermissions = 0750 // rwxr-x--- ipVersionV4 = 4 ipVersionV6 = 6 ipv6Length = 16 ipv4Offset = 12 ipv4Bits = 32 maxIPv4 = 0xFFFFFFFF ) // Common errors var ( // ErrInvalidIP is returned when an IP address is malformed ErrInvalidIP = errors.New("invalid IP address") // ErrNoRoute is returned when no route is found for an IP ErrNoRoute = errors.New("no route found") ) // Database manages the SQLite database connection and operations. type Database struct { db *sql.DB logger *logger.Logger path string mu sync.Mutex lockedAt time.Time lockedBy string } // New creates a new database connection and initializes the schema. func New(cfg *config.Config, logger *logger.Logger) (*Database, error) { dbPath := filepath.Join(cfg.GetStateDir(), "db.sqlite") // Log database path logger.Info("Opening database", "path", dbPath) // Ensure directory exists dir := filepath.Dir(dbPath) if err := os.MkdirAll(dir, dirPermissions); err != nil { return nil, fmt.Errorf("failed to create database directory: %w", err) } // Add connection parameters for go-sqlite3 // Configure SQLite connection parameters dsn := fmt.Sprintf( "file:%s", dbPath, ) db, err := sql.Open("sqlite3", dsn) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } if err := db.Ping(); err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) } // Set connection pool parameters // Multiple connections allow concurrent reads while writes are serialized const maxConns = 10 db.SetMaxOpenConns(maxConns) db.SetMaxIdleConns(maxConns) db.SetConnMaxLifetime(0) database := &Database{db: db, logger: logger, path: dbPath} if err := database.Initialize(); err != nil { return nil, fmt.Errorf("failed to initialize database: %w", err) } return database, nil } // Initialize creates the database schema if it doesn't exist. func (d *Database) Initialize() error { // Set SQLite pragmas for performance pragmas := []string{ "PRAGMA journal_mode=WAL", // Write-Ahead Logging "PRAGMA synchronous=OFF", // Don't wait for disk writes "PRAGMA cache_size=-3145728", // 3GB cache (upper limit for 2.4GB DB) "PRAGMA temp_store=MEMORY", // Use memory for temp tables "PRAGMA wal_checkpoint(TRUNCATE)", // Checkpoint and truncate WAL now "PRAGMA busy_timeout=5000", // 5 second busy timeout "PRAGMA analysis_limit=0", // Disable automatic ANALYZE } for _, pragma := range pragmas { if err := d.exec(pragma); err != nil { d.logger.Warn("Failed to set pragma", "pragma", pragma, "error", err) } } err := d.exec(dbSchema) if err != nil { return err } // Run VACUUM on startup to optimize database d.logger.Info("Running VACUUM to optimize database (this may take a moment)") if err := d.exec("VACUUM"); err != nil { d.logger.Warn("Failed to VACUUM database", "error", err) } return nil } // Close closes the database connection. func (d *Database) Close() error { return d.db.Close() } // lock acquires the database mutex and logs debug information func (d *Database) lock(operation string) { // Get caller information _, file, line, _ := runtime.Caller(1) caller := fmt.Sprintf("%s:%d", filepath.Base(file), line) d.logger.Debug("Acquiring database lock", "operation", operation, "caller", caller) d.mu.Lock() d.lockedAt = time.Now() d.lockedBy = fmt.Sprintf("%s (%s)", operation, caller) d.logger.Debug("Database lock acquired", "operation", operation, "caller", caller) } // unlock releases the database mutex and logs debug information including hold duration func (d *Database) unlock() { holdDuration := time.Since(d.lockedAt) lockedBy := d.lockedBy d.lockedAt = time.Time{} d.lockedBy = "" d.mu.Unlock() d.logger.Debug("Database lock released", "held_by", lockedBy, "duration_ms", holdDuration.Milliseconds()) } // beginTx starts a new transaction with logging func (d *Database) beginTx() (*loggingTx, error) { tx, err := d.db.Begin() if err != nil { return nil, err } return &loggingTx{Tx: tx, logger: d.logger}, nil } // UpsertLiveRouteBatch inserts or updates multiple live routes in a single transaction func (d *Database) UpsertLiveRouteBatch(routes []*LiveRoute) error { if len(routes) == 0 { return nil } d.lock("UpsertLiveRouteBatch") defer d.unlock() tx, err := d.beginTx() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() // Use prepared statement for better performance query := ` INSERT INTO live_routes (id, prefix, mask_length, ip_version, origin_asn, peer_ip, as_path, next_hop, last_updated, v4_ip_start, v4_ip_end) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET mask_length = excluded.mask_length, ip_version = excluded.ip_version, as_path = excluded.as_path, next_hop = excluded.next_hop, last_updated = excluded.last_updated, v4_ip_start = excluded.v4_ip_start, v4_ip_end = excluded.v4_ip_end ` stmt, err := tx.Prepare(query) if err != nil { return fmt.Errorf("failed to prepare statement: %w", err) } defer func() { _ = stmt.Close() }() for _, route := range routes { // Encode AS path as JSON pathJSON, err := json.Marshal(route.ASPath) if err != nil { return fmt.Errorf("failed to encode AS path: %w", err) } // Convert v4_ip_start and v4_ip_end to interface{} for SQL NULL handling var v4Start, v4End interface{} if route.V4IPStart != nil { v4Start = *route.V4IPStart } if route.V4IPEnd != nil { v4End = *route.V4IPEnd } _, err = stmt.Exec( route.ID.String(), route.Prefix, route.MaskLength, route.IPVersion, route.OriginASN, route.PeerIP, string(pathJSON), route.NextHop, route.LastUpdated, v4Start, v4End, ) if err != nil { return fmt.Errorf("failed to upsert route %s: %w", route.Prefix, err) } } if err = tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } // DeleteLiveRouteBatch deletes multiple live routes in a single transaction func (d *Database) DeleteLiveRouteBatch(deletions []LiveRouteDeletion) error { if len(deletions) == 0 { return nil } d.lock("DeleteLiveRouteBatch") defer d.unlock() tx, err := d.beginTx() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() // Separate deletions by type and use prepared statements var withOrigin []LiveRouteDeletion var withoutOrigin []LiveRouteDeletion for _, del := range deletions { if del.OriginASN == 0 { withoutOrigin = append(withoutOrigin, del) } else { withOrigin = append(withOrigin, del) } } // Process deletions with origin ASN if len(withOrigin) > 0 { stmt, err := tx.Prepare(`DELETE FROM live_routes WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`) if err != nil { return fmt.Errorf("failed to prepare delete statement: %w", err) } defer func() { _ = stmt.Close() }() for _, del := range withOrigin { _, err = stmt.Exec(del.Prefix, del.OriginASN, del.PeerIP) if err != nil { return fmt.Errorf("failed to delete route %s: %w", del.Prefix, err) } } } // Process deletions without origin ASN if len(withoutOrigin) > 0 { stmt, err := tx.Prepare(`DELETE FROM live_routes WHERE prefix = ? AND peer_ip = ?`) if err != nil { return fmt.Errorf("failed to prepare delete statement: %w", err) } defer func() { _ = stmt.Close() }() for _, del := range withoutOrigin { _, err = stmt.Exec(del.Prefix, del.PeerIP) if err != nil { return fmt.Errorf("failed to delete route %s: %w", del.Prefix, err) } } } if err = tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } // GetOrCreateASNBatch creates or updates multiple ASNs in a single transaction func (d *Database) GetOrCreateASNBatch(asns map[int]time.Time) error { if len(asns) == 0 { return nil } d.lock("GetOrCreateASNBatch") defer d.unlock() tx, err := d.beginTx() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() // Prepare statements selectStmt, err := tx.Prepare( "SELECT id, number, handle, description, first_seen, last_seen FROM asns WHERE number = ?") if err != nil { return fmt.Errorf("failed to prepare select statement: %w", err) } defer func() { _ = selectStmt.Close() }() updateStmt, err := tx.Prepare("UPDATE asns SET last_seen = ? WHERE id = ?") if err != nil { return fmt.Errorf("failed to prepare update statement: %w", err) } defer func() { _ = updateStmt.Close() }() insertStmt, err := tx.Prepare( "INSERT INTO asns (id, number, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?, ?)") if err != nil { return fmt.Errorf("failed to prepare insert statement: %w", err) } defer func() { _ = insertStmt.Close() }() for number, timestamp := range asns { var asn ASN var idStr string var handle, description sql.NullString err = selectStmt.QueryRow(number).Scan(&idStr, &asn.Number, &handle, &description, &asn.FirstSeen, &asn.LastSeen) if err == nil { // ASN exists, update last_seen asn.ID, _ = uuid.Parse(idStr) _, err = updateStmt.Exec(timestamp, asn.ID.String()) if err != nil { return fmt.Errorf("failed to update ASN %d: %w", number, err) } continue } if err == sql.ErrNoRows { // ASN doesn't exist, create it asn = ASN{ ID: generateUUID(), Number: number, FirstSeen: timestamp, LastSeen: timestamp, } // Look up ASN info if info, ok := asinfo.Get(number); ok { asn.Handle = info.Handle asn.Description = info.Description } _, err = insertStmt.Exec(asn.ID.String(), asn.Number, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen) if err != nil { return fmt.Errorf("failed to insert ASN %d: %w", number, err) } continue } if err != nil { return fmt.Errorf("failed to query ASN %d: %w", number, err) } } if err = tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } // GetOrCreateASN retrieves an existing ASN or creates a new one if it doesn't exist. func (d *Database) GetOrCreateASN(number int, timestamp time.Time) (*ASN, error) { d.lock("GetOrCreateASN") defer d.unlock() tx, err := d.beginTx() if err != nil { return nil, err } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() var asn ASN var idStr string var handle, description sql.NullString err = tx.QueryRow("SELECT id, number, handle, description, first_seen, last_seen FROM asns WHERE number = ?", number). Scan(&idStr, &asn.Number, &handle, &description, &asn.FirstSeen, &asn.LastSeen) if err == nil { // ASN exists, update last_seen asn.ID, _ = uuid.Parse(idStr) asn.Handle = handle.String asn.Description = description.String _, err = tx.Exec("UPDATE asns SET last_seen = ? WHERE id = ?", timestamp, asn.ID.String()) if err != nil { return nil, err } asn.LastSeen = timestamp if err = tx.Commit(); err != nil { d.logger.Error("Failed to commit transaction for ASN update", "asn", number, "error", err) return nil, err } return &asn, nil } if err != sql.ErrNoRows { return nil, err } // ASN doesn't exist, create it with ASN info lookup asn = ASN{ ID: generateUUID(), Number: number, FirstSeen: timestamp, LastSeen: timestamp, } // Look up ASN info if info, ok := asinfo.Get(number); ok { asn.Handle = info.Handle asn.Description = info.Description } _, err = tx.Exec("INSERT INTO asns (id, number, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?, ?)", asn.ID.String(), asn.Number, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen) if err != nil { return nil, err } if err = tx.Commit(); err != nil { d.logger.Error("Failed to commit transaction for ASN creation", "asn", number, "error", err) return nil, err } return &asn, nil } // GetOrCreatePrefix retrieves an existing prefix or creates a new one if it doesn't exist. func (d *Database) GetOrCreatePrefix(prefix string, timestamp time.Time) (*Prefix, error) { d.lock("GetOrCreatePrefix") defer d.unlock() tx, err := d.beginTx() if err != nil { return nil, err } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() var p Prefix var idStr string err = tx.QueryRow("SELECT id, prefix, ip_version, first_seen, last_seen FROM prefixes WHERE prefix = ?", prefix). Scan(&idStr, &p.Prefix, &p.IPVersion, &p.FirstSeen, &p.LastSeen) if err == nil { // Prefix exists, update last_seen p.ID, _ = uuid.Parse(idStr) _, err = tx.Exec("UPDATE prefixes SET last_seen = ? WHERE id = ?", timestamp, p.ID.String()) if err != nil { return nil, err } p.LastSeen = timestamp if err = tx.Commit(); err != nil { d.logger.Error("Failed to commit transaction for prefix update", "prefix", prefix, "error", err) return nil, err } return &p, nil } if err != sql.ErrNoRows { return nil, err } // Prefix doesn't exist, create it p = Prefix{ ID: generateUUID(), Prefix: prefix, IPVersion: detectIPVersion(prefix), FirstSeen: timestamp, LastSeen: timestamp, } _, err = tx.Exec("INSERT INTO prefixes (id, prefix, ip_version, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)", p.ID.String(), p.Prefix, p.IPVersion, p.FirstSeen, p.LastSeen) if err != nil { return nil, err } if err = tx.Commit(); err != nil { d.logger.Error("Failed to commit transaction for prefix creation", "prefix", prefix, "error", err) return nil, err } return &p, nil } // RecordAnnouncement inserts a new BGP announcement or withdrawal into the database. func (d *Database) RecordAnnouncement(announcement *Announcement) error { d.lock("RecordAnnouncement") defer d.unlock() err := d.exec(` INSERT INTO announcements (id, prefix_id, asn_id, origin_asn_id, path, next_hop, timestamp, is_withdrawal) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, announcement.ID.String(), announcement.PrefixID.String(), announcement.ASNID.String(), announcement.OriginASNID.String(), announcement.Path, announcement.NextHop, announcement.Timestamp, announcement.IsWithdrawal) return err } // RecordPeering records a peering relationship between two ASNs. func (d *Database) RecordPeering(asA, asB int, timestamp time.Time) error { // Validate ASNs if asA <= 0 || asB <= 0 { return fmt.Errorf("invalid ASN: asA=%d, asB=%d", asA, asB) } if asA == asB { return fmt.Errorf("cannot create peering with same ASN: %d", asA) } // Normalize: ensure asA < asB if asA > asB { asA, asB = asB, asA } d.lock("RecordPeering") defer d.unlock() tx, err := d.beginTx() if err != nil { return err } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() var exists bool err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM peerings WHERE as_a = ? AND as_b = ?)", asA, asB).Scan(&exists) if err != nil { return err } if exists { _, err = tx.Exec("UPDATE peerings SET last_seen = ? WHERE as_a = ? AND as_b = ?", timestamp, asA, asB) } else { _, err = tx.Exec(` INSERT INTO peerings (id, as_a, as_b, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)`, generateUUID().String(), asA, asB, timestamp, timestamp) } if err != nil { return err } if err = tx.Commit(); err != nil { d.logger.Error("Failed to commit transaction for peering", "as_a", asA, "as_b", asB, "error", err, ) return err } return nil } // UpdatePeerBatch updates or creates multiple BGP peer records in a single transaction func (d *Database) UpdatePeerBatch(peers map[string]PeerUpdate) error { if len(peers) == 0 { return nil } d.lock("UpdatePeerBatch") defer d.unlock() tx, err := d.beginTx() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() // Prepare statements checkStmt, err := tx.Prepare("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)") if err != nil { return fmt.Errorf("failed to prepare check statement: %w", err) } defer func() { _ = checkStmt.Close() }() updateStmt, err := tx.Prepare( "UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?") if err != nil { return fmt.Errorf("failed to prepare update statement: %w", err) } defer func() { _ = updateStmt.Close() }() insertStmt, err := tx.Prepare( "INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)") if err != nil { return fmt.Errorf("failed to prepare insert statement: %w", err) } defer func() { _ = insertStmt.Close() }() for _, update := range peers { var exists bool err = checkStmt.QueryRow(update.PeerIP).Scan(&exists) if err != nil { return fmt.Errorf("failed to check peer %s: %w", update.PeerIP, err) } if exists { _, err = updateStmt.Exec(update.PeerASN, update.Timestamp, update.MessageType, update.PeerIP) } else { _, err = insertStmt.Exec( generateUUID().String(), update.PeerIP, update.PeerASN, update.Timestamp, update.Timestamp, update.MessageType) } if err != nil { return fmt.Errorf("failed to update peer %s: %w", update.PeerIP, err) } } if err = tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } return nil } // UpdatePeer updates or creates a BGP peer record func (d *Database) UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error { d.lock("UpdatePeer") defer d.unlock() tx, err := d.beginTx() if err != nil { return err } defer func() { if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { d.logger.Error("Failed to rollback transaction", "error", err) } }() var exists bool err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)", peerIP).Scan(&exists) if err != nil { return err } if exists { _, err = tx.Exec( "UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?", peerASN, timestamp, messageType, peerIP, ) } else { _, err = tx.Exec( "INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)", generateUUID().String(), peerIP, peerASN, timestamp, timestamp, messageType, ) } if err != nil { return err } if err = tx.Commit(); err != nil { d.logger.Error("Failed to commit transaction for peer update", "peer_ip", peerIP, "peer_asn", peerASN, "error", err, ) return err } return nil } // GetStats returns database statistics func (d *Database) GetStats() (Stats, error) { return d.GetStatsContext(context.Background()) } // GetStatsContext returns database statistics with context support func (d *Database) GetStatsContext(ctx context.Context) (Stats, error) { var stats Stats // Count ASNs err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM asns").Scan(&stats.ASNs) if err != nil { return stats, err } // Count prefixes err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes").Scan(&stats.Prefixes) if err != nil { return stats, err } // Count IPv4 and IPv6 prefixes const ipVersionV4 = 4 err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV4).Scan(&stats.IPv4Prefixes) if err != nil { return stats, err } const ipVersionV6 = 6 err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV6).Scan(&stats.IPv6Prefixes) if err != nil { return stats, err } // Count peerings err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings) if err != nil { return stats, err } // Count peers err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers) if err != nil { return stats, err } // Get database file size fileInfo, err := os.Stat(d.path) if err != nil { d.logger.Warn("Failed to get database file size", "error", err) stats.FileSizeBytes = 0 } else { stats.FileSizeBytes = fileInfo.Size() } // Get live routes count err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes").Scan(&stats.LiveRoutes) if err != nil { return stats, fmt.Errorf("failed to count live routes: %w", err) } // Get prefix distribution stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistributionContext(ctx) if err != nil { // Log but don't fail d.logger.Warn("Failed to get prefix distribution", "error", err) } return stats, nil } // UpsertLiveRoute inserts or updates a live route func (d *Database) UpsertLiveRoute(route *LiveRoute) error { d.lock("UpsertLiveRoute") defer d.unlock() query := ` INSERT INTO live_routes (id, prefix, mask_length, ip_version, origin_asn, peer_ip, as_path, next_hop, last_updated, v4_ip_start, v4_ip_end) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET mask_length = excluded.mask_length, ip_version = excluded.ip_version, as_path = excluded.as_path, next_hop = excluded.next_hop, last_updated = excluded.last_updated, v4_ip_start = excluded.v4_ip_start, v4_ip_end = excluded.v4_ip_end ` // Encode AS path as JSON pathJSON, err := json.Marshal(route.ASPath) if err != nil { return fmt.Errorf("failed to encode AS path: %w", err) } // Convert v4_ip_start and v4_ip_end to interface{} for SQL NULL handling var v4Start, v4End interface{} if route.V4IPStart != nil { v4Start = *route.V4IPStart } if route.V4IPEnd != nil { v4End = *route.V4IPEnd } _, err = d.db.Exec(query, route.ID.String(), route.Prefix, route.MaskLength, route.IPVersion, route.OriginASN, route.PeerIP, string(pathJSON), route.NextHop, route.LastUpdated, v4Start, v4End, ) return err } // DeleteLiveRoute deletes a live route // If originASN is 0, deletes all routes for the prefix/peer combination func (d *Database) DeleteLiveRoute(prefix string, originASN int, peerIP string) error { d.lock("DeleteLiveRoute") defer d.unlock() var query string var err error if originASN == 0 { // Delete all routes for this prefix from this peer query = `DELETE FROM live_routes WHERE prefix = ? AND peer_ip = ?` _, err = d.db.Exec(query, prefix, peerIP) } else { // Delete specific route query = `DELETE FROM live_routes WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?` _, err = d.db.Exec(query, prefix, originASN, peerIP) } return err } // GetPrefixDistribution returns the distribution of unique prefixes by mask length func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) { return d.GetPrefixDistributionContext(context.Background()) } // GetPrefixDistributionContext returns the distribution of unique prefixes by mask length with context support func (d *Database) GetPrefixDistributionContext(ctx context.Context) ( ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) { // IPv4 distribution - count unique prefixes, not routes query := ` SELECT mask_length, COUNT(DISTINCT prefix) as count FROM live_routes WHERE ip_version = 4 GROUP BY mask_length ORDER BY mask_length ` rows, err := d.db.QueryContext(ctx, query) if err != nil { return nil, nil, fmt.Errorf("failed to query IPv4 distribution: %w", err) } defer func() { _ = rows.Close() }() for rows.Next() { var dist PrefixDistribution if err := rows.Scan(&dist.MaskLength, &dist.Count); err != nil { return nil, nil, fmt.Errorf("failed to scan IPv4 distribution: %w", err) } ipv4 = append(ipv4, dist) } // IPv6 distribution - count unique prefixes, not routes query = ` SELECT mask_length, COUNT(DISTINCT prefix) as count FROM live_routes WHERE ip_version = 6 GROUP BY mask_length ORDER BY mask_length ` rows, err = d.db.QueryContext(ctx, query) if err != nil { return nil, nil, fmt.Errorf("failed to query IPv6 distribution: %w", err) } defer func() { _ = rows.Close() }() for rows.Next() { var dist PrefixDistribution if err := rows.Scan(&dist.MaskLength, &dist.Count); err != nil { return nil, nil, fmt.Errorf("failed to scan IPv6 distribution: %w", err) } ipv6 = append(ipv6, dist) } return ipv4, ipv6, nil } // GetLiveRouteCounts returns the count of IPv4 and IPv6 routes func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) { return d.GetLiveRouteCountsContext(context.Background()) } // GetLiveRouteCountsContext returns the count of IPv4 and IPv6 routes with context support func (d *Database) GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error) { // Get IPv4 count err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes WHERE ip_version = 4").Scan(&ipv4Count) if err != nil { return 0, 0, fmt.Errorf("failed to count IPv4 routes: %w", err) } // Get IPv6 count err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes WHERE ip_version = 6").Scan(&ipv6Count) if err != nil { return 0, 0, fmt.Errorf("failed to count IPv6 routes: %w", err) } return ipv4Count, ipv6Count, nil } // GetASInfoForIP returns AS information for the given IP address func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) { return d.GetASInfoForIPContext(context.Background(), ip) } // GetASInfoForIPContext returns AS information for the given IP address with context support func (d *Database) GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error) { // Parse the IP to validate it parsedIP := net.ParseIP(ip) if parsedIP == nil { return nil, fmt.Errorf("%w: %s", ErrInvalidIP, ip) } // Determine IP version ipVersion := ipVersionV4 ipv4 := parsedIP.To4() if ipv4 == nil { ipVersion = ipVersionV6 } // For IPv4, use optimized range query if ipVersion == ipVersionV4 { // Convert IP to 32-bit unsigned integer ipUint := ipToUint32(ipv4) query := ` SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description FROM live_routes lr LEFT JOIN asns a ON a.number = lr.origin_asn WHERE lr.ip_version = ? AND lr.v4_ip_start <= ? AND lr.v4_ip_end >= ? ORDER BY lr.mask_length DESC LIMIT 1 ` var prefix string var maskLength, originASN int var lastUpdated time.Time var handle, description sql.NullString err := d.db.QueryRowContext(ctx, query, ipVersionV4, ipUint, ipUint).Scan( &prefix, &maskLength, &originASN, &lastUpdated, &handle, &description) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip) } return nil, fmt.Errorf("failed to query routes: %w", err) } age := time.Since(lastUpdated).Round(time.Second).String() return &ASInfo{ ASN: originASN, Handle: handle.String, Description: description.String, Prefix: prefix, LastUpdated: lastUpdated, Age: age, }, nil } // For IPv6, use the original method since we don't have range optimization query := ` SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description FROM live_routes lr LEFT JOIN asns a ON a.number = lr.origin_asn WHERE lr.ip_version = ? ORDER BY lr.mask_length DESC ` rows, err := d.db.QueryContext(ctx, query, ipVersionV6) if err != nil { return nil, fmt.Errorf("failed to query routes: %w", err) } defer func() { _ = rows.Close() }() // Find the most specific matching prefix var bestMatch struct { prefix string maskLength int originASN int lastUpdated time.Time handle sql.NullString description sql.NullString } bestMaskLength := -1 for rows.Next() { var prefix string var maskLength, originASN int var lastUpdated time.Time var handle, description sql.NullString if err := rows.Scan(&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description); err != nil { continue } // Parse the prefix CIDR _, ipNet, err := net.ParseCIDR(prefix) if err != nil { continue } // Check if the IP is in this prefix if ipNet.Contains(parsedIP) && maskLength > bestMaskLength { bestMatch.prefix = prefix bestMatch.maskLength = maskLength bestMatch.originASN = originASN bestMatch.lastUpdated = lastUpdated bestMatch.handle = handle bestMatch.description = description bestMaskLength = maskLength } } if bestMaskLength == -1 { return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip) } age := time.Since(bestMatch.lastUpdated).Round(time.Second).String() return &ASInfo{ ASN: bestMatch.originASN, Handle: bestMatch.handle.String, Description: bestMatch.description.String, Prefix: bestMatch.prefix, LastUpdated: bestMatch.lastUpdated, Age: age, }, nil } // ipToUint32 converts an IPv4 address to a 32-bit unsigned integer func ipToUint32(ip net.IP) uint32 { if len(ip) == ipv6Length { // Convert to 4-byte representation ip = ip[ipv4Offset:ipv6Length] } return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) } // CalculateIPv4Range calculates the start and end IP addresses for an IPv4 CIDR block func CalculateIPv4Range(cidr string) (start, end uint32, err error) { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { return 0, 0, err } // Get the network address (start of range) ip := ipNet.IP.To4() if ip == nil { return 0, 0, fmt.Errorf("not an IPv4 address") } start = ipToUint32(ip) // Calculate the end of the range ones, bits := ipNet.Mask.Size() hostBits := bits - ones if hostBits >= ipv4Bits { // Special case for /0 - entire IPv4 space end = maxIPv4 } else { // Safe to convert since we checked hostBits < 32 //nolint:gosec // hostBits is guaranteed to be < 32 from the check above end = start | ((1 << uint(hostBits)) - 1) } return start, end, nil } // GetASDetails returns detailed information about an AS including prefixes func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) { return d.GetASDetailsContext(context.Background(), asn) } // GetASDetailsContext returns detailed information about an AS including prefixes with context support func (d *Database) GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error) { // Get AS information var asnInfo ASN var idStr string var handle, description sql.NullString err := d.db.QueryRowContext(ctx, "SELECT id, number, handle, description, first_seen, last_seen FROM asns WHERE number = ?", asn, ).Scan(&idStr, &asnInfo.Number, &handle, &description, &asnInfo.FirstSeen, &asnInfo.LastSeen) if err != nil { if err == sql.ErrNoRows { return nil, nil, fmt.Errorf("%w: AS%d", ErrNoRoute, asn) } return nil, nil, fmt.Errorf("failed to query AS: %w", err) } asnInfo.ID, _ = uuid.Parse(idStr) asnInfo.Handle = handle.String asnInfo.Description = description.String // Get prefixes announced by this AS (unique prefixes with most recent update) query := ` SELECT prefix, mask_length, ip_version, MAX(last_updated) as last_updated FROM live_routes WHERE origin_asn = ? GROUP BY prefix, mask_length, ip_version ` rows, err := d.db.QueryContext(ctx, query, asn) if err != nil { return &asnInfo, nil, fmt.Errorf("failed to query prefixes: %w", err) } defer func() { _ = rows.Close() }() var prefixes []LiveRoute for rows.Next() { var route LiveRoute var lastUpdatedStr string err := rows.Scan(&route.Prefix, &route.MaskLength, &route.IPVersion, &lastUpdatedStr) if err != nil { d.logger.Error("Failed to scan prefix row", "error", err, "asn", asn) continue } // Parse the timestamp string route.LastUpdated, err = time.Parse("2006-01-02 15:04:05-07:00", lastUpdatedStr) if err != nil { // Try without timezone route.LastUpdated, err = time.Parse("2006-01-02 15:04:05", lastUpdatedStr) if err != nil { d.logger.Error("Failed to parse timestamp", "error", err, "timestamp", lastUpdatedStr) continue } } route.OriginASN = asn prefixes = append(prefixes, route) } return &asnInfo, prefixes, nil } // GetPrefixDetails returns detailed information about a prefix func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) { return d.GetPrefixDetailsContext(context.Background(), prefix) } // GetPrefixDetailsContext returns detailed information about a prefix with context support func (d *Database) GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error) { query := ` SELECT lr.origin_asn, lr.peer_ip, lr.as_path, lr.next_hop, lr.last_updated, a.handle, a.description FROM live_routes lr LEFT JOIN asns a ON a.number = lr.origin_asn WHERE lr.prefix = ? ORDER BY lr.origin_asn, lr.peer_ip ` rows, err := d.db.QueryContext(ctx, query, prefix) if err != nil { return nil, fmt.Errorf("failed to query prefix details: %w", err) } defer func() { _ = rows.Close() }() var routes []LiveRoute for rows.Next() { var route LiveRoute var pathJSON string var handle, description sql.NullString err := rows.Scan( &route.OriginASN, &route.PeerIP, &pathJSON, &route.NextHop, &route.LastUpdated, &handle, &description, ) if err != nil { continue } // Decode AS path if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil { route.ASPath = []int{} } route.Prefix = prefix routes = append(routes, route) } if len(routes) == 0 { return nil, fmt.Errorf("%w: %s", ErrNoRoute, prefix) } return routes, nil } // GetRandomPrefixesByLength returns a random sample of prefixes with the specified mask length func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error) { return d.GetRandomPrefixesByLengthContext(context.Background(), maskLength, ipVersion, limit) } // GetRandomPrefixesByLengthContext returns a random sample of prefixes with context support func (d *Database) GetRandomPrefixesByLengthContext( ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error) { // Select unique prefixes with their most recent route information query := ` WITH unique_prefixes AS ( SELECT prefix, MAX(last_updated) as max_updated FROM live_routes WHERE mask_length = ? AND ip_version = ? GROUP BY prefix ORDER BY RANDOM() LIMIT ? ) SELECT lr.prefix, lr.mask_length, lr.ip_version, lr.origin_asn, lr.as_path, lr.peer_ip, lr.last_updated FROM live_routes lr INNER JOIN unique_prefixes up ON lr.prefix = up.prefix AND lr.last_updated = up.max_updated WHERE lr.mask_length = ? AND lr.ip_version = ? ` rows, err := d.db.QueryContext(ctx, query, maskLength, ipVersion, limit, maskLength, ipVersion) if err != nil { return nil, fmt.Errorf("failed to query random prefixes: %w", err) } defer func() { _ = rows.Close() }() var routes []LiveRoute for rows.Next() { var route LiveRoute var pathJSON string err := rows.Scan( &route.Prefix, &route.MaskLength, &route.IPVersion, &route.OriginASN, &pathJSON, &route.PeerIP, &route.LastUpdated, ) if err != nil { continue } // Decode AS path if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil { route.ASPath = []int{} } routes = append(routes, route) } return routes, nil }