Add comprehensive godoc comments to all exported types, functions, and constants throughout the codebase. Create README.md documenting the project architecture, execution flow, database schema, and component relationships.
1633 lines
45 KiB
Go
1633 lines
45 KiB
Go
// Package database provides SQLite storage for BGP routing data including ASNs, prefixes, announcements and peerings.
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
_ "embed"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.eeqj.de/sneak/routewatch/internal/config"
|
|
"git.eeqj.de/sneak/routewatch/internal/logger"
|
|
"git.eeqj.de/sneak/routewatch/pkg/asinfo"
|
|
"github.com/google/uuid"
|
|
_ "github.com/mattn/go-sqlite3" // CGO SQLite driver
|
|
)
|
|
|
|
// IMPORTANT: NO schema changes are to be made outside of schema.sql
|
|
// We do NOT support migrations. All schema changes MUST be made in schema.sql only.
|
|
//
|
|
//go:embed schema.sql
|
|
var dbSchema string
|
|
|
|
const (
|
|
dirPermissions = 0750 // rwxr-x---
|
|
ipVersionV4 = 4
|
|
ipVersionV6 = 6
|
|
ipv6Length = 16
|
|
ipv4Offset = 12
|
|
ipv4Bits = 32
|
|
maxIPv4 = 0xFFFFFFFF
|
|
)
|
|
|
|
// Common errors
|
|
var (
|
|
// ErrInvalidIP is returned when an IP address is malformed
|
|
ErrInvalidIP = errors.New("invalid IP address")
|
|
// ErrNoRoute is returned when no route is found for an IP
|
|
ErrNoRoute = errors.New("no route found")
|
|
)
|
|
|
|
// Database manages the SQLite database connection and operations.
|
|
type Database struct {
|
|
db *sql.DB
|
|
logger *logger.Logger
|
|
path string
|
|
mu sync.Mutex
|
|
lockedAt time.Time
|
|
lockedBy string
|
|
}
|
|
|
|
// New creates a new database connection and initializes the schema.
|
|
func New(cfg *config.Config, logger *logger.Logger) (*Database, error) {
|
|
dbPath := filepath.Join(cfg.GetStateDir(), "db.sqlite")
|
|
|
|
// Log database path
|
|
logger.Info("Opening database", "path", dbPath)
|
|
|
|
// Ensure directory exists
|
|
dir := filepath.Dir(dbPath)
|
|
if err := os.MkdirAll(dir, dirPermissions); err != nil {
|
|
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
|
}
|
|
|
|
// Add connection parameters for go-sqlite3
|
|
// Configure SQLite connection parameters
|
|
dsn := fmt.Sprintf(
|
|
"file:%s",
|
|
dbPath,
|
|
)
|
|
db, err := sql.Open("sqlite3", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
// Set connection pool parameters
|
|
// Multiple connections allow concurrent reads while writes are serialized
|
|
const maxConns = 10
|
|
db.SetMaxOpenConns(maxConns)
|
|
db.SetMaxIdleConns(maxConns)
|
|
db.SetConnMaxLifetime(0)
|
|
|
|
database := &Database{db: db, logger: logger, path: dbPath}
|
|
|
|
if err := database.Initialize(); err != nil {
|
|
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
|
}
|
|
|
|
return database, nil
|
|
}
|
|
|
|
// Initialize creates the database schema if it doesn't exist.
|
|
func (d *Database) Initialize() error {
|
|
// Set SQLite pragmas for performance
|
|
pragmas := []string{
|
|
"PRAGMA journal_mode=WAL", // Write-Ahead Logging
|
|
"PRAGMA synchronous=OFF", // Don't wait for disk writes
|
|
"PRAGMA cache_size=-3145728", // 3GB cache (upper limit for 2.4GB DB)
|
|
"PRAGMA temp_store=MEMORY", // Use memory for temp tables
|
|
"PRAGMA wal_checkpoint(TRUNCATE)", // Checkpoint and truncate WAL now
|
|
"PRAGMA busy_timeout=5000", // 5 second busy timeout
|
|
"PRAGMA analysis_limit=0", // Disable automatic ANALYZE
|
|
}
|
|
|
|
for _, pragma := range pragmas {
|
|
if err := d.exec(pragma); err != nil {
|
|
d.logger.Warn("Failed to set pragma", "pragma", pragma, "error", err)
|
|
}
|
|
}
|
|
|
|
err := d.exec(dbSchema)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Run VACUUM on startup to optimize database
|
|
d.logger.Info("Running VACUUM to optimize database (this may take a moment)")
|
|
if err := d.exec("VACUUM"); err != nil {
|
|
d.logger.Warn("Failed to VACUUM database", "error", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes the database connection.
|
|
func (d *Database) Close() error {
|
|
return d.db.Close()
|
|
}
|
|
|
|
// lock acquires the database mutex and logs debug information
|
|
func (d *Database) lock(operation string) {
|
|
// Get caller information
|
|
_, file, line, _ := runtime.Caller(1)
|
|
caller := fmt.Sprintf("%s:%d", filepath.Base(file), line)
|
|
|
|
d.logger.Debug("Acquiring database lock", "operation", operation, "caller", caller)
|
|
|
|
d.mu.Lock()
|
|
d.lockedAt = time.Now()
|
|
d.lockedBy = fmt.Sprintf("%s (%s)", operation, caller)
|
|
|
|
d.logger.Debug("Database lock acquired", "operation", operation, "caller", caller)
|
|
}
|
|
|
|
// unlock releases the database mutex and logs debug information including hold duration
|
|
func (d *Database) unlock() {
|
|
holdDuration := time.Since(d.lockedAt)
|
|
lockedBy := d.lockedBy
|
|
|
|
d.lockedAt = time.Time{}
|
|
d.lockedBy = ""
|
|
d.mu.Unlock()
|
|
|
|
d.logger.Debug("Database lock released", "held_by", lockedBy, "duration_ms", holdDuration.Milliseconds())
|
|
}
|
|
|
|
// beginTx starts a new transaction with logging
|
|
func (d *Database) beginTx() (*loggingTx, error) {
|
|
tx, err := d.db.Begin()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &loggingTx{Tx: tx, logger: d.logger}, nil
|
|
}
|
|
|
|
// UpsertLiveRouteBatch inserts or updates multiple live routes in a single transaction
|
|
func (d *Database) UpsertLiveRouteBatch(routes []*LiveRoute) error {
|
|
if len(routes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("UpsertLiveRouteBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements for both IPv4 and IPv6
|
|
queryV4 := `
|
|
INSERT INTO live_routes_v4 (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated, ip_start, ip_end)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated,
|
|
ip_start = excluded.ip_start,
|
|
ip_end = excluded.ip_end
|
|
`
|
|
|
|
queryV6 := `
|
|
INSERT INTO live_routes_v6 (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated
|
|
`
|
|
|
|
stmtV4, err := tx.Prepare(queryV4)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV4.Close() }()
|
|
|
|
stmtV6, err := tx.Prepare(queryV6)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV6.Close() }()
|
|
|
|
for _, route := range routes {
|
|
// Encode AS path as JSON
|
|
pathJSON, err := json.Marshal(route.ASPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encode AS path: %w", err)
|
|
}
|
|
|
|
// Use appropriate statement based on IP version
|
|
if route.IPVersion == ipVersionV4 {
|
|
// IPv4 routes must have range values
|
|
if route.V4IPStart == nil || route.V4IPEnd == nil {
|
|
return fmt.Errorf("IPv4 route %s missing range values", route.Prefix)
|
|
}
|
|
|
|
_, err = stmtV4.Exec(
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
*route.V4IPStart,
|
|
*route.V4IPEnd,
|
|
)
|
|
} else {
|
|
// IPv6 routes
|
|
_, err = stmtV6.Exec(
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to upsert route %s: %w", route.Prefix, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteLiveRouteBatch deletes multiple live routes in a single transaction
|
|
func (d *Database) DeleteLiveRouteBatch(deletions []LiveRouteDeletion) error {
|
|
if len(deletions) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("DeleteLiveRouteBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// No longer need to separate deletions since we handle them in the loop below
|
|
|
|
// Prepare statements for both IPv4 and IPv6 tables
|
|
stmtV4WithOrigin, err := tx.Prepare(`DELETE FROM live_routes_v4 WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 delete with origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV4WithOrigin.Close() }()
|
|
|
|
stmtV4WithoutOrigin, err := tx.Prepare(`DELETE FROM live_routes_v4 WHERE prefix = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 delete without origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV4WithoutOrigin.Close() }()
|
|
|
|
stmtV6WithOrigin, err := tx.Prepare(`DELETE FROM live_routes_v6 WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 delete with origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV6WithOrigin.Close() }()
|
|
|
|
stmtV6WithoutOrigin, err := tx.Prepare(`DELETE FROM live_routes_v6 WHERE prefix = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 delete without origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV6WithoutOrigin.Close() }()
|
|
|
|
// Process deletions
|
|
for _, del := range deletions {
|
|
var stmt *sql.Stmt
|
|
|
|
// Select appropriate statement based on IP version and whether we have origin ASN
|
|
//nolint:nestif // Clear logic for selecting the right statement
|
|
if del.IPVersion == ipVersionV4 {
|
|
if del.OriginASN == 0 {
|
|
stmt = stmtV4WithoutOrigin
|
|
} else {
|
|
stmt = stmtV4WithOrigin
|
|
}
|
|
} else {
|
|
if del.OriginASN == 0 {
|
|
stmt = stmtV6WithoutOrigin
|
|
} else {
|
|
stmt = stmtV6WithOrigin
|
|
}
|
|
}
|
|
|
|
// Execute deletion
|
|
if del.OriginASN == 0 {
|
|
_, err = stmt.Exec(del.Prefix, del.PeerIP)
|
|
} else {
|
|
_, err = stmt.Exec(del.Prefix, del.OriginASN, del.PeerIP)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete route %s: %w", del.Prefix, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdatePrefixesBatch updates the last_seen time for multiple prefixes in a single transaction
|
|
func (d *Database) UpdatePrefixesBatch(prefixes map[string]time.Time) error {
|
|
if len(prefixes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("UpdatePrefixesBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements for both IPv4 and IPv6 tables
|
|
selectV4Stmt, err := tx.Prepare("SELECT id FROM prefixes_v4 WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 select statement: %w", err)
|
|
}
|
|
defer func() { _ = selectV4Stmt.Close() }()
|
|
|
|
updateV4Stmt, err := tx.Prepare("UPDATE prefixes_v4 SET last_seen = ? WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateV4Stmt.Close() }()
|
|
|
|
insertV4Stmt, err := tx.Prepare("INSERT INTO prefixes_v4 (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertV4Stmt.Close() }()
|
|
|
|
selectV6Stmt, err := tx.Prepare("SELECT id FROM prefixes_v6 WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 select statement: %w", err)
|
|
}
|
|
defer func() { _ = selectV6Stmt.Close() }()
|
|
|
|
updateV6Stmt, err := tx.Prepare("UPDATE prefixes_v6 SET last_seen = ? WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateV6Stmt.Close() }()
|
|
|
|
insertV6Stmt, err := tx.Prepare("INSERT INTO prefixes_v6 (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertV6Stmt.Close() }()
|
|
|
|
for prefix, timestamp := range prefixes {
|
|
ipVersion := detectIPVersion(prefix)
|
|
|
|
var selectStmt, updateStmt, insertStmt *sql.Stmt
|
|
if ipVersion == ipVersionV4 {
|
|
selectStmt, updateStmt, insertStmt = selectV4Stmt, updateV4Stmt, insertV4Stmt
|
|
} else {
|
|
selectStmt, updateStmt, insertStmt = selectV6Stmt, updateV6Stmt, insertV6Stmt
|
|
}
|
|
|
|
var id string
|
|
err = selectStmt.QueryRow(prefix).Scan(&id)
|
|
|
|
switch err {
|
|
case nil:
|
|
// Prefix exists, update last_seen
|
|
_, err = updateStmt.Exec(timestamp, prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update prefix %s: %w", prefix, err)
|
|
}
|
|
case sql.ErrNoRows:
|
|
// Prefix doesn't exist, create it
|
|
_, err = insertStmt.Exec(generateUUID().String(), prefix, timestamp, timestamp)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert prefix %s: %w", prefix, err)
|
|
}
|
|
default:
|
|
return fmt.Errorf("failed to query prefix %s: %w", prefix, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetOrCreateASNBatch creates or updates multiple ASNs in a single transaction
|
|
func (d *Database) GetOrCreateASNBatch(asns map[int]time.Time) error {
|
|
if len(asns) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("GetOrCreateASNBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements
|
|
selectStmt, err := tx.Prepare(
|
|
"SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare select statement: %w", err)
|
|
}
|
|
defer func() { _ = selectStmt.Close() }()
|
|
|
|
updateStmt, err := tx.Prepare("UPDATE asns SET last_seen = ? WHERE asn = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateStmt.Close() }()
|
|
|
|
insertStmt, err := tx.Prepare(
|
|
"INSERT INTO asns (asn, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertStmt.Close() }()
|
|
|
|
for number, timestamp := range asns {
|
|
var asn ASN
|
|
var handle, description sql.NullString
|
|
|
|
err = selectStmt.QueryRow(number).Scan(&asn.ASN, &handle, &description, &asn.FirstSeen, &asn.LastSeen)
|
|
|
|
if err == nil {
|
|
// ASN exists, update last_seen
|
|
_, err = updateStmt.Exec(timestamp, number)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update ASN %d: %w", number, err)
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
if err == sql.ErrNoRows {
|
|
// ASN doesn't exist, create it
|
|
asn = ASN{
|
|
ASN: number,
|
|
FirstSeen: timestamp,
|
|
LastSeen: timestamp,
|
|
}
|
|
|
|
// Look up ASN info
|
|
if info, ok := asinfo.Get(number); ok {
|
|
asn.Handle = info.Handle
|
|
asn.Description = info.Description
|
|
}
|
|
|
|
_, err = insertStmt.Exec(asn.ASN, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert ASN %d: %w", number, err)
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query ASN %d: %w", number, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetOrCreateASN retrieves an existing ASN or creates a new one if it doesn't exist.
|
|
func (d *Database) GetOrCreateASN(number int, timestamp time.Time) (*ASN, error) {
|
|
d.lock("GetOrCreateASN")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
var asn ASN
|
|
var handle, description sql.NullString
|
|
err = tx.QueryRow("SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?", number).
|
|
Scan(&asn.ASN, &handle, &description, &asn.FirstSeen, &asn.LastSeen)
|
|
|
|
if err == nil {
|
|
// ASN exists, update last_seen
|
|
asn.Handle = handle.String
|
|
asn.Description = description.String
|
|
_, err = tx.Exec("UPDATE asns SET last_seen = ? WHERE asn = ?", timestamp, number)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
asn.LastSeen = timestamp
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for ASN update", "asn", number, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &asn, nil
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
|
|
// ASN doesn't exist, create it with ASN info lookup
|
|
asn = ASN{
|
|
ASN: number,
|
|
FirstSeen: timestamp,
|
|
LastSeen: timestamp,
|
|
}
|
|
|
|
// Look up ASN info
|
|
if info, ok := asinfo.Get(number); ok {
|
|
asn.Handle = info.Handle
|
|
asn.Description = info.Description
|
|
}
|
|
|
|
_, err = tx.Exec("INSERT INTO asns (asn, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)",
|
|
asn.ASN, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for ASN creation", "asn", number, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &asn, nil
|
|
}
|
|
|
|
// GetOrCreatePrefix retrieves an existing prefix or creates a new one if it doesn't exist.
|
|
func (d *Database) GetOrCreatePrefix(prefix string, timestamp time.Time) (*Prefix, error) {
|
|
d.lock("GetOrCreatePrefix")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Determine table based on IP version
|
|
ipVersion := detectIPVersion(prefix)
|
|
tableName := "prefixes_v4"
|
|
if ipVersion == ipVersionV6 {
|
|
tableName = "prefixes_v6"
|
|
}
|
|
|
|
var p Prefix
|
|
var idStr string
|
|
query := fmt.Sprintf("SELECT id, prefix, first_seen, last_seen FROM %s WHERE prefix = ?", tableName)
|
|
err = tx.QueryRow(query, prefix).
|
|
Scan(&idStr, &p.Prefix, &p.FirstSeen, &p.LastSeen)
|
|
|
|
if err == nil {
|
|
// Prefix exists, update last_seen
|
|
p.ID, _ = uuid.Parse(idStr)
|
|
p.IPVersion = ipVersion
|
|
updateQuery := fmt.Sprintf("UPDATE %s SET last_seen = ? WHERE id = ?", tableName)
|
|
_, err = tx.Exec(updateQuery, timestamp, p.ID.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.LastSeen = timestamp
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for prefix update", "prefix", prefix, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &p, nil
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
|
|
// Prefix doesn't exist, create it
|
|
p = Prefix{
|
|
ID: generateUUID(),
|
|
Prefix: prefix,
|
|
IPVersion: ipVersion,
|
|
FirstSeen: timestamp,
|
|
LastSeen: timestamp,
|
|
}
|
|
insertQuery := fmt.Sprintf("INSERT INTO %s (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)", tableName)
|
|
_, err = tx.Exec(insertQuery, p.ID.String(), p.Prefix, p.FirstSeen, p.LastSeen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for prefix creation", "prefix", prefix, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &p, nil
|
|
}
|
|
|
|
// RecordAnnouncement inserts a new BGP announcement or withdrawal into the database.
|
|
func (d *Database) RecordAnnouncement(announcement *Announcement) error {
|
|
d.lock("RecordAnnouncement")
|
|
defer d.unlock()
|
|
|
|
err := d.exec(`
|
|
INSERT INTO announcements (id, prefix_id, peer_asn, origin_asn, path, next_hop, timestamp, is_withdrawal)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
announcement.ID.String(), announcement.PrefixID.String(),
|
|
announcement.PeerASN, announcement.OriginASN,
|
|
announcement.Path, announcement.NextHop, announcement.Timestamp, announcement.IsWithdrawal)
|
|
|
|
return err
|
|
}
|
|
|
|
// RecordPeering records a peering relationship between two ASNs.
|
|
func (d *Database) RecordPeering(asA, asB int, timestamp time.Time) error {
|
|
// Validate ASNs
|
|
if asA <= 0 || asB <= 0 {
|
|
return fmt.Errorf("invalid ASN: asA=%d, asB=%d", asA, asB)
|
|
}
|
|
if asA == asB {
|
|
return fmt.Errorf("cannot create peering with same ASN: %d", asA)
|
|
}
|
|
|
|
// Normalize: ensure asA < asB
|
|
if asA > asB {
|
|
asA, asB = asB, asA
|
|
}
|
|
|
|
d.lock("RecordPeering")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
var exists bool
|
|
err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM peerings WHERE as_a = ? AND as_b = ?)",
|
|
asA, asB).Scan(&exists)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if exists {
|
|
_, err = tx.Exec("UPDATE peerings SET last_seen = ? WHERE as_a = ? AND as_b = ?",
|
|
timestamp, asA, asB)
|
|
} else {
|
|
_, err = tx.Exec(`
|
|
INSERT INTO peerings (id, as_a, as_b, first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
generateUUID().String(), asA, asB, timestamp, timestamp)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for peering",
|
|
"as_a", asA,
|
|
"as_b", asB,
|
|
"error", err,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdatePeerBatch updates or creates multiple BGP peer records in a single transaction
|
|
func (d *Database) UpdatePeerBatch(peers map[string]PeerUpdate) error {
|
|
if len(peers) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("UpdatePeerBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements
|
|
checkStmt, err := tx.Prepare("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare check statement: %w", err)
|
|
}
|
|
defer func() { _ = checkStmt.Close() }()
|
|
|
|
updateStmt, err := tx.Prepare(
|
|
"UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateStmt.Close() }()
|
|
|
|
insertStmt, err := tx.Prepare(
|
|
"INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertStmt.Close() }()
|
|
|
|
for _, update := range peers {
|
|
var exists bool
|
|
err = checkStmt.QueryRow(update.PeerIP).Scan(&exists)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check peer %s: %w", update.PeerIP, err)
|
|
}
|
|
|
|
if exists {
|
|
_, err = updateStmt.Exec(update.PeerASN, update.Timestamp, update.MessageType, update.PeerIP)
|
|
} else {
|
|
_, err = insertStmt.Exec(
|
|
generateUUID().String(), update.PeerIP, update.PeerASN,
|
|
update.Timestamp, update.Timestamp, update.MessageType)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update peer %s: %w", update.PeerIP, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdatePeer updates or creates a BGP peer record
|
|
func (d *Database) UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error {
|
|
d.lock("UpdatePeer")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
var exists bool
|
|
err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)", peerIP).Scan(&exists)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if exists {
|
|
_, err = tx.Exec(
|
|
"UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?",
|
|
peerASN, timestamp, messageType, peerIP,
|
|
)
|
|
} else {
|
|
_, err = tx.Exec(
|
|
"INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)",
|
|
generateUUID().String(), peerIP, peerASN, timestamp, timestamp, messageType,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for peer update",
|
|
"peer_ip", peerIP,
|
|
"peer_asn", peerASN,
|
|
"error", err,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetStats returns database statistics
|
|
func (d *Database) GetStats() (Stats, error) {
|
|
return d.GetStatsContext(context.Background())
|
|
}
|
|
|
|
// GetStatsContext returns database statistics with context support
|
|
func (d *Database) GetStatsContext(ctx context.Context) (Stats, error) {
|
|
var stats Stats
|
|
|
|
// Count ASNs
|
|
err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM asns").Scan(&stats.ASNs)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
// Count prefixes from both tables
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes_v4").Scan(&stats.IPv4Prefixes)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes_v6").Scan(&stats.IPv6Prefixes)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
stats.Prefixes = stats.IPv4Prefixes + stats.IPv6Prefixes
|
|
|
|
// Count peerings
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
// Count peers
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
// Get database file size
|
|
fileInfo, err := os.Stat(d.path)
|
|
if err != nil {
|
|
d.logger.Warn("Failed to get database file size", "error", err)
|
|
stats.FileSizeBytes = 0
|
|
} else {
|
|
stats.FileSizeBytes = fileInfo.Size()
|
|
}
|
|
|
|
// Get live routes count from both tables
|
|
var v4Count, v6Count int
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v4").Scan(&v4Count)
|
|
if err != nil {
|
|
return stats, fmt.Errorf("failed to count IPv4 routes: %w", err)
|
|
}
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v6").Scan(&v6Count)
|
|
if err != nil {
|
|
return stats, fmt.Errorf("failed to count IPv6 routes: %w", err)
|
|
}
|
|
stats.LiveRoutes = v4Count + v6Count
|
|
|
|
// Get prefix distribution
|
|
stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistributionContext(ctx)
|
|
if err != nil {
|
|
// Log but don't fail
|
|
d.logger.Warn("Failed to get prefix distribution", "error", err)
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// UpsertLiveRoute inserts or updates a live route
|
|
func (d *Database) UpsertLiveRoute(route *LiveRoute) error {
|
|
d.lock("UpsertLiveRoute")
|
|
defer d.unlock()
|
|
|
|
// Choose table based on IP version
|
|
tableName := "live_routes_v4"
|
|
if route.IPVersion == ipVersionV6 {
|
|
tableName = "live_routes_v6"
|
|
}
|
|
|
|
var query string
|
|
if route.IPVersion == ipVersionV4 {
|
|
query = fmt.Sprintf(`
|
|
INSERT INTO %s (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated, ip_start, ip_end)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated,
|
|
ip_start = excluded.ip_start,
|
|
ip_end = excluded.ip_end
|
|
`, tableName)
|
|
} else {
|
|
query = fmt.Sprintf(`
|
|
INSERT INTO %s (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated
|
|
`, tableName)
|
|
}
|
|
|
|
// Encode AS path as JSON
|
|
pathJSON, err := json.Marshal(route.ASPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encode AS path: %w", err)
|
|
}
|
|
|
|
if route.IPVersion == ipVersionV4 {
|
|
// Convert v4_ip_start and v4_ip_end to interface{} for SQL NULL handling
|
|
var v4Start, v4End interface{}
|
|
if route.V4IPStart != nil {
|
|
v4Start = *route.V4IPStart
|
|
}
|
|
if route.V4IPEnd != nil {
|
|
v4End = *route.V4IPEnd
|
|
}
|
|
|
|
_, err = d.db.Exec(query,
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
v4Start,
|
|
v4End,
|
|
)
|
|
} else {
|
|
_, err = d.db.Exec(query,
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// DeleteLiveRoute deletes a live route
|
|
// If originASN is 0, deletes all routes for the prefix/peer combination
|
|
func (d *Database) DeleteLiveRoute(prefix string, originASN int, peerIP string) error {
|
|
d.lock("DeleteLiveRoute")
|
|
defer d.unlock()
|
|
|
|
// Determine table based on prefix IP version
|
|
_, ipnet, err := net.ParseCIDR(prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid prefix format: %w", err)
|
|
}
|
|
|
|
tableName := "live_routes_v4"
|
|
if ipnet.IP.To4() == nil {
|
|
tableName = "live_routes_v6"
|
|
}
|
|
|
|
var query string
|
|
if originASN == 0 {
|
|
// Delete all routes for this prefix from this peer
|
|
query = fmt.Sprintf(`DELETE FROM %s WHERE prefix = ? AND peer_ip = ?`, tableName)
|
|
_, err = d.db.Exec(query, prefix, peerIP)
|
|
} else {
|
|
// Delete specific route
|
|
query = fmt.Sprintf(`DELETE FROM %s WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`, tableName)
|
|
_, err = d.db.Exec(query, prefix, originASN, peerIP)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// GetPrefixDistribution returns the distribution of unique prefixes by mask length
|
|
func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
|
return d.GetPrefixDistributionContext(context.Background())
|
|
}
|
|
|
|
// GetPrefixDistributionContext returns the distribution of unique prefixes by mask length with context support
|
|
func (d *Database) GetPrefixDistributionContext(ctx context.Context) (
|
|
ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
|
// IPv4 distribution - count unique prefixes from v4 table
|
|
query := `
|
|
SELECT mask_length, COUNT(DISTINCT prefix) as count
|
|
FROM live_routes_v4
|
|
GROUP BY mask_length
|
|
ORDER BY mask_length
|
|
`
|
|
rows4, err := d.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to query IPv4 distribution: %w", err)
|
|
}
|
|
defer func() {
|
|
if rows4 != nil {
|
|
_ = rows4.Close()
|
|
}
|
|
}()
|
|
|
|
for rows4.Next() {
|
|
var dist PrefixDistribution
|
|
if err := rows4.Scan(&dist.MaskLength, &dist.Count); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to scan IPv4 distribution: %w", err)
|
|
}
|
|
ipv4 = append(ipv4, dist)
|
|
}
|
|
|
|
// IPv6 distribution - count unique prefixes from v6 table
|
|
query = `
|
|
SELECT mask_length, COUNT(DISTINCT prefix) as count
|
|
FROM live_routes_v6
|
|
GROUP BY mask_length
|
|
ORDER BY mask_length
|
|
`
|
|
rows6, err := d.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to query IPv6 distribution: %w", err)
|
|
}
|
|
defer func() {
|
|
if rows6 != nil {
|
|
_ = rows6.Close()
|
|
}
|
|
}()
|
|
|
|
for rows6.Next() {
|
|
var dist PrefixDistribution
|
|
if err := rows6.Scan(&dist.MaskLength, &dist.Count); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to scan IPv6 distribution: %w", err)
|
|
}
|
|
ipv6 = append(ipv6, dist)
|
|
}
|
|
|
|
return ipv4, ipv6, nil
|
|
}
|
|
|
|
// GetLiveRouteCounts returns the count of IPv4 and IPv6 routes
|
|
func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
|
return d.GetLiveRouteCountsContext(context.Background())
|
|
}
|
|
|
|
// GetLiveRouteCountsContext returns the count of IPv4 and IPv6 routes with context support
|
|
func (d *Database) GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error) {
|
|
// Get IPv4 count from dedicated table
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v4").Scan(&ipv4Count)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("failed to count IPv4 routes: %w", err)
|
|
}
|
|
|
|
// Get IPv6 count from dedicated table
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v6").Scan(&ipv6Count)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("failed to count IPv6 routes: %w", err)
|
|
}
|
|
|
|
return ipv4Count, ipv6Count, nil
|
|
}
|
|
|
|
// GetASInfoForIP returns AS information for the given IP address
|
|
func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
|
return d.GetASInfoForIPContext(context.Background(), ip)
|
|
}
|
|
|
|
// GetASInfoForIPContext returns AS information for the given IP address with context support
|
|
func (d *Database) GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error) {
|
|
// Parse the IP to validate it
|
|
parsedIP := net.ParseIP(ip)
|
|
if parsedIP == nil {
|
|
return nil, fmt.Errorf("%w: %s", ErrInvalidIP, ip)
|
|
}
|
|
|
|
// Determine IP version
|
|
ipVersion := ipVersionV4
|
|
ipv4 := parsedIP.To4()
|
|
if ipv4 == nil {
|
|
ipVersion = ipVersionV6
|
|
}
|
|
|
|
// For IPv4, use optimized range query
|
|
if ipVersion == ipVersionV4 {
|
|
// Convert IP to 32-bit unsigned integer
|
|
ipUint := ipToUint32(ipv4)
|
|
|
|
query := `
|
|
SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description
|
|
FROM live_routes_v4 lr
|
|
LEFT JOIN asns a ON a.asn = lr.origin_asn
|
|
WHERE lr.ip_start <= ? AND lr.ip_end >= ?
|
|
ORDER BY lr.mask_length DESC
|
|
LIMIT 1
|
|
`
|
|
|
|
var prefix string
|
|
var maskLength, originASN int
|
|
var lastUpdated time.Time
|
|
var handle, description sql.NullString
|
|
|
|
err := d.db.QueryRowContext(ctx, query, ipUint, ipUint).Scan(
|
|
&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip)
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to query routes: %w", err)
|
|
}
|
|
|
|
age := time.Since(lastUpdated).Round(time.Second).String()
|
|
|
|
return &ASInfo{
|
|
ASN: originASN,
|
|
Handle: handle.String,
|
|
Description: description.String,
|
|
Prefix: prefix,
|
|
LastUpdated: lastUpdated,
|
|
Age: age,
|
|
}, nil
|
|
}
|
|
|
|
// For IPv6, use the original method since we don't have range optimization
|
|
query := `
|
|
SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description
|
|
FROM live_routes_v6 lr
|
|
LEFT JOIN asns a ON a.asn = lr.origin_asn
|
|
ORDER BY lr.mask_length DESC
|
|
`
|
|
|
|
rows, err := d.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query routes: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
// Find the most specific matching prefix
|
|
var bestMatch struct {
|
|
prefix string
|
|
maskLength int
|
|
originASN int
|
|
lastUpdated time.Time
|
|
handle sql.NullString
|
|
description sql.NullString
|
|
}
|
|
bestMaskLength := -1
|
|
|
|
for rows.Next() {
|
|
var prefix string
|
|
var maskLength, originASN int
|
|
var lastUpdated time.Time
|
|
var handle, description sql.NullString
|
|
|
|
if err := rows.Scan(&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Parse the prefix CIDR
|
|
_, ipNet, err := net.ParseCIDR(prefix)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Check if the IP is in this prefix
|
|
if ipNet.Contains(parsedIP) && maskLength > bestMaskLength {
|
|
bestMatch.prefix = prefix
|
|
bestMatch.maskLength = maskLength
|
|
bestMatch.originASN = originASN
|
|
bestMatch.lastUpdated = lastUpdated
|
|
bestMatch.handle = handle
|
|
bestMatch.description = description
|
|
bestMaskLength = maskLength
|
|
}
|
|
}
|
|
|
|
if bestMaskLength == -1 {
|
|
return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip)
|
|
}
|
|
|
|
age := time.Since(bestMatch.lastUpdated).Round(time.Second).String()
|
|
|
|
return &ASInfo{
|
|
ASN: bestMatch.originASN,
|
|
Handle: bestMatch.handle.String,
|
|
Description: bestMatch.description.String,
|
|
Prefix: bestMatch.prefix,
|
|
LastUpdated: bestMatch.lastUpdated,
|
|
Age: age,
|
|
}, nil
|
|
}
|
|
|
|
// ipToUint32 converts an IPv4 address to a 32-bit unsigned integer
|
|
func ipToUint32(ip net.IP) uint32 {
|
|
if len(ip) == ipv6Length {
|
|
// Convert to 4-byte representation
|
|
ip = ip[ipv4Offset:ipv6Length]
|
|
}
|
|
|
|
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
|
}
|
|
|
|
// CalculateIPv4Range calculates the start and end IP addresses for an IPv4 CIDR block
|
|
func CalculateIPv4Range(cidr string) (start, end uint32, err error) {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
// Get the network address (start of range)
|
|
ip := ipNet.IP.To4()
|
|
if ip == nil {
|
|
return 0, 0, fmt.Errorf("not an IPv4 address")
|
|
}
|
|
|
|
start = ipToUint32(ip)
|
|
|
|
// Calculate the end of the range
|
|
ones, bits := ipNet.Mask.Size()
|
|
hostBits := bits - ones
|
|
if hostBits >= ipv4Bits {
|
|
// Special case for /0 - entire IPv4 space
|
|
end = maxIPv4
|
|
} else {
|
|
// Safe to convert since we checked hostBits < 32
|
|
//nolint:gosec // hostBits is guaranteed to be < 32 from the check above
|
|
end = start | ((1 << uint(hostBits)) - 1)
|
|
}
|
|
|
|
return start, end, nil
|
|
}
|
|
|
|
// GetASDetails returns detailed information about an AS including prefixes
|
|
func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
|
return d.GetASDetailsContext(context.Background(), asn)
|
|
}
|
|
|
|
// GetASDetailsContext returns detailed information about an AS including prefixes with context support
|
|
func (d *Database) GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error) {
|
|
// Get AS information
|
|
var asnInfo ASN
|
|
var handle, description sql.NullString
|
|
err := d.db.QueryRowContext(ctx,
|
|
"SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?",
|
|
asn,
|
|
).Scan(&asnInfo.ASN, &handle, &description, &asnInfo.FirstSeen, &asnInfo.LastSeen)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, fmt.Errorf("%w: AS%d", ErrNoRoute, asn)
|
|
}
|
|
|
|
return nil, nil, fmt.Errorf("failed to query AS: %w", err)
|
|
}
|
|
|
|
asnInfo.Handle = handle.String
|
|
asnInfo.Description = description.String
|
|
|
|
// Get prefixes announced by this AS from both tables
|
|
var allPrefixes []LiveRoute
|
|
|
|
// Query IPv4 prefixes
|
|
queryV4 := `
|
|
SELECT prefix, mask_length, MAX(last_updated) as last_updated
|
|
FROM live_routes_v4
|
|
WHERE origin_asn = ?
|
|
GROUP BY prefix, mask_length
|
|
`
|
|
|
|
rows4, err := d.db.QueryContext(ctx, queryV4, asn)
|
|
if err != nil {
|
|
return &asnInfo, nil, fmt.Errorf("failed to query IPv4 prefixes: %w", err)
|
|
}
|
|
defer func() { _ = rows4.Close() }()
|
|
|
|
for rows4.Next() {
|
|
var route LiveRoute
|
|
var lastUpdatedStr string
|
|
err := rows4.Scan(&route.Prefix, &route.MaskLength, &lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to scan IPv4 prefix row", "error", err, "asn", asn)
|
|
|
|
continue
|
|
}
|
|
// Parse the timestamp string
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05-07:00", lastUpdatedStr)
|
|
if err != nil {
|
|
// Try without timezone
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05", lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to parse timestamp", "error", err, "timestamp", lastUpdatedStr)
|
|
|
|
continue
|
|
}
|
|
}
|
|
route.OriginASN = asn
|
|
route.IPVersion = ipVersionV4
|
|
allPrefixes = append(allPrefixes, route)
|
|
}
|
|
|
|
// Query IPv6 prefixes
|
|
queryV6 := `
|
|
SELECT prefix, mask_length, MAX(last_updated) as last_updated
|
|
FROM live_routes_v6
|
|
WHERE origin_asn = ?
|
|
GROUP BY prefix, mask_length
|
|
`
|
|
|
|
rows6, err := d.db.QueryContext(ctx, queryV6, asn)
|
|
if err != nil {
|
|
return &asnInfo, allPrefixes, fmt.Errorf("failed to query IPv6 prefixes: %w", err)
|
|
}
|
|
defer func() { _ = rows6.Close() }()
|
|
|
|
for rows6.Next() {
|
|
var route LiveRoute
|
|
var lastUpdatedStr string
|
|
err := rows6.Scan(&route.Prefix, &route.MaskLength, &lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to scan IPv6 prefix row", "error", err, "asn", asn)
|
|
|
|
continue
|
|
}
|
|
// Parse the timestamp string
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05-07:00", lastUpdatedStr)
|
|
if err != nil {
|
|
// Try without timezone
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05", lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to parse timestamp", "error", err, "timestamp", lastUpdatedStr)
|
|
|
|
continue
|
|
}
|
|
}
|
|
route.OriginASN = asn
|
|
route.IPVersion = ipVersionV6
|
|
allPrefixes = append(allPrefixes, route)
|
|
}
|
|
|
|
return &asnInfo, allPrefixes, nil
|
|
}
|
|
|
|
// ASPeer represents a peering relationship with another AS including handle, description, and timestamps.
|
|
type ASPeer struct {
|
|
ASN int `json:"asn"`
|
|
Handle string `json:"handle"`
|
|
Description string `json:"description"`
|
|
FirstSeen time.Time `json:"first_seen"`
|
|
LastSeen time.Time `json:"last_seen"`
|
|
}
|
|
|
|
// GetASPeers returns all ASes that peer with the given AS
|
|
func (d *Database) GetASPeers(asn int) ([]ASPeer, error) {
|
|
return d.GetASPeersContext(context.Background(), asn)
|
|
}
|
|
|
|
// GetASPeersContext returns all ASes that peer with the given AS with context support
|
|
func (d *Database) GetASPeersContext(ctx context.Context, asn int) ([]ASPeer, error) {
|
|
query := `
|
|
SELECT
|
|
CASE
|
|
WHEN p.as_a = ? THEN p.as_b
|
|
ELSE p.as_a
|
|
END as peer_asn,
|
|
COALESCE(a.handle, '') as handle,
|
|
COALESCE(a.description, '') as description,
|
|
p.first_seen,
|
|
p.last_seen
|
|
FROM peerings p
|
|
LEFT JOIN asns a ON a.asn = CASE
|
|
WHEN p.as_a = ? THEN p.as_b
|
|
ELSE p.as_a
|
|
END
|
|
WHERE p.as_a = ? OR p.as_b = ?
|
|
ORDER BY peer_asn
|
|
`
|
|
|
|
rows, err := d.db.QueryContext(ctx, query, asn, asn, asn, asn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query AS peers: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var peers []ASPeer
|
|
for rows.Next() {
|
|
var peer ASPeer
|
|
err := rows.Scan(&peer.ASN, &peer.Handle, &peer.Description, &peer.FirstSeen, &peer.LastSeen)
|
|
if err != nil {
|
|
d.logger.Error("Failed to scan peer row", "error", err, "asn", asn)
|
|
|
|
continue
|
|
}
|
|
peers = append(peers, peer)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating AS peers: %w", err)
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// GetPrefixDetails returns detailed information about a prefix
|
|
func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
|
return d.GetPrefixDetailsContext(context.Background(), prefix)
|
|
}
|
|
|
|
// GetPrefixDetailsContext returns detailed information about a prefix with context support
|
|
func (d *Database) GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error) {
|
|
// Determine if it's IPv4 or IPv6 by parsing the prefix
|
|
_, ipnet, err := net.ParseCIDR(prefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid prefix format: %w", err)
|
|
}
|
|
|
|
tableName := "live_routes_v4"
|
|
ipVersion := ipVersionV4
|
|
if ipnet.IP.To4() == nil {
|
|
tableName = "live_routes_v6"
|
|
ipVersion = ipVersionV6
|
|
}
|
|
|
|
//nolint:gosec // Table name is hardcoded based on IP version
|
|
query := fmt.Sprintf(`
|
|
SELECT lr.origin_asn, lr.peer_ip, lr.as_path, lr.next_hop, lr.last_updated,
|
|
a.handle, a.description
|
|
FROM %s lr
|
|
LEFT JOIN asns a ON a.asn = lr.origin_asn
|
|
WHERE lr.prefix = ?
|
|
ORDER BY lr.origin_asn, lr.peer_ip
|
|
`, tableName)
|
|
|
|
rows, err := d.db.QueryContext(ctx, query, prefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query prefix details: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var routes []LiveRoute
|
|
for rows.Next() {
|
|
var route LiveRoute
|
|
var pathJSON string
|
|
var handle, description sql.NullString
|
|
|
|
err := rows.Scan(
|
|
&route.OriginASN, &route.PeerIP, &pathJSON, &route.NextHop,
|
|
&route.LastUpdated, &handle, &description,
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Decode AS path
|
|
if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil {
|
|
route.ASPath = []int{}
|
|
}
|
|
|
|
route.Prefix = prefix
|
|
route.IPVersion = ipVersion
|
|
route.MaskLength, _ = ipnet.Mask.Size()
|
|
routes = append(routes, route)
|
|
}
|
|
|
|
if len(routes) == 0 {
|
|
return nil, fmt.Errorf("%w: %s", ErrNoRoute, prefix)
|
|
}
|
|
|
|
return routes, nil
|
|
}
|
|
|
|
// GetRandomPrefixesByLength returns a random sample of prefixes with the specified mask length
|
|
func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
|
return d.GetRandomPrefixesByLengthContext(context.Background(), maskLength, ipVersion, limit)
|
|
}
|
|
|
|
// GetRandomPrefixesByLengthContext returns a random sample of prefixes with context support
|
|
func (d *Database) GetRandomPrefixesByLengthContext(
|
|
ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
|
// Select unique prefixes with their most recent route information
|
|
tableName := "live_routes_v4"
|
|
if ipVersion == ipVersionV6 {
|
|
tableName = "live_routes_v6"
|
|
}
|
|
|
|
//nolint:gosec // Table name is hardcoded based on IP version
|
|
query := fmt.Sprintf(`
|
|
WITH unique_prefixes AS (
|
|
SELECT prefix, MAX(last_updated) as max_updated
|
|
FROM %s
|
|
WHERE mask_length = ?
|
|
GROUP BY prefix
|
|
ORDER BY RANDOM()
|
|
LIMIT ?
|
|
)
|
|
SELECT lr.prefix, lr.mask_length, lr.origin_asn, lr.as_path,
|
|
lr.peer_ip, lr.last_updated
|
|
FROM %s lr
|
|
INNER JOIN unique_prefixes up ON lr.prefix = up.prefix AND lr.last_updated = up.max_updated
|
|
WHERE lr.mask_length = ?
|
|
`, tableName, tableName)
|
|
|
|
rows, err := d.db.QueryContext(ctx, query, maskLength, limit, maskLength)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query random prefixes: %w", err)
|
|
}
|
|
defer func() {
|
|
_ = rows.Close()
|
|
}()
|
|
|
|
var routes []LiveRoute
|
|
for rows.Next() {
|
|
var route LiveRoute
|
|
var pathJSON string
|
|
err := rows.Scan(
|
|
&route.Prefix,
|
|
&route.MaskLength,
|
|
&route.OriginASN,
|
|
&pathJSON,
|
|
&route.PeerIP,
|
|
&route.LastUpdated,
|
|
)
|
|
// Set IP version based on which table we queried
|
|
route.IPVersion = ipVersion
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Decode AS path
|
|
if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil {
|
|
route.ASPath = []int{}
|
|
}
|
|
|
|
routes = append(routes, route)
|
|
}
|
|
|
|
return routes, nil
|
|
}
|