Add context cancellation support to database operations
- Add context-aware versions of all read operations in the database - Update handlers to use context from HTTP requests - Allows database queries to be cancelled when HTTP requests timeout - Prevents database operations from continuing after client disconnects
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
@@ -747,41 +748,48 @@ func (d *Database) UpdatePeer(peerIP string, peerASN int, messageType string, ti
|
||||
|
||||
// GetStats returns database statistics
|
||||
func (d *Database) GetStats() (Stats, error) {
|
||||
return d.GetStatsContext(context.Background())
|
||||
}
|
||||
|
||||
// GetStatsContext returns database statistics with context support
|
||||
func (d *Database) GetStatsContext(ctx context.Context) (Stats, error) {
|
||||
var stats Stats
|
||||
|
||||
// Count ASNs
|
||||
err := d.queryRow("SELECT COUNT(*) FROM asns").Scan(&stats.ASNs)
|
||||
err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM asns").Scan(&stats.ASNs)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
|
||||
// Count prefixes
|
||||
err = d.queryRow("SELECT COUNT(*) FROM prefixes").Scan(&stats.Prefixes)
|
||||
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM prefixes").Scan(&stats.Prefixes)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
|
||||
// Count IPv4 and IPv6 prefixes
|
||||
const ipVersionV4 = 4
|
||||
err = d.queryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV4).Scan(&stats.IPv4Prefixes)
|
||||
err = d.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV4).Scan(&stats.IPv4Prefixes)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
|
||||
const ipVersionV6 = 6
|
||||
err = d.queryRow("SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV6).Scan(&stats.IPv6Prefixes)
|
||||
err = d.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM prefixes WHERE ip_version = ?", ipVersionV6).Scan(&stats.IPv6Prefixes)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
|
||||
// Count peerings
|
||||
err = d.queryRow("SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings)
|
||||
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM peerings").Scan(&stats.Peerings)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
|
||||
// Count peers
|
||||
err = d.queryRow("SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers)
|
||||
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM bgp_peers").Scan(&stats.Peers)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
@@ -796,13 +804,13 @@ func (d *Database) GetStats() (Stats, error) {
|
||||
}
|
||||
|
||||
// Get live routes count
|
||||
err = d.db.QueryRow("SELECT COUNT(*) FROM live_routes").Scan(&stats.LiveRoutes)
|
||||
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes").Scan(&stats.LiveRoutes)
|
||||
if err != nil {
|
||||
return stats, fmt.Errorf("failed to count live routes: %w", err)
|
||||
}
|
||||
|
||||
// Get prefix distribution
|
||||
stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistribution()
|
||||
stats.IPv4PrefixDistribution, stats.IPv6PrefixDistribution, err = d.GetPrefixDistributionContext(ctx)
|
||||
if err != nil {
|
||||
// Log but don't fail
|
||||
d.logger.Warn("Failed to get prefix distribution", "error", err)
|
||||
@@ -886,6 +894,12 @@ func (d *Database) DeleteLiveRoute(prefix string, originASN int, peerIP string)
|
||||
|
||||
// GetPrefixDistribution returns the distribution of unique prefixes by mask length
|
||||
func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
||||
return d.GetPrefixDistributionContext(context.Background())
|
||||
}
|
||||
|
||||
// GetPrefixDistributionContext returns the distribution of unique prefixes by mask length with context support
|
||||
func (d *Database) GetPrefixDistributionContext(ctx context.Context) (
|
||||
ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error) {
|
||||
// IPv4 distribution - count unique prefixes, not routes
|
||||
query := `
|
||||
SELECT mask_length, COUNT(DISTINCT prefix) as count
|
||||
@@ -894,7 +908,7 @@ func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []Pr
|
||||
GROUP BY mask_length
|
||||
ORDER BY mask_length
|
||||
`
|
||||
rows, err := d.db.Query(query)
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to query IPv4 distribution: %w", err)
|
||||
}
|
||||
@@ -916,7 +930,7 @@ func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []Pr
|
||||
GROUP BY mask_length
|
||||
ORDER BY mask_length
|
||||
`
|
||||
rows, err = d.db.Query(query)
|
||||
rows, err = d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to query IPv6 distribution: %w", err)
|
||||
}
|
||||
@@ -935,14 +949,19 @@ func (d *Database) GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []Pr
|
||||
|
||||
// GetLiveRouteCounts returns the count of IPv4 and IPv6 routes
|
||||
func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
||||
return d.GetLiveRouteCountsContext(context.Background())
|
||||
}
|
||||
|
||||
// GetLiveRouteCountsContext returns the count of IPv4 and IPv6 routes with context support
|
||||
func (d *Database) GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error) {
|
||||
// Get IPv4 count
|
||||
err = d.db.QueryRow("SELECT COUNT(*) FROM live_routes WHERE ip_version = 4").Scan(&ipv4Count)
|
||||
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes WHERE ip_version = 4").Scan(&ipv4Count)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to count IPv4 routes: %w", err)
|
||||
}
|
||||
|
||||
// Get IPv6 count
|
||||
err = d.db.QueryRow("SELECT COUNT(*) FROM live_routes WHERE ip_version = 6").Scan(&ipv6Count)
|
||||
err = d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM live_routes WHERE ip_version = 6").Scan(&ipv6Count)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to count IPv6 routes: %w", err)
|
||||
}
|
||||
@@ -952,6 +971,11 @@ func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) {
|
||||
|
||||
// GetASInfoForIP returns AS information for the given IP address
|
||||
func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
||||
return d.GetASInfoForIPContext(context.Background(), ip)
|
||||
}
|
||||
|
||||
// GetASInfoForIPContext returns AS information for the given IP address with context support
|
||||
func (d *Database) GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error) {
|
||||
// Parse the IP to validate it
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
@@ -984,7 +1008,7 @@ func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
||||
var lastUpdated time.Time
|
||||
var handle, description sql.NullString
|
||||
|
||||
err := d.db.QueryRow(query, ipVersionV4, ipUint, ipUint).Scan(
|
||||
err := d.db.QueryRowContext(ctx, query, ipVersionV4, ipUint, ipUint).Scan(
|
||||
&prefix, &maskLength, &originASN, &lastUpdated, &handle, &description)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -1015,7 +1039,7 @@ func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) {
|
||||
ORDER BY lr.mask_length DESC
|
||||
`
|
||||
|
||||
rows, err := d.db.Query(query, ipVersionV6)
|
||||
rows, err := d.db.QueryContext(ctx, query, ipVersionV6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query routes: %w", err)
|
||||
}
|
||||
@@ -1118,11 +1142,16 @@ func CalculateIPv4Range(cidr string) (start, end uint32, err error) {
|
||||
|
||||
// GetASDetails returns detailed information about an AS including prefixes
|
||||
func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
||||
return d.GetASDetailsContext(context.Background(), asn)
|
||||
}
|
||||
|
||||
// GetASDetailsContext returns detailed information about an AS including prefixes with context support
|
||||
func (d *Database) GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error) {
|
||||
// Get AS information
|
||||
var asnInfo ASN
|
||||
var idStr string
|
||||
var handle, description sql.NullString
|
||||
err := d.db.QueryRow(
|
||||
err := d.db.QueryRowContext(ctx,
|
||||
"SELECT id, number, handle, description, first_seen, last_seen FROM asns WHERE number = ?",
|
||||
asn,
|
||||
).Scan(&idStr, &asnInfo.Number, &handle, &description, &asnInfo.FirstSeen, &asnInfo.LastSeen)
|
||||
@@ -1147,7 +1176,7 @@ func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
||||
GROUP BY prefix, mask_length, ip_version
|
||||
`
|
||||
|
||||
rows, err := d.db.Query(query, asn)
|
||||
rows, err := d.db.QueryContext(ctx, query, asn)
|
||||
if err != nil {
|
||||
return &asnInfo, nil, fmt.Errorf("failed to query prefixes: %w", err)
|
||||
}
|
||||
@@ -1183,6 +1212,11 @@ func (d *Database) GetASDetails(asn int) (*ASN, []LiveRoute, error) {
|
||||
|
||||
// GetPrefixDetails returns detailed information about a prefix
|
||||
func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
||||
return d.GetPrefixDetailsContext(context.Background(), prefix)
|
||||
}
|
||||
|
||||
// GetPrefixDetailsContext returns detailed information about a prefix with context support
|
||||
func (d *Database) GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error) {
|
||||
query := `
|
||||
SELECT lr.origin_asn, lr.peer_ip, lr.as_path, lr.next_hop, lr.last_updated,
|
||||
a.handle, a.description
|
||||
@@ -1192,7 +1226,7 @@ func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
||||
ORDER BY lr.origin_asn, lr.peer_ip
|
||||
`
|
||||
|
||||
rows, err := d.db.Query(query, prefix)
|
||||
rows, err := d.db.QueryContext(ctx, query, prefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query prefix details: %w", err)
|
||||
}
|
||||
@@ -1230,6 +1264,12 @@ func (d *Database) GetPrefixDetails(prefix string) ([]LiveRoute, error) {
|
||||
|
||||
// GetRandomPrefixesByLength returns a random sample of prefixes with the specified mask length
|
||||
func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
||||
return d.GetRandomPrefixesByLengthContext(context.Background(), maskLength, ipVersion, limit)
|
||||
}
|
||||
|
||||
// GetRandomPrefixesByLengthContext returns a random sample of prefixes with context support
|
||||
func (d *Database) GetRandomPrefixesByLengthContext(
|
||||
ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error) {
|
||||
// Select unique prefixes with their most recent route information
|
||||
query := `
|
||||
WITH unique_prefixes AS (
|
||||
@@ -1247,7 +1287,7 @@ func (d *Database) GetRandomPrefixesByLength(maskLength, ipVersion, limit int) (
|
||||
WHERE lr.mask_length = ? AND lr.ip_version = ?
|
||||
`
|
||||
|
||||
rows, err := d.db.Query(query, maskLength, ipVersion, limit, maskLength, ipVersion)
|
||||
rows, err := d.db.QueryContext(ctx, query, maskLength, ipVersion, limit, maskLength, ipVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query random prefixes: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -35,6 +36,7 @@ type Store interface {
|
||||
|
||||
// Statistics
|
||||
GetStats() (Stats, error)
|
||||
GetStatsContext(ctx context.Context) (Stats, error)
|
||||
|
||||
// Peer operations
|
||||
UpdatePeer(peerIP string, peerASN int, messageType string, timestamp time.Time) error
|
||||
@@ -46,15 +48,21 @@ type Store interface {
|
||||
DeleteLiveRoute(prefix string, originASN int, peerIP string) error
|
||||
DeleteLiveRouteBatch(deletions []LiveRouteDeletion) error
|
||||
GetPrefixDistribution() (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error)
|
||||
GetPrefixDistributionContext(ctx context.Context) (ipv4 []PrefixDistribution, ipv6 []PrefixDistribution, err error)
|
||||
GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error)
|
||||
GetLiveRouteCountsContext(ctx context.Context) (ipv4Count, ipv6Count int, err error)
|
||||
|
||||
// IP lookup operations
|
||||
GetASInfoForIP(ip string) (*ASInfo, error)
|
||||
GetASInfoForIPContext(ctx context.Context, ip string) (*ASInfo, error)
|
||||
|
||||
// AS and prefix detail operations
|
||||
GetASDetails(asn int) (*ASN, []LiveRoute, error)
|
||||
GetASDetailsContext(ctx context.Context, asn int) (*ASN, []LiveRoute, error)
|
||||
GetPrefixDetails(prefix string) ([]LiveRoute, error)
|
||||
GetPrefixDetailsContext(ctx context.Context, prefix string) ([]LiveRoute, error)
|
||||
GetRandomPrefixesByLength(maskLength, ipVersion, limit int) ([]LiveRoute, error)
|
||||
GetRandomPrefixesByLengthContext(ctx context.Context, maskLength, ipVersion, limit int) ([]LiveRoute, error)
|
||||
|
||||
// Lifecycle
|
||||
Close() error
|
||||
|
||||
@@ -19,6 +19,7 @@ func logSlowQuery(logger *logger.Logger, query string, start time.Time) {
|
||||
}
|
||||
|
||||
// queryRow wraps QueryRow with slow query logging
|
||||
// nolint:unused // kept for consistency with other query wrappers
|
||||
func (d *Database) queryRow(query string, args ...interface{}) *sql.Row {
|
||||
start := time.Now()
|
||||
defer logSlowQuery(d.logger, query, start)
|
||||
|
||||
Reference in New Issue
Block a user