package audit_test import ( "context" "log/slog" "net/http" "net/http/httptest" "os" "testing" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/fx" "sneak.berlin/go/upaas/internal/config" "sneak.berlin/go/upaas/internal/database" "sneak.berlin/go/upaas/internal/globals" "sneak.berlin/go/upaas/internal/logger" "sneak.berlin/go/upaas/internal/metrics" "sneak.berlin/go/upaas/internal/models" "sneak.berlin/go/upaas/internal/service/audit" ) func setupTestAuditService(t *testing.T) (*audit.Service, *database.Database) { t.Helper() globals.SetAppname("upaas-test") globals.SetVersion("test") tmpDir := t.TempDir() cfg := &config.Config{ DataDir: tmpDir, } log := slog.New(slog.NewTextHandler(os.Stderr, nil)) logWrapper := logger.NewForTest(log) db, err := database.New(fx.Lifecycle(nil), database.Params{ Logger: logWrapper, Config: cfg, }) require.NoError(t, err) reg := prometheus.NewRegistry() metricsInstance := metrics.NewForTest(reg) svc, err := audit.New(fx.Lifecycle(nil), audit.ServiceParams{ Logger: logWrapper, Database: db, Metrics: metricsInstance, }) require.NoError(t, err) return svc, db } func TestAuditServiceLog(t *testing.T) { t.Parallel() svc, db := setupTestAuditService(t) ctx := context.Background() svc.Log(ctx, audit.LogEntry{ UserID: 1, Username: "admin", Action: models.AuditActionLogin, ResourceType: models.AuditResourceSession, Detail: "user logged in", RemoteIP: "127.0.0.1", }) entries, err := models.FindAuditEntries(ctx, db, 10) require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, "admin", entries[0].Username) assert.Equal(t, models.AuditActionLogin, entries[0].Action) assert.Equal(t, "127.0.0.1", entries[0].RemoteIP.String) } func TestAuditServiceLogFromRequest(t *testing.T) { t.Parallel() svc, db := setupTestAuditService(t) ctx := context.Background() request := httptest.NewRequest(http.MethodPost, "/apps", nil) request.RemoteAddr = "10.0.0.1:12345" svc.LogFromRequest(ctx, request, audit.LogEntry{ Username: "admin", Action: models.AuditActionAppCreate, ResourceType: models.AuditResourceApp, ResourceID: "app-1", Detail: "created app", }) entries, err := models.FindAuditEntries(ctx, db, 10) require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, "10.0.0.1", entries[0].RemoteIP.String) assert.Equal(t, "app-1", entries[0].ResourceID.String) } func TestAuditServiceLogFromRequestWithXRealIPTrustedProxy(t *testing.T) { t.Parallel() svc, db := setupTestAuditService(t) ctx := context.Background() // When the request comes from a trusted proxy (RFC1918), X-Real-IP is honoured. request := httptest.NewRequest(http.MethodPost, "/apps", nil) request.RemoteAddr = "10.0.0.1:1234" request.Header.Set("X-Real-IP", "203.0.113.50") svc.LogFromRequest(ctx, request, audit.LogEntry{ Username: "admin", Action: models.AuditActionAppCreate, ResourceType: models.AuditResourceApp, }) entries, err := models.FindAuditEntries(ctx, db, 10) require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, "203.0.113.50", entries[0].RemoteIP.String) } func TestAuditServiceLogFromRequestWithXRealIPUntrustedProxy(t *testing.T) { t.Parallel() svc, db := setupTestAuditService(t) ctx := context.Background() // When the request comes from a public IP, X-Real-IP is ignored (anti-spoof). request := httptest.NewRequest(http.MethodPost, "/apps", nil) request.RemoteAddr = "203.0.113.99:1234" request.Header.Set("X-Real-IP", "10.0.0.1") svc.LogFromRequest(ctx, request, audit.LogEntry{ Username: "admin", Action: models.AuditActionAppCreate, ResourceType: models.AuditResourceApp, }) entries, err := models.FindAuditEntries(ctx, db, 10) require.NoError(t, err) require.Len(t, entries, 1) assert.Equal(t, "203.0.113.99", entries[0].RemoteIP.String) } func TestAuditServiceRecent(t *testing.T) { t.Parallel() svc, _ := setupTestAuditService(t) ctx := context.Background() for range 5 { svc.Log(ctx, audit.LogEntry{ Username: "admin", Action: models.AuditActionLogin, ResourceType: models.AuditResourceSession, }) } entries, err := svc.Recent(ctx, 3) require.NoError(t, err) assert.Len(t, entries, 3) } func TestAuditServiceForResource(t *testing.T) { t.Parallel() svc, _ := setupTestAuditService(t) ctx := context.Background() // Log entries for different resources. svc.Log(ctx, audit.LogEntry{ Username: "admin", Action: models.AuditActionAppCreate, ResourceType: models.AuditResourceApp, ResourceID: "app-1", }) svc.Log(ctx, audit.LogEntry{ Username: "admin", Action: models.AuditActionAppDeploy, ResourceType: models.AuditResourceApp, ResourceID: "app-1", }) svc.Log(ctx, audit.LogEntry{ Username: "admin", Action: models.AuditActionAppCreate, ResourceType: models.AuditResourceApp, ResourceID: "app-2", }) entries, err := svc.ForResource(ctx, models.AuditResourceApp, "app-1", 10) require.NoError(t, err) assert.Len(t, entries, 2) } func TestAuditServiceLogWithNoOptionalFields(t *testing.T) { t.Parallel() svc, db := setupTestAuditService(t) ctx := context.Background() svc.Log(ctx, audit.LogEntry{ Username: "system", Action: models.AuditActionWebhookReceive, ResourceType: models.AuditResourceWebhook, }) entries, err := models.FindAuditEntries(ctx, db, 10) require.NoError(t, err) require.Len(t, entries, 1) assert.False(t, entries[0].UserID.Valid) assert.False(t, entries[0].ResourceID.Valid) assert.False(t, entries[0].Detail.Valid) assert.False(t, entries[0].RemoteIP.Valid) }