9 Commits

Author SHA1 Message Date
08377058c2 Merge branch 'main' into fix/repo-url-validation 2026-02-20 05:24:34 +01:00
06e8e66443 Merge pull request 'fix: clean up orphan resources on deploy cancellation (closes #89)' (#93) from fix/deploy-cancel-cleanup into main
Reviewed-on: #93
2026-02-20 05:22:58 +01:00
clawbot
95a690e805 fix: use strings.HasPrefix instead of manual slice comparison
- Replace entry.Name()[:len(prefix)] == prefix with strings.HasPrefix
- Applied consistently in both deploy.go and export_test.go
2026-02-19 20:17:27 -08:00
clawbot
02f0a12626 fix: restrict SCP-like URLs to git user only and reject path traversal
- Changed SCP regex to only accept 'git' as the username
- Added path traversal check: reject URLs containing '..'
- Added test cases for non-git users and path traversal
2026-02-19 20:17:25 -08:00
clawbot
802518b917 fix: clean up orphan resources on deploy cancellation (closes #89) 2026-02-19 20:15:22 -08:00
clawbot
9f2d62da05 fix: validate repo URL format on app creation (closes #88) 2026-02-19 20:15:19 -08:00
b47f871412 Merge pull request 'fix: restrict CORS to configured origins (closes #40)' (#92) from fix/cors-wildcard into main
Reviewed-on: #92
2026-02-20 05:11:33 +01:00
clawbot
02847eea92 fix: restrict CORS to configured origins (closes #40)
- Add CORSOrigins config field (UPAAS_CORS_ORIGINS env var)
- Default to same-origin only (no CORS headers when unconfigured)
- When configured, allow specified origins with AllowCredentials: true
- Add tests for CORS middleware behavior
2026-02-19 13:45:18 -08:00
clawbot
506c795f16 test: add CORS middleware tests (failing - TDD) 2026-02-19 13:43:33 -08:00
13 changed files with 511 additions and 65 deletions

View File

@@ -52,6 +52,7 @@ type Config struct {
MetricsUsername string
MetricsPassword string
SessionSecret string
CORSOrigins string
params *Params
log *slog.Logger
}
@@ -102,6 +103,7 @@ func setupViper(name string) {
viper.SetDefault("METRICS_USERNAME", "")
viper.SetDefault("METRICS_PASSWORD", "")
viper.SetDefault("SESSION_SECRET", "")
viper.SetDefault("CORS_ORIGINS", "")
}
func buildConfig(log *slog.Logger, params *Params) (*Config, error) {
@@ -136,6 +138,7 @@ func buildConfig(log *slog.Logger, params *Params) (*Config, error) {
MetricsUsername: viper.GetString("METRICS_USERNAME"),
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
SessionSecret: viper.GetString("SESSION_SECRET"),
CORSOrigins: viper.GetString("CORS_ORIGINS"),
params: params,
log: log,
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
@@ -739,6 +740,20 @@ func (c *Client) connect(ctx context.Context) error {
return nil
}
// RemoveImage removes a Docker image by ID or tag.
// It returns nil if the image was successfully removed or does not exist.
func (c *Client) RemoveImage(ctx context.Context, imageID string) error {
_, err := c.docker.ImageRemove(ctx, imageID, image.RemoveOptions{
Force: true,
PruneChildren: true,
})
if err != nil && !client.IsErrNotFound(err) {
return fmt.Errorf("failed to remove image %s: %w", imageID, err)
}
return nil
}
func (c *Client) close() error {
if c.docker != nil {
err := c.docker.Close()

View File

@@ -218,6 +218,15 @@ func (h *Handlers) HandleAPICreateApp() http.HandlerFunc {
return
}
repoURLErr := validateRepoURL(req.RepoURL)
if repoURLErr != nil {
h.respondJSON(writer, request,
map[string]string{"error": "invalid repository URL: " + repoURLErr.Error()},
http.StatusBadRequest)
return
}
createdApp, createErr := h.appService.CreateApp(request.Context(), app.CreateAppInput{
Name: req.Name,
RepoURL: req.RepoURL,

View File

@@ -77,6 +77,14 @@ func (h *Handlers) HandleAppCreate() http.HandlerFunc { //nolint:funlen // valid
return
}
repoURLErr := validateRepoURL(repoURL)
if repoURLErr != nil {
data["Error"] = "Invalid repository URL: " + repoURLErr.Error()
h.renderTemplate(writer, tmpl, "app_new.html", data)
return
}
if branch == "" {
branch = "main"
}
@@ -225,6 +233,17 @@ func (h *Handlers) HandleAppUpdate() http.HandlerFunc { //nolint:funlen // valid
return
}
repoURLErr := validateRepoURL(request.FormValue("repo_url"))
if repoURLErr != nil {
data := h.addGlobals(map[string]any{
"App": application,
"Error": "Invalid repository URL: " + repoURLErr.Error(),
}, request)
_ = tmpl.ExecuteTemplate(writer, "app_edit.html", data)
return
}
application.Name = newName
application.RepoURL = request.FormValue("repo_url")
application.Branch = request.FormValue("branch")

View File

@@ -0,0 +1,67 @@
package handlers
import (
"errors"
"net/url"
"regexp"
"strings"
)
// Repo URL validation errors.
var (
errRepoURLEmpty = errors.New("repository URL must not be empty")
errRepoURLScheme = errors.New("file:// URLs are not allowed for security reasons")
errRepoURLInvalid = errors.New("repository URL must use https://, http://, ssh://, git://, or git@host:path format")
errRepoURLNoHost = errors.New("repository URL must include a host")
errRepoURLNoPath = errors.New("repository URL must include a path")
)
// scpLikeRepoRe matches SCP-like git URLs: git@host:path (e.g. git@github.com:user/repo.git).
// Only the "git" user is allowed, as that is the standard for SSH deploy keys.
var scpLikeRepoRe = regexp.MustCompile(`^git@[a-zA-Z0-9._-]+:.+$`)
// validateRepoURL checks that the given repository URL is valid and uses an allowed scheme.
func validateRepoURL(repoURL string) error {
if strings.TrimSpace(repoURL) == "" {
return errRepoURLEmpty
}
// Reject path traversal in any URL format
if strings.Contains(repoURL, "..") {
return errRepoURLInvalid
}
// Check for SCP-like git URLs first (git@host:path)
if scpLikeRepoRe.MatchString(repoURL) {
return nil
}
// Reject file:// explicitly
if strings.HasPrefix(strings.ToLower(repoURL), "file://") {
return errRepoURLScheme
}
// Parse as standard URL
parsed, err := url.Parse(repoURL)
if err != nil {
return errRepoURLInvalid
}
// Must have a recognized scheme
switch strings.ToLower(parsed.Scheme) {
case "https", "http", "ssh", "git":
// OK
default:
return errRepoURLInvalid
}
if parsed.Host == "" {
return errRepoURLNoHost
}
if parsed.Path == "" || parsed.Path == "/" {
return errRepoURLNoPath
}
return nil
}

View File

@@ -0,0 +1,56 @@
package handlers
import "testing"
func TestValidateRepoURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
wantErr bool
}{
// Valid URLs
{name: "https URL", url: "https://github.com/user/repo.git", wantErr: false},
{name: "http URL", url: "http://github.com/user/repo.git", wantErr: false},
{name: "ssh URL", url: "ssh://git@github.com/user/repo.git", wantErr: false},
{name: "git URL", url: "git://github.com/user/repo.git", wantErr: false},
{name: "SCP-like URL", url: "git@github.com:user/repo.git", wantErr: false},
{name: "SCP-like with dots", url: "git@git.example.com:org/repo.git", wantErr: false},
{name: "https without .git", url: "https://github.com/user/repo", wantErr: false},
{name: "https with port", url: "https://git.example.com:8443/user/repo.git", wantErr: false},
// Invalid URLs
{name: "empty string", url: "", wantErr: true},
{name: "whitespace only", url: " ", wantErr: true},
{name: "file URL", url: "file:///etc/passwd", wantErr: true},
{name: "file URL uppercase", url: "FILE:///etc/passwd", wantErr: true},
{name: "bare path", url: "/some/local/path", wantErr: true},
{name: "relative path", url: "../repo", wantErr: true},
{name: "just a word", url: "notaurl", wantErr: true},
{name: "ftp URL", url: "ftp://example.com/repo.git", wantErr: true},
{name: "no host https", url: "https:///path", wantErr: true},
{name: "no path https", url: "https://github.com", wantErr: true},
{name: "no path https trailing slash", url: "https://github.com/", wantErr: true},
{name: "SCP-like non-git user", url: "root@github.com:user/repo.git", wantErr: true},
{name: "SCP-like arbitrary user", url: "admin@github.com:user/repo.git", wantErr: true},
{name: "path traversal SCP", url: "git@github.com:../../etc/passwd", wantErr: true},
{name: "path traversal https", url: "https://github.com/user/../../../etc/passwd", wantErr: true},
{name: "path traversal in middle", url: "https://github.com/user/repo/../secret", wantErr: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := validateRepoURL(tc.url)
if tc.wantErr && err == nil {
t.Errorf("validateRepoURL(%q) = nil, want error", tc.url)
}
if !tc.wantErr && err != nil {
t.Errorf("validateRepoURL(%q) = %v, want nil", tc.url, err)
}
})
}
}

View File

@@ -0,0 +1,81 @@
package middleware //nolint:testpackage // tests internal CORS behavior
import (
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"git.eeqj.de/sneak/upaas/internal/config"
)
//nolint:gosec // test credentials
func newCORSTestMiddleware(corsOrigins string) *Middleware {
return &Middleware{
log: slog.Default(),
params: &Params{
Config: &config.Config{
CORSOrigins: corsOrigins,
SessionSecret: "test-secret-32-bytes-long-enough",
},
},
}
}
func TestCORS_NoOriginsConfigured_NoCORSHeaders(t *testing.T) {
t.Parallel()
m := newCORSTestMiddleware("")
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://evil.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"),
"expected no CORS headers when no origins configured")
}
func TestCORS_OriginsConfigured_AllowsMatchingOrigin(t *testing.T) {
t.Parallel()
m := newCORSTestMiddleware("https://app.example.com,https://other.example.com")
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://app.example.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, "https://app.example.com",
rec.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true",
rec.Header().Get("Access-Control-Allow-Credentials"))
}
func TestCORS_OriginsConfigured_RejectsNonMatchingOrigin(t *testing.T) {
t.Parallel()
m := newCORSTestMiddleware("https://app.example.com")
handler := m.CORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "https://evil.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"),
"expected no CORS headers for non-matching origin")
}

View File

@@ -177,17 +177,48 @@ func realIP(r *http.Request) string {
}
// CORS returns CORS middleware.
// When UPAAS_CORS_ORIGINS is empty (default), no CORS headers are sent
// (same-origin only). When configured, only the specified origins are
// allowed and credentials (cookies) are permitted.
func (m *Middleware) CORS() func(http.Handler) http.Handler {
origins := parseCORSOrigins(m.params.Config.CORSOrigins)
// No origins configured — no CORS headers (same-origin policy).
if len(origins) == 0 {
return func(next http.Handler) http.Handler {
return next
}
}
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedOrigins: origins,
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
AllowCredentials: true,
MaxAge: corsMaxAge,
})
}
// parseCORSOrigins splits a comma-separated origin string into a slice,
// trimming whitespace. Returns nil if the input is empty.
func parseCORSOrigins(raw string) []string {
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
origins := make([]string, 0, len(parts))
for _, p := range parts {
if o := strings.TrimSpace(p); o != "" {
origins = append(origins, o)
}
}
return origins
}
// MetricsAuth returns basic auth middleware for metrics endpoint.
func (m *Middleware) MetricsAuth() func(http.Handler) http.Handler {
if m.params.Config.MetricsUsername == "" {
@@ -235,9 +266,9 @@ func (m *Middleware) CSRF() func(http.Handler) http.Handler {
// loginRateLimit configures the login rate limiter.
const (
loginRateLimit = rate.Limit(5.0 / 60.0) // 5 requests per 60 seconds
loginBurst = 5 // allow burst of 5
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
limiterCleanupEvery = 1 * time.Minute // sweep interval
loginBurst = 5 // allow burst of 5
limiterExpiry = 10 * time.Minute // evict entries not seen in 10 minutes
limiterCleanupEvery = 1 * time.Minute // sweep interval
)
// ipLimiterEntry stores a rate limiter with its last-seen timestamp.
@@ -249,8 +280,8 @@ type ipLimiterEntry struct {
// ipLimiter tracks per-IP rate limiters for login attempts with automatic
// eviction of stale entries to prevent unbounded memory growth.
type ipLimiter struct {
mu sync.Mutex
limiters map[string]*ipLimiterEntry
mu sync.Mutex
limiters map[string]*ipLimiterEntry
lastSweep time.Time
}

View File

@@ -32,23 +32,23 @@ const (
type App struct {
db *database.Database
ID string
Name string
RepoURL string
Branch string
DockerfilePath string
ID string
Name string
RepoURL string
Branch string
DockerfilePath string
WebhookSecret string
WebhookSecretHash string
SSHPrivateKey string
SSHPublicKey string
ImageID sql.NullString
PreviousImageID sql.NullString
Status AppStatus
DockerNetwork sql.NullString
NtfyTopic sql.NullString
SlackWebhook sql.NullString
CreatedAt time.Time
UpdatedAt time.Time
SSHPublicKey string
ImageID sql.NullString
PreviousImageID sql.NullString
Status AppStatus
DockerNetwork sql.NullString
NtfyTopic sql.NullString
SlackWebhook sql.NullString
CreatedAt time.Time
UpdatedAt time.Time
}
// NewApp creates a new App with a database reference.

View File

@@ -54,51 +54,51 @@ func (s *Server) SetupRoutes() {
r.Group(func(r chi.Router) {
r.Use(s.mw.SessionAuth())
// Dashboard
r.Get("/", s.handlers.HandleDashboard())
// Dashboard
r.Get("/", s.handlers.HandleDashboard())
// Logout
r.Post("/logout", s.handlers.HandleLogout())
// Logout
r.Post("/logout", s.handlers.HandleLogout())
// App routes
r.Get("/apps/new", s.handlers.HandleAppNew())
r.Post("/apps", s.handlers.HandleAppCreate())
r.Get("/apps/{id}", s.handlers.HandleAppDetail())
r.Get("/apps/{id}/edit", s.handlers.HandleAppEdit())
r.Post("/apps/{id}", s.handlers.HandleAppUpdate())
r.Post("/apps/{id}/delete", s.handlers.HandleAppDelete())
r.Post("/apps/{id}/deploy", s.handlers.HandleAppDeploy())
r.Post("/apps/{id}/deployments/cancel", s.handlers.HandleCancelDeploy())
r.Get("/apps/{id}/deployments", s.handlers.HandleAppDeployments())
r.Get("/apps/{id}/deployments/{deploymentID}/logs", s.handlers.HandleDeploymentLogsAPI())
r.Get("/apps/{id}/deployments/{deploymentID}/download", s.handlers.HandleDeploymentLogDownload())
r.Get("/apps/{id}/logs", s.handlers.HandleAppLogs())
r.Get("/apps/{id}/container-logs", s.handlers.HandleContainerLogsAPI())
r.Get("/apps/{id}/status", s.handlers.HandleAppStatusAPI())
r.Get("/apps/{id}/recent-deployments", s.handlers.HandleRecentDeploymentsAPI())
r.Post("/apps/{id}/rollback", s.handlers.HandleAppRollback())
r.Post("/apps/{id}/restart", s.handlers.HandleAppRestart())
r.Post("/apps/{id}/stop", s.handlers.HandleAppStop())
r.Post("/apps/{id}/start", s.handlers.HandleAppStart())
// App routes
r.Get("/apps/new", s.handlers.HandleAppNew())
r.Post("/apps", s.handlers.HandleAppCreate())
r.Get("/apps/{id}", s.handlers.HandleAppDetail())
r.Get("/apps/{id}/edit", s.handlers.HandleAppEdit())
r.Post("/apps/{id}", s.handlers.HandleAppUpdate())
r.Post("/apps/{id}/delete", s.handlers.HandleAppDelete())
r.Post("/apps/{id}/deploy", s.handlers.HandleAppDeploy())
r.Post("/apps/{id}/deployments/cancel", s.handlers.HandleCancelDeploy())
r.Get("/apps/{id}/deployments", s.handlers.HandleAppDeployments())
r.Get("/apps/{id}/deployments/{deploymentID}/logs", s.handlers.HandleDeploymentLogsAPI())
r.Get("/apps/{id}/deployments/{deploymentID}/download", s.handlers.HandleDeploymentLogDownload())
r.Get("/apps/{id}/logs", s.handlers.HandleAppLogs())
r.Get("/apps/{id}/container-logs", s.handlers.HandleContainerLogsAPI())
r.Get("/apps/{id}/status", s.handlers.HandleAppStatusAPI())
r.Get("/apps/{id}/recent-deployments", s.handlers.HandleRecentDeploymentsAPI())
r.Post("/apps/{id}/rollback", s.handlers.HandleAppRollback())
r.Post("/apps/{id}/restart", s.handlers.HandleAppRestart())
r.Post("/apps/{id}/stop", s.handlers.HandleAppStop())
r.Post("/apps/{id}/start", s.handlers.HandleAppStart())
// Environment variables
r.Post("/apps/{id}/env-vars", s.handlers.HandleEnvVarAdd())
r.Post("/apps/{id}/env-vars/{varID}/edit", s.handlers.HandleEnvVarEdit())
r.Post("/apps/{id}/env-vars/{varID}/delete", s.handlers.HandleEnvVarDelete())
// Environment variables
r.Post("/apps/{id}/env-vars", s.handlers.HandleEnvVarAdd())
r.Post("/apps/{id}/env-vars/{varID}/edit", s.handlers.HandleEnvVarEdit())
r.Post("/apps/{id}/env-vars/{varID}/delete", s.handlers.HandleEnvVarDelete())
// Labels
r.Post("/apps/{id}/labels", s.handlers.HandleLabelAdd())
r.Post("/apps/{id}/labels/{labelID}/edit", s.handlers.HandleLabelEdit())
r.Post("/apps/{id}/labels/{labelID}/delete", s.handlers.HandleLabelDelete())
// Labels
r.Post("/apps/{id}/labels", s.handlers.HandleLabelAdd())
r.Post("/apps/{id}/labels/{labelID}/edit", s.handlers.HandleLabelEdit())
r.Post("/apps/{id}/labels/{labelID}/delete", s.handlers.HandleLabelDelete())
// Volumes
r.Post("/apps/{id}/volumes", s.handlers.HandleVolumeAdd())
r.Post("/apps/{id}/volumes/{volumeID}/edit", s.handlers.HandleVolumeEdit())
r.Post("/apps/{id}/volumes/{volumeID}/delete", s.handlers.HandleVolumeDelete())
// Volumes
r.Post("/apps/{id}/volumes", s.handlers.HandleVolumeAdd())
r.Post("/apps/{id}/volumes/{volumeID}/edit", s.handlers.HandleVolumeEdit())
r.Post("/apps/{id}/volumes/{volumeID}/delete", s.handlers.HandleVolumeDelete())
// Ports
r.Post("/apps/{id}/ports", s.handlers.HandlePortAdd())
r.Post("/apps/{id}/ports/{portID}/delete", s.handlers.HandlePortDelete())
// Ports
r.Post("/apps/{id}/ports", s.handlers.HandlePortAdd())
r.Post("/apps/{id}/ports/{portID}/delete", s.handlers.HandlePortDelete())
})
})

View File

@@ -11,6 +11,7 @@ import (
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -82,7 +83,7 @@ type deploymentLogWriter struct {
lineBuffer bytes.Buffer // buffer for incomplete lines
mu sync.Mutex
done chan struct{}
flushed sync.WaitGroup // waits for flush goroutine to finish
flushed sync.WaitGroup // waits for flush goroutine to finish
flushCtx context.Context //nolint:containedctx // needed for async flush goroutine
}
@@ -472,7 +473,7 @@ func (svc *Service) runBuildAndDeploy(
// Build phase with timeout
imageID, err := svc.buildImageWithTimeout(deployCtx, app, deployment)
if err != nil {
cancelErr := svc.checkCancelled(deployCtx, bgCtx, app, deployment)
cancelErr := svc.checkCancelled(deployCtx, bgCtx, app, deployment, "")
if cancelErr != nil {
return cancelErr
}
@@ -485,7 +486,7 @@ func (svc *Service) runBuildAndDeploy(
// Deploy phase with timeout
err = svc.deployContainerWithTimeout(deployCtx, app, deployment, imageID)
if err != nil {
cancelErr := svc.checkCancelled(deployCtx, bgCtx, app, deployment)
cancelErr := svc.checkCancelled(deployCtx, bgCtx, app, deployment, imageID)
if cancelErr != nil {
return cancelErr
}
@@ -661,24 +662,76 @@ func (svc *Service) cancelActiveDeploy(appID string) {
}
// checkCancelled checks if the deploy context was cancelled (by a newer deploy)
// and if so, marks the deployment as cancelled. Returns ErrDeployCancelled or nil.
// and if so, marks the deployment as cancelled and cleans up orphan resources.
// Returns ErrDeployCancelled or nil.
func (svc *Service) checkCancelled(
deployCtx context.Context,
bgCtx context.Context,
app *models.App,
deployment *models.Deployment,
imageID string,
) error {
if !errors.Is(deployCtx.Err(), context.Canceled) {
return nil
}
svc.log.Info("deployment cancelled by newer deploy", "app", app.Name)
svc.log.Info("deployment cancelled", "app", app.Name)
svc.cleanupCancelledDeploy(bgCtx, app, deployment, imageID)
_ = deployment.MarkFinished(bgCtx, models.DeploymentStatusCancelled)
return ErrDeployCancelled
}
// cleanupCancelledDeploy removes orphan resources left by a cancelled deployment.
func (svc *Service) cleanupCancelledDeploy(
ctx context.Context,
app *models.App,
deployment *models.Deployment,
imageID string,
) {
// Clean up the intermediate Docker image if one was built
if imageID != "" {
removeErr := svc.docker.RemoveImage(ctx, imageID)
if removeErr != nil {
svc.log.Error("failed to remove image from cancelled deploy",
"error", removeErr, "app", app.Name, "image", imageID)
_ = deployment.AppendLog(ctx, "WARNING: failed to clean up image "+imageID+": "+removeErr.Error())
} else {
svc.log.Info("cleaned up image from cancelled deploy",
"app", app.Name, "image", imageID)
_ = deployment.AppendLog(ctx, "Cleaned up intermediate image: "+imageID)
}
}
// Clean up the build directory for this deployment
buildDir := svc.GetBuildDir(app.Name)
entries, err := os.ReadDir(buildDir)
if err != nil {
return
}
prefix := fmt.Sprintf("%d-", deployment.ID)
for _, entry := range entries {
if entry.IsDir() && strings.HasPrefix(entry.Name(), prefix) {
dirPath := filepath.Join(buildDir, entry.Name())
removeErr := os.RemoveAll(dirPath)
if removeErr != nil {
svc.log.Error("failed to remove build dir from cancelled deploy",
"error", removeErr, "path", dirPath)
} else {
svc.log.Info("cleaned up build dir from cancelled deploy",
"app", app.Name, "path", dirPath)
_ = deployment.AppendLog(ctx, "Cleaned up build directory")
}
}
}
}
func (svc *Service) fetchWebhookEvent(
ctx context.Context,
webhookEventID *int64,

View File

@@ -0,0 +1,63 @@
package deploy_test
import (
"context"
"log/slog"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.eeqj.de/sneak/upaas/internal/config"
"git.eeqj.de/sneak/upaas/internal/service/deploy"
)
func TestCleanupCancelledDeploy_RemovesBuildDir(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
cfg := &config.Config{DataDir: tmpDir}
svc := deploy.NewTestServiceWithConfig(slog.Default(), cfg, nil)
// Create a fake build directory matching the deployment pattern
appName := "test-app"
buildDir := svc.GetBuildDirExported(appName)
require.NoError(t, os.MkdirAll(buildDir, 0o750))
// Create deployment-specific dir: <deploymentID>-<random>
deployDir := filepath.Join(buildDir, "42-abc123")
require.NoError(t, os.MkdirAll(deployDir, 0o750))
// Create a file inside to verify full removal
require.NoError(t, os.WriteFile(filepath.Join(deployDir, "work"), []byte("test"), 0o640))
// Also create a dir for a different deployment (should NOT be removed)
otherDir := filepath.Join(buildDir, "99-xyz789")
require.NoError(t, os.MkdirAll(otherDir, 0o750))
// Run cleanup for deployment 42
svc.CleanupCancelledDeploy(context.Background(), appName, 42, "")
// Deployment 42's dir should be gone
_, err := os.Stat(deployDir)
assert.True(t, os.IsNotExist(err), "deployment build dir should be removed")
// Deployment 99's dir should still exist
_, err = os.Stat(otherDir)
assert.NoError(t, err, "other deployment build dir should not be removed")
}
func TestCleanupCancelledDeploy_NoBuildDir(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
cfg := &config.Config{DataDir: tmpDir}
svc := deploy.NewTestServiceWithConfig(slog.Default(), cfg, nil)
// Should not panic when build dir doesn't exist
svc.CleanupCancelledDeploy(context.Background(), "nonexistent-app", 1, "")
}

View File

@@ -2,7 +2,14 @@ package deploy
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"git.eeqj.de/sneak/upaas/internal/config"
"git.eeqj.de/sneak/upaas/internal/docker"
)
// NewTestService creates a Service with minimal dependencies for testing.
@@ -31,3 +38,45 @@ func (svc *Service) TryLockApp(appID string) bool {
func (svc *Service) UnlockApp(appID string) {
svc.unlockApp(appID)
}
// NewTestServiceWithConfig creates a Service with config and docker client for testing.
func NewTestServiceWithConfig(log *slog.Logger, cfg *config.Config, dockerClient *docker.Client) *Service {
return &Service{
log: log,
config: cfg,
docker: dockerClient,
}
}
// CleanupCancelledDeploy exposes the build directory cleanup portion of
// cleanupCancelledDeploy for testing. It removes build directories matching
// the deployment ID prefix.
func (svc *Service) CleanupCancelledDeploy(
ctx context.Context,
appName string,
deploymentID int64,
imageID string,
) {
// We can't create real models.App/Deployment in tests easily,
// so we test the build dir cleanup portion directly.
buildDir := svc.GetBuildDir(appName)
entries, err := os.ReadDir(buildDir)
if err != nil {
return
}
prefix := fmt.Sprintf("%d-", deploymentID)
for _, entry := range entries {
if entry.IsDir() && strings.HasPrefix(entry.Name(), prefix) {
dirPath := filepath.Join(buildDir, entry.Name())
_ = os.RemoveAll(dirPath)
}
}
}
// GetBuildDirExported exposes GetBuildDir for testing.
func (svc *Service) GetBuildDirExported(appName string) string {
return svc.GetBuildDir(appName)
}