diff --git a/internal/config/config.go b/internal/config/config.go index 6d8f6f7..b3adafb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -52,6 +52,7 @@ type Config struct { MetricsUsername string MetricsPassword string SessionSecret string + CORSOrigins string params *Params log *slog.Logger } @@ -102,6 +103,7 @@ func setupViper(name string) { viper.SetDefault("METRICS_USERNAME", "") viper.SetDefault("METRICS_PASSWORD", "") viper.SetDefault("SESSION_SECRET", "") + viper.SetDefault("CORS_ORIGINS", "") } func buildConfig(log *slog.Logger, params *Params) (*Config, error) { @@ -136,6 +138,7 @@ func buildConfig(log *slog.Logger, params *Params) (*Config, error) { MetricsUsername: viper.GetString("METRICS_USERNAME"), MetricsPassword: viper.GetString("METRICS_PASSWORD"), SessionSecret: viper.GetString("SESSION_SECRET"), + CORSOrigins: viper.GetString("CORS_ORIGINS"), params: params, log: log, } diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go new file mode 100644 index 0000000..56ecc6b --- /dev/null +++ b/internal/middleware/cors_test.go @@ -0,0 +1,81 @@ +package middleware //nolint:testpackage // tests internal CORS behavior + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "git.eeqj.de/sneak/upaas/internal/config" +) + +//nolint:gosec // test credentials +func newCORSTestMiddleware(corsOrigins string) *Middleware { + return &Middleware{ + log: slog.Default(), + params: &Params{ + Config: &config.Config{ + CORSOrigins: corsOrigins, + SessionSecret: "test-secret-32-bytes-long-enough", + }, + }, + } +} + +func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) { + t.Parallel() + + m := newCORSTestMiddleware("") + handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://evil.com") + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"), + "expected no CORS headers when no origins configured") +} + +func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) { + t.Parallel() + + m := newCORSTestMiddleware("https://app.example.com,https://other.example.com") + handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://app.example.com") + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, "https://app.example.com", + rec.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", + rec.Header().Get("Access-Control-Allow-Credentials")) +} + +func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) { + t.Parallel() + + m := newCORSTestMiddleware("https://app.example.com") + handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://evil.com") + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"), + "expected no CORS headers for non-matching origin") +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 0b8179b..32a121c 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -177,17 +177,48 @@ func realIP(r *http.Request) string { } // CORS returns CORS middleware. +// When UPAAS_CORS_ORIGINS is empty (default), no CORS headers are sent +// (same-origin only). When configured, only the specified origins are +// allowed and credentials (cookies) are permitted. func (m *Middleware) CORS() func(http.Handler) http.Handler { + origins := parseCORSOrigins(m.params.Config.CORSOrigins) + + // No origins configured — no CORS headers (same-origin policy). + if len(origins) == 0 { + return func(next http.Handler) http.Handler { + return next + } + } + return cors.Handler(cors.Options{ - AllowedOrigins: []string{"*"}, + AllowedOrigins: origins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, - AllowCredentials: false, + AllowCredentials: true, MaxAge: corsMaxAge, }) } +// parseCORSOrigins splits a comma-separated origin string into a slice, +// trimming whitespace. Returns nil if the input is empty. +func parseCORSOrigins(raw string) []string { + if raw == "" { + return nil + } + + parts := strings.Split(raw, ",") + origins := make([]string, 0, len(parts)) + + for _, p := range parts { + if o := strings.TrimSpace(p); o != "" { + origins = append(origins, o) + } + } + + return origins +} + // MetricsAuth returns basic auth middleware for metrics endpoint. func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler { if m.params.Config.MetricsUsername == "" { @@ -235,9 +266,9 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler { // loginRateLimit configures the login rate limiter. const ( loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds - loginBurst = 5 // allow burst of 5 - limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes - limiterCleanupEvery = 1 * time.Minute // sweep interval + loginBurst = 5 // allow burst of 5 + limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes + limiterCleanupEvery = 1 * time.Minute // sweep interval ) // ipLimiterEntry stores a rate limiter with its last-seen timestamp. @@ -249,8 +280,8 @@ type ipLimiterEntry struct { // ipLimiter tracks per-IP rate limiters for login attempts with automatic // eviction of stale entries to prevent unbounded memory growth. type ipLimiter struct { - mu sync.Mutex - limiters map[string]*ipLimiterEntry + mu sync.Mutex + limiters map[string]*ipLimiterEntry lastSweep time.Time } diff --git a/internal/models/app.go b/internal/models/app.go index 03f5031..593cf8d 100644 --- a/internal/models/app.go +++ b/internal/models/app.go @@ -32,23 +32,23 @@ const ( type App struct { db *database.Database - ID string - Name string - RepoURL string - Branch string - DockerfilePath string + ID string + Name string + RepoURL string + Branch string + DockerfilePath string WebhookSecret string WebhookSecretHash string SSHPrivateKey string - SSHPublicKey string - ImageID sql.NullString - PreviousImageID sql.NullString - Status AppStatus - DockerNetwork sql.NullString - NtfyTopic sql.NullString - SlackWebhook sql.NullString - CreatedAt time.Time - UpdatedAt time.Time + SSHPublicKey string + ImageID sql.NullString + PreviousImageID sql.NullString + Status AppStatus + DockerNetwork sql.NullString + NtfyTopic sql.NullString + SlackWebhook sql.NullString + CreatedAt time.Time + UpdatedAt time.Time } // NewApp creates a new App with a database reference. diff --git a/internal/server/routes.go b/internal/server/routes.go index 21e4d3d..3339bb8 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -54,51 +54,51 @@ func (s *Server) SetupRoutes() { r.Group(func(r chi.Router) { r.Use(s.mw.SessionAuth()) - // Dashboard - r.Get("/", s.handlers.HandleDashboard()) + // Dashboard + r.Get("/", s.handlers.HandleDashboard()) - // Logout - r.Post("/logout", s.handlers.HandleLogout()) + // Logout + r.Post("/logout", s.handlers.HandleLogout()) - // App routes - r.Get("/apps/new", s.handlers.HandleAppNew()) - r.Post("/apps", s.handlers.HandleAppCreate()) - r.Get("/apps/{id}", s.handlers.HandleAppDetail()) - r.Get("/apps/{id}/edit", s.handlers.HandleAppEdit()) - r.Post("/apps/{id}", s.handlers.HandleAppUpdate()) - r.Post("/apps/{id}/delete", s.handlers.HandleAppDelete()) - r.Post("/apps/{id}/deploy", s.handlers.HandleAppDeploy()) - r.Post("/apps/{id}/deployments/cancel", s.handlers.HandleCancelDeploy()) - r.Get("/apps/{id}/deployments", s.handlers.HandleAppDeployments()) - r.Get("/apps/{id}/deployments/{deploymentID}/logs", s.handlers.HandleDeploymentLogsAPI()) - r.Get("/apps/{id}/deployments/{deploymentID}/download", s.handlers.HandleDeploymentLogDownload()) - r.Get("/apps/{id}/logs", s.handlers.HandleAppLogs()) - r.Get("/apps/{id}/container-logs", s.handlers.HandleContainerLogsAPI()) - r.Get("/apps/{id}/status", s.handlers.HandleAppStatusAPI()) - r.Get("/apps/{id}/recent-deployments", s.handlers.HandleRecentDeploymentsAPI()) - r.Post("/apps/{id}/rollback", s.handlers.HandleAppRollback()) - r.Post("/apps/{id}/restart", s.handlers.HandleAppRestart()) - r.Post("/apps/{id}/stop", s.handlers.HandleAppStop()) - r.Post("/apps/{id}/start", s.handlers.HandleAppStart()) + // App routes + r.Get("/apps/new", s.handlers.HandleAppNew()) + r.Post("/apps", s.handlers.HandleAppCreate()) + r.Get("/apps/{id}", s.handlers.HandleAppDetail()) + r.Get("/apps/{id}/edit", s.handlers.HandleAppEdit()) + r.Post("/apps/{id}", s.handlers.HandleAppUpdate()) + r.Post("/apps/{id}/delete", s.handlers.HandleAppDelete()) + r.Post("/apps/{id}/deploy", s.handlers.HandleAppDeploy()) + r.Post("/apps/{id}/deployments/cancel", s.handlers.HandleCancelDeploy()) + r.Get("/apps/{id}/deployments", s.handlers.HandleAppDeployments()) + r.Get("/apps/{id}/deployments/{deploymentID}/logs", s.handlers.HandleDeploymentLogsAPI()) + r.Get("/apps/{id}/deployments/{deploymentID}/download", s.handlers.HandleDeploymentLogDownload()) + r.Get("/apps/{id}/logs", s.handlers.HandleAppLogs()) + r.Get("/apps/{id}/container-logs", s.handlers.HandleContainerLogsAPI()) + r.Get("/apps/{id}/status", s.handlers.HandleAppStatusAPI()) + r.Get("/apps/{id}/recent-deployments", s.handlers.HandleRecentDeploymentsAPI()) + r.Post("/apps/{id}/rollback", s.handlers.HandleAppRollback()) + r.Post("/apps/{id}/restart", s.handlers.HandleAppRestart()) + r.Post("/apps/{id}/stop", s.handlers.HandleAppStop()) + r.Post("/apps/{id}/start", s.handlers.HandleAppStart()) - // Environment variables - r.Post("/apps/{id}/env-vars", s.handlers.HandleEnvVarAdd()) - r.Post("/apps/{id}/env-vars/{varID}/edit", s.handlers.HandleEnvVarEdit()) - r.Post("/apps/{id}/env-vars/{varID}/delete", s.handlers.HandleEnvVarDelete()) + // Environment variables + r.Post("/apps/{id}/env-vars", s.handlers.HandleEnvVarAdd()) + r.Post("/apps/{id}/env-vars/{varID}/edit", s.handlers.HandleEnvVarEdit()) + r.Post("/apps/{id}/env-vars/{varID}/delete", s.handlers.HandleEnvVarDelete()) - // Labels - r.Post("/apps/{id}/labels", s.handlers.HandleLabelAdd()) - r.Post("/apps/{id}/labels/{labelID}/edit", s.handlers.HandleLabelEdit()) - r.Post("/apps/{id}/labels/{labelID}/delete", s.handlers.HandleLabelDelete()) + // Labels + r.Post("/apps/{id}/labels", s.handlers.HandleLabelAdd()) + r.Post("/apps/{id}/labels/{labelID}/edit", s.handlers.HandleLabelEdit()) + r.Post("/apps/{id}/labels/{labelID}/delete", s.handlers.HandleLabelDelete()) - // Volumes - r.Post("/apps/{id}/volumes", s.handlers.HandleVolumeAdd()) - r.Post("/apps/{id}/volumes/{volumeID}/edit", s.handlers.HandleVolumeEdit()) - r.Post("/apps/{id}/volumes/{volumeID}/delete", s.handlers.HandleVolumeDelete()) + // Volumes + r.Post("/apps/{id}/volumes", s.handlers.HandleVolumeAdd()) + r.Post("/apps/{id}/volumes/{volumeID}/edit", s.handlers.HandleVolumeEdit()) + r.Post("/apps/{id}/volumes/{volumeID}/delete", s.handlers.HandleVolumeDelete()) - // Ports - r.Post("/apps/{id}/ports", s.handlers.HandlePortAdd()) - r.Post("/apps/{id}/ports/{portID}/delete", s.handlers.HandlePortDelete()) + // Ports + r.Post("/apps/{id}/ports", s.handlers.HandlePortAdd()) + r.Post("/apps/{id}/ports/{portID}/delete", s.handlers.HandlePortDelete()) }) })