package http import ( "context" "encoding/json" "log/slog" "net/http" "net/http/httptest" "testing" "time" "git.eeqj.de/sneak/ipapi/internal/config" "git.eeqj.de/sneak/ipapi/internal/database" "git.eeqj.de/sneak/ipapi/internal/state" "github.com/go-chi/chi/v5" ) func TestNewRouter(t *testing.T) { logger := slog.Default() cfg := &config.Config{ StateDir: t.TempDir(), } stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) router, err := NewRouter(logger, db, stateManager) if err != nil { t.Fatalf("failed to create router: %v", err) } if router == nil { t.Fatal("expected router, got nil") } } func TestHealthEndpoint(t *testing.T) { logger := slog.Default() cfg := &config.Config{ StateDir: t.TempDir(), } stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) router, _ := NewRouter(logger, db, stateManager) req := httptest.NewRequest("GET", "/health", nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } if rec.Body.String() != "OK" { t.Errorf("expected body 'OK', got %s", rec.Body.String()) } } func TestIPLookupEndpoint(t *testing.T) { logger := slog.Default() tmpDir := t.TempDir() cfg := &config.Config{ StateDir: tmpDir, } stateManager, _ := state.New(cfg, logger) _ = stateManager.Initialize(context.Background()) db, _ := database.New(cfg, logger, stateManager) // Download databases to temp directory for testing ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() t.Log("Downloading test databases...") if err := db.EnsureDatabases(ctx); err != nil { t.Skipf("Skipping test - could not download databases: %v", err) } tests := []struct { name string ip string expectedCode int }{ {"valid IPv4", "8.8.8.8", http.StatusOK}, {"valid IPv6", "2001:4860:4860::8888", http.StatusOK}, {"invalid IP", "invalid", http.StatusOK}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+tt.ip, nil) rec := httptest.NewRecorder() // Create a new context with the URL param rctx := chi.NewRouteContext() rctx.URLParams.Add("ips", tt.ip) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) handleIPLookup(logger, db, stateManager)(rec, req) if rec.Code != tt.expectedCode { t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code) } if tt.expectedCode == http.StatusOK { var response SuccessResponse if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { t.Errorf("failed to unmarshal response: %v", err) } if response.Status != "success" { t.Errorf("expected status 'success', got %s", response.Status) } if len(response.Result) != 1 { t.Errorf("expected 1 result, got %d", len(response.Result)) } if info, ok := response.Result[tt.ip]; ok { if info.IP != tt.ip { t.Errorf("expected IP %s, got %s", tt.ip, info.IP) } } else { t.Errorf("expected result for IP %s", tt.ip) } } }) } } func TestIPLookupMultipleIPs(t *testing.T) { logger := slog.Default() tmpDir := t.TempDir() cfg := &config.Config{ StateDir: tmpDir, } stateManager, _ := state.New(cfg, logger) _ = stateManager.Initialize(context.Background()) db, _ := database.New(cfg, logger, stateManager) // Download databases to temp directory for testing ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() t.Log("Downloading test databases...") if err := db.EnsureDatabases(ctx); err != nil { t.Skipf("Skipping test - could not download databases: %v", err) } // Test multiple IPs req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/8.8.8.8,1.1.1.1,invalid", nil) rec := httptest.NewRecorder() // Create a new context with the URL param rctx := chi.NewRouteContext() rctx.URLParams.Add("ips", "8.8.8.8,1.1.1.1,invalid") req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) handleIPLookup(logger, db, stateManager)(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } var response SuccessResponse if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { t.Errorf("failed to unmarshal response: %v", err) } if len(response.Result) != 3 { t.Errorf("expected 3 results, got %d", len(response.Result)) } // Check that all requested IPs are in the result for _, ip := range []string{"8.8.8.8", "1.1.1.1", "invalid"} { if _, ok := response.Result[ip]; !ok { t.Errorf("expected result for IP %s", ip) } } } func TestIPLookupTooManyIPs(t *testing.T) { logger := slog.Default() cfg := &config.Config{ StateDir: t.TempDir(), } stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) // Create 51 IPs (more than the limit) ips := "" for i := 0; i < 51; i++ { if i > 0 { ips += "," } ips += "1.1.1.1" } req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+ips, nil) rec := httptest.NewRecorder() // Create a new context with the URL param rctx := chi.NewRouteContext() rctx.URLParams.Add("ips", ips) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) handleIPLookup(logger, db, stateManager)(rec, req) if rec.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", rec.Code) } var response ErrorResponse if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil { t.Errorf("failed to unmarshal response: %v", err) } if response.Status != "error" { t.Errorf("expected status 'error', got %s", response.Status) } } func TestGetClientIP(t *testing.T) { tests := []struct { name string headers map[string]string remoteAddr string expected string }{ { name: "X-Forwarded-For", headers: map[string]string{ "X-Forwarded-For": "1.2.3.4, 5.6.7.8", }, remoteAddr: "9.10.11.12:1234", expected: "1.2.3.4", }, { name: "X-Real-IP", headers: map[string]string{ "X-Real-IP": "1.2.3.4", }, remoteAddr: "9.10.11.12:1234", expected: "1.2.3.4", }, { name: "RemoteAddr only", headers: map[string]string{}, remoteAddr: "9.10.11.12:1234", expected: "9.10.11.12", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tt.remoteAddr for k, v := range tt.headers { req.Header.Set(k, v) } ip := getClientIP(req) if ip != tt.expected { t.Errorf("expected IP %s, got %s", tt.expected, ip) } }) } }