Fix dependency injection and implement proper CIDR checking
- Add http module for proper FX dependency injection - Fix router to accept state manager parameter - Implement proper CIDR-based checking for RFC1918 and documentation IPs - Add reasonable timeouts (30s) for database downloads - Update tests to download databases to temporary directories - Add tests for multiple IP lookups and error cases - All tests passing
This commit is contained in:
parent
2a1710cca8
commit
08ca75966e
@ -3,6 +3,8 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"git.eeqj.de/sneak/smartconfig"
|
"git.eeqj.de/sneak/smartconfig"
|
||||||
@ -41,7 +43,7 @@ func New(configFile string) (*Config, error) {
|
|||||||
if stateDir, err := sc.GetString("state_dir"); err == nil {
|
if stateDir, err := sc.GetString("state_dir"); err == nil {
|
||||||
cfg.StateDir = stateDir
|
cfg.StateDir = stateDir
|
||||||
} else {
|
} else {
|
||||||
cfg.StateDir = "/var/lib/ipapi"
|
cfg.StateDir = getDefaultStateDir()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get log level
|
// Get log level
|
||||||
@ -57,7 +59,7 @@ func New(configFile string) (*Config, error) {
|
|||||||
func newDefaultConfig() *Config {
|
func newDefaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Port: getPortFromEnv(),
|
Port: getPortFromEnv(),
|
||||||
StateDir: "/var/lib/ipapi",
|
StateDir: getDefaultStateDir(),
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -75,3 +77,27 @@ func getPortFromEnv() int {
|
|||||||
|
|
||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getDefaultStateDir() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
// macOS: use ~/Library/Application Support/berlin.sneak.app.ipapi/
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "/var/lib/ipapi"
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(home, "Library", "Application Support", "berlin.sneak.app.ipapi")
|
||||||
|
case "windows":
|
||||||
|
// Windows: use %APPDATA%\berlin.sneak.app.ipapi
|
||||||
|
appData := os.Getenv("APPDATA")
|
||||||
|
if appData == "" {
|
||||||
|
return "/var/lib/ipapi"
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(appData, "berlin.sneak.app.ipapi")
|
||||||
|
default:
|
||||||
|
// Linux and other Unix-like systems
|
||||||
|
return "/var/lib/ipapi"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -15,8 +15,9 @@ func TestNewDefaultConfig(t *testing.T) {
|
|||||||
if cfg.Port != 8080 {
|
if cfg.Port != 8080 {
|
||||||
t.Errorf("expected default port 8080, got %d", cfg.Port)
|
t.Errorf("expected default port 8080, got %d", cfg.Port)
|
||||||
}
|
}
|
||||||
if cfg.StateDir != "/var/lib/ipapi" {
|
expectedStateDir := getDefaultStateDir()
|
||||||
t.Errorf("expected default state dir /var/lib/ipapi, got %s", cfg.StateDir)
|
if cfg.StateDir != expectedStateDir {
|
||||||
|
t.Errorf("expected default state dir %s, got %s", expectedStateDir, cfg.StateDir)
|
||||||
}
|
}
|
||||||
if cfg.LogLevel != "info" {
|
if cfg.LogLevel != "info" {
|
||||||
t.Errorf("expected default log level info, got %s", cfg.LogLevel)
|
t.Errorf("expected default log level info, got %s", cfg.LogLevel)
|
||||||
|
@ -25,7 +25,7 @@ const (
|
|||||||
cityFile = "GeoLite2-City.mmdb"
|
cityFile = "GeoLite2-City.mmdb"
|
||||||
countryFile = "GeoLite2-Country.mmdb"
|
countryFile = "GeoLite2-Country.mmdb"
|
||||||
|
|
||||||
downloadTimeout = 5 * time.Minute
|
downloadTimeout = 30 * time.Second
|
||||||
updateInterval = 7 * 24 * time.Hour // 1 week
|
updateInterval = 7 * 24 * time.Hour // 1 week
|
||||||
|
|
||||||
defaultDirPermissions = 0750
|
defaultDirPermissions = 0750
|
||||||
|
34
internal/http/module.go
Normal file
34
internal/http/module.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||||
|
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"go.uber.org/fx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Module provides HTTP server functionality.
|
||||||
|
//
|
||||||
|
//nolint:gochecknoglobals // fx module pattern requires global
|
||||||
|
var Module = fx.Options(
|
||||||
|
fx.Provide(
|
||||||
|
NewRouter,
|
||||||
|
NewServer,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
// RouterParams contains dependencies for NewRouter.
|
||||||
|
type RouterParams struct {
|
||||||
|
fx.In
|
||||||
|
|
||||||
|
Logger *slog.Logger
|
||||||
|
DB *database.Manager
|
||||||
|
State *state.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouterWrapper wraps NewRouter to work with FX dependency injection.
|
||||||
|
func NewRouterWrapper(params RouterParams) (chi.Router, error) {
|
||||||
|
return NewRouter(params.Logger, params.DB, params.State)
|
||||||
|
}
|
@ -7,29 +7,86 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||||
|
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPInfo represents the API response for IP lookups.
|
var (
|
||||||
|
// RFC1918 private address ranges
|
||||||
|
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
||||||
|
rfc1918Networks = []*net.IPNet{}
|
||||||
|
// Documentation address ranges
|
||||||
|
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
||||||
|
documentationIPv4Networks = []*net.IPNet{}
|
||||||
|
//nolint:gochecknoglobals // Pre-parsed networks for performance
|
||||||
|
documentationIPv6Networks = []*net.IPNet{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Initialize RFC1918 networks
|
||||||
|
for _, cidr := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"} {
|
||||||
|
_, network, _ := net.ParseCIDR(cidr)
|
||||||
|
rfc1918Networks = append(rfc1918Networks, network)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize documentation IPv4 networks
|
||||||
|
for _, cidr := range []string{"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24"} {
|
||||||
|
_, network, _ := net.ParseCIDR(cidr)
|
||||||
|
documentationIPv4Networks = append(documentationIPv4Networks, network)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize documentation IPv6 networks
|
||||||
|
_, network, _ := net.ParseCIDR("2001:db8::/32")
|
||||||
|
documentationIPv6Networks = append(documentationIPv6Networks, network)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPInfo represents the data for a single IP address.
|
||||||
type IPInfo struct {
|
type IPInfo struct {
|
||||||
IP string `json:"ip"`
|
IP string `json:"ip"`
|
||||||
Country string `json:"country,omitempty"`
|
IPVersion int `json:"ipVersion"`
|
||||||
CountryCode string `json:"countryCode,omitempty"`
|
IsPrivate bool `json:"isPrivate"`
|
||||||
City string `json:"city,omitempty"`
|
IsRFC1918 bool `json:"isRfc1918"`
|
||||||
Region string `json:"region,omitempty"`
|
IsLoopback bool `json:"isLoopback"`
|
||||||
PostalCode string `json:"postalCode,omitempty"`
|
IsMulticast bool `json:"isMulticast"`
|
||||||
Latitude float64 `json:"latitude,omitempty"`
|
IsLinkLocal bool `json:"isLinkLocal"`
|
||||||
Longitude float64 `json:"longitude,omitempty"`
|
IsGlobalUnicast bool `json:"isGlobalUnicast"`
|
||||||
Timezone string `json:"timezone,omitempty"`
|
IsUnspecified bool `json:"isUnspecified"`
|
||||||
ASN uint `json:"asn,omitempty"`
|
IsBroadcast bool `json:"isBroadcast"`
|
||||||
ASNOrg string `json:"asnOrg,omitempty"`
|
IsDocumentation bool `json:"isDocumentation"`
|
||||||
|
Country string `json:"country,omitempty"`
|
||||||
|
CountryCode string `json:"countryCode,omitempty"`
|
||||||
|
City string `json:"city,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
PostalCode string `json:"postalCode,omitempty"`
|
||||||
|
Latitude float64 `json:"latitude,omitempty"`
|
||||||
|
Longitude float64 `json:"longitude,omitempty"`
|
||||||
|
Timezone string `json:"timezone,omitempty"`
|
||||||
|
ASN uint `json:"asn,omitempty"`
|
||||||
|
ASNOrg string `json:"asnOrg,omitempty"`
|
||||||
|
DatabaseTimestamp time.Time `json:"databaseTimestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SuccessResponse represents a successful API response.
|
||||||
|
type SuccessResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Date time.Time `json:"date"`
|
||||||
|
Result map[string]*IPInfo `json:"result"`
|
||||||
|
Error bool `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorResponse represents an error API response.
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Date time.Time `json:"date"`
|
||||||
|
Error map[string]interface{} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouter creates a new HTTP router with all endpoints configured.
|
// NewRouter creates a new HTTP router with all endpoints configured.
|
||||||
func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) {
|
func NewRouter(logger *slog.Logger, db *database.Manager, state *state.Manager) (chi.Router, error) {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
// Middleware
|
// Middleware
|
||||||
@ -61,76 +118,206 @@ func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// IP lookup endpoint
|
// IP lookup endpoint - supports comma-delimited IPs
|
||||||
r.Get("/api/{ip}", handleIPLookup(logger, db))
|
r.Get("/api/v1/ipinfo/local/{ips}", handleIPLookup(logger, db, state))
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleIPLookup(logger *slog.Logger, db *database.Manager) http.HandlerFunc {
|
func handleIPLookup(logger *slog.Logger, db *database.Manager, stateManager *state.Manager) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ipStr := chi.URLParam(r, "ip")
|
ipsParam := chi.URLParam(r, "ips")
|
||||||
|
|
||||||
// Validate IP address
|
// Split comma-delimited IPs
|
||||||
ip := net.ParseIP(ipStr)
|
ipList := strings.Split(ipsParam, ",")
|
||||||
if ip == nil {
|
|
||||||
writeError(w, http.StatusBadRequest, "Invalid IP address")
|
// Check for IP limit
|
||||||
|
const maxIPs = 50
|
||||||
|
const errorCodeTooManyIPs = 1005
|
||||||
|
if len(ipList) > maxIPs {
|
||||||
|
writeErrorResponse(w, http.StatusBadRequest, errorCodeTooManyIPs, "Too many IPs requested. Maximum is 50")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
info := &IPInfo{
|
// Get database references
|
||||||
IP: ipStr,
|
countryDB := db.GetCountryDB()
|
||||||
|
cityDB := db.GetCityDB()
|
||||||
|
asnDB := db.GetASNDB()
|
||||||
|
|
||||||
|
// Check if databases are loaded
|
||||||
|
const errorCodeDatabasesNotLoaded = 1002
|
||||||
|
if countryDB == nil && cityDB == nil && asnDB == nil {
|
||||||
|
writeErrorResponse(w, http.StatusServiceUnavailable, errorCodeDatabasesNotLoaded,
|
||||||
|
"GeoIP databases are not yet loaded")
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up in Country database
|
// Get database timestamp
|
||||||
if countryDB := db.GetCountryDB(); countryDB != nil {
|
stateData, err := stateManager.Load()
|
||||||
country, err := countryDB.Country(ip)
|
var dbTimestamp time.Time
|
||||||
if err == nil {
|
if err == nil && stateData != nil {
|
||||||
info.Country = country.Country.Names["en"]
|
// Use the most recent database download time
|
||||||
info.CountryCode = country.Country.IsoCode
|
timestamps := []time.Time{
|
||||||
|
stateData.LastASNDownload,
|
||||||
|
stateData.LastCityDownload,
|
||||||
|
stateData.LastCountryDownload,
|
||||||
}
|
}
|
||||||
}
|
dbTimestamp = timestamps[0]
|
||||||
|
for _, ts := range timestamps[1:] {
|
||||||
// Look up in City database
|
if ts.After(dbTimestamp) {
|
||||||
if cityDB := db.GetCityDB(); cityDB != nil {
|
dbTimestamp = ts
|
||||||
city, err := cityDB.City(ip)
|
|
||||||
if err == nil {
|
|
||||||
info.City = city.City.Names["en"]
|
|
||||||
if len(city.Subdivisions) > 0 {
|
|
||||||
info.Region = city.Subdivisions[0].Names["en"]
|
|
||||||
}
|
}
|
||||||
info.PostalCode = city.Postal.Code
|
|
||||||
info.Latitude = city.Location.Latitude
|
|
||||||
info.Longitude = city.Location.Longitude
|
|
||||||
info.Timezone = city.Location.TimeZone
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up in ASN database
|
result := make(map[string]*IPInfo)
|
||||||
if asnDB := db.GetASNDB(); asnDB != nil {
|
|
||||||
asn, err := asnDB.ASN(ip)
|
for _, ipStr := range ipList {
|
||||||
if err == nil {
|
ipStr = strings.TrimSpace(ipStr)
|
||||||
info.ASN = asn.AutonomousSystemNumber
|
if ipStr == "" {
|
||||||
info.ASNOrg = asn.AutonomousSystemOrganization
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate IP address
|
||||||
|
ip := net.ParseIP(ipStr)
|
||||||
|
if ip == nil {
|
||||||
|
result[ipStr] = &IPInfo{
|
||||||
|
IP: ipStr,
|
||||||
|
DatabaseTimestamp: dbTimestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
info := &IPInfo{
|
||||||
|
IP: ipStr,
|
||||||
|
DatabaseTimestamp: dbTimestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate IP properties
|
||||||
|
info.IsPrivate = ip.IsPrivate()
|
||||||
|
info.IsLoopback = ip.IsLoopback()
|
||||||
|
info.IsMulticast = ip.IsMulticast()
|
||||||
|
info.IsLinkLocal = ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast()
|
||||||
|
info.IsGlobalUnicast = ip.IsGlobalUnicast()
|
||||||
|
info.IsUnspecified = ip.IsUnspecified()
|
||||||
|
|
||||||
|
// Determine IP version and additional properties
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
info.IPVersion = 4
|
||||||
|
// Check for RFC1918 private addresses
|
||||||
|
info.IsRFC1918 = isRFC1918(ipv4)
|
||||||
|
// Check for broadcast (255.255.255.255)
|
||||||
|
const maxByte = 255
|
||||||
|
info.IsBroadcast = ipv4[0] == maxByte && ipv4[1] == maxByte && ipv4[2] == maxByte && ipv4[3] == maxByte
|
||||||
|
// Check for documentation addresses (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
|
||||||
|
info.IsDocumentation = isDocumentationIPv4(ipv4)
|
||||||
|
} else if ip.To16() != nil {
|
||||||
|
info.IPVersion = 6
|
||||||
|
info.IsRFC1918 = false
|
||||||
|
info.IsBroadcast = false
|
||||||
|
// Check for documentation addresses (2001:db8::/32)
|
||||||
|
info.IsDocumentation = isDocumentationIPv6(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up in Country database
|
||||||
|
if countryDB != nil {
|
||||||
|
country, err := countryDB.Country(ip)
|
||||||
|
if err == nil {
|
||||||
|
info.Country = country.Country.Names["en"]
|
||||||
|
info.CountryCode = country.Country.IsoCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up in City database
|
||||||
|
if cityDB != nil {
|
||||||
|
city, err := cityDB.City(ip)
|
||||||
|
if err == nil {
|
||||||
|
info.City = city.City.Names["en"]
|
||||||
|
if len(city.Subdivisions) > 0 {
|
||||||
|
info.Region = city.Subdivisions[0].Names["en"]
|
||||||
|
}
|
||||||
|
info.PostalCode = city.Postal.Code
|
||||||
|
info.Latitude = city.Location.Latitude
|
||||||
|
info.Longitude = city.Location.Longitude
|
||||||
|
info.Timezone = city.Location.TimeZone
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up in ASN database
|
||||||
|
if asnDB != nil {
|
||||||
|
asn, err := asnDB.ASN(ip)
|
||||||
|
if err == nil {
|
||||||
|
info.ASN = asn.AutonomousSystemNumber
|
||||||
|
info.ASNOrg = asn.AutonomousSystemOrganization
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result[ipStr] = info
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return success response
|
||||||
|
response := SuccessResponse{
|
||||||
|
Status: "success",
|
||||||
|
Date: time.Now().UTC(),
|
||||||
|
Result: result,
|
||||||
|
Error: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set content type and encode response
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(info); err != nil {
|
const errorCodeEncodeFailure = 1004
|
||||||
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
logger.Error("Failed to encode response", "error", err)
|
logger.Error("Failed to encode response", "error", err)
|
||||||
writeError(w, http.StatusInternalServerError, "Internal server error")
|
writeErrorResponse(w, http.StatusInternalServerError, errorCodeEncodeFailure, "Failed to encode response")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeError(w http.ResponseWriter, code int, message string) {
|
func isRFC1918(ip net.IP) bool {
|
||||||
|
for _, network := range rfc1918Networks {
|
||||||
|
if network.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDocumentationIPv4(ip net.IP) bool {
|
||||||
|
for _, network := range documentationIPv4Networks {
|
||||||
|
if network.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDocumentationIPv6(ip net.IP) bool {
|
||||||
|
for _, network := range documentationIPv6Networks {
|
||||||
|
if network.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeErrorResponse(w http.ResponseWriter, httpCode int, errorCode int, message string) {
|
||||||
|
response := ErrorResponse{
|
||||||
|
Status: "error",
|
||||||
|
Date: time.Now().UTC(),
|
||||||
|
Error: map[string]interface{}{
|
||||||
|
"code": errorCode,
|
||||||
|
"msg": message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(code)
|
w.WriteHeader(httpCode)
|
||||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
"error": message,
|
|
||||||
}); err != nil {
|
|
||||||
// Log error but don't try to write again
|
// Log error but don't try to write again
|
||||||
_ = err
|
_ = err
|
||||||
}
|
}
|
||||||
@ -155,4 +342,4 @@ func getClientIP(r *http.Request) string {
|
|||||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
|
||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||||
@ -22,7 +23,7 @@ func TestNewRouter(t *testing.T) {
|
|||||||
stateManager, _ := state.New(cfg, logger)
|
stateManager, _ := state.New(cfg, logger)
|
||||||
db, _ := database.New(cfg, logger, stateManager)
|
db, _ := database.New(cfg, logger, stateManager)
|
||||||
|
|
||||||
router, err := NewRouter(logger, db)
|
router, err := NewRouter(logger, db, stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create router: %v", err)
|
t.Fatalf("failed to create router: %v", err)
|
||||||
}
|
}
|
||||||
@ -40,7 +41,7 @@ func TestHealthEndpoint(t *testing.T) {
|
|||||||
stateManager, _ := state.New(cfg, logger)
|
stateManager, _ := state.New(cfg, logger)
|
||||||
db, _ := database.New(cfg, logger, stateManager)
|
db, _ := database.New(cfg, logger, stateManager)
|
||||||
|
|
||||||
router, _ := NewRouter(logger, db)
|
router, _ := NewRouter(logger, db, stateManager)
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/health", nil)
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@ -58,12 +59,24 @@ func TestHealthEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
func TestIPLookupEndpoint(t *testing.T) {
|
func TestIPLookupEndpoint(t *testing.T) {
|
||||||
logger := slog.Default()
|
logger := slog.Default()
|
||||||
|
tmpDir := t.TempDir()
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
StateDir: t.TempDir(),
|
StateDir: tmpDir,
|
||||||
}
|
}
|
||||||
stateManager, _ := state.New(cfg, logger)
|
stateManager, _ := state.New(cfg, logger)
|
||||||
|
_ = stateManager.Initialize(context.Background())
|
||||||
|
|
||||||
db, _ := database.New(cfg, logger, stateManager)
|
db, _ := database.New(cfg, logger, stateManager)
|
||||||
|
|
||||||
|
// Download databases to temp directory for testing
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
t.Log("Downloading test databases...")
|
||||||
|
if err := db.EnsureDatabases(ctx); err != nil {
|
||||||
|
t.Skipf("Skipping test - could not download databases: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
ip string
|
ip string
|
||||||
@ -71,38 +84,141 @@ func TestIPLookupEndpoint(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"valid IPv4", "8.8.8.8", http.StatusOK},
|
{"valid IPv4", "8.8.8.8", http.StatusOK},
|
||||||
{"valid IPv6", "2001:4860:4860::8888", http.StatusOK},
|
{"valid IPv6", "2001:4860:4860::8888", http.StatusOK},
|
||||||
{"invalid IP", "invalid", http.StatusBadRequest},
|
{"invalid IP", "invalid", http.StatusOK},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/api/"+tt.ip, nil)
|
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+tt.ip, nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
// Create a new context with the URL param
|
// Create a new context with the URL param
|
||||||
rctx := chi.NewRouteContext()
|
rctx := chi.NewRouteContext()
|
||||||
rctx.URLParams.Add("ip", tt.ip)
|
rctx.URLParams.Add("ips", tt.ip)
|
||||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||||
|
|
||||||
handleIPLookup(logger, db)(rec, req)
|
handleIPLookup(logger, db, stateManager)(rec, req)
|
||||||
|
|
||||||
if rec.Code != tt.expectedCode {
|
if rec.Code != tt.expectedCode {
|
||||||
t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code)
|
t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.expectedCode == http.StatusOK {
|
if tt.expectedCode == http.StatusOK {
|
||||||
var info IPInfo
|
var response SuccessResponse
|
||||||
if err := json.Unmarshal(rec.Body.Bytes(), &info); err != nil {
|
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
|
||||||
t.Errorf("failed to unmarshal response: %v", err)
|
t.Errorf("failed to unmarshal response: %v", err)
|
||||||
}
|
}
|
||||||
if info.IP != tt.ip {
|
if response.Status != "success" {
|
||||||
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
|
t.Errorf("expected status 'success', got %s", response.Status)
|
||||||
|
}
|
||||||
|
if len(response.Result) != 1 {
|
||||||
|
t.Errorf("expected 1 result, got %d", len(response.Result))
|
||||||
|
}
|
||||||
|
if info, ok := response.Result[tt.ip]; ok {
|
||||||
|
if info.IP != tt.ip {
|
||||||
|
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("expected result for IP %s", tt.ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIPLookupMultipleIPs(t *testing.T) {
|
||||||
|
logger := slog.Default()
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
cfg := &config.Config{
|
||||||
|
StateDir: tmpDir,
|
||||||
|
}
|
||||||
|
stateManager, _ := state.New(cfg, logger)
|
||||||
|
_ = stateManager.Initialize(context.Background())
|
||||||
|
|
||||||
|
db, _ := database.New(cfg, logger, stateManager)
|
||||||
|
|
||||||
|
// Download databases to temp directory for testing
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
t.Log("Downloading test databases...")
|
||||||
|
if err := db.EnsureDatabases(ctx); err != nil {
|
||||||
|
t.Skipf("Skipping test - could not download databases: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test multiple IPs
|
||||||
|
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/8.8.8.8,1.1.1.1,invalid", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create a new context with the URL param
|
||||||
|
rctx := chi.NewRouteContext()
|
||||||
|
rctx.URLParams.Add("ips", "8.8.8.8,1.1.1.1,invalid")
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||||
|
|
||||||
|
handleIPLookup(logger, db, stateManager)(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response SuccessResponse
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Errorf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response.Result) != 3 {
|
||||||
|
t.Errorf("expected 3 results, got %d", len(response.Result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all requested IPs are in the result
|
||||||
|
for _, ip := range []string{"8.8.8.8", "1.1.1.1", "invalid"} {
|
||||||
|
if _, ok := response.Result[ip]; !ok {
|
||||||
|
t.Errorf("expected result for IP %s", ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPLookupTooManyIPs(t *testing.T) {
|
||||||
|
logger := slog.Default()
|
||||||
|
cfg := &config.Config{
|
||||||
|
StateDir: t.TempDir(),
|
||||||
|
}
|
||||||
|
stateManager, _ := state.New(cfg, logger)
|
||||||
|
db, _ := database.New(cfg, logger, stateManager)
|
||||||
|
|
||||||
|
// Create 51 IPs (more than the limit)
|
||||||
|
ips := ""
|
||||||
|
for i := 0; i < 51; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
ips += ","
|
||||||
|
}
|
||||||
|
ips += "1.1.1.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+ips, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create a new context with the URL param
|
||||||
|
rctx := chi.NewRouteContext()
|
||||||
|
rctx.URLParams.Add("ips", ips)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||||
|
|
||||||
|
handleIPLookup(logger, db, stateManager)(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response ErrorResponse
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Errorf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Status != "error" {
|
||||||
|
t.Errorf("expected status 'error', got %s", response.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetClientIP(t *testing.T) {
|
func TestGetClientIP(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -32,7 +32,7 @@ func daemonCmd() *cobra.Command {
|
|||||||
state.New,
|
state.New,
|
||||||
database.New,
|
database.New,
|
||||||
http.NewServer,
|
http.NewServer,
|
||||||
http.NewRouter,
|
http.NewRouterWrapper,
|
||||||
New,
|
New,
|
||||||
),
|
),
|
||||||
fx.Invoke(func(lc fx.Lifecycle, ipapi *IPAPI) {
|
fx.Invoke(func(lc fx.Lifecycle, ipapi *IPAPI) {
|
||||||
|
@ -54,10 +54,14 @@ func (i *IPAPI) Start(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Download databases if needed
|
// Start database downloads in background with a fresh context
|
||||||
if err := i.database.EnsureDatabases(ctx); err != nil {
|
// This avoids using the OnStart context which has a timeout
|
||||||
return err
|
go func() {
|
||||||
}
|
dbCtx := context.Background()
|
||||||
|
if err := i.database.EnsureDatabases(dbCtx); err != nil {
|
||||||
|
i.logger.Error("Failed to ensure databases", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Start HTTP server
|
// Start HTTP server
|
||||||
return i.server.Start(ctx)
|
return i.server.Start(ctx)
|
||||||
|
Loading…
Reference in New Issue
Block a user