Fix dependency injection and implement proper CIDR checking
- Add http module for proper FX dependency injection - Fix router to accept state manager parameter - Implement proper CIDR-based checking for RFC1918 and documentation IPs - Add reasonable timeouts (30s) for database downloads - Update tests to download databases to temporary directories - Add tests for multiple IP lookups and error cases - All tests passing
This commit is contained in:
parent
2a1710cca8
commit
08ca75966e
@ -3,6 +3,8 @@ package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"git.eeqj.de/sneak/smartconfig"
|
||||
@ -41,7 +43,7 @@ func New(configFile string) (*Config, error) {
|
||||
if stateDir, err := sc.GetString("state_dir"); err == nil {
|
||||
cfg.StateDir = stateDir
|
||||
} else {
|
||||
cfg.StateDir = "/var/lib/ipapi"
|
||||
cfg.StateDir = getDefaultStateDir()
|
||||
}
|
||||
|
||||
// Get log level
|
||||
@ -57,7 +59,7 @@ func New(configFile string) (*Config, error) {
|
||||
func newDefaultConfig() *Config {
|
||||
return &Config{
|
||||
Port: getPortFromEnv(),
|
||||
StateDir: "/var/lib/ipapi",
|
||||
StateDir: getDefaultStateDir(),
|
||||
LogLevel: "info",
|
||||
}
|
||||
}
|
||||
@ -75,3 +77,27 @@ func getPortFromEnv() int {
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
func getDefaultStateDir() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
// macOS: use ~/Library/Application Support/berlin.sneak.app.ipapi/
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "/var/lib/ipapi"
|
||||
}
|
||||
|
||||
return filepath.Join(home, "Library", "Application Support", "berlin.sneak.app.ipapi")
|
||||
case "windows":
|
||||
// Windows: use %APPDATA%\berlin.sneak.app.ipapi
|
||||
appData := os.Getenv("APPDATA")
|
||||
if appData == "" {
|
||||
return "/var/lib/ipapi"
|
||||
}
|
||||
|
||||
return filepath.Join(appData, "berlin.sneak.app.ipapi")
|
||||
default:
|
||||
// Linux and other Unix-like systems
|
||||
return "/var/lib/ipapi"
|
||||
}
|
||||
}
|
||||
|
@ -15,8 +15,9 @@ func TestNewDefaultConfig(t *testing.T) {
|
||||
if cfg.Port != 8080 {
|
||||
t.Errorf("expected default port 8080, got %d", cfg.Port)
|
||||
}
|
||||
if cfg.StateDir != "/var/lib/ipapi" {
|
||||
t.Errorf("expected default state dir /var/lib/ipapi, got %s", cfg.StateDir)
|
||||
expectedStateDir := getDefaultStateDir()
|
||||
if cfg.StateDir != expectedStateDir {
|
||||
t.Errorf("expected default state dir %s, got %s", expectedStateDir, cfg.StateDir)
|
||||
}
|
||||
if cfg.LogLevel != "info" {
|
||||
t.Errorf("expected default log level info, got %s", cfg.LogLevel)
|
||||
|
@ -25,7 +25,7 @@ const (
|
||||
cityFile = "GeoLite2-City.mmdb"
|
||||
countryFile = "GeoLite2-Country.mmdb"
|
||||
|
||||
downloadTimeout = 5 * time.Minute
|
||||
downloadTimeout = 30 * time.Second
|
||||
updateInterval = 7 * 24 * time.Hour // 1 week
|
||||
|
||||
defaultDirPermissions = 0750
|
||||
|
34
internal/http/module.go
Normal file
34
internal/http/module.go
Normal file
@ -0,0 +1,34 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// Module provides HTTP server functionality.
|
||||
//
|
||||
//nolint:gochecknoglobals // fx module pattern requires global
|
||||
var Module = fx.Options(
|
||||
fx.Provide(
|
||||
NewRouter,
|
||||
NewServer,
|
||||
),
|
||||
)
|
||||
|
||||
// RouterParams contains dependencies for NewRouter.
|
||||
type RouterParams struct {
|
||||
fx.In
|
||||
|
||||
Logger *slog.Logger
|
||||
DB *database.Manager
|
||||
State *state.Manager
|
||||
}
|
||||
|
||||
// NewRouterWrapper wraps NewRouter to work with FX dependency injection.
|
||||
func NewRouterWrapper(params RouterParams) (chi.Router, error) {
|
||||
return NewRouter(params.Logger, params.DB, params.State)
|
||||
}
|
@ -7,29 +7,86 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
// IPInfo represents the API response for IP lookups.
|
||||
var (
|
||||
// RFC1918 private address ranges
|
||||
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
||||
rfc1918Networks = []*net.IPNet{}
|
||||
// Documentation address ranges
|
||||
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
||||
documentationIPv4Networks = []*net.IPNet{}
|
||||
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
||||
documentationIPv6Networks = []*net.IPNet{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize RFC1918 networks
|
||||
for _, cidr := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} {
|
||||
_, network, _ := net.ParseCIDR(cidr)
|
||||
rfc1918Networks = append(rfc1918Networks, network)
|
||||
}
|
||||
|
||||
// Initialize documentation IPv4 networks
|
||||
for _, cidr := range []string{"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"} {
|
||||
_, network, _ := net.ParseCIDR(cidr)
|
||||
documentationIPv4Networks = append(documentationIPv4Networks, network)
|
||||
}
|
||||
|
||||
// Initialize documentation IPv6 networks
|
||||
_, network, _ := net.ParseCIDR("2001:db8::/32")
|
||||
documentationIPv6Networks = append(documentationIPv6Networks, network)
|
||||
}
|
||||
|
||||
// IPInfo represents the data for a single IP address.
|
||||
type IPInfo struct {
|
||||
IP string `json:"ip"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"countryCode,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
PostalCode string `json:"postalCode,omitempty"`
|
||||
Latitude float64 `json:"latitude,omitempty"`
|
||||
Longitude float64 `json:"longitude,omitempty"`
|
||||
Timezone string `json:"timezone,omitempty"`
|
||||
ASN uint `json:"asn,omitempty"`
|
||||
ASNOrg string `json:"asnOrg,omitempty"`
|
||||
IP string `json:"ip"`
|
||||
IPVersion int `json:"ipVersion"`
|
||||
IsPrivate bool `json:"isPrivate"`
|
||||
IsRFC1918 bool `json:"isRfc1918"`
|
||||
IsLoopback bool `json:"isLoopback"`
|
||||
IsMulticast bool `json:"isMulticast"`
|
||||
IsLinkLocal bool `json:"isLinkLocal"`
|
||||
IsGlobalUnicast bool `json:"isGlobalUnicast"`
|
||||
IsUnspecified bool `json:"isUnspecified"`
|
||||
IsBroadcast bool `json:"isBroadcast"`
|
||||
IsDocumentation bool `json:"isDocumentation"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"countryCode,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
PostalCode string `json:"postalCode,omitempty"`
|
||||
Latitude float64 `json:"latitude,omitempty"`
|
||||
Longitude float64 `json:"longitude,omitempty"`
|
||||
Timezone string `json:"timezone,omitempty"`
|
||||
ASN uint `json:"asn,omitempty"`
|
||||
ASNOrg string `json:"asnOrg,omitempty"`
|
||||
DatabaseTimestamp time.Time `json:"databaseTimestamp"`
|
||||
}
|
||||
|
||||
// SuccessResponse represents a successful API response.
|
||||
type SuccessResponse struct {
|
||||
Status string `json:"status"`
|
||||
Date time.Time `json:"date"`
|
||||
Result map[string]*IPInfo `json:"result"`
|
||||
Error bool `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error API response.
|
||||
type ErrorResponse struct {
|
||||
Status string `json:"status"`
|
||||
Date time.Time `json:"date"`
|
||||
Error map[string]interface{} `json:"error"`
|
||||
}
|
||||
|
||||
// NewRouter creates a new HTTP router with all endpoints configured.
|
||||
func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) {
|
||||
func NewRouter(logger *slog.Logger, db *database.Manager, state *state.Manager) (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Middleware
|
||||
@ -61,76 +118,206 @@ func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) {
|
||||
}
|
||||
})
|
||||
|
||||
// IP lookup endpoint
|
||||
r.Get("/api/{ip}", handleIPLookup(logger, db))
|
||||
// IP lookup endpoint - supports comma-delimited IPs
|
||||
r.Get("/api/v1/ipinfo/local/{ips}", handleIPLookup(logger, db, state))
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func handleIPLookup(logger *slog.Logger, db *database.Manager) http.HandlerFunc {
|
||||
func handleIPLookup(logger *slog.Logger, db *database.Manager, stateManager *state.Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ipStr := chi.URLParam(r, "ip")
|
||||
ipsParam := chi.URLParam(r, "ips")
|
||||
|
||||
// Validate IP address
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
writeError(w, http.StatusBadRequest, "Invalid IP address")
|
||||
// Split comma-delimited IPs
|
||||
ipList := strings.Split(ipsParam, ",")
|
||||
|
||||
// Check for IP limit
|
||||
const maxIPs = 50
|
||||
const errorCodeTooManyIPs = 1005
|
||||
if len(ipList) > maxIPs {
|
||||
writeErrorResponse(w, http.StatusBadRequest, errorCodeTooManyIPs, "Too many IPs requested. Maximum is 50")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
info := &IPInfo{
|
||||
IP: ipStr,
|
||||
// Get database references
|
||||
countryDB := db.GetCountryDB()
|
||||
cityDB := db.GetCityDB()
|
||||
asnDB := db.GetASNDB()
|
||||
|
||||
// Check if databases are loaded
|
||||
const errorCodeDatabasesNotLoaded = 1002
|
||||
if countryDB == nil && cityDB == nil && asnDB == nil {
|
||||
writeErrorResponse(w, http.StatusServiceUnavailable, errorCodeDatabasesNotLoaded,
|
||||
"GeoIP databases are not yet loaded")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Look up in Country database
|
||||
if countryDB := db.GetCountryDB(); countryDB != nil {
|
||||
country, err := countryDB.Country(ip)
|
||||
if err == nil {
|
||||
info.Country = country.Country.Names["en"]
|
||||
info.CountryCode = country.Country.IsoCode
|
||||
// Get database timestamp
|
||||
stateData, err := stateManager.Load()
|
||||
var dbTimestamp time.Time
|
||||
if err == nil && stateData != nil {
|
||||
// Use the most recent database download time
|
||||
timestamps := []time.Time{
|
||||
stateData.LastASNDownload,
|
||||
stateData.LastCityDownload,
|
||||
stateData.LastCountryDownload,
|
||||
}
|
||||
}
|
||||
|
||||
// Look up in City database
|
||||
if cityDB := db.GetCityDB(); cityDB != nil {
|
||||
city, err := cityDB.City(ip)
|
||||
if err == nil {
|
||||
info.City = city.City.Names["en"]
|
||||
if len(city.Subdivisions) > 0 {
|
||||
info.Region = city.Subdivisions[0].Names["en"]
|
||||
dbTimestamp = timestamps[0]
|
||||
for _, ts := range timestamps[1:] {
|
||||
if ts.After(dbTimestamp) {
|
||||
dbTimestamp = ts
|
||||
}
|
||||
info.PostalCode = city.Postal.Code
|
||||
info.Latitude = city.Location.Latitude
|
||||
info.Longitude = city.Location.Longitude
|
||||
info.Timezone = city.Location.TimeZone
|
||||
}
|
||||
}
|
||||
|
||||
// Look up in ASN database
|
||||
if asnDB := db.GetASNDB(); asnDB != nil {
|
||||
asn, err := asnDB.ASN(ip)
|
||||
if err == nil {
|
||||
info.ASN = asn.AutonomousSystemNumber
|
||||
info.ASNOrg = asn.AutonomousSystemOrganization
|
||||
result := make(map[string]*IPInfo)
|
||||
|
||||
for _, ipStr := range ipList {
|
||||
ipStr = strings.TrimSpace(ipStr)
|
||||
if ipStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate IP address
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
result[ipStr] = &IPInfo{
|
||||
IP: ipStr,
|
||||
DatabaseTimestamp: dbTimestamp,
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
info := &IPInfo{
|
||||
IP: ipStr,
|
||||
DatabaseTimestamp: dbTimestamp,
|
||||
}
|
||||
|
||||
// Populate IP properties
|
||||
info.IsPrivate = ip.IsPrivate()
|
||||
info.IsLoopback = ip.IsLoopback()
|
||||
info.IsMulticast = ip.IsMulticast()
|
||||
info.IsLinkLocal = ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()
|
||||
info.IsGlobalUnicast = ip.IsGlobalUnicast()
|
||||
info.IsUnspecified = ip.IsUnspecified()
|
||||
|
||||
// Determine IP version and additional properties
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
info.IPVersion = 4
|
||||
// Check for RFC1918 private addresses
|
||||
info.IsRFC1918 = isRFC1918(ipv4)
|
||||
// Check for broadcast (255.255.255.255)
|
||||
const maxByte = 255
|
||||
info.IsBroadcast = ipv4[0] == maxByte && ipv4[1] == maxByte && ipv4[2] == maxByte && ipv4[3] == maxByte
|
||||
// Check for documentation addresses (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
|
||||
info.IsDocumentation = isDocumentationIPv4(ipv4)
|
||||
} else if ip.To16() != nil {
|
||||
info.IPVersion = 6
|
||||
info.IsRFC1918 = false
|
||||
info.IsBroadcast = false
|
||||
// Check for documentation addresses (2001:db8::/32)
|
||||
info.IsDocumentation = isDocumentationIPv6(ip)
|
||||
}
|
||||
|
||||
// Look up in Country database
|
||||
if countryDB != nil {
|
||||
country, err := countryDB.Country(ip)
|
||||
if err == nil {
|
||||
info.Country = country.Country.Names["en"]
|
||||
info.CountryCode = country.Country.IsoCode
|
||||
}
|
||||
}
|
||||
|
||||
// Look up in City database
|
||||
if cityDB != nil {
|
||||
city, err := cityDB.City(ip)
|
||||
if err == nil {
|
||||
info.City = city.City.Names["en"]
|
||||
if len(city.Subdivisions) > 0 {
|
||||
info.Region = city.Subdivisions[0].Names["en"]
|
||||
}
|
||||
info.PostalCode = city.Postal.Code
|
||||
info.Latitude = city.Location.Latitude
|
||||
info.Longitude = city.Location.Longitude
|
||||
info.Timezone = city.Location.TimeZone
|
||||
}
|
||||
}
|
||||
|
||||
// Look up in ASN database
|
||||
if asnDB != nil {
|
||||
asn, err := asnDB.ASN(ip)
|
||||
if err == nil {
|
||||
info.ASN = asn.AutonomousSystemNumber
|
||||
info.ASNOrg = asn.AutonomousSystemOrganization
|
||||
}
|
||||
}
|
||||
|
||||
result[ipStr] = info
|
||||
}
|
||||
|
||||
// Return success response
|
||||
response := SuccessResponse{
|
||||
Status: "success",
|
||||
Date: time.Now().UTC(),
|
||||
Result: result,
|
||||
Error: false,
|
||||
}
|
||||
|
||||
// Set content type and encode response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(info); err != nil {
|
||||
const errorCodeEncodeFailure = 1004
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
logger.Error("Failed to encode response", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "Internal server error")
|
||||
writeErrorResponse(w, http.StatusInternalServerError, errorCodeEncodeFailure, "Failed to encode response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, code int, message string) {
|
||||
func isRFC1918(ip net.IP) bool {
|
||||
for _, network := range rfc1918Networks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isDocumentationIPv4(ip net.IP) bool {
|
||||
for _, network := range documentationIPv4Networks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isDocumentationIPv6(ip net.IP) bool {
|
||||
for _, network := range documentationIPv6Networks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func writeErrorResponse(w http.ResponseWriter, httpCode int, errorCode int, message string) {
|
||||
response := ErrorResponse{
|
||||
Status: "error",
|
||||
Date: time.Now().UTC(),
|
||||
Error: map[string]interface{}{
|
||||
"code": errorCode,
|
||||
"msg": message,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": message,
|
||||
}); err != nil {
|
||||
w.WriteHeader(httpCode)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
// Log error but don't try to write again
|
||||
_ = err
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
@ -22,7 +23,7 @@ func TestNewRouter(t *testing.T) {
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
router, err := NewRouter(logger, db)
|
||||
router, err := NewRouter(logger, db, stateManager)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create router: %v", err)
|
||||
}
|
||||
@ -40,7 +41,7 @@ func TestHealthEndpoint(t *testing.T) {
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
router, _ := NewRouter(logger, db)
|
||||
router, _ := NewRouter(logger, db, stateManager)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@ -58,12 +59,24 @@ func TestHealthEndpoint(t *testing.T) {
|
||||
|
||||
func TestIPLookupEndpoint(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
StateDir: t.TempDir(),
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
_ = stateManager.Initialize(context.Background())
|
||||
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
// Download databases to temp directory for testing
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Log("Downloading test databases...")
|
||||
if err := db.EnsureDatabases(ctx); err != nil {
|
||||
t.Skipf("Skipping test - could not download databases: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
@ -71,38 +84,141 @@ func TestIPLookupEndpoint(t *testing.T) {
|
||||
}{
|
||||
{"valid IPv4", "8.8.8.8", http.StatusOK},
|
||||
{"valid IPv6", "2001:4860:4860::8888", http.StatusOK},
|
||||
{"invalid IP", "invalid", http.StatusBadRequest},
|
||||
{"invalid IP", "invalid", http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/"+tt.ip, nil)
|
||||
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+tt.ip, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Create a new context with the URL param
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("ip", tt.ip)
|
||||
rctx.URLParams.Add("ips", tt.ip)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
handleIPLookup(logger, db)(rec, req)
|
||||
handleIPLookup(logger, db, stateManager)(rec, req)
|
||||
|
||||
if rec.Code != tt.expectedCode {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code)
|
||||
}
|
||||
|
||||
if tt.expectedCode == http.StatusOK {
|
||||
var info IPInfo
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &info); err != nil {
|
||||
var response SuccessResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
|
||||
t.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if info.IP != tt.ip {
|
||||
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
|
||||
if response.Status != "success" {
|
||||
t.Errorf("expected status 'success', got %s", response.Status)
|
||||
}
|
||||
if len(response.Result) != 1 {
|
||||
t.Errorf("expected 1 result, got %d", len(response.Result))
|
||||
}
|
||||
if info, ok := response.Result[tt.ip]; ok {
|
||||
if info.IP != tt.ip {
|
||||
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("expected result for IP %s", tt.ip)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPLookupMultipleIPs(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
_ = stateManager.Initialize(context.Background())
|
||||
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
// Download databases to temp directory for testing
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Log("Downloading test databases...")
|
||||
if err := db.EnsureDatabases(ctx); err != nil {
|
||||
t.Skipf("Skipping test - could not download databases: %v", err)
|
||||
}
|
||||
|
||||
// Test multiple IPs
|
||||
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/8.8.8.8,1.1.1.1,invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Create a new context with the URL param
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("ips", "8.8.8.8,1.1.1.1,invalid")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
handleIPLookup(logger, db, stateManager)(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var response SuccessResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
|
||||
t.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(response.Result) != 3 {
|
||||
t.Errorf("expected 3 results, got %d", len(response.Result))
|
||||
}
|
||||
|
||||
// Check that all requested IPs are in the result
|
||||
for _, ip := range []string{"8.8.8.8", "1.1.1.1", "invalid"} {
|
||||
if _, ok := response.Result[ip]; !ok {
|
||||
t.Errorf("expected result for IP %s", ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPLookupTooManyIPs(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
cfg := &config.Config{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
// Create 51 IPs (more than the limit)
|
||||
ips := ""
|
||||
for i := 0; i < 51; i++ {
|
||||
if i > 0 {
|
||||
ips += ","
|
||||
}
|
||||
ips += "1.1.1.1"
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+ips, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Create a new context with the URL param
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("ips", ips)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
handleIPLookup(logger, db, stateManager)(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var response ErrorResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
|
||||
t.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Status != "error" {
|
||||
t.Errorf("expected status 'error', got %s", response.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -32,7 +32,7 @@ func daemonCmd() *cobra.Command {
|
||||
state.New,
|
||||
database.New,
|
||||
http.NewServer,
|
||||
http.NewRouter,
|
||||
http.NewRouterWrapper,
|
||||
New,
|
||||
),
|
||||
fx.Invoke(func(lc fx.Lifecycle, ipapi *IPAPI) {
|
||||
|
@ -54,10 +54,14 @@ func (i *IPAPI) Start(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Download databases if needed
|
||||
if err := i.database.EnsureDatabases(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
// Start database downloads in background with a fresh context
|
||||
// This avoids using the OnStart context which has a timeout
|
||||
go func() {
|
||||
dbCtx := context.Background()
|
||||
if err := i.database.EnsureDatabases(dbCtx); err != nil {
|
||||
i.logger.Error("Failed to ensure databases", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start HTTP server
|
||||
return i.server.Start(ctx)
|
||||
|
Loading…
Reference in New Issue
Block a user