routewatch/internal/database/database.go
sneak 95bbb655ab Add godoc documentation and README with code structure
Add comprehensive godoc comments to all exported types, functions,
and constants throughout the codebase. Create README.md documenting
the project architecture, execution flow, database schema, and
component relationships.
2025-12-27 12:30:46 +07:00

1633 lines
45 KiB
Go

// Package database provides SQLite storage for BGP routing data including ASNs, prefixes, announcements and peerings.
package database
import (
"context"
"database/sql"
_ "embed"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"sync"
"time"
"git.eeqj.de/sneak/routewatch/internal/config"
"git.eeqj.de/sneak/routewatch/internal/logger"
"git.eeqj.de/sneak/routewatch/pkg/asinfo"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3" // CGO SQLite driver
)
// IMPORTANT: NO schema changes are to be made outside of schema.sql
// We do NOT support migrations. All schema changes MUST be made in schema.sql only.
//
//go:embed schema.sql
var dbSchema string
const (
dirPermissions = 0750 // rwxr-x---
ipVersionV4 = 4
ipVersionV6 = 6
ipv6Length = 16
ipv4Offset = 12
ipv4Bits = 32
maxIPv4 = 0xFFFFFFFF
)
// Common errors
var (
// ErrInvalidIP is returned when an IP address is malformed
ErrInvalidIP = errors.New("invalid IP address")
// ErrNoRoute is returned when no route is found for an IP
ErrNoRoute = errors.New("no route found")
)
// Database manages the SQLite database connection and operations.
type Database struct {
db *sql.DB
logger *logger.Logger
path string
mu sync.Mutex
lockedAt time.Time
lockedBy string
}
// New creates a new database connection and initializes the schema.
func New(cfg *config.Config, logger *logger.Logger) (*Database, error) {
dbPath := filepath.Join(cfg.GetStateDir(), "db.sqlite")
// Log database path
logger.Info("Opening database", "path", dbPath)
// Ensure directory exists
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, dirPermissions); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
// Add connection parameters for go-sqlite3
// Configure SQLite connection parameters
dsn := fmt.Sprintf(
"file:%s",
dbPath,
)
db, err := sql.Open("sqlite3", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
// Set connection pool parameters
// Multiple connections allow concurrent reads while writes are serialized
const maxConns = 10
db.SetMaxOpenConns(maxConns)
db.SetMaxIdleConns(maxConns)
db.SetConnMaxLifetime(0)
database := &Database{db: db, logger: logger, path: dbPath}
if err := database.Initialize(); err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err)
}
return database, nil
}
// Initialize creates the database schema if it doesn't exist.
func (d *Database) Initialize() error {
// Set SQLite pragmas for performance
pragmas := []string{
"PRAGMA journal_mode=WAL", // Write-Ahead Logging
"PRAGMA synchronous=OFF", // Don't wait for disk writes
"PRAGMA cache_size=-3145728", // 3GB cache (upper limit for 2.4GB DB)
"PRAGMA temp_store=MEMORY", // Use memory for temp tables
"PRAGMA wal_checkpoint(TRUNCATE)", // Checkpoint and truncate WAL now
"PRAGMA busy_timeout=5000", // 5 second busy timeout
"PRAGMA analysis_limit=0", // Disable automatic ANALYZE
}
for _, pragma := range pragmas {
if err := d.exec(pragma); err != nil {
d.logger.Warn("Failed to set pragma", "pragma", pragma, "error", err)
}
}
err := d.exec(dbSchema)
if err != nil {
return err
}
// Run VACUUM on startup to optimize database
d.logger.Info("Running VACUUM to optimize database (this may take a moment)")
if err := d.exec("VACUUM"); err != nil {
d.logger.Warn("Failed to VACUUM database", "error", err)
}
return nil
}
// Close closes the database connection.
func (d *Database) Close() error {
return d.db.Close()
}
// lock acquires the database mutex and logs debug information
func (d *Database) lock(operation string) {
// Get caller information
_, file, line, _ := runtime.Caller(1)
caller := fmt.Sprintf("%s:%d", filepath.Base(file), line)
d.logger.Debug("Acquiring database lock", "operation", operation, "caller", caller)
d.mu.Lock()
d.lockedAt = time.Now()
d.lockedBy = fmt.Sprintf("%s (%s)", operation, caller)
d.logger.Debug("Database lock acquired", "operation", operation, "caller", caller)
}
// unlock releases the database mutex and logs debug information including hold duration
func (d *Database) unlock() {
holdDuration := time.Since(d.lockedAt)
lockedBy := d.lockedBy
d.lockedAt = time.Time{}
d.lockedBy = ""
d.mu.Unlock()
d.logger.Debug("Database lock released", "held_by", lockedBy, "duration_ms", holdDuration.Milliseconds())
}
// beginTx starts a new transaction with logging
func (d *Database) beginTx() (*loggingTx, error) {
tx, err := d.db.Begin()
if err != nil {
return nil, err
}
return &loggingTx{Tx: tx, logger: d.logger}, nil
}
// UpsertLiveRouteBatch inserts or updates multiple live routes in a single transaction
func (d *Database) UpsertLiveRouteBatch(routes []*LiveRoute) error {
if len(routes) == 0 {
return nil
}
d.lock("UpsertLiveRouteBatch")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
// Prepare statements for both IPv4 and IPv6
queryV4 := `
INSERT INTO live_routes_v4 (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
last_updated, ip_start, ip_end)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
mask_length = excluded.mask_length,
as_path = excluded.as_path,
next_hop = excluded.next_hop,
last_updated = excluded.last_updated,
ip_start = excluded.ip_start,
ip_end = excluded.ip_end
`
queryV6 := `
INSERT INTO live_routes_v6 (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
last_updated)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
mask_length = excluded.mask_length,
as_path = excluded.as_path,
next_hop = excluded.next_hop,
last_updated = excluded.last_updated
`
stmtV4, err := tx.Prepare(queryV4)
if err != nil {
return fmt.Errorf("failed to prepare IPv4 statement: %w", err)
}
defer func() { _ = stmtV4.Close() }()
stmtV6, err := tx.Prepare(queryV6)
if err != nil {
return fmt.Errorf("failed to prepare IPv6 statement: %w", err)
}
defer func() { _ = stmtV6.Close() }()
for _, route := range routes {
// Encode AS path as JSON
pathJSON, err := json.Marshal(route.ASPath)
if err != nil {
return fmt.Errorf("failed to encode AS path: %w", err)
}
// Use appropriate statement based on IP version
if route.IPVersion == ipVersionV4 {
// IPv4 routes must have range values
if route.V4IPStart == nil || route.V4IPEnd == nil {
return fmt.Errorf("IPv4 route %s missing range values", route.Prefix)
}
_, err = stmtV4.Exec(
route.ID.String(),
route.Prefix,
route.MaskLength,
route.OriginASN,
route.PeerIP,
string(pathJSON),
route.NextHop,
route.LastUpdated,
*route.V4IPStart,
*route.V4IPEnd,
)
} else {
// IPv6 routes
_, err = stmtV6.Exec(
route.ID.String(),
route.Prefix,
route.MaskLength,
route.OriginASN,
route.PeerIP,
string(pathJSON),
route.NextHop,
route.LastUpdated,
)
}
if err != nil {
return fmt.Errorf("failed to upsert route %s: %w", route.Prefix, err)
}
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// DeleteLiveRouteBatch deletes multiple live routes in a single transaction
func (d *Database) DeleteLiveRouteBatch(deletions []LiveRouteDeletion) error {
if len(deletions) == 0 {
return nil
}
d.lock("DeleteLiveRouteBatch")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
// No longer need to separate deletions since we handle them in the loop below
// Prepare statements for both IPv4 and IPv6 tables
stmtV4WithOrigin, err := tx.Prepare(`DELETE FROM live_routes_v4 WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`)
if err != nil {
return fmt.Errorf("failed to prepare IPv4 delete with origin statement: %w", err)
}
defer func() { _ = stmtV4WithOrigin.Close() }()
stmtV4WithoutOrigin, err := tx.Prepare(`DELETE FROM live_routes_v4 WHERE prefix = ? AND peer_ip = ?`)
if err != nil {
return fmt.Errorf("failed to prepare IPv4 delete without origin statement: %w", err)
}
defer func() { _ = stmtV4WithoutOrigin.Close() }()
stmtV6WithOrigin, err := tx.Prepare(`DELETE FROM live_routes_v6 WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`)
if err != nil {
return fmt.Errorf("failed to prepare IPv6 delete with origin statement: %w", err)
}
defer func() { _ = stmtV6WithOrigin.Close() }()
stmtV6WithoutOrigin, err := tx.Prepare(`DELETE FROM live_routes_v6 WHERE prefix = ? AND peer_ip = ?`)
if err != nil {
return fmt.Errorf("failed to prepare IPv6 delete without origin statement: %w", err)
}
defer func() { _ = stmtV6WithoutOrigin.Close() }()
// Process deletions
for _, del := range deletions {
var stmt *sql.Stmt
// Select appropriate statement based on IP version and whether we have origin ASN
//nolint:nestif // Clear logic for selecting the right statement
if del.IPVersion == ipVersionV4 {
if del.OriginASN == 0 {
stmt = stmtV4WithoutOrigin
} else {
stmt = stmtV4WithOrigin
}
} else {
if del.OriginASN == 0 {
stmt = stmtV6WithoutOrigin
} else {
stmt = stmtV6WithOrigin
}
}
// Execute deletion
if del.OriginASN == 0 {
_, err = stmt.Exec(del.Prefix, del.PeerIP)
} else {
_, err = stmt.Exec(del.Prefix, del.OriginASN, del.PeerIP)
}
if err != nil {
return fmt.Errorf("failed to delete route %s: %w", del.Prefix, err)
}
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// UpdatePrefixesBatch updates the last_seen time for multiple prefixes in a single transaction
func (d *Database) UpdatePrefixesBatch(prefixes map[string]time.Time) error {
if len(prefixes) == 0 {
return nil
}
d.lock("UpdatePrefixesBatch")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
// Prepare statements for both IPv4 and IPv6 tables
selectV4Stmt, err := tx.Prepare("SELECT id FROM prefixes_v4 WHERE prefix = ?")
if err != nil {
return fmt.Errorf("failed to prepare IPv4 select statement: %w", err)
}
defer func() { _ = selectV4Stmt.Close() }()
updateV4Stmt, err := tx.Prepare("UPDATE prefixes_v4 SET last_seen = ? WHERE prefix = ?")
if err != nil {
return fmt.Errorf("failed to prepare IPv4 update statement: %w", err)
}
defer func() { _ = updateV4Stmt.Close() }()
insertV4Stmt, err := tx.Prepare("INSERT INTO prefixes_v4 (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)")
if err != nil {
return fmt.Errorf("failed to prepare IPv4 insert statement: %w", err)
}
defer func() { _ = insertV4Stmt.Close() }()
selectV6Stmt, err := tx.Prepare("SELECT id FROM prefixes_v6 WHERE prefix = ?")
if err != nil {
return fmt.Errorf("failed to prepare IPv6 select statement: %w", err)
}
defer func() { _ = selectV6Stmt.Close() }()
updateV6Stmt, err := tx.Prepare("UPDATE prefixes_v6 SET last_seen = ? WHERE prefix = ?")
if err != nil {
return fmt.Errorf("failed to prepare IPv6 update statement: %w", err)
}
defer func() { _ = updateV6Stmt.Close() }()
insertV6Stmt, err := tx.Prepare("INSERT INTO prefixes_v6 (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)")
if err != nil {
return fmt.Errorf("failed to prepare IPv6 insert statement: %w", err)
}
defer func() { _ = insertV6Stmt.Close() }()
for prefix, timestamp := range prefixes {
ipVersion := detectIPVersion(prefix)
var selectStmt, updateStmt, insertStmt *sql.Stmt
if ipVersion == ipVersionV4 {
selectStmt, updateStmt, insertStmt = selectV4Stmt, updateV4Stmt, insertV4Stmt
} else {
selectStmt, updateStmt, insertStmt = selectV6Stmt, updateV6Stmt, insertV6Stmt
}
var id string
err = selectStmt.QueryRow(prefix).Scan(&id)
switch err {
case nil:
// Prefix exists, update last_seen
_, err = updateStmt.Exec(timestamp, prefix)
if err != nil {
return fmt.Errorf("failed to update prefix %s: %w", prefix, err)
}
case sql.ErrNoRows:
// Prefix doesn't exist, create it
_, err = insertStmt.Exec(generateUUID().String(), prefix, timestamp, timestamp)
if err != nil {
return fmt.Errorf("failed to insert prefix %s: %w", prefix, err)
}
default:
return fmt.Errorf("failed to query prefix %s: %w", prefix, err)
}
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// GetOrCreateASNBatch creates or updates multiple ASNs in a single transaction
func (d *Database) GetOrCreateASNBatch(asns map[int]time.Time) error {
if len(asns) == 0 {
return nil
}
d.lock("GetOrCreateASNBatch")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
// Prepare statements
selectStmt, err := tx.Prepare(
"SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?")
if err != nil {
return fmt.Errorf("failed to prepare select statement: %w", err)
}
defer func() { _ = selectStmt.Close() }()
updateStmt, err := tx.Prepare("UPDATE asns SET last_seen = ? WHERE asn = ?")
if err != nil {
return fmt.Errorf("failed to prepare update statement: %w", err)
}
defer func() { _ = updateStmt.Close() }()
insertStmt, err := tx.Prepare(
"INSERT INTO asns (asn, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)")
if err != nil {
return fmt.Errorf("failed to prepare insert statement: %w", err)
}
defer func() { _ = insertStmt.Close() }()
for number, timestamp := range asns {
var asn ASN
var handle, description sql.NullString
err = selectStmt.QueryRow(number).Scan(&asn.ASN, &handle, &description, &asn.FirstSeen, &asn.LastSeen)
if err == nil {
// ASN exists, update last_seen
_, err = updateStmt.Exec(timestamp, number)
if err != nil {
return fmt.Errorf("failed to update ASN %d: %w", number, err)
}
continue
}
if err == sql.ErrNoRows {
// ASN doesn't exist, create it
asn = ASN{
ASN: number,
FirstSeen: timestamp,
LastSeen: timestamp,
}
// Look up ASN info
if info, ok := asinfo.Get(number); ok {
asn.Handle = info.Handle
asn.Description = info.Description
}
_, err = insertStmt.Exec(asn.ASN, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen)
if err != nil {
return fmt.Errorf("failed to insert ASN %d: %w", number, err)
}
continue
}
if err != nil {
return fmt.Errorf("failed to query ASN %d: %w", number, err)
}
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// GetOrCreateASN retrieves an existing ASN or creates a new one if it doesn't exist.
func (d *Database) GetOrCreateASN(number int, timestamp time.Time) (*ASN, error) {
d.lock("GetOrCreateASN")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return nil, err
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
var asn ASN
var handle, description sql.NullString
err = tx.QueryRow("SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?", number).
Scan(&asn.ASN, &handle, &description, &asn.FirstSeen, &asn.LastSeen)
if err == nil {
// ASN exists, update last_seen
asn.Handle = handle.String
asn.Description = description.String
_, err = tx.Exec("UPDATE asns SET last_seen = ? WHERE asn = ?", timestamp, number)
if err != nil {
return nil, err
}
asn.LastSeen = timestamp
if err = tx.Commit(); err != nil {
d.logger.Error("Failed to commit transaction for ASN update", "asn", number, "error", err)
return nil, err
}
return &asn, nil
}
if err != sql.ErrNoRows {
return nil, err
}
// ASN doesn't exist, create it with ASN info lookup
asn = ASN{
ASN: number,
FirstSeen: timestamp,
LastSeen: timestamp,
}
// Look up ASN info
if info, ok := asinfo.Get(number); ok {
asn.Handle = info.Handle
asn.Description = info.Description
}
_, err = tx.Exec("INSERT INTO asns (asn, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)",
asn.ASN, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen)
if err != nil {
return nil, err
}
if err = tx.Commit(); err != nil {
d.logger.Error("Failed to commit transaction for ASN creation", "asn", number, "error", err)
return nil, err
}
return &asn, nil
}
// GetOrCreatePrefix retrieves an existing prefix or creates a new one if it doesn't exist.
func (d *Database) GetOrCreatePrefix(prefix string, timestamp time.Time) (*Prefix, error) {
d.lock("GetOrCreatePrefix")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return nil, err
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
// Determine table based on IP version
ipVersion := detectIPVersion(prefix)
tableName := "prefixes_v4"
if ipVersion == ipVersionV6 {
tableName = "prefixes_v6"
}
var p Prefix
var idStr string
query := fmt.Sprintf("SELECT id, prefix, first_seen, last_seen FROM %s WHERE prefix = ?", tableName)
err = tx.QueryRow(query, prefix).
Scan(&idStr, &p.Prefix, &p.FirstSeen, &p.LastSeen)
if err == nil {
// Prefix exists, update last_seen
p.ID, _ = uuid.Parse(idStr)
p.IPVersion = ipVersion
updateQuery := fmt.Sprintf("UPDATE %s SET last_seen = ? WHERE id = ?", tableName)
_, err = tx.Exec(updateQuery, timestamp, p.ID.String())
if err != nil {
return nil, err
}
p.LastSeen = timestamp
if err = tx.Commit(); err != nil {
d.logger.Error("Failed to commit transaction for prefix update", "prefix", prefix, "error", err)
return nil, err
}
return &p, nil
}
if err != sql.ErrNoRows {
return nil, err
}
// Prefix doesn't exist, create it
p = Prefix{
ID: generateUUID(),
Prefix: prefix,
IPVersion: ipVersion,
FirstSeen: timestamp,
LastSeen: timestamp,
}
insertQuery := fmt.Sprintf("INSERT INTO %s (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)", tableName)
_, err = tx.Exec(insertQuery, p.ID.String(), p.Prefix, p.FirstSeen, p.LastSeen)
if err != nil {
return nil, err
}
if err = tx.Commit(); err != nil {
d.logger.Error("Failed to commit transaction for prefix creation", "prefix", prefix, "error", err)
return nil, err
}
return &p, nil
}
// RecordAnnouncement inserts a new BGP announcement or withdrawal into the database.
func (d *Database) RecordAnnouncement(announcement *Announcement) error {
d.lock("RecordAnnouncement")
defer d.unlock()
err := d.exec(`
INSERT INTO announcements (id, prefix_id, peer_asn, origin_asn, path, next_hop, timestamp, is_withdrawal)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
announcement.ID.String(), announcement.PrefixID.String(),
announcement.PeerASN, announcement.OriginASN,
announcement.Path, announcement.NextHop, announcement.Timestamp, announcement.IsWithdrawal)
return err
}
// RecordPeering records a peering relationship between two ASNs.
func (d *Database) RecordPeering(asA, asB int, timestamp time.Time) error {
// Validate ASNs
if asA <= 0 || asB <= 0 {
return fmt.Errorf("invalid ASN: asA=%d, asB=%d", asA, asB)
}
if asA == asB {
return fmt.Errorf("cannot create peering with same ASN: %d", asA)
}
// Normalize: ensure asA < asB
if asA > asB {
asA, asB = asB, asA
}
d.lock("RecordPeering")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return err
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
var exists bool
err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM peerings WHERE as_a = ? AND as_b = ?)",
asA, asB).Scan(&exists)
if err != nil {
return err
}
if exists {
_, err = tx.Exec("UPDATE peerings SET last_seen = ? WHERE as_a = ? AND as_b = ?",
timestamp, asA, asB)
} else {
_, err = tx.Exec(`
INSERT INTO peerings (id, as_a, as_b, first_seen, last_seen)
VALUES (?, ?, ?, ?, ?)`,
generateUUID().String(), asA, asB, timestamp, timestamp)
}
if err != nil {
return err
}
if err = tx.Commit(); err != nil {
d.logger.Error("Failed to commit transaction for peering",
"as_a", asA,
"as_b", asB,
"error", err,
)
return err
}
return nil
}
// UpdatePeerBatch updates or creates multiple BGP peer records in a single transaction
func (d *Database) UpdatePeerBatch(peers map[string]PeerUpdate) error {
if len(peers) == 0 {
return nil
}
d.lock("UpdatePeerBatch")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
// Prepare statements
checkStmt, err := tx.Prepare("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)")
if err != nil {
return fmt.Errorf("failed to prepare check statement: %w", err)
}
defer func() { _ = checkStmt.Close() }()
updateStmt, err := tx.Prepare(
"UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?")
if err != nil {
return fmt.Errorf("failed to prepare update statement: %w", err)
}
defer func() { _ = updateStmt.Close() }()
insertStmt, err := tx.Prepare(
"INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)")
if err != nil {
return fmt.Errorf("failed to prepare insert statement: %w", err)
}
defer func() { _ = insertStmt.Close() }()
for _, update := range peers {
var exists bool
err = checkStmt.QueryRow(update.PeerIP).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check peer %s: %w", update.PeerIP, err)
}
if exists {
_, err = updateStmt.Exec(update.PeerASN, update.Timestamp, update.MessageType, update.PeerIP)
} else {
_, err = insertStmt.Exec(
generateUUID().String(), update.PeerIP, update.PeerASN,
update.Timestamp, update.Timestamp, update.MessageType)
}
if err != nil {
return fmt.Errorf("failed to update peer %s: %w", update.PeerIP, err)
}
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// UpdatePeer updates or creates a BGP peer record
func (d *Database) UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error {
d.lock("UpdatePeer")
defer d.unlock()
tx, err := d.beginTx()
if err != nil {
return err
}
defer func() {
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
d.logger.Error("Failed to rollback transaction", "error", err)
}
}()
var exists bool
err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)", peerIP).Scan(&exists)
if err != nil {
return err
}
if exists {
_, err = tx.Exec(
"UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?",
peerASN, timestamp, messageType, peerIP,
)
} else {
_, err = tx.Exec(
"INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)",
generateUUID().String(), peerIP, peerASN, timestamp, timestamp, messageType,
)
}
if err != nil {
return err
}
if err = tx.Commit(); err != nil {
d.logger.Error("Failed to commit transaction for peer update",
"peer_ip", peerIP,
"peer_asn", peerASN,
"error", err,
)
return err
}
return nil
}
// GetStats returns database statistics
func (d *Database) GetStats() (Stats, error) {
return d.GetStatsContext(context.Background())
}
// GetStatsContext returns database statistics with context support
func (d *Database) GetStatsContext(ctx context.Context) (Stats, error) {
var stats Stats
// Count ASNs
err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM asns").Scan(&stats.ASNs)
if err != nil {
return stats, err
}
// Count prefixes from both tables
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes_v4").Scan(&stats.IPv4Prefixes)
if err != nil {
return stats, err
}
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes_v6").Scan(&stats.IPv6Prefixes)
if err != nil {
return stats, err
}
stats.Prefixes = stats.IPv4Prefixes + stats.IPv6Prefixes
// Count peerings
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings)
if err != nil {
return stats, err
}
// Count peers
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers)
if err != nil {
return stats, err
}
// Get database file size
fileInfo, err := os.Stat(d.path)
if err != nil {
d.logger.Warn("Failed to get database file size", "error", err)
stats.FileSizeBytes = 0
} else {
stats.FileSizeBytes = fileInfo.Size()
}
// Get live routes count from both tables
var v4Count, v6Count int
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v4").Scan(&v4Count)
if err != nil {
return stats, fmt.Errorf("failed to count IPv4 routes: %w", err)
}
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v6").Scan(&v6Count)
if err != nil {
return stats, fmt.Errorf("failed to count IPv6 routes: %w", err)
}
stats.LiveRoutes = v4Count + v6Count
// Get prefix distribution
stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistributionContext(ctx)
if err != nil {
// Log but don't fail
d.logger.Warn("Failed to get prefix distribution", "error", err)
}
return stats, nil
}
// UpsertLiveRoute inserts or updates a live route
func (d *Database) UpsertLiveRoute(route *LiveRoute) error {
d.lock("UpsertLiveRoute")
defer d.unlock()
// Choose table based on IP version
tableName := "live_routes_v4"
if route.IPVersion == ipVersionV6 {
tableName = "live_routes_v6"
}
var query string
if route.IPVersion == ipVersionV4 {
query = fmt.Sprintf(`
INSERT INTO %s (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
last_updated, ip_start, ip_end)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
mask_length = excluded.mask_length,
as_path = excluded.as_path,
next_hop = excluded.next_hop,
last_updated = excluded.last_updated,
ip_start = excluded.ip_start,
ip_end = excluded.ip_end
`, tableName)
} else {
query = fmt.Sprintf(`
INSERT INTO %s (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
last_updated)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
mask_length = excluded.mask_length,
as_path = excluded.as_path,
next_hop = excluded.next_hop,
last_updated = excluded.last_updated
`, tableName)
}
// Encode AS path as JSON
pathJSON, err := json.Marshal(route.ASPath)
if err != nil {
return fmt.Errorf("failed to encode AS path: %w", err)
}
if route.IPVersion == ipVersionV4 {
// Convert v4_ip_start and v4_ip_end to interface{} for SQL NULL handling
var v4Start, v4End interface{}
if route.V4IPStart != nil {
v4Start = *route.V4IPStart
}
if route.V4IPEnd != nil {
v4End = *route.V4IPEnd
}
_, err = d.db.Exec(query,
route.ID.String(),
route.Prefix,
route.MaskLength,
route.OriginASN,
route.PeerIP,
string(pathJSON),
route.NextHop,
route.LastUpdated,
v4Start,
v4End,
)
} else {
_, err = d.db.Exec(query,
route.ID.String(),
route.Prefix,
route.MaskLength,
route.OriginASN,
route.PeerIP,
string(pathJSON),
route.NextHop,
route.LastUpdated,
)
}
return err
}
// DeleteLiveRoute deletes a live route
// If originASN is 0, deletes all routes for the prefix/peer combination
func (d *Database) DeleteLiveRoute(prefix string, originASN int, peerIP string) error {
d.lock("DeleteLiveRoute")
defer d.unlock()
// Determine table based on prefix IP version
_, ipnet, err := net.ParseCIDR(prefix)
if err != nil {
return fmt.Errorf("invalid prefix format: %w", err)
}
tableName := "live_routes_v4"
if ipnet.IP.To4() == nil {
tableName = "live_routes_v6"
}
var query string
if originASN == 0 {
// Delete all routes for this prefix from this peer
query = fmt.Sprintf(`DELETE FROM %s WHERE prefix = ? AND peer_ip = ?`, tableName)
_, err = d.db.Exec(query, prefix, peerIP)
} else {
// Delete specific route
query = fmt.Sprintf(`DELETE FROM %s WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`, tableName)
_, err = d.db.Exec(query, prefix, originASN, peerIP)
}
return err
}
// GetPrefixDistribution returns the distribution of unique prefixes by mask length
func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
return d.GetPrefixDistributionContext(context.Background())
}
// GetPrefixDistributionContext returns the distribution of unique prefixes by mask length with context support
func (d *Database) GetPrefixDistributionContext(ctx context.Context) (
ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
// IPv4 distribution - count unique prefixes from v4 table
query := `
SELECT mask_length, COUNT(DISTINCT prefix) as count
FROM live_routes_v4
GROUP BY mask_length
ORDER BY mask_length
`
rows4, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, nil, fmt.Errorf("failed to query IPv4 distribution: %w", err)
}
defer func() {
if rows4 != nil {
_ = rows4.Close()
}
}()
for rows4.Next() {
var dist PrefixDistribution
if err := rows4.Scan(&dist.MaskLength, &dist.Count); err != nil {
return nil, nil, fmt.Errorf("failed to scan IPv4 distribution: %w", err)
}
ipv4 = append(ipv4, dist)
}
// IPv6 distribution - count unique prefixes from v6 table
query = `
SELECT mask_length, COUNT(DISTINCT prefix) as count
FROM live_routes_v6
GROUP BY mask_length
ORDER BY mask_length
`
rows6, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, nil, fmt.Errorf("failed to query IPv6 distribution: %w", err)
}
defer func() {
if rows6 != nil {
_ = rows6.Close()
}
}()
for rows6.Next() {
var dist PrefixDistribution
if err := rows6.Scan(&dist.MaskLength, &dist.Count); err != nil {
return nil, nil, fmt.Errorf("failed to scan IPv6 distribution: %w", err)
}
ipv6 = append(ipv6, dist)
}
return ipv4, ipv6, nil
}
// GetLiveRouteCounts returns the count of IPv4 and IPv6 routes
func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
return d.GetLiveRouteCountsContext(context.Background())
}
// GetLiveRouteCountsContext returns the count of IPv4 and IPv6 routes with context support
func (d *Database) GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error) {
// Get IPv4 count from dedicated table
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v4").Scan(&ipv4Count)
if err != nil {
return 0, 0, fmt.Errorf("failed to count IPv4 routes: %w", err)
}
// Get IPv6 count from dedicated table
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v6").Scan(&ipv6Count)
if err != nil {
return 0, 0, fmt.Errorf("failed to count IPv6 routes: %w", err)
}
return ipv4Count, ipv6Count, nil
}
// GetASInfoForIP returns AS information for the given IP address
func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
return d.GetASInfoForIPContext(context.Background(), ip)
}
// GetASInfoForIPContext returns AS information for the given IP address with context support
func (d *Database) GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error) {
// Parse the IP to validate it
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return nil, fmt.Errorf("%w: %s", ErrInvalidIP, ip)
}
// Determine IP version
ipVersion := ipVersionV4
ipv4 := parsedIP.To4()
if ipv4 == nil {
ipVersion = ipVersionV6
}
// For IPv4, use optimized range query
if ipVersion == ipVersionV4 {
// Convert IP to 32-bit unsigned integer
ipUint := ipToUint32(ipv4)
query := `
SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description
FROM live_routes_v4 lr
LEFT JOIN asns a ON a.asn = lr.origin_asn
WHERE lr.ip_start <= ? AND lr.ip_end >= ?
ORDER BY lr.mask_length DESC
LIMIT 1
`
var prefix string
var maskLength, originASN int
var lastUpdated time.Time
var handle, description sql.NullString
err := d.db.QueryRowContext(ctx, query, ipUint, ipUint).Scan(
&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip)
}
return nil, fmt.Errorf("failed to query routes: %w", err)
}
age := time.Since(lastUpdated).Round(time.Second).String()
return &ASInfo{
ASN: originASN,
Handle: handle.String,
Description: description.String,
Prefix: prefix,
LastUpdated: lastUpdated,
Age: age,
}, nil
}
// For IPv6, use the original method since we don't have range optimization
query := `
SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description
FROM live_routes_v6 lr
LEFT JOIN asns a ON a.asn = lr.origin_asn
ORDER BY lr.mask_length DESC
`
rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query routes: %w", err)
}
defer func() { _ = rows.Close() }()
// Find the most specific matching prefix
var bestMatch struct {
prefix string
maskLength int
originASN int
lastUpdated time.Time
handle sql.NullString
description sql.NullString
}
bestMaskLength := -1
for rows.Next() {
var prefix string
var maskLength, originASN int
var lastUpdated time.Time
var handle, description sql.NullString
if err := rows.Scan(&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description); err != nil {
continue
}
// Parse the prefix CIDR
_, ipNet, err := net.ParseCIDR(prefix)
if err != nil {
continue
}
// Check if the IP is in this prefix
if ipNet.Contains(parsedIP) && maskLength > bestMaskLength {
bestMatch.prefix = prefix
bestMatch.maskLength = maskLength
bestMatch.originASN = originASN
bestMatch.lastUpdated = lastUpdated
bestMatch.handle = handle
bestMatch.description = description
bestMaskLength = maskLength
}
}
if bestMaskLength == -1 {
return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip)
}
age := time.Since(bestMatch.lastUpdated).Round(time.Second).String()
return &ASInfo{
ASN: bestMatch.originASN,
Handle: bestMatch.handle.String,
Description: bestMatch.description.String,
Prefix: bestMatch.prefix,
LastUpdated: bestMatch.lastUpdated,
Age: age,
}, nil
}
// ipToUint32 converts an IPv4 address to a 32-bit unsigned integer
func ipToUint32(ip net.IP) uint32 {
if len(ip) == ipv6Length {
// Convert to 4-byte representation
ip = ip[ipv4Offset:ipv6Length]
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
// CalculateIPv4Range calculates the start and end IP addresses for an IPv4 CIDR block
func CalculateIPv4Range(cidr string) (start, end uint32, err error) {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return 0, 0, err
}
// Get the network address (start of range)
ip := ipNet.IP.To4()
if ip == nil {
return 0, 0, fmt.Errorf("not an IPv4 address")
}
start = ipToUint32(ip)
// Calculate the end of the range
ones, bits := ipNet.Mask.Size()
hostBits := bits - ones
if hostBits >= ipv4Bits {
// Special case for /0 - entire IPv4 space
end = maxIPv4
} else {
// Safe to convert since we checked hostBits < 32
//nolint:gosec // hostBits is guaranteed to be < 32 from the check above
end = start | ((1 << uint(hostBits)) - 1)
}
return start, end, nil
}
// GetASDetails returns detailed information about an AS including prefixes
func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
return d.GetASDetailsContext(context.Background(), asn)
}
// GetASDetailsContext returns detailed information about an AS including prefixes with context support
func (d *Database) GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error) {
// Get AS information
var asnInfo ASN
var handle, description sql.NullString
err := d.db.QueryRowContext(ctx,
"SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?",
asn,
).Scan(&asnInfo.ASN, &handle, &description, &asnInfo.FirstSeen, &asnInfo.LastSeen)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, fmt.Errorf("%w: AS%d", ErrNoRoute, asn)
}
return nil, nil, fmt.Errorf("failed to query AS: %w", err)
}
asnInfo.Handle = handle.String
asnInfo.Description = description.String
// Get prefixes announced by this AS from both tables
var allPrefixes []LiveRoute
// Query IPv4 prefixes
queryV4 := `
SELECT prefix, mask_length, MAX(last_updated) as last_updated
FROM live_routes_v4
WHERE origin_asn = ?
GROUP BY prefix, mask_length
`
rows4, err := d.db.QueryContext(ctx, queryV4, asn)
if err != nil {
return &asnInfo, nil, fmt.Errorf("failed to query IPv4 prefixes: %w", err)
}
defer func() { _ = rows4.Close() }()
for rows4.Next() {
var route LiveRoute
var lastUpdatedStr string
err := rows4.Scan(&route.Prefix, &route.MaskLength, &lastUpdatedStr)
if err != nil {
d.logger.Error("Failed to scan IPv4 prefix row", "error", err, "asn", asn)
continue
}
// Parse the timestamp string
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05-07:00", lastUpdatedStr)
if err != nil {
// Try without timezone
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05", lastUpdatedStr)
if err != nil {
d.logger.Error("Failed to parse timestamp", "error", err, "timestamp", lastUpdatedStr)
continue
}
}
route.OriginASN = asn
route.IPVersion = ipVersionV4
allPrefixes = append(allPrefixes, route)
}
// Query IPv6 prefixes
queryV6 := `
SELECT prefix, mask_length, MAX(last_updated) as last_updated
FROM live_routes_v6
WHERE origin_asn = ?
GROUP BY prefix, mask_length
`
rows6, err := d.db.QueryContext(ctx, queryV6, asn)
if err != nil {
return &asnInfo, allPrefixes, fmt.Errorf("failed to query IPv6 prefixes: %w", err)
}
defer func() { _ = rows6.Close() }()
for rows6.Next() {
var route LiveRoute
var lastUpdatedStr string
err := rows6.Scan(&route.Prefix, &route.MaskLength, &lastUpdatedStr)
if err != nil {
d.logger.Error("Failed to scan IPv6 prefix row", "error", err, "asn", asn)
continue
}
// Parse the timestamp string
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05-07:00", lastUpdatedStr)
if err != nil {
// Try without timezone
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05", lastUpdatedStr)
if err != nil {
d.logger.Error("Failed to parse timestamp", "error", err, "timestamp", lastUpdatedStr)
continue
}
}
route.OriginASN = asn
route.IPVersion = ipVersionV6
allPrefixes = append(allPrefixes, route)
}
return &asnInfo, allPrefixes, nil
}
// ASPeer represents a peering relationship with another AS including handle, description, and timestamps.
type ASPeer struct {
ASN int `json:"asn"`
Handle string `json:"handle"`
Description string `json:"description"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
}
// GetASPeers returns all ASes that peer with the given AS
func (d *Database) GetASPeers(asn int) ([]ASPeer, error) {
return d.GetASPeersContext(context.Background(), asn)
}
// GetASPeersContext returns all ASes that peer with the given AS with context support
func (d *Database) GetASPeersContext(ctx context.Context, asn int) ([]ASPeer, error) {
query := `
SELECT
CASE
WHEN p.as_a = ? THEN p.as_b
ELSE p.as_a
END as peer_asn,
COALESCE(a.handle, '') as handle,
COALESCE(a.description, '') as description,
p.first_seen,
p.last_seen
FROM peerings p
LEFT JOIN asns a ON a.asn = CASE
WHEN p.as_a = ? THEN p.as_b
ELSE p.as_a
END
WHERE p.as_a = ? OR p.as_b = ?
ORDER BY peer_asn
`
rows, err := d.db.QueryContext(ctx, query, asn, asn, asn, asn)
if err != nil {
return nil, fmt.Errorf("failed to query AS peers: %w", err)
}
defer func() { _ = rows.Close() }()
var peers []ASPeer
for rows.Next() {
var peer ASPeer
err := rows.Scan(&peer.ASN, &peer.Handle, &peer.Description, &peer.FirstSeen, &peer.LastSeen)
if err != nil {
d.logger.Error("Failed to scan peer row", "error", err, "asn", asn)
continue
}
peers = append(peers, peer)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating AS peers: %w", err)
}
return peers, nil
}
// GetPrefixDetails returns detailed information about a prefix
func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
return d.GetPrefixDetailsContext(context.Background(), prefix)
}
// GetPrefixDetailsContext returns detailed information about a prefix with context support
func (d *Database) GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error) {
// Determine if it's IPv4 or IPv6 by parsing the prefix
_, ipnet, err := net.ParseCIDR(prefix)
if err != nil {
return nil, fmt.Errorf("invalid prefix format: %w", err)
}
tableName := "live_routes_v4"
ipVersion := ipVersionV4
if ipnet.IP.To4() == nil {
tableName = "live_routes_v6"
ipVersion = ipVersionV6
}
//nolint:gosec // Table name is hardcoded based on IP version
query := fmt.Sprintf(`
SELECT lr.origin_asn, lr.peer_ip, lr.as_path, lr.next_hop, lr.last_updated,
a.handle, a.description
FROM %s lr
LEFT JOIN asns a ON a.asn = lr.origin_asn
WHERE lr.prefix = ?
ORDER BY lr.origin_asn, lr.peer_ip
`, tableName)
rows, err := d.db.QueryContext(ctx, query, prefix)
if err != nil {
return nil, fmt.Errorf("failed to query prefix details: %w", err)
}
defer func() { _ = rows.Close() }()
var routes []LiveRoute
for rows.Next() {
var route LiveRoute
var pathJSON string
var handle, description sql.NullString
err := rows.Scan(
&route.OriginASN, &route.PeerIP, &pathJSON, &route.NextHop,
&route.LastUpdated, &handle, &description,
)
if err != nil {
continue
}
// Decode AS path
if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil {
route.ASPath = []int{}
}
route.Prefix = prefix
route.IPVersion = ipVersion
route.MaskLength, _ = ipnet.Mask.Size()
routes = append(routes, route)
}
if len(routes) == 0 {
return nil, fmt.Errorf("%w: %s", ErrNoRoute, prefix)
}
return routes, nil
}
// GetRandomPrefixesByLength returns a random sample of prefixes with the specified mask length
func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error) {
return d.GetRandomPrefixesByLengthContext(context.Background(), maskLength, ipVersion, limit)
}
// GetRandomPrefixesByLengthContext returns a random sample of prefixes with context support
func (d *Database) GetRandomPrefixesByLengthContext(
ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error) {
// Select unique prefixes with their most recent route information
tableName := "live_routes_v4"
if ipVersion == ipVersionV6 {
tableName = "live_routes_v6"
}
//nolint:gosec // Table name is hardcoded based on IP version
query := fmt.Sprintf(`
WITH unique_prefixes AS (
SELECT prefix, MAX(last_updated) as max_updated
FROM %s
WHERE mask_length = ?
GROUP BY prefix
ORDER BY RANDOM()
LIMIT ?
)
SELECT lr.prefix, lr.mask_length, lr.origin_asn, lr.as_path,
lr.peer_ip, lr.last_updated
FROM %s lr
INNER JOIN unique_prefixes up ON lr.prefix = up.prefix AND lr.last_updated = up.max_updated
WHERE lr.mask_length = ?
`, tableName, tableName)
rows, err := d.db.QueryContext(ctx, query, maskLength, limit, maskLength)
if err != nil {
return nil, fmt.Errorf("failed to query random prefixes: %w", err)
}
defer func() {
_ = rows.Close()
}()
var routes []LiveRoute
for rows.Next() {
var route LiveRoute
var pathJSON string
err := rows.Scan(
&route.Prefix,
&route.MaskLength,
&route.OriginASN,
&pathJSON,
&route.PeerIP,
&route.LastUpdated,
)
// Set IP version based on which table we queried
route.IPVersion = ipVersion
if err != nil {
continue
}
// Decode AS path
if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil {
route.ASPath = []int{}
}
routes = append(routes, route)
}
return routes, nil
}