diff --git a/internal/database/database.go b/internal/database/database.go index bf6f4c3..d66f427 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -6,6 +6,7 @@ import ( _ "embed" "encoding/json" "fmt" + "net" "os" "path/filepath" "time" @@ -20,7 +21,15 @@ import ( //go:embed schema.sql var dbSchema string -const dirPermissions = 0750 // rwxr-x--- +const ( + dirPermissions = 0750 // rwxr-x--- + ipVersionV4 = 4 + ipVersionV6 = 6 + ipv6Length = 16 + ipv4Offset = 12 + ipv4Bits = 32 + maxIPv4 = 0xFFFFFFFF +) // Database manages the SQLite database connection and operations. type Database struct { @@ -430,14 +439,17 @@ func (d *Database) GetStats() (Stats, error) { // UpsertLiveRoute inserts or updates a live route func (d *Database) UpsertLiveRoute(route *LiveRoute) error { query := ` - INSERT INTO live_routes (id, prefix, mask_length, ip_version, origin_asn, peer_ip, as_path, next_hop, last_updated) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO live_routes (id, prefix, mask_length, ip_version, origin_asn, peer_ip, as_path, next_hop, + last_updated, v4_ip_start, v4_ip_end) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(prefix, origin_asn, peer_ip) DO UPDATE SET mask_length = excluded.mask_length, ip_version = excluded.ip_version, as_path = excluded.as_path, next_hop = excluded.next_hop, - last_updated = excluded.last_updated + last_updated = excluded.last_updated, + v4_ip_start = excluded.v4_ip_start, + v4_ip_end = excluded.v4_ip_end ` // Encode AS path as JSON @@ -446,6 +458,15 @@ func (d *Database) UpsertLiveRoute(route *LiveRoute) error { return fmt.Errorf("failed to encode AS path: %w", err) } + // Convert v4_ip_start and v4_ip_end to interface{} for SQL NULL handling + var v4Start, v4End interface{} + if route.V4IPStart != nil { + v4Start = *route.V4IPStart + } + if route.V4IPEnd != nil { + v4End = *route.V4IPEnd + } + _, err = d.db.Exec(query, route.ID.String(), route.Prefix, @@ -456,6 +477,8 @@ func (d *Database) UpsertLiveRoute(route *LiveRoute) error { string(pathJSON), route.NextHop, route.LastUpdated, + v4Start, + v4End, ) return err @@ -545,3 +568,156 @@ func (d *Database) GetLiveRouteCounts() (ipv4Count, ipv6Count int, err error) { return ipv4Count, ipv6Count, nil } + +// GetASInfoForIP returns AS information for the given IP address +func (d *Database) GetASInfoForIP(ip string) (*ASInfo, error) { + // Parse the IP to validate it + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return nil, fmt.Errorf("invalid IP address: %s", ip) + } + + // Determine IP version + ipVersion := ipVersionV4 + ipv4 := parsedIP.To4() + if ipv4 == nil { + ipVersion = ipVersionV6 + } + + // For IPv4, use optimized range query + if ipVersion == ipVersionV4 { + // Convert IP to 32-bit unsigned integer + ipUint := ipToUint32(ipv4) + + query := ` + SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, a.handle, a.description + FROM live_routes lr + LEFT JOIN asns a ON a.number = lr.origin_asn + WHERE lr.ip_version = ? AND lr.v4_ip_start <= ? AND lr.v4_ip_end >= ? + ORDER BY lr.mask_length DESC + LIMIT 1 + ` + + var prefix string + var maskLength, originASN int + var handle, description sql.NullString + + err := d.db.QueryRow(query, ipVersionV4, ipUint, ipUint).Scan(&prefix, &maskLength, &originASN, &handle, &description) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("no route found for IP %s", ip) + } + + return nil, fmt.Errorf("failed to query routes: %w", err) + } + + return &ASInfo{ + ASN: originASN, + Handle: handle.String, + Description: description.String, + Prefix: prefix, + }, nil + } + + // For IPv6, use the original method since we don't have range optimization + query := ` + SELECT DISTINCT lr.prefix, lr.mask_length, lr.origin_asn, a.handle, a.description + FROM live_routes lr + LEFT JOIN asns a ON a.number = lr.origin_asn + WHERE lr.ip_version = ? + ORDER BY lr.mask_length DESC + ` + + rows, err := d.db.Query(query, ipVersionV6) + if err != nil { + return nil, fmt.Errorf("failed to query routes: %w", err) + } + defer func() { _ = rows.Close() }() + + // Find the most specific matching prefix + var bestMatch struct { + prefix string + maskLength int + originASN int + handle sql.NullString + description sql.NullString + } + bestMaskLength := -1 + + for rows.Next() { + var prefix string + var maskLength, originASN int + var handle, description sql.NullString + + if err := rows.Scan(&prefix, &maskLength, &originASN, &handle, &description); err != nil { + continue + } + + // Parse the prefix CIDR + _, ipNet, err := net.ParseCIDR(prefix) + if err != nil { + continue + } + + // Check if the IP is in this prefix + if ipNet.Contains(parsedIP) && maskLength > bestMaskLength { + bestMatch.prefix = prefix + bestMatch.maskLength = maskLength + bestMatch.originASN = originASN + bestMatch.handle = handle + bestMatch.description = description + bestMaskLength = maskLength + } + } + + if bestMaskLength == -1 { + return nil, fmt.Errorf("no route found for IP %s", ip) + } + + return &ASInfo{ + ASN: bestMatch.originASN, + Handle: bestMatch.handle.String, + Description: bestMatch.description.String, + Prefix: bestMatch.prefix, + }, nil +} + +// ipToUint32 converts an IPv4 address to a 32-bit unsigned integer +func ipToUint32(ip net.IP) uint32 { + if len(ip) == ipv6Length { + // Convert to 4-byte representation + ip = ip[ipv4Offset:ipv6Length] + } + + return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) +} + +// CalculateIPv4Range calculates the start and end IP addresses for an IPv4 CIDR block +func CalculateIPv4Range(cidr string) (start, end uint32, err error) { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + return 0, 0, err + } + + // Get the network address (start of range) + ip := ipNet.IP.To4() + if ip == nil { + return 0, 0, fmt.Errorf("not an IPv4 address") + } + + start = ipToUint32(ip) + + // Calculate the end of the range + ones, bits := ipNet.Mask.Size() + hostBits := bits - ones + if hostBits >= ipv4Bits { + // Special case for /0 - entire IPv4 space + end = maxIPv4 + } else { + // Safe to convert since we checked hostBits < 32 + //nolint:gosec // hostBits is guaranteed to be < 32 from the check above + end = start | ((1 << uint(hostBits)) - 1) + } + + return start, end, nil +} diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 0000000..6636871 --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,301 @@ +package database + +import ( + "net" + "testing" +) + +func TestIPToUint32(t *testing.T) { + tests := []struct { + name string + ip string + expected uint32 + }{ + { + name: "Simple IP", + ip: "192.168.1.1", + expected: 3232235777, // 192<<24 + 168<<16 + 1<<8 + 1 + }, + { + name: "Minimum IP", + ip: "0.0.0.0", + expected: 0, + }, + { + name: "Maximum IP", + ip: "255.255.255.255", + expected: 4294967295, + }, + { + name: "10.0.0.0", + ip: "10.0.0.0", + expected: 167772160, + }, + { + name: "172.16.0.0", + ip: "172.16.0.0", + expected: 2886729728, + }, + { + name: "8.8.8.8", + ip: "8.8.8.8", + expected: 134744072, + }, + { + name: "1.2.3.4", + ip: "1.2.3.4", + expected: 16909060, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("Failed to parse IP: %s", tt.ip) + } + + result := ipToUint32(ip) + if result != tt.expected { + t.Errorf("ipToUint32(%s) = %d, want %d", tt.ip, result, tt.expected) + } + + // Test with IPv4-mapped IPv6 address + ip6 := net.ParseIP(tt.ip).To16() + if ip6 != nil { + result6 := ipToUint32(ip6) + if result6 != tt.expected { + t.Errorf("ipToUint32(%s as IPv6) = %d, want %d", tt.ip, result6, tt.expected) + } + } + }) + } +} + +func TestCalculateIPv4Range(t *testing.T) { + tests := []struct { + name string + cidr string + wantStart uint32 + wantEnd uint32 + wantErr bool + }{ + { + name: "Single IP /32", + cidr: "192.168.1.1/32", + wantStart: 3232235777, + wantEnd: 3232235777, + }, + { + name: "Class C /24", + cidr: "192.168.1.0/24", + wantStart: 3232235776, // 192.168.1.0 + wantEnd: 3232236031, // 192.168.1.255 + }, + { + name: "Class B /16", + cidr: "192.168.0.0/16", + wantStart: 3232235520, // 192.168.0.0 + wantEnd: 3232301055, // 192.168.255.255 + }, + { + name: "Class A /8", + cidr: "10.0.0.0/8", + wantStart: 167772160, // 10.0.0.0 + wantEnd: 184549375, // 10.255.255.255 + }, + { + name: "Entire IPv4 space /0", + cidr: "0.0.0.0/0", + wantStart: 0, + wantEnd: 4294967295, + }, + { + name: "Small subnet /30", + cidr: "192.168.1.0/30", + wantStart: 3232235776, // 192.168.1.0 + wantEnd: 3232235779, // 192.168.1.3 + }, + { + name: "Medium subnet /20", + cidr: "172.16.0.0/20", + wantStart: 2886729728, // 172.16.0.0 + wantEnd: 2886733823, // 172.16.15.255 + }, + { + name: "Private range 172.16/12", + cidr: "172.16.0.0/12", + wantStart: 2886729728, // 172.16.0.0 + wantEnd: 2887778303, // 172.31.255.255 + }, + { + name: "Google DNS /29", + cidr: "8.8.8.8/29", + wantStart: 134744072, // 8.8.8.8 (network is actually 8.8.8.8 with /29) + wantEnd: 134744079, // 8.8.8.15 + }, + { + name: "Non-zero host bits", + cidr: "192.168.1.5/24", + wantStart: 3232235776, // 192.168.1.0 (network address) + wantEnd: 3232236031, // 192.168.1.255 + }, + { + name: "Invalid CIDR", + cidr: "192.168.1.1/33", + wantErr: true, + }, + { + name: "Invalid IP", + cidr: "256.256.256.256/24", + wantErr: true, + }, + { + name: "IPv6 CIDR", + cidr: "2001:db8::/32", + wantErr: true, + }, + { + name: "Empty CIDR", + cidr: "", + wantErr: true, + }, + { + name: "Missing mask", + cidr: "192.168.1.1", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start, end, err := CalculateIPv4Range(tt.cidr) + + if tt.wantErr { + if err == nil { + t.Errorf("CalculateIPv4Range(%s) expected error, got nil", tt.cidr) + } + return + } + + if err != nil { + t.Errorf("CalculateIPv4Range(%s) unexpected error: %v", tt.cidr, err) + return + } + + if start != tt.wantStart { + t.Errorf("CalculateIPv4Range(%s) start = %d, want %d", tt.cidr, start, tt.wantStart) + } + + if end != tt.wantEnd { + t.Errorf("CalculateIPv4Range(%s) end = %d, want %d", tt.cidr, end, tt.wantEnd) + } + + // Verify that start <= end + if start > end { + t.Errorf("CalculateIPv4Range(%s) start (%d) > end (%d)", tt.cidr, start, end) + } + + // Verify the range size matches the CIDR mask + if !tt.wantErr && tt.cidr != "" { + _, ipNet, _ := net.ParseCIDR(tt.cidr) + if ipNet != nil { + ones, bits := ipNet.Mask.Size() + expectedSize := uint32(1) << uint(bits-ones) + actualSize := end - start + 1 + if actualSize != expectedSize { + t.Errorf("CalculateIPv4Range(%s) range size = %d, want %d", tt.cidr, actualSize, expectedSize) + } + } + } + }) + } +} + +func TestIPv4RangeIntegration(t *testing.T) { + // Test that our functions work correctly together + tests := []struct { + name string + cidr string + testIPs []string + shouldContain []bool + }{ + { + name: "192.168.1.0/24", + cidr: "192.168.1.0/24", + testIPs: []string{ + "192.168.1.0", + "192.168.1.1", + "192.168.1.255", + "192.168.0.255", + "192.168.2.0", + }, + shouldContain: []bool{true, true, true, false, false}, + }, + { + name: "10.0.0.0/8", + cidr: "10.0.0.0/8", + testIPs: []string{ + "10.0.0.0", + "10.255.255.255", + "10.1.2.3", + "9.255.255.255", + "11.0.0.0", + }, + shouldContain: []bool{true, true, true, false, false}, + }, + { + name: "172.16.0.0/12", + cidr: "172.16.0.0/12", + testIPs: []string{ + "172.16.0.0", + "172.31.255.255", + "172.20.1.1", + "172.15.255.255", + "172.32.0.0", + }, + shouldContain: []bool{true, true, true, false, false}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start, end, err := CalculateIPv4Range(tt.cidr) + if err != nil { + t.Fatalf("Failed to calculate range for %s: %v", tt.cidr, err) + } + + for i, testIP := range tt.testIPs { + ip := net.ParseIP(testIP) + if ip == nil { + t.Fatalf("Failed to parse test IP: %s", testIP) + } + + ipUint := ipToUint32(ip) + contained := ipUint >= start && ipUint <= end + + if contained != tt.shouldContain[i] { + t.Errorf("IP %s in range %s: got %v, want %v", testIP, tt.cidr, contained, tt.shouldContain[i]) + } + } + }) + } +} + +func BenchmarkIPToUint32(b *testing.B) { + ip := net.ParseIP("192.168.1.1") + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = ipToUint32(ip) + } +} + +func BenchmarkCalculateIPv4Range(b *testing.B) { + cidr := "192.168.0.0/16" + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _, _ = CalculateIPv4Range(cidr) + } +} diff --git a/internal/database/models.go b/internal/database/models.go index f57e0b2..ff04f60 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -57,6 +57,9 @@ type LiveRoute struct { ASPath []int `json:"as_path"` NextHop string `json:"next_hop"` LastUpdated time.Time `json:"last_updated"` + // IPv4 range fields for fast lookups (nil for IPv6) + V4IPStart *uint32 `json:"v4_ip_start,omitempty"` + V4IPEnd *uint32 `json:"v4_ip_end,omitempty"` } // PrefixDistribution represents the distribution of prefixes by mask length @@ -64,3 +67,11 @@ type PrefixDistribution struct { MaskLength int `json:"mask_length"` Count int `json:"count"` } + +// ASInfo represents AS information for an IP lookup +type ASInfo struct { + ASN int `json:"asn"` + Handle string `json:"handle"` + Description string `json:"description"` + Prefix string `json:"prefix"` +} diff --git a/internal/database/schema.sql b/internal/database/schema.sql index cc0b8af..f02e187 100644 --- a/internal/database/schema.sql +++ b/internal/database/schema.sql @@ -79,6 +79,9 @@ CREATE TABLE IF NOT EXISTS live_routes ( as_path TEXT NOT NULL, -- JSON array next_hop TEXT NOT NULL, last_updated DATETIME NOT NULL, + -- IPv4 range columns for fast lookups (NULL for IPv6) + v4_ip_start INTEGER, -- Start of IPv4 range as 32-bit unsigned int + v4_ip_end INTEGER, -- End of IPv4 range as 32-bit unsigned int UNIQUE(prefix, origin_asn, peer_ip) ); @@ -86,4 +89,6 @@ CREATE TABLE IF NOT EXISTS live_routes ( CREATE INDEX IF NOT EXISTS idx_live_routes_prefix ON live_routes(prefix); CREATE INDEX IF NOT EXISTS idx_live_routes_mask_length ON live_routes(mask_length); CREATE INDEX IF NOT EXISTS idx_live_routes_ip_version_mask ON live_routes(ip_version, mask_length); -CREATE INDEX IF NOT EXISTS idx_live_routes_last_updated ON live_routes(last_updated); \ No newline at end of file +CREATE INDEX IF NOT EXISTS idx_live_routes_last_updated ON live_routes(last_updated); +-- Indexes for IPv4 range queries +CREATE INDEX IF NOT EXISTS idx_live_routes_ipv4_range ON live_routes(v4_ip_start, v4_ip_end) WHERE ip_version = 4; \ No newline at end of file diff --git a/internal/database/utils.go b/internal/database/utils.go index f09a3b8..cf1ff34 100644 --- a/internal/database/utils.go +++ b/internal/database/utils.go @@ -10,11 +10,6 @@ func generateUUID() uuid.UUID { return uuid.New() } -const ( - ipVersionV4 = 4 - ipVersionV6 = 6 -) - // detectIPVersion determines if a prefix is IPv4 (returns 4) or IPv6 (returns 6) func detectIPVersion(prefix string) int { if strings.Contains(prefix, ":") { diff --git a/internal/routewatch/prefixhandler.go b/internal/routewatch/prefixhandler.go index 40d3b91..b88f216 100644 --- a/internal/routewatch/prefixhandler.go +++ b/internal/routewatch/prefixhandler.go @@ -249,6 +249,20 @@ func (h *PrefixHandler) processAnnouncement(_ *database.Prefix, update prefixUpd LastUpdated: update.timestamp, } + // For IPv4, calculate the IP range + if ipVersion == ipv4Version { + start, end, err := database.CalculateIPv4Range(update.prefix) + if err == nil { + liveRoute.V4IPStart = &start + liveRoute.V4IPEnd = &end + } else { + h.logger.Error("Failed to calculate IPv4 range", + "prefix", update.prefix, + "error", err, + ) + } + } + if err := h.db.UpsertLiveRoute(liveRoute); err != nil { h.logger.Error("Failed to upsert live route", "prefix", update.prefix,