package http import ( "context" "encoding/json" "log/slog" "net/http" "net/http/httptest" "testing" "git.eeqj.de/sneak/ipapi/internal/config" "git.eeqj.de/sneak/ipapi/internal/database" "git.eeqj.de/sneak/ipapi/internal/state" "github.com/go-chi/chi/v5" ) func TestNewRouter(t *testing.T) { logger := slog.Default() cfg := &config.Config{ StateDir: t.TempDir(), } stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) router, err := NewRouter(logger, db) if err != nil { t.Fatalf("failed to create router: %v", err) } if router == nil { t.Fatal("expected router, got nil") } } func TestHealthEndpoint(t *testing.T) { logger := slog.Default() cfg := &config.Config{ StateDir: t.TempDir(), } stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) router, _ := NewRouter(logger, db) req := httptest.NewRequest("GET", "/health", nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } if rec.Body.String() != "OK" { t.Errorf("expected body 'OK', got %s", rec.Body.String()) } } func TestIPLookupEndpoint(t *testing.T) { logger := slog.Default() cfg := &config.Config{ StateDir: t.TempDir(), } stateManager, _ := state.New(cfg, logger) db, _ := database.New(cfg, logger, stateManager) tests := []struct { name string ip string expectedCode int }{ {"valid IPv4", "8.8.8.8", http.StatusOK}, {"valid IPv6", "2001:4860:4860::8888", http.StatusOK}, {"invalid IP", "invalid", http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/api/"+tt.ip, nil) rec := httptest.NewRecorder() // Create a new context with the URL param rctx := chi.NewRouteContext() rctx.URLParams.Add("ip", tt.ip) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) handleIPLookup(logger, db)(rec, req) if rec.Code != tt.expectedCode { t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code) } if tt.expectedCode == http.StatusOK { var info IPInfo if err := json.Unmarshal(rec.Body.Bytes(), &info); err != nil { t.Errorf("failed to unmarshal response: %v", err) } if info.IP != tt.ip { t.Errorf("expected IP %s, got %s", tt.ip, info.IP) } } }) } } func TestGetClientIP(t *testing.T) { tests := []struct { name string headers map[string]string remoteAddr string expected string }{ { name: "X-Forwarded-For", headers: map[string]string{ "X-Forwarded-For": "1.2.3.4, 5.6.7.8", }, remoteAddr: "9.10.11.12:1234", expected: "1.2.3.4", }, { name: "X-Real-IP", headers: map[string]string{ "X-Real-IP": "1.2.3.4", }, remoteAddr: "9.10.11.12:1234", expected: "1.2.3.4", }, { name: "RemoteAddr only", headers: map[string]string{}, remoteAddr: "9.10.11.12:1234", expected: "9.10.11.12", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.RemoteAddr = tt.remoteAddr for k, v := range tt.headers { req.Header.Set(k, v) } ip := getClientIP(req) if ip != tt.expected { t.Errorf("expected IP %s, got %s", tt.expected, ip) } }) } }