ipapi/internal/http/router_test.go
sneak 08ca75966e Fix dependency injection and implement proper CIDR checking
- Add http module for proper FX dependency injection
- Fix router to accept state manager parameter
- Implement proper CIDR-based checking for RFC1918 and documentation IPs
- Add reasonable timeouts (30s) for database downloads
- Update tests to download databases to temporary directories
- Add tests for multiple IP lookups and error cases
- All tests passing
2025-07-27 18:44:53 +02:00

268 lines
6.6 KiB
Go

package http
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.eeqj.de/sneak/ipapi/internal/config"
"git.eeqj.de/sneak/ipapi/internal/database"
"git.eeqj.de/sneak/ipapi/internal/state"
"github.com/go-chi/chi/v5"
)
func TestNewRouter(t *testing.T) {
logger := slog.Default()
cfg := &config.Config{
StateDir: t.TempDir(),
}
stateManager, _ := state.New(cfg, logger)
db, _ := database.New(cfg, logger, stateManager)
router, err := NewRouter(logger, db, stateManager)
if err != nil {
t.Fatalf("failed to create router: %v", err)
}
if router == nil {
t.Fatal("expected router, got nil")
}
}
func TestHealthEndpoint(t *testing.T) {
logger := slog.Default()
cfg := &config.Config{
StateDir: t.TempDir(),
}
stateManager, _ := state.New(cfg, logger)
db, _ := database.New(cfg, logger, stateManager)
router, _ := NewRouter(logger, db, stateManager)
req := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rec.Code)
}
if rec.Body.String() != "OK" {
t.Errorf("expected body 'OK', got %s", rec.Body.String())
}
}
func TestIPLookupEndpoint(t *testing.T) {
logger := slog.Default()
tmpDir := t.TempDir()
cfg := &config.Config{
StateDir: tmpDir,
}
stateManager, _ := state.New(cfg, logger)
_ = stateManager.Initialize(context.Background())
db, _ := database.New(cfg, logger, stateManager)
// Download databases to temp directory for testing
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
t.Log("Downloading test databases...")
if err := db.EnsureDatabases(ctx); err != nil {
t.Skipf("Skipping test - could not download databases: %v", err)
}
tests := []struct {
name string
ip string
expectedCode int
}{
{"valid IPv4", "8.8.8.8", http.StatusOK},
{"valid IPv6", "2001:4860:4860::8888", http.StatusOK},
{"invalid IP", "invalid", http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+tt.ip, nil)
rec := httptest.NewRecorder()
// Create a new context with the URL param
rctx := chi.NewRouteContext()
rctx.URLParams.Add("ips", tt.ip)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
handleIPLookup(logger, db, stateManager)(rec, req)
if rec.Code != tt.expectedCode {
t.Errorf("expected status %d, got %d", tt.expectedCode, rec.Code)
}
if tt.expectedCode == http.StatusOK {
var response SuccessResponse
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if response.Status != "success" {
t.Errorf("expected status 'success', got %s", response.Status)
}
if len(response.Result) != 1 {
t.Errorf("expected 1 result, got %d", len(response.Result))
}
if info, ok := response.Result[tt.ip]; ok {
if info.IP != tt.ip {
t.Errorf("expected IP %s, got %s", tt.ip, info.IP)
}
} else {
t.Errorf("expected result for IP %s", tt.ip)
}
}
})
}
}
func TestIPLookupMultipleIPs(t *testing.T) {
logger := slog.Default()
tmpDir := t.TempDir()
cfg := &config.Config{
StateDir: tmpDir,
}
stateManager, _ := state.New(cfg, logger)
_ = stateManager.Initialize(context.Background())
db, _ := database.New(cfg, logger, stateManager)
// Download databases to temp directory for testing
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
t.Log("Downloading test databases...")
if err := db.EnsureDatabases(ctx); err != nil {
t.Skipf("Skipping test - could not download databases: %v", err)
}
// Test multiple IPs
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/8.8.8.8,1.1.1.1,invalid", nil)
rec := httptest.NewRecorder()
// Create a new context with the URL param
rctx := chi.NewRouteContext()
rctx.URLParams.Add("ips", "8.8.8.8,1.1.1.1,invalid")
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
handleIPLookup(logger, db, stateManager)(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rec.Code)
}
var response SuccessResponse
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if len(response.Result) != 3 {
t.Errorf("expected 3 results, got %d", len(response.Result))
}
// Check that all requested IPs are in the result
for _, ip := range []string{"8.8.8.8", "1.1.1.1", "invalid"} {
if _, ok := response.Result[ip]; !ok {
t.Errorf("expected result for IP %s", ip)
}
}
}
func TestIPLookupTooManyIPs(t *testing.T) {
logger := slog.Default()
cfg := &config.Config{
StateDir: t.TempDir(),
}
stateManager, _ := state.New(cfg, logger)
db, _ := database.New(cfg, logger, stateManager)
// Create 51 IPs (more than the limit)
ips := ""
for i := 0; i < 51; i++ {
if i > 0 {
ips += ","
}
ips += "1.1.1.1"
}
req := httptest.NewRequest("GET", "/api/v1/ipinfo/local/"+ips, nil)
rec := httptest.NewRecorder()
// Create a new context with the URL param
rctx := chi.NewRouteContext()
rctx.URLParams.Add("ips", ips)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
handleIPLookup(logger, db, stateManager)(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rec.Code)
}
var response ErrorResponse
if err := json.Unmarshal(rec.Body.Bytes(), &response); err != nil {
t.Errorf("failed to unmarshal response: %v", err)
}
if response.Status != "error" {
t.Errorf("expected status 'error', got %s", response.Status)
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
headers map[string]string
remoteAddr string
expected string
}{
{
name: "X-Forwarded-For",
headers: map[string]string{
"X-Forwarded-For": "1.2.3.4, 5.6.7.8",
},
remoteAddr: "9.10.11.12:1234",
expected: "1.2.3.4",
},
{
name: "X-Real-IP",
headers: map[string]string{
"X-Real-IP": "1.2.3.4",
},
remoteAddr: "9.10.11.12:1234",
expected: "1.2.3.4",
},
{
name: "RemoteAddr only",
headers: map[string]string{},
remoteAddr: "9.10.11.12:1234",
expected: "9.10.11.12",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
for k, v := range tt.headers {
req.Header.Set(k, v)
}
ip := getClientIP(req)
if ip != tt.expected {
t.Errorf("expected IP %s, got %s", tt.expected, ip)
}
})
}
}