diff --git a/bin/ipapi b/bin/ipapi new file mode 100755 index 0000000..de3a5ca Binary files /dev/null and b/bin/ipapi differ diff --git a/internal/config/config.go b/internal/config/config.go index 61eebab..3524a0b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "os" + "path/filepath" + "runtime" "strconv" "git.eeqj.de/sneak/smartconfig" @@ -41,7 +43,7 @@ func New(configFile string) (*Config, error) { if stateDir, err := sc.GetString("state_dir"); err == nil { cfg.StateDir = stateDir } else { - cfg.StateDir = "/var/lib/ipapi" + cfg.StateDir = getDefaultStateDir() } // Get log level @@ -57,7 +59,7 @@ func New(configFile string) (*Config, error) { func newDefaultConfig() *Config { return &Config{ Port: getPortFromEnv(), - StateDir: "/var/lib/ipapi", + StateDir: getDefaultStateDir(), LogLevel: "info", } } @@ -75,3 +77,27 @@ func getPortFromEnv() int { return port } + +func getDefaultStateDir() string { + switch runtime.GOOS { + case "darwin": + // macOS: use ~/Library/Application Support/berlin.sneak.app.ipapi/ + home, err := os.UserHomeDir() + if err != nil { + return "/var/lib/ipapi" + } + + return filepath.Join(home, "Library", "Application Support", "berlin.sneak.app.ipapi") + case "windows": + // Windows: use %APPDATA%\berlin.sneak.app.ipapi + appData := os.Getenv("APPDATA") + if appData == "" { + return "/var/lib/ipapi" + } + + return filepath.Join(appData, "berlin.sneak.app.ipapi") + default: + // Linux and other Unix-like systems + return "/var/lib/ipapi" + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 97fa22d..43c1b0b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -15,8 +15,9 @@ func TestNewDefaultConfig(t *testing.T) { if cfg.Port != 8080 { t.Errorf("expected default port 8080, got %d", cfg.Port) } - if cfg.StateDir != "/var/lib/ipapi" { - t.Errorf("expected default state dir /var/lib/ipapi, got %s", cfg.StateDir) + expectedStateDir := getDefaultStateDir() + if cfg.StateDir != expectedStateDir { + t.Errorf("expected default state dir %s, got %s", expectedStateDir, cfg.StateDir) } if cfg.LogLevel != "info" { t.Errorf("expected default log level info, got %s", cfg.LogLevel) diff --git a/internal/database/database.go b/internal/database/database.go index 6f91de0..124a9e3 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -25,7 +25,7 @@ const ( cityFile = "GeoLite2-City.mmdb" countryFile = "GeoLite2-Country.mmdb" - downloadTimeout = 5 * time.Minute + downloadTimeout = 30 * time.Second updateInterval = 7 * 24 * time.Hour // 1 week defaultDirPermissions = 0750 diff --git a/internal/http/module.go b/internal/http/module.go new file mode 100644 index 0000000..eadac19 --- /dev/null +++ b/internal/http/module.go @@ -0,0 +1,34 @@ +package http + +import ( + "log/slog" + + "git.eeqj.de/sneak/ipapi/internal/database" + "git.eeqj.de/sneak/ipapi/internal/state" + "github.com/go-chi/chi/v5" + "go.uber.org/fx" +) + +// Module provides HTTP server functionality. +// +//nolint:gochecknoglobals // fx module pattern requires global +var Module = fx.Options( + fx.Provide( + NewRouter, + NewServer, + ), +) + +// RouterParams contains dependencies for NewRouter. +type RouterParams struct { + fx.In + + Logger *slog.Logger + DB *database.Manager + State *state.Manager +} + +// NewRouterWrapper wraps NewRouter to work with FX dependency injection. +func NewRouterWrapper(params RouterParams) (chi.Router, error) { + return NewRouter(params.Logger, params.DB, params.State) +} diff --git a/internal/http/router.go b/internal/http/router.go index 6a20cff..12a0d59 100644 --- a/internal/http/router.go +++ b/internal/http/router.go @@ -7,29 +7,86 @@ import ( "net" "net/http" "strings" + "time" "git.eeqj.de/sneak/ipapi/internal/database" + "git.eeqj.de/sneak/ipapi/internal/state" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" ) -// IPInfo represents the API response for IP lookups. +var ( + // RFC1918 private address ranges + //nolint:gochecknoglobals // Pre-parsed networks for performance + rfc1918Networks = []*net.IPNet{} + // Documentation address ranges + //nolint:gochecknoglobals // Pre-parsed networks for performance + documentationIPv4Networks = []*net.IPNet{} + //nolint:gochecknoglobals // Pre-parsed networks for performance + documentationIPv6Networks = []*net.IPNet{} +) + +func init() { + // Initialize RFC1918 networks + for _, cidr := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} { + _, network, _ := net.ParseCIDR(cidr) + rfc1918Networks = append(rfc1918Networks, network) + } + + // Initialize documentation IPv4 networks + for _, cidr := range []string{"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"} { + _, network, _ := net.ParseCIDR(cidr) + documentationIPv4Networks = append(documentationIPv4Networks, network) + } + + // Initialize documentation IPv6 networks + _, network, _ := net.ParseCIDR("2001:db8::/32") + documentationIPv6Networks = append(documentationIPv6Networks, network) +} + +// IPInfo represents the data for a single IP address. type IPInfo struct { - IP string `json:"ip"` - Country string `json:"country,omitempty"` - CountryCode string `json:"countryCode,omitempty"` - City string `json:"city,omitempty"` - Region string `json:"region,omitempty"` - PostalCode string `json:"postalCode,omitempty"` - Latitude float64 `json:"latitude,omitempty"` - Longitude float64 `json:"longitude,omitempty"` - Timezone string `json:"timezone,omitempty"` - ASN uint `json:"asn,omitempty"` - ASNOrg string `json:"asnOrg,omitempty"` + IP string `json:"ip"` + IPVersion int `json:"ipVersion"` + IsPrivate bool `json:"isPrivate"` + IsRFC1918 bool `json:"isRfc1918"` + IsLoopback bool `json:"isLoopback"` + IsMulticast bool `json:"isMulticast"` + IsLinkLocal bool `json:"isLinkLocal"` + IsGlobalUnicast bool `json:"isGlobalUnicast"` + IsUnspecified bool `json:"isUnspecified"` + IsBroadcast bool `json:"isBroadcast"` + IsDocumentation bool `json:"isDocumentation"` + Country string `json:"country,omitempty"` + CountryCode string `json:"countryCode,omitempty"` + City string `json:"city,omitempty"` + Region string `json:"region,omitempty"` + PostalCode string `json:"postalCode,omitempty"` + Latitude float64 `json:"latitude,omitempty"` + Longitude float64 `json:"longitude,omitempty"` + Timezone string `json:"timezone,omitempty"` + ASN uint `json:"asn,omitempty"` + ASNOrg string `json:"asnOrg,omitempty"` + DatabaseTimestamp time.Time `json:"databaseTimestamp"` +} + +// SuccessResponse represents a successful API response. +type SuccessResponse struct { + Status string `json:"status"` + Date time.Time `json:"date"` + Result map[string]*IPInfo `json:"result"` + Error bool `json:"error"` +} + +// ErrorResponse represents an error API response. +type ErrorResponse struct { + Status string `json:"status"` + Date time.Time `json:"date"` + Error map[string]interface{} `json:"error"` } // NewRouter creates a new HTTP router with all endpoints configured. -func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) { +func NewRouter(logger *slog.Logger, db *database.Manager, state *state.Manager) (chi.Router, error) { r := chi.NewRouter() // Middleware @@ -61,76 +118,206 @@ func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) { } }) - // IP lookup endpoint - r.Get("/api/{ip}", handleIPLookup(logger, db)) + // IP lookup endpoint - supports comma-delimited IPs + r.Get("/api/v1/ipinfo/local/{ips}", handleIPLookup(logger, db, state)) return r, nil } -func handleIPLookup(logger *slog.Logger, db *database.Manager) http.HandlerFunc { +func handleIPLookup(logger *slog.Logger, db *database.Manager, stateManager *state.Manager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - ipStr := chi.URLParam(r, "ip") + ipsParam := chi.URLParam(r, "ips") - // Validate IP address - ip := net.ParseIP(ipStr) - if ip == nil { - writeError(w, http.StatusBadRequest, "Invalid IP address") + // Split comma-delimited IPs + ipList := strings.Split(ipsParam, ",") + + // Check for IP limit + const maxIPs = 50 + const errorCodeTooManyIPs = 1005 + if len(ipList) > maxIPs { + writeErrorResponse(w, http.StatusBadRequest, errorCodeTooManyIPs, "Too many IPs requested. Maximum is 50") return } - info := &IPInfo{ - IP: ipStr, + // Get database references + countryDB := db.GetCountryDB() + cityDB := db.GetCityDB() + asnDB := db.GetASNDB() + + // Check if databases are loaded + const errorCodeDatabasesNotLoaded = 1002 + if countryDB == nil && cityDB == nil && asnDB == nil { + writeErrorResponse(w, http.StatusServiceUnavailable, errorCodeDatabasesNotLoaded, + "GeoIP databases are not yet loaded") + + return } - // Look up in Country database - if countryDB := db.GetCountryDB(); countryDB != nil { - country, err := countryDB.Country(ip) - if err == nil { - info.Country = country.Country.Names["en"] - info.CountryCode = country.Country.IsoCode + // Get database timestamp + stateData, err := stateManager.Load() + var dbTimestamp time.Time + if err == nil && stateData != nil { + // Use the most recent database download time + timestamps := []time.Time{ + stateData.LastASNDownload, + stateData.LastCityDownload, + stateData.LastCountryDownload, } - } - - // Look up in City database - if cityDB := db.GetCityDB(); cityDB != nil { - city, err := cityDB.City(ip) - if err == nil { - info.City = city.City.Names["en"] - if len(city.Subdivisions) > 0 { - info.Region = city.Subdivisions[0].Names["en"] + dbTimestamp = timestamps[0] + for _, ts := range timestamps[1:] { + if ts.After(dbTimestamp) { + dbTimestamp = ts } - info.PostalCode = city.Postal.Code - info.Latitude = city.Location.Latitude - info.Longitude = city.Location.Longitude - info.Timezone = city.Location.TimeZone } } - // Look up in ASN database - if asnDB := db.GetASNDB(); asnDB != nil { - asn, err := asnDB.ASN(ip) - if err == nil { - info.ASN = asn.AutonomousSystemNumber - info.ASNOrg = asn.AutonomousSystemOrganization + result := make(map[string]*IPInfo) + + for _, ipStr := range ipList { + ipStr = strings.TrimSpace(ipStr) + if ipStr == "" { + continue } + + // Validate IP address + ip := net.ParseIP(ipStr) + if ip == nil { + result[ipStr] = &IPInfo{ + IP: ipStr, + DatabaseTimestamp: dbTimestamp, + } + + continue + } + + info := &IPInfo{ + IP: ipStr, + DatabaseTimestamp: dbTimestamp, + } + + // Populate IP properties + info.IsPrivate = ip.IsPrivate() + info.IsLoopback = ip.IsLoopback() + info.IsMulticast = ip.IsMulticast() + info.IsLinkLocal = ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() + info.IsGlobalUnicast = ip.IsGlobalUnicast() + info.IsUnspecified = ip.IsUnspecified() + + // Determine IP version and additional properties + if ipv4 := ip.To4(); ipv4 != nil { + info.IPVersion = 4 + // Check for RFC1918 private addresses + info.IsRFC1918 = isRFC1918(ipv4) + // Check for broadcast (255.255.255.255) + const maxByte = 255 + info.IsBroadcast = ipv4[0] == maxByte && ipv4[1] == maxByte && ipv4[2] == maxByte && ipv4[3] == maxByte + // Check for documentation addresses (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24) + info.IsDocumentation = isDocumentationIPv4(ipv4) + } else if ip.To16() != nil { + info.IPVersion = 6 + info.IsRFC1918 = false + info.IsBroadcast = false + // Check for documentation addresses (2001:db8::/32) + info.IsDocumentation = isDocumentationIPv6(ip) + } + + // Look up in Country database + if countryDB != nil { + country, err := countryDB.Country(ip) + if err == nil { + info.Country = country.Country.Names["en"] + info.CountryCode = country.Country.IsoCode + } + } + + // Look up in City database + if cityDB != nil { + city, err := cityDB.City(ip) + if err == nil { + info.City = city.City.Names["en"] + if len(city.Subdivisions) > 0 { + info.Region = city.Subdivisions[0].Names["en"] + } + info.PostalCode = city.Postal.Code + info.Latitude = city.Location.Latitude + info.Longitude = city.Location.Longitude + info.Timezone = city.Location.TimeZone + } + } + + // Look up in ASN database + if asnDB != nil { + asn, err := asnDB.ASN(ip) + if err == nil { + info.ASN = asn.AutonomousSystemNumber + info.ASNOrg = asn.AutonomousSystemOrganization + } + } + + result[ipStr] = info + } + + // Return success response + response := SuccessResponse{ + Status: "success", + Date: time.Now().UTC(), + Result: result, + Error: false, } - // Set content type and encode response w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(info); err != nil { + const errorCodeEncodeFailure = 1004 + if err := json.NewEncoder(w).Encode(response); err != nil { logger.Error("Failed to encode response", "error", err) - writeError(w, http.StatusInternalServerError, "Internal server error") + writeErrorResponse(w, http.StatusInternalServerError, errorCodeEncodeFailure, "Failed to encode response") } } } -func writeError(w http.ResponseWriter, code int, message string) { +func isRFC1918(ip net.IP) bool { + for _, network := range rfc1918Networks { + if network.Contains(ip) { + return true + } + } + + return false +} + +func isDocumentationIPv4(ip net.IP) bool { + for _, network := range documentationIPv4Networks { + if network.Contains(ip) { + return true + } + } + + return false +} + +func isDocumentationIPv6(ip net.IP) bool { + for _, network := range documentationIPv6Networks { + if network.Contains(ip) { + return true + } + } + + return false +} + +func writeErrorResponse(w http.ResponseWriter, httpCode int, errorCode int, message string) { + response := ErrorResponse{ + Status: "error", + Date: time.Now().UTC(), + Error: map[string]interface{}{ + "code": errorCode, + "msg": message, + }, + } + w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - if err := json.NewEncoder(w).Encode(map[string]string{ - "error": message, - }); err != nil { + w.WriteHeader(httpCode) + if err := json.NewEncoder(w).Encode(response); err != nil { // Log error but don't try to write again _ = err } @@ -155,4 +342,4 @@ func getClientIP(r *http.Request) string { host, _, _ := net.SplitHostPort(r.RemoteAddr) return host -} \ No newline at end of file +} diff --git a/internal/http/router_test.go b/internal/http/router_test.go index 78c93d7..b7c28bd 100644 --- a/internal/http/router_test.go +++ b/internal/http/router_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "git.eeqj.de/sneak/ipapi/internal/config" "git.eeqj.de/sneak/ipapi/internal/database" @@ -22,7 +23,7 @@ func TestNewRouter(t *testing.T) { stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) - router, err := NewRouter(logger, db) + router, err := NewRouter(logger, db, stateManager) if err != nil { t.Fatalf("failed to create router: %v", err) } @@ -40,7 +41,7 @@ func TestHealthEndpoint(t *testing.T) { stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) - router, _ := NewRouter(logger, db) + router, _ := NewRouter(logger, db, stateManager) req := httptest.NewRequest("GET", "/health", nil) rec := httptest.NewRecorder() @@ -58,12 +59,24 @@ func TestHealthEndpoint(t *testing.T) { func TestIPLookupEndpoint(t *testing.T) { logger := slog.Default() + tmpDir := t.TempDir() cfg := &config.Config{ - StateDir: t.TempDir(), + StateDir: tmpDir, } stateManager, _ := state.New(cfg, logger) + _ = stateManager.Initialize(context.Background()) + db, _ := database.New(cfg, logger, stateManager) + // Download databases to temp directory for testing + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + t.Log("Downloading test databases...") + if err := db.EnsureDatabases(ctx); err != nil { + t.Skipf("Skipping test - could not download databases: %v", err) + } + tests := []struct { name string ip string @@ -71,38 +84,141 @@ func TestIPLookupEndpoint(t *testing.T) { }{ {"valid IPv4", "8.8.8.8", http.StatusOK}, {"valid IPv6", "2001:4860:4860::8888", http.StatusOK}, - {"invalid IP", "invalid", http.StatusBadRequest}, + {"invalid IP", "invalid", http.StatusOK}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/api/"+tt.ip, nil) + req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+tt.ip, nil) rec := httptest.NewRecorder() // Create a new context with the URL param rctx := chi.NewRouteContext() - rctx.URLParams.Add("ip", tt.ip) + rctx.URLParams.Add("ips", tt.ip) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - handleIPLookup(logger, db)(rec, req) + handleIPLookup(logger, db, stateManager)(rec, req) if rec.Code != tt.expectedCode { t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code) } if tt.expectedCode == http.StatusOK { - var info IPInfo - if err := json.Unmarshal(rec.Body.Bytes(), &info); err != nil { + var response SuccessResponse + if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { t.Errorf("failed to unmarshal response: %v", err) } - if info.IP != tt.ip { - t.Errorf("expected IP %s, got %s", tt.ip, info.IP) + if response.Status != "success" { + t.Errorf("expected status 'success', got %s", response.Status) + } + if len(response.Result) != 1 { + t.Errorf("expected 1 result, got %d", len(response.Result)) + } + if info, ok := response.Result[tt.ip]; ok { + if info.IP != tt.ip { + t.Errorf("expected IP %s, got %s", tt.ip, info.IP) + } + } else { + t.Errorf("expected result for IP %s", tt.ip) } } }) } } +func TestIPLookupMultipleIPs(t *testing.T) { + logger := slog.Default() + tmpDir := t.TempDir() + cfg := &config.Config{ + StateDir: tmpDir, + } + stateManager, _ := state.New(cfg, logger) + _ = stateManager.Initialize(context.Background()) + + db, _ := database.New(cfg, logger, stateManager) + + // Download databases to temp directory for testing + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + t.Log("Downloading test databases...") + if err := db.EnsureDatabases(ctx); err != nil { + t.Skipf("Skipping test - could not download databases: %v", err) + } + + // Test multiple IPs + req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/8.8.8.8,1.1.1.1,invalid", nil) + rec := httptest.NewRecorder() + + // Create a new context with the URL param + rctx := chi.NewRouteContext() + rctx.URLParams.Add("ips", "8.8.8.8,1.1.1.1,invalid") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + handleIPLookup(logger, db, stateManager)(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rec.Code) + } + + var response SuccessResponse + if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { + t.Errorf("failed to unmarshal response: %v", err) + } + + if len(response.Result) != 3 { + t.Errorf("expected 3 results, got %d", len(response.Result)) + } + + // Check that all requested IPs are in the result + for _, ip := range []string{"8.8.8.8", "1.1.1.1", "invalid"} { + if _, ok := response.Result[ip]; !ok { + t.Errorf("expected result for IP %s", ip) + } + } +} + +func TestIPLookupTooManyIPs(t *testing.T) { + logger := slog.Default() + cfg := &config.Config{ + StateDir: t.TempDir(), + } + stateManager, _ := state.New(cfg, logger) + db, _ := database.New(cfg, logger, stateManager) + + // Create 51 IPs (more than the limit) + ips := "" + for i := 0; i < 51; i++ { + if i > 0 { + ips += "," + } + ips += "1.1.1.1" + } + + req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+ips, nil) + rec := httptest.NewRecorder() + + // Create a new context with the URL param + rctx := chi.NewRouteContext() + rctx.URLParams.Add("ips", ips) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + handleIPLookup(logger, db, stateManager)(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", rec.Code) + } + + var response ErrorResponse + if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { + t.Errorf("failed to unmarshal response: %v", err) + } + + if response.Status != "error" { + t.Errorf("expected status 'error', got %s", response.Status) + } +} + func TestGetClientIP(t *testing.T) { tests := []struct { name string diff --git a/internal/ipapi/daemon.go b/internal/ipapi/daemon.go index 90f331d..336fba3 100644 --- a/internal/ipapi/daemon.go +++ b/internal/ipapi/daemon.go @@ -32,7 +32,7 @@ func daemonCmd() *cobra.Command { state.New, database.New, http.NewServer, - http.NewRouter, + http.NewRouterWrapper, New, ), fx.Invoke(func(lc fx.Lifecycle, ipapi *IPAPI) { diff --git a/internal/ipapi/ipapi.go b/internal/ipapi/ipapi.go index 64b5847..131fbe1 100644 --- a/internal/ipapi/ipapi.go +++ b/internal/ipapi/ipapi.go @@ -54,10 +54,14 @@ func (i *IPAPI) Start(ctx context.Context) error { return err } - // Download databases if needed - if err := i.database.EnsureDatabases(ctx); err != nil { - return err - } + // Start database downloads in background with a fresh context + // This avoids using the OnStart context which has a timeout + go func() { + dbCtx := context.Background() + if err := i.database.EnsureDatabases(dbCtx); err != nil { + i.logger.Error("Failed to ensure databases", "error", err) + } + }() // Start HTTP server return i.server.Start(ctx)