diff --git a/internal/cli/restore.go b/internal/cli/restore.go index c69bf6e..33e3618 100644 --- a/internal/cli/restore.go +++ b/internal/cli/restore.go @@ -57,76 +57,7 @@ Examples: vaultik restore --verify myhost_docs_2025-01-01T12:00:00Z /restore`, Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - snapshotID := args[0] - opts.TargetDir = args[1] - if len(args) > 2 { - opts.Paths = args[2:] - } - - // Use unified config resolution - configPath, err := ResolveConfigPath() - if err != nil { - return err - } - - // Use the app framework like other commands - rootFlags := GetRootFlags() - return RunWithApp(cmd.Context(), AppOptions{ - ConfigPath: configPath, - LogOptions: log.LogOptions{ - Verbose: rootFlags.Verbose, - Debug: rootFlags.Debug, - Quiet: rootFlags.Quiet, - }, - Modules: []fx.Option{ - fx.Provide(fx.Annotate( - func(g *globals.Globals, cfg *config.Config, - storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp { - return &RestoreApp{ - Globals: g, - Config: cfg, - Storage: storer, - Vaultik: v, - Shutdowner: shutdowner, - } - }, - )), - }, - Invokes: []fx.Option{ - fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - // Start the restore operation in a goroutine - go func() { - // Run the restore operation - restoreOpts := &vaultik.RestoreOptions{ - SnapshotID: snapshotID, - TargetDir: opts.TargetDir, - Paths: opts.Paths, - Verify: opts.Verify, - } - if err := app.Vaultik.Restore(restoreOpts); err != nil { - if err != context.Canceled { - log.Error("Restore operation failed", "error", err) - } - } - - // Shutdown the app when restore completes - if err := app.Shutdowner.Shutdown(); err != nil { - log.Error("Failed to shutdown", "error", err) - } - }() - return nil - }, - OnStop: func(ctx context.Context) error { - log.Debug("Stopping restore operation") - app.Vaultik.Cancel() - return nil - }, - }) - }), - }, - }) + return runRestore(cmd, args, opts) }, } @@ -134,3 +65,87 @@ Examples: return cmd } + +// runRestore parses arguments and runs the restore operation through the app framework +func runRestore(cmd *cobra.Command, args []string, opts *RestoreOptions) error { + snapshotID := args[0] + opts.TargetDir = args[1] + if len(args) > 2 { + opts.Paths = args[2:] + } + + // Use unified config resolution + configPath, err := ResolveConfigPath() + if err != nil { + return err + } + + // Use the app framework like other commands + rootFlags := GetRootFlags() + return RunWithApp(cmd.Context(), AppOptions{ + ConfigPath: configPath, + LogOptions: log.LogOptions{ + Verbose: rootFlags.Verbose, + Debug: rootFlags.Debug, + Quiet: rootFlags.Quiet, + }, + Modules: buildRestoreModules(), + Invokes: buildRestoreInvokes(snapshotID, opts), + }) +} + +// buildRestoreModules returns the fx.Options for dependency injection in restore +func buildRestoreModules() []fx.Option { + return []fx.Option{ + fx.Provide(fx.Annotate( + func(g *globals.Globals, cfg *config.Config, + storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp { + return &RestoreApp{ + Globals: g, + Config: cfg, + Storage: storer, + Vaultik: v, + Shutdowner: shutdowner, + } + }, + )), + } +} + +// buildRestoreInvokes returns the fx.Options that wire up the restore lifecycle +func buildRestoreInvokes(snapshotID string, opts *RestoreOptions) []fx.Option { + return []fx.Option{ + fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + // Start the restore operation in a goroutine + go func() { + // Run the restore operation + restoreOpts := &vaultik.RestoreOptions{ + SnapshotID: snapshotID, + TargetDir: opts.TargetDir, + Paths: opts.Paths, + Verify: opts.Verify, + } + if err := app.Vaultik.Restore(restoreOpts); err != nil { + if err != context.Canceled { + log.Error("Restore operation failed", "error", err) + } + } + + // Shutdown the app when restore completes + if err := app.Shutdowner.Shutdown(); err != nil { + log.Error("Failed to shutdown", "error", err) + } + }() + return nil + }, + OnStop: func(ctx context.Context) error { + log.Debug("Stopping restore operation") + app.Vaultik.Cancel() + return nil + }, + }) + }), + } +} diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index bb01ea1..883c572 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -227,12 +227,39 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st } }() + // Steps 1-5: Copy, clean, vacuum, compress, and read the database + finalData, tempDBPath, err := sm.prepareExportDB(ctx, dbPath, snapshotID, tempDir) + if err != nil { + return err + } + + // Step 6: Generate blob manifest (before closing temp DB) + blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) + if err != nil { + return fmt.Errorf("generating blob manifest: %w", err) + } + + // Step 7: Upload to S3 in snapshot subdirectory + if err := sm.uploadSnapshotArtifacts(ctx, snapshotID, finalData, blobManifest); err != nil { + return err + } + + log.Info("Uploaded snapshot metadata", + "snapshot_id", snapshotID, + "db_size", len(finalData), + "manifest_size", len(blobManifest)) + return nil +} + +// prepareExportDB copies, cleans, vacuums, and compresses the snapshot database for export. +// Returns the compressed data and the path to the temporary database (needed for manifest generation). +func (sm *SnapshotManager) prepareExportDB(ctx context.Context, dbPath, snapshotID, tempDir string) ([]byte, string, error) { // Step 1: Copy database to temp file // The main database should be closed at this point tempDBPath := filepath.Join(tempDir, "snapshot.db") log.Debug("Copying database to temporary location", "source", dbPath, "destination", tempDBPath) if err := sm.copyFile(dbPath, tempDBPath); err != nil { - return fmt.Errorf("copying database: %w", err) + return nil, "", fmt.Errorf("copying database: %w", err) } log.Debug("Database copy complete", "size", sm.getFileSize(tempDBPath)) @@ -240,7 +267,7 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st log.Debug("Cleaning temporary database", "snapshot_id", snapshotID) stats, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID) if err != nil { - return fmt.Errorf("cleaning snapshot database: %w", err) + return nil, "", fmt.Errorf("cleaning snapshot database: %w", err) } log.Info("Temporary database cleanup complete", "db_path", tempDBPath, @@ -255,14 +282,14 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st // Step 3: VACUUM the database to remove deleted data and compact // This is critical for security - ensures no stale/deleted data is uploaded if err := sm.vacuumDatabase(tempDBPath); err != nil { - return fmt.Errorf("vacuuming database: %w", err) + return nil, "", fmt.Errorf("vacuuming database: %w", err) } log.Debug("Database vacuumed", "size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath)))) // Step 4: Compress and encrypt the binary database file compressedPath := filepath.Join(tempDir, "db.zst.age") if err := sm.compressFile(tempDBPath, compressedPath); err != nil { - return fmt.Errorf("compressing database: %w", err) + return nil, "", fmt.Errorf("compressing database: %w", err) } log.Debug("Compression complete", "original_size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath))), @@ -271,49 +298,43 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st // Step 5: Read compressed and encrypted data for upload finalData, err := afero.ReadFile(sm.fs, compressedPath) if err != nil { - return fmt.Errorf("reading compressed dump: %w", err) + return nil, "", fmt.Errorf("reading compressed dump: %w", err) } - // Step 6: Generate blob manifest (before closing temp DB) - blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) - if err != nil { - return fmt.Errorf("generating blob manifest: %w", err) - } + return finalData, tempDBPath, nil +} - // Step 7: Upload to S3 in snapshot subdirectory +// uploadSnapshotArtifacts uploads the database backup and blob manifest to S3 +func (sm *SnapshotManager) uploadSnapshotArtifacts(ctx context.Context, snapshotID string, dbData, manifestData []byte) error { // Upload database backup (compressed and encrypted) dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) dbUploadStart := time.Now() - if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(finalData)); err != nil { + if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(dbData)); err != nil { return fmt.Errorf("uploading snapshot database: %w", err) } dbUploadDuration := time.Since(dbUploadStart) - dbUploadSpeed := float64(len(finalData)) * 8 / dbUploadDuration.Seconds() // bits per second + dbUploadSpeed := float64(len(dbData)) * 8 / dbUploadDuration.Seconds() // bits per second log.Info("Uploaded snapshot database", "path", dbKey, - "size", humanize.Bytes(uint64(len(finalData))), + "size", humanize.Bytes(uint64(len(dbData))), "duration", dbUploadDuration, "speed", humanize.SI(dbUploadSpeed, "bps")) // Upload blob manifest (compressed only, not encrypted) manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) manifestUploadStart := time.Now() - if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil { + if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(manifestData)); err != nil { return fmt.Errorf("uploading blob manifest: %w", err) } manifestUploadDuration := time.Since(manifestUploadStart) - manifestUploadSpeed := float64(len(blobManifest)) * 8 / manifestUploadDuration.Seconds() // bits per second + manifestUploadSpeed := float64(len(manifestData)) * 8 / manifestUploadDuration.Seconds() // bits per second log.Info("Uploaded blob manifest", "path", manifestKey, - "size", humanize.Bytes(uint64(len(blobManifest))), + "size", humanize.Bytes(uint64(len(manifestData))), "duration", manifestUploadDuration, "speed", humanize.SI(manifestUploadSpeed, "bps")) - log.Info("Uploaded snapshot metadata", - "snapshot_id", snapshotID, - "db_size", len(finalData), - "manifest_size", len(blobManifest)) return nil } diff --git a/internal/vaultik/restore.go b/internal/vaultik/restore.go index 20f7ba8..acf2ad8 100644 --- a/internal/vaultik/restore.go +++ b/internal/vaultik/restore.go @@ -55,15 +55,9 @@ type RestoreResult struct { func (v *Vaultik) Restore(opts *RestoreOptions) error { startTime := time.Now() - // Check for age_secret_key - if v.Config.AgeSecretKey == "" { - return fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'") - } - - // Parse the age identity - identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey) + identity, err := v.prepareRestoreIdentity() if err != nil { - return fmt.Errorf("parsing age secret key: %w", err) + return err } log.Info("Starting restore operation", @@ -115,10 +109,73 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { } // Step 5: Restore files + result, err := v.restoreAllFiles(files, repos, opts, identity, chunkToBlobMap) + if err != nil { + return err + } + + result.Duration = time.Since(startTime) + + log.Info("Restore complete", + "files_restored", result.FilesRestored, + "bytes_restored", humanize.Bytes(uint64(result.BytesRestored)), + "blobs_downloaded", result.BlobsDownloaded, + "bytes_downloaded", humanize.Bytes(uint64(result.BytesDownloaded)), + "duration", result.Duration, + ) + + v.printfStdout("Restored %d files (%s) in %s\n", + result.FilesRestored, + humanize.Bytes(uint64(result.BytesRestored)), + result.Duration.Round(time.Second), + ) + + if result.FilesFailed > 0 { + _, _ = fmt.Fprintf(v.Stdout, "\nWARNING: %d file(s) failed to restore:\n", result.FilesFailed) + for _, path := range result.FailedFiles { + _, _ = fmt.Fprintf(v.Stdout, " - %s\n", path) + } + } + + // Run verification if requested + if opts.Verify { + if err := v.handleRestoreVerification(repos, files, opts, result); err != nil { + return err + } + } + + if result.FilesFailed > 0 { + return fmt.Errorf("%d file(s) failed to restore", result.FilesFailed) + } + + return nil +} + +// prepareRestoreIdentity validates that an age secret key is configured and parses it +func (v *Vaultik) prepareRestoreIdentity() (age.Identity, error) { + if v.Config.AgeSecretKey == "" { + return nil, fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'") + } + + identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey) + if err != nil { + return nil, fmt.Errorf("parsing age secret key: %w", err) + } + return identity, nil +} + +// restoreAllFiles iterates over files and restores each one, tracking progress and failures +func (v *Vaultik) restoreAllFiles( + files []*database.File, + repos *database.Repositories, + opts *RestoreOptions, + identity age.Identity, + chunkToBlobMap map[string]*database.BlobChunk, +) (*RestoreResult, error) { result := &RestoreResult{} blobCache, err := newBlobDiskCache(4 * v.Config.BlobSizeLimit.Int64()) if err != nil { - return fmt.Errorf("creating blob cache: %w", err) + return nil, fmt.Errorf("creating blob cache: %w", err) } defer func() { _ = blobCache.Close() }() @@ -133,7 +190,7 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { for i, file := range files { if v.ctx.Err() != nil { - return v.ctx.Err() + return nil, v.ctx.Err() } if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil { @@ -165,53 +222,32 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { _ = bar.Finish() } - result.Duration = time.Since(startTime) + return result, nil +} - log.Info("Restore complete", - "files_restored", result.FilesRestored, - "bytes_restored", humanize.Bytes(uint64(result.BytesRestored)), - "blobs_downloaded", result.BlobsDownloaded, - "bytes_downloaded", humanize.Bytes(uint64(result.BytesDownloaded)), - "duration", result.Duration, - ) - - v.printfStdout("Restored %d files (%s) in %s\n", - result.FilesRestored, - humanize.Bytes(uint64(result.BytesRestored)), - result.Duration.Round(time.Second), - ) +// handleRestoreVerification runs post-restore verification if requested +func (v *Vaultik) handleRestoreVerification( + repos *database.Repositories, + files []*database.File, + opts *RestoreOptions, + result *RestoreResult, +) error { + if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil { + return fmt.Errorf("verification failed: %w", err) + } if result.FilesFailed > 0 { - _, _ = fmt.Fprintf(v.Stdout, "\nWARNING: %d file(s) failed to restore:\n", result.FilesFailed) + v.printfStdout("\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed) for _, path := range result.FailedFiles { - _, _ = fmt.Fprintf(v.Stdout, " - %s\n", path) + v.printfStdout(" - %s\n", path) } + return fmt.Errorf("%d files failed verification", result.FilesFailed) } - // Run verification if requested - if opts.Verify { - if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil { - return fmt.Errorf("verification failed: %w", err) - } - - if result.FilesFailed > 0 { - v.printfStdout("\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed) - for _, path := range result.FailedFiles { - v.printfStdout(" - %s\n", path) - } - return fmt.Errorf("%d files failed verification", result.FilesFailed) - } - - v.printfStdout("Verified %d files (%s)\n", - result.FilesVerified, - humanize.Bytes(uint64(result.BytesVerified)), - ) - } - - if result.FilesFailed > 0 { - return fmt.Errorf("%d file(s) failed to restore", result.FilesFailed) - } - + v.printfStdout("Verified %d files (%s)\n", + result.FilesVerified, + humanize.Bytes(uint64(result.BytesVerified)), + ) return nil } diff --git a/internal/vaultik/snapshot.go b/internal/vaultik/snapshot.go index 6943cd8..21904bf 100644 --- a/internal/vaultik/snapshot.go +++ b/internal/vaultik/snapshot.go @@ -563,26 +563,9 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool) return snapshots[i].Timestamp.After(snapshots[j].Timestamp) }) - var toDelete []SnapshotInfo - - if keepLatest { - // Keep only the most recent snapshot - if len(snapshots) > 1 { - toDelete = snapshots[1:] - } - } else if olderThan != "" { - // Parse duration - duration, err := parseDuration(olderThan) - if err != nil { - return fmt.Errorf("invalid duration: %w", err) - } - - cutoff := time.Now().UTC().Add(-duration) - for _, snap := range snapshots { - if snap.Timestamp.Before(cutoff) { - toDelete = append(toDelete, snap) - } - } + toDelete, err := v.collectSnapshotsToPurge(snapshots, keepLatest, olderThan) + if err != nil { + return err } if len(toDelete) == 0 { @@ -590,6 +573,41 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool) return nil } + return v.confirmAndExecutePurge(toDelete, force) +} + +// collectSnapshotsToPurge determines which snapshots to delete based on retention criteria +func (v *Vaultik) collectSnapshotsToPurge(snapshots []SnapshotInfo, keepLatest bool, olderThan string) ([]SnapshotInfo, error) { + if keepLatest { + // Keep only the most recent snapshot + if len(snapshots) > 1 { + return snapshots[1:], nil + } + return nil, nil + } + + if olderThan != "" { + // Parse duration + duration, err := parseDuration(olderThan) + if err != nil { + return nil, fmt.Errorf("invalid duration: %w", err) + } + + cutoff := time.Now().UTC().Add(-duration) + var toDelete []SnapshotInfo + for _, snap := range snapshots { + if snap.Timestamp.Before(cutoff) { + toDelete = append(toDelete, snap) + } + } + return toDelete, nil + } + + return nil, nil +} + +// confirmAndExecutePurge shows deletion candidates, confirms with user, and deletes snapshots +func (v *Vaultik) confirmAndExecutePurge(toDelete []SnapshotInfo, force bool) error { // Show what will be deleted v.printfStdout("The following snapshots will be deleted:\n\n") for _, snap := range toDelete { @@ -655,29 +673,7 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio result.Mode = "deep" } - // Parse snapshot ID to extract timestamp - parts := strings.Split(snapshotID, "-") - var snapshotTime time.Time - if len(parts) >= 3 { - // Format: hostname-YYYYMMDD-HHMMSSZ - dateStr := parts[len(parts)-2] - timeStr := parts[len(parts)-1] - if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") { - timeStr = timeStr[:6] // Remove Z - timestamp, err := time.Parse("20060102150405", dateStr+timeStr) - if err == nil { - snapshotTime = timestamp - } - } - } - - if !opts.JSON { - v.printfStdout("Verifying snapshot %s\n", snapshotID) - if !snapshotTime.IsZero() { - v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) - } - v.printlnStdout() - } + v.printVerifyHeader(snapshotID, opts) // Download and parse manifest manifest, err := v.downloadManifest(snapshotID) @@ -708,10 +704,40 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio v.printfStdout("Checking blob existence...\n") } - missing := 0 - verified := 0 - missingSize := int64(0) + result.Verified, result.Missing, result.MissingSize = v.verifyManifestBlobsExist(manifest, opts) + return v.formatVerifyResult(result, manifest, opts) +} + +// printVerifyHeader prints the snapshot ID and parsed timestamp for verification output +func (v *Vaultik) printVerifyHeader(snapshotID string, opts *VerifyOptions) { + // Parse snapshot ID to extract timestamp + parts := strings.Split(snapshotID, "-") + var snapshotTime time.Time + if len(parts) >= 3 { + // Format: hostname-YYYYMMDD-HHMMSSZ + dateStr := parts[len(parts)-2] + timeStr := parts[len(parts)-1] + if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") { + timeStr = timeStr[:6] // Remove Z + timestamp, err := time.Parse("20060102150405", dateStr+timeStr) + if err == nil { + snapshotTime = timestamp + } + } + } + + if !opts.JSON { + v.printfStdout("Verifying snapshot %s\n", snapshotID) + if !snapshotTime.IsZero() { + v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) + } + v.printlnStdout() + } +} + +// verifyManifestBlobsExist checks that each blob in the manifest exists in storage +func (v *Vaultik) verifyManifestBlobsExist(manifest *snapshot.Manifest, opts *VerifyOptions) (verified, missing int, missingSize int64) { for _, blob := range manifest.Blobs { blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash) @@ -727,15 +753,15 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio verified++ } } + return verified, missing, missingSize +} - result.Verified = verified - result.Missing = missing - result.MissingSize = missingSize - +// formatVerifyResult outputs the final verification results as JSON or human-readable text +func (v *Vaultik) formatVerifyResult(result *VerifyResult, manifest *snapshot.Manifest, opts *VerifyOptions) error { if opts.JSON { - if missing > 0 { + if result.Missing > 0 { result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("%d blobs are missing", missing) + result.ErrorMessage = fmt.Sprintf("%d blobs are missing", result.Missing) } else { result.Status = "ok" } @@ -743,20 +769,19 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio } v.printfStdout("\nVerification complete:\n") - v.printfStdout(" Verified: %d blobs (%s)\n", verified, - humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize))) - if missing > 0 { - v.printfStdout(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize))) + v.printfStdout(" Verified: %d blobs (%s)\n", result.Verified, + humanize.Bytes(uint64(manifest.TotalCompressedSize-result.MissingSize))) + if result.Missing > 0 { + v.printfStdout(" Missing: %d blobs (%s)\n", result.Missing, humanize.Bytes(uint64(result.MissingSize))) } else { v.printfStdout(" Missing: 0 blobs\n") } v.printfStdout(" Status: ") - if missing > 0 { - v.printfStdout("FAILED - %d blobs are missing\n", missing) - return fmt.Errorf("%d blobs are missing", missing) - } else { - v.printfStdout("OK - All blobs verified\n") + if result.Missing > 0 { + v.printfStdout("FAILED - %d blobs are missing\n", result.Missing) + return fmt.Errorf("%d blobs are missing", result.Missing) } + v.printfStdout("OK - All blobs verified\n") return nil } @@ -952,9 +977,27 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov // RemoveAllSnapshots removes all snapshots from local database and optionally from remote func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) { - result := &RemoveResult{} + snapshotIDs, err := v.listAllRemoteSnapshotIDs() + if err != nil { + return nil, err + } - // List all snapshots + if len(snapshotIDs) == 0 { + if !opts.JSON { + v.printlnStdout("No snapshots found") + } + return &RemoveResult{}, nil + } + + if opts.DryRun { + return v.handleRemoveAllDryRun(snapshotIDs, opts) + } + + return v.executeRemoveAll(snapshotIDs, opts) +} + +// listAllRemoteSnapshotIDs collects all unique snapshot IDs from remote storage +func (v *Vaultik) listAllRemoteSnapshotIDs() ([]string, error) { log.Info("Listing all snapshots") objectCh := v.Storage.ListStream(v.ctx, "metadata/") @@ -986,32 +1029,33 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) } } - if len(snapshotIDs) == 0 { - if !opts.JSON { - v.printlnStdout("No snapshots found") - } - return result, nil - } + return snapshotIDs, nil +} - if opts.DryRun { - result.DryRun = true - result.SnapshotsRemoved = snapshotIDs - if !opts.JSON { - v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs)) - for _, id := range snapshotIDs { - v.printfStdout(" %s\n", id) - } - if opts.Remote { - v.printlnStdout("Would also remove from remote storage") - } - v.printlnStdout("[Dry run - no changes made]") - } - if opts.JSON { - return result, v.outputRemoveJSON(result) - } - return result, nil +// handleRemoveAllDryRun handles the dry-run mode for removing all snapshots +func (v *Vaultik) handleRemoveAllDryRun(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) { + result := &RemoveResult{ + DryRun: true, + SnapshotsRemoved: snapshotIDs, } + if !opts.JSON { + v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs)) + for _, id := range snapshotIDs { + v.printfStdout(" %s\n", id) + } + if opts.Remote { + v.printlnStdout("Would also remove from remote storage") + } + v.printlnStdout("[Dry run - no changes made]") + } + if opts.JSON { + return result, v.outputRemoveJSON(result) + } + return result, nil +} +// executeRemoveAll removes all snapshots from local database and optionally from remote storage +func (v *Vaultik) executeRemoveAll(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) { // --all requires --force if !opts.Force { return nil, fmt.Errorf("--all requires --force") @@ -1019,6 +1063,7 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) log.Info("Removing all snapshots", "count", len(snapshotIDs)) + result := &RemoveResult{} for _, snapshotID := range snapshotIDs { log.Info("Removing snapshot", "snapshot_id", snapshotID) diff --git a/internal/vaultik/verify.go b/internal/vaultik/verify.go index b5f5a49..732d70b 100644 --- a/internal/vaultik/verify.go +++ b/internal/vaultik/verify.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/hex" "fmt" + "hash" "io" "os" "time" @@ -301,7 +302,27 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { } defer decompressor.Close() - // Query blob chunks from database to get offsets and lengths + chunkCount, err := v.verifyBlobChunks(db, blobInfo.Hash, decompressor) + if err != nil { + return err + } + + if err := v.verifyBlobFinalIntegrity(decompressor, blobHasher, blobInfo.Hash); err != nil { + return err + } + + log.Info("Blob verified", + "hash", blobInfo.Hash[:16]+"...", + "chunks", chunkCount, + "size", humanize.Bytes(uint64(blobInfo.CompressedSize)), + ) + + return nil +} + +// verifyBlobChunks queries blob chunks from the database and verifies each chunk's hash +// against the decompressed blob stream +func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.Reader) (int, error) { query := ` SELECT bc.chunk_hash, bc.offset, bc.length FROM blob_chunks bc @@ -309,9 +330,9 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { WHERE b.blob_hash = ? ORDER BY bc.offset ` - rows, err := db.QueryContext(v.ctx, query, blobInfo.Hash) + rows, err := db.QueryContext(v.ctx, query, blobHash) if err != nil { - return fmt.Errorf("failed to query blob chunks: %w", err) + return 0, fmt.Errorf("failed to query blob chunks: %w", err) } defer func() { _ = rows.Close() }() @@ -324,12 +345,12 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { var chunkHash string var offset, length int64 if err := rows.Scan(&chunkHash, &offset, &length); err != nil { - return fmt.Errorf("failed to scan chunk row: %w", err) + return 0, fmt.Errorf("failed to scan chunk row: %w", err) } // Verify chunk ordering if offset <= lastOffset { - return fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) + return 0, fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) } lastOffset = offset @@ -338,7 +359,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Skip to the correct offset skipBytes := offset - totalRead if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil { - return fmt.Errorf("failed to skip to offset %d: %w", offset, err) + return 0, fmt.Errorf("failed to skip to offset %d: %w", offset, err) } totalRead = offset } @@ -346,7 +367,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Read chunk data chunkData := make([]byte, length) if _, err := io.ReadFull(decompressor, chunkData); err != nil { - return fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) + return 0, fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) } totalRead += length @@ -356,7 +377,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { calculatedHash := hex.EncodeToString(hasher.Sum(nil)) if calculatedHash != chunkHash { - return fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", + return 0, fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", offset, calculatedHash, chunkHash) } @@ -364,9 +385,15 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { } if err := rows.Err(); err != nil { - return fmt.Errorf("error iterating blob chunks: %w", err) + return 0, fmt.Errorf("error iterating blob chunks: %w", err) } + return chunkCount, nil +} + +// verifyBlobFinalIntegrity checks that no trailing data exists in the decompressed stream +// and that the encrypted blob hash matches the expected value +func (v *Vaultik) verifyBlobFinalIntegrity(decompressor io.Reader, blobHasher hash.Hash, expectedHash string) error { // Verify no remaining data in blob - if chunk list is accurate, blob should be fully consumed remaining, err := io.Copy(io.Discard, decompressor) if err != nil { @@ -378,17 +405,11 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Verify blob hash matches the encrypted data we downloaded calculatedBlobHash := hex.EncodeToString(blobHasher.Sum(nil)) - if calculatedBlobHash != blobInfo.Hash { + if calculatedBlobHash != expectedHash { return fmt.Errorf("blob hash mismatch: calculated %s, expected %s", - calculatedBlobHash, blobInfo.Hash) + calculatedBlobHash, expectedHash) } - log.Info("Blob verified", - "hash", blobInfo.Hash[:16]+"...", - "chunks", chunkCount, - "size", humanize.Bytes(uint64(blobInfo.CompressedSize)), - ) - return nil }