- Add http module for proper FX dependency injection - Fix router to accept state manager parameter - Implement proper CIDR-based checking for RFC1918 and documentation IPs - Add reasonable timeouts (30s) for database downloads - Update tests to download databases to temporary directories - Add tests for multiple IP lookups and error cases - All tests passing
346 lines
9.6 KiB
Go
346 lines
9.6 KiB
Go
// Package http provides the HTTP server and routing functionality.
|
|
package http
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.eeqj.de/sneak/ipapi/internal/database"
|
|
"git.eeqj.de/sneak/ipapi/internal/state"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
)
|
|
|
|
var (
|
|
// RFC1918 private address ranges
|
|
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
|
rfc1918Networks = []*net.IPNet{}
|
|
// Documentation address ranges
|
|
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
|
documentationIPv4Networks = []*net.IPNet{}
|
|
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
|
documentationIPv6Networks = []*net.IPNet{}
|
|
)
|
|
|
|
func init() {
|
|
// Initialize RFC1918 networks
|
|
for _, cidr := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} {
|
|
_, network, _ := net.ParseCIDR(cidr)
|
|
rfc1918Networks = append(rfc1918Networks, network)
|
|
}
|
|
|
|
// Initialize documentation IPv4 networks
|
|
for _, cidr := range []string{"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"} {
|
|
_, network, _ := net.ParseCIDR(cidr)
|
|
documentationIPv4Networks = append(documentationIPv4Networks, network)
|
|
}
|
|
|
|
// Initialize documentation IPv6 networks
|
|
_, network, _ := net.ParseCIDR("2001:db8::/32")
|
|
documentationIPv6Networks = append(documentationIPv6Networks, network)
|
|
}
|
|
|
|
// IPInfo represents the data for a single IP address.
|
|
type IPInfo struct {
|
|
IP string `json:"ip"`
|
|
IPVersion int `json:"ipVersion"`
|
|
IsPrivate bool `json:"isPrivate"`
|
|
IsRFC1918 bool `json:"isRfc1918"`
|
|
IsLoopback bool `json:"isLoopback"`
|
|
IsMulticast bool `json:"isMulticast"`
|
|
IsLinkLocal bool `json:"isLinkLocal"`
|
|
IsGlobalUnicast bool `json:"isGlobalUnicast"`
|
|
IsUnspecified bool `json:"isUnspecified"`
|
|
IsBroadcast bool `json:"isBroadcast"`
|
|
IsDocumentation bool `json:"isDocumentation"`
|
|
Country string `json:"country,omitempty"`
|
|
CountryCode string `json:"countryCode,omitempty"`
|
|
City string `json:"city,omitempty"`
|
|
Region string `json:"region,omitempty"`
|
|
PostalCode string `json:"postalCode,omitempty"`
|
|
Latitude float64 `json:"latitude,omitempty"`
|
|
Longitude float64 `json:"longitude,omitempty"`
|
|
Timezone string `json:"timezone,omitempty"`
|
|
ASN uint `json:"asn,omitempty"`
|
|
ASNOrg string `json:"asnOrg,omitempty"`
|
|
DatabaseTimestamp time.Time `json:"databaseTimestamp"`
|
|
}
|
|
|
|
// SuccessResponse represents a successful API response.
|
|
type SuccessResponse struct {
|
|
Status string `json:"status"`
|
|
Date time.Time `json:"date"`
|
|
Result map[string]*IPInfo `json:"result"`
|
|
Error bool `json:"error"`
|
|
}
|
|
|
|
// ErrorResponse represents an error API response.
|
|
type ErrorResponse struct {
|
|
Status string `json:"status"`
|
|
Date time.Time `json:"date"`
|
|
Error map[string]interface{} `json:"error"`
|
|
}
|
|
|
|
// NewRouter creates a new HTTP router with all endpoints configured.
|
|
func NewRouter(logger *slog.Logger, db *database.Manager, state *state.Manager) (chi.Router, error) {
|
|
r := chi.NewRouter()
|
|
|
|
// Middleware
|
|
r.Use(middleware.RequestID)
|
|
r.Use(middleware.RealIP)
|
|
r.Use(middleware.Recoverer)
|
|
const requestTimeout = 60
|
|
r.Use(middleware.Timeout(requestTimeout))
|
|
|
|
// Logging middleware
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := r.Context().Value(middleware.RequestIDKey).(string)
|
|
logger.Debug("HTTP request",
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"remote_addr", r.RemoteAddr,
|
|
"request_id", start,
|
|
)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
})
|
|
|
|
// Health check
|
|
r.Get("/health", func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
if _, err := w.Write([]byte("OK")); err != nil {
|
|
logger.Error("Failed to write health response", "error", err)
|
|
}
|
|
})
|
|
|
|
// IP lookup endpoint - supports comma-delimited IPs
|
|
r.Get("/api/v1/ipinfo/local/{ips}", handleIPLookup(logger, db, state))
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func handleIPLookup(logger *slog.Logger, db *database.Manager, stateManager *state.Manager) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ipsParam := chi.URLParam(r, "ips")
|
|
|
|
// Split comma-delimited IPs
|
|
ipList := strings.Split(ipsParam, ",")
|
|
|
|
// Check for IP limit
|
|
const maxIPs = 50
|
|
const errorCodeTooManyIPs = 1005
|
|
if len(ipList) > maxIPs {
|
|
writeErrorResponse(w, http.StatusBadRequest, errorCodeTooManyIPs, "Too many IPs requested. Maximum is 50")
|
|
|
|
return
|
|
}
|
|
|
|
// Get database references
|
|
countryDB := db.GetCountryDB()
|
|
cityDB := db.GetCityDB()
|
|
asnDB := db.GetASNDB()
|
|
|
|
// Check if databases are loaded
|
|
const errorCodeDatabasesNotLoaded = 1002
|
|
if countryDB == nil && cityDB == nil && asnDB == nil {
|
|
writeErrorResponse(w, http.StatusServiceUnavailable, errorCodeDatabasesNotLoaded,
|
|
"GeoIP databases are not yet loaded")
|
|
|
|
return
|
|
}
|
|
|
|
// Get database timestamp
|
|
stateData, err := stateManager.Load()
|
|
var dbTimestamp time.Time
|
|
if err == nil && stateData != nil {
|
|
// Use the most recent database download time
|
|
timestamps := []time.Time{
|
|
stateData.LastASNDownload,
|
|
stateData.LastCityDownload,
|
|
stateData.LastCountryDownload,
|
|
}
|
|
dbTimestamp = timestamps[0]
|
|
for _, ts := range timestamps[1:] {
|
|
if ts.After(dbTimestamp) {
|
|
dbTimestamp = ts
|
|
}
|
|
}
|
|
}
|
|
|
|
result := make(map[string]*IPInfo)
|
|
|
|
for _, ipStr := range ipList {
|
|
ipStr = strings.TrimSpace(ipStr)
|
|
if ipStr == "" {
|
|
continue
|
|
}
|
|
|
|
// Validate IP address
|
|
ip := net.ParseIP(ipStr)
|
|
if ip == nil {
|
|
result[ipStr] = &IPInfo{
|
|
IP: ipStr,
|
|
DatabaseTimestamp: dbTimestamp,
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
info := &IPInfo{
|
|
IP: ipStr,
|
|
DatabaseTimestamp: dbTimestamp,
|
|
}
|
|
|
|
// Populate IP properties
|
|
info.IsPrivate = ip.IsPrivate()
|
|
info.IsLoopback = ip.IsLoopback()
|
|
info.IsMulticast = ip.IsMulticast()
|
|
info.IsLinkLocal = ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()
|
|
info.IsGlobalUnicast = ip.IsGlobalUnicast()
|
|
info.IsUnspecified = ip.IsUnspecified()
|
|
|
|
// Determine IP version and additional properties
|
|
if ipv4 := ip.To4(); ipv4 != nil {
|
|
info.IPVersion = 4
|
|
// Check for RFC1918 private addresses
|
|
info.IsRFC1918 = isRFC1918(ipv4)
|
|
// Check for broadcast (255.255.255.255)
|
|
const maxByte = 255
|
|
info.IsBroadcast = ipv4[0] == maxByte && ipv4[1] == maxByte && ipv4[2] == maxByte && ipv4[3] == maxByte
|
|
// Check for documentation addresses (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
|
|
info.IsDocumentation = isDocumentationIPv4(ipv4)
|
|
} else if ip.To16() != nil {
|
|
info.IPVersion = 6
|
|
info.IsRFC1918 = false
|
|
info.IsBroadcast = false
|
|
// Check for documentation addresses (2001:db8::/32)
|
|
info.IsDocumentation = isDocumentationIPv6(ip)
|
|
}
|
|
|
|
// Look up in Country database
|
|
if countryDB != nil {
|
|
country, err := countryDB.Country(ip)
|
|
if err == nil {
|
|
info.Country = country.Country.Names["en"]
|
|
info.CountryCode = country.Country.IsoCode
|
|
}
|
|
}
|
|
|
|
// Look up in City database
|
|
if cityDB != nil {
|
|
city, err := cityDB.City(ip)
|
|
if err == nil {
|
|
info.City = city.City.Names["en"]
|
|
if len(city.Subdivisions) > 0 {
|
|
info.Region = city.Subdivisions[0].Names["en"]
|
|
}
|
|
info.PostalCode = city.Postal.Code
|
|
info.Latitude = city.Location.Latitude
|
|
info.Longitude = city.Location.Longitude
|
|
info.Timezone = city.Location.TimeZone
|
|
}
|
|
}
|
|
|
|
// Look up in ASN database
|
|
if asnDB != nil {
|
|
asn, err := asnDB.ASN(ip)
|
|
if err == nil {
|
|
info.ASN = asn.AutonomousSystemNumber
|
|
info.ASNOrg = asn.AutonomousSystemOrganization
|
|
}
|
|
}
|
|
|
|
result[ipStr] = info
|
|
}
|
|
|
|
// Return success response
|
|
response := SuccessResponse{
|
|
Status: "success",
|
|
Date: time.Now().UTC(),
|
|
Result: result,
|
|
Error: false,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
const errorCodeEncodeFailure = 1004
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
logger.Error("Failed to encode response", "error", err)
|
|
writeErrorResponse(w, http.StatusInternalServerError, errorCodeEncodeFailure, "Failed to encode response")
|
|
}
|
|
}
|
|
}
|
|
|
|
func isRFC1918(ip net.IP) bool {
|
|
for _, network := range rfc1918Networks {
|
|
if network.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isDocumentationIPv4(ip net.IP) bool {
|
|
for _, network := range documentationIPv4Networks {
|
|
if network.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func isDocumentationIPv6(ip net.IP) bool {
|
|
for _, network := range documentationIPv6Networks {
|
|
if network.Contains(ip) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func writeErrorResponse(w http.ResponseWriter, httpCode int, errorCode int, message string) {
|
|
response := ErrorResponse{
|
|
Status: "error",
|
|
Date: time.Now().UTC(),
|
|
Error: map[string]interface{}{
|
|
"code": errorCode,
|
|
"msg": message,
|
|
},
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(httpCode)
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
// Log error but don't try to write again
|
|
_ = err
|
|
}
|
|
}
|
|
|
|
//nolint:unused // will be used in future for rate limiting
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
ips := strings.Split(xff, ",")
|
|
if len(ips) > 0 {
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
|
|
return host
|
|
}
|