diff --git a/internal/blob/packer.go b/internal/blob/packer.go index c5284ec..7edf15b 100644 --- a/internal/blob/packer.go +++ b/internal/blob/packer.go @@ -361,101 +361,23 @@ func (p *Packer) finalizeCurrentBlob() error { return nil } - // Close blobgen writer to flush all data - if err := p.currentBlob.writer.Close(); err != nil { - p.cleanupTempFile() - return fmt.Errorf("closing blobgen writer: %w", err) - } - - // Sync file to ensure all data is written - if err := p.currentBlob.tempFile.Sync(); err != nil { - p.cleanupTempFile() - return fmt.Errorf("syncing temp file: %w", err) - } - - // Get the final size (encrypted if applicable) - finalSize, err := p.currentBlob.tempFile.Seek(0, io.SeekCurrent) + blobHash, finalSize, err := p.closeBlobWriter() if err != nil { - p.cleanupTempFile() - return fmt.Errorf("getting file size: %w", err) + return err } - // Reset to beginning for reading - if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { - p.cleanupTempFile() - return fmt.Errorf("seeking to start: %w", err) - } + chunkRefs := p.buildChunkRefs() - // Get hash from blobgen writer (of final encrypted data) - finalHash := p.currentBlob.writer.Sum256() - blobHash := hex.EncodeToString(finalHash) - - // Create chunk references with offsets - chunkRefs := make([]*BlobChunkRef, 0, len(p.currentBlob.chunks)) - - for _, chunk := range p.currentBlob.chunks { - chunkRefs = append(chunkRefs, &BlobChunkRef{ - ChunkHash: chunk.Hash, - Offset: chunk.Offset, - Length: chunk.Size, - }) - } - - // Get pending chunks (will be inserted to DB and reported to handler) chunksToInsert := p.pendingChunks - p.pendingChunks = nil // Clear pending list + p.pendingChunks = nil - // Insert pending chunks, blob_chunks, and update blob in a single transaction - if p.repos != nil { - blobIDTyped, parseErr := types.ParseBlobID(p.currentBlob.id) - if parseErr != nil { - p.cleanupTempFile() - return fmt.Errorf("parsing blob ID: %w", parseErr) - } - err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { - // First insert all pending chunks (required for blob_chunks FK) - for _, chunk := range chunksToInsert { - dbChunk := &database.Chunk{ - ChunkHash: types.ChunkHash(chunk.Hash), - Size: chunk.Size, - } - if err := p.repos.Chunks.Create(ctx, tx, dbChunk); err != nil { - return fmt.Errorf("creating chunk: %w", err) - } - } - - // Insert all blob_chunk records in batch - for _, chunk := range p.currentBlob.chunks { - blobChunk := &database.BlobChunk{ - BlobID: blobIDTyped, - ChunkHash: types.ChunkHash(chunk.Hash), - Offset: chunk.Offset, - Length: chunk.Size, - } - if err := p.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil { - return fmt.Errorf("creating blob_chunk: %w", err) - } - } - - // Update blob record with final hash and sizes - return p.repos.Blobs.UpdateFinished(ctx, tx, p.currentBlob.id, blobHash, - p.currentBlob.size, finalSize) - }) - if err != nil { - p.cleanupTempFile() - return fmt.Errorf("finalizing blob transaction: %w", err) - } - - log.Debug("Committed blob transaction", - "chunks_inserted", len(chunksToInsert), - "blob_chunks_inserted", len(p.currentBlob.chunks)) + if err := p.commitBlobToDatabase(blobHash, finalSize, chunksToInsert); err != nil { + return err } - // Create finished blob finished := &FinishedBlob{ ID: p.currentBlob.id, Hash: blobHash, - Data: nil, // We don't load data into memory anymore Chunks: chunkRefs, CreatedTS: p.currentBlob.startTime, Uncompressed: p.currentBlob.size, @@ -464,28 +386,105 @@ func (p *Packer) finalizeCurrentBlob() error { compressionRatio := float64(finished.Compressed) / float64(finished.Uncompressed) log.Info("Finalized blob (compressed and encrypted)", - "hash", blobHash, - "chunks", len(chunkRefs), - "uncompressed", finished.Uncompressed, - "compressed", finished.Compressed, + "hash", blobHash, "chunks", len(chunkRefs), + "uncompressed", finished.Uncompressed, "compressed", finished.Compressed, "ratio", fmt.Sprintf("%.2f", compressionRatio), "duration", time.Since(p.currentBlob.startTime)) - // Collect inserted chunk hashes for the scanner to track var insertedChunkHashes []string for _, chunk := range chunksToInsert { insertedChunkHashes = append(insertedChunkHashes, chunk.Hash) } - // Call blob handler if set + return p.deliverFinishedBlob(finished, insertedChunkHashes) +} + +// closeBlobWriter closes the writer, syncs to disk, and returns the blob hash and final size +func (p *Packer) closeBlobWriter() (string, int64, error) { + if err := p.currentBlob.writer.Close(); err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("closing blobgen writer: %w", err) + } + if err := p.currentBlob.tempFile.Sync(); err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("syncing temp file: %w", err) + } + + finalSize, err := p.currentBlob.tempFile.Seek(0, io.SeekCurrent) + if err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("getting file size: %w", err) + } + if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("seeking to start: %w", err) + } + + finalHash := p.currentBlob.writer.Sum256() + return hex.EncodeToString(finalHash), finalSize, nil +} + +// buildChunkRefs creates BlobChunkRef entries from the current blob's chunks +func (p *Packer) buildChunkRefs() []*BlobChunkRef { + refs := make([]*BlobChunkRef, 0, len(p.currentBlob.chunks)) + for _, chunk := range p.currentBlob.chunks { + refs = append(refs, &BlobChunkRef{ + ChunkHash: chunk.Hash, Offset: chunk.Offset, Length: chunk.Size, + }) + } + return refs +} + +// commitBlobToDatabase inserts pending chunks, blob_chunks, and updates the blob record +func (p *Packer) commitBlobToDatabase(blobHash string, finalSize int64, chunksToInsert []PendingChunk) error { + if p.repos == nil { + return nil + } + + blobIDTyped, parseErr := types.ParseBlobID(p.currentBlob.id) + if parseErr != nil { + p.cleanupTempFile() + return fmt.Errorf("parsing blob ID: %w", parseErr) + } + + err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + for _, chunk := range chunksToInsert { + dbChunk := &database.Chunk{ChunkHash: types.ChunkHash(chunk.Hash), Size: chunk.Size} + if err := p.repos.Chunks.Create(ctx, tx, dbChunk); err != nil { + return fmt.Errorf("creating chunk: %w", err) + } + } + + for _, chunk := range p.currentBlob.chunks { + blobChunk := &database.BlobChunk{ + BlobID: blobIDTyped, ChunkHash: types.ChunkHash(chunk.Hash), + Offset: chunk.Offset, Length: chunk.Size, + } + if err := p.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil { + return fmt.Errorf("creating blob_chunk: %w", err) + } + } + + return p.repos.Blobs.UpdateFinished(ctx, tx, p.currentBlob.id, blobHash, p.currentBlob.size, finalSize) + }) + if err != nil { + p.cleanupTempFile() + return fmt.Errorf("finalizing blob transaction: %w", err) + } + + log.Debug("Committed blob transaction", + "chunks_inserted", len(chunksToInsert), "blob_chunks_inserted", len(p.currentBlob.chunks)) + return nil +} + +// deliverFinishedBlob passes the blob to the handler or stores it internally +func (p *Packer) deliverFinishedBlob(finished *FinishedBlob, insertedChunkHashes []string) error { if p.blobHandler != nil { - // Reset file position for handler if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { p.cleanupTempFile() return fmt.Errorf("seeking for handler: %w", err) } - // Create a blob reader that includes the data stream blobWithReader := &BlobWithReader{ FinishedBlob: finished, Reader: p.currentBlob.tempFile, @@ -497,30 +496,26 @@ func (p *Packer) finalizeCurrentBlob() error { p.cleanupTempFile() return fmt.Errorf("blob handler failed: %w", err) } - // Note: blob handler is responsible for closing/cleaning up temp file - p.currentBlob = nil - } else { - log.Debug("No blob handler callback configured", "blob_hash", blobHash[:8]+"...") - // No handler, need to read data for legacy behavior - if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { - p.cleanupTempFile() - return fmt.Errorf("seeking to read data: %w", err) - } - - data, err := io.ReadAll(p.currentBlob.tempFile) - if err != nil { - p.cleanupTempFile() - return fmt.Errorf("reading blob data: %w", err) - } - finished.Data = data - - p.finishedBlobs = append(p.finishedBlobs, finished) - - // Cleanup - p.cleanupTempFile() p.currentBlob = nil + return nil } + // No handler - read data for legacy behavior + log.Debug("No blob handler callback configured", "blob_hash", finished.Hash[:8]+"...") + if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { + p.cleanupTempFile() + return fmt.Errorf("seeking to read data: %w", err) + } + + data, err := io.ReadAll(p.currentBlob.tempFile) + if err != nil { + p.cleanupTempFile() + return fmt.Errorf("reading blob data: %w", err) + } + finished.Data = data + p.finishedBlobs = append(p.finishedBlobs, finished) + p.cleanupTempFile() + p.currentBlob = nil return nil } diff --git a/internal/cli/restore.go b/internal/cli/restore.go index c69bf6e..33e3618 100644 --- a/internal/cli/restore.go +++ b/internal/cli/restore.go @@ -57,76 +57,7 @@ Examples: vaultik restore --verify myhost_docs_2025-01-01T12:00:00Z /restore`, Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - snapshotID := args[0] - opts.TargetDir = args[1] - if len(args) > 2 { - opts.Paths = args[2:] - } - - // Use unified config resolution - configPath, err := ResolveConfigPath() - if err != nil { - return err - } - - // Use the app framework like other commands - rootFlags := GetRootFlags() - return RunWithApp(cmd.Context(), AppOptions{ - ConfigPath: configPath, - LogOptions: log.LogOptions{ - Verbose: rootFlags.Verbose, - Debug: rootFlags.Debug, - Quiet: rootFlags.Quiet, - }, - Modules: []fx.Option{ - fx.Provide(fx.Annotate( - func(g *globals.Globals, cfg *config.Config, - storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp { - return &RestoreApp{ - Globals: g, - Config: cfg, - Storage: storer, - Vaultik: v, - Shutdowner: shutdowner, - } - }, - )), - }, - Invokes: []fx.Option{ - fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - // Start the restore operation in a goroutine - go func() { - // Run the restore operation - restoreOpts := &vaultik.RestoreOptions{ - SnapshotID: snapshotID, - TargetDir: opts.TargetDir, - Paths: opts.Paths, - Verify: opts.Verify, - } - if err := app.Vaultik.Restore(restoreOpts); err != nil { - if err != context.Canceled { - log.Error("Restore operation failed", "error", err) - } - } - - // Shutdown the app when restore completes - if err := app.Shutdowner.Shutdown(); err != nil { - log.Error("Failed to shutdown", "error", err) - } - }() - return nil - }, - OnStop: func(ctx context.Context) error { - log.Debug("Stopping restore operation") - app.Vaultik.Cancel() - return nil - }, - }) - }), - }, - }) + return runRestore(cmd, args, opts) }, } @@ -134,3 +65,87 @@ Examples: return cmd } + +// runRestore parses arguments and runs the restore operation through the app framework +func runRestore(cmd *cobra.Command, args []string, opts *RestoreOptions) error { + snapshotID := args[0] + opts.TargetDir = args[1] + if len(args) > 2 { + opts.Paths = args[2:] + } + + // Use unified config resolution + configPath, err := ResolveConfigPath() + if err != nil { + return err + } + + // Use the app framework like other commands + rootFlags := GetRootFlags() + return RunWithApp(cmd.Context(), AppOptions{ + ConfigPath: configPath, + LogOptions: log.LogOptions{ + Verbose: rootFlags.Verbose, + Debug: rootFlags.Debug, + Quiet: rootFlags.Quiet, + }, + Modules: buildRestoreModules(), + Invokes: buildRestoreInvokes(snapshotID, opts), + }) +} + +// buildRestoreModules returns the fx.Options for dependency injection in restore +func buildRestoreModules() []fx.Option { + return []fx.Option{ + fx.Provide(fx.Annotate( + func(g *globals.Globals, cfg *config.Config, + storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp { + return &RestoreApp{ + Globals: g, + Config: cfg, + Storage: storer, + Vaultik: v, + Shutdowner: shutdowner, + } + }, + )), + } +} + +// buildRestoreInvokes returns the fx.Options that wire up the restore lifecycle +func buildRestoreInvokes(snapshotID string, opts *RestoreOptions) []fx.Option { + return []fx.Option{ + fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + // Start the restore operation in a goroutine + go func() { + // Run the restore operation + restoreOpts := &vaultik.RestoreOptions{ + SnapshotID: snapshotID, + TargetDir: opts.TargetDir, + Paths: opts.Paths, + Verify: opts.Verify, + } + if err := app.Vaultik.Restore(restoreOpts); err != nil { + if err != context.Canceled { + log.Error("Restore operation failed", "error", err) + } + } + + // Shutdown the app when restore completes + if err := app.Shutdowner.Shutdown(); err != nil { + log.Error("Failed to shutdown", "error", err) + } + }() + return nil + }, + OnStop: func(ctx context.Context) error { + log.Debug("Stopping restore operation") + app.Vaultik.Cancel() + return nil + }, + }) + }), + } +} diff --git a/internal/snapshot/scanner.go b/internal/snapshot/scanner.go index ca403b4..8804dc2 100644 --- a/internal/snapshot/scanner.go +++ b/internal/snapshot/scanner.go @@ -180,18 +180,10 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc } // Phase 0: Load known files and chunks from database into memory for fast lookup - fmt.Println("Loading known files from database...") - knownFiles, err := s.loadKnownFiles(ctx, path) + knownFiles, err := s.loadDatabaseState(ctx, path) if err != nil { - return nil, fmt.Errorf("loading known files: %w", err) + return nil, err } - fmt.Printf("Loaded %s known files from database\n", formatNumber(len(knownFiles))) - - fmt.Println("Loading known chunks from database...") - if err := s.loadKnownChunks(ctx); err != nil { - return nil, fmt.Errorf("loading known chunks: %w", err) - } - fmt.Printf("Loaded %s known chunks from database\n", formatNumber(len(s.knownChunks))) // Phase 1: Scan directory, collect files to process, and track existing files // (builds existingFiles map during walk to avoid double traversal) @@ -216,36 +208,8 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc } } - // Calculate total size to process - var totalSizeToProcess int64 - for _, file := range filesToProcess { - totalSizeToProcess += file.FileInfo.Size() - } - - // Update progress with total size and file count - if s.progress != nil { - s.progress.SetTotalSize(totalSizeToProcess) - s.progress.GetStats().TotalFiles.Store(int64(len(filesToProcess))) - } - - log.Info("Phase 1 complete", - "total_files", len(filesToProcess), - "total_size", humanize.Bytes(uint64(totalSizeToProcess)), - "files_skipped", result.FilesSkipped, - "bytes_skipped", humanize.Bytes(uint64(result.BytesSkipped))) - - // Print scan summary - fmt.Printf("Scan complete: %s examined (%s), %s to process (%s)", - formatNumber(result.FilesScanned), - humanize.Bytes(uint64(totalSizeToProcess+result.BytesSkipped)), - formatNumber(len(filesToProcess)), - humanize.Bytes(uint64(totalSizeToProcess))) - if result.FilesDeleted > 0 { - fmt.Printf(", %s deleted (%s)", - formatNumber(result.FilesDeleted), - humanize.Bytes(uint64(result.BytesDeleted))) - } - fmt.Println() + // Summarize scan phase results and update progress + s.summarizeScanPhase(result, filesToProcess) // Phase 2: Process files and create chunks if len(filesToProcess) > 0 { @@ -259,7 +223,66 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc log.Info("Phase 2/3: Skipping (no files need processing, metadata-only snapshot)") } - // Get final stats from packer + // Finalize result with blob statistics + s.finalizeScanResult(ctx, result) + + return result, nil +} + +// loadDatabaseState loads known files and chunks from the database into memory for fast lookup +// This avoids per-file and per-chunk database queries during the scan and process phases +func (s *Scanner) loadDatabaseState(ctx context.Context, path string) (map[string]*database.File, error) { + fmt.Println("Loading known files from database...") + knownFiles, err := s.loadKnownFiles(ctx, path) + if err != nil { + return nil, fmt.Errorf("loading known files: %w", err) + } + fmt.Printf("Loaded %s known files from database\n", formatNumber(len(knownFiles))) + + fmt.Println("Loading known chunks from database...") + if err := s.loadKnownChunks(ctx); err != nil { + return nil, fmt.Errorf("loading known chunks: %w", err) + } + fmt.Printf("Loaded %s known chunks from database\n", formatNumber(len(s.knownChunks))) + + return knownFiles, nil +} + +// summarizeScanPhase calculates total size to process, updates progress tracking, +// and prints the scan phase summary with file counts and sizes +func (s *Scanner) summarizeScanPhase(result *ScanResult, filesToProcess []*FileToProcess) { + var totalSizeToProcess int64 + for _, file := range filesToProcess { + totalSizeToProcess += file.FileInfo.Size() + } + + if s.progress != nil { + s.progress.SetTotalSize(totalSizeToProcess) + s.progress.GetStats().TotalFiles.Store(int64(len(filesToProcess))) + } + + log.Info("Phase 1 complete", + "total_files", len(filesToProcess), + "total_size", humanize.Bytes(uint64(totalSizeToProcess)), + "files_skipped", result.FilesSkipped, + "bytes_skipped", humanize.Bytes(uint64(result.BytesSkipped))) + + fmt.Printf("Scan complete: %s examined (%s), %s to process (%s)", + formatNumber(result.FilesScanned), + humanize.Bytes(uint64(totalSizeToProcess+result.BytesSkipped)), + formatNumber(len(filesToProcess)), + humanize.Bytes(uint64(totalSizeToProcess))) + if result.FilesDeleted > 0 { + fmt.Printf(", %s deleted (%s)", + formatNumber(result.FilesDeleted), + humanize.Bytes(uint64(result.BytesDeleted))) + } + fmt.Println() +} + +// finalizeScanResult populates final blob statistics in the scan result +// by querying the packer and database for blob/upload counts +func (s *Scanner) finalizeScanResult(ctx context.Context, result *ScanResult) { blobs := s.packer.GetFinishedBlobs() result.BlobsCreated += len(blobs) @@ -276,7 +299,6 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc } result.EndTime = time.Now().UTC() - return result, nil } // loadKnownFiles loads all known files from the database into a map for fast lookup @@ -424,12 +446,38 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { flushStart := time.Now() log.Debug("flushCompletedPendingFiles: starting") + // Partition pending files into those ready to flush and those still waiting + canFlush, stillPendingCount := s.partitionPendingByChunkStatus() + + if len(canFlush) == 0 { + log.Debug("flushCompletedPendingFiles: nothing to flush") + return nil + } + + log.Debug("Flushing completed files after blob finalize", + "files_to_flush", len(canFlush), + "files_still_pending", stillPendingCount) + + // Collect all data for batch operations + allFiles, allFileIDs, allFileChunks, allChunkFiles := s.collectBatchFlushData(canFlush) + + // Execute the batch flush in a single transaction + log.Debug("flushCompletedPendingFiles: starting transaction") + txStart := time.Now() + err := s.executeBatchFileFlush(ctx, allFiles, allFileIDs, allFileChunks, allChunkFiles) + log.Debug("flushCompletedPendingFiles: transaction done", "duration", time.Since(txStart)) + log.Debug("flushCompletedPendingFiles: total duration", "duration", time.Since(flushStart)) + return err +} + +// partitionPendingByChunkStatus separates pending files into those whose chunks +// are all committed to DB (ready to flush) and those still waiting on pending chunks. +// Updates s.pendingFiles to contain only the still-pending files. +func (s *Scanner) partitionPendingByChunkStatus() (canFlush []pendingFileData, stillPendingCount int) { log.Debug("flushCompletedPendingFiles: acquiring pendingFilesMu lock") s.pendingFilesMu.Lock() log.Debug("flushCompletedPendingFiles: acquired lock", "pending_files", len(s.pendingFiles)) - // Separate files into complete (can flush) and incomplete (keep pending) - var canFlush []pendingFileData var stillPending []pendingFileData log.Debug("flushCompletedPendingFiles: checking which files can flush") @@ -454,18 +502,15 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { s.pendingFilesMu.Unlock() log.Debug("flushCompletedPendingFiles: released lock") - if len(canFlush) == 0 { - log.Debug("flushCompletedPendingFiles: nothing to flush") - return nil - } + return canFlush, len(stillPending) +} - log.Debug("Flushing completed files after blob finalize", - "files_to_flush", len(canFlush), - "files_still_pending", len(stillPending)) - - // Collect all data for batch operations +// collectBatchFlushData aggregates file records, IDs, file-chunk mappings, and chunk-file +// mappings from the given pending file data for efficient batch database operations +func (s *Scanner) collectBatchFlushData(canFlush []pendingFileData) ([]*database.File, []types.FileID, []database.FileChunk, []database.ChunkFile) { log.Debug("flushCompletedPendingFiles: collecting data for batch ops") collectStart := time.Now() + var allFileChunks []database.FileChunk var allChunkFiles []database.ChunkFile var allFileIDs []types.FileID @@ -477,16 +522,20 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { allFileIDs = append(allFileIDs, data.file.ID) allFiles = append(allFiles, data.file) } + log.Debug("flushCompletedPendingFiles: collected data", "duration", time.Since(collectStart), "file_chunks", len(allFileChunks), "chunk_files", len(allChunkFiles), "files", len(allFiles)) - // Flush the complete files using batch operations - log.Debug("flushCompletedPendingFiles: starting transaction") - txStart := time.Now() - err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { + return allFiles, allFileIDs, allFileChunks, allChunkFiles +} + +// executeBatchFileFlush writes all collected file data to the database in a single transaction, +// including deleting old mappings, creating file records, and adding snapshot associations +func (s *Scanner) executeBatchFileFlush(ctx context.Context, allFiles []*database.File, allFileIDs []types.FileID, allFileChunks []database.FileChunk, allChunkFiles []database.ChunkFile) error { + return s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { log.Debug("flushCompletedPendingFiles: inside transaction") // Batch delete old file_chunks and chunk_files @@ -539,9 +588,6 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { log.Debug("flushCompletedPendingFiles: transaction complete") return nil }) - log.Debug("flushCompletedPendingFiles: transaction done", "duration", time.Since(txStart)) - log.Debug("flushCompletedPendingFiles: total duration", "duration", time.Since(flushStart)) - return err } // ScanPhaseResult contains the results of the scan phase @@ -623,62 +669,11 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult mu.Unlock() // Update result stats - if needsProcessing { - result.BytesScanned += info.Size() - if s.progress != nil { - s.progress.GetStats().BytesScanned.Add(info.Size()) - } - } else { - result.FilesSkipped++ - result.BytesSkipped += info.Size() - if s.progress != nil { - s.progress.GetStats().FilesSkipped.Add(1) - s.progress.GetStats().BytesSkipped.Add(info.Size()) - } - } - result.FilesScanned++ - if s.progress != nil { - s.progress.GetStats().FilesScanned.Add(1) - } + s.updateScanEntryStats(result, needsProcessing, info) // Output periodic status if time.Since(lastStatusTime) >= statusInterval { - elapsed := time.Since(startTime) - rate := float64(filesScanned) / elapsed.Seconds() - - // Build status line - use estimate if available (not first backup) - if estimatedTotal > 0 { - // Show actual scanned vs estimate (may exceed estimate if files were added) - pct := float64(filesScanned) / float64(estimatedTotal) * 100 - if pct > 100 { - pct = 100 // Cap at 100% for display - } - remaining := estimatedTotal - filesScanned - if remaining < 0 { - remaining = 0 - } - var eta time.Duration - if rate > 0 && remaining > 0 { - eta = time.Duration(float64(remaining)/rate) * time.Second - } - fmt.Printf("Scan: %s files (~%.0f%%), %s changed/new, %.0f files/sec, %s elapsed", - formatNumber(int(filesScanned)), - pct, - formatNumber(changedCount), - rate, - elapsed.Round(time.Second)) - if eta > 0 { - fmt.Printf(", ETA %s", eta.Round(time.Second)) - } - fmt.Println() - } else { - // First backup - no estimate available - fmt.Printf("Scan: %s files, %s changed/new, %.0f files/sec, %s elapsed\n", - formatNumber(int(filesScanned)), - formatNumber(changedCount), - rate, - elapsed.Round(time.Second)) - } + printScanProgressLine(filesScanned, changedCount, estimatedTotal, startTime) lastStatusTime = time.Now() } @@ -695,6 +690,68 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult }, nil } +// updateScanEntryStats updates the scan result and progress reporter statistics +// for a single scanned file entry based on whether it needs processing +func (s *Scanner) updateScanEntryStats(result *ScanResult, needsProcessing bool, info os.FileInfo) { + if needsProcessing { + result.BytesScanned += info.Size() + if s.progress != nil { + s.progress.GetStats().BytesScanned.Add(info.Size()) + } + } else { + result.FilesSkipped++ + result.BytesSkipped += info.Size() + if s.progress != nil { + s.progress.GetStats().FilesSkipped.Add(1) + s.progress.GetStats().BytesSkipped.Add(info.Size()) + } + } + result.FilesScanned++ + if s.progress != nil { + s.progress.GetStats().FilesScanned.Add(1) + } +} + +// printScanProgressLine prints a periodic progress line during the scan phase, +// showing files scanned, percentage complete (if estimate available), and ETA +func printScanProgressLine(filesScanned int64, changedCount int, estimatedTotal int64, startTime time.Time) { + elapsed := time.Since(startTime) + rate := float64(filesScanned) / elapsed.Seconds() + + if estimatedTotal > 0 { + // Show actual scanned vs estimate (may exceed estimate if files were added) + pct := float64(filesScanned) / float64(estimatedTotal) * 100 + if pct > 100 { + pct = 100 // Cap at 100% for display + } + remaining := estimatedTotal - filesScanned + if remaining < 0 { + remaining = 0 + } + var eta time.Duration + if rate > 0 && remaining > 0 { + eta = time.Duration(float64(remaining)/rate) * time.Second + } + fmt.Printf("Scan: %s files (~%.0f%%), %s changed/new, %.0f files/sec, %s elapsed", + formatNumber(int(filesScanned)), + pct, + formatNumber(changedCount), + rate, + elapsed.Round(time.Second)) + if eta > 0 { + fmt.Printf(", ETA %s", eta.Round(time.Second)) + } + fmt.Println() + } else { + // First backup - no estimate available + fmt.Printf("Scan: %s files, %s changed/new, %.0f files/sec, %s elapsed\n", + formatNumber(int(filesScanned)), + formatNumber(changedCount), + rate, + elapsed.Round(time.Second)) + } +} + // checkFileInMemory checks if a file needs processing using the in-memory map // No database access is performed - this is purely CPU/memory work func (s *Scanner) checkFileInMemory(path string, info os.FileInfo, knownFiles map[string]*database.File) (*database.File, bool) { @@ -830,22 +887,13 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc s.progress.GetStats().CurrentFile.Store(fileToProcess.Path) } - // Process file in streaming fashion - if err := s.processFileStreaming(ctx, fileToProcess, result); err != nil { - // Handle files that were deleted between scan and process phases - if errors.Is(err, os.ErrNotExist) { - log.Warn("File was deleted during backup, skipping", "path", fileToProcess.Path) - result.FilesSkipped++ - continue - } - // Skip file read errors if --skip-errors is enabled - if s.skipErrors { - log.Error("ERROR: Failed to process file (skipping due to --skip-errors)", "path", fileToProcess.Path, "error", err) - fmt.Printf("ERROR: Failed to process %s: %v (skipping)\n", fileToProcess.Path, err) - result.FilesSkipped++ - continue - } - return fmt.Errorf("processing file %s: %w", fileToProcess.Path, err) + // Process file with error handling for deleted files and skip-errors mode + skipped, err := s.processFileWithErrorHandling(ctx, fileToProcess, result) + if err != nil { + return err + } + if skipped { + continue } // Update files processed counter @@ -858,36 +906,71 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc // Output periodic status if time.Since(lastStatusTime) >= statusInterval { - elapsed := time.Since(startTime) - pct := float64(bytesProcessed) / float64(totalBytes) * 100 - byteRate := float64(bytesProcessed) / elapsed.Seconds() - fileRate := float64(filesProcessed) / elapsed.Seconds() - - // Calculate ETA based on bytes (more accurate than files) - remainingBytes := totalBytes - bytesProcessed - var eta time.Duration - if byteRate > 0 { - eta = time.Duration(float64(remainingBytes)/byteRate) * time.Second - } - - // Format: Progress [5.7k/610k] 6.7 GB/44 GB (15.4%), 106MB/sec, 500 files/sec, running for 1m30s, ETA: 5m49s - fmt.Printf("Progress [%s/%s] %s/%s (%.1f%%), %s/sec, %.0f files/sec, running for %s", - formatCompact(filesProcessed), - formatCompact(totalFiles), - humanize.Bytes(uint64(bytesProcessed)), - humanize.Bytes(uint64(totalBytes)), - pct, - humanize.Bytes(uint64(byteRate)), - fileRate, - elapsed.Round(time.Second)) - if eta > 0 { - fmt.Printf(", ETA: %s", eta.Round(time.Second)) - } - fmt.Println() + printProcessingProgress(filesProcessed, totalFiles, bytesProcessed, totalBytes, startTime) lastStatusTime = time.Now() } } + // Finalize: flush packer, pending files, and handle local blobs + return s.finalizeProcessPhase(ctx, result) +} + +// processFileWithErrorHandling wraps processFileStreaming with error recovery for +// deleted files and skip-errors mode. Returns (skipped, error). +func (s *Scanner) processFileWithErrorHandling(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) (bool, error) { + if err := s.processFileStreaming(ctx, fileToProcess, result); err != nil { + // Handle files that were deleted between scan and process phases + if errors.Is(err, os.ErrNotExist) { + log.Warn("File was deleted during backup, skipping", "path", fileToProcess.Path) + result.FilesSkipped++ + return true, nil + } + // Skip file read errors if --skip-errors is enabled + if s.skipErrors { + log.Error("ERROR: Failed to process file (skipping due to --skip-errors)", "path", fileToProcess.Path, "error", err) + fmt.Printf("ERROR: Failed to process %s: %v (skipping)\n", fileToProcess.Path, err) + result.FilesSkipped++ + return true, nil + } + return false, fmt.Errorf("processing file %s: %w", fileToProcess.Path, err) + } + return false, nil +} + +// printProcessingProgress prints a periodic progress line during the process phase, +// showing files processed, bytes transferred, throughput, and ETA +func printProcessingProgress(filesProcessed, totalFiles int, bytesProcessed, totalBytes int64, startTime time.Time) { + elapsed := time.Since(startTime) + pct := float64(bytesProcessed) / float64(totalBytes) * 100 + byteRate := float64(bytesProcessed) / elapsed.Seconds() + fileRate := float64(filesProcessed) / elapsed.Seconds() + + // Calculate ETA based on bytes (more accurate than files) + remainingBytes := totalBytes - bytesProcessed + var eta time.Duration + if byteRate > 0 { + eta = time.Duration(float64(remainingBytes)/byteRate) * time.Second + } + + // Format: Progress [5.7k/610k] 6.7 GB/44 GB (15.4%), 106MB/sec, 500 files/sec, running for 1m30s, ETA: 5m49s + fmt.Printf("Progress [%s/%s] %s/%s (%.1f%%), %s/sec, %.0f files/sec, running for %s", + formatCompact(filesProcessed), + formatCompact(totalFiles), + humanize.Bytes(uint64(bytesProcessed)), + humanize.Bytes(uint64(totalBytes)), + pct, + humanize.Bytes(uint64(byteRate)), + fileRate, + elapsed.Round(time.Second)) + if eta > 0 { + fmt.Printf(", ETA: %s", eta.Round(time.Second)) + } + fmt.Println() +} + +// finalizeProcessPhase flushes the packer, writes remaining pending files to the database, +// and handles local blob storage when no remote storage is configured +func (s *Scanner) finalizeProcessPhase(ctx context.Context, result *ScanResult) error { // Final packer flush first - this commits remaining chunks to DB // and handleBlobReady will flush files whose chunks are now committed s.packerMu.Lock() @@ -931,40 +1014,103 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { startTime := time.Now().UTC() finishedBlob := blobWithReader.FinishedBlob - // Report upload start and increment blobs created if s.progress != nil { s.progress.ReportUploadStart(finishedBlob.Hash, finishedBlob.Compressed) s.progress.GetStats().BlobsCreated.Add(1) } - // Upload to storage first (without holding any locks) - // Use scan context for cancellation support ctx := s.scanCtx if ctx == nil { ctx = context.Background() } - // Track bytes uploaded for accurate speed calculation + blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash) + blobExists, err := s.uploadBlobIfNeeded(ctx, blobPath, blobWithReader, startTime) + if err != nil { + s.cleanupBlobTempFile(blobWithReader) + return fmt.Errorf("uploading blob %s: %w", finishedBlob.Hash, err) + } + + if err := s.recordBlobMetadata(ctx, finishedBlob, blobExists, startTime); err != nil { + s.cleanupBlobTempFile(blobWithReader) + return err + } + + s.cleanupBlobTempFile(blobWithReader) + + // Chunks from this blob are now committed to DB - remove from pending set + s.removePendingChunkHashes(blobWithReader.InsertedChunkHashes) + + // Flush files whose chunks are now all committed + if err := s.flushCompletedPendingFiles(ctx); err != nil { + return fmt.Errorf("flushing completed files: %w", err) + } + + return nil +} + +// uploadBlobIfNeeded uploads the blob to storage if it doesn't already exist, returns whether it existed +func (s *Scanner) uploadBlobIfNeeded(ctx context.Context, blobPath string, blobWithReader *blob.BlobWithReader, startTime time.Time) (bool, error) { + finishedBlob := blobWithReader.FinishedBlob + + // Check if blob already exists (deduplication after restart) + if _, err := s.storage.Stat(ctx, blobPath); err == nil { + log.Info("Blob already exists in storage, skipping upload", + "hash", finishedBlob.Hash, "size", humanize.Bytes(uint64(finishedBlob.Compressed))) + fmt.Printf("Blob exists: %s (%s, skipped upload)\n", + finishedBlob.Hash[:12]+"...", humanize.Bytes(uint64(finishedBlob.Compressed))) + return true, nil + } + + progressCallback := s.makeUploadProgressCallback(ctx, finishedBlob) + + if err := s.storage.PutWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil { + log.Error("Failed to upload blob", "hash", finishedBlob.Hash, "error", err) + return false, fmt.Errorf("uploading blob to storage: %w", err) + } + + uploadDuration := time.Since(startTime) + uploadSpeedBps := float64(finishedBlob.Compressed) / uploadDuration.Seconds() + + fmt.Printf("Blob stored: %s (%s, %s/sec, %s)\n", + finishedBlob.Hash[:12]+"...", + humanize.Bytes(uint64(finishedBlob.Compressed)), + humanize.Bytes(uint64(uploadSpeedBps)), + uploadDuration.Round(time.Millisecond)) + + log.Info("Successfully uploaded blob to storage", + "path", blobPath, + "size", humanize.Bytes(uint64(finishedBlob.Compressed)), + "duration", uploadDuration, + "speed", humanize.SI(uploadSpeedBps*8, "bps")) + + if s.progress != nil { + s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration) + stats := s.progress.GetStats() + stats.BlobsUploaded.Add(1) + stats.BytesUploaded.Add(finishedBlob.Compressed) + } + + return false, nil +} + +// makeUploadProgressCallback creates a progress callback for blob uploads +func (s *Scanner) makeUploadProgressCallback(ctx context.Context, finishedBlob *blob.FinishedBlob) func(int64) error { lastProgressTime := time.Now() lastProgressBytes := int64(0) - progressCallback := func(uploaded int64) error { - // Calculate instantaneous speed + return func(uploaded int64) error { now := time.Now() elapsed := now.Sub(lastProgressTime).Seconds() - if elapsed > 0.5 { // Update speed every 0.5 seconds + if elapsed > 0.5 { bytesSinceLastUpdate := uploaded - lastProgressBytes speed := float64(bytesSinceLastUpdate) / elapsed - if s.progress != nil { s.progress.ReportUploadProgress(finishedBlob.Hash, uploaded, finishedBlob.Compressed, speed) } - lastProgressTime = now lastProgressBytes = uploaded } - - // Check for cancellation select { case <-ctx.Done(): return ctx.Err() @@ -972,87 +1118,26 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { return nil } } +} - // Create sharded path: blobs/ca/fe/cafebabe... - blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash) - - // Check if blob already exists in remote storage (deduplication after restart) - blobExists := false - if _, err := s.storage.Stat(ctx, blobPath); err == nil { - blobExists = true - log.Info("Blob already exists in storage, skipping upload", - "hash", finishedBlob.Hash, - "size", humanize.Bytes(uint64(finishedBlob.Compressed))) - fmt.Printf("Blob exists: %s (%s, skipped upload)\n", - finishedBlob.Hash[:12]+"...", - humanize.Bytes(uint64(finishedBlob.Compressed))) - } - - if !blobExists { - if err := s.storage.PutWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil { - return fmt.Errorf("uploading blob %s to storage: %w", finishedBlob.Hash, err) - } - - uploadDuration := time.Since(startTime) - - // Calculate upload speed - uploadSpeedBps := float64(finishedBlob.Compressed) / uploadDuration.Seconds() - - // Print blob stored message - fmt.Printf("Blob stored: %s (%s, %s/sec, %s)\n", - finishedBlob.Hash[:12]+"...", - humanize.Bytes(uint64(finishedBlob.Compressed)), - humanize.Bytes(uint64(uploadSpeedBps)), - uploadDuration.Round(time.Millisecond)) - - // Log upload stats - uploadSpeedBits := uploadSpeedBps * 8 // bits per second - log.Info("Successfully uploaded blob to storage", - "path", blobPath, - "size", humanize.Bytes(uint64(finishedBlob.Compressed)), - "duration", uploadDuration, - "speed", humanize.SI(uploadSpeedBits, "bps")) - - // Report upload complete - if s.progress != nil { - s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration) - } - - // Update progress after upload completes - if s.progress != nil { - stats := s.progress.GetStats() - stats.BlobsUploaded.Add(1) - stats.BytesUploaded.Add(finishedBlob.Compressed) - } - } - - // Store metadata in database (after upload is complete) - dbCtx := s.scanCtx - if dbCtx == nil { - dbCtx = context.Background() - } - - // Parse blob ID for typed operations +// recordBlobMetadata stores blob upload metadata in the database +func (s *Scanner) recordBlobMetadata(ctx context.Context, finishedBlob *blob.FinishedBlob, blobExists bool, startTime time.Time) error { finishedBlobID, err := types.ParseBlobID(finishedBlob.ID) if err != nil { return fmt.Errorf("parsing finished blob ID: %w", err) } - // Track upload duration (0 if blob already existed) uploadDuration := time.Since(startTime) - err = s.repos.WithTx(dbCtx, func(ctx context.Context, tx *sql.Tx) error { - // Update blob upload timestamp - if err := s.repos.Blobs.UpdateUploaded(ctx, tx, finishedBlob.ID); err != nil { + return s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { + if err := s.repos.Blobs.UpdateUploaded(txCtx, tx, finishedBlob.ID); err != nil { return fmt.Errorf("updating blob upload timestamp: %w", err) } - // Add the blob to the snapshot - if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, finishedBlobID, types.BlobHash(finishedBlob.Hash)); err != nil { + if err := s.repos.Snapshots.AddBlob(txCtx, tx, s.snapshotID, finishedBlobID, types.BlobHash(finishedBlob.Hash)); err != nil { return fmt.Errorf("adding blob to snapshot: %w", err) } - // Record upload metrics (only for actual uploads, not deduplicated blobs) if !blobExists { upload := &database.Upload{ BlobHash: finishedBlob.Hash, @@ -1061,15 +1146,17 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { Size: finishedBlob.Compressed, DurationMs: uploadDuration.Milliseconds(), } - if err := s.repos.Uploads.Create(ctx, tx, upload); err != nil { + if err := s.repos.Uploads.Create(txCtx, tx, upload); err != nil { return fmt.Errorf("recording upload metrics: %w", err) } } return nil }) +} - // Cleanup temp file if needed +// cleanupBlobTempFile closes and removes the blob's temporary file +func (s *Scanner) cleanupBlobTempFile(blobWithReader *blob.BlobWithReader) { if blobWithReader.TempFile != nil { tempName := blobWithReader.TempFile.Name() if err := blobWithReader.TempFile.Close(); err != nil { @@ -1079,77 +1166,41 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { log.Fatal("Failed to remove temp file", "file", tempName, "error", err) } } +} - if err != nil { - return err - } - - // Chunks from this blob are now committed to DB - remove from pending set - log.Debug("handleBlobReady: removing pending chunk hashes") - s.removePendingChunkHashes(blobWithReader.InsertedChunkHashes) - log.Debug("handleBlobReady: removed pending chunk hashes") - - // Flush files whose chunks are now all committed - // This maintains database consistency after each blob - log.Debug("handleBlobReady: calling flushCompletedPendingFiles") - if err := s.flushCompletedPendingFiles(dbCtx); err != nil { - return fmt.Errorf("flushing completed files: %w", err) - } - log.Debug("handleBlobReady: flushCompletedPendingFiles returned") - - log.Debug("handleBlobReady: complete") - return nil +// streamingChunkInfo tracks chunk metadata collected during streaming +type streamingChunkInfo struct { + fileChunk database.FileChunk + offset int64 + size int64 } // processFileStreaming processes a file by streaming chunks directly to the packer func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) error { - // Open the file file, err := s.fs.Open(fileToProcess.Path) if err != nil { return fmt.Errorf("opening file: %w", err) } defer func() { _ = file.Close() }() - // We'll collect file chunks for database storage - // but process them for packing as we go - type chunkInfo struct { - fileChunk database.FileChunk - offset int64 - size int64 - } - var chunks []chunkInfo + var chunks []streamingChunkInfo chunkIndex := 0 - // Process chunks in streaming fashion and get full file hash fileHash, err := s.chunker.ChunkReaderStreaming(file, func(chunk chunker.Chunk) error { - // Check for cancellation select { case <-ctx.Done(): return ctx.Err() default: } - log.Debug("Processing content-defined chunk from file", - "file", fileToProcess.Path, - "chunk_index", chunkIndex, - "hash", chunk.Hash, - "size", chunk.Size) - - // Check if chunk already exists (fast in-memory lookup) chunkExists := s.chunkExists(chunk.Hash) - - // Queue new chunks for batch insert when blob finalizes - // This dramatically reduces transaction overhead if !chunkExists { s.packer.AddPendingChunk(chunk.Hash, chunk.Size) - // Add to in-memory cache immediately for fast duplicate detection s.addKnownChunk(chunk.Hash) - // Track as pending until blob finalizes and commits to DB s.addPendingChunkHash(chunk.Hash) } - // Track file chunk association for later storage - chunks = append(chunks, chunkInfo{ + chunks = append(chunks, streamingChunkInfo{ fileChunk: database.FileChunk{ FileID: fileToProcess.File.ID, Idx: chunkIndex, @@ -1159,55 +1210,15 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT size: chunk.Size, }) - // Update stats - if chunkExists { - result.FilesSkipped++ // Track as skipped for now - result.BytesSkipped += chunk.Size - if s.progress != nil { - s.progress.GetStats().BytesSkipped.Add(chunk.Size) - } - } else { - result.ChunksCreated++ - result.BytesScanned += chunk.Size - if s.progress != nil { - s.progress.GetStats().ChunksCreated.Add(1) - s.progress.GetStats().BytesProcessed.Add(chunk.Size) - s.progress.UpdateChunkingActivity() - } - } + s.updateChunkStats(chunkExists, chunk.Size, result) - // Add chunk to packer immediately (streaming) - // This happens outside the database transaction if !chunkExists { - s.packerMu.Lock() - err := s.packer.AddChunk(&blob.ChunkRef{ - Hash: chunk.Hash, - Data: chunk.Data, - }) - if err == blob.ErrBlobSizeLimitExceeded { - // Finalize current blob and retry - if err := s.packer.FinalizeBlob(); err != nil { - s.packerMu.Unlock() - return fmt.Errorf("finalizing blob: %w", err) - } - // Retry adding the chunk - if err := s.packer.AddChunk(&blob.ChunkRef{ - Hash: chunk.Hash, - Data: chunk.Data, - }); err != nil { - s.packerMu.Unlock() - return fmt.Errorf("adding chunk after finalize: %w", err) - } - } else if err != nil { - s.packerMu.Unlock() - return fmt.Errorf("adding chunk to packer: %w", err) + if err := s.addChunkToPacker(chunk); err != nil { + return err } - s.packerMu.Unlock() } - // Clear chunk data from memory immediately after use chunk.Data = nil - chunkIndex++ return nil }) @@ -1217,12 +1228,54 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT } log.Debug("Completed snapshotting file", - "path", fileToProcess.Path, - "file_hash", fileHash, - "chunks", len(chunks)) + "path", fileToProcess.Path, "file_hash", fileHash, "chunks", len(chunks)) - // Build file data for batch insertion - // Update chunk associations with the file ID + s.queueFileForBatchInsert(ctx, fileToProcess, chunks) + return nil +} + +// updateChunkStats updates scan result and progress stats for a processed chunk +func (s *Scanner) updateChunkStats(chunkExists bool, chunkSize int64, result *ScanResult) { + if chunkExists { + result.FilesSkipped++ + result.BytesSkipped += chunkSize + if s.progress != nil { + s.progress.GetStats().BytesSkipped.Add(chunkSize) + } + } else { + result.ChunksCreated++ + result.BytesScanned += chunkSize + if s.progress != nil { + s.progress.GetStats().ChunksCreated.Add(1) + s.progress.GetStats().BytesProcessed.Add(chunkSize) + s.progress.UpdateChunkingActivity() + } + } +} + +// addChunkToPacker adds a chunk to the blob packer, finalizing the current blob if needed +func (s *Scanner) addChunkToPacker(chunk chunker.Chunk) error { + s.packerMu.Lock() + err := s.packer.AddChunk(&blob.ChunkRef{Hash: chunk.Hash, Data: chunk.Data}) + if err == blob.ErrBlobSizeLimitExceeded { + if err := s.packer.FinalizeBlob(); err != nil { + s.packerMu.Unlock() + return fmt.Errorf("finalizing blob: %w", err) + } + if err := s.packer.AddChunk(&blob.ChunkRef{Hash: chunk.Hash, Data: chunk.Data}); err != nil { + s.packerMu.Unlock() + return fmt.Errorf("adding chunk after finalize: %w", err) + } + } else if err != nil { + s.packerMu.Unlock() + return fmt.Errorf("adding chunk to packer: %w", err) + } + s.packerMu.Unlock() + return nil +} + +// queueFileForBatchInsert builds file/chunk associations and queues the file for batch DB insert +func (s *Scanner) queueFileForBatchInsert(ctx context.Context, fileToProcess *FileToProcess, chunks []streamingChunkInfo) { fileChunks := make([]database.FileChunk, len(chunks)) chunkFiles := make([]database.ChunkFile, len(chunks)) for i, ci := range chunks { @@ -1239,14 +1292,11 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT } } - // Queue file for batch insertion - // Files will be flushed when their chunks are committed (after blob finalize) s.addPendingFile(ctx, pendingFileData{ file: fileToProcess.File, fileChunks: fileChunks, chunkFiles: chunkFiles, }) - return nil } // GetProgress returns the progress reporter for this scanner diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index bb01ea1..883c572 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -227,12 +227,39 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st } }() + // Steps 1-5: Copy, clean, vacuum, compress, and read the database + finalData, tempDBPath, err := sm.prepareExportDB(ctx, dbPath, snapshotID, tempDir) + if err != nil { + return err + } + + // Step 6: Generate blob manifest (before closing temp DB) + blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) + if err != nil { + return fmt.Errorf("generating blob manifest: %w", err) + } + + // Step 7: Upload to S3 in snapshot subdirectory + if err := sm.uploadSnapshotArtifacts(ctx, snapshotID, finalData, blobManifest); err != nil { + return err + } + + log.Info("Uploaded snapshot metadata", + "snapshot_id", snapshotID, + "db_size", len(finalData), + "manifest_size", len(blobManifest)) + return nil +} + +// prepareExportDB copies, cleans, vacuums, and compresses the snapshot database for export. +// Returns the compressed data and the path to the temporary database (needed for manifest generation). +func (sm *SnapshotManager) prepareExportDB(ctx context.Context, dbPath, snapshotID, tempDir string) ([]byte, string, error) { // Step 1: Copy database to temp file // The main database should be closed at this point tempDBPath := filepath.Join(tempDir, "snapshot.db") log.Debug("Copying database to temporary location", "source", dbPath, "destination", tempDBPath) if err := sm.copyFile(dbPath, tempDBPath); err != nil { - return fmt.Errorf("copying database: %w", err) + return nil, "", fmt.Errorf("copying database: %w", err) } log.Debug("Database copy complete", "size", sm.getFileSize(tempDBPath)) @@ -240,7 +267,7 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st log.Debug("Cleaning temporary database", "snapshot_id", snapshotID) stats, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID) if err != nil { - return fmt.Errorf("cleaning snapshot database: %w", err) + return nil, "", fmt.Errorf("cleaning snapshot database: %w", err) } log.Info("Temporary database cleanup complete", "db_path", tempDBPath, @@ -255,14 +282,14 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st // Step 3: VACUUM the database to remove deleted data and compact // This is critical for security - ensures no stale/deleted data is uploaded if err := sm.vacuumDatabase(tempDBPath); err != nil { - return fmt.Errorf("vacuuming database: %w", err) + return nil, "", fmt.Errorf("vacuuming database: %w", err) } log.Debug("Database vacuumed", "size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath)))) // Step 4: Compress and encrypt the binary database file compressedPath := filepath.Join(tempDir, "db.zst.age") if err := sm.compressFile(tempDBPath, compressedPath); err != nil { - return fmt.Errorf("compressing database: %w", err) + return nil, "", fmt.Errorf("compressing database: %w", err) } log.Debug("Compression complete", "original_size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath))), @@ -271,49 +298,43 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st // Step 5: Read compressed and encrypted data for upload finalData, err := afero.ReadFile(sm.fs, compressedPath) if err != nil { - return fmt.Errorf("reading compressed dump: %w", err) + return nil, "", fmt.Errorf("reading compressed dump: %w", err) } - // Step 6: Generate blob manifest (before closing temp DB) - blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) - if err != nil { - return fmt.Errorf("generating blob manifest: %w", err) - } + return finalData, tempDBPath, nil +} - // Step 7: Upload to S3 in snapshot subdirectory +// uploadSnapshotArtifacts uploads the database backup and blob manifest to S3 +func (sm *SnapshotManager) uploadSnapshotArtifacts(ctx context.Context, snapshotID string, dbData, manifestData []byte) error { // Upload database backup (compressed and encrypted) dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) dbUploadStart := time.Now() - if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(finalData)); err != nil { + if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(dbData)); err != nil { return fmt.Errorf("uploading snapshot database: %w", err) } dbUploadDuration := time.Since(dbUploadStart) - dbUploadSpeed := float64(len(finalData)) * 8 / dbUploadDuration.Seconds() // bits per second + dbUploadSpeed := float64(len(dbData)) * 8 / dbUploadDuration.Seconds() // bits per second log.Info("Uploaded snapshot database", "path", dbKey, - "size", humanize.Bytes(uint64(len(finalData))), + "size", humanize.Bytes(uint64(len(dbData))), "duration", dbUploadDuration, "speed", humanize.SI(dbUploadSpeed, "bps")) // Upload blob manifest (compressed only, not encrypted) manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) manifestUploadStart := time.Now() - if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil { + if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(manifestData)); err != nil { return fmt.Errorf("uploading blob manifest: %w", err) } manifestUploadDuration := time.Since(manifestUploadStart) - manifestUploadSpeed := float64(len(blobManifest)) * 8 / manifestUploadDuration.Seconds() // bits per second + manifestUploadSpeed := float64(len(manifestData)) * 8 / manifestUploadDuration.Seconds() // bits per second log.Info("Uploaded blob manifest", "path", manifestKey, - "size", humanize.Bytes(uint64(len(blobManifest))), + "size", humanize.Bytes(uint64(len(manifestData))), "duration", manifestUploadDuration, "speed", humanize.SI(manifestUploadSpeed, "bps")) - log.Info("Uploaded snapshot metadata", - "snapshot_id", snapshotID, - "db_size", len(finalData), - "manifest_size", len(blobManifest)) return nil } diff --git a/internal/vaultik/info.go b/internal/vaultik/info.go index 53cfc2c..124fc7b 100644 --- a/internal/vaultik/info.go +++ b/internal/vaultik/info.go @@ -149,9 +149,9 @@ type RemoteInfoResult struct { // RemoteInfo displays information about remote storage func (v *Vaultik) RemoteInfo(jsonOutput bool) error { + log.Info("Starting remote storage info gathering") result := &RemoteInfoResult{} - // Get storage info storageInfo := v.Storage.Info() result.StorageType = storageInfo.Type result.StorageLocation = storageInfo.Location @@ -161,23 +161,52 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { v.printfStdout("Type: %s\n", storageInfo.Type) v.printfStdout("Location: %s\n", storageInfo.Location) v.printlnStdout() - } - - // List all snapshot metadata - if !jsonOutput { v.printfStdout("Scanning snapshot metadata...\n") } + snapshotMetadata, snapshotIDs, err := v.collectSnapshotMetadata() + if err != nil { + return err + } + + if !jsonOutput { + v.printfStdout("Downloading %d manifest(s)...\n", len(snapshotIDs)) + } + + referencedBlobs := v.collectReferencedBlobsFromManifests(snapshotIDs, snapshotMetadata) + + v.populateRemoteInfoResult(result, snapshotMetadata, snapshotIDs, referencedBlobs) + + if err := v.scanRemoteBlobStorage(result, referencedBlobs, jsonOutput); err != nil { + return err + } + + log.Info("Remote info complete", + "snapshots", result.TotalMetadataCount, + "total_blobs", result.TotalBlobCount, + "referenced_blobs", result.ReferencedBlobCount, + "orphaned_blobs", result.OrphanedBlobCount) + + if jsonOutput { + enc := json.NewEncoder(v.Stdout) + enc.SetIndent("", " ") + return enc.Encode(result) + } + + v.printRemoteInfoTable(result) + return nil +} + +// collectSnapshotMetadata scans remote metadata and returns per-snapshot info and sorted IDs +func (v *Vaultik) collectSnapshotMetadata() (map[string]*SnapshotMetadataInfo, []string, error) { snapshotMetadata := make(map[string]*SnapshotMetadataInfo) - // Collect metadata files metadataCh := v.Storage.ListStream(v.ctx, "metadata/") for obj := range metadataCh { if obj.Err != nil { - return fmt.Errorf("listing metadata: %w", obj.Err) + return nil, nil, fmt.Errorf("listing metadata: %w", obj.Err) } - // Parse key: metadata// parts := strings.Split(obj.Key, "/") if len(parts) < 3 { continue @@ -185,14 +214,11 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { snapshotID := parts[1] if _, exists := snapshotMetadata[snapshotID]; !exists { - snapshotMetadata[snapshotID] = &SnapshotMetadataInfo{ - SnapshotID: snapshotID, - } + snapshotMetadata[snapshotID] = &SnapshotMetadataInfo{SnapshotID: snapshotID} } info := snapshotMetadata[snapshotID] filename := parts[2] - if strings.HasPrefix(filename, "manifest") { info.ManifestSize = obj.Size } else if strings.HasPrefix(filename, "db") { @@ -201,19 +227,18 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { info.TotalSize = info.ManifestSize + info.DatabaseSize } - // Sort snapshots by ID for consistent output var snapshotIDs []string for id := range snapshotMetadata { snapshotIDs = append(snapshotIDs, id) } sort.Strings(snapshotIDs) - // Download and parse all manifests to get referenced blobs - if !jsonOutput { - v.printfStdout("Downloading %d manifest(s)...\n", len(snapshotIDs)) - } + return snapshotMetadata, snapshotIDs, nil +} - referencedBlobs := make(map[string]int64) // hash -> compressed size +// collectReferencedBlobsFromManifests downloads manifests and returns referenced blob hashes with sizes +func (v *Vaultik) collectReferencedBlobsFromManifests(snapshotIDs []string, snapshotMetadata map[string]*SnapshotMetadataInfo) map[string]int64 { + referencedBlobs := make(map[string]int64) for _, snapshotID := range snapshotIDs { manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) @@ -230,10 +255,8 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { continue } - // Record blob info from manifest info := snapshotMetadata[snapshotID] info.BlobCount = manifest.BlobCount - var blobsSize int64 for _, blob := range manifest.Blobs { referencedBlobs[blob.Hash] = blob.CompressedSize @@ -242,7 +265,11 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { info.BlobsSize = blobsSize } - // Build result snapshots + return referencedBlobs +} + +// populateRemoteInfoResult fills in the result's snapshot and referenced blob stats +func (v *Vaultik) populateRemoteInfoResult(result *RemoteInfoResult, snapshotMetadata map[string]*SnapshotMetadataInfo, snapshotIDs []string, referencedBlobs map[string]int64) { var totalMetadataSize int64 for _, id := range snapshotIDs { info := snapshotMetadata[id] @@ -252,26 +279,25 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { result.TotalMetadataSize = totalMetadataSize result.TotalMetadataCount = len(snapshotIDs) - // Calculate referenced blob stats for _, size := range referencedBlobs { result.ReferencedBlobCount++ result.ReferencedBlobSize += size } +} - // List all blobs on remote +// scanRemoteBlobStorage lists all blobs on remote and computes orphan stats +func (v *Vaultik) scanRemoteBlobStorage(result *RemoteInfoResult, referencedBlobs map[string]int64, jsonOutput bool) error { if !jsonOutput { v.printfStdout("Scanning blobs...\n") } - allBlobs := make(map[string]int64) // hash -> size from storage - blobCh := v.Storage.ListStream(v.ctx, "blobs/") + allBlobs := make(map[string]int64) + for obj := range blobCh { if obj.Err != nil { return fmt.Errorf("listing blobs: %w", obj.Err) } - - // Extract hash from key: blobs/xx/yy/hash parts := strings.Split(obj.Key, "/") if len(parts) < 4 { continue @@ -282,7 +308,6 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { result.TotalBlobSize += obj.Size } - // Calculate orphaned blobs for hash, size := range allBlobs { if _, referenced := referencedBlobs[hash]; !referenced { result.OrphanedBlobCount++ @@ -290,14 +315,11 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { } } - // Output results - if jsonOutput { - enc := json.NewEncoder(v.Stdout) - enc.SetIndent("", " ") - return enc.Encode(result) - } + return nil +} - // Human-readable output +// printRemoteInfoTable renders the human-readable remote info output +func (v *Vaultik) printRemoteInfoTable(result *RemoteInfoResult) { v.printfStdout("\n=== Snapshot Metadata ===\n") if len(result.Snapshots) == 0 { v.printfStdout("No snapshots found\n") @@ -320,20 +342,15 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { v.printfStdout("\n=== Blob Storage ===\n") v.printfStdout("Total blobs on remote: %s (%s)\n", - humanize.Comma(int64(result.TotalBlobCount)), - humanize.Bytes(uint64(result.TotalBlobSize))) + humanize.Comma(int64(result.TotalBlobCount)), humanize.Bytes(uint64(result.TotalBlobSize))) v.printfStdout("Referenced by snapshots: %s (%s)\n", - humanize.Comma(int64(result.ReferencedBlobCount)), - humanize.Bytes(uint64(result.ReferencedBlobSize))) + humanize.Comma(int64(result.ReferencedBlobCount)), humanize.Bytes(uint64(result.ReferencedBlobSize))) v.printfStdout("Orphaned (unreferenced): %s (%s)\n", - humanize.Comma(int64(result.OrphanedBlobCount)), - humanize.Bytes(uint64(result.OrphanedBlobSize))) + humanize.Comma(int64(result.OrphanedBlobCount)), humanize.Bytes(uint64(result.OrphanedBlobSize))) if result.OrphanedBlobCount > 0 { v.printfStdout("\nRun 'vaultik prune --remote' to remove orphaned blobs.\n") } - - return nil } // truncateString truncates a string to maxLen, adding "..." if truncated diff --git a/internal/vaultik/prune.go b/internal/vaultik/prune.go index dff9dd9..2fb1a35 100644 --- a/internal/vaultik/prune.go +++ b/internal/vaultik/prune.go @@ -27,95 +27,19 @@ type PruneBlobsResult struct { func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { log.Info("Starting prune operation") - // Get all remote snapshots and their manifests - allBlobsReferenced := make(map[string]bool) - manifestCount := 0 - - // List all snapshots in storage - log.Info("Listing remote snapshots") - objectCh := v.Storage.ListStream(v.ctx, "metadata/") - - var snapshotIDs []string - for object := range objectCh { - if object.Err != nil { - return fmt.Errorf("listing remote snapshots: %w", object.Err) - } - - // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/ - parts := strings.Split(object.Key, "/") - if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { - // Check if this is a directory by looking for trailing slash - if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") { - snapshotID := parts[1] - // Only add unique snapshot IDs - found := false - for _, id := range snapshotIDs { - if id == snapshotID { - found = true - break - } - } - if !found { - snapshotIDs = append(snapshotIDs, snapshotID) - } - } - } + allBlobsReferenced, err := v.collectReferencedBlobs() + if err != nil { + return err } - log.Info("Found manifests in remote storage", "count", len(snapshotIDs)) - - // Download and parse each manifest to get referenced blobs - for _, snapshotID := range snapshotIDs { - log.Debug("Processing manifest", "snapshot_id", snapshotID) - - manifest, err := v.downloadManifest(snapshotID) - if err != nil { - log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err) - continue - } - - // Add all blobs from this manifest to our referenced set - for _, blob := range manifest.Blobs { - allBlobsReferenced[blob.Hash] = true - } - manifestCount++ + allBlobs, err := v.listAllRemoteBlobs() + if err != nil { + return err } - log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced)) + unreferencedBlobs, totalSize := v.findUnreferencedBlobs(allBlobs, allBlobsReferenced) - // List all blobs in storage - log.Info("Listing all blobs in storage") - allBlobs := make(map[string]int64) // hash -> size - blobObjectCh := v.Storage.ListStream(v.ctx, "blobs/") - - for object := range blobObjectCh { - if object.Err != nil { - return fmt.Errorf("listing blobs: %w", object.Err) - } - - // Extract hash from path like blobs/ab/cd/abcdef123456... - parts := strings.Split(object.Key, "/") - if len(parts) == 4 && parts[0] == "blobs" { - hash := parts[3] - allBlobs[hash] = object.Size - } - } - - log.Info("Found blobs in storage", "count", len(allBlobs)) - - // Find unreferenced blobs - var unreferencedBlobs []string - var totalSize int64 - for hash, size := range allBlobs { - if !allBlobsReferenced[hash] { - unreferencedBlobs = append(unreferencedBlobs, hash) - totalSize += size - } - } - - result := &PruneBlobsResult{ - BlobsFound: len(unreferencedBlobs), - } + result := &PruneBlobsResult{BlobsFound: len(unreferencedBlobs)} if len(unreferencedBlobs) == 0 { log.Info("No unreferenced blobs found") @@ -126,18 +50,15 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { return nil } - // Show what will be deleted log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize))) if !opts.JSON { v.printfStdout("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize))) } - // Confirm unless --force is used (skip in JSON mode - require --force) if !opts.Force && !opts.JSON { v.printfStdout("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs)) var confirm string if _, err := v.scanStdin(&confirm); err != nil { - // Treat EOF or error as "no" v.printlnStdout("Cancelled") return nil } @@ -147,10 +68,109 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { } } - // Delete unreferenced blobs + v.deleteUnreferencedBlobs(unreferencedBlobs, allBlobs, result) + + if opts.JSON { + return v.outputPruneBlobsJSON(result) + } + + v.printfStdout("\nDeleted %d blob(s) totaling %s\n", result.BlobsDeleted, humanize.Bytes(uint64(result.BytesFreed))) + if result.BlobsFailed > 0 { + v.printfStdout("Failed to delete %d blob(s)\n", result.BlobsFailed) + } + + return nil +} + +// collectReferencedBlobs downloads all manifests and returns the set of referenced blob hashes +func (v *Vaultik) collectReferencedBlobs() (map[string]bool, error) { + log.Info("Listing remote snapshots") + snapshotIDs, err := v.listUniqueSnapshotIDs() + if err != nil { + return nil, fmt.Errorf("listing snapshot IDs: %w", err) + } + log.Info("Found manifests in remote storage", "count", len(snapshotIDs)) + + allBlobsReferenced := make(map[string]bool) + manifestCount := 0 + + for _, snapshotID := range snapshotIDs { + log.Debug("Processing manifest", "snapshot_id", snapshotID) + manifest, err := v.downloadManifest(snapshotID) + if err != nil { + log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err) + continue + } + for _, blob := range manifest.Blobs { + allBlobsReferenced[blob.Hash] = true + } + manifestCount++ + } + + log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced)) + return allBlobsReferenced, nil +} + +// listUniqueSnapshotIDs returns deduplicated snapshot IDs from remote metadata +func (v *Vaultik) listUniqueSnapshotIDs() ([]string, error) { + objectCh := v.Storage.ListStream(v.ctx, "metadata/") + seen := make(map[string]bool) + var snapshotIDs []string + + for object := range objectCh { + if object.Err != nil { + return nil, fmt.Errorf("listing metadata objects: %w", object.Err) + } + parts := strings.Split(object.Key, "/") + if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { + if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") { + snapshotID := parts[1] + if !seen[snapshotID] { + seen[snapshotID] = true + snapshotIDs = append(snapshotIDs, snapshotID) + } + } + } + } + return snapshotIDs, nil +} + +// listAllRemoteBlobs returns a map of all blob hashes to their sizes in remote storage +func (v *Vaultik) listAllRemoteBlobs() (map[string]int64, error) { + log.Info("Listing all blobs in storage") + allBlobs := make(map[string]int64) + blobObjectCh := v.Storage.ListStream(v.ctx, "blobs/") + + for object := range blobObjectCh { + if object.Err != nil { + return nil, fmt.Errorf("listing blobs: %w", object.Err) + } + parts := strings.Split(object.Key, "/") + if len(parts) == 4 && parts[0] == "blobs" { + allBlobs[parts[3]] = object.Size + } + } + + log.Info("Found blobs in storage", "count", len(allBlobs)) + return allBlobs, nil +} + +// findUnreferencedBlobs returns blob hashes not referenced by any manifest and their total size +func (v *Vaultik) findUnreferencedBlobs(allBlobs map[string]int64, referenced map[string]bool) ([]string, int64) { + var unreferenced []string + var totalSize int64 + for hash, size := range allBlobs { + if !referenced[hash] { + unreferenced = append(unreferenced, hash) + totalSize += size + } + } + return unreferenced, totalSize +} + +// deleteUnreferencedBlobs deletes the given blobs from storage and populates the result +func (v *Vaultik) deleteUnreferencedBlobs(unreferencedBlobs []string, allBlobs map[string]int64, result *PruneBlobsResult) { log.Info("Deleting unreferenced blobs") - deletedCount := 0 - deletedSize := int64(0) for i, hash := range unreferencedBlobs { blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash) @@ -160,10 +180,9 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { continue } - deletedCount++ - deletedSize += allBlobs[hash] + result.BlobsDeleted++ + result.BytesFreed += allBlobs[hash] - // Progress update every 100 blobs if (i+1)%100 == 0 || i == len(unreferencedBlobs)-1 { log.Info("Deletion progress", "deleted", i+1, @@ -173,26 +192,13 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { } } - result.BlobsDeleted = deletedCount - result.BlobsFailed = len(unreferencedBlobs) - deletedCount - result.BytesFreed = deletedSize + result.BlobsFailed = len(unreferencedBlobs) - result.BlobsDeleted log.Info("Prune complete", - "deleted_count", deletedCount, - "deleted_size", humanize.Bytes(uint64(deletedSize)), - "failed", len(unreferencedBlobs)-deletedCount, + "deleted_count", result.BlobsDeleted, + "deleted_size", humanize.Bytes(uint64(result.BytesFreed)), + "failed", result.BlobsFailed, ) - - if opts.JSON { - return v.outputPruneBlobsJSON(result) - } - - v.printfStdout("\nDeleted %d blob(s) totaling %s\n", deletedCount, humanize.Bytes(uint64(deletedSize))) - if deletedCount < len(unreferencedBlobs) { - v.printfStdout("Failed to delete %d blob(s)\n", len(unreferencedBlobs)-deletedCount) - } - - return nil } // outputPruneBlobsJSON outputs the prune result as JSON diff --git a/internal/vaultik/restore.go b/internal/vaultik/restore.go index a92fef5..5797fc8 100644 --- a/internal/vaultik/restore.go +++ b/internal/vaultik/restore.go @@ -55,15 +55,9 @@ type RestoreResult struct { func (v *Vaultik) Restore(opts *RestoreOptions) error { startTime := time.Now() - // Check for age_secret_key - if v.Config.AgeSecretKey == "" { - return fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'") - } - - // Parse the age identity - identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey) + identity, err := v.prepareRestoreIdentity() if err != nil { - return fmt.Errorf("parsing age secret key: %w", err) + return err } log.Info("Starting restore operation", @@ -115,10 +109,73 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { } // Step 5: Restore files + result, err := v.restoreAllFiles(files, repos, opts, identity, chunkToBlobMap) + if err != nil { + return err + } + + result.Duration = time.Since(startTime) + + log.Info("Restore complete", + "files_restored", result.FilesRestored, + "bytes_restored", humanize.Bytes(uint64(result.BytesRestored)), + "blobs_downloaded", result.BlobsDownloaded, + "bytes_downloaded", humanize.Bytes(uint64(result.BytesDownloaded)), + "duration", result.Duration, + ) + + v.printfStdout("Restored %d files (%s) in %s\n", + result.FilesRestored, + humanize.Bytes(uint64(result.BytesRestored)), + result.Duration.Round(time.Second), + ) + + if result.FilesFailed > 0 { + _, _ = fmt.Fprintf(v.Stdout, "\nWARNING: %d file(s) failed to restore:\n", result.FilesFailed) + for _, path := range result.FailedFiles { + _, _ = fmt.Fprintf(v.Stdout, " - %s\n", path) + } + } + + // Run verification if requested + if opts.Verify { + if err := v.handleRestoreVerification(repos, files, opts, result); err != nil { + return err + } + } + + if result.FilesFailed > 0 { + return fmt.Errorf("%d file(s) failed to restore", result.FilesFailed) + } + + return nil +} + +// prepareRestoreIdentity validates that an age secret key is configured and parses it +func (v *Vaultik) prepareRestoreIdentity() (age.Identity, error) { + if v.Config.AgeSecretKey == "" { + return nil, fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'") + } + + identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey) + if err != nil { + return nil, fmt.Errorf("parsing age secret key: %w", err) + } + return identity, nil +} + +// restoreAllFiles iterates over files and restores each one, tracking progress and failures +func (v *Vaultik) restoreAllFiles( + files []*database.File, + repos *database.Repositories, + opts *RestoreOptions, + identity age.Identity, + chunkToBlobMap map[string]*database.BlobChunk, +) (*RestoreResult, error) { result := &RestoreResult{} blobCache, err := newBlobDiskCache(4 * v.Config.BlobSizeLimit.Int64()) if err != nil { - return fmt.Errorf("creating blob cache: %w", err) + return nil, fmt.Errorf("creating blob cache: %w", err) } defer func() { _ = blobCache.Close() }() @@ -133,7 +190,7 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { for i, file := range files { if v.ctx.Err() != nil { - return v.ctx.Err() + return nil, v.ctx.Err() } if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil { @@ -165,53 +222,32 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { _ = bar.Finish() } - result.Duration = time.Since(startTime) + return result, nil +} - log.Info("Restore complete", - "files_restored", result.FilesRestored, - "bytes_restored", humanize.Bytes(uint64(result.BytesRestored)), - "blobs_downloaded", result.BlobsDownloaded, - "bytes_downloaded", humanize.Bytes(uint64(result.BytesDownloaded)), - "duration", result.Duration, - ) - - v.printfStdout("Restored %d files (%s) in %s\n", - result.FilesRestored, - humanize.Bytes(uint64(result.BytesRestored)), - result.Duration.Round(time.Second), - ) +// handleRestoreVerification runs post-restore verification if requested +func (v *Vaultik) handleRestoreVerification( + repos *database.Repositories, + files []*database.File, + opts *RestoreOptions, + result *RestoreResult, +) error { + if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil { + return fmt.Errorf("verification failed: %w", err) + } if result.FilesFailed > 0 { - _, _ = fmt.Fprintf(v.Stdout, "\nWARNING: %d file(s) failed to restore:\n", result.FilesFailed) + v.printfStdout("\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed) for _, path := range result.FailedFiles { - _, _ = fmt.Fprintf(v.Stdout, " - %s\n", path) + v.printfStdout(" - %s\n", path) } + return fmt.Errorf("%d files failed verification", result.FilesFailed) } - // Run verification if requested - if opts.Verify { - if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil { - return fmt.Errorf("verification failed: %w", err) - } - - if result.FilesFailed > 0 { - v.printfStdout("\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed) - for _, path := range result.FailedFiles { - v.printfStdout(" - %s\n", path) - } - return fmt.Errorf("%d files failed verification", result.FilesFailed) - } - - v.printfStdout("Verified %d files (%s)\n", - result.FilesVerified, - humanize.Bytes(uint64(result.BytesVerified)), - ) - } - - if result.FilesFailed > 0 { - return fmt.Errorf("%d file(s) failed to restore", result.FilesFailed) - } - + v.printfStdout("Verified %d files (%s)\n", + result.FilesVerified, + humanize.Bytes(uint64(result.BytesVerified)), + ) return nil } diff --git a/internal/vaultik/snapshot.go b/internal/vaultik/snapshot.go index e0d93b2..21904bf 100644 --- a/internal/vaultik/snapshot.go +++ b/internal/vaultik/snapshot.go @@ -111,40 +111,34 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error { return nil } +// snapshotStats tracks aggregate statistics across directory scans +type snapshotStats struct { + totalFiles int + totalBytes int64 + totalChunks int + totalBlobs int + totalBytesSkipped int64 + totalFilesSkipped int + totalFilesDeleted int + totalBytesDeleted int64 + totalBytesUploaded int64 + totalBlobsUploaded int + uploadDuration time.Duration +} + // createNamedSnapshot creates a single named snapshot func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, snapName string, idx, total int) error { snapshotStartTime := time.Now() - snapConfig := v.Config.Snapshots[snapName] - if total > 1 { v.printfStdout("\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName) } - // Resolve source directories to absolute paths - resolvedDirs := make([]string, 0, len(snapConfig.Paths)) - for _, dir := range snapConfig.Paths { - absPath, err := filepath.Abs(dir) - if err != nil { - return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err) - } - - // Resolve symlinks - resolvedPath, err := filepath.EvalSymlinks(absPath) - if err != nil { - // If the path doesn't exist yet, use the absolute path - if os.IsNotExist(err) { - resolvedPath = absPath - } else { - return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err) - } - } - - resolvedDirs = append(resolvedDirs, resolvedPath) + resolvedDirs, err := v.resolveSnapshotPaths(snapName) + if err != nil { + return err } - // Create scanner with progress enabled (unless in cron mode) - // Pass the combined excludes for this snapshot scanner := v.ScannerFactory(snapshot.ScannerParams{ EnableProgress: !opts.Cron, Fs: v.Fs, @@ -152,20 +146,6 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna SkipErrors: opts.SkipErrors, }) - // Statistics tracking - totalFiles := 0 - totalBytes := int64(0) - totalChunks := 0 - totalBlobs := 0 - totalBytesSkipped := int64(0) - totalFilesSkipped := 0 - totalFilesDeleted := 0 - totalBytesDeleted := int64(0) - totalBytesUploaded := int64(0) - totalBlobsUploaded := 0 - uploadDuration := time.Duration(0) - - // Create a new snapshot at the beginning (with snapshot name in ID) snapshotID, err := v.SnapshotManager.CreateSnapshotWithName(v.ctx, hostname, snapName, v.Globals.Version, v.Globals.Commit) if err != nil { return fmt.Errorf("creating snapshot: %w", err) @@ -173,12 +153,64 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName) v.printfStdout("Beginning snapshot: %s\n", snapshotID) + stats, err := v.scanAllDirectories(scanner, resolvedDirs, snapshotID) + if err != nil { + return err + } + + v.collectUploadStats(scanner, stats) + + if err := v.finalizeSnapshotMetadata(snapshotID, stats); err != nil { + return err + } + + log.Info("Snapshot complete", + "snapshot_id", snapshotID, + "name", snapName, + "files", stats.totalFiles, + "blobs_uploaded", stats.totalBlobsUploaded, + "bytes_uploaded", stats.totalBytesUploaded, + "duration", time.Since(snapshotStartTime)) + + v.printSnapshotSummary(snapshotID, snapshotStartTime, stats) + return nil +} + +// resolveSnapshotPaths resolves source directories to absolute paths with symlink resolution +func (v *Vaultik) resolveSnapshotPaths(snapName string) ([]string, error) { + snapConfig := v.Config.Snapshots[snapName] + resolvedDirs := make([]string, 0, len(snapConfig.Paths)) + + for _, dir := range snapConfig.Paths { + absPath, err := filepath.Abs(dir) + if err != nil { + return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err) + } + + resolvedPath, err := filepath.EvalSymlinks(absPath) + if err != nil { + if os.IsNotExist(err) { + resolvedPath = absPath + } else { + return nil, fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err) + } + } + + resolvedDirs = append(resolvedDirs, resolvedPath) + } + + return resolvedDirs, nil +} + +// scanAllDirectories runs the scanner on each resolved directory and accumulates stats +func (v *Vaultik) scanAllDirectories(scanner *snapshot.Scanner, resolvedDirs []string, snapshotID string) (*snapshotStats, error) { + stats := &snapshotStats{} + for i, dir := range resolvedDirs { - // Check if context is cancelled select { case <-v.ctx.Done(): log.Info("Snapshot creation cancelled") - return v.ctx.Err() + return nil, v.ctx.Err() default: } @@ -186,17 +218,17 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna v.printfStdout("Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir) result, err := scanner.Scan(v.ctx, dir, snapshotID) if err != nil { - return fmt.Errorf("failed to scan %s: %w", dir, err) + return nil, fmt.Errorf("failed to scan %s: %w", dir, err) } - totalFiles += result.FilesScanned - totalBytes += result.BytesScanned - totalChunks += result.ChunksCreated - totalBlobs += result.BlobsCreated - totalFilesSkipped += result.FilesSkipped - totalBytesSkipped += result.BytesSkipped - totalFilesDeleted += result.FilesDeleted - totalBytesDeleted += result.BytesDeleted + stats.totalFiles += result.FilesScanned + stats.totalBytes += result.BytesScanned + stats.totalChunks += result.ChunksCreated + stats.totalBlobs += result.BlobsCreated + stats.totalFilesSkipped += result.FilesSkipped + stats.totalBytesSkipped += result.BytesSkipped + stats.totalFilesDeleted += result.FilesDeleted + stats.totalBytesDeleted += result.BytesDeleted log.Info("Directory scan complete", "path", dir, @@ -207,85 +239,79 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna "chunks", result.ChunksCreated, "blobs", result.BlobsCreated, "duration", result.EndTime.Sub(result.StartTime)) - - // Remove per-directory summary - the scanner already prints its own summary } - // Get upload statistics from scanner progress if available + return stats, nil +} + +// collectUploadStats gathers upload statistics from the scanner's progress reporter +func (v *Vaultik) collectUploadStats(scanner *snapshot.Scanner, stats *snapshotStats) { if s := scanner.GetProgress(); s != nil { - stats := s.GetStats() - totalBytesUploaded = stats.BytesUploaded.Load() - totalBlobsUploaded = int(stats.BlobsUploaded.Load()) - uploadDuration = time.Duration(stats.UploadDurationMs.Load()) * time.Millisecond + progressStats := s.GetStats() + stats.totalBytesUploaded = progressStats.BytesUploaded.Load() + stats.totalBlobsUploaded = int(progressStats.BlobsUploaded.Load()) + stats.uploadDuration = time.Duration(progressStats.UploadDurationMs.Load()) * time.Millisecond } +} - // Update snapshot statistics with extended fields +// finalizeSnapshotMetadata updates stats, marks complete, and exports metadata +func (v *Vaultik) finalizeSnapshotMetadata(snapshotID string, stats *snapshotStats) error { extStats := snapshot.ExtendedBackupStats{ BackupStats: snapshot.BackupStats{ - FilesScanned: totalFiles, - BytesScanned: totalBytes, - ChunksCreated: totalChunks, - BlobsCreated: totalBlobs, - BytesUploaded: totalBytesUploaded, + FilesScanned: stats.totalFiles, + BytesScanned: stats.totalBytes, + ChunksCreated: stats.totalChunks, + BlobsCreated: stats.totalBlobs, + BytesUploaded: stats.totalBytesUploaded, }, - BlobUncompressedSize: 0, // Will be set from database query below + BlobUncompressedSize: 0, CompressionLevel: v.Config.CompressionLevel, - UploadDurationMs: uploadDuration.Milliseconds(), + UploadDurationMs: stats.uploadDuration.Milliseconds(), } if err := v.SnapshotManager.UpdateSnapshotStatsExtended(v.ctx, snapshotID, extStats); err != nil { return fmt.Errorf("updating snapshot stats: %w", err) } - // Mark snapshot as complete if err := v.SnapshotManager.CompleteSnapshot(v.ctx, snapshotID); err != nil { return fmt.Errorf("completing snapshot: %w", err) } - // Export snapshot metadata - // Export snapshot metadata without closing the database - // The export function should handle its own database connection if err := v.SnapshotManager.ExportSnapshotMetadata(v.ctx, v.Config.IndexPath, snapshotID); err != nil { return fmt.Errorf("exporting snapshot metadata: %w", err) } - // Calculate final statistics - snapshotDuration := time.Since(snapshotStartTime) - totalFilesChanged := totalFiles - totalFilesSkipped - totalBytesChanged := totalBytes - totalBytesAll := totalBytes + totalBytesSkipped + return nil +} - // Calculate upload speed - var avgUploadSpeed string - if totalBytesUploaded > 0 && uploadDuration > 0 { - bytesPerSec := float64(totalBytesUploaded) / uploadDuration.Seconds() - bitsPerSec := bytesPerSec * 8 - if bitsPerSec >= 1e9 { - avgUploadSpeed = fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9) - } else if bitsPerSec >= 1e6 { - avgUploadSpeed = fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6) - } else if bitsPerSec >= 1e3 { - avgUploadSpeed = fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3) - } else { - avgUploadSpeed = fmt.Sprintf("%.0f bit/s", bitsPerSec) - } - } else { - avgUploadSpeed = "N/A" +// formatUploadSpeed formats bytes uploaded and duration into a human-readable speed string +func formatUploadSpeed(bytesUploaded int64, duration time.Duration) string { + if bytesUploaded <= 0 || duration <= 0 { + return "N/A" } + bytesPerSec := float64(bytesUploaded) / duration.Seconds() + bitsPerSec := bytesPerSec * 8 + switch { + case bitsPerSec >= 1e9: + return fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9) + case bitsPerSec >= 1e6: + return fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6) + case bitsPerSec >= 1e3: + return fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3) + default: + return fmt.Sprintf("%.0f bit/s", bitsPerSec) + } +} + +// printSnapshotSummary prints the comprehensive snapshot completion summary +func (v *Vaultik) printSnapshotSummary(snapshotID string, startTime time.Time, stats *snapshotStats) { + snapshotDuration := time.Since(startTime) + totalFilesChanged := stats.totalFiles - stats.totalFilesSkipped + totalBytesAll := stats.totalBytes + stats.totalBytesSkipped // Get total blob sizes from database - totalBlobSizeCompressed := int64(0) - totalBlobSizeUncompressed := int64(0) - if blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID); err == nil { - for _, hash := range blobHashes { - if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil { - totalBlobSizeCompressed += blob.CompressedSize - totalBlobSizeUncompressed += blob.UncompressedSize - } - } - } + totalBlobSizeCompressed, totalBlobSizeUncompressed := v.getSnapshotBlobSizes(snapshotID) - // Calculate compression ratio var compressionRatio float64 if totalBlobSizeUncompressed > 0 { compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed) @@ -293,55 +319,96 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna compressionRatio = 1.0 } - // Print comprehensive summary v.printfStdout("=== Snapshot Complete ===\n") v.printfStdout("ID: %s\n", snapshotID) v.printfStdout("Files: %s examined, %s to process, %s unchanged", - formatNumber(totalFiles), + formatNumber(stats.totalFiles), formatNumber(totalFilesChanged), - formatNumber(totalFilesSkipped)) - if totalFilesDeleted > 0 { - v.printfStdout(", %s deleted", formatNumber(totalFilesDeleted)) + formatNumber(stats.totalFilesSkipped)) + if stats.totalFilesDeleted > 0 { + v.printfStdout(", %s deleted", formatNumber(stats.totalFilesDeleted)) } v.printlnStdout() v.printfStdout("Data: %s total (%s to process)", humanize.Bytes(uint64(totalBytesAll)), - humanize.Bytes(uint64(totalBytesChanged))) - if totalBytesDeleted > 0 { - v.printfStdout(", %s deleted", humanize.Bytes(uint64(totalBytesDeleted))) + humanize.Bytes(uint64(stats.totalBytes))) + if stats.totalBytesDeleted > 0 { + v.printfStdout(", %s deleted", humanize.Bytes(uint64(stats.totalBytesDeleted))) } v.printlnStdout() - if totalBlobsUploaded > 0 { + if stats.totalBlobsUploaded > 0 { v.printfStdout("Storage: %s compressed from %s (%.2fx)\n", humanize.Bytes(uint64(totalBlobSizeCompressed)), humanize.Bytes(uint64(totalBlobSizeUncompressed)), compressionRatio) v.printfStdout("Upload: %d blobs, %s in %s (%s)\n", - totalBlobsUploaded, - humanize.Bytes(uint64(totalBytesUploaded)), - formatDuration(uploadDuration), - avgUploadSpeed) + stats.totalBlobsUploaded, + humanize.Bytes(uint64(stats.totalBytesUploaded)), + formatDuration(stats.uploadDuration), + formatUploadSpeed(stats.totalBytesUploaded, stats.uploadDuration)) } v.printfStdout("Duration: %s\n", formatDuration(snapshotDuration)) +} - return nil +// getSnapshotBlobSizes returns total compressed and uncompressed blob sizes for a snapshot +func (v *Vaultik) getSnapshotBlobSizes(snapshotID string) (compressed int64, uncompressed int64) { + blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID) + if err != nil { + return 0, 0 + } + for _, hash := range blobHashes { + if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil { + compressed += blob.CompressedSize + uncompressed += blob.UncompressedSize + } + } + return compressed, uncompressed } // ListSnapshots lists all snapshots func (v *Vaultik) ListSnapshots(jsonOutput bool) error { - // Get all remote snapshots + log.Info("Listing snapshots") + remoteSnapshots, err := v.listRemoteSnapshotIDs() + if err != nil { + return err + } + + localSnapshotMap, err := v.reconcileLocalWithRemote(remoteSnapshots) + if err != nil { + return err + } + + snapshots, err := v.buildSnapshotInfoList(remoteSnapshots, localSnapshotMap) + if err != nil { + return err + } + + // Sort by timestamp (newest first) + sort.Slice(snapshots, func(i, j int) bool { + return snapshots[i].Timestamp.After(snapshots[j].Timestamp) + }) + + if jsonOutput { + encoder := json.NewEncoder(v.Stdout) + encoder.SetIndent("", " ") + return encoder.Encode(snapshots) + } + + return v.printSnapshotTable(snapshots) +} + +// listRemoteSnapshotIDs returns a set of snapshot IDs found in remote storage +func (v *Vaultik) listRemoteSnapshotIDs() (map[string]bool, error) { remoteSnapshots := make(map[string]bool) objectCh := v.Storage.ListStream(v.ctx, "metadata/") for object := range objectCh { if object.Err != nil { - return fmt.Errorf("listing remote snapshots: %w", object.Err) + return nil, fmt.Errorf("listing remote snapshots: %w", object.Err) } - // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/ parts := strings.Split(object.Key, "/") if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { - // Skip macOS resource fork files (._*) and other hidden files if strings.HasPrefix(parts[1], ".") { continue } @@ -349,56 +416,46 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { } } - // Get all local snapshots + return remoteSnapshots, nil +} + +// reconcileLocalWithRemote removes local snapshots not in remote and returns the surviving local map +func (v *Vaultik) reconcileLocalWithRemote(remoteSnapshots map[string]bool) (map[string]*database.Snapshot, error) { localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000) if err != nil { - return fmt.Errorf("listing local snapshots: %w", err) + return nil, fmt.Errorf("listing local snapshots: %w", err) } - // Build a map of local snapshots for quick lookup localSnapshotMap := make(map[string]*database.Snapshot) for _, s := range localSnapshots { localSnapshotMap[s.ID.String()] = s } - // Remove local snapshots that don't exist remotely - for _, snapshot := range localSnapshots { - snapshotIDStr := snapshot.ID.String() + for _, snap := range localSnapshots { + snapshotIDStr := snap.ID.String() if !remoteSnapshots[snapshotIDStr] { - log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID) - - // Delete related records first to avoid foreign key constraints - if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err) - } - if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err) - } - if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err) - } - - // Now delete the snapshot itself - if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err) + log.Info("Removing local snapshot not found in remote", "snapshot_id", snap.ID) + if err := v.deleteSnapshotFromLocalDB(snapshotIDStr); err != nil { + log.Error("Failed to delete local snapshot", "snapshot_id", snap.ID, "error", err) } else { - log.Info("Deleted local snapshot not found in remote", "snapshot_id", snapshot.ID) + log.Info("Deleted local snapshot not found in remote", "snapshot_id", snap.ID) delete(localSnapshotMap, snapshotIDStr) } } } - // Build final snapshot list + return localSnapshotMap, nil +} + +// buildSnapshotInfoList constructs SnapshotInfo entries from remote IDs and local data +func (v *Vaultik) buildSnapshotInfoList(remoteSnapshots map[string]bool, localSnapshotMap map[string]*database.Snapshot) ([]SnapshotInfo, error) { snapshots := make([]SnapshotInfo, 0, len(remoteSnapshots)) for snapshotID := range remoteSnapshots { - // Check if we have this snapshot locally if localSnap, exists := localSnapshotMap[snapshotID]; exists && localSnap.CompletedAt != nil { - // Get total compressed size of all blobs referenced by this snapshot totalSize, err := v.Repositories.Snapshots.GetSnapshotTotalCompressedSize(v.ctx, snapshotID) if err != nil { log.Warn("Failed to get total compressed size", "id", snapshotID, "error", err) - // Fall back to stored blob size totalSize = localSnap.BlobSize } @@ -408,17 +465,15 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { CompressedSize: totalSize, }) } else { - // Remote snapshot not in local DB - fetch manifest to get size timestamp, err := parseSnapshotTimestamp(snapshotID) if err != nil { log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err) continue } - // Try to download manifest to get size totalSize, err := v.getManifestSize(snapshotID) if err != nil { - return fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err) + return nil, fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err) } snapshots = append(snapshots, SnapshotInfo{ @@ -429,22 +484,13 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { } } - // Sort by timestamp (newest first) - sort.Slice(snapshots, func(i, j int) bool { - return snapshots[i].Timestamp.After(snapshots[j].Timestamp) - }) + return snapshots, nil +} - if jsonOutput { - // JSON output - encoder := json.NewEncoder(v.Stdout) - encoder.SetIndent("", " ") - return encoder.Encode(snapshots) - } - - // Table output +// printSnapshotTable renders the snapshot list as a formatted table +func (v *Vaultik) printSnapshotTable(snapshots []SnapshotInfo) error { w := tabwriter.NewWriter(v.Stdout, 0, 0, 3, ' ', 0) - // Show configured snapshots from config file if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil { return err } @@ -465,7 +511,6 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { return err } - // Show remote snapshots if _, err := fmt.Fprintln(w, "REMOTE SNAPSHOTS:"); err != nil { return err } @@ -518,26 +563,9 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool) return snapshots[i].Timestamp.After(snapshots[j].Timestamp) }) - var toDelete []SnapshotInfo - - if keepLatest { - // Keep only the most recent snapshot - if len(snapshots) > 1 { - toDelete = snapshots[1:] - } - } else if olderThan != "" { - // Parse duration - duration, err := parseDuration(olderThan) - if err != nil { - return fmt.Errorf("invalid duration: %w", err) - } - - cutoff := time.Now().UTC().Add(-duration) - for _, snap := range snapshots { - if snap.Timestamp.Before(cutoff) { - toDelete = append(toDelete, snap) - } - } + toDelete, err := v.collectSnapshotsToPurge(snapshots, keepLatest, olderThan) + if err != nil { + return err } if len(toDelete) == 0 { @@ -545,6 +573,41 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool) return nil } + return v.confirmAndExecutePurge(toDelete, force) +} + +// collectSnapshotsToPurge determines which snapshots to delete based on retention criteria +func (v *Vaultik) collectSnapshotsToPurge(snapshots []SnapshotInfo, keepLatest bool, olderThan string) ([]SnapshotInfo, error) { + if keepLatest { + // Keep only the most recent snapshot + if len(snapshots) > 1 { + return snapshots[1:], nil + } + return nil, nil + } + + if olderThan != "" { + // Parse duration + duration, err := parseDuration(olderThan) + if err != nil { + return nil, fmt.Errorf("invalid duration: %w", err) + } + + cutoff := time.Now().UTC().Add(-duration) + var toDelete []SnapshotInfo + for _, snap := range snapshots { + if snap.Timestamp.Before(cutoff) { + toDelete = append(toDelete, snap) + } + } + return toDelete, nil + } + + return nil, nil +} + +// confirmAndExecutePurge shows deletion candidates, confirms with user, and deletes snapshots +func (v *Vaultik) confirmAndExecutePurge(toDelete []SnapshotInfo, force bool) error { // Show what will be deleted v.printfStdout("The following snapshots will be deleted:\n\n") for _, snap := range toDelete { @@ -610,29 +673,7 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio result.Mode = "deep" } - // Parse snapshot ID to extract timestamp - parts := strings.Split(snapshotID, "-") - var snapshotTime time.Time - if len(parts) >= 3 { - // Format: hostname-YYYYMMDD-HHMMSSZ - dateStr := parts[len(parts)-2] - timeStr := parts[len(parts)-1] - if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") { - timeStr = timeStr[:6] // Remove Z - timestamp, err := time.Parse("20060102150405", dateStr+timeStr) - if err == nil { - snapshotTime = timestamp - } - } - } - - if !opts.JSON { - v.printfStdout("Verifying snapshot %s\n", snapshotID) - if !snapshotTime.IsZero() { - v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) - } - v.printlnStdout() - } + v.printVerifyHeader(snapshotID, opts) // Download and parse manifest manifest, err := v.downloadManifest(snapshotID) @@ -663,10 +704,40 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio v.printfStdout("Checking blob existence...\n") } - missing := 0 - verified := 0 - missingSize := int64(0) + result.Verified, result.Missing, result.MissingSize = v.verifyManifestBlobsExist(manifest, opts) + return v.formatVerifyResult(result, manifest, opts) +} + +// printVerifyHeader prints the snapshot ID and parsed timestamp for verification output +func (v *Vaultik) printVerifyHeader(snapshotID string, opts *VerifyOptions) { + // Parse snapshot ID to extract timestamp + parts := strings.Split(snapshotID, "-") + var snapshotTime time.Time + if len(parts) >= 3 { + // Format: hostname-YYYYMMDD-HHMMSSZ + dateStr := parts[len(parts)-2] + timeStr := parts[len(parts)-1] + if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") { + timeStr = timeStr[:6] // Remove Z + timestamp, err := time.Parse("20060102150405", dateStr+timeStr) + if err == nil { + snapshotTime = timestamp + } + } + } + + if !opts.JSON { + v.printfStdout("Verifying snapshot %s\n", snapshotID) + if !snapshotTime.IsZero() { + v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) + } + v.printlnStdout() + } +} + +// verifyManifestBlobsExist checks that each blob in the manifest exists in storage +func (v *Vaultik) verifyManifestBlobsExist(manifest *snapshot.Manifest, opts *VerifyOptions) (verified, missing int, missingSize int64) { for _, blob := range manifest.Blobs { blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash) @@ -682,15 +753,15 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio verified++ } } + return verified, missing, missingSize +} - result.Verified = verified - result.Missing = missing - result.MissingSize = missingSize - +// formatVerifyResult outputs the final verification results as JSON or human-readable text +func (v *Vaultik) formatVerifyResult(result *VerifyResult, manifest *snapshot.Manifest, opts *VerifyOptions) error { if opts.JSON { - if missing > 0 { + if result.Missing > 0 { result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("%d blobs are missing", missing) + result.ErrorMessage = fmt.Sprintf("%d blobs are missing", result.Missing) } else { result.Status = "ok" } @@ -698,20 +769,19 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio } v.printfStdout("\nVerification complete:\n") - v.printfStdout(" Verified: %d blobs (%s)\n", verified, - humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize))) - if missing > 0 { - v.printfStdout(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize))) + v.printfStdout(" Verified: %d blobs (%s)\n", result.Verified, + humanize.Bytes(uint64(manifest.TotalCompressedSize-result.MissingSize))) + if result.Missing > 0 { + v.printfStdout(" Missing: %d blobs (%s)\n", result.Missing, humanize.Bytes(uint64(result.MissingSize))) } else { v.printfStdout(" Missing: 0 blobs\n") } v.printfStdout(" Status: ") - if missing > 0 { - v.printfStdout("FAILED - %d blobs are missing\n", missing) - return fmt.Errorf("%d blobs are missing", missing) - } else { - v.printfStdout("OK - All blobs verified\n") + if result.Missing > 0 { + v.printfStdout("FAILED - %d blobs are missing\n", result.Missing) + return fmt.Errorf("%d blobs are missing", result.Missing) } + v.printfStdout("OK - All blobs verified\n") return nil } @@ -907,9 +977,27 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov // RemoveAllSnapshots removes all snapshots from local database and optionally from remote func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) { - result := &RemoveResult{} + snapshotIDs, err := v.listAllRemoteSnapshotIDs() + if err != nil { + return nil, err + } - // List all snapshots + if len(snapshotIDs) == 0 { + if !opts.JSON { + v.printlnStdout("No snapshots found") + } + return &RemoveResult{}, nil + } + + if opts.DryRun { + return v.handleRemoveAllDryRun(snapshotIDs, opts) + } + + return v.executeRemoveAll(snapshotIDs, opts) +} + +// listAllRemoteSnapshotIDs collects all unique snapshot IDs from remote storage +func (v *Vaultik) listAllRemoteSnapshotIDs() ([]string, error) { log.Info("Listing all snapshots") objectCh := v.Storage.ListStream(v.ctx, "metadata/") @@ -941,32 +1029,33 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) } } - if len(snapshotIDs) == 0 { - if !opts.JSON { - v.printlnStdout("No snapshots found") - } - return result, nil - } + return snapshotIDs, nil +} - if opts.DryRun { - result.DryRun = true - result.SnapshotsRemoved = snapshotIDs - if !opts.JSON { - v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs)) - for _, id := range snapshotIDs { - v.printfStdout(" %s\n", id) - } - if opts.Remote { - v.printlnStdout("Would also remove from remote storage") - } - v.printlnStdout("[Dry run - no changes made]") - } - if opts.JSON { - return result, v.outputRemoveJSON(result) - } - return result, nil +// handleRemoveAllDryRun handles the dry-run mode for removing all snapshots +func (v *Vaultik) handleRemoveAllDryRun(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) { + result := &RemoveResult{ + DryRun: true, + SnapshotsRemoved: snapshotIDs, } + if !opts.JSON { + v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs)) + for _, id := range snapshotIDs { + v.printfStdout(" %s\n", id) + } + if opts.Remote { + v.printlnStdout("Would also remove from remote storage") + } + v.printlnStdout("[Dry run - no changes made]") + } + if opts.JSON { + return result, v.outputRemoveJSON(result) + } + return result, nil +} +// executeRemoveAll removes all snapshots from local database and optionally from remote storage +func (v *Vaultik) executeRemoveAll(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) { // --all requires --force if !opts.Force { return nil, fmt.Errorf("--all requires --force") @@ -974,6 +1063,7 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) log.Info("Removing all snapshots", "count", len(snapshotIDs)) + result := &RemoveResult{} for _, snapshotID := range snapshotIDs { log.Info("Removing snapshot", "snapshot_id", snapshotID) diff --git a/internal/vaultik/verify.go b/internal/vaultik/verify.go index 55213ef..732d70b 100644 --- a/internal/vaultik/verify.go +++ b/internal/vaultik/verify.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/hex" "fmt" + "hash" "io" "os" "time" @@ -35,6 +36,19 @@ type VerifyResult struct { ErrorMessage string `json:"error,omitempty"` } +// deepVerifyFailure records a failure in the result and returns it appropriately +func (v *Vaultik) deepVerifyFailure(result *VerifyResult, opts *VerifyOptions, msg string, err error) error { + result.Status = "failed" + result.ErrorMessage = msg + if opts.JSON { + return v.outputVerifyJSON(result) + } + if err != nil { + return err + } + return fmt.Errorf("%s", msg) +} + // RunDeepVerify executes deep verification operation func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { result := &VerifyResult{ @@ -42,89 +56,20 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { Mode: "deep", } - // Check for decryption capability if !v.CanDecrypt() { - result.Status = "failed" - result.ErrorMessage = "VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification" - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification") + return v.deepVerifyFailure(result, opts, + "VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification", + fmt.Errorf("VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification")) } - log.Info("Starting snapshot verification", - "snapshot_id", snapshotID, - "mode", "deep", - ) - + log.Info("Starting snapshot verification", "snapshot_id", snapshotID, "mode", "deep") if !opts.JSON { v.printfStdout("Deep verification of snapshot: %s\n\n", snapshotID) } - // Step 1: Download manifest - manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) - log.Info("Downloading manifest", "path", manifestPath) - if !opts.JSON { - v.printfStdout("Downloading manifest...\n") - } - - manifestReader, err := v.Storage.Get(v.ctx, manifestPath) + manifest, tempDB, dbBlobs, err := v.loadVerificationData(snapshotID, opts, result) if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to download manifest: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to download manifest: %w", err) - } - defer func() { _ = manifestReader.Close() }() - - // Decompress manifest - manifest, err := snapshot.DecodeManifest(manifestReader) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to decode manifest: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to decode manifest: %w", err) - } - - log.Info("Manifest loaded", - "manifest_blob_count", manifest.BlobCount, - "manifest_total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize)), - ) - if !opts.JSON { - v.printfStdout("Manifest loaded: %d blobs (%s)\n", manifest.BlobCount, humanize.Bytes(uint64(manifest.TotalCompressedSize))) - } - - // Step 2: Download and decrypt database (authoritative source) - dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) - log.Info("Downloading encrypted database", "path", dbPath) - if !opts.JSON { - v.printfStdout("Downloading and decrypting database...\n") - } - - dbReader, err := v.Storage.Get(v.ctx, dbPath) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to download database: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to download database: %w", err) - } - defer func() { _ = dbReader.Close() }() - - // Decrypt and decompress database - tempDB, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to decrypt database: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to decrypt database: %w", err) + return err } defer func() { if tempDB != nil { @@ -132,17 +77,6 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { } }() - // Step 3: Get authoritative blob list from database - dbBlobs, err := v.getBlobsFromDatabase(snapshotID, tempDB.DB) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to get blobs from database: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to get blobs from database: %w", err) - } - result.BlobCount = len(dbBlobs) var totalSize int64 for _, blob := range dbBlobs { @@ -150,54 +84,10 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { } result.TotalSize = totalSize - log.Info("Database loaded", - "db_blob_count", len(dbBlobs), - "db_total_size", humanize.Bytes(uint64(totalSize)), - ) - if !opts.JSON { - v.printfStdout("Database loaded: %d blobs (%s)\n", len(dbBlobs), humanize.Bytes(uint64(totalSize))) - v.printfStdout("Verifying manifest against database...\n") - } - - // Step 4: Verify manifest matches database - if err := v.verifyManifestAgainstDatabase(manifest, dbBlobs); err != nil { - result.Status = "failed" - result.ErrorMessage = err.Error() - if opts.JSON { - return v.outputVerifyJSON(result) - } + if err := v.runVerificationSteps(manifest, dbBlobs, tempDB, opts, result, totalSize); err != nil { return err } - // Step 5: Verify all blobs exist in S3 (using database as source) - if !opts.JSON { - v.printfStdout("Manifest verified.\n") - v.printfStdout("Checking blob existence in remote storage...\n") - } - if err := v.verifyBlobExistenceFromDB(dbBlobs); err != nil { - result.Status = "failed" - result.ErrorMessage = err.Error() - if opts.JSON { - return v.outputVerifyJSON(result) - } - return err - } - - // Step 6: Deep verification - download and verify blob contents - if !opts.JSON { - v.printfStdout("All blobs exist.\n") - v.printfStdout("Downloading and verifying blob contents (%d blobs, %s)...\n", len(dbBlobs), humanize.Bytes(uint64(totalSize))) - } - if err := v.performDeepVerificationFromDB(dbBlobs, tempDB.DB, opts); err != nil { - result.Status = "failed" - result.ErrorMessage = err.Error() - if opts.JSON { - return v.outputVerifyJSON(result) - } - return err - } - - // Success result.Status = "ok" result.Verified = len(dbBlobs) @@ -206,11 +96,7 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { } log.Info("āœ“ Verification completed successfully", - "snapshot_id", snapshotID, - "mode", "deep", - "blobs_verified", len(dbBlobs), - ) - + "snapshot_id", snapshotID, "mode", "deep", "blobs_verified", len(dbBlobs)) v.printfStdout("\nāœ“ Verification completed successfully\n") v.printfStdout(" Snapshot: %s\n", snapshotID) v.printfStdout(" Blobs verified: %d\n", len(dbBlobs)) @@ -219,6 +105,106 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { return nil } +// loadVerificationData downloads manifest, database, and blob list for verification +func (v *Vaultik) loadVerificationData(snapshotID string, opts *VerifyOptions, result *VerifyResult) (*snapshot.Manifest, *tempDB, []snapshot.BlobInfo, error) { + // Download manifest + manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) + log.Info("Downloading manifest", "path", manifestPath) + if !opts.JSON { + v.printfStdout("Downloading manifest...\n") + } + manifestReader, err := v.Storage.Get(v.ctx, manifestPath) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to download manifest: %v", err), + fmt.Errorf("failed to download manifest: %w", err)) + } + defer func() { _ = manifestReader.Close() }() + + manifest, err := snapshot.DecodeManifest(manifestReader) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to decode manifest: %v", err), + fmt.Errorf("failed to decode manifest: %w", err)) + } + + log.Info("Manifest loaded", + "manifest_blob_count", manifest.BlobCount, + "manifest_total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize))) + if !opts.JSON { + v.printfStdout("Manifest loaded: %d blobs (%s)\n", manifest.BlobCount, humanize.Bytes(uint64(manifest.TotalCompressedSize))) + v.printfStdout("Downloading and decrypting database...\n") + } + + // Download and decrypt database + dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) + log.Info("Downloading encrypted database", "path", dbPath) + dbReader, err := v.Storage.Get(v.ctx, dbPath) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to download database: %v", err), + fmt.Errorf("failed to download database: %w", err)) + } + defer func() { _ = dbReader.Close() }() + + tdb, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to decrypt database: %v", err), + fmt.Errorf("failed to decrypt database: %w", err)) + } + + dbBlobs, err := v.getBlobsFromDatabase(snapshotID, tdb.DB) + if err != nil { + _ = tdb.Close() + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to get blobs from database: %v", err), + fmt.Errorf("failed to get blobs from database: %w", err)) + } + + var dbTotalSize int64 + for _, b := range dbBlobs { + dbTotalSize += b.CompressedSize + } + + log.Info("Database loaded", + "db_blob_count", len(dbBlobs), + "db_total_size", humanize.Bytes(uint64(dbTotalSize))) + if !opts.JSON { + v.printfStdout("Database loaded: %d blobs (%s)\n", len(dbBlobs), humanize.Bytes(uint64(dbTotalSize))) + } + + return manifest, tdb, dbBlobs, nil +} + +// runVerificationSteps executes manifest verification, blob existence check, and deep content verification +func (v *Vaultik) runVerificationSteps(manifest *snapshot.Manifest, dbBlobs []snapshot.BlobInfo, tdb *tempDB, opts *VerifyOptions, result *VerifyResult, totalSize int64) error { + if !opts.JSON { + v.printfStdout("Verifying manifest against database...\n") + } + if err := v.verifyManifestAgainstDatabase(manifest, dbBlobs); err != nil { + return v.deepVerifyFailure(result, opts, err.Error(), err) + } + + if !opts.JSON { + v.printfStdout("Manifest verified.\n") + v.printfStdout("Checking blob existence in remote storage...\n") + } + if err := v.verifyBlobExistenceFromDB(dbBlobs); err != nil { + return v.deepVerifyFailure(result, opts, err.Error(), err) + } + + if !opts.JSON { + v.printfStdout("All blobs exist.\n") + v.printfStdout("Downloading and verifying blob contents (%d blobs, %s)...\n", len(dbBlobs), humanize.Bytes(uint64(totalSize))) + } + if err := v.performDeepVerificationFromDB(dbBlobs, tdb.DB, opts); err != nil { + return v.deepVerifyFailure(result, opts, err.Error(), err) + } + + return nil +} + // tempDB wraps sql.DB with cleanup type tempDB struct { *sql.DB @@ -316,7 +302,27 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { } defer decompressor.Close() - // Query blob chunks from database to get offsets and lengths + chunkCount, err := v.verifyBlobChunks(db, blobInfo.Hash, decompressor) + if err != nil { + return err + } + + if err := v.verifyBlobFinalIntegrity(decompressor, blobHasher, blobInfo.Hash); err != nil { + return err + } + + log.Info("Blob verified", + "hash", blobInfo.Hash[:16]+"...", + "chunks", chunkCount, + "size", humanize.Bytes(uint64(blobInfo.CompressedSize)), + ) + + return nil +} + +// verifyBlobChunks queries blob chunks from the database and verifies each chunk's hash +// against the decompressed blob stream +func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.Reader) (int, error) { query := ` SELECT bc.chunk_hash, bc.offset, bc.length FROM blob_chunks bc @@ -324,9 +330,9 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { WHERE b.blob_hash = ? ORDER BY bc.offset ` - rows, err := db.QueryContext(v.ctx, query, blobInfo.Hash) + rows, err := db.QueryContext(v.ctx, query, blobHash) if err != nil { - return fmt.Errorf("failed to query blob chunks: %w", err) + return 0, fmt.Errorf("failed to query blob chunks: %w", err) } defer func() { _ = rows.Close() }() @@ -339,12 +345,12 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { var chunkHash string var offset, length int64 if err := rows.Scan(&chunkHash, &offset, &length); err != nil { - return fmt.Errorf("failed to scan chunk row: %w", err) + return 0, fmt.Errorf("failed to scan chunk row: %w", err) } // Verify chunk ordering if offset <= lastOffset { - return fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) + return 0, fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) } lastOffset = offset @@ -353,7 +359,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Skip to the correct offset skipBytes := offset - totalRead if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil { - return fmt.Errorf("failed to skip to offset %d: %w", offset, err) + return 0, fmt.Errorf("failed to skip to offset %d: %w", offset, err) } totalRead = offset } @@ -361,7 +367,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Read chunk data chunkData := make([]byte, length) if _, err := io.ReadFull(decompressor, chunkData); err != nil { - return fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) + return 0, fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) } totalRead += length @@ -371,7 +377,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { calculatedHash := hex.EncodeToString(hasher.Sum(nil)) if calculatedHash != chunkHash { - return fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", + return 0, fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", offset, calculatedHash, chunkHash) } @@ -379,9 +385,15 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { } if err := rows.Err(); err != nil { - return fmt.Errorf("error iterating blob chunks: %w", err) + return 0, fmt.Errorf("error iterating blob chunks: %w", err) } + return chunkCount, nil +} + +// verifyBlobFinalIntegrity checks that no trailing data exists in the decompressed stream +// and that the encrypted blob hash matches the expected value +func (v *Vaultik) verifyBlobFinalIntegrity(decompressor io.Reader, blobHasher hash.Hash, expectedHash string) error { // Verify no remaining data in blob - if chunk list is accurate, blob should be fully consumed remaining, err := io.Copy(io.Discard, decompressor) if err != nil { @@ -393,17 +405,11 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Verify blob hash matches the encrypted data we downloaded calculatedBlobHash := hex.EncodeToString(blobHasher.Sum(nil)) - if calculatedBlobHash != blobInfo.Hash { + if calculatedBlobHash != expectedHash { return fmt.Errorf("blob hash mismatch: calculated %s, expected %s", - calculatedBlobHash, blobInfo.Hash) + calculatedBlobHash, expectedHash) } - log.Info("Blob verified", - "hash", blobInfo.Hash[:16]+"...", - "chunks", chunkCount, - "size", humanize.Bytes(uint64(blobInfo.CompressedSize)), - ) - return nil }