1 Commits

Author SHA1 Message Date
clawbot
3d3c13cd01 fix: bound blob cache during restore with LRU eviction to prevent OOM
The restore operation cached every downloaded blob in an unbounded map,
which could exhaust system memory when restoring large backups with many
unique blobs (each up to 10GB).

Replaced with an LRU cache bounded to 1 GiB by default, evicting
least-recently-used blobs when the limit is exceeded.
2026-02-08 12:04:50 -08:00
4 changed files with 159 additions and 92 deletions

View File

@@ -51,13 +51,7 @@ func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipien
if err != nil { if err != nil {
return 0, "", fmt.Errorf("creating writer: %w", err) return 0, "", fmt.Errorf("creating writer: %w", err)
} }
defer func() { _ = w.Close() }()
closed := false
defer func() {
if !closed {
_ = w.Close()
}
}()
// Copy data // Copy data
if _, err := io.Copy(w, src); err != nil { if _, err := io.Copy(w, src); err != nil {
@@ -68,7 +62,6 @@ func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipien
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
return 0, "", fmt.Errorf("closing writer: %w", err) return 0, "", fmt.Errorf("closing writer: %w", err)
} }
closed = true
return w.BytesWritten(), hex.EncodeToString(w.Sum256()), nil return w.BytesWritten(), hex.EncodeToString(w.Sum256()), nil
} }

View File

@@ -0,0 +1,83 @@
package vaultik
import (
"container/list"
"sync"
)
// defaultMaxBlobCacheBytes is the default maximum size of the blob cache (1 GB).
const defaultMaxBlobCacheBytes = 1 << 30 // 1 GiB
// blobCacheEntry is an entry in the LRU blob cache.
type blobCacheEntry struct {
key string
data []byte
}
// blobLRUCache is a simple LRU cache for downloaded blob data, bounded by
// total byte size to prevent memory exhaustion during large restores.
type blobLRUCache struct {
mu sync.Mutex
maxBytes int64
curBytes int64
ll *list.List
items map[string]*list.Element
}
// newBlobLRUCache creates a new LRU blob cache with the given maximum size in bytes.
func newBlobLRUCache(maxBytes int64) *blobLRUCache {
return &blobLRUCache{
maxBytes: maxBytes,
ll: list.New(),
items: make(map[string]*list.Element),
}
}
// Get returns the blob data for the given key, or nil if not cached.
func (c *blobLRUCache) Get(key string) ([]byte, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ele, ok := c.items[key]; ok {
c.ll.MoveToFront(ele)
return ele.Value.(*blobCacheEntry).data, true
}
return nil, false
}
// Put adds a blob to the cache, evicting least-recently-used entries if needed.
func (c *blobLRUCache) Put(key string, data []byte) {
c.mu.Lock()
defer c.mu.Unlock()
entrySize := int64(len(data))
// If this single entry exceeds max, don't cache it
if entrySize > c.maxBytes {
return
}
// If already present, update in place
if ele, ok := c.items[key]; ok {
c.ll.MoveToFront(ele)
old := ele.Value.(*blobCacheEntry)
c.curBytes += entrySize - int64(len(old.data))
old.data = data
} else {
ele := c.ll.PushFront(&blobCacheEntry{key: key, data: data})
c.items[key] = ele
c.curBytes += entrySize
}
// Evict from back until under limit
for c.curBytes > c.maxBytes && c.ll.Len() > 0 {
oldest := c.ll.Back()
if oldest == nil {
break
}
entry := oldest.Value.(*blobCacheEntry)
c.ll.Remove(oldest)
delete(c.items, entry.key)
c.curBytes -= int64(len(entry.data))
}
}

View File

@@ -109,7 +109,7 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error {
// Step 5: Restore files // Step 5: Restore files
result := &RestoreResult{} result := &RestoreResult{}
blobCache := make(map[string][]byte) // Cache downloaded and decrypted blobs blobCache := newBlobLRUCache(defaultMaxBlobCacheBytes) // LRU cache bounded to ~1 GiB
for i, file := range files { for i, file := range files {
if v.ctx.Err() != nil { if v.ctx.Err() != nil {
@@ -299,7 +299,7 @@ func (v *Vaultik) restoreFile(
targetDir string, targetDir string,
identity age.Identity, identity age.Identity,
chunkToBlobMap map[string]*database.BlobChunk, chunkToBlobMap map[string]*database.BlobChunk,
blobCache map[string][]byte, blobCache *blobLRUCache,
result *RestoreResult, result *RestoreResult,
) error { ) error {
// Calculate target path - use full original path under target directory // Calculate target path - use full original path under target directory
@@ -383,7 +383,7 @@ func (v *Vaultik) restoreRegularFile(
targetPath string, targetPath string,
identity age.Identity, identity age.Identity,
chunkToBlobMap map[string]*database.BlobChunk, chunkToBlobMap map[string]*database.BlobChunk,
blobCache map[string][]byte, blobCache *blobLRUCache,
result *RestoreResult, result *RestoreResult,
) error { ) error {
// Get file chunks in order // Get file chunks in order
@@ -417,13 +417,13 @@ func (v *Vaultik) restoreRegularFile(
// Download and decrypt blob if not cached // Download and decrypt blob if not cached
blobHashStr := blob.Hash.String() blobHashStr := blob.Hash.String()
blobData, ok := blobCache[blobHashStr] blobData, ok := blobCache.Get(blobHashStr)
if !ok { if !ok {
blobData, err = v.downloadBlob(ctx, blobHashStr, blob.CompressedSize, identity) blobData, err = v.downloadBlob(ctx, blobHashStr, blob.CompressedSize, identity)
if err != nil { if err != nil {
return fmt.Errorf("downloading blob %s: %w", blobHashStr[:16], err) return fmt.Errorf("downloading blob %s: %w", blobHashStr[:16], err)
} }
blobCache[blobHashStr] = blobData blobCache.Put(blobHashStr, blobData)
result.BlobsDownloaded++ result.BlobsDownloaded++
result.BytesDownloaded += blob.CompressedSize result.BytesDownloaded += blob.CompressedSize
} }

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"regexp"
"path/filepath" "path/filepath"
"sort" "sort"
"strings" "strings"
@@ -87,7 +86,7 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
// Print overall summary if multiple snapshots // Print overall summary if multiple snapshots
if len(snapshotNames) > 1 { if len(snapshotNames) > 1 {
v.printfStdout("\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second)) _, _ = fmt.Fprintf(v.Stdout, "\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second))
} }
return nil return nil
@@ -100,7 +99,7 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
snapConfig := v.Config.Snapshots[snapName] snapConfig := v.Config.Snapshots[snapName]
if total > 1 { if total > 1 {
v.printfStdout("\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName) _, _ = fmt.Fprintf(v.Stdout, "\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName)
} }
// Resolve source directories to absolute paths // Resolve source directories to absolute paths
@@ -153,7 +152,7 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
return fmt.Errorf("creating snapshot: %w", err) return fmt.Errorf("creating snapshot: %w", err)
} }
log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName) log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName)
v.printfStdout("Beginning snapshot: %s\n", snapshotID) _, _ = fmt.Fprintf(v.Stdout, "Beginning snapshot: %s\n", snapshotID)
for i, dir := range resolvedDirs { for i, dir := range resolvedDirs {
// Check if context is cancelled // Check if context is cancelled
@@ -165,7 +164,7 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
} }
log.Info("Scanning directory", "path", dir) log.Info("Scanning directory", "path", dir)
v.printfStdout("Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir) _, _ = fmt.Fprintf(v.Stdout, "Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir)
result, err := scanner.Scan(v.ctx, dir, snapshotID) result, err := scanner.Scan(v.ctx, dir, snapshotID)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan %s: %w", dir, err) return fmt.Errorf("failed to scan %s: %w", dir, err)
@@ -276,35 +275,35 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
} }
// Print comprehensive summary // Print comprehensive summary
v.printfStdout("=== Snapshot Complete ===\n") _, _ = fmt.Fprintf(v.Stdout, "=== Snapshot Complete ===\n")
v.printfStdout("ID: %s\n", snapshotID) _, _ = fmt.Fprintf(v.Stdout, "ID: %s\n", snapshotID)
v.printfStdout("Files: %s examined, %s to process, %s unchanged", _, _ = fmt.Fprintf(v.Stdout, "Files: %s examined, %s to process, %s unchanged",
formatNumber(totalFiles), formatNumber(totalFiles),
formatNumber(totalFilesChanged), formatNumber(totalFilesChanged),
formatNumber(totalFilesSkipped)) formatNumber(totalFilesSkipped))
if totalFilesDeleted > 0 { if totalFilesDeleted > 0 {
v.printfStdout(", %s deleted", formatNumber(totalFilesDeleted)) _, _ = fmt.Fprintf(v.Stdout, ", %s deleted", formatNumber(totalFilesDeleted))
} }
v.printlnStdout() _, _ = fmt.Fprintln(v.Stdout)
v.printfStdout("Data: %s total (%s to process)", _, _ = fmt.Fprintf(v.Stdout, "Data: %s total (%s to process)",
humanize.Bytes(uint64(totalBytesAll)), humanize.Bytes(uint64(totalBytesAll)),
humanize.Bytes(uint64(totalBytesChanged))) humanize.Bytes(uint64(totalBytesChanged)))
if totalBytesDeleted > 0 { if totalBytesDeleted > 0 {
v.printfStdout(", %s deleted", humanize.Bytes(uint64(totalBytesDeleted))) _, _ = fmt.Fprintf(v.Stdout, ", %s deleted", humanize.Bytes(uint64(totalBytesDeleted)))
} }
v.printlnStdout() _, _ = fmt.Fprintln(v.Stdout)
if totalBlobsUploaded > 0 { if totalBlobsUploaded > 0 {
v.printfStdout("Storage: %s compressed from %s (%.2fx)\n", _, _ = fmt.Fprintf(v.Stdout, "Storage: %s compressed from %s (%.2fx)\n",
humanize.Bytes(uint64(totalBlobSizeCompressed)), humanize.Bytes(uint64(totalBlobSizeCompressed)),
humanize.Bytes(uint64(totalBlobSizeUncompressed)), humanize.Bytes(uint64(totalBlobSizeUncompressed)),
compressionRatio) compressionRatio)
v.printfStdout("Upload: %d blobs, %s in %s (%s)\n", _, _ = fmt.Fprintf(v.Stdout, "Upload: %d blobs, %s in %s (%s)\n",
totalBlobsUploaded, totalBlobsUploaded,
humanize.Bytes(uint64(totalBytesUploaded)), humanize.Bytes(uint64(totalBytesUploaded)),
formatDuration(uploadDuration), formatDuration(uploadDuration),
avgUploadSpeed) avgUploadSpeed)
} }
v.printfStdout("Duration: %s\n", formatDuration(snapshotDuration)) _, _ = fmt.Fprintf(v.Stdout, "Duration: %s\n", formatDuration(snapshotDuration))
if opts.Prune { if opts.Prune {
log.Info("Pruning enabled - will delete old snapshots after snapshot") log.Info("Pruning enabled - will delete old snapshots after snapshot")
@@ -423,13 +422,13 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
if jsonOutput { if jsonOutput {
// JSON output // JSON output
encoder := json.NewEncoder(v.Stdout) encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
return encoder.Encode(snapshots) return encoder.Encode(snapshots)
} }
// Table output // Table output
w := tabwriter.NewWriter(v.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
// Show configured snapshots from config file // Show configured snapshots from config file
if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil { if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil {
@@ -528,14 +527,14 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
} }
if len(toDelete) == 0 { if len(toDelete) == 0 {
v.printlnStdout("No snapshots to delete") fmt.Println("No snapshots to delete")
return nil return nil
} }
// Show what will be deleted // Show what will be deleted
v.printfStdout("The following snapshots will be deleted:\n\n") fmt.Printf("The following snapshots will be deleted:\n\n")
for _, snap := range toDelete { for _, snap := range toDelete {
v.printfStdout(" %s (%s, %s)\n", fmt.Printf(" %s (%s, %s)\n",
snap.ID, snap.ID,
snap.Timestamp.Format("2006-01-02 15:04:05"), snap.Timestamp.Format("2006-01-02 15:04:05"),
formatBytes(snap.CompressedSize)) formatBytes(snap.CompressedSize))
@@ -543,19 +542,19 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
// Confirm unless --force is used // Confirm unless --force is used
if !force { if !force {
v.printfStdout("\nDelete %d snapshot(s)? [y/N] ", len(toDelete)) fmt.Printf("\nDelete %d snapshot(s)? [y/N] ", len(toDelete))
var confirm string var confirm string
if _, err := fmt.Scanln(&confirm); err != nil { if _, err := fmt.Scanln(&confirm); err != nil {
// Treat EOF or error as "no" // Treat EOF or error as "no"
v.printlnStdout("Cancelled") fmt.Println("Cancelled")
return nil return nil
} }
if strings.ToLower(confirm) != "y" { if strings.ToLower(confirm) != "y" {
v.printlnStdout("Cancelled") fmt.Println("Cancelled")
return nil return nil
} }
} else { } else {
v.printfStdout("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete)) fmt.Printf("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete))
} }
// Delete snapshots (both local and remote) // Delete snapshots (both local and remote)
@@ -570,10 +569,10 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
} }
} }
v.printfStdout("Deleted %d snapshot(s)\n", len(toDelete)) fmt.Printf("Deleted %d snapshot(s)\n", len(toDelete))
// Note: Run 'vaultik prune' separately to clean up unreferenced blobs // Note: Run 'vaultik prune' separately to clean up unreferenced blobs
v.printlnStdout("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.") fmt.Println("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.")
return nil return nil
} }
@@ -614,11 +613,11 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
} }
if !opts.JSON { if !opts.JSON {
v.printfStdout("Verifying snapshot %s\n", snapshotID) fmt.Printf("Verifying snapshot %s\n", snapshotID)
if !snapshotTime.IsZero() { if !snapshotTime.IsZero() {
v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) fmt.Printf("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST"))
} }
v.printlnStdout() fmt.Println()
} }
// Download and parse manifest // Download and parse manifest
@@ -636,18 +635,18 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
result.TotalSize = manifest.TotalCompressedSize result.TotalSize = manifest.TotalCompressedSize
if !opts.JSON { if !opts.JSON {
v.printfStdout("Snapshot information:\n") fmt.Printf("Snapshot information:\n")
v.printfStdout(" Blob count: %d\n", manifest.BlobCount) fmt.Printf(" Blob count: %d\n", manifest.BlobCount)
v.printfStdout(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize))) fmt.Printf(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize)))
if manifest.Timestamp != "" { if manifest.Timestamp != "" {
if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil { if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil {
v.printfStdout(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST")) fmt.Printf(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST"))
} }
} }
v.printlnStdout() fmt.Println()
// Check each blob exists // Check each blob exists
v.printfStdout("Checking blob existence...\n") fmt.Printf("Checking blob existence...\n")
} }
missing := 0 missing := 0
@@ -661,7 +660,7 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
_, err := v.Storage.Stat(v.ctx, blobPath) _, err := v.Storage.Stat(v.ctx, blobPath)
if err != nil { if err != nil {
if !opts.JSON { if !opts.JSON {
v.printfStdout(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize))) fmt.Printf(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize)))
} }
missing++ missing++
missingSize += blob.CompressedSize missingSize += blob.CompressedSize
@@ -684,20 +683,20 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
return v.outputVerifyJSON(result) return v.outputVerifyJSON(result)
} }
v.printfStdout("\nVerification complete:\n") fmt.Printf("\nVerification complete:\n")
v.printfStdout(" Verified: %d blobs (%s)\n", verified, fmt.Printf(" Verified: %d blobs (%s)\n", verified,
humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize))) humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize)))
if missing > 0 { if missing > 0 {
v.printfStdout(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize))) fmt.Printf(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize)))
} else { } else {
v.printfStdout(" Missing: 0 blobs\n") fmt.Printf(" Missing: 0 blobs\n")
} }
v.printfStdout(" Status: ") fmt.Printf(" Status: ")
if missing > 0 { if missing > 0 {
v.printfStdout("FAILED - %d blobs are missing\n", missing) fmt.Printf("FAILED - %d blobs are missing\n", missing)
return fmt.Errorf("%d blobs are missing", missing) return fmt.Errorf("%d blobs are missing", missing)
} else { } else {
v.printfStdout("OK - All blobs verified\n") fmt.Printf("OK - All blobs verified\n")
} }
return nil return nil
@@ -705,7 +704,7 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
// outputVerifyJSON outputs the verification result as JSON // outputVerifyJSON outputs the verification result as JSON
func (v *Vaultik) outputVerifyJSON(result *VerifyResult) error { func (v *Vaultik) outputVerifyJSON(result *VerifyResult) error {
encoder := json.NewEncoder(v.Stdout) encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
if err := encoder.Encode(result); err != nil { if err := encoder.Encode(result); err != nil {
return fmt.Errorf("encoding JSON: %w", err) return fmt.Errorf("encoding JSON: %w", err)
@@ -831,11 +830,11 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov
if opts.DryRun { if opts.DryRun {
result.DryRun = true result.DryRun = true
if !opts.JSON { if !opts.JSON {
v.printfStdout("Would remove snapshot: %s\n", snapshotID) _, _ = fmt.Fprintf(v.Stdout, "Would remove snapshot: %s\n", snapshotID)
if opts.Remote { if opts.Remote {
v.printlnStdout("Would also remove from remote storage") _, _ = fmt.Fprintln(v.Stdout, "Would also remove from remote storage")
} }
v.printlnStdout("[Dry run - no changes made]") _, _ = fmt.Fprintln(v.Stdout, "[Dry run - no changes made]")
} }
if opts.JSON { if opts.JSON {
return result, v.outputRemoveJSON(result) return result, v.outputRemoveJSON(result)
@@ -846,17 +845,17 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov
// Confirm unless --force is used (skip in JSON mode - require --force) // Confirm unless --force is used (skip in JSON mode - require --force)
if !opts.Force && !opts.JSON { if !opts.Force && !opts.JSON {
if opts.Remote { if opts.Remote {
v.printfStdout("Remove snapshot '%s' from local database and remote storage? [y/N] ", snapshotID) _, _ = fmt.Fprintf(v.Stdout, "Remove snapshot '%s' from local database and remote storage? [y/N] ", snapshotID)
} else { } else {
v.printfStdout("Remove snapshot '%s' from local database? [y/N] ", snapshotID) _, _ = fmt.Fprintf(v.Stdout, "Remove snapshot '%s' from local database? [y/N] ", snapshotID)
} }
var confirm string var confirm string
if err := v.scanlnStdin(&confirm); err != nil { if _, err := fmt.Fscanln(v.Stdin, &confirm); err != nil {
v.printlnStdout("Cancelled") _, _ = fmt.Fprintln(v.Stdout, "Cancelled")
return result, nil return result, nil
} }
if strings.ToLower(confirm) != "y" { if strings.ToLower(confirm) != "y" {
v.printlnStdout("Cancelled") _, _ = fmt.Fprintln(v.Stdout, "Cancelled")
return result, nil return result, nil
} }
} }
@@ -883,10 +882,10 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov
} }
// Print summary // Print summary
v.printfStdout("Removed snapshot '%s' from local database\n", snapshotID) _, _ = fmt.Fprintf(v.Stdout, "Removed snapshot '%s' from local database\n", snapshotID)
if opts.Remote { if opts.Remote {
v.printlnStdout("Removed snapshot metadata from remote storage") _, _ = fmt.Fprintln(v.Stdout, "Removed snapshot metadata from remote storage")
v.printlnStdout("\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") _, _ = fmt.Fprintln(v.Stdout, "\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.")
} }
return result, nil return result, nil
@@ -930,7 +929,7 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error)
if len(snapshotIDs) == 0 { if len(snapshotIDs) == 0 {
if !opts.JSON { if !opts.JSON {
v.printlnStdout("No snapshots found") _, _ = fmt.Fprintln(v.Stdout, "No snapshots found")
} }
return result, nil return result, nil
} }
@@ -939,14 +938,14 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error)
result.DryRun = true result.DryRun = true
result.SnapshotsRemoved = snapshotIDs result.SnapshotsRemoved = snapshotIDs
if !opts.JSON { if !opts.JSON {
v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs)) _, _ = fmt.Fprintf(v.Stdout, "Would remove %d snapshot(s):\n", len(snapshotIDs))
for _, id := range snapshotIDs { for _, id := range snapshotIDs {
v.printfStdout(" %s\n", id) _, _ = fmt.Fprintf(v.Stdout, " %s\n", id)
} }
if opts.Remote { if opts.Remote {
v.printlnStdout("Would also remove from remote storage") _, _ = fmt.Fprintln(v.Stdout, "Would also remove from remote storage")
} }
v.printlnStdout("[Dry run - no changes made]") _, _ = fmt.Fprintln(v.Stdout, "[Dry run - no changes made]")
} }
if opts.JSON { if opts.JSON {
return result, v.outputRemoveJSON(result) return result, v.outputRemoveJSON(result)
@@ -987,10 +986,10 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error)
return result, v.outputRemoveJSON(result) return result, v.outputRemoveJSON(result)
} }
v.printfStdout("Removed %d snapshot(s)\n", len(result.SnapshotsRemoved)) _, _ = fmt.Fprintf(v.Stdout, "Removed %d snapshot(s)\n", len(result.SnapshotsRemoved))
if opts.Remote { if opts.Remote {
v.printlnStdout("Removed snapshot metadata from remote storage") _, _ = fmt.Fprintln(v.Stdout, "Removed snapshot metadata from remote storage")
v.printlnStdout("\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") _, _ = fmt.Fprintln(v.Stdout, "\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.")
} }
return result, nil return result, nil
@@ -1044,7 +1043,7 @@ func (v *Vaultik) deleteSnapshotFromRemote(snapshotID string) error {
// outputRemoveJSON outputs the removal result as JSON // outputRemoveJSON outputs the removal result as JSON
func (v *Vaultik) outputRemoveJSON(result *RemoveResult) error { func (v *Vaultik) outputRemoveJSON(result *RemoveResult) error {
encoder := json.NewEncoder(v.Stdout) encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
return encoder.Encode(result) return encoder.Encode(result)
} }
@@ -1118,29 +1117,21 @@ func (v *Vaultik) PruneDatabase() (*PruneResult, error) {
) )
// Print summary // Print summary
v.printfStdout("Local database prune complete:\n") _, _ = fmt.Fprintf(v.Stdout, "Local database prune complete:\n")
v.printfStdout(" Incomplete snapshots removed: %d\n", result.SnapshotsDeleted) _, _ = fmt.Fprintf(v.Stdout, " Incomplete snapshots removed: %d\n", result.SnapshotsDeleted)
v.printfStdout(" Orphaned files removed: %d\n", result.FilesDeleted) _, _ = fmt.Fprintf(v.Stdout, " Orphaned files removed: %d\n", result.FilesDeleted)
v.printfStdout(" Orphaned chunks removed: %d\n", result.ChunksDeleted) _, _ = fmt.Fprintf(v.Stdout, " Orphaned chunks removed: %d\n", result.ChunksDeleted)
v.printfStdout(" Orphaned blobs removed: %d\n", result.BlobsDeleted) _, _ = fmt.Fprintf(v.Stdout, " Orphaned blobs removed: %d\n", result.BlobsDeleted)
return result, nil return result, nil
} }
// validTableNameRe matches table names containing only lowercase alphanumeric characters and underscores. // getTableCount returns the count of rows in a table
var validTableNameRe = regexp.MustCompile(`^[a-z0-9_]+$`)
// getTableCount returns the count of rows in a table.
// The tableName is sanitized to only allow [a-z0-9_] characters to prevent SQL injection.
func (v *Vaultik) getTableCount(tableName string) (int64, error) { func (v *Vaultik) getTableCount(tableName string) (int64, error) {
if v.DB == nil { if v.DB == nil {
return 0, nil return 0, nil
} }
if !validTableNameRe.MatchString(tableName) {
return 0, fmt.Errorf("invalid table name: %q", tableName)
}
var count int64 var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName) query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)
err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&count) err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&count)