Add context cancellation support to database operations
- Add context-aware versions of all read operations in the database - Update handlers to use context from HTTP requests - Allows database queries to be cancelled when HTTP requests timeout - Prevents database operations from continuing after client disconnects
This commit is contained in:
parent
0196251906
commit
e0a4c8642e
@ -2,6 +2,7 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -747,41 +748,48 @@ func (d *Database) UpdatePeer(peerIP string, peerASN int, messageType string, ti
|
|||||||
|
|
||||||
// GetStats returns database statistics
|
// GetStats returns database statistics
|
||||||
func (d *Database) GetStats() (Stats, error) {
|
func (d *Database) GetStats() (Stats, error) {
|
||||||
|
return d.GetStatsContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStatsContext returns database statistics with context support
|
||||||
|
func (d *Database) GetStatsContext(ctx context.Context) (Stats, error) {
|
||||||
var stats Stats
|
var stats Stats
|
||||||
|
|
||||||
// Count ASNs
|
// Count ASNs
|
||||||
err := d.queryRow("SELECT COUNT(*) FROM asns").Scan(&stats.ASNs)
|
err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM asns").Scan(&stats.ASNs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count prefixes
|
// Count prefixes
|
||||||
err = d.queryRow("SELECT COUNT(*) FROM prefixes").Scan(&stats.Prefixes)
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes").Scan(&stats.Prefixes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count IPv4 and IPv6 prefixes
|
// Count IPv4 and IPv6 prefixes
|
||||||
const ipVersionV4 = 4
|
const ipVersionV4 = 4
|
||||||
err = d.queryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV4).Scan(&stats.IPv4Prefixes)
|
err = d.db.QueryRowContext(ctx,
|
||||||
|
"SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV4).Scan(&stats.IPv4Prefixes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const ipVersionV6 = 6
|
const ipVersionV6 = 6
|
||||||
err = d.queryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV6).Scan(&stats.IPv6Prefixes)
|
err = d.db.QueryRowContext(ctx,
|
||||||
|
"SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV6).Scan(&stats.IPv6Prefixes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count peerings
|
// Count peerings
|
||||||
err = d.queryRow("SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings)
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count peers
|
// Count peers
|
||||||
err = d.queryRow("SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers)
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, err
|
return stats, err
|
||||||
}
|
}
|
||||||
@ -796,13 +804,13 @@ func (d *Database) GetStats() (Stats, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get live routes count
|
// Get live routes count
|
||||||
err = d.db.QueryRow("SELECT COUNT(*) FROM live_routes").Scan(&stats.LiveRoutes)
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes").Scan(&stats.LiveRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return stats, fmt.Errorf("failed to count live routes: %w", err)
|
return stats, fmt.Errorf("failed to count live routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get prefix distribution
|
// Get prefix distribution
|
||||||
stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistribution()
|
stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistributionContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log but don't fail
|
// Log but don't fail
|
||||||
d.logger.Warn("Failed to get prefix distribution", "error", err)
|
d.logger.Warn("Failed to get prefix distribution", "error", err)
|
||||||
@ -886,6 +894,12 @@ func (d *Database) DeleteLiveRoute(prefix string, originASN int, peerIP string)
|
|||||||
|
|
||||||
// GetPrefixDistribution returns the distribution of unique prefixes by mask length
|
// GetPrefixDistribution returns the distribution of unique prefixes by mask length
|
||||||
func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
||||||
|
return d.GetPrefixDistributionContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrefixDistributionContext returns the distribution of unique prefixes by mask length with context support
|
||||||
|
func (d *Database) GetPrefixDistributionContext(ctx context.Context) (
|
||||||
|
ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
||||||
// IPv4 distribution - count unique prefixes, not routes
|
// IPv4 distribution - count unique prefixes, not routes
|
||||||
query := `
|
query := `
|
||||||
SELECT mask_length, COUNT(DISTINCT prefix) as count
|
SELECT mask_length, COUNT(DISTINCT prefix) as count
|
||||||
@ -894,7 +908,7 @@ func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []Pr
|
|||||||
GROUP BY mask_length
|
GROUP BY mask_length
|
||||||
ORDER BY mask_length
|
ORDER BY mask_length
|
||||||
`
|
`
|
||||||
rows, err := d.db.Query(query)
|
rows, err := d.db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to query IPv4 distribution: %w", err)
|
return nil, nil, fmt.Errorf("failed to query IPv4 distribution: %w", err)
|
||||||
}
|
}
|
||||||
@ -916,7 +930,7 @@ func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []Pr
|
|||||||
GROUP BY mask_length
|
GROUP BY mask_length
|
||||||
ORDER BY mask_length
|
ORDER BY mask_length
|
||||||
`
|
`
|
||||||
rows, err = d.db.Query(query)
|
rows, err = d.db.QueryContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to query IPv6 distribution: %w", err)
|
return nil, nil, fmt.Errorf("failed to query IPv6 distribution: %w", err)
|
||||||
}
|
}
|
||||||
@ -935,14 +949,19 @@ func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []Pr
|
|||||||
|
|
||||||
// GetLiveRouteCounts returns the count of IPv4 and IPv6 routes
|
// GetLiveRouteCounts returns the count of IPv4 and IPv6 routes
|
||||||
func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
||||||
|
return d.GetLiveRouteCountsContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLiveRouteCountsContext returns the count of IPv4 and IPv6 routes with context support
|
||||||
|
func (d *Database) GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error) {
|
||||||
// Get IPv4 count
|
// Get IPv4 count
|
||||||
err = d.db.QueryRow("SELECT COUNT(*) FROM live_routes WHERE ip_version = 4").Scan(&ipv4Count)
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes WHERE ip_version = 4").Scan(&ipv4Count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, fmt.Errorf("failed to count IPv4 routes: %w", err)
|
return 0, 0, fmt.Errorf("failed to count IPv4 routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get IPv6 count
|
// Get IPv6 count
|
||||||
err = d.db.QueryRow("SELECT COUNT(*) FROM live_routes WHERE ip_version = 6").Scan(&ipv6Count)
|
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes WHERE ip_version = 6").Scan(&ipv6Count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, fmt.Errorf("failed to count IPv6 routes: %w", err)
|
return 0, 0, fmt.Errorf("failed to count IPv6 routes: %w", err)
|
||||||
}
|
}
|
||||||
@ -952,6 +971,11 @@ func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
|||||||
|
|
||||||
// GetASInfoForIP returns AS information for the given IP address
|
// GetASInfoForIP returns AS information for the given IP address
|
||||||
func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
||||||
|
return d.GetASInfoForIPContext(context.Background(), ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetASInfoForIPContext returns AS information for the given IP address with context support
|
||||||
|
func (d *Database) GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error) {
|
||||||
// Parse the IP to validate it
|
// Parse the IP to validate it
|
||||||
parsedIP := net.ParseIP(ip)
|
parsedIP := net.ParseIP(ip)
|
||||||
if parsedIP == nil {
|
if parsedIP == nil {
|
||||||
@ -984,7 +1008,7 @@ func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
|||||||
var lastUpdated time.Time
|
var lastUpdated time.Time
|
||||||
var handle, description sql.NullString
|
var handle, description sql.NullString
|
||||||
|
|
||||||
err := d.db.QueryRow(query, ipVersionV4, ipUint, ipUint).Scan(
|
err := d.db.QueryRowContext(ctx, query, ipVersionV4, ipUint, ipUint).Scan(
|
||||||
&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description)
|
&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
@ -1015,7 +1039,7 @@ func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
|||||||
ORDER BY lr.mask_length DESC
|
ORDER BY lr.mask_length DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := d.db.Query(query, ipVersionV6)
|
rows, err := d.db.QueryContext(ctx, query, ipVersionV6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query routes: %w", err)
|
return nil, fmt.Errorf("failed to query routes: %w", err)
|
||||||
}
|
}
|
||||||
@ -1118,11 +1142,16 @@ func CalculateIPv4Range(cidr string) (start, end uint32, err error) {
|
|||||||
|
|
||||||
// GetASDetails returns detailed information about an AS including prefixes
|
// GetASDetails returns detailed information about an AS including prefixes
|
||||||
func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
||||||
|
return d.GetASDetailsContext(context.Background(), asn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetASDetailsContext returns detailed information about an AS including prefixes with context support
|
||||||
|
func (d *Database) GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error) {
|
||||||
// Get AS information
|
// Get AS information
|
||||||
var asnInfo ASN
|
var asnInfo ASN
|
||||||
var idStr string
|
var idStr string
|
||||||
var handle, description sql.NullString
|
var handle, description sql.NullString
|
||||||
err := d.db.QueryRow(
|
err := d.db.QueryRowContext(ctx,
|
||||||
"SELECT id, number, handle, description, first_seen, last_seen FROM asns WHERE number = ?",
|
"SELECT id, number, handle, description, first_seen, last_seen FROM asns WHERE number = ?",
|
||||||
asn,
|
asn,
|
||||||
).Scan(&idStr, &asnInfo.Number, &handle, &description, &asnInfo.FirstSeen, &asnInfo.LastSeen)
|
).Scan(&idStr, &asnInfo.Number, &handle, &description, &asnInfo.FirstSeen, &asnInfo.LastSeen)
|
||||||
@ -1147,7 +1176,7 @@ func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
|||||||
GROUP BY prefix, mask_length, ip_version
|
GROUP BY prefix, mask_length, ip_version
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := d.db.Query(query, asn)
|
rows, err := d.db.QueryContext(ctx, query, asn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &asnInfo, nil, fmt.Errorf("failed to query prefixes: %w", err)
|
return &asnInfo, nil, fmt.Errorf("failed to query prefixes: %w", err)
|
||||||
}
|
}
|
||||||
@ -1183,6 +1212,11 @@ func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
|||||||
|
|
||||||
// GetPrefixDetails returns detailed information about a prefix
|
// GetPrefixDetails returns detailed information about a prefix
|
||||||
func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
||||||
|
return d.GetPrefixDetailsContext(context.Background(), prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrefixDetailsContext returns detailed information about a prefix with context support
|
||||||
|
func (d *Database) GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT lr.origin_asn, lr.peer_ip, lr.as_path, lr.next_hop, lr.last_updated,
|
SELECT lr.origin_asn, lr.peer_ip, lr.as_path, lr.next_hop, lr.last_updated,
|
||||||
a.handle, a.description
|
a.handle, a.description
|
||||||
@ -1192,7 +1226,7 @@ func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
|||||||
ORDER BY lr.origin_asn, lr.peer_ip
|
ORDER BY lr.origin_asn, lr.peer_ip
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := d.db.Query(query, prefix)
|
rows, err := d.db.QueryContext(ctx, query, prefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query prefix details: %w", err)
|
return nil, fmt.Errorf("failed to query prefix details: %w", err)
|
||||||
}
|
}
|
||||||
@ -1230,6 +1264,12 @@ func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
|||||||
|
|
||||||
// GetRandomPrefixesByLength returns a random sample of prefixes with the specified mask length
|
// GetRandomPrefixesByLength returns a random sample of prefixes with the specified mask length
|
||||||
func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
||||||
|
return d.GetRandomPrefixesByLengthContext(context.Background(), maskLength, ipVersion, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRandomPrefixesByLengthContext returns a random sample of prefixes with context support
|
||||||
|
func (d *Database) GetRandomPrefixesByLengthContext(
|
||||||
|
ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
||||||
// Select unique prefixes with their most recent route information
|
// Select unique prefixes with their most recent route information
|
||||||
query := `
|
query := `
|
||||||
WITH unique_prefixes AS (
|
WITH unique_prefixes AS (
|
||||||
@ -1247,7 +1287,7 @@ func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) (
|
|||||||
WHERE lr.mask_length = ? AND lr.ip_version = ?
|
WHERE lr.mask_length = ? AND lr.ip_version = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := d.db.Query(query, maskLength, ipVersion, limit, maskLength, ipVersion)
|
rows, err := d.db.QueryContext(ctx, query, maskLength, ipVersion, limit, maskLength, ipVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to query random prefixes: %w", err)
|
return nil, fmt.Errorf("failed to query random prefixes: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ type Store interface {
|
|||||||
|
|
||||||
// Statistics
|
// Statistics
|
||||||
GetStats() (Stats, error)
|
GetStats() (Stats, error)
|
||||||
|
GetStatsContext(ctx context.Context) (Stats, error)
|
||||||
|
|
||||||
// Peer operations
|
// Peer operations
|
||||||
UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error
|
UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error
|
||||||
@ -46,15 +48,21 @@ type Store interface {
|
|||||||
DeleteLiveRoute(prefix string, originASN int, peerIP string) error
|
DeleteLiveRoute(prefix string, originASN int, peerIP string) error
|
||||||
DeleteLiveRouteBatch(deletions []LiveRouteDeletion) error
|
DeleteLiveRouteBatch(deletions []LiveRouteDeletion) error
|
||||||
GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error)
|
GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error)
|
||||||
|
GetPrefixDistributionContext(ctx context.Context) (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error)
|
||||||
GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error)
|
GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error)
|
||||||
|
GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error)
|
||||||
|
|
||||||
// IP lookup operations
|
// IP lookup operations
|
||||||
GetASInfoForIP(ip string) (*ASInfo, error)
|
GetASInfoForIP(ip string) (*ASInfo, error)
|
||||||
|
GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error)
|
||||||
|
|
||||||
// AS and prefix detail operations
|
// AS and prefix detail operations
|
||||||
GetASDetails(asn int) (*ASN, []LiveRoute, error)
|
GetASDetails(asn int) (*ASN, []LiveRoute, error)
|
||||||
|
GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error)
|
||||||
GetPrefixDetails(prefix string) ([]LiveRoute, error)
|
GetPrefixDetails(prefix string) ([]LiveRoute, error)
|
||||||
|
GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error)
|
||||||
GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error)
|
GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error)
|
||||||
|
GetRandomPrefixesByLengthContext(ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error)
|
||||||
|
|
||||||
// Lifecycle
|
// Lifecycle
|
||||||
Close() error
|
Close() error
|
||||||
|
@ -19,6 +19,7 @@ func logSlowQuery(logger *logger.Logger, query string, start time.Time) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// queryRow wraps QueryRow with slow query logging
|
// queryRow wraps QueryRow with slow query logging
|
||||||
|
// nolint:unused // kept for consistency with other query wrappers
|
||||||
func (d *Database) queryRow(query string, args ...interface{}) *sql.Row {
|
func (d *Database) queryRow(query string, args ...interface{}) *sql.Row {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer logSlowQuery(d.logger, query, start)
|
defer logSlowQuery(d.logger, query, start)
|
||||||
|
@ -163,6 +163,11 @@ func (m *mockStore) GetStats() (database.Stats, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStatsContext returns statistics about the mock store with context support
|
||||||
|
func (m *mockStore) GetStatsContext(ctx context.Context) (database.Stats, error) {
|
||||||
|
return m.GetStats()
|
||||||
|
}
|
||||||
|
|
||||||
// UpsertLiveRoute mock implementation
|
// UpsertLiveRoute mock implementation
|
||||||
func (m *mockStore) UpsertLiveRoute(route *database.LiveRoute) error {
|
func (m *mockStore) UpsertLiveRoute(route *database.LiveRoute) error {
|
||||||
// Simple mock - just return nil
|
// Simple mock - just return nil
|
||||||
@ -181,12 +186,22 @@ func (m *mockStore) GetPrefixDistribution() (ipv4 []database.PrefixDistribution,
|
|||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrefixDistributionContext mock implementation with context support
|
||||||
|
func (m *mockStore) GetPrefixDistributionContext(ctx context.Context) (ipv4 []database.PrefixDistribution, ipv6 []database.PrefixDistribution, err error) {
|
||||||
|
return m.GetPrefixDistribution()
|
||||||
|
}
|
||||||
|
|
||||||
// GetLiveRouteCounts mock implementation
|
// GetLiveRouteCounts mock implementation
|
||||||
func (m *mockStore) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
func (m *mockStore) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
||||||
// Return mock counts
|
// Return mock counts
|
||||||
return m.RouteCount / 2, m.RouteCount / 2, nil
|
return m.RouteCount / 2, m.RouteCount / 2, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLiveRouteCountsContext mock implementation with context support
|
||||||
|
func (m *mockStore) GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error) {
|
||||||
|
return m.GetLiveRouteCounts()
|
||||||
|
}
|
||||||
|
|
||||||
// GetASInfoForIP mock implementation
|
// GetASInfoForIP mock implementation
|
||||||
func (m *mockStore) GetASInfoForIP(ip string) (*database.ASInfo, error) {
|
func (m *mockStore) GetASInfoForIP(ip string) (*database.ASInfo, error) {
|
||||||
// Simple mock - return a test AS
|
// Simple mock - return a test AS
|
||||||
@ -201,6 +216,11 @@ func (m *mockStore) GetASInfoForIP(ip string) (*database.ASInfo, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetASInfoForIPContext mock implementation with context support
|
||||||
|
func (m *mockStore) GetASInfoForIPContext(ctx context.Context, ip string) (*database.ASInfo, error) {
|
||||||
|
return m.GetASInfoForIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
// GetASDetails mock implementation
|
// GetASDetails mock implementation
|
||||||
func (m *mockStore) GetASDetails(asn int) (*database.ASN, []database.LiveRoute, error) {
|
func (m *mockStore) GetASDetails(asn int) (*database.ASN, []database.LiveRoute, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@ -215,17 +235,32 @@ func (m *mockStore) GetASDetails(asn int) (*database.ASN, []database.LiveRoute,
|
|||||||
return nil, nil, database.ErrNoRoute
|
return nil, nil, database.ErrNoRoute
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetASDetailsContext mock implementation with context support
|
||||||
|
func (m *mockStore) GetASDetailsContext(ctx context.Context, asn int) (*database.ASN, []database.LiveRoute, error) {
|
||||||
|
return m.GetASDetails(asn)
|
||||||
|
}
|
||||||
|
|
||||||
// GetPrefixDetails mock implementation
|
// GetPrefixDetails mock implementation
|
||||||
func (m *mockStore) GetPrefixDetails(prefix string) ([]database.LiveRoute, error) {
|
func (m *mockStore) GetPrefixDetails(prefix string) ([]database.LiveRoute, error) {
|
||||||
// Return empty routes for now
|
// Return empty routes for now
|
||||||
return []database.LiveRoute{}, nil
|
return []database.LiveRoute{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPrefixDetailsContext mock implementation with context support
|
||||||
|
func (m *mockStore) GetPrefixDetailsContext(ctx context.Context, prefix string) ([]database.LiveRoute, error) {
|
||||||
|
return m.GetPrefixDetails(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockStore) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]database.LiveRoute, error) {
|
func (m *mockStore) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]database.LiveRoute, error) {
|
||||||
// Return empty routes for now
|
// Return empty routes for now
|
||||||
return []database.LiveRoute{}, nil
|
return []database.LiveRoute{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRandomPrefixesByLengthContext mock implementation with context support
|
||||||
|
func (m *mockStore) GetRandomPrefixesByLengthContext(ctx context.Context, maskLength, ipVersion, limit int) ([]database.LiveRoute, error) {
|
||||||
|
return m.GetRandomPrefixesByLength(maskLength, ipVersion, limit)
|
||||||
|
}
|
||||||
|
|
||||||
// UpsertLiveRouteBatch mock implementation
|
// UpsertLiveRouteBatch mock implementation
|
||||||
func (m *mockStore) UpsertLiveRouteBatch(routes []*database.LiveRoute) error {
|
func (m *mockStore) UpsertLiveRouteBatch(routes []*database.LiveRoute) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
|
@ -92,7 +92,7 @@ func (s *Server) handleStatusJSON() http.HandlerFunc {
|
|||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
dbStats, err := s.db.GetStats()
|
dbStats, err := s.db.GetStatsContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Debug("Database stats query failed", "error", err)
|
s.logger.Debug("Database stats query failed", "error", err)
|
||||||
errChan <- err
|
errChan <- err
|
||||||
@ -126,7 +126,7 @@ func (s *Server) handleStatusJSON() http.HandlerFunc {
|
|||||||
const bitsPerMegabit = 1000000.0
|
const bitsPerMegabit = 1000000.0
|
||||||
|
|
||||||
// Get route counts from database
|
// Get route counts from database
|
||||||
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCounts()
|
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCountsContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Warn("Failed to get live route counts", "error", err)
|
s.logger.Warn("Failed to get live route counts", "error", err)
|
||||||
// Continue with zero counts
|
// Continue with zero counts
|
||||||
@ -234,7 +234,7 @@ func (s *Server) handleStats() http.HandlerFunc {
|
|||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
dbStats, err := s.db.GetStats()
|
dbStats, err := s.db.GetStatsContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Debug("Database stats query failed", "error", err)
|
s.logger.Debug("Database stats query failed", "error", err)
|
||||||
errChan <- err
|
errChan <- err
|
||||||
@ -267,7 +267,7 @@ func (s *Server) handleStats() http.HandlerFunc {
|
|||||||
const bitsPerMegabit = 1000000.0
|
const bitsPerMegabit = 1000000.0
|
||||||
|
|
||||||
// Get route counts from database
|
// Get route counts from database
|
||||||
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCounts()
|
ipv4Routes, ipv6Routes, err := s.db.GetLiveRouteCountsContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Warn("Failed to get live route counts", "error", err)
|
s.logger.Warn("Failed to get live route counts", "error", err)
|
||||||
// Continue with zero counts
|
// Continue with zero counts
|
||||||
@ -354,7 +354,7 @@ func (s *Server) handleIPLookup() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Look up AS information for the IP
|
// Look up AS information for the IP
|
||||||
asInfo, err := s.db.GetASInfoForIP(ip)
|
asInfo, err := s.db.GetASInfoForIPContext(r.Context(), ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check if it's an invalid IP error
|
// Check if it's an invalid IP error
|
||||||
if errors.Is(err, database.ErrInvalidIP) {
|
if errors.Is(err, database.ErrInvalidIP) {
|
||||||
@ -385,7 +385,7 @@ func (s *Server) handleASDetailJSON() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
asInfo, prefixes, err := s.db.GetASDetails(asn)
|
asInfo, prefixes, err := s.db.GetASDetailsContext(r.Context(), asn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, database.ErrNoRoute) {
|
if errors.Is(err, database.ErrNoRoute) {
|
||||||
writeJSONError(w, http.StatusNotFound, err.Error())
|
writeJSONError(w, http.StatusNotFound, err.Error())
|
||||||
@ -438,7 +438,7 @@ func (s *Server) handlePrefixDetailJSON() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := s.db.GetPrefixDetails(prefix)
|
routes, err := s.db.GetPrefixDetailsContext(r.Context(), prefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, database.ErrNoRoute) {
|
if errors.Is(err, database.ErrNoRoute) {
|
||||||
writeJSONError(w, http.StatusNotFound, err.Error())
|
writeJSONError(w, http.StatusNotFound, err.Error())
|
||||||
@ -480,7 +480,7 @@ func (s *Server) handleASDetail() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
asInfo, prefixes, err := s.db.GetASDetails(asn)
|
asInfo, prefixes, err := s.db.GetASDetailsContext(r.Context(), asn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, database.ErrNoRoute) {
|
if errors.Is(err, database.ErrNoRoute) {
|
||||||
http.Error(w, "AS not found", http.StatusNotFound)
|
http.Error(w, "AS not found", http.StatusNotFound)
|
||||||
@ -584,7 +584,7 @@ func (s *Server) handlePrefixDetail() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := s.db.GetPrefixDetails(prefix)
|
routes, err := s.db.GetPrefixDetailsContext(r.Context(), prefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, database.ErrNoRoute) {
|
if errors.Is(err, database.ErrNoRoute) {
|
||||||
http.Error(w, "Prefix not found", http.StatusNotFound)
|
http.Error(w, "Prefix not found", http.StatusNotFound)
|
||||||
@ -607,7 +607,7 @@ func (s *Server) handlePrefixDetail() http.HandlerFunc {
|
|||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if _, exists := originMap[route.OriginASN]; !exists {
|
if _, exists := originMap[route.OriginASN]; !exists {
|
||||||
// Get AS info from database
|
// Get AS info from database
|
||||||
asInfo, _, _ := s.db.GetASDetails(route.OriginASN)
|
asInfo, _, _ := s.db.GetASDetailsContext(r.Context(), route.OriginASN)
|
||||||
handle := ""
|
handle := ""
|
||||||
description := ""
|
description := ""
|
||||||
if asInfo != nil {
|
if asInfo != nil {
|
||||||
@ -768,7 +768,7 @@ func (s *Server) handlePrefixLength() http.HandlerFunc {
|
|||||||
|
|
||||||
// Get random sample of prefixes
|
// Get random sample of prefixes
|
||||||
const maxPrefixes = 500
|
const maxPrefixes = 500
|
||||||
prefixes, err := s.db.GetRandomPrefixesByLength(maskLength, ipVersion, maxPrefixes)
|
prefixes, err := s.db.GetRandomPrefixesByLengthContext(r.Context(), maskLength, ipVersion, maxPrefixes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("Failed to get prefixes by length", "error", err)
|
s.logger.Error("Failed to get prefixes by length", "error", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -20,9 +19,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
risLiveURL = "https://ris-live.ripe.net/v1/stream/?format=json"
|
risLiveURL = "https://ris-live.ripe.net/v1/stream/?format=json&" +
|
||||||
|
"client=https%3A%2F%2Fgit.eeqj.de%2Fsneak%2Froutewatch"
|
||||||
metricsWindowSize = 60 // seconds for rolling average
|
metricsWindowSize = 60 // seconds for rolling average
|
||||||
metricsUpdateRate = time.Second
|
metricsUpdateRate = time.Second
|
||||||
|
minBackoffDelay = 5 * time.Second
|
||||||
|
maxBackoffDelay = 320 * time.Second
|
||||||
metricsLogInterval = 10 * time.Second
|
metricsLogInterval = 10 * time.Second
|
||||||
bytesPerKB = 1024
|
bytesPerKB = 1024
|
||||||
bytesPerMB = 1024 * 1024
|
bytesPerMB = 1024 * 1024
|
||||||
@ -135,9 +137,7 @@ func (s *Streamer) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.stream(ctx); err != nil {
|
s.streamWithReconnect(ctx)
|
||||||
s.logger.Error("Streaming error", "error", err)
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.running = false
|
s.running = false
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
@ -324,6 +324,72 @@ func (s *Streamer) updateMetrics(messageBytes int) {
|
|||||||
s.metrics.RecordMessage(int64(messageBytes))
|
s.metrics.RecordMessage(int64(messageBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// streamWithReconnect handles streaming with automatic reconnection and exponential backoff
|
||||||
|
func (s *Streamer) streamWithReconnect(ctx context.Context) {
|
||||||
|
backoffDelay := minBackoffDelay
|
||||||
|
consecutiveFailures := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
s.logger.Info("Stream context cancelled, stopping reconnection attempts")
|
||||||
|
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to stream
|
||||||
|
startTime := time.Now()
|
||||||
|
err := s.stream(ctx)
|
||||||
|
streamDuration := time.Since(startTime)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// Clean exit (context cancelled)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log the error
|
||||||
|
s.logger.Error("Stream disconnected",
|
||||||
|
"error", err,
|
||||||
|
"consecutive_failures", consecutiveFailures+1,
|
||||||
|
"stream_duration", streamDuration)
|
||||||
|
s.metrics.SetConnected(false)
|
||||||
|
|
||||||
|
// Check if context is cancelled
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we streamed for more than 30 seconds, reset the backoff
|
||||||
|
// This indicates we had a successful connection that received data
|
||||||
|
if streamDuration > 30*time.Second {
|
||||||
|
s.logger.Info("Resetting backoff delay due to successful connection",
|
||||||
|
"stream_duration", streamDuration)
|
||||||
|
backoffDelay = minBackoffDelay
|
||||||
|
consecutiveFailures = 0
|
||||||
|
} else {
|
||||||
|
// Increment consecutive failures
|
||||||
|
consecutiveFailures++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait with exponential backoff
|
||||||
|
s.logger.Info("Waiting before reconnection attempt",
|
||||||
|
"delay_seconds", backoffDelay.Seconds(),
|
||||||
|
"consecutive_failures", consecutiveFailures)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(backoffDelay):
|
||||||
|
// Double the backoff delay for next time, up to max
|
||||||
|
backoffDelay *= 2
|
||||||
|
if backoffDelay > maxBackoffDelay {
|
||||||
|
backoffDelay = maxBackoffDelay
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Streamer) stream(ctx context.Context) error {
|
func (s *Streamer) stream(ctx context.Context) error {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", risLiveURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", risLiveURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -394,10 +460,13 @@ func (s *Streamer) stream(ctx context.Context) error {
|
|||||||
// Parse the message first
|
// Parse the message first
|
||||||
var wrapper ristypes.RISLiveMessage
|
var wrapper ristypes.RISLiveMessage
|
||||||
if err := json.Unmarshal(line, &wrapper); err != nil {
|
if err := json.Unmarshal(line, &wrapper); err != nil {
|
||||||
// Output the raw line and panic on parse failure
|
// Log the error and return to trigger reconnection
|
||||||
fmt.Fprintf(os.Stderr, "Failed to parse JSON: %v\n", err)
|
s.logger.Error("Failed to parse JSON",
|
||||||
fmt.Fprintf(os.Stderr, "Raw line: %s\n", string(line))
|
"error", err,
|
||||||
panic(fmt.Sprintf("JSON parse error: %v", err))
|
"line", string(line),
|
||||||
|
"line_length", len(line))
|
||||||
|
|
||||||
|
return fmt.Errorf("JSON parse error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if it's a ris_message wrapper
|
// Check if it's a ris_message wrapper
|
||||||
@ -447,18 +516,11 @@ func (s *Streamer) stream(ctx context.Context) error {
|
|||||||
// Peer state changes - silently ignore
|
// Peer state changes - silently ignore
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(
|
s.logger.Error("Unknown message type",
|
||||||
os.Stderr,
|
"type", msg.Type,
|
||||||
"UNKNOWN MESSAGE TYPE: %s\nRAW MESSAGE: %s\n",
|
"line", string(line),
|
||||||
msg.Type,
|
|
||||||
string(line),
|
|
||||||
)
|
|
||||||
panic(
|
|
||||||
fmt.Sprintf(
|
|
||||||
"Unknown RIS message type: %s",
|
|
||||||
msg.Type,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
panic(fmt.Sprintf("Unknown RIS message type: %s", msg.Type))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dispatch to interested handlers
|
// Dispatch to interested handlers
|
||||||
|
Loading…
Reference in New Issue
Block a user