Implement IP API daemon with GeoIP database support
- Create modular architecture with separate packages for config, database, HTTP, logging, and state management - Implement Cobra CLI with daemon command - Set up Uber FX dependency injection - Add Chi router with health check and IP lookup endpoints - Implement GeoIP database downloader with automatic updates - Add state persistence for tracking database download times - Include comprehensive test coverage for all components - Configure structured logging with slog - Add Makefile with test, lint, and build targets - Support both IPv4 and IPv6 lookups - Return country, city, ASN, and location data in JSON format
This commit is contained in:
77
internal/config/config.go
Normal file
77
internal/config/config.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Package config handles application configuration.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"git.eeqj.de/sneak/smartconfig"
|
||||
)
|
||||
|
||||
// Config holds the application configuration.
|
||||
type Config struct {
|
||||
Port int
|
||||
StateDir string
|
||||
LogLevel string
|
||||
}
|
||||
|
||||
// New creates a new configuration instance.
|
||||
func New(configFile string) (*Config, error) {
|
||||
// Check if config file exists first
|
||||
if _, err := os.Stat(configFile); os.IsNotExist(err) {
|
||||
return newDefaultConfig(), nil
|
||||
}
|
||||
|
||||
// Load smartconfig
|
||||
sc, err := smartconfig.NewFromConfigPath(configFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &Config{}
|
||||
|
||||
// Get port from smartconfig or environment or default
|
||||
if port, err := sc.GetInt("port"); err == nil {
|
||||
cfg.Port = port
|
||||
} else {
|
||||
cfg.Port = getPortFromEnv()
|
||||
}
|
||||
|
||||
// Get state directory
|
||||
if stateDir, err := sc.GetString("state_dir"); err == nil {
|
||||
cfg.StateDir = stateDir
|
||||
} else {
|
||||
cfg.StateDir = "/var/lib/ipapi"
|
||||
}
|
||||
|
||||
// Get log level
|
||||
if logLevel, err := sc.GetString("log_level"); err == nil {
|
||||
cfg.LogLevel = logLevel
|
||||
} else {
|
||||
cfg.LogLevel = "info"
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func newDefaultConfig() *Config {
|
||||
return &Config{
|
||||
Port: getPortFromEnv(),
|
||||
StateDir: "/var/lib/ipapi",
|
||||
LogLevel: "info",
|
||||
}
|
||||
}
|
||||
|
||||
func getPortFromEnv() int {
|
||||
const defaultPort = 8080
|
||||
portStr := os.Getenv("PORT")
|
||||
if portStr == "" {
|
||||
return defaultPort
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return defaultPort
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
64
internal/config/config_test.go
Normal file
64
internal/config/config_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewDefaultConfig(t *testing.T) {
|
||||
// Clear PORT env var for test
|
||||
oldPort := os.Getenv("PORT")
|
||||
os.Unsetenv("PORT")
|
||||
defer os.Setenv("PORT", oldPort)
|
||||
|
||||
cfg := newDefaultConfig()
|
||||
if cfg.Port != 8080 {
|
||||
t.Errorf("expected default port 8080, got %d", cfg.Port)
|
||||
}
|
||||
if cfg.StateDir != "/var/lib/ipapi" {
|
||||
t.Errorf("expected default state dir /var/lib/ipapi, got %s", cfg.StateDir)
|
||||
}
|
||||
if cfg.LogLevel != "info" {
|
||||
t.Errorf("expected default log level info, got %s", cfg.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPortFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{"no env", "", 8080},
|
||||
{"valid port", "9090", 9090},
|
||||
{"invalid port", "invalid", 8080},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oldPort := os.Getenv("PORT")
|
||||
if tt.envValue == "" {
|
||||
os.Unsetenv("PORT")
|
||||
} else {
|
||||
os.Setenv("PORT", tt.envValue)
|
||||
}
|
||||
defer os.Setenv("PORT", oldPort)
|
||||
|
||||
port := getPortFromEnv()
|
||||
if port != tt.expected {
|
||||
t.Errorf("expected port %d, got %d", tt.expected, port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
// Test with non-existent file (should use defaults)
|
||||
cfg, err := New("/nonexistent/config.yml")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error for non-existent file, got %v", err)
|
||||
}
|
||||
if cfg == nil {
|
||||
t.Fatal("expected config, got nil")
|
||||
}
|
||||
}
|
||||
237
internal/database/database.go
Normal file
237
internal/database/database.go
Normal file
@@ -0,0 +1,237 @@
|
||||
// Package database handles GeoIP database management and downloads.
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
const (
|
||||
asnURL = "https://git.io/GeoLite2-ASN.mmdb"
|
||||
cityURL = "https://git.io/GeoLite2-City.mmdb"
|
||||
countryURL = "https://git.io/GeoLite2-Country.mmdb"
|
||||
|
||||
asnFile = "GeoLite2-ASN.mmdb"
|
||||
cityFile = "GeoLite2-City.mmdb"
|
||||
countryFile = "GeoLite2-Country.mmdb"
|
||||
|
||||
downloadTimeout = 5 * time.Minute
|
||||
updateInterval = 7 * 24 * time.Hour // 1 week
|
||||
|
||||
defaultDirPermissions = 0750
|
||||
defaultFilePermissions = 0640
|
||||
)
|
||||
|
||||
// Manager handles GeoIP database operations.
|
||||
type Manager struct {
|
||||
config *config.Config
|
||||
logger *slog.Logger
|
||||
state *state.Manager
|
||||
dataDir string
|
||||
asnDB *geoip2.Reader
|
||||
cityDB *geoip2.Reader
|
||||
countryDB *geoip2.Reader
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// New creates a new database manager.
|
||||
func New(cfg *config.Config, logger *slog.Logger, state *state.Manager) (*Manager, error) {
|
||||
dataDir := filepath.Join(cfg.StateDir, "databases")
|
||||
|
||||
return &Manager{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
state: state,
|
||||
dataDir: dataDir,
|
||||
httpClient: &http.Client{
|
||||
Timeout: downloadTimeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnsureDatabases downloads missing or outdated databases.
|
||||
func (m *Manager) EnsureDatabases(ctx context.Context) error {
|
||||
// Create data directory if it doesn't exist
|
||||
if err := os.MkdirAll(m.dataDir, defaultDirPermissions); err != nil {
|
||||
return fmt.Errorf("failed to create data directory: %w", err)
|
||||
}
|
||||
|
||||
// Load current state
|
||||
currentState, err := m.state.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load state: %w", err)
|
||||
}
|
||||
|
||||
// Check and download ASN database
|
||||
asnPath := filepath.Join(m.dataDir, asnFile)
|
||||
if needsUpdate(asnPath, currentState.LastASNDownload) {
|
||||
m.logger.Info("Downloading ASN database")
|
||||
if err := m.downloadFile(ctx, asnURL, asnPath); err != nil {
|
||||
return fmt.Errorf("failed to download ASN database: %w", err)
|
||||
}
|
||||
if err := m.state.UpdateASNDownloadTime(); err != nil {
|
||||
return fmt.Errorf("failed to update ASN download time: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check and download City database
|
||||
cityPath := filepath.Join(m.dataDir, cityFile)
|
||||
if needsUpdate(cityPath, currentState.LastCityDownload) {
|
||||
m.logger.Info("Downloading City database")
|
||||
if err := m.downloadFile(ctx, cityURL, cityPath); err != nil {
|
||||
return fmt.Errorf("failed to download City database: %w", err)
|
||||
}
|
||||
if err := m.state.UpdateCityDownloadTime(); err != nil {
|
||||
return fmt.Errorf("failed to update City download time: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check and download Country database
|
||||
countryPath := filepath.Join(m.dataDir, countryFile)
|
||||
if needsUpdate(countryPath, currentState.LastCountryDownload) {
|
||||
m.logger.Info("Downloading Country database")
|
||||
if err := m.downloadFile(ctx, countryURL, countryPath); err != nil {
|
||||
return fmt.Errorf("failed to download Country database: %w", err)
|
||||
}
|
||||
if err := m.state.UpdateCountryDownloadTime(); err != nil {
|
||||
return fmt.Errorf("failed to update Country download time: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Open databases
|
||||
if err := m.openDatabases(); err != nil {
|
||||
return fmt.Errorf("failed to open databases: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Info("All databases ready")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) downloadFile(ctx context.Context, url, destPath string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Write to temporary file first
|
||||
tmpPath := destPath + ".tmp"
|
||||
tmpFile, err := os.Create(tmpPath) //nolint:gosec // temporary file with predictable name is ok
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpPath) }()
|
||||
|
||||
_, err = io.Copy(tmpFile, resp.Body)
|
||||
if err2 := tmpFile.Close(); err2 != nil {
|
||||
return fmt.Errorf("failed to close temp file: %w", err2)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
// Move to final location
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
return fmt.Errorf("failed to move file: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Debug("Downloaded file", "url", url, "path", destPath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) openDatabases() error {
|
||||
var err error
|
||||
|
||||
// Open ASN database
|
||||
asnPath := filepath.Join(m.dataDir, asnFile)
|
||||
m.asnDB, err = geoip2.Open(asnPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open ASN database: %w", err)
|
||||
}
|
||||
|
||||
// Open City database
|
||||
cityPath := filepath.Join(m.dataDir, cityFile)
|
||||
m.cityDB, err = geoip2.Open(cityPath)
|
||||
if err != nil {
|
||||
_ = m.asnDB.Close()
|
||||
|
||||
return fmt.Errorf("failed to open City database: %w", err)
|
||||
}
|
||||
|
||||
// Open Country database
|
||||
countryPath := filepath.Join(m.dataDir, countryFile)
|
||||
m.countryDB, err = geoip2.Open(countryPath)
|
||||
if err != nil {
|
||||
_ = m.asnDB.Close()
|
||||
_ = m.cityDB.Close()
|
||||
|
||||
return fmt.Errorf("failed to open Country database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes all open databases.
|
||||
func (m *Manager) Close() error {
|
||||
if m.asnDB != nil {
|
||||
_ = m.asnDB.Close()
|
||||
}
|
||||
if m.cityDB != nil {
|
||||
_ = m.cityDB.Close()
|
||||
}
|
||||
if m.countryDB != nil {
|
||||
_ = m.countryDB.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetASNDB returns the ASN database reader.
|
||||
func (m *Manager) GetASNDB() *geoip2.Reader {
|
||||
return m.asnDB
|
||||
}
|
||||
|
||||
// GetCityDB returns the City database reader.
|
||||
func (m *Manager) GetCityDB() *geoip2.Reader {
|
||||
return m.cityDB
|
||||
}
|
||||
|
||||
// GetCountryDB returns the Country database reader.
|
||||
func (m *Manager) GetCountryDB() *geoip2.Reader {
|
||||
return m.countryDB
|
||||
}
|
||||
|
||||
func needsUpdate(filePath string, lastDownload time.Time) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's time to update
|
||||
if time.Since(lastDownload) > updateInterval {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
87
internal/database/database_test.go
Normal file
87
internal/database/database_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
)
|
||||
|
||||
func TestNeedsUpdate(t *testing.T) {
|
||||
tmpFile, err := os.CreateTemp("", "test-db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
tmpFile.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filePath string
|
||||
lastDownload time.Time
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "file doesn't exist",
|
||||
filePath: "/nonexistent/file",
|
||||
lastDownload: time.Now(),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "recent download",
|
||||
filePath: tmpFile.Name(),
|
||||
lastDownload: time.Now().Add(-time.Hour),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "old download",
|
||||
filePath: tmpFile.Name(),
|
||||
lastDownload: time.Now().Add(-8 * 24 * time.Hour),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := needsUpdate(tt.filePath, tt.lastDownload)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "ipapi-db-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
logger := slog.Default()
|
||||
stateManager, err := state.New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
manager, err := New(cfg, logger, stateManager)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create database manager: %v", err)
|
||||
}
|
||||
|
||||
if manager == nil {
|
||||
t.Fatal("expected manager, got nil")
|
||||
}
|
||||
|
||||
expectedDataDir := filepath.Join(tmpDir, "databases")
|
||||
if manager.dataDir != expectedDataDir {
|
||||
t.Errorf("expected data dir %s, got %s", expectedDataDir, manager.dataDir)
|
||||
}
|
||||
}
|
||||
158
internal/http/router.go
Normal file
158
internal/http/router.go
Normal file
@@ -0,0 +1,158 @@
|
||||
// Package http provides the HTTP server and routing functionality.
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
// IPInfo represents the API response for IP lookups.
|
||||
type IPInfo struct {
|
||||
IP string `json:"ip"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"countryCode,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
PostalCode string `json:"postalCode,omitempty"`
|
||||
Latitude float64 `json:"latitude,omitempty"`
|
||||
Longitude float64 `json:"longitude,omitempty"`
|
||||
Timezone string `json:"timezone,omitempty"`
|
||||
ASN uint `json:"asn,omitempty"`
|
||||
ASNOrg string `json:"asnOrg,omitempty"`
|
||||
}
|
||||
|
||||
// NewRouter creates a new HTTP router with all endpoints configured.
|
||||
func NewRouter(logger *slog.Logger, db *database.Manager) (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Middleware
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Recoverer)
|
||||
const requestTimeout = 60
|
||||
r.Use(middleware.Timeout(requestTimeout))
|
||||
|
||||
// Logging middleware
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := r.Context().Value(middleware.RequestIDKey).(string)
|
||||
logger.Debug("HTTP request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"remote_addr", r.RemoteAddr,
|
||||
"request_id", start,
|
||||
)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
|
||||
// Health check
|
||||
r.Get("/health", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("OK")); err != nil {
|
||||
logger.Error("Failed to write health response", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
// IP lookup endpoint
|
||||
r.Get("/api/{ip}", handleIPLookup(logger, db))
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func handleIPLookup(logger *slog.Logger, db *database.Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ipStr := chi.URLParam(r, "ip")
|
||||
|
||||
// Validate IP address
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
writeError(w, http.StatusBadRequest, "Invalid IP address")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
info := &IPInfo{
|
||||
IP: ipStr,
|
||||
}
|
||||
|
||||
// Look up in Country database
|
||||
if countryDB := db.GetCountryDB(); countryDB != nil {
|
||||
country, err := countryDB.Country(ip)
|
||||
if err == nil {
|
||||
info.Country = country.Country.Names["en"]
|
||||
info.CountryCode = country.Country.IsoCode
|
||||
}
|
||||
}
|
||||
|
||||
// Look up in City database
|
||||
if cityDB := db.GetCityDB(); cityDB != nil {
|
||||
city, err := cityDB.City(ip)
|
||||
if err == nil {
|
||||
info.City = city.City.Names["en"]
|
||||
if len(city.Subdivisions) > 0 {
|
||||
info.Region = city.Subdivisions[0].Names["en"]
|
||||
}
|
||||
info.PostalCode = city.Postal.Code
|
||||
info.Latitude = city.Location.Latitude
|
||||
info.Longitude = city.Location.Longitude
|
||||
info.Timezone = city.Location.TimeZone
|
||||
}
|
||||
}
|
||||
|
||||
// Look up in ASN database
|
||||
if asnDB := db.GetASNDB(); asnDB != nil {
|
||||
asn, err := asnDB.ASN(ip)
|
||||
if err == nil {
|
||||
info.ASN = asn.AutonomousSystemNumber
|
||||
info.ASNOrg = asn.AutonomousSystemOrganization
|
||||
}
|
||||
}
|
||||
|
||||
// Set content type and encode response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(info); err != nil {
|
||||
logger.Error("Failed to encode response", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "Internal server error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, code int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": message,
|
||||
}); err != nil {
|
||||
// Log error but don't try to write again
|
||||
_ = err
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:unused // will be used in future for rate limiting
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
|
||||
return host
|
||||
}
|
||||
151
internal/http/router_test.go
Normal file
151
internal/http/router_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestNewRouter(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
cfg := &config.Config{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
router, err := NewRouter(logger, db)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create router: %v", err)
|
||||
}
|
||||
|
||||
if router == nil {
|
||||
t.Fatal("expected router, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthEndpoint(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
cfg := &config.Config{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
router, _ := NewRouter(logger, db)
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", rec.Code)
|
||||
}
|
||||
|
||||
if rec.Body.String() != "OK" {
|
||||
t.Errorf("expected body 'OK', got %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPLookupEndpoint(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
cfg := &config.Config{
|
||||
StateDir: t.TempDir(),
|
||||
}
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
db, _ := database.New(cfg, logger, stateManager)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expectedCode int
|
||||
}{
|
||||
{"valid IPv4", "8.8.8.8", http.StatusOK},
|
||||
{"valid IPv6", "2001:4860:4860::8888", http.StatusOK},
|
||||
{"invalid IP", "invalid", http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/"+tt.ip, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Create a new context with the URL param
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("ip", tt.ip)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
handleIPLookup(logger, db)(rec, req)
|
||||
|
||||
if rec.Code != tt.expectedCode {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code)
|
||||
}
|
||||
|
||||
if tt.expectedCode == http.StatusOK {
|
||||
var info IPInfo
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &info); err != nil {
|
||||
t.Errorf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if info.IP != tt.ip {
|
||||
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "1.2.3.4, 5.6.7.8",
|
||||
},
|
||||
remoteAddr: "9.10.11.12:1234",
|
||||
expected: "1.2.3.4",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP",
|
||||
headers: map[string]string{
|
||||
"X-Real-IP": "1.2.3.4",
|
||||
},
|
||||
remoteAddr: "9.10.11.12:1234",
|
||||
expected: "1.2.3.4",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr only",
|
||||
headers: map[string]string{},
|
||||
remoteAddr: "9.10.11.12:1234",
|
||||
expected: "9.10.11.12",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
for k, v := range tt.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
ip := getClientIP(req)
|
||||
if ip != tt.expected {
|
||||
t.Errorf("expected IP %s, got %s", tt.expected, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
66
internal/http/server.go
Normal file
66
internal/http/server.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// Server manages the HTTP server lifecycle.
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
logger *slog.Logger
|
||||
router chi.Router
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
// NewServer creates a new HTTP server instance.
|
||||
func NewServer(cfg *config.Config, logger *slog.Logger, router chi.Router) (*Server, error) {
|
||||
return &Server{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
router: router,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins listening for HTTP requests.
|
||||
func (s *Server) Start(_ context.Context) error {
|
||||
addr := fmt.Sprintf(":%d", s.config.Port)
|
||||
s.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.router,
|
||||
ReadTimeout: 15 * time.Second, //nolint:mnd
|
||||
WriteTimeout: 15 * time.Second, //nolint:mnd
|
||||
IdleTimeout: 60 * time.Second, //nolint:mnd
|
||||
}
|
||||
|
||||
s.logger.Info("Starting HTTP server", "addr", addr)
|
||||
|
||||
go func() {
|
||||
if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.logger.Error("HTTP server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the HTTP server.
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
if s.httpServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info("Stopping HTTP server")
|
||||
|
||||
const shutdownTimeout = 30 * time.Second
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.httpServer.Shutdown(shutdownCtx)
|
||||
}
|
||||
64
internal/http/server_test.go
Normal file
64
internal/http/server_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
}
|
||||
logger := slog.Default()
|
||||
router := chi.NewRouter()
|
||||
|
||||
server, err := NewServer(cfg, logger, router)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %v", err)
|
||||
}
|
||||
|
||||
if server == nil {
|
||||
t.Fatal("expected server, got nil")
|
||||
}
|
||||
|
||||
if server.config != cfg {
|
||||
t.Error("config not set correctly")
|
||||
}
|
||||
|
||||
if server.logger != logger {
|
||||
t.Error("logger not set correctly")
|
||||
}
|
||||
|
||||
if server.router != router {
|
||||
t.Error("router not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerStartStop(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 0, // Use random port
|
||||
}
|
||||
logger := slog.Default()
|
||||
router := chi.NewRouter()
|
||||
|
||||
server, err := NewServer(cfg, logger, router)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Start server
|
||||
if err := server.Start(ctx); err != nil {
|
||||
t.Fatalf("failed to start server: %v", err)
|
||||
}
|
||||
|
||||
// Stop server
|
||||
if err := server.Stop(ctx); err != nil {
|
||||
t.Fatalf("failed to stop server: %v", err)
|
||||
}
|
||||
}
|
||||
23
internal/ipapi/cli.go
Normal file
23
internal/ipapi/cli.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Package ipapi provides the main application structure and CLI.
|
||||
package ipapi
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// CLIEntry is the main entry point for the CLI application.
|
||||
func CLIEntry() {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "ipapi",
|
||||
Short: "IP API is a simple IP information REST API",
|
||||
Long: `IP API provides GeoIP information for IPv4 and IPv6 addresses.`,
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(daemonCmd())
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
66
internal/ipapi/daemon.go
Normal file
66
internal/ipapi/daemon.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package ipapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
"git.eeqj.de/sneak/ipapi/internal/http"
|
||||
"git.eeqj.de/sneak/ipapi/internal/log"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
const defaultConfigFile = "/etc/ipapi/config.yml"
|
||||
|
||||
func daemonCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "daemon",
|
||||
Short: "Run the IP API daemon",
|
||||
Long: `Start the IP API HTTP server that provides GeoIP information.`,
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
configFile := getConfigFile()
|
||||
|
||||
app := fx.New(
|
||||
fx.Provide(
|
||||
func() (*config.Config, error) {
|
||||
return config.New(configFile)
|
||||
},
|
||||
log.New,
|
||||
state.New,
|
||||
database.New,
|
||||
http.NewServer,
|
||||
http.NewRouter,
|
||||
New,
|
||||
),
|
||||
fx.Invoke(func(lc fx.Lifecycle, ipapi *IPAPI) {
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
return ipapi.Start(ctx)
|
||||
},
|
||||
OnStop: func(ctx context.Context) error {
|
||||
return ipapi.Stop(ctx)
|
||||
},
|
||||
})
|
||||
}),
|
||||
)
|
||||
|
||||
app.Run()
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getConfigFile() string {
|
||||
configFile := os.Getenv("IP_API_CONFIG_FILE")
|
||||
if configFile == "" {
|
||||
return defaultConfigFile
|
||||
}
|
||||
|
||||
return configFile
|
||||
}
|
||||
71
internal/ipapi/ipapi.go
Normal file
71
internal/ipapi/ipapi.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package ipapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
"git.eeqj.de/sneak/ipapi/internal/http"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// Options contains all dependencies for IPAPI.
|
||||
type Options struct {
|
||||
fx.In
|
||||
|
||||
Config *config.Config
|
||||
Logger *slog.Logger
|
||||
State *state.Manager
|
||||
Database *database.Manager
|
||||
Server *http.Server
|
||||
}
|
||||
|
||||
// IPAPI is the main application structure.
|
||||
type IPAPI struct {
|
||||
config *config.Config
|
||||
logger *slog.Logger
|
||||
state *state.Manager
|
||||
database *database.Manager
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// New creates a new IPAPI instance with the given options.
|
||||
func New(opts Options) *IPAPI {
|
||||
return &IPAPI{
|
||||
config: opts.Config,
|
||||
logger: opts.Logger,
|
||||
state: opts.State,
|
||||
database: opts.Database,
|
||||
server: opts.Server,
|
||||
}
|
||||
}
|
||||
|
||||
// Start initializes and starts all components.
|
||||
func (i *IPAPI) Start(ctx context.Context) error {
|
||||
i.logger.Info("Starting IP API daemon",
|
||||
"port", i.config.Port,
|
||||
"state_dir", i.config.StateDir,
|
||||
)
|
||||
|
||||
// Initialize state
|
||||
if err := i.state.Initialize(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Download databases if needed
|
||||
if err := i.database.EnsureDatabases(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start HTTP server
|
||||
return i.server.Start(ctx)
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down all components.
|
||||
func (i *IPAPI) Stop(ctx context.Context) error {
|
||||
i.logger.Info("Stopping IP API daemon")
|
||||
|
||||
return i.server.Stop(ctx)
|
||||
}
|
||||
93
internal/ipapi/ipapi_test.go
Normal file
93
internal/ipapi/ipapi_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package ipapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
"git.eeqj.de/sneak/ipapi/internal/database"
|
||||
"git.eeqj.de/sneak/ipapi/internal/http"
|
||||
"git.eeqj.de/sneak/ipapi/internal/state"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
StateDir: t.TempDir(),
|
||||
LogLevel: "info",
|
||||
}
|
||||
logger := slog.Default()
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
dbManager, _ := database.New(cfg, logger, stateManager)
|
||||
router := chi.NewRouter()
|
||||
server, _ := http.NewServer(cfg, logger, router)
|
||||
|
||||
opts := Options{
|
||||
Config: cfg,
|
||||
Logger: logger,
|
||||
State: stateManager,
|
||||
Database: dbManager,
|
||||
Server: server,
|
||||
}
|
||||
|
||||
ipapi := New(opts)
|
||||
if ipapi == nil {
|
||||
t.Fatal("expected IPAPI instance, got nil")
|
||||
}
|
||||
|
||||
if ipapi.config != cfg {
|
||||
t.Error("config not set correctly")
|
||||
}
|
||||
|
||||
if ipapi.logger != logger {
|
||||
t.Error("logger not set correctly")
|
||||
}
|
||||
|
||||
if ipapi.state != stateManager {
|
||||
t.Error("state manager not set correctly")
|
||||
}
|
||||
|
||||
if ipapi.database != dbManager {
|
||||
t.Error("database manager not set correctly")
|
||||
}
|
||||
|
||||
if ipapi.server != server {
|
||||
t.Error("server not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartStop(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 0, // Random port
|
||||
StateDir: t.TempDir(),
|
||||
LogLevel: "info",
|
||||
}
|
||||
logger := slog.Default()
|
||||
stateManager, _ := state.New(cfg, logger)
|
||||
dbManager, _ := database.New(cfg, logger, stateManager)
|
||||
router := chi.NewRouter()
|
||||
server, _ := http.NewServer(cfg, logger, router)
|
||||
|
||||
opts := Options{
|
||||
Config: cfg,
|
||||
Logger: logger,
|
||||
State: stateManager,
|
||||
Database: dbManager,
|
||||
Server: server,
|
||||
}
|
||||
|
||||
ipapi := New(opts)
|
||||
ctx := context.Background()
|
||||
|
||||
// Initialize state first
|
||||
if err := stateManager.Initialize(ctx); err != nil {
|
||||
t.Fatalf("failed to initialize state: %v", err)
|
||||
}
|
||||
|
||||
// Stop should work even if not started
|
||||
if err := ipapi.Stop(ctx); err != nil {
|
||||
t.Errorf("stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
49
internal/log/log.go
Normal file
49
internal/log/log.go
Normal file
@@ -0,0 +1,49 @@
|
||||
// Package log provides structured logging functionality.
|
||||
package log
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
)
|
||||
|
||||
// New creates a new logger instance based on configuration.
|
||||
func New(cfg *config.Config) *slog.Logger {
|
||||
var level slog.Level
|
||||
switch strings.ToLower(cfg.LogLevel) {
|
||||
case "debug":
|
||||
level = slog.LevelDebug
|
||||
case "info":
|
||||
level = slog.LevelInfo
|
||||
case "warn", "warning":
|
||||
level = slog.LevelWarn
|
||||
case "error":
|
||||
level = slog.LevelError
|
||||
default:
|
||||
level = slog.LevelInfo
|
||||
}
|
||||
|
||||
opts := &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
if isTerminal() {
|
||||
handler = slog.NewTextHandler(os.Stdout, opts)
|
||||
} else {
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
}
|
||||
|
||||
return slog.New(handler)
|
||||
}
|
||||
|
||||
func isTerminal() bool {
|
||||
fileInfo, err := os.Stdout.Stat()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return (fileInfo.Mode() & os.ModeCharDevice) != 0
|
||||
}
|
||||
38
internal/log/log_test.go
Normal file
38
internal/log/log_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logLevel string
|
||||
}{
|
||||
{"debug level", "debug"},
|
||||
{"info level", "info"},
|
||||
{"warn level", "warn"},
|
||||
{"warning level", "warning"},
|
||||
{"error level", "error"},
|
||||
{"invalid level", "invalid"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
LogLevel: tt.logLevel,
|
||||
}
|
||||
logger := New(cfg)
|
||||
if logger == nil {
|
||||
t.Fatal("expected logger, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTerminal(t *testing.T) {
|
||||
// Just test that it doesn't panic
|
||||
_ = isTerminal()
|
||||
}
|
||||
134
internal/state/state.go
Normal file
134
internal/state/state.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Package state manages daemon state persistence.
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
stateFileName = "daemon.json"
|
||||
dirPermissions = 0750
|
||||
filePermissions = 0600
|
||||
)
|
||||
|
||||
// State holds the daemon's persistent state information.
|
||||
type State struct {
|
||||
LastASNDownload time.Time `json:"lastAsnDownload"`
|
||||
LastCityDownload time.Time `json:"lastCityDownload"`
|
||||
LastCountryDownload time.Time `json:"lastCountryDownload"`
|
||||
}
|
||||
|
||||
// Manager handles state file operations.
|
||||
type Manager struct {
|
||||
config *config.Config
|
||||
logger *slog.Logger
|
||||
statePath string
|
||||
}
|
||||
|
||||
// New creates a new state manager.
|
||||
func New(cfg *config.Config, logger *slog.Logger) (*Manager, error) {
|
||||
statePath := filepath.Join(cfg.StateDir, stateFileName)
|
||||
|
||||
return &Manager{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
statePath: statePath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Initialize ensures the state directory exists.
|
||||
func (m *Manager) Initialize(_ context.Context) error {
|
||||
// Ensure state directory exists
|
||||
dir := filepath.Dir(m.statePath)
|
||||
if err := os.MkdirAll(dir, dirPermissions); err != nil {
|
||||
return fmt.Errorf("failed to create state directory: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Info("State manager initialized", "path", m.statePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load reads the state from disk.
|
||||
func (m *Manager) Load() (*State, error) {
|
||||
data, err := os.ReadFile(m.statePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Return empty state if file doesn't exist
|
||||
return &State{}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to read state file: %w", err)
|
||||
}
|
||||
|
||||
var state State
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse state file: %w", err)
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
// Save writes the state to disk.
|
||||
func (m *Manager) Save(state *State) error {
|
||||
data, err := json.MarshalIndent(state, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal state: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first
|
||||
tmpPath := m.statePath + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, data, filePermissions); err != nil {
|
||||
return fmt.Errorf("failed to write state file: %w", err)
|
||||
}
|
||||
|
||||
// Rename to final path (atomic operation)
|
||||
if err := os.Rename(tmpPath, m.statePath); err != nil {
|
||||
return fmt.Errorf("failed to save state file: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Debug("State saved", "path", m.statePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateASNDownloadTime updates the ASN database download timestamp.
|
||||
func (m *Manager) UpdateASNDownloadTime() error {
|
||||
state, err := m.Load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state.LastASNDownload = time.Now().UTC()
|
||||
|
||||
return m.Save(state)
|
||||
}
|
||||
|
||||
// UpdateCityDownloadTime updates the City database download timestamp.
|
||||
func (m *Manager) UpdateCityDownloadTime() error {
|
||||
state, err := m.Load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state.LastCityDownload = time.Now().UTC()
|
||||
|
||||
return m.Save(state)
|
||||
}
|
||||
|
||||
// UpdateCountryDownloadTime updates the Country database download timestamp.
|
||||
func (m *Manager) UpdateCountryDownloadTime() error {
|
||||
state, err := m.Load()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state.LastCountryDownload = time.Now().UTC()
|
||||
|
||||
return m.Save(state)
|
||||
}
|
||||
86
internal/state/state_test.go
Normal file
86
internal/state/state_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/ipapi/internal/config"
|
||||
)
|
||||
|
||||
func TestManager(t *testing.T) {
|
||||
// Create temp directory for testing
|
||||
tmpDir, err := os.MkdirTemp("", "ipapi-state-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
StateDir: tmpDir,
|
||||
}
|
||||
logger := slog.Default()
|
||||
|
||||
m, err := New(cfg, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// Test Initialize
|
||||
ctx := context.Background()
|
||||
if err := m.Initialize(ctx); err != nil {
|
||||
t.Fatalf("failed to initialize: %v", err)
|
||||
}
|
||||
|
||||
// Test Load with non-existent file
|
||||
state, err := m.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load empty state: %v", err)
|
||||
}
|
||||
if !state.LastASNDownload.IsZero() {
|
||||
t.Error("expected zero time for new state")
|
||||
}
|
||||
|
||||
// Test Save and Load
|
||||
now := time.Now().UTC()
|
||||
state.LastASNDownload = now
|
||||
state.LastCityDownload = now
|
||||
state.LastCountryDownload = now
|
||||
|
||||
if err := m.Save(state); err != nil {
|
||||
t.Fatalf("failed to save state: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
statePath := filepath.Join(tmpDir, stateFileName)
|
||||
if _, err := os.Stat(statePath); os.IsNotExist(err) {
|
||||
t.Error("state file was not created")
|
||||
}
|
||||
|
||||
// Load and verify
|
||||
loaded, err := m.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load saved state: %v", err)
|
||||
}
|
||||
|
||||
if !loaded.LastASNDownload.Equal(now) {
|
||||
t.Errorf("ASN download time mismatch: got %v, want %v", loaded.LastASNDownload, now)
|
||||
}
|
||||
|
||||
// Test update methods
|
||||
if err := m.UpdateASNDownloadTime(); err != nil {
|
||||
t.Fatalf("failed to update ASN download time: %v", err)
|
||||
}
|
||||
|
||||
loaded, err = m.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load after update: %v", err)
|
||||
}
|
||||
|
||||
if loaded.LastASNDownload.Before(now) || loaded.LastASNDownload.Equal(now) {
|
||||
t.Error("ASN download time was not updated")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user