Fixed remaining references to a.number that should be a.asn after the column rename in the ASNs table.
1573 lines
43 KiB
Go
1573 lines
43 KiB
Go
// Package database provides SQLite storage for BGP routing data including ASNs, prefixes, announcements and peerings.
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
_ "embed"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.eeqj.de/sneak/routewatch/internal/config"
|
|
"git.eeqj.de/sneak/routewatch/internal/logger"
|
|
"git.eeqj.de/sneak/routewatch/pkg/asinfo"
|
|
"github.com/google/uuid"
|
|
_ "github.com/mattn/go-sqlite3" // CGO SQLite driver
|
|
)
|
|
|
|
// IMPORTANT: NO schema changes are to be made outside of schema.sql
|
|
// We do NOT support migrations. All schema changes MUST be made in schema.sql only.
|
|
//
|
|
//go:embed schema.sql
|
|
var dbSchema string
|
|
|
|
const (
|
|
dirPermissions = 0750 // rwxr-x---
|
|
ipVersionV4 = 4
|
|
ipVersionV6 = 6
|
|
ipv6Length = 16
|
|
ipv4Offset = 12
|
|
ipv4Bits = 32
|
|
maxIPv4 = 0xFFFFFFFF
|
|
)
|
|
|
|
// Common errors
|
|
var (
|
|
// ErrInvalidIP is returned when an IP address is malformed
|
|
ErrInvalidIP = errors.New("invalid IP address")
|
|
// ErrNoRoute is returned when no route is found for an IP
|
|
ErrNoRoute = errors.New("no route found")
|
|
)
|
|
|
|
// Database manages the SQLite database connection and operations.
|
|
type Database struct {
|
|
db *sql.DB
|
|
logger *logger.Logger
|
|
path string
|
|
mu sync.Mutex
|
|
lockedAt time.Time
|
|
lockedBy string
|
|
}
|
|
|
|
// New creates a new database connection and initializes the schema.
|
|
func New(cfg *config.Config, logger *logger.Logger) (*Database, error) {
|
|
dbPath := filepath.Join(cfg.GetStateDir(), "db.sqlite")
|
|
|
|
// Log database path
|
|
logger.Info("Opening database", "path", dbPath)
|
|
|
|
// Ensure directory exists
|
|
dir := filepath.Dir(dbPath)
|
|
if err := os.MkdirAll(dir, dirPermissions); err != nil {
|
|
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
|
}
|
|
|
|
// Add connection parameters for go-sqlite3
|
|
// Configure SQLite connection parameters
|
|
dsn := fmt.Sprintf(
|
|
"file:%s",
|
|
dbPath,
|
|
)
|
|
db, err := sql.Open("sqlite3", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
// Set connection pool parameters
|
|
// Multiple connections allow concurrent reads while writes are serialized
|
|
const maxConns = 10
|
|
db.SetMaxOpenConns(maxConns)
|
|
db.SetMaxIdleConns(maxConns)
|
|
db.SetConnMaxLifetime(0)
|
|
|
|
database := &Database{db: db, logger: logger, path: dbPath}
|
|
|
|
if err := database.Initialize(); err != nil {
|
|
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
|
}
|
|
|
|
return database, nil
|
|
}
|
|
|
|
// Initialize creates the database schema if it doesn't exist.
|
|
func (d *Database) Initialize() error {
|
|
// Set SQLite pragmas for performance
|
|
pragmas := []string{
|
|
"PRAGMA journal_mode=WAL", // Write-Ahead Logging
|
|
"PRAGMA synchronous=OFF", // Don't wait for disk writes
|
|
"PRAGMA cache_size=-3145728", // 3GB cache (upper limit for 2.4GB DB)
|
|
"PRAGMA temp_store=MEMORY", // Use memory for temp tables
|
|
"PRAGMA wal_checkpoint(TRUNCATE)", // Checkpoint and truncate WAL now
|
|
"PRAGMA busy_timeout=5000", // 5 second busy timeout
|
|
"PRAGMA analysis_limit=0", // Disable automatic ANALYZE
|
|
}
|
|
|
|
for _, pragma := range pragmas {
|
|
if err := d.exec(pragma); err != nil {
|
|
d.logger.Warn("Failed to set pragma", "pragma", pragma, "error", err)
|
|
}
|
|
}
|
|
|
|
err := d.exec(dbSchema)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Run VACUUM on startup to optimize database
|
|
d.logger.Info("Running VACUUM to optimize database (this may take a moment)")
|
|
if err := d.exec("VACUUM"); err != nil {
|
|
d.logger.Warn("Failed to VACUUM database", "error", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes the database connection.
|
|
func (d *Database) Close() error {
|
|
return d.db.Close()
|
|
}
|
|
|
|
// lock acquires the database mutex and logs debug information
|
|
func (d *Database) lock(operation string) {
|
|
// Get caller information
|
|
_, file, line, _ := runtime.Caller(1)
|
|
caller := fmt.Sprintf("%s:%d", filepath.Base(file), line)
|
|
|
|
d.logger.Debug("Acquiring database lock", "operation", operation, "caller", caller)
|
|
|
|
d.mu.Lock()
|
|
d.lockedAt = time.Now()
|
|
d.lockedBy = fmt.Sprintf("%s (%s)", operation, caller)
|
|
|
|
d.logger.Debug("Database lock acquired", "operation", operation, "caller", caller)
|
|
}
|
|
|
|
// unlock releases the database mutex and logs debug information including hold duration
|
|
func (d *Database) unlock() {
|
|
holdDuration := time.Since(d.lockedAt)
|
|
lockedBy := d.lockedBy
|
|
|
|
d.lockedAt = time.Time{}
|
|
d.lockedBy = ""
|
|
d.mu.Unlock()
|
|
|
|
d.logger.Debug("Database lock released", "held_by", lockedBy, "duration_ms", holdDuration.Milliseconds())
|
|
}
|
|
|
|
// beginTx starts a new transaction with logging
|
|
func (d *Database) beginTx() (*loggingTx, error) {
|
|
tx, err := d.db.Begin()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &loggingTx{Tx: tx, logger: d.logger}, nil
|
|
}
|
|
|
|
// UpsertLiveRouteBatch inserts or updates multiple live routes in a single transaction
|
|
func (d *Database) UpsertLiveRouteBatch(routes []*LiveRoute) error {
|
|
if len(routes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("UpsertLiveRouteBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements for both IPv4 and IPv6
|
|
queryV4 := `
|
|
INSERT INTO live_routes_v4 (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated, ip_start, ip_end)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated,
|
|
ip_start = excluded.ip_start,
|
|
ip_end = excluded.ip_end
|
|
`
|
|
|
|
queryV6 := `
|
|
INSERT INTO live_routes_v6 (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated
|
|
`
|
|
|
|
stmtV4, err := tx.Prepare(queryV4)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV4.Close() }()
|
|
|
|
stmtV6, err := tx.Prepare(queryV6)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV6.Close() }()
|
|
|
|
for _, route := range routes {
|
|
// Encode AS path as JSON
|
|
pathJSON, err := json.Marshal(route.ASPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encode AS path: %w", err)
|
|
}
|
|
|
|
// Use appropriate statement based on IP version
|
|
if route.IPVersion == ipVersionV4 {
|
|
// IPv4 routes must have range values
|
|
if route.V4IPStart == nil || route.V4IPEnd == nil {
|
|
return fmt.Errorf("IPv4 route %s missing range values", route.Prefix)
|
|
}
|
|
|
|
_, err = stmtV4.Exec(
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
*route.V4IPStart,
|
|
*route.V4IPEnd,
|
|
)
|
|
} else {
|
|
// IPv6 routes
|
|
_, err = stmtV6.Exec(
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to upsert route %s: %w", route.Prefix, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteLiveRouteBatch deletes multiple live routes in a single transaction
|
|
func (d *Database) DeleteLiveRouteBatch(deletions []LiveRouteDeletion) error {
|
|
if len(deletions) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("DeleteLiveRouteBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// No longer need to separate deletions since we handle them in the loop below
|
|
|
|
// Prepare statements for both IPv4 and IPv6 tables
|
|
stmtV4WithOrigin, err := tx.Prepare(`DELETE FROM live_routes_v4 WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 delete with origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV4WithOrigin.Close() }()
|
|
|
|
stmtV4WithoutOrigin, err := tx.Prepare(`DELETE FROM live_routes_v4 WHERE prefix = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 delete without origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV4WithoutOrigin.Close() }()
|
|
|
|
stmtV6WithOrigin, err := tx.Prepare(`DELETE FROM live_routes_v6 WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 delete with origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV6WithOrigin.Close() }()
|
|
|
|
stmtV6WithoutOrigin, err := tx.Prepare(`DELETE FROM live_routes_v6 WHERE prefix = ? AND peer_ip = ?`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 delete without origin statement: %w", err)
|
|
}
|
|
defer func() { _ = stmtV6WithoutOrigin.Close() }()
|
|
|
|
// Process deletions
|
|
for _, del := range deletions {
|
|
var stmt *sql.Stmt
|
|
|
|
// Select appropriate statement based on IP version and whether we have origin ASN
|
|
//nolint:nestif // Clear logic for selecting the right statement
|
|
if del.IPVersion == ipVersionV4 {
|
|
if del.OriginASN == 0 {
|
|
stmt = stmtV4WithoutOrigin
|
|
} else {
|
|
stmt = stmtV4WithOrigin
|
|
}
|
|
} else {
|
|
if del.OriginASN == 0 {
|
|
stmt = stmtV6WithoutOrigin
|
|
} else {
|
|
stmt = stmtV6WithOrigin
|
|
}
|
|
}
|
|
|
|
// Execute deletion
|
|
if del.OriginASN == 0 {
|
|
_, err = stmt.Exec(del.Prefix, del.PeerIP)
|
|
} else {
|
|
_, err = stmt.Exec(del.Prefix, del.OriginASN, del.PeerIP)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete route %s: %w", del.Prefix, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdatePrefixesBatch updates the last_seen time for multiple prefixes in a single transaction
|
|
func (d *Database) UpdatePrefixesBatch(prefixes map[string]time.Time) error {
|
|
if len(prefixes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("UpdatePrefixesBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements for both IPv4 and IPv6 tables
|
|
selectV4Stmt, err := tx.Prepare("SELECT id FROM prefixes_v4 WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 select statement: %w", err)
|
|
}
|
|
defer func() { _ = selectV4Stmt.Close() }()
|
|
|
|
updateV4Stmt, err := tx.Prepare("UPDATE prefixes_v4 SET last_seen = ? WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateV4Stmt.Close() }()
|
|
|
|
insertV4Stmt, err := tx.Prepare("INSERT INTO prefixes_v4 (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv4 insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertV4Stmt.Close() }()
|
|
|
|
selectV6Stmt, err := tx.Prepare("SELECT id FROM prefixes_v6 WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 select statement: %w", err)
|
|
}
|
|
defer func() { _ = selectV6Stmt.Close() }()
|
|
|
|
updateV6Stmt, err := tx.Prepare("UPDATE prefixes_v6 SET last_seen = ? WHERE prefix = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateV6Stmt.Close() }()
|
|
|
|
insertV6Stmt, err := tx.Prepare("INSERT INTO prefixes_v6 (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare IPv6 insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertV6Stmt.Close() }()
|
|
|
|
for prefix, timestamp := range prefixes {
|
|
ipVersion := detectIPVersion(prefix)
|
|
|
|
var selectStmt, updateStmt, insertStmt *sql.Stmt
|
|
if ipVersion == ipVersionV4 {
|
|
selectStmt, updateStmt, insertStmt = selectV4Stmt, updateV4Stmt, insertV4Stmt
|
|
} else {
|
|
selectStmt, updateStmt, insertStmt = selectV6Stmt, updateV6Stmt, insertV6Stmt
|
|
}
|
|
|
|
var id string
|
|
err = selectStmt.QueryRow(prefix).Scan(&id)
|
|
|
|
switch err {
|
|
case nil:
|
|
// Prefix exists, update last_seen
|
|
_, err = updateStmt.Exec(timestamp, prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update prefix %s: %w", prefix, err)
|
|
}
|
|
case sql.ErrNoRows:
|
|
// Prefix doesn't exist, create it
|
|
_, err = insertStmt.Exec(generateUUID().String(), prefix, timestamp, timestamp)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert prefix %s: %w", prefix, err)
|
|
}
|
|
default:
|
|
return fmt.Errorf("failed to query prefix %s: %w", prefix, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetOrCreateASNBatch creates or updates multiple ASNs in a single transaction
|
|
func (d *Database) GetOrCreateASNBatch(asns map[int]time.Time) error {
|
|
if len(asns) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("GetOrCreateASNBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements
|
|
selectStmt, err := tx.Prepare(
|
|
"SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare select statement: %w", err)
|
|
}
|
|
defer func() { _ = selectStmt.Close() }()
|
|
|
|
updateStmt, err := tx.Prepare("UPDATE asns SET last_seen = ? WHERE asn = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateStmt.Close() }()
|
|
|
|
insertStmt, err := tx.Prepare(
|
|
"INSERT INTO asns (asn, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertStmt.Close() }()
|
|
|
|
for number, timestamp := range asns {
|
|
var asn ASN
|
|
var handle, description sql.NullString
|
|
|
|
err = selectStmt.QueryRow(number).Scan(&asn.ASN, &handle, &description, &asn.FirstSeen, &asn.LastSeen)
|
|
|
|
if err == nil {
|
|
// ASN exists, update last_seen
|
|
_, err = updateStmt.Exec(timestamp, number)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update ASN %d: %w", number, err)
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
if err == sql.ErrNoRows {
|
|
// ASN doesn't exist, create it
|
|
asn = ASN{
|
|
ASN: number,
|
|
FirstSeen: timestamp,
|
|
LastSeen: timestamp,
|
|
}
|
|
|
|
// Look up ASN info
|
|
if info, ok := asinfo.Get(number); ok {
|
|
asn.Handle = info.Handle
|
|
asn.Description = info.Description
|
|
}
|
|
|
|
_, err = insertStmt.Exec(asn.ASN, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert ASN %d: %w", number, err)
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query ASN %d: %w", number, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetOrCreateASN retrieves an existing ASN or creates a new one if it doesn't exist.
|
|
func (d *Database) GetOrCreateASN(number int, timestamp time.Time) (*ASN, error) {
|
|
d.lock("GetOrCreateASN")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
var asn ASN
|
|
var handle, description sql.NullString
|
|
err = tx.QueryRow("SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?", number).
|
|
Scan(&asn.ASN, &handle, &description, &asn.FirstSeen, &asn.LastSeen)
|
|
|
|
if err == nil {
|
|
// ASN exists, update last_seen
|
|
asn.Handle = handle.String
|
|
asn.Description = description.String
|
|
_, err = tx.Exec("UPDATE asns SET last_seen = ? WHERE asn = ?", timestamp, number)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
asn.LastSeen = timestamp
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for ASN update", "asn", number, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &asn, nil
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
|
|
// ASN doesn't exist, create it with ASN info lookup
|
|
asn = ASN{
|
|
ASN: number,
|
|
FirstSeen: timestamp,
|
|
LastSeen: timestamp,
|
|
}
|
|
|
|
// Look up ASN info
|
|
if info, ok := asinfo.Get(number); ok {
|
|
asn.Handle = info.Handle
|
|
asn.Description = info.Description
|
|
}
|
|
|
|
_, err = tx.Exec("INSERT INTO asns (asn, handle, description, first_seen, last_seen) VALUES (?, ?, ?, ?, ?)",
|
|
asn.ASN, asn.Handle, asn.Description, asn.FirstSeen, asn.LastSeen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for ASN creation", "asn", number, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &asn, nil
|
|
}
|
|
|
|
// GetOrCreatePrefix retrieves an existing prefix or creates a new one if it doesn't exist.
|
|
func (d *Database) GetOrCreatePrefix(prefix string, timestamp time.Time) (*Prefix, error) {
|
|
d.lock("GetOrCreatePrefix")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Determine table based on IP version
|
|
ipVersion := detectIPVersion(prefix)
|
|
tableName := "prefixes_v4"
|
|
if ipVersion == ipVersionV6 {
|
|
tableName = "prefixes_v6"
|
|
}
|
|
|
|
var p Prefix
|
|
var idStr string
|
|
query := fmt.Sprintf("SELECT id, prefix, first_seen, last_seen FROM %s WHERE prefix = ?", tableName)
|
|
err = tx.QueryRow(query, prefix).
|
|
Scan(&idStr, &p.Prefix, &p.FirstSeen, &p.LastSeen)
|
|
|
|
if err == nil {
|
|
// Prefix exists, update last_seen
|
|
p.ID, _ = uuid.Parse(idStr)
|
|
p.IPVersion = ipVersion
|
|
updateQuery := fmt.Sprintf("UPDATE %s SET last_seen = ? WHERE id = ?", tableName)
|
|
_, err = tx.Exec(updateQuery, timestamp, p.ID.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.LastSeen = timestamp
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for prefix update", "prefix", prefix, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &p, nil
|
|
}
|
|
|
|
if err != sql.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
|
|
// Prefix doesn't exist, create it
|
|
p = Prefix{
|
|
ID: generateUUID(),
|
|
Prefix: prefix,
|
|
IPVersion: ipVersion,
|
|
FirstSeen: timestamp,
|
|
LastSeen: timestamp,
|
|
}
|
|
insertQuery := fmt.Sprintf("INSERT INTO %s (id, prefix, first_seen, last_seen) VALUES (?, ?, ?, ?)", tableName)
|
|
_, err = tx.Exec(insertQuery, p.ID.String(), p.Prefix, p.FirstSeen, p.LastSeen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for prefix creation", "prefix", prefix, "error", err)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &p, nil
|
|
}
|
|
|
|
// RecordAnnouncement inserts a new BGP announcement or withdrawal into the database.
|
|
func (d *Database) RecordAnnouncement(announcement *Announcement) error {
|
|
d.lock("RecordAnnouncement")
|
|
defer d.unlock()
|
|
|
|
err := d.exec(`
|
|
INSERT INTO announcements (id, prefix_id, peer_asn, origin_asn, path, next_hop, timestamp, is_withdrawal)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
announcement.ID.String(), announcement.PrefixID.String(),
|
|
announcement.PeerASN, announcement.OriginASN,
|
|
announcement.Path, announcement.NextHop, announcement.Timestamp, announcement.IsWithdrawal)
|
|
|
|
return err
|
|
}
|
|
|
|
// RecordPeering records a peering relationship between two ASNs.
|
|
func (d *Database) RecordPeering(asA, asB int, timestamp time.Time) error {
|
|
// Validate ASNs
|
|
if asA <= 0 || asB <= 0 {
|
|
return fmt.Errorf("invalid ASN: asA=%d, asB=%d", asA, asB)
|
|
}
|
|
if asA == asB {
|
|
return fmt.Errorf("cannot create peering with same ASN: %d", asA)
|
|
}
|
|
|
|
// Normalize: ensure asA < asB
|
|
if asA > asB {
|
|
asA, asB = asB, asA
|
|
}
|
|
|
|
d.lock("RecordPeering")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
var exists bool
|
|
err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM peerings WHERE as_a = ? AND as_b = ?)",
|
|
asA, asB).Scan(&exists)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if exists {
|
|
_, err = tx.Exec("UPDATE peerings SET last_seen = ? WHERE as_a = ? AND as_b = ?",
|
|
timestamp, asA, asB)
|
|
} else {
|
|
_, err = tx.Exec(`
|
|
INSERT INTO peerings (id, as_a, as_b, first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
generateUUID().String(), asA, asB, timestamp, timestamp)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for peering",
|
|
"as_a", asA,
|
|
"as_b", asB,
|
|
"error", err,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdatePeerBatch updates or creates multiple BGP peer records in a single transaction
|
|
func (d *Database) UpdatePeerBatch(peers map[string]PeerUpdate) error {
|
|
if len(peers) == 0 {
|
|
return nil
|
|
}
|
|
|
|
d.lock("UpdatePeerBatch")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Prepare statements
|
|
checkStmt, err := tx.Prepare("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare check statement: %w", err)
|
|
}
|
|
defer func() { _ = checkStmt.Close() }()
|
|
|
|
updateStmt, err := tx.Prepare(
|
|
"UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare update statement: %w", err)
|
|
}
|
|
defer func() { _ = updateStmt.Close() }()
|
|
|
|
insertStmt, err := tx.Prepare(
|
|
"INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare insert statement: %w", err)
|
|
}
|
|
defer func() { _ = insertStmt.Close() }()
|
|
|
|
for _, update := range peers {
|
|
var exists bool
|
|
err = checkStmt.QueryRow(update.PeerIP).Scan(&exists)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check peer %s: %w", update.PeerIP, err)
|
|
}
|
|
|
|
if exists {
|
|
_, err = updateStmt.Exec(update.PeerASN, update.Timestamp, update.MessageType, update.PeerIP)
|
|
} else {
|
|
_, err = insertStmt.Exec(
|
|
generateUUID().String(), update.PeerIP, update.PeerASN,
|
|
update.Timestamp, update.Timestamp, update.MessageType)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update peer %s: %w", update.PeerIP, err)
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdatePeer updates or creates a BGP peer record
|
|
func (d *Database) UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error {
|
|
d.lock("UpdatePeer")
|
|
defer d.unlock()
|
|
|
|
tx, err := d.beginTx()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err := tx.Rollback(); err != nil && err != sql.ErrTxDone {
|
|
d.logger.Error("Failed to rollback transaction", "error", err)
|
|
}
|
|
}()
|
|
|
|
var exists bool
|
|
err = tx.QueryRow("SELECT EXISTS(SELECT 1 FROM bgp_peers WHERE peer_ip = ?)", peerIP).Scan(&exists)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if exists {
|
|
_, err = tx.Exec(
|
|
"UPDATE bgp_peers SET peer_asn = ?, last_seen = ?, last_message_type = ? WHERE peer_ip = ?",
|
|
peerASN, timestamp, messageType, peerIP,
|
|
)
|
|
} else {
|
|
_, err = tx.Exec(
|
|
"INSERT INTO bgp_peers (id, peer_ip, peer_asn, first_seen, last_seen, last_message_type) VALUES (?, ?, ?, ?, ?, ?)",
|
|
generateUUID().String(), peerIP, peerASN, timestamp, timestamp, messageType,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
d.logger.Error("Failed to commit transaction for peer update",
|
|
"peer_ip", peerIP,
|
|
"peer_asn", peerASN,
|
|
"error", err,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetStats returns database statistics
|
|
func (d *Database) GetStats() (Stats, error) {
|
|
return d.GetStatsContext(context.Background())
|
|
}
|
|
|
|
// GetStatsContext returns database statistics with context support
|
|
func (d *Database) GetStatsContext(ctx context.Context) (Stats, error) {
|
|
var stats Stats
|
|
|
|
// Count ASNs
|
|
err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM asns").Scan(&stats.ASNs)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
// Count prefixes from both tables
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes_v4").Scan(&stats.IPv4Prefixes)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes_v6").Scan(&stats.IPv6Prefixes)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
stats.Prefixes = stats.IPv4Prefixes + stats.IPv6Prefixes
|
|
|
|
// Count peerings
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
// Count peers
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers)
|
|
if err != nil {
|
|
return stats, err
|
|
}
|
|
|
|
// Get database file size
|
|
fileInfo, err := os.Stat(d.path)
|
|
if err != nil {
|
|
d.logger.Warn("Failed to get database file size", "error", err)
|
|
stats.FileSizeBytes = 0
|
|
} else {
|
|
stats.FileSizeBytes = fileInfo.Size()
|
|
}
|
|
|
|
// Get live routes count from both tables
|
|
var v4Count, v6Count int
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v4").Scan(&v4Count)
|
|
if err != nil {
|
|
return stats, fmt.Errorf("failed to count IPv4 routes: %w", err)
|
|
}
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v6").Scan(&v6Count)
|
|
if err != nil {
|
|
return stats, fmt.Errorf("failed to count IPv6 routes: %w", err)
|
|
}
|
|
stats.LiveRoutes = v4Count + v6Count
|
|
|
|
// Get prefix distribution
|
|
stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistributionContext(ctx)
|
|
if err != nil {
|
|
// Log but don't fail
|
|
d.logger.Warn("Failed to get prefix distribution", "error", err)
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// UpsertLiveRoute inserts or updates a live route
|
|
func (d *Database) UpsertLiveRoute(route *LiveRoute) error {
|
|
d.lock("UpsertLiveRoute")
|
|
defer d.unlock()
|
|
|
|
// Choose table based on IP version
|
|
tableName := "live_routes_v4"
|
|
if route.IPVersion == ipVersionV6 {
|
|
tableName = "live_routes_v6"
|
|
}
|
|
|
|
var query string
|
|
if route.IPVersion == ipVersionV4 {
|
|
query = fmt.Sprintf(`
|
|
INSERT INTO %s (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated, ip_start, ip_end)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated,
|
|
ip_start = excluded.ip_start,
|
|
ip_end = excluded.ip_end
|
|
`, tableName)
|
|
} else {
|
|
query = fmt.Sprintf(`
|
|
INSERT INTO %s (id, prefix, mask_length, origin_asn, peer_ip, as_path, next_hop,
|
|
last_updated)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET
|
|
mask_length = excluded.mask_length,
|
|
as_path = excluded.as_path,
|
|
next_hop = excluded.next_hop,
|
|
last_updated = excluded.last_updated
|
|
`, tableName)
|
|
}
|
|
|
|
// Encode AS path as JSON
|
|
pathJSON, err := json.Marshal(route.ASPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encode AS path: %w", err)
|
|
}
|
|
|
|
if route.IPVersion == ipVersionV4 {
|
|
// Convert v4_ip_start and v4_ip_end to interface{} for SQL NULL handling
|
|
var v4Start, v4End interface{}
|
|
if route.V4IPStart != nil {
|
|
v4Start = *route.V4IPStart
|
|
}
|
|
if route.V4IPEnd != nil {
|
|
v4End = *route.V4IPEnd
|
|
}
|
|
|
|
_, err = d.db.Exec(query,
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
v4Start,
|
|
v4End,
|
|
)
|
|
} else {
|
|
_, err = d.db.Exec(query,
|
|
route.ID.String(),
|
|
route.Prefix,
|
|
route.MaskLength,
|
|
route.OriginASN,
|
|
route.PeerIP,
|
|
string(pathJSON),
|
|
route.NextHop,
|
|
route.LastUpdated,
|
|
)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// DeleteLiveRoute deletes a live route
|
|
// If originASN is 0, deletes all routes for the prefix/peer combination
|
|
func (d *Database) DeleteLiveRoute(prefix string, originASN int, peerIP string) error {
|
|
d.lock("DeleteLiveRoute")
|
|
defer d.unlock()
|
|
|
|
// Determine table based on prefix IP version
|
|
_, ipnet, err := net.ParseCIDR(prefix)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid prefix format: %w", err)
|
|
}
|
|
|
|
tableName := "live_routes_v4"
|
|
if ipnet.IP.To4() == nil {
|
|
tableName = "live_routes_v6"
|
|
}
|
|
|
|
var query string
|
|
if originASN == 0 {
|
|
// Delete all routes for this prefix from this peer
|
|
query = fmt.Sprintf(`DELETE FROM %s WHERE prefix = ? AND peer_ip = ?`, tableName)
|
|
_, err = d.db.Exec(query, prefix, peerIP)
|
|
} else {
|
|
// Delete specific route
|
|
query = fmt.Sprintf(`DELETE FROM %s WHERE prefix = ? AND origin_asn = ? AND peer_ip = ?`, tableName)
|
|
_, err = d.db.Exec(query, prefix, originASN, peerIP)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// GetPrefixDistribution returns the distribution of unique prefixes by mask length
|
|
func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
|
return d.GetPrefixDistributionContext(context.Background())
|
|
}
|
|
|
|
// GetPrefixDistributionContext returns the distribution of unique prefixes by mask length with context support
|
|
func (d *Database) GetPrefixDistributionContext(ctx context.Context) (
|
|
ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
|
// IPv4 distribution - count unique prefixes from v4 table
|
|
query := `
|
|
SELECT mask_length, COUNT(DISTINCT prefix) as count
|
|
FROM live_routes_v4
|
|
GROUP BY mask_length
|
|
ORDER BY mask_length
|
|
`
|
|
rows4, err := d.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to query IPv4 distribution: %w", err)
|
|
}
|
|
defer func() {
|
|
if rows4 != nil {
|
|
_ = rows4.Close()
|
|
}
|
|
}()
|
|
|
|
for rows4.Next() {
|
|
var dist PrefixDistribution
|
|
if err := rows4.Scan(&dist.MaskLength, &dist.Count); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to scan IPv4 distribution: %w", err)
|
|
}
|
|
ipv4 = append(ipv4, dist)
|
|
}
|
|
|
|
// IPv6 distribution - count unique prefixes from v6 table
|
|
query = `
|
|
SELECT mask_length, COUNT(DISTINCT prefix) as count
|
|
FROM live_routes_v6
|
|
GROUP BY mask_length
|
|
ORDER BY mask_length
|
|
`
|
|
rows6, err := d.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to query IPv6 distribution: %w", err)
|
|
}
|
|
defer func() {
|
|
if rows6 != nil {
|
|
_ = rows6.Close()
|
|
}
|
|
}()
|
|
|
|
for rows6.Next() {
|
|
var dist PrefixDistribution
|
|
if err := rows6.Scan(&dist.MaskLength, &dist.Count); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to scan IPv6 distribution: %w", err)
|
|
}
|
|
ipv6 = append(ipv6, dist)
|
|
}
|
|
|
|
return ipv4, ipv6, nil
|
|
}
|
|
|
|
// GetLiveRouteCounts returns the count of IPv4 and IPv6 routes
|
|
func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
|
return d.GetLiveRouteCountsContext(context.Background())
|
|
}
|
|
|
|
// GetLiveRouteCountsContext returns the count of IPv4 and IPv6 routes with context support
|
|
func (d *Database) GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error) {
|
|
// Get IPv4 count from dedicated table
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v4").Scan(&ipv4Count)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("failed to count IPv4 routes: %w", err)
|
|
}
|
|
|
|
// Get IPv6 count from dedicated table
|
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes_v6").Scan(&ipv6Count)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("failed to count IPv6 routes: %w", err)
|
|
}
|
|
|
|
return ipv4Count, ipv6Count, nil
|
|
}
|
|
|
|
// GetASInfoForIP returns AS information for the given IP address
|
|
func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
|
return d.GetASInfoForIPContext(context.Background(), ip)
|
|
}
|
|
|
|
// GetASInfoForIPContext returns AS information for the given IP address with context support
|
|
func (d *Database) GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error) {
|
|
// Parse the IP to validate it
|
|
parsedIP := net.ParseIP(ip)
|
|
if parsedIP == nil {
|
|
return nil, fmt.Errorf("%w: %s", ErrInvalidIP, ip)
|
|
}
|
|
|
|
// Determine IP version
|
|
ipVersion := ipVersionV4
|
|
ipv4 := parsedIP.To4()
|
|
if ipv4 == nil {
|
|
ipVersion = ipVersionV6
|
|
}
|
|
|
|
// For IPv4, use optimized range query
|
|
if ipVersion == ipVersionV4 {
|
|
// Convert IP to 32-bit unsigned integer
|
|
ipUint := ipToUint32(ipv4)
|
|
|
|
query := `
|
|
SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description
|
|
FROM live_routes_v4 lr
|
|
LEFT JOIN asns a ON a.asn = lr.origin_asn
|
|
WHERE lr.ip_start <= ? AND lr.ip_end >= ?
|
|
ORDER BY lr.mask_length DESC
|
|
LIMIT 1
|
|
`
|
|
|
|
var prefix string
|
|
var maskLength, originASN int
|
|
var lastUpdated time.Time
|
|
var handle, description sql.NullString
|
|
|
|
err := d.db.QueryRowContext(ctx, query, ipUint, ipUint).Scan(
|
|
&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip)
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to query routes: %w", err)
|
|
}
|
|
|
|
age := time.Since(lastUpdated).Round(time.Second).String()
|
|
|
|
return &ASInfo{
|
|
ASN: originASN,
|
|
Handle: handle.String,
|
|
Description: description.String,
|
|
Prefix: prefix,
|
|
LastUpdated: lastUpdated,
|
|
Age: age,
|
|
}, nil
|
|
}
|
|
|
|
// For IPv6, use the original method since we don't have range optimization
|
|
query := `
|
|
SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, lr.last_updated, a.handle, a.description
|
|
FROM live_routes_v6 lr
|
|
LEFT JOIN asns a ON a.asn = lr.origin_asn
|
|
ORDER BY lr.mask_length DESC
|
|
`
|
|
|
|
rows, err := d.db.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query routes: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
// Find the most specific matching prefix
|
|
var bestMatch struct {
|
|
prefix string
|
|
maskLength int
|
|
originASN int
|
|
lastUpdated time.Time
|
|
handle sql.NullString
|
|
description sql.NullString
|
|
}
|
|
bestMaskLength := -1
|
|
|
|
for rows.Next() {
|
|
var prefix string
|
|
var maskLength, originASN int
|
|
var lastUpdated time.Time
|
|
var handle, description sql.NullString
|
|
|
|
if err := rows.Scan(&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Parse the prefix CIDR
|
|
_, ipNet, err := net.ParseCIDR(prefix)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Check if the IP is in this prefix
|
|
if ipNet.Contains(parsedIP) && maskLength > bestMaskLength {
|
|
bestMatch.prefix = prefix
|
|
bestMatch.maskLength = maskLength
|
|
bestMatch.originASN = originASN
|
|
bestMatch.lastUpdated = lastUpdated
|
|
bestMatch.handle = handle
|
|
bestMatch.description = description
|
|
bestMaskLength = maskLength
|
|
}
|
|
}
|
|
|
|
if bestMaskLength == -1 {
|
|
return nil, fmt.Errorf("%w for IP %s", ErrNoRoute, ip)
|
|
}
|
|
|
|
age := time.Since(bestMatch.lastUpdated).Round(time.Second).String()
|
|
|
|
return &ASInfo{
|
|
ASN: bestMatch.originASN,
|
|
Handle: bestMatch.handle.String,
|
|
Description: bestMatch.description.String,
|
|
Prefix: bestMatch.prefix,
|
|
LastUpdated: bestMatch.lastUpdated,
|
|
Age: age,
|
|
}, nil
|
|
}
|
|
|
|
// ipToUint32 converts an IPv4 address to a 32-bit unsigned integer
|
|
func ipToUint32(ip net.IP) uint32 {
|
|
if len(ip) == ipv6Length {
|
|
// Convert to 4-byte representation
|
|
ip = ip[ipv4Offset:ipv6Length]
|
|
}
|
|
|
|
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
|
}
|
|
|
|
// CalculateIPv4Range calculates the start and end IP addresses for an IPv4 CIDR block
|
|
func CalculateIPv4Range(cidr string) (start, end uint32, err error) {
|
|
_, ipNet, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
// Get the network address (start of range)
|
|
ip := ipNet.IP.To4()
|
|
if ip == nil {
|
|
return 0, 0, fmt.Errorf("not an IPv4 address")
|
|
}
|
|
|
|
start = ipToUint32(ip)
|
|
|
|
// Calculate the end of the range
|
|
ones, bits := ipNet.Mask.Size()
|
|
hostBits := bits - ones
|
|
if hostBits >= ipv4Bits {
|
|
// Special case for /0 - entire IPv4 space
|
|
end = maxIPv4
|
|
} else {
|
|
// Safe to convert since we checked hostBits < 32
|
|
//nolint:gosec // hostBits is guaranteed to be < 32 from the check above
|
|
end = start | ((1 << uint(hostBits)) - 1)
|
|
}
|
|
|
|
return start, end, nil
|
|
}
|
|
|
|
// GetASDetails returns detailed information about an AS including prefixes
|
|
func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
|
return d.GetASDetailsContext(context.Background(), asn)
|
|
}
|
|
|
|
// GetASDetailsContext returns detailed information about an AS including prefixes with context support
|
|
func (d *Database) GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error) {
|
|
// Get AS information
|
|
var asnInfo ASN
|
|
var handle, description sql.NullString
|
|
err := d.db.QueryRowContext(ctx,
|
|
"SELECT asn, handle, description, first_seen, last_seen FROM asns WHERE asn = ?",
|
|
asn,
|
|
).Scan(&asnInfo.ASN, &handle, &description, &asnInfo.FirstSeen, &asnInfo.LastSeen)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil, fmt.Errorf("%w: AS%d", ErrNoRoute, asn)
|
|
}
|
|
|
|
return nil, nil, fmt.Errorf("failed to query AS: %w", err)
|
|
}
|
|
|
|
asnInfo.Handle = handle.String
|
|
asnInfo.Description = description.String
|
|
|
|
// Get prefixes announced by this AS from both tables
|
|
var allPrefixes []LiveRoute
|
|
|
|
// Query IPv4 prefixes
|
|
queryV4 := `
|
|
SELECT prefix, mask_length, MAX(last_updated) as last_updated
|
|
FROM live_routes_v4
|
|
WHERE origin_asn = ?
|
|
GROUP BY prefix, mask_length
|
|
`
|
|
|
|
rows4, err := d.db.QueryContext(ctx, queryV4, asn)
|
|
if err != nil {
|
|
return &asnInfo, nil, fmt.Errorf("failed to query IPv4 prefixes: %w", err)
|
|
}
|
|
defer func() { _ = rows4.Close() }()
|
|
|
|
for rows4.Next() {
|
|
var route LiveRoute
|
|
var lastUpdatedStr string
|
|
err := rows4.Scan(&route.Prefix, &route.MaskLength, &lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to scan IPv4 prefix row", "error", err, "asn", asn)
|
|
|
|
continue
|
|
}
|
|
// Parse the timestamp string
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05-07:00", lastUpdatedStr)
|
|
if err != nil {
|
|
// Try without timezone
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05", lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to parse timestamp", "error", err, "timestamp", lastUpdatedStr)
|
|
|
|
continue
|
|
}
|
|
}
|
|
route.OriginASN = asn
|
|
route.IPVersion = ipVersionV4
|
|
allPrefixes = append(allPrefixes, route)
|
|
}
|
|
|
|
// Query IPv6 prefixes
|
|
queryV6 := `
|
|
SELECT prefix, mask_length, MAX(last_updated) as last_updated
|
|
FROM live_routes_v6
|
|
WHERE origin_asn = ?
|
|
GROUP BY prefix, mask_length
|
|
`
|
|
|
|
rows6, err := d.db.QueryContext(ctx, queryV6, asn)
|
|
if err != nil {
|
|
return &asnInfo, allPrefixes, fmt.Errorf("failed to query IPv6 prefixes: %w", err)
|
|
}
|
|
defer func() { _ = rows6.Close() }()
|
|
|
|
for rows6.Next() {
|
|
var route LiveRoute
|
|
var lastUpdatedStr string
|
|
err := rows6.Scan(&route.Prefix, &route.MaskLength, &lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to scan IPv6 prefix row", "error", err, "asn", asn)
|
|
|
|
continue
|
|
}
|
|
// Parse the timestamp string
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05-07:00", lastUpdatedStr)
|
|
if err != nil {
|
|
// Try without timezone
|
|
route.LastUpdated, err = time.Parse("2006-01-02 15:04:05", lastUpdatedStr)
|
|
if err != nil {
|
|
d.logger.Error("Failed to parse timestamp", "error", err, "timestamp", lastUpdatedStr)
|
|
|
|
continue
|
|
}
|
|
}
|
|
route.OriginASN = asn
|
|
route.IPVersion = ipVersionV6
|
|
allPrefixes = append(allPrefixes, route)
|
|
}
|
|
|
|
return &asnInfo, allPrefixes, nil
|
|
}
|
|
|
|
// GetPrefixDetails returns detailed information about a prefix
|
|
func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
|
return d.GetPrefixDetailsContext(context.Background(), prefix)
|
|
}
|
|
|
|
// GetPrefixDetailsContext returns detailed information about a prefix with context support
|
|
func (d *Database) GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error) {
|
|
// Determine if it's IPv4 or IPv6 by parsing the prefix
|
|
_, ipnet, err := net.ParseCIDR(prefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid prefix format: %w", err)
|
|
}
|
|
|
|
tableName := "live_routes_v4"
|
|
ipVersion := ipVersionV4
|
|
if ipnet.IP.To4() == nil {
|
|
tableName = "live_routes_v6"
|
|
ipVersion = ipVersionV6
|
|
}
|
|
|
|
//nolint:gosec // Table name is hardcoded based on IP version
|
|
query := fmt.Sprintf(`
|
|
SELECT lr.origin_asn, lr.peer_ip, lr.as_path, lr.next_hop, lr.last_updated,
|
|
a.handle, a.description
|
|
FROM %s lr
|
|
LEFT JOIN asns a ON a.asn = lr.origin_asn
|
|
WHERE lr.prefix = ?
|
|
ORDER BY lr.origin_asn, lr.peer_ip
|
|
`, tableName)
|
|
|
|
rows, err := d.db.QueryContext(ctx, query, prefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query prefix details: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var routes []LiveRoute
|
|
for rows.Next() {
|
|
var route LiveRoute
|
|
var pathJSON string
|
|
var handle, description sql.NullString
|
|
|
|
err := rows.Scan(
|
|
&route.OriginASN, &route.PeerIP, &pathJSON, &route.NextHop,
|
|
&route.LastUpdated, &handle, &description,
|
|
)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Decode AS path
|
|
if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil {
|
|
route.ASPath = []int{}
|
|
}
|
|
|
|
route.Prefix = prefix
|
|
route.IPVersion = ipVersion
|
|
route.MaskLength, _ = ipnet.Mask.Size()
|
|
routes = append(routes, route)
|
|
}
|
|
|
|
if len(routes) == 0 {
|
|
return nil, fmt.Errorf("%w: %s", ErrNoRoute, prefix)
|
|
}
|
|
|
|
return routes, nil
|
|
}
|
|
|
|
// GetRandomPrefixesByLength returns a random sample of prefixes with the specified mask length
|
|
func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
|
return d.GetRandomPrefixesByLengthContext(context.Background(), maskLength, ipVersion, limit)
|
|
}
|
|
|
|
// GetRandomPrefixesByLengthContext returns a random sample of prefixes with context support
|
|
func (d *Database) GetRandomPrefixesByLengthContext(
|
|
ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
|
// Select unique prefixes with their most recent route information
|
|
tableName := "live_routes_v4"
|
|
if ipVersion == ipVersionV6 {
|
|
tableName = "live_routes_v6"
|
|
}
|
|
|
|
//nolint:gosec // Table name is hardcoded based on IP version
|
|
query := fmt.Sprintf(`
|
|
WITH unique_prefixes AS (
|
|
SELECT prefix, MAX(last_updated) as max_updated
|
|
FROM %s
|
|
WHERE mask_length = ?
|
|
GROUP BY prefix
|
|
ORDER BY RANDOM()
|
|
LIMIT ?
|
|
)
|
|
SELECT lr.prefix, lr.mask_length, lr.origin_asn, lr.as_path,
|
|
lr.peer_ip, lr.last_updated
|
|
FROM %s lr
|
|
INNER JOIN unique_prefixes up ON lr.prefix = up.prefix AND lr.last_updated = up.max_updated
|
|
WHERE lr.mask_length = ?
|
|
`, tableName, tableName)
|
|
|
|
rows, err := d.db.QueryContext(ctx, query, maskLength, limit, maskLength)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query random prefixes: %w", err)
|
|
}
|
|
defer func() {
|
|
_ = rows.Close()
|
|
}()
|
|
|
|
var routes []LiveRoute
|
|
for rows.Next() {
|
|
var route LiveRoute
|
|
var pathJSON string
|
|
err := rows.Scan(
|
|
&route.Prefix,
|
|
&route.MaskLength,
|
|
&route.OriginASN,
|
|
&pathJSON,
|
|
&route.PeerIP,
|
|
&route.LastUpdated,
|
|
)
|
|
// Set IP version based on which table we queried
|
|
route.IPVersion = ipVersion
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Decode AS path
|
|
if err := json.Unmarshal([]byte(pathJSON), &route.ASPath); err != nil {
|
|
route.ASPath = []int{}
|
|
}
|
|
|
|
routes = append(routes, route)
|
|
}
|
|
|
|
return routes, nil
|
|
}
|