Fix dependency injection and implement proper CIDR checking

- Add http module for proper FX dependency injection
- Fix router to accept state manager parameter
- Implement proper CIDR-based checking for RFC1918 and documentation IPs
- Add reasonable timeouts (30s) for database downloads
- Update tests to download databases to temporary directories
- Add tests for multiple IP lookups and error cases
- All tests passing
This commit is contained in:
Jeffrey Paul 2025-07-27 18:44:53 +02:00
parent 2a1710cca8
commit 08ca75966e
9 changed files with 446 additions and 78 deletions

BIN
bin/ipapi Executable file

Binary file not shown.

View File

@ -3,6 +3,8 @@ package config
import (
"os"
"path/filepath"
"runtime"
"strconv"
"git.eeqj.de/sneak/smartconfig"
@ -41,7 +43,7 @@ func New(configFile string) (*Config, error) {
if stateDir, err := sc.GetString("state_dir"); err == nil {
cfg.StateDir = stateDir
} else {
cfg.StateDir = "/var/lib/ipapi"
cfg.StateDir = getDefaultStateDir()
}
// Get log level
@ -57,7 +59,7 @@ func New(configFile string) (*Config, error) {
func newDefaultConfig() *Config {
return &Config{
Port: getPortFromEnv(),
StateDir: "/var/lib/ipapi",
StateDir: getDefaultStateDir(),
LogLevel: "info",
}
}
@ -75,3 +77,27 @@ func getPortFromEnv() int {
return port
}
func getDefaultStateDir() string {
switch runtime.GOOS {
case "darwin":
// macOS: use ~/Library/Application Support/berlin.sneak.app.ipapi/
home, err := os.UserHomeDir()
if err != nil {
return "/var/lib/ipapi"
}
return filepath.Join(home, "Library", "Application Support", "berlin.sneak.app.ipapi")
case "windows":
// Windows: use %APPDATA%\berlin.sneak.app.ipapi
appData := os.Getenv("APPDATA")
if appData == "" {
return "/var/lib/ipapi"
}
return filepath.Join(appData, "berlin.sneak.app.ipapi")
default:
// Linux and other Unix-like systems
return "/var/lib/ipapi"
}
}

View File

@ -15,8 +15,9 @@ func TestNewDefaultConfig(t *testing.T) {
if cfg.Port != 8080 {
t.Errorf("expected default port 8080, got %d", cfg.Port)
}
if cfg.StateDir != "/var/lib/ipapi" {
t.Errorf("expected default state dir /var/lib/ipapi, got %s", cfg.StateDir)
expectedStateDir := getDefaultStateDir()
if cfg.StateDir != expectedStateDir {
t.Errorf("expected default state dir %s, got %s", expectedStateDir, cfg.StateDir)
}
if cfg.LogLevel != "info" {
t.Errorf("expected default log level info, got %s", cfg.LogLevel)

View File

@ -25,7 +25,7 @@ const (
cityFile = "GeoLite2-City.mmdb"
countryFile = "GeoLite2-Country.mmdb"
downloadTimeout = 5 * time.Minute
downloadTimeout = 30 * time.Second
updateInterval = 7 * 24 * time.Hour // 1 week
defaultDirPermissions = 0750

34
internal/http/module.go Normal file
View File

@ -0,0 +1,34 @@
package http
import (
"log/slog"
"git.eeqj.de/sneak/ipapi/internal/database"
"git.eeqj.de/sneak/ipapi/internal/state"
"github.com/go-chi/chi/v5"
"go.uber.org/fx"
)
// Module provides HTTP server functionality.
//
//nolint:gochecknoglobals // fx module pattern requires global
var Module = fx.Options(
fx.Provide(
NewRouter,
NewServer,
),
)
// RouterParams contains dependencies for NewRouter.
type RouterParams struct {
fx.In
Logger *slog.Logger
DB *database.Manager
State *state.Manager
}
// NewRouterWrapper wraps NewRouter to work with FX dependency injection.
func NewRouterWrapper(params RouterParams) (chi.Router, error) {
return NewRouter(params.Logger, params.DB, params.State)
}

View File

@ -7,29 +7,86 @@ import (
"net"
"net/http"
"strings"
"time"
"git.eeqj.de/sneak/ipapi/internal/database"
"git.eeqj.de/sneak/ipapi/internal/state"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
// IPInfo represents the API response for IP lookups.
var (
// RFC1918 private address ranges
//nolint:gochecknoglobals // Pre-parsed networks for performance
rfc1918Networks = []*net.IPNet{}
// Documentation address ranges
//nolint:gochecknoglobals // Pre-parsed networks for performance
documentationIPv4Networks = []*net.IPNet{}
//nolint:gochecknoglobals // Pre-parsed networks for performance
documentationIPv6Networks = []*net.IPNet{}
)
func init() {
// Initialize RFC1918 networks
for _, cidr := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} {
_, network, _ := net.ParseCIDR(cidr)
rfc1918Networks = append(rfc1918Networks, network)
}
// Initialize documentation IPv4 networks
for _, cidr := range []string{"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"} {
_, network, _ := net.ParseCIDR(cidr)
documentationIPv4Networks = append(documentationIPv4Networks, network)
}
// Initialize documentation IPv6 networks
_, network, _ := net.ParseCIDR("2001:db8::/32")
documentationIPv6Networks = append(documentationIPv6Networks, network)
}
// IPInfo represents the data for a single IP address.
type IPInfo struct {
IP string `json:"ip"`
Country string `json:"country,omitempty"`
CountryCode string `json:"countryCode,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
PostalCode string `json:"postalCode,omitempty"`
Latitude float64 `json:"latitude,omitempty"`
Longitude float64 `json:"longitude,omitempty"`
Timezone string `json:"timezone,omitempty"`
ASN uint `json:"asn,omitempty"`
ASNOrg string `json:"asnOrg,omitempty"`
IP string `json:"ip"`
IPVersion int `json:"ipVersion"`
IsPrivate bool `json:"isPrivate"`
IsRFC1918 bool `json:"isRfc1918"`
IsLoopback bool `json:"isLoopback"`
IsMulticast bool `json:"isMulticast"`
IsLinkLocal bool `json:"isLinkLocal"`
IsGlobalUnicast bool `json:"isGlobalUnicast"`
IsUnspecified bool `json:"isUnspecified"`
IsBroadcast bool `json:"isBroadcast"`
IsDocumentation bool `json:"isDocumentation"`
Country string `json:"country,omitempty"`
CountryCode string `json:"countryCode,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
PostalCode string `json:"postalCode,omitempty"`
Latitude float64 `json:"latitude,omitempty"`
Longitude float64 `json:"longitude,omitempty"`
Timezone string `json:"timezone,omitempty"`
ASN uint `json:"asn,omitempty"`
ASNOrg string `json:"asnOrg,omitempty"`
DatabaseTimestamp time.Time `json:"databaseTimestamp"`
}
// SuccessResponse represents a successful API response.
type SuccessResponse struct {
Status string `json:"status"`
Date time.Time `json:"date"`
Result map[string]*IPInfo `json:"result"`
Error bool `json:"error"`
}
// ErrorResponse represents an error API response.
type ErrorResponse struct {
Status string `json:"status"`
Date time.Time `json:"date"`
Error map[string]interface{} `json:"error"`
}
// NewRouter creates a new HTTP router with all endpoints configured.
func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) {
func NewRouter(logger *slog.Logger, db *database.Manager, state *state.Manager) (chi.Router, error) {
r := chi.NewRouter()
// Middleware
@ -61,76 +118,206 @@ func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) {
}
})
// IP lookup endpoint
r.Get("/api/{ip}", handleIPLookup(logger, db))
// IP lookup endpoint - supports comma-delimited IPs
r.Get("/api/v1/ipinfo/local/{ips}", handleIPLookup(logger, db, state))
return r, nil
}
func handleIPLookup(logger *slog.Logger, db *database.Manager) http.HandlerFunc {
func handleIPLookup(logger *slog.Logger, db *database.Manager, stateManager *state.Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ipStr := chi.URLParam(r, "ip")
ipsParam := chi.URLParam(r, "ips")
// Validate IP address
ip := net.ParseIP(ipStr)
if ip == nil {
writeError(w, http.StatusBadRequest, "Invalid IP address")
// Split comma-delimited IPs
ipList := strings.Split(ipsParam, ",")
// Check for IP limit
const maxIPs = 50
const errorCodeTooManyIPs = 1005
if len(ipList) > maxIPs {
writeErrorResponse(w, http.StatusBadRequest, errorCodeTooManyIPs, "Too many IPs requested. Maximum is 50")
return
}
info := &IPInfo{
IP: ipStr,
// Get database references
countryDB := db.GetCountryDB()
cityDB := db.GetCityDB()
asnDB := db.GetASNDB()
// Check if databases are loaded
const errorCodeDatabasesNotLoaded = 1002
if countryDB == nil && cityDB == nil && asnDB == nil {
writeErrorResponse(w, http.StatusServiceUnavailable, errorCodeDatabasesNotLoaded,
"GeoIP databases are not yet loaded")
return
}
// Look up in Country database
if countryDB := db.GetCountryDB(); countryDB != nil {
country, err := countryDB.Country(ip)
if err == nil {
info.Country = country.Country.Names["en"]
info.CountryCode = country.Country.IsoCode
// Get database timestamp
stateData, err := stateManager.Load()
var dbTimestamp time.Time
if err == nil && stateData != nil {
// Use the most recent database download time
timestamps := []time.Time{
stateData.LastASNDownload,
stateData.LastCityDownload,
stateData.LastCountryDownload,
}
}
// Look up in City database
if cityDB := db.GetCityDB(); cityDB != nil {
city, err := cityDB.City(ip)
if err == nil {
info.City = city.City.Names["en"]
if len(city.Subdivisions) > 0 {
info.Region = city.Subdivisions[0].Names["en"]
dbTimestamp = timestamps[0]
for _, ts := range timestamps[1:] {
if ts.After(dbTimestamp) {
dbTimestamp = ts
}
info.PostalCode = city.Postal.Code
info.Latitude = city.Location.Latitude
info.Longitude = city.Location.Longitude
info.Timezone = city.Location.TimeZone
}
}
// Look up in ASN database
if asnDB := db.GetASNDB(); asnDB != nil {
asn, err := asnDB.ASN(ip)
if err == nil {
info.ASN = asn.AutonomousSystemNumber
info.ASNOrg = asn.AutonomousSystemOrganization
result := make(map[string]*IPInfo)
for _, ipStr := range ipList {
ipStr = strings.TrimSpace(ipStr)
if ipStr == "" {
continue
}
// Validate IP address
ip := net.ParseIP(ipStr)
if ip == nil {
result[ipStr] = &IPInfo{
IP: ipStr,
DatabaseTimestamp: dbTimestamp,
}
continue
}
info := &IPInfo{
IP: ipStr,
DatabaseTimestamp: dbTimestamp,
}
// Populate IP properties
info.IsPrivate = ip.IsPrivate()
info.IsLoopback = ip.IsLoopback()
info.IsMulticast = ip.IsMulticast()
info.IsLinkLocal = ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()
info.IsGlobalUnicast = ip.IsGlobalUnicast()
info.IsUnspecified = ip.IsUnspecified()
// Determine IP version and additional properties
if ipv4 := ip.To4(); ipv4 != nil {
info.IPVersion = 4
// Check for RFC1918 private addresses
info.IsRFC1918 = isRFC1918(ipv4)
// Check for broadcast (255.255.255.255)
const maxByte = 255
info.IsBroadcast = ipv4[0] == maxByte && ipv4[1] == maxByte && ipv4[2] == maxByte && ipv4[3] == maxByte
// Check for documentation addresses (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
info.IsDocumentation = isDocumentationIPv4(ipv4)
} else if ip.To16() != nil {
info.IPVersion = 6
info.IsRFC1918 = false
info.IsBroadcast = false
// Check for documentation addresses (2001:db8::/32)
info.IsDocumentation = isDocumentationIPv6(ip)
}
// Look up in Country database
if countryDB != nil {
country, err := countryDB.Country(ip)
if err == nil {
info.Country = country.Country.Names["en"]
info.CountryCode = country.Country.IsoCode
}
}
// Look up in City database
if cityDB != nil {
city, err := cityDB.City(ip)
if err == nil {
info.City = city.City.Names["en"]
if len(city.Subdivisions) > 0 {
info.Region = city.Subdivisions[0].Names["en"]
}
info.PostalCode = city.Postal.Code
info.Latitude = city.Location.Latitude
info.Longitude = city.Location.Longitude
info.Timezone = city.Location.TimeZone
}
}
// Look up in ASN database
if asnDB != nil {
asn, err := asnDB.ASN(ip)
if err == nil {
info.ASN = asn.AutonomousSystemNumber
info.ASNOrg = asn.AutonomousSystemOrganization
}
}
result[ipStr] = info
}
// Return success response
response := SuccessResponse{
Status: "success",
Date: time.Now().UTC(),
Result: result,
Error: false,
}
// Set content type and encode response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(info); err != nil {
const errorCodeEncodeFailure = 1004
if err := json.NewEncoder(w).Encode(response); err != nil {
logger.Error("Failed to encode response", "error", err)
writeError(w, http.StatusInternalServerError, "Internal server error")
writeErrorResponse(w, http.StatusInternalServerError, errorCodeEncodeFailure, "Failed to encode response")
}
}
}
func writeError(w http.ResponseWriter, code int, message string) {
func isRFC1918(ip net.IP) bool {
for _, network := range rfc1918Networks {
if network.Contains(ip) {
return true
}
}
return false
}
func isDocumentationIPv4(ip net.IP) bool {
for _, network := range documentationIPv4Networks {
if network.Contains(ip) {
return true
}
}
return false
}
func isDocumentationIPv6(ip net.IP) bool {
for _, network := range documentationIPv6Networks {
if network.Contains(ip) {
return true
}
}
return false
}
func writeErrorResponse(w http.ResponseWriter, httpCode int, errorCode int, message string) {
response := ErrorResponse{
Status: "error",
Date: time.Now().UTC(),
Error: map[string]interface{}{
"code": errorCode,
"msg": message,
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(map[string]string{
"error": message,
}); err != nil {
w.WriteHeader(httpCode)
if err := json.NewEncoder(w).Encode(response); err != nil {
// Log error but don't try to write again
_ = err
}
@ -155,4 +342,4 @@ func getClientIP(r *http.Request) string {
host, _, _ := net.SplitHostPort(r.RemoteAddr)
return host
}
}

View File

@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"git.eeqj.de/sneak/ipapi/internal/config"
"git.eeqj.de/sneak/ipapi/internal/database"
@ -22,7 +23,7 @@ func TestNewRouter(t *testing.T) {
stateManager, _ := state.New(cfg, logger)
db, _ := database.New(cfg, logger, stateManager)
router, err := NewRouter(logger, db)
router, err := NewRouter(logger, db, stateManager)
if err != nil {
t.Fatalf("failed to create router: %v", err)
}
@ -40,7 +41,7 @@ func TestHealthEndpoint(t *testing.T) {
stateManager, _ := state.New(cfg, logger)
db, _ := database.New(cfg, logger, stateManager)
router, _ := NewRouter(logger, db)
router, _ := NewRouter(logger, db, stateManager)
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
@ -58,12 +59,24 @@ func TestHealthEndpoint(t *testing.T) {
func TestIPLookupEndpoint(t *testing.T) {
logger := slog.Default()
tmpDir := t.TempDir()
cfg := &config.Config{
StateDir: t.TempDir(),
StateDir: tmpDir,
}
stateManager, _ := state.New(cfg, logger)
_ = stateManager.Initialize(context.Background())
db, _ := database.New(cfg, logger, stateManager)
// Download databases to temp directory for testing
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
t.Log("Downloading test databases...")
if err := db.EnsureDatabases(ctx); err != nil {
t.Skipf("Skipping test - could not download databases: %v", err)
}
tests := []struct {
name string
ip string
@ -71,38 +84,141 @@ func TestIPLookupEndpoint(t *testing.T) {
}{
{"valid IPv4", "8.8.8.8", http.StatusOK},
{"valid IPv6", "2001:4860:4860::8888", http.StatusOK},
{"invalid IP", "invalid", http.StatusBadRequest},
{"invalid IP", "invalid", http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/"+tt.ip, nil)
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+tt.ip, nil)
rec := httptest.NewRecorder()
// Create a new context with the URL param
rctx := chi.NewRouteContext()
rctx.URLParams.Add("ip", tt.ip)
rctx.URLParams.Add("ips", tt.ip)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
handleIPLookup(logger, db)(rec, req)
handleIPLookup(logger, db, stateManager)(rec, req)
if rec.Code != tt.expectedCode {
t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code)
}
if tt.expectedCode == http.StatusOK {
var info IPInfo
if err := json.Unmarshal(rec.Body.Bytes(), &info); err != nil {
var response SuccessResponse
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if info.IP != tt.ip {
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
if response.Status != "success" {
t.Errorf("expected status 'success', got %s", response.Status)
}
if len(response.Result) != 1 {
t.Errorf("expected 1 result, got %d", len(response.Result))
}
if info, ok := response.Result[tt.ip]; ok {
if info.IP != tt.ip {
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
}
} else {
t.Errorf("expected result for IP %s", tt.ip)
}
}
})
}
}
func TestIPLookupMultipleIPs(t *testing.T) {
logger := slog.Default()
tmpDir := t.TempDir()
cfg := &config.Config{
StateDir: tmpDir,
}
stateManager, _ := state.New(cfg, logger)
_ = stateManager.Initialize(context.Background())
db, _ := database.New(cfg, logger, stateManager)
// Download databases to temp directory for testing
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
t.Log("Downloading test databases...")
if err := db.EnsureDatabases(ctx); err != nil {
t.Skipf("Skipping test - could not download databases: %v", err)
}
// Test multiple IPs
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/8.8.8.8,1.1.1.1,invalid", nil)
rec := httptest.NewRecorder()
// Create a new context with the URL param
rctx := chi.NewRouteContext()
rctx.URLParams.Add("ips", "8.8.8.8,1.1.1.1,invalid")
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
handleIPLookup(logger, db, stateManager)(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rec.Code)
}
var response SuccessResponse
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if len(response.Result) != 3 {
t.Errorf("expected 3 results, got %d", len(response.Result))
}
// Check that all requested IPs are in the result
for _, ip := range []string{"8.8.8.8", "1.1.1.1", "invalid"} {
if _, ok := response.Result[ip]; !ok {
t.Errorf("expected result for IP %s", ip)
}
}
}
func TestIPLookupTooManyIPs(t *testing.T) {
logger := slog.Default()
cfg := &config.Config{
StateDir: t.TempDir(),
}
stateManager, _ := state.New(cfg, logger)
db, _ := database.New(cfg, logger, stateManager)
// Create 51 IPs (more than the limit)
ips := ""
for i := 0; i < 51; i++ {
if i > 0 {
ips += ","
}
ips += "1.1.1.1"
}
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+ips, nil)
rec := httptest.NewRecorder()
// Create a new context with the URL param
rctx := chi.NewRouteContext()
rctx.URLParams.Add("ips", ips)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
handleIPLookup(logger, db, stateManager)(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rec.Code)
}
var response ErrorResponse
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if response.Status != "error" {
t.Errorf("expected status 'error', got %s", response.Status)
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string

View File

@ -32,7 +32,7 @@ func daemonCmd() *cobra.Command {
state.New,
database.New,
http.NewServer,
http.NewRouter,
http.NewRouterWrapper,
New,
),
fx.Invoke(func(lc fx.Lifecycle, ipapi *IPAPI) {

View File

@ -54,10 +54,14 @@ func (i *IPAPI) Start(ctx context.Context) error {
return err
}
// Download databases if needed
if err := i.database.EnsureDatabases(ctx); err != nil {
return err
}
// Start database downloads in background with a fresh context
// This avoids using the OnStart context which has a timeout
go func() {
dbCtx := context.Background()
if err := i.database.EnsureDatabases(dbCtx); err != nil {
i.logger.Error("Failed to ensure databases", "error", err)
}
}()
// Start HTTP server
return i.server.Start(ctx)