- Add http module for proper FX dependency injection - Fix router to accept state manager parameter - Implement proper CIDR-based checking for RFC1918 and documentation IPs - Add reasonable timeouts (30s) for database downloads - Update tests to download databases to temporary directories - Add tests for multiple IP lookups and error cases - All tests passing
238 lines
6.0 KiB
Go
238 lines
6.0 KiB
Go
// Package database handles GeoIP database management and downloads.
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"git.eeqj.de/sneak/ipapi/internal/config"
|
|
"git.eeqj.de/sneak/ipapi/internal/state"
|
|
"github.com/oschwald/geoip2-golang"
|
|
)
|
|
|
|
const (
|
|
asnURL = "https://git.io/GeoLite2-ASN.mmdb"
|
|
cityURL = "https://git.io/GeoLite2-City.mmdb"
|
|
countryURL = "https://git.io/GeoLite2-Country.mmdb"
|
|
|
|
asnFile = "GeoLite2-ASN.mmdb"
|
|
cityFile = "GeoLite2-City.mmdb"
|
|
countryFile = "GeoLite2-Country.mmdb"
|
|
|
|
downloadTimeout = 30 * time.Second
|
|
updateInterval = 7 * 24 * time.Hour // 1 week
|
|
|
|
defaultDirPermissions = 0750
|
|
defaultFilePermissions = 0640
|
|
)
|
|
|
|
// Manager handles GeoIP database operations.
|
|
type Manager struct {
|
|
config *config.Config
|
|
logger *slog.Logger
|
|
state *state.Manager
|
|
dataDir string
|
|
asnDB *geoip2.Reader
|
|
cityDB *geoip2.Reader
|
|
countryDB *geoip2.Reader
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// New creates a new database manager.
|
|
func New(cfg *config.Config, logger *slog.Logger, state *state.Manager) (*Manager, error) {
|
|
dataDir := filepath.Join(cfg.StateDir, "databases")
|
|
|
|
return &Manager{
|
|
config: cfg,
|
|
logger: logger,
|
|
state: state,
|
|
dataDir: dataDir,
|
|
httpClient: &http.Client{
|
|
Timeout: downloadTimeout,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// EnsureDatabases downloads missing or outdated databases.
|
|
func (m *Manager) EnsureDatabases(ctx context.Context) error {
|
|
// Create data directory if it doesn't exist
|
|
if err := os.MkdirAll(m.dataDir, defaultDirPermissions); err != nil {
|
|
return fmt.Errorf("failed to create data directory: %w", err)
|
|
}
|
|
|
|
// Load current state
|
|
currentState, err := m.state.Load()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load state: %w", err)
|
|
}
|
|
|
|
// Check and download ASN database
|
|
asnPath := filepath.Join(m.dataDir, asnFile)
|
|
if needsUpdate(asnPath, currentState.LastASNDownload) {
|
|
m.logger.Info("Downloading ASN database")
|
|
if err := m.downloadFile(ctx, asnURL, asnPath); err != nil {
|
|
return fmt.Errorf("failed to download ASN database: %w", err)
|
|
}
|
|
if err := m.state.UpdateASNDownloadTime(); err != nil {
|
|
return fmt.Errorf("failed to update ASN download time: %w", err)
|
|
}
|
|
}
|
|
|
|
// Check and download City database
|
|
cityPath := filepath.Join(m.dataDir, cityFile)
|
|
if needsUpdate(cityPath, currentState.LastCityDownload) {
|
|
m.logger.Info("Downloading City database")
|
|
if err := m.downloadFile(ctx, cityURL, cityPath); err != nil {
|
|
return fmt.Errorf("failed to download City database: %w", err)
|
|
}
|
|
if err := m.state.UpdateCityDownloadTime(); err != nil {
|
|
return fmt.Errorf("failed to update City download time: %w", err)
|
|
}
|
|
}
|
|
|
|
// Check and download Country database
|
|
countryPath := filepath.Join(m.dataDir, countryFile)
|
|
if needsUpdate(countryPath, currentState.LastCountryDownload) {
|
|
m.logger.Info("Downloading Country database")
|
|
if err := m.downloadFile(ctx, countryURL, countryPath); err != nil {
|
|
return fmt.Errorf("failed to download Country database: %w", err)
|
|
}
|
|
if err := m.state.UpdateCountryDownloadTime(); err != nil {
|
|
return fmt.Errorf("failed to update Country download time: %w", err)
|
|
}
|
|
}
|
|
|
|
// Open databases
|
|
if err := m.openDatabases(); err != nil {
|
|
return fmt.Errorf("failed to open databases: %w", err)
|
|
}
|
|
|
|
m.logger.Info("All databases ready")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) downloadFile(ctx context.Context, url, destPath string) error {
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
resp, err := m.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to download: %w", err)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
}
|
|
|
|
// Write to temporary file first
|
|
tmpPath := destPath + ".tmp"
|
|
tmpFile, err := os.Create(tmpPath) //nolint:gosec // temporary file with predictable name is ok
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create temp file: %w", err)
|
|
}
|
|
defer func() { _ = os.Remove(tmpPath) }()
|
|
|
|
_, err = io.Copy(tmpFile, resp.Body)
|
|
if err2 := tmpFile.Close(); err2 != nil {
|
|
return fmt.Errorf("failed to close temp file: %w", err2)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write file: %w", err)
|
|
}
|
|
|
|
// Move to final location
|
|
if err := os.Rename(tmpPath, destPath); err != nil {
|
|
return fmt.Errorf("failed to move file: %w", err)
|
|
}
|
|
|
|
m.logger.Debug("Downloaded file", "url", url, "path", destPath)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) openDatabases() error {
|
|
var err error
|
|
|
|
// Open ASN database
|
|
asnPath := filepath.Join(m.dataDir, asnFile)
|
|
m.asnDB, err = geoip2.Open(asnPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open ASN database: %w", err)
|
|
}
|
|
|
|
// Open City database
|
|
cityPath := filepath.Join(m.dataDir, cityFile)
|
|
m.cityDB, err = geoip2.Open(cityPath)
|
|
if err != nil {
|
|
_ = m.asnDB.Close()
|
|
|
|
return fmt.Errorf("failed to open City database: %w", err)
|
|
}
|
|
|
|
// Open Country database
|
|
countryPath := filepath.Join(m.dataDir, countryFile)
|
|
m.countryDB, err = geoip2.Open(countryPath)
|
|
if err != nil {
|
|
_ = m.asnDB.Close()
|
|
_ = m.cityDB.Close()
|
|
|
|
return fmt.Errorf("failed to open Country database: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes all open databases.
|
|
func (m *Manager) Close() error {
|
|
if m.asnDB != nil {
|
|
_ = m.asnDB.Close()
|
|
}
|
|
if m.cityDB != nil {
|
|
_ = m.cityDB.Close()
|
|
}
|
|
if m.countryDB != nil {
|
|
_ = m.countryDB.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetASNDB returns the ASN database reader.
|
|
func (m *Manager) GetASNDB() *geoip2.Reader {
|
|
return m.asnDB
|
|
}
|
|
|
|
// GetCityDB returns the City database reader.
|
|
func (m *Manager) GetCityDB() *geoip2.Reader {
|
|
return m.cityDB
|
|
}
|
|
|
|
// GetCountryDB returns the Country database reader.
|
|
func (m *Manager) GetCountryDB() *geoip2.Reader {
|
|
return m.countryDB
|
|
}
|
|
|
|
func needsUpdate(filePath string, lastDownload time.Time) bool {
|
|
// Check if file exists
|
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
|
return true
|
|
}
|
|
|
|
// Check if it's time to update
|
|
if time.Since(lastDownload) > updateInterval {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|