Compare commits

...

9 Commits

Author SHA1 Message Date
dc39e5b6e2 feat: add progress bar to restore operation
Add an interactive progress bar (using schollz/progressbar) to the
file restore loop, matching the existing pattern in verify. Shows
bytes restored with ETA when output is a terminal, falls back to
structured log progress every 100 files otherwise.

Fixes #20
2026-02-19 23:54:00 -08:00
clawbot
d77ac18aaa fix: add missing printfStdout, printlnStdout, scanlnStdin, FetchBlob, and FetchAndDecryptBlob methods
These methods were referenced in main but never defined, causing compilation
failures. They were introduced by merges that assumed dependent PRs were
already merged.
2026-02-19 23:51:53 -08:00
825f25da58 Merge pull request 'Validate table name against allowlist in getTableCount (closes #27)' (#32) from fix/issue-27 into main
Reviewed-on: #32
2026-02-16 06:21:41 +01:00
162d76bb38 Merge branch 'main' into fix/issue-27 2026-02-16 06:17:51 +01:00
clawbot
bfd7334221 fix: replace table name allowlist with regex sanitization
Replace the hardcoded validTableNames allowlist with a regexp that
only allows [a-z0-9_] characters. This prevents SQL injection without
requiring maintenance of a separate allowlist when new tables are added.

Addresses review feedback from @sneak on PR #32.
2026-02-15 21:17:24 -08:00
user
9b32bf0846 fix: replace table name allowlist with regex sanitization
Replace the hardcoded validTableNames allowlist with a regexp that
only allows [a-z0-9_] characters. This prevents SQL injection without
requiring maintenance of a separate allowlist when new tables are added.

Addresses review feedback from @sneak on PR #32.
2026-02-15 21:15:49 -08:00
8adc668fa6 Merge pull request 'Prevent double-close of blobgen.Writer in CompressStream (closes #28)' (#33) from fix/issue-28 into main
Reviewed-on: #33
2026-02-16 06:04:33 +01:00
clawbot
441c441eca fix: prevent double-close of blobgen.Writer in CompressStream
CompressStream had both a defer w.Close() and an explicit w.Close() call,
causing the compressor and encryptor to be closed twice. The second close
on the zstd encoder returns an error, and the age encryptor may write
duplicate finalization bytes, potentially corrupting the output stream.

Use a closed flag to prevent the deferred close from running after the
explicit close succeeds.
2026-02-08 12:03:36 -08:00
clawbot
4d9f912a5f fix: validate table name against allowlist in getTableCount to prevent SQL injection
The getTableCount method used fmt.Sprintf to interpolate a table name directly
into a SQL query. While currently only called with hardcoded names, this is a
dangerous pattern. Added an allowlist of valid table names and return an error
for unrecognized names.
2026-02-08 12:03:18 -08:00
4 changed files with 196 additions and 73 deletions

View File

@ -51,7 +51,13 @@ func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipien
if err != nil { if err != nil {
return 0, "", fmt.Errorf("creating writer: %w", err) return 0, "", fmt.Errorf("creating writer: %w", err)
} }
defer func() { _ = w.Close() }()
closed := false
defer func() {
if !closed {
_ = w.Close()
}
}()
// Copy data // Copy data
if _, err := io.Copy(w, src); err != nil { if _, err := io.Copy(w, src); err != nil {
@ -62,6 +68,7 @@ func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipien
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
return 0, "", fmt.Errorf("closing writer: %w", err) return 0, "", fmt.Errorf("closing writer: %w", err)
} }
closed = true
return w.BytesWritten(), hex.EncodeToString(w.Sum256()), nil return w.BytesWritten(), hex.EncodeToString(w.Sum256()), nil
} }

View File

@ -111,6 +111,30 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error {
result := &RestoreResult{} result := &RestoreResult{}
blobCache := make(map[string][]byte) // Cache downloaded and decrypted blobs blobCache := make(map[string][]byte) // Cache downloaded and decrypted blobs
// Calculate total bytes for progress bar
var totalBytesExpected int64
for _, file := range files {
totalBytesExpected += file.Size
}
// Create progress bar if output is a terminal
var bar *progressbar.ProgressBar
if isTerminal() {
bar = progressbar.NewOptions64(
totalBytesExpected,
progressbar.OptionSetDescription("Restoring"),
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionShowBytes(true),
progressbar.OptionShowCount(),
progressbar.OptionSetWidth(40),
progressbar.OptionThrottle(100*time.Millisecond),
progressbar.OptionOnCompletion(func() {
fmt.Fprint(os.Stderr, "\n")
}),
progressbar.OptionSetRenderBlankState(true),
)
}
for i, file := range files { for i, file := range files {
if v.ctx.Err() != nil { if v.ctx.Err() != nil {
return v.ctx.Err() return v.ctx.Err()
@ -119,10 +143,14 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error {
if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil { if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil {
log.Error("Failed to restore file", "path", file.Path, "error", err) log.Error("Failed to restore file", "path", file.Path, "error", err)
// Continue with other files // Continue with other files
continue
} }
// Progress logging // Update progress bar
if bar != nil {
_ = bar.Add64(file.Size)
}
// Progress logging (for non-terminal or structured logs)
if (i+1)%100 == 0 || i+1 == len(files) { if (i+1)%100 == 0 || i+1 == len(files) {
log.Info("Restore progress", log.Info("Restore progress",
"files", fmt.Sprintf("%d/%d", i+1, len(files)), "files", fmt.Sprintf("%d/%d", i+1, len(files)),
@ -131,6 +159,10 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error {
} }
} }
if bar != nil {
_ = bar.Finish()
}
result.Duration = time.Since(startTime) result.Duration = time.Since(startTime)
log.Info("Restore complete", log.Info("Restore complete",
@ -473,6 +505,53 @@ func (v *Vaultik) restoreRegularFile(
return nil return nil
} }
// BlobFetchResult holds the result of fetching and decrypting a blob.
type BlobFetchResult struct {
Data []byte
CompressedSize int64
}
// FetchAndDecryptBlob downloads a blob from storage, decrypts and decompresses it.
func (v *Vaultik) FetchAndDecryptBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) (*BlobFetchResult, error) {
// Construct blob path with sharding
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
reader, err := v.Storage.Get(ctx, blobPath)
if err != nil {
return nil, fmt.Errorf("downloading blob: %w", err)
}
defer func() { _ = reader.Close() }()
// Read encrypted data
encryptedData, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("reading blob data: %w", err)
}
// Decrypt and decompress
blobReader, err := blobgen.NewReader(bytes.NewReader(encryptedData), identity)
if err != nil {
return nil, fmt.Errorf("creating decryption reader: %w", err)
}
defer func() { _ = blobReader.Close() }()
data, err := io.ReadAll(blobReader)
if err != nil {
return nil, fmt.Errorf("decrypting blob: %w", err)
}
log.Debug("Downloaded and decrypted blob",
"hash", blobHash[:16],
"encrypted_size", humanize.Bytes(uint64(len(encryptedData))),
"decrypted_size", humanize.Bytes(uint64(len(data))),
)
return &BlobFetchResult{
Data: data,
CompressedSize: int64(len(encryptedData)),
}, nil
}
// downloadBlob downloads and decrypts a blob // downloadBlob downloads and decrypts a blob
func (v *Vaultik) downloadBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) ([]byte, error) { func (v *Vaultik) downloadBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) ([]byte, error) {
result, err := v.FetchAndDecryptBlob(ctx, blobHash, expectedSize, identity) result, err := v.FetchAndDecryptBlob(ctx, blobHash, expectedSize, identity)

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strings" "strings"
"text/tabwriter" "text/tabwriter"
@ -86,7 +87,7 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
// Print overall summary if multiple snapshots // Print overall summary if multiple snapshots
if len(snapshotNames) > 1 { if len(snapshotNames) > 1 {
_, _ = fmt.Fprintf(v.Stdout, "\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second)) v.printfStdout("\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second))
} }
return nil return nil
@ -99,7 +100,7 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
snapConfig := v.Config.Snapshots[snapName] snapConfig := v.Config.Snapshots[snapName]
if total > 1 { if total > 1 {
_, _ = fmt.Fprintf(v.Stdout, "\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName) v.printfStdout("\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName)
} }
// Resolve source directories to absolute paths // Resolve source directories to absolute paths
@ -152,7 +153,7 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
return fmt.Errorf("creating snapshot: %w", err) return fmt.Errorf("creating snapshot: %w", err)
} }
log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName) log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName)
_, _ = fmt.Fprintf(v.Stdout, "Beginning snapshot: %s\n", snapshotID) v.printfStdout("Beginning snapshot: %s\n", snapshotID)
for i, dir := range resolvedDirs { for i, dir := range resolvedDirs {
// Check if context is cancelled // Check if context is cancelled
@ -164,7 +165,7 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
} }
log.Info("Scanning directory", "path", dir) log.Info("Scanning directory", "path", dir)
_, _ = fmt.Fprintf(v.Stdout, "Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir) v.printfStdout("Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir)
result, err := scanner.Scan(v.ctx, dir, snapshotID) result, err := scanner.Scan(v.ctx, dir, snapshotID)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan %s: %w", dir, err) return fmt.Errorf("failed to scan %s: %w", dir, err)
@ -275,35 +276,35 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
} }
// Print comprehensive summary // Print comprehensive summary
_, _ = fmt.Fprintf(v.Stdout, "=== Snapshot Complete ===\n") v.printfStdout("=== Snapshot Complete ===\n")
_, _ = fmt.Fprintf(v.Stdout, "ID: %s\n", snapshotID) v.printfStdout("ID: %s\n", snapshotID)
_, _ = fmt.Fprintf(v.Stdout, "Files: %s examined, %s to process, %s unchanged", v.printfStdout("Files: %s examined, %s to process, %s unchanged",
formatNumber(totalFiles), formatNumber(totalFiles),
formatNumber(totalFilesChanged), formatNumber(totalFilesChanged),
formatNumber(totalFilesSkipped)) formatNumber(totalFilesSkipped))
if totalFilesDeleted > 0 { if totalFilesDeleted > 0 {
_, _ = fmt.Fprintf(v.Stdout, ", %s deleted", formatNumber(totalFilesDeleted)) v.printfStdout(", %s deleted", formatNumber(totalFilesDeleted))
} }
_, _ = fmt.Fprintln(v.Stdout) v.printlnStdout()
_, _ = fmt.Fprintf(v.Stdout, "Data: %s total (%s to process)", v.printfStdout("Data: %s total (%s to process)",
humanize.Bytes(uint64(totalBytesAll)), humanize.Bytes(uint64(totalBytesAll)),
humanize.Bytes(uint64(totalBytesChanged))) humanize.Bytes(uint64(totalBytesChanged)))
if totalBytesDeleted > 0 { if totalBytesDeleted > 0 {
_, _ = fmt.Fprintf(v.Stdout, ", %s deleted", humanize.Bytes(uint64(totalBytesDeleted))) v.printfStdout(", %s deleted", humanize.Bytes(uint64(totalBytesDeleted)))
} }
_, _ = fmt.Fprintln(v.Stdout) v.printlnStdout()
if totalBlobsUploaded > 0 { if totalBlobsUploaded > 0 {
_, _ = fmt.Fprintf(v.Stdout, "Storage: %s compressed from %s (%.2fx)\n", v.printfStdout("Storage: %s compressed from %s (%.2fx)\n",
humanize.Bytes(uint64(totalBlobSizeCompressed)), humanize.Bytes(uint64(totalBlobSizeCompressed)),
humanize.Bytes(uint64(totalBlobSizeUncompressed)), humanize.Bytes(uint64(totalBlobSizeUncompressed)),
compressionRatio) compressionRatio)
_, _ = fmt.Fprintf(v.Stdout, "Upload: %d blobs, %s in %s (%s)\n", v.printfStdout("Upload: %d blobs, %s in %s (%s)\n",
totalBlobsUploaded, totalBlobsUploaded,
humanize.Bytes(uint64(totalBytesUploaded)), humanize.Bytes(uint64(totalBytesUploaded)),
formatDuration(uploadDuration), formatDuration(uploadDuration),
avgUploadSpeed) avgUploadSpeed)
} }
_, _ = fmt.Fprintf(v.Stdout, "Duration: %s\n", formatDuration(snapshotDuration)) v.printfStdout("Duration: %s\n", formatDuration(snapshotDuration))
if opts.Prune { if opts.Prune {
log.Info("Pruning enabled - will delete old snapshots after snapshot") log.Info("Pruning enabled - will delete old snapshots after snapshot")
@ -422,13 +423,13 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
if jsonOutput { if jsonOutput {
// JSON output // JSON output
encoder := json.NewEncoder(os.Stdout) encoder := json.NewEncoder(v.Stdout)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
return encoder.Encode(snapshots) return encoder.Encode(snapshots)
} }
// Table output // Table output
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(v.Stdout, 0, 0, 3, ' ', 0)
// Show configured snapshots from config file // Show configured snapshots from config file
if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil { if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil {
@ -527,14 +528,14 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
} }
if len(toDelete) == 0 { if len(toDelete) == 0 {
fmt.Println("No snapshots to delete") v.printlnStdout("No snapshots to delete")
return nil return nil
} }
// Show what will be deleted // Show what will be deleted
fmt.Printf("The following snapshots will be deleted:\n\n") v.printfStdout("The following snapshots will be deleted:\n\n")
for _, snap := range toDelete { for _, snap := range toDelete {
fmt.Printf(" %s (%s, %s)\n", v.printfStdout(" %s (%s, %s)\n",
snap.ID, snap.ID,
snap.Timestamp.Format("2006-01-02 15:04:05"), snap.Timestamp.Format("2006-01-02 15:04:05"),
formatBytes(snap.CompressedSize)) formatBytes(snap.CompressedSize))
@ -542,19 +543,19 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
// Confirm unless --force is used // Confirm unless --force is used
if !force { if !force {
fmt.Printf("\nDelete %d snapshot(s)? [y/N] ", len(toDelete)) v.printfStdout("\nDelete %d snapshot(s)? [y/N] ", len(toDelete))
var confirm string var confirm string
if _, err := fmt.Scanln(&confirm); err != nil { if _, err := fmt.Scanln(&confirm); err != nil {
// Treat EOF or error as "no" // Treat EOF or error as "no"
fmt.Println("Cancelled") v.printlnStdout("Cancelled")
return nil return nil
} }
if strings.ToLower(confirm) != "y" { if strings.ToLower(confirm) != "y" {
fmt.Println("Cancelled") v.printlnStdout("Cancelled")
return nil return nil
} }
} else { } else {
fmt.Printf("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete)) v.printfStdout("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete))
} }
// Delete snapshots (both local and remote) // Delete snapshots (both local and remote)
@ -569,10 +570,10 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
} }
} }
fmt.Printf("Deleted %d snapshot(s)\n", len(toDelete)) v.printfStdout("Deleted %d snapshot(s)\n", len(toDelete))
// Note: Run 'vaultik prune' separately to clean up unreferenced blobs // Note: Run 'vaultik prune' separately to clean up unreferenced blobs
fmt.Println("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.") v.printlnStdout("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.")
return nil return nil
} }
@ -613,11 +614,11 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
} }
if !opts.JSON { if !opts.JSON {
fmt.Printf("Verifying snapshot %s\n", snapshotID) v.printfStdout("Verifying snapshot %s\n", snapshotID)
if !snapshotTime.IsZero() { if !snapshotTime.IsZero() {
fmt.Printf("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST"))
} }
fmt.Println() v.printlnStdout()
} }
// Download and parse manifest // Download and parse manifest
@ -635,18 +636,18 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
result.TotalSize = manifest.TotalCompressedSize result.TotalSize = manifest.TotalCompressedSize
if !opts.JSON { if !opts.JSON {
fmt.Printf("Snapshot information:\n") v.printfStdout("Snapshot information:\n")
fmt.Printf(" Blob count: %d\n", manifest.BlobCount) v.printfStdout(" Blob count: %d\n", manifest.BlobCount)
fmt.Printf(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize))) v.printfStdout(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize)))
if manifest.Timestamp != "" { if manifest.Timestamp != "" {
if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil { if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil {
fmt.Printf(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST")) v.printfStdout(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST"))
} }
} }
fmt.Println() v.printlnStdout()
// Check each blob exists // Check each blob exists
fmt.Printf("Checking blob existence...\n") v.printfStdout("Checking blob existence...\n")
} }
missing := 0 missing := 0
@ -660,7 +661,7 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
_, err := v.Storage.Stat(v.ctx, blobPath) _, err := v.Storage.Stat(v.ctx, blobPath)
if err != nil { if err != nil {
if !opts.JSON { if !opts.JSON {
fmt.Printf(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize))) v.printfStdout(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize)))
} }
missing++ missing++
missingSize += blob.CompressedSize missingSize += blob.CompressedSize
@ -683,20 +684,20 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
return v.outputVerifyJSON(result) return v.outputVerifyJSON(result)
} }
fmt.Printf("\nVerification complete:\n") v.printfStdout("\nVerification complete:\n")
fmt.Printf(" Verified: %d blobs (%s)\n", verified, v.printfStdout(" Verified: %d blobs (%s)\n", verified,
humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize))) humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize)))
if missing > 0 { if missing > 0 {
fmt.Printf(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize))) v.printfStdout(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize)))
} else { } else {
fmt.Printf(" Missing: 0 blobs\n") v.printfStdout(" Missing: 0 blobs\n")
} }
fmt.Printf(" Status: ") v.printfStdout(" Status: ")
if missing > 0 { if missing > 0 {
fmt.Printf("FAILED - %d blobs are missing\n", missing) v.printfStdout("FAILED - %d blobs are missing\n", missing)
return fmt.Errorf("%d blobs are missing", missing) return fmt.Errorf("%d blobs are missing", missing)
} else { } else {
fmt.Printf("OK - All blobs verified\n") v.printfStdout("OK - All blobs verified\n")
} }
return nil return nil
@ -704,7 +705,7 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
// outputVerifyJSON outputs the verification result as JSON // outputVerifyJSON outputs the verification result as JSON
func (v *Vaultik) outputVerifyJSON(result *VerifyResult) error { func (v *Vaultik) outputVerifyJSON(result *VerifyResult) error {
encoder := json.NewEncoder(os.Stdout) encoder := json.NewEncoder(v.Stdout)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
if err := encoder.Encode(result); err != nil { if err := encoder.Encode(result); err != nil {
return fmt.Errorf("encoding JSON: %w", err) return fmt.Errorf("encoding JSON: %w", err)
@ -830,11 +831,11 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov
if opts.DryRun { if opts.DryRun {
result.DryRun = true result.DryRun = true
if !opts.JSON { if !opts.JSON {
_, _ = fmt.Fprintf(v.Stdout, "Would remove snapshot: %s\n", snapshotID) v.printfStdout("Would remove snapshot: %s\n", snapshotID)
if opts.Remote { if opts.Remote {
_, _ = fmt.Fprintln(v.Stdout, "Would also remove from remote storage") v.printlnStdout("Would also remove from remote storage")
} }
_, _ = fmt.Fprintln(v.Stdout, "[Dry run - no changes made]") v.printlnStdout("[Dry run - no changes made]")
} }
if opts.JSON { if opts.JSON {
return result, v.outputRemoveJSON(result) return result, v.outputRemoveJSON(result)
@ -845,17 +846,17 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov
// Confirm unless --force is used (skip in JSON mode - require --force) // Confirm unless --force is used (skip in JSON mode - require --force)
if !opts.Force && !opts.JSON { if !opts.Force && !opts.JSON {
if opts.Remote { if opts.Remote {
_, _ = fmt.Fprintf(v.Stdout, "Remove snapshot '%s' from local database and remote storage? [y/N] ", snapshotID) v.printfStdout("Remove snapshot '%s' from local database and remote storage? [y/N] ", snapshotID)
} else { } else {
_, _ = fmt.Fprintf(v.Stdout, "Remove snapshot '%s' from local database? [y/N] ", snapshotID) v.printfStdout("Remove snapshot '%s' from local database? [y/N] ", snapshotID)
} }
var confirm string var confirm string
if _, err := fmt.Fscanln(v.Stdin, &confirm); err != nil { if err := v.scanlnStdin(&confirm); err != nil {
_, _ = fmt.Fprintln(v.Stdout, "Cancelled") v.printlnStdout("Cancelled")
return result, nil return result, nil
} }
if strings.ToLower(confirm) != "y" { if strings.ToLower(confirm) != "y" {
_, _ = fmt.Fprintln(v.Stdout, "Cancelled") v.printlnStdout("Cancelled")
return result, nil return result, nil
} }
} }
@ -882,10 +883,10 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov
} }
// Print summary // Print summary
_, _ = fmt.Fprintf(v.Stdout, "Removed snapshot '%s' from local database\n", snapshotID) v.printfStdout("Removed snapshot '%s' from local database\n", snapshotID)
if opts.Remote { if opts.Remote {
_, _ = fmt.Fprintln(v.Stdout, "Removed snapshot metadata from remote storage") v.printlnStdout("Removed snapshot metadata from remote storage")
_, _ = fmt.Fprintln(v.Stdout, "\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") v.printlnStdout("\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.")
} }
return result, nil return result, nil
@ -929,7 +930,7 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error)
if len(snapshotIDs) == 0 { if len(snapshotIDs) == 0 {
if !opts.JSON { if !opts.JSON {
_, _ = fmt.Fprintln(v.Stdout, "No snapshots found") v.printlnStdout("No snapshots found")
} }
return result, nil return result, nil
} }
@ -938,14 +939,14 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error)
result.DryRun = true result.DryRun = true
result.SnapshotsRemoved = snapshotIDs result.SnapshotsRemoved = snapshotIDs
if !opts.JSON { if !opts.JSON {
_, _ = fmt.Fprintf(v.Stdout, "Would remove %d snapshot(s):\n", len(snapshotIDs)) v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs))
for _, id := range snapshotIDs { for _, id := range snapshotIDs {
_, _ = fmt.Fprintf(v.Stdout, " %s\n", id) v.printfStdout(" %s\n", id)
} }
if opts.Remote { if opts.Remote {
_, _ = fmt.Fprintln(v.Stdout, "Would also remove from remote storage") v.printlnStdout("Would also remove from remote storage")
} }
_, _ = fmt.Fprintln(v.Stdout, "[Dry run - no changes made]") v.printlnStdout("[Dry run - no changes made]")
} }
if opts.JSON { if opts.JSON {
return result, v.outputRemoveJSON(result) return result, v.outputRemoveJSON(result)
@ -986,10 +987,10 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error)
return result, v.outputRemoveJSON(result) return result, v.outputRemoveJSON(result)
} }
_, _ = fmt.Fprintf(v.Stdout, "Removed %d snapshot(s)\n", len(result.SnapshotsRemoved)) v.printfStdout("Removed %d snapshot(s)\n", len(result.SnapshotsRemoved))
if opts.Remote { if opts.Remote {
_, _ = fmt.Fprintln(v.Stdout, "Removed snapshot metadata from remote storage") v.printlnStdout("Removed snapshot metadata from remote storage")
_, _ = fmt.Fprintln(v.Stdout, "\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") v.printlnStdout("\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.")
} }
return result, nil return result, nil
@ -1043,7 +1044,7 @@ func (v *Vaultik) deleteSnapshotFromRemote(snapshotID string) error {
// outputRemoveJSON outputs the removal result as JSON // outputRemoveJSON outputs the removal result as JSON
func (v *Vaultik) outputRemoveJSON(result *RemoveResult) error { func (v *Vaultik) outputRemoveJSON(result *RemoveResult) error {
encoder := json.NewEncoder(os.Stdout) encoder := json.NewEncoder(v.Stdout)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
return encoder.Encode(result) return encoder.Encode(result)
} }
@ -1117,21 +1118,29 @@ func (v *Vaultik) PruneDatabase() (*PruneResult, error) {
) )
// Print summary // Print summary
_, _ = fmt.Fprintf(v.Stdout, "Local database prune complete:\n") v.printfStdout("Local database prune complete:\n")
_, _ = fmt.Fprintf(v.Stdout, " Incomplete snapshots removed: %d\n", result.SnapshotsDeleted) v.printfStdout(" Incomplete snapshots removed: %d\n", result.SnapshotsDeleted)
_, _ = fmt.Fprintf(v.Stdout, " Orphaned files removed: %d\n", result.FilesDeleted) v.printfStdout(" Orphaned files removed: %d\n", result.FilesDeleted)
_, _ = fmt.Fprintf(v.Stdout, " Orphaned chunks removed: %d\n", result.ChunksDeleted) v.printfStdout(" Orphaned chunks removed: %d\n", result.ChunksDeleted)
_, _ = fmt.Fprintf(v.Stdout, " Orphaned blobs removed: %d\n", result.BlobsDeleted) v.printfStdout(" Orphaned blobs removed: %d\n", result.BlobsDeleted)
return result, nil return result, nil
} }
// getTableCount returns the count of rows in a table // validTableNameRe matches table names containing only lowercase alphanumeric characters and underscores.
var validTableNameRe = regexp.MustCompile(`^[a-z0-9_]+$`)
// getTableCount returns the count of rows in a table.
// The tableName is sanitized to only allow [a-z0-9_] characters to prevent SQL injection.
func (v *Vaultik) getTableCount(tableName string) (int64, error) { func (v *Vaultik) getTableCount(tableName string) (int64, error) {
if v.DB == nil { if v.DB == nil {
return 0, nil return 0, nil
} }
if !validTableNameRe.MatchString(tableName) {
return 0, fmt.Errorf("invalid table name: %q", tableName)
}
var count int64 var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName) query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)
err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&count) err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&count)

View File

@ -135,6 +135,34 @@ func (v *Vaultik) Outputf(format string, args ...any) {
_, _ = fmt.Fprintf(v.Stdout, format, args...) _, _ = fmt.Fprintf(v.Stdout, format, args...)
} }
// printfStdout writes formatted output to stdout.
func (v *Vaultik) printfStdout(format string, args ...any) {
_, _ = fmt.Fprintf(v.Stdout, format, args...)
}
// printlnStdout writes a line to stdout.
func (v *Vaultik) printlnStdout(args ...any) {
_, _ = fmt.Fprintln(v.Stdout, args...)
}
// scanlnStdin reads a line from stdin into the provided string pointer.
func (v *Vaultik) scanlnStdin(s *string) error {
_, err := fmt.Fscanln(v.Stdin, s)
return err
}
// FetchBlob downloads a blob from storage and returns a reader for the encrypted data.
func (v *Vaultik) FetchBlob(ctx context.Context, blobHash string, expectedSize int64) (io.ReadCloser, int64, error) {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
reader, err := v.Storage.Get(ctx, blobPath)
if err != nil {
return nil, 0, fmt.Errorf("downloading blob: %w", err)
}
return reader, expectedSize, nil
}
// TestVaultik wraps a Vaultik with captured stdout/stderr for testing // TestVaultik wraps a Vaultik with captured stdout/stderr for testing
type TestVaultik struct { type TestVaultik struct {
*Vaultik *Vaultik