diff --git a/internal/database/blobs.go b/internal/database/blobs.go index 44072c6..58aed0f 100644 --- a/internal/database/blobs.go +++ b/internal/database/blobs.go @@ -130,6 +130,51 @@ func (r *BlobRepository) GetByID(ctx context.Context, id string) (*Blob, error) return &blob, nil } +// GetAll returns every blob row keyed by blob ID. Useful at restore +// start to translate the per-chunk blob_id references in chunkToBlobMap +// into blob hashes without doing one GetByID query per chunk. +func (r *BlobRepository) GetAll(ctx context.Context) (map[string]*Blob, error) { + query := ` + SELECT id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts + FROM blobs + ` + + rows, err := r.db.conn.QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("querying blobs: %w", err) + } + defer CloseRows(rows) + + out := make(map[string]*Blob) + for rows.Next() { + var blob Blob + var createdTSUnix int64 + var finishedTSUnix, uploadedTSUnix sql.NullInt64 + if err := rows.Scan( + &blob.ID, + &blob.Hash, + &createdTSUnix, + &finishedTSUnix, + &blob.UncompressedSize, + &blob.CompressedSize, + &uploadedTSUnix, + ); err != nil { + return nil, fmt.Errorf("scanning blob: %w", err) + } + blob.CreatedTS = time.Unix(createdTSUnix, 0).UTC() + if finishedTSUnix.Valid { + ts := time.Unix(finishedTSUnix.Int64, 0).UTC() + blob.FinishedTS = &ts + } + if uploadedTSUnix.Valid { + ts := time.Unix(uploadedTSUnix.Int64, 0).UTC() + blob.UploadedTS = &ts + } + out[blob.ID.String()] = &blob + } + return out, rows.Err() +} + // UpdateFinished updates a blob when it's finalized func (r *BlobRepository) UpdateFinished(ctx context.Context, tx *sql.Tx, id string, hash string, uncompressedSize, compressedSize int64) error { query := ` diff --git a/internal/vaultik/blobcache.go b/internal/vaultik/blobcache.go index 553b187..4fbc3ff 100644 --- a/internal/vaultik/blobcache.go +++ b/internal/vaultik/blobcache.go @@ -2,6 +2,7 @@ package vaultik import ( "fmt" + "io" "os" "path/filepath" "sync" @@ -132,6 +133,66 @@ func (c *blobDiskCache) Put(key string, data []byte) error { return nil } +// PutFromReader streams r into the cache file for key, returning the +// total number of bytes written. Unlike Put, the data never has to +// reside fully in memory at any point — io.Copy uses an internal +// 32 KiB buffer. Used by restore to land a freshly decrypted blob on +// disk without buffering its entire plaintext (which may be tens of GB) +// in RAM. +func (c *blobDiskCache) PutFromReader(key string, r io.Reader) (int64, error) { + c.mu.Lock() + // Remove any prior entry first; we'll re-link after the file is + // written successfully. + if e, ok := c.items[key]; ok { + c.unlink(e) + c.curBytes -= e.size + _ = os.Remove(c.path(key)) + delete(c.items, key) + } + c.mu.Unlock() + + f, err := os.OpenFile(c.path(key), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o600) + if err != nil { + return 0, fmt.Errorf("creating cache file: %w", err) + } + written, copyErr := io.Copy(f, r) + closeErr := f.Close() + if copyErr != nil { + _ = os.Remove(c.path(key)) + return written, fmt.Errorf("streaming to cache file: %w", copyErr) + } + if closeErr != nil { + _ = os.Remove(c.path(key)) + return written, fmt.Errorf("closing cache file: %w", closeErr) + } + + c.mu.Lock() + defer c.mu.Unlock() + + // If the entry would exceed maxBytes outright, drop it on the + // floor — but the restore path passes math.MaxInt64 as maxBytes + // so this branch is effectively unreachable there. + if written > c.maxBytes { + _ = os.Remove(c.path(key)) + return written, nil + } + + e := &blobDiskCacheEntry{key: key, size: written} + c.pushFront(e) + c.items[key] = e + c.curBytes += written + + for c.curBytes > c.maxBytes && c.tail != nil { + c.evictLRU() + } + + if n := len(c.items); n > c.peakLen { + c.peakLen = n + } + + return written, nil +} + // Get reads a cached blob from disk. Returns data and true on hit. func (c *blobDiskCache) Get(key string) ([]byte, bool) { c.mu.Lock() diff --git a/internal/vaultik/restore.go b/internal/vaultik/restore.go index fb04d7a..e7ac019 100644 --- a/internal/vaultik/restore.go +++ b/internal/vaultik/restore.go @@ -159,7 +159,12 @@ func (v *Vaultik) prepareRestoreIdentity() (age.Identity, error) { return identity, nil } -// restoreAllFiles iterates over files and restores each one, tracking progress and failures +// restoreAllFiles processes files in blob-locality order: drain every +// file whose blob set is on disk, download the missing blobs for the +// pending file with the smallest uncached count, repeat. This keeps +// peak cache occupancy near 1 even on snapshots whose path order +// interleaves blobs, and lets the sweeper free each blob the moment +// its file set is exhausted. func (v *Vaultik) restoreAllFiles( files []*database.File, repos *database.Repositories, @@ -192,6 +197,32 @@ func (v *Vaultik) restoreAllFiles( // are all already restored. sweeper := newRestoreSweeper(v.ctx, repos, blobCache, v.Config.BlobSizeLimit.Int64()/100) + // Pre-fetch every blob row once so chunk extraction can map a + // blob_id to its hash without a DB round-trip per chunk. + blobsByID, err := repos.Blobs.GetAll(v.ctx) + if err != nil { + return nil, fmt.Errorf("fetching blob index: %w", err) + } + blobIDToHash := make(map[string]string, len(blobsByID)) + blobByHash := make(map[string]*database.Blob, len(blobsByID)) + for id, blob := range blobsByID { + hash := blob.Hash.String() + blobIDToHash[id] = hash + blobByHash[hash] = blob + } + + plan, err := newRestorePlan(v.ctx, repos, files, chunkToBlobMap, blobIDToHash) + if err != nil { + return nil, fmt.Errorf("building restore plan: %w", err) + } + + // Index files by ID so the loop can look them up by the IDs the + // plan hands back. + filesByID := make(map[types.FileID]*database.File, len(files)) + for _, f := range files { + filesByID[f.ID] = f + } + // Calculate total bytes expected for percentage / ETA arithmetic. var totalBytesExpected int64 for _, file := range files { @@ -203,17 +234,65 @@ func (v *Vaultik) restoreAllFiles( v.UI.Size(totalBytesExpected), v.UI.Path(opts.TargetDir)) + session := &restoreSession{ + v: v, + ctx: v.ctx, + repos: repos, + opts: opts, + identity: identity, + chunkToBlobMap: chunkToBlobMap, + blobByHash: blobByHash, + blobIDToHash: blobIDToHash, + blobCache: blobCache, + sweeper: sweeper, + result: result, + } + // Periodic progress output, matching the snapshot create cadence. startTime := time.Now() lastStatusTime := startTime const statusInterval = 15 * time.Second - for i, file := range files { + processed := 0 + for plan.hasPending() { if v.ctx.Err() != nil { return nil, v.ctx.Err() } - if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, sweeper, result); err != nil { + fileID, ready := plan.popReady() + if !ready { + // No file is fully cache-served. First free any blobs + // whose file sets are exhausted — without this, the + // blob whose last file we just finished would still be + // cached when we Put the next one, briefly pushing + // peak occupancy from 1 to 2. + sweeper.sweep() + + // Pick the pending file with the smallest uncached + // blob set and download its blobs. After each blob + // lands, the plan moves any pending file whose set + // just emptied onto the ready queue. + next := plan.pickNextDownload() + if next.IsZero() { + break + } + for _, hash := range plan.blobsNeeded(next) { + blob, ok := blobByHash[hash] + if !ok { + return nil, fmt.Errorf("blob hash %s missing from blob index", hash[:16]) + } + if err := session.downloadBlobToCache(hash, blob.CompressedSize); err != nil { + return nil, fmt.Errorf("downloading blob %s: %w", hash[:16], err) + } + result.BlobsDownloaded++ + result.BytesDownloaded += blob.CompressedSize + plan.markBlobCached(hash) + } + continue + } + + file := filesByID[fileID] + if err := session.restoreFile(file); err != nil { log.Error("Failed to restore file", "path", file.Path, "error", err) if !opts.SkipErrors { return nil, fmt.Errorf("restoring %s: %w (pass --skip-errors to continue past restore failures)", file.Path, err) @@ -221,22 +300,26 @@ func (v *Vaultik) restoreAllFiles( v.UI.Error("Failed to restore %s: %v. Skipping (--skip-errors).", v.UI.Path(file.Path.String()), err) result.FilesFailed++ result.FailedFiles = append(result.FailedFiles, file.Path.String()) + plan.finishFile(fileID) continue } - // Record the file as restored so the sweeper can free blobs once - // all referencing files are done. - sweeper.fileRestored(file.ID.String()) + // Record the file as restored so the sweeper can free blobs + // once all referencing files are done, and drop it from the + // plan's indexes so future picks ignore it. + sweeper.fileRestored(fileID.String()) + plan.finishFile(fileID) + processed++ if time.Since(lastStatusTime) >= statusInterval { - v.printRestoreProgress(i+1, len(files), result.BytesRestored, totalBytesExpected, startTime) + v.printRestoreProgress(processed, len(files), result.BytesRestored, totalBytesExpected, startTime) lastStatusTime = time.Now() } // Structured progress log for --verbose / JSON consumers. - if (i+1)%100 == 0 || i+1 == len(files) { + if processed%100 == 0 || processed == len(files) { log.Info("Restore progress", - "files", fmt.Sprintf("%d/%d", i+1, len(files)), + "files", fmt.Sprintf("%d/%d", processed, len(files)), "bytes", humanize.Bytes(uint64(result.BytesRestored)), ) } @@ -432,183 +515,128 @@ func (v *Vaultik) buildChunkToBlobMap(ctx context.Context, repos *database.Repos return result, rows.Err() } -// restoreFile restores a single file -func (v *Vaultik) restoreFile( - ctx context.Context, - repos *database.Repositories, - file *database.File, - targetDir string, - identity age.Identity, - chunkToBlobMap map[string]*database.BlobChunk, - blobCache *blobDiskCache, - sweeper *restoreSweeper, - result *RestoreResult, -) error { - // Calculate target path - use full original path under target directory - targetPath := filepath.Join(targetDir, file.Path.String()) - - // Create parent directories - parentDir := filepath.Dir(targetPath) - if err := v.Fs.MkdirAll(parentDir, 0755); err != nil { - return fmt.Errorf("creating parent directory: %w", err) - } - - // Handle symlinks - if file.IsSymlink() { - return v.restoreSymlink(file, targetPath, result) - } - - // Handle directories - if file.Mode&uint32(os.ModeDir) != 0 { - return v.restoreDirectory(file, targetPath, result) - } - - // Handle regular files - return v.restoreRegularFile(ctx, repos, file, targetPath, identity, chunkToBlobMap, blobCache, sweeper, result) +// restoreSession holds every piece of per-restore state shared by the +// restore-time methods. Each restore builds one of these from the +// snapshot's metadata and then drives the file loop through methods on +// it. Keeping this state on the struct rather than threading it +// through every function signature keeps the inner-loop call sites +// readable: restoreFile(file) instead of a ten-argument helper. +type restoreSession struct { + v *Vaultik + ctx context.Context + repos *database.Repositories + opts *RestoreOptions + identity age.Identity + chunkToBlobMap map[string]*database.BlobChunk + blobByHash map[string]*database.Blob + blobIDToHash map[string]string + blobCache *blobDiskCache + sweeper *restoreSweeper + result *RestoreResult } -// restoreSymlink restores a symbolic link -func (v *Vaultik) restoreSymlink(file *database.File, targetPath string, result *RestoreResult) error { - // Remove existing file if it exists - _ = v.Fs.Remove(targetPath) +// restoreFile dispatches to the right per-kind restorer. +func (s *restoreSession) restoreFile(file *database.File) error { + targetPath := filepath.Join(s.opts.TargetDir, file.Path.String()) + parentDir := filepath.Dir(targetPath) + if err := s.v.Fs.MkdirAll(parentDir, 0755); err != nil { + return fmt.Errorf("creating parent directory: %w", err) + } + if file.IsSymlink() { + return s.restoreSymlink(file, targetPath) + } + if file.Mode&uint32(os.ModeDir) != 0 { + return s.restoreDirectory(file, targetPath) + } + return s.restoreRegularFile(file, targetPath) +} - // Create symlink - // Note: afero.MemMapFs doesn't support symlinks, so we use os for real filesystems - if osFs, ok := v.Fs.(*afero.OsFs); ok { - _ = osFs // silence unused variable warning +// restoreSymlink restores a symbolic link. +func (s *restoreSession) restoreSymlink(file *database.File, targetPath string) error { + _ = s.v.Fs.Remove(targetPath) + // afero.MemMapFs doesn't support symlinks, so route real-FS + // symlinks through os. + if _, ok := s.v.Fs.(*afero.OsFs); ok { if err := os.Symlink(file.LinkTarget.String(), targetPath); err != nil { return fmt.Errorf("creating symlink: %w", err) } } else { log.Debug("Symlink creation not supported on this filesystem", "path", file.Path, "target", file.LinkTarget) } - - result.FilesRestored++ + s.result.FilesRestored++ log.Debug("Restored symlink", "path", file.Path, "target", file.LinkTarget) return nil } -// restoreDirectory restores a directory with proper permissions -func (v *Vaultik) restoreDirectory(file *database.File, targetPath string, result *RestoreResult) error { - // Create directory - if err := v.Fs.MkdirAll(targetPath, os.FileMode(file.Mode)); err != nil { +// restoreDirectory restores a directory with its permissions, mtime, +// and (on real filesystems, with sufficient privileges) ownership. +func (s *restoreSession) restoreDirectory(file *database.File, targetPath string) error { + if err := s.v.Fs.MkdirAll(targetPath, os.FileMode(file.Mode)); err != nil { return fmt.Errorf("creating directory: %w", err) } - - // Set permissions - if err := v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil { + if err := s.v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil { log.Debug("Failed to set directory permissions", "path", targetPath, "error", err) } - - // Set ownership (requires root) - if osFs, ok := v.Fs.(*afero.OsFs); ok { - _ = osFs + if _, ok := s.v.Fs.(*afero.OsFs); ok { if err := os.Chown(targetPath, int(file.UID), int(file.GID)); err != nil { log.Debug("Failed to set directory ownership", "path", targetPath, "error", err) } } - - // Set mtime - if err := v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil { + if err := s.v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil { log.Debug("Failed to set directory mtime", "path", targetPath, "error", err) } - - result.FilesRestored++ + s.result.FilesRestored++ return nil } -// restoreRegularFile restores a regular file by reconstructing it from chunks -func (v *Vaultik) restoreRegularFile( - ctx context.Context, - repos *database.Repositories, - file *database.File, - targetPath string, - identity age.Identity, - chunkToBlobMap map[string]*database.BlobChunk, - blobCache *blobDiskCache, - sweeper *restoreSweeper, - result *RestoreResult, -) error { +// restoreRegularFile reconstructs a regular file by reading chunks +// directly out of cached blobs via ReadAt. The expectation when this +// method runs is that every blob this file needs is already in the +// disk cache — the planner guarantees that by only marking files +// "ready" once their full blob set is on disk. +func (s *restoreSession) restoreRegularFile(file *database.File, targetPath string) error { fileStart := time.Now() - // Get file chunks in order t0 := time.Now() - fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID) + fileChunks, err := s.repos.FileChunks.GetByFileID(s.ctx, file.ID) fileChunksQueryDur := time.Since(t0) if err != nil { return fmt.Errorf("getting file chunks: %w", err) } - // Create output file t0 = time.Now() - outFile, err := v.Fs.Create(targetPath) + outFile, err := s.v.Fs.Create(targetPath) createDur := time.Since(t0) if err != nil { return fmt.Errorf("creating output file: %w", err) } defer func() { _ = outFile.Close() }() - // Per-file timing buckets so --debug shows exactly where seconds go. var ( - blobDBLookupDur time.Duration - cacheGetDur time.Duration - downloadDur time.Duration - cachePutDur time.Duration - writeDur time.Duration - sweeperDur time.Duration - downloadCount int - cacheHitCount int - bytesWritten int64 + readAtDur time.Duration + writeDur time.Duration + sweeperDur time.Duration + bytesWritten int64 ) for _, fc := range fileChunks { - // Find which blob contains this chunk chunkHashStr := fc.ChunkHash.String() - blobChunk, ok := chunkToBlobMap[chunkHashStr] + blobChunk, ok := s.chunkToBlobMap[chunkHashStr] if !ok { return fmt.Errorf("chunk %s not found in any blob", chunkHashStr[:16]) } - - // Get the blob's hash from the database (runs per chunk). - t0 = time.Now() - blob, err := repos.Blobs.GetByID(ctx, blobChunk.BlobID.String()) - blobDBLookupDur += time.Since(t0) - if err != nil { - return fmt.Errorf("getting blob %s: %w", blobChunk.BlobID, err) - } - - // Download and decrypt blob if not cached - blobHashStr := blob.Hash.String() - t0 = time.Now() - blobData, ok := blobCache.Get(blobHashStr) - cacheGetDur += time.Since(t0) + blobHash, ok := s.blobIDToHash[blobChunk.BlobID.String()] if !ok { - t0 = time.Now() - blobData, err = v.downloadBlob(ctx, blobHashStr, blob.CompressedSize, identity) - downloadDur += time.Since(t0) - if err != nil { - return fmt.Errorf("downloading blob %s: %w", blobHashStr[:16], err) - } - t0 = time.Now() - if putErr := blobCache.Put(blobHashStr, blobData); putErr != nil { - log.Debug("Failed to cache blob on disk", "hash", blobHashStr[:16], "error", putErr) - } - cachePutDur += time.Since(t0) - downloadCount++ - result.BlobsDownloaded++ - result.BytesDownloaded += blob.CompressedSize - } else { - cacheHitCount++ + return fmt.Errorf("blob id %s missing from hash index", blobChunk.BlobID) } - // Extract chunk from blob - if blobChunk.Offset+blobChunk.Length > int64(len(blobData)) { - return fmt.Errorf("chunk %s extends beyond blob data (offset=%d, length=%d, blob_size=%d)", - fc.ChunkHash[:16], blobChunk.Offset, blobChunk.Length, len(blobData)) + t0 = time.Now() + chunkData, err := s.blobCache.ReadAt(blobHash, blobChunk.Offset, blobChunk.Length) + readAtDur += time.Since(t0) + if err != nil { + return fmt.Errorf("reading chunk %s from cached blob %s: %w", fc.ChunkHash[:16], blobHash[:16], err) } - chunkData := blobData[blobChunk.Offset : blobChunk.Offset+blobChunk.Length] - // Write chunk to output file t0 = time.Now() n, err := outFile.Write(chunkData) writeDur += time.Since(t0) @@ -617,11 +645,8 @@ func (v *Vaultik) restoreRegularFile( } bytesWritten += int64(n) - // Tell the sweeper about the bytes we just restored so it can - // run an eviction sweep once the accumulated total crosses its - // threshold (config.BlobSizeLimit/100). t0 = time.Now() - sweeper.chunkRestored(int64(n)) + s.sweeper.chunkRestored(int64(n)) sweeperDur += time.Since(t0) } @@ -629,89 +654,72 @@ func (v *Vaultik) restoreRegularFile( "path", file.Path, "chunks", len(fileChunks), "bytes_written", bytesWritten, - "downloads", downloadCount, - "cache_hits", cacheHitCount, "ms_total", time.Since(fileStart).Milliseconds(), "ms_file_chunks_query", fileChunksQueryDur.Milliseconds(), "ms_create", createDur.Milliseconds(), - "ms_blob_db_lookups", blobDBLookupDur.Milliseconds(), - "ms_cache_gets", cacheGetDur.Milliseconds(), - "ms_cache_puts", cachePutDur.Milliseconds(), - "ms_downloads", downloadDur.Milliseconds(), + "ms_readat", readAtDur.Milliseconds(), "ms_writes", writeDur.Milliseconds(), "ms_sweeper", sweeperDur.Milliseconds(), ) - // Close file before setting metadata if err := outFile.Close(); err != nil { return fmt.Errorf("closing output file: %w", err) } - - // Set permissions - if err := v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil { + if err := s.v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil { log.Debug("Failed to set file permissions", "path", targetPath, "error", err) } - - // Set ownership (requires root) - if osFs, ok := v.Fs.(*afero.OsFs); ok { - _ = osFs + if _, ok := s.v.Fs.(*afero.OsFs); ok { if err := os.Chown(targetPath, int(file.UID), int(file.GID)); err != nil { log.Debug("Failed to set file ownership", "path", targetPath, "error", err) } } - - // Set mtime - if err := v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil { + if err := s.v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil { log.Debug("Failed to set file mtime", "path", targetPath, "error", err) } - result.FilesRestored++ - result.BytesRestored += bytesWritten + s.result.FilesRestored++ + s.result.BytesRestored += bytesWritten log.Debug("Restored file", "path", file.Path, "size", humanize.Bytes(uint64(bytesWritten))) return nil } -// downloadBlob downloads and decrypts a blob, returning the plaintext. -// Emits a debug log line splitting time spent in the network fetch (Get -// + Stat round-trips) from the streaming decrypt/decompress/read phase -// so --debug shows which side of the wire is the bottleneck. -func (v *Vaultik) downloadBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) ([]byte, error) { +// downloadBlobToCache streams a blob from remote storage straight into +// the disk cache, decrypting and decompressing on the fly. The +// plaintext never lives fully in memory — io.Copy through +// blobDiskCache.PutFromReader uses a 32 KiB buffer regardless of blob +// size, which is what makes multi-GB blobs tractable on machines with +// less RAM than the blob. +func (s *restoreSession) downloadBlobToCache(blobHash string, expectedSize int64) error { start := time.Now() t0 := time.Now() - rc, err := v.FetchAndDecryptBlob(ctx, blobHash, expectedSize, identity) + rc, err := s.v.FetchAndDecryptBlob(s.ctx, blobHash, expectedSize, s.identity) fetchSetupDur := time.Since(t0) if err != nil { - return nil, err + return err } t0 = time.Now() - data, err := io.ReadAll(rc) - readAllDur := time.Since(t0) - if err != nil { - _ = rc.Close() - return nil, fmt.Errorf("reading blob data: %w", err) + written, copyErr := s.blobCache.PutFromReader(blobHash, rc) + streamDur := time.Since(t0) + closeErr := rc.Close() + if copyErr != nil { + return copyErr + } + if closeErr != nil { + return closeErr } - // Close triggers hash verification - t0 = time.Now() - if err := rc.Close(); err != nil { - return nil, err - } - closeDur := time.Since(t0) - - log.Debug("Downloaded and decrypted blob (timings)", + log.Debug("Streamed blob into disk cache", "hash", blobHash[:16], "compressed_bytes", expectedSize, - "plaintext_bytes", len(data), + "plaintext_bytes", written, "ms_total", time.Since(start).Milliseconds(), "ms_fetch_setup", fetchSetupDur.Milliseconds(), - "ms_read_decrypt_decompress", readAllDur.Milliseconds(), - "ms_close_verify", closeDur.Milliseconds(), + "ms_stream_decrypt_decompress", streamDur.Milliseconds(), ) - - return data, nil + return nil } // verifyRestoredFiles verifies that all restored files match their expected chunk hashes diff --git a/internal/vaultik/restore_plan.go b/internal/vaultik/restore_plan.go new file mode 100644 index 0000000..5b53680 --- /dev/null +++ b/internal/vaultik/restore_plan.go @@ -0,0 +1,185 @@ +package vaultik + +import ( + "context" + "fmt" + "math" + "os" + + "sneak.berlin/go/vaultik/internal/database" + "sneak.berlin/go/vaultik/internal/types" +) + +// restorePlan orders restore-time file processing by blob locality. The +// goal is to keep the blob disk cache occupancy as small as possible: +// download one blob, drain every file referencing only that blob, let +// the sweeper free the blob, then move on. Files that span multiple +// blobs are processed when their full blob set is on disk. +// +// The plan keeps two indexes: +// +// - fileBlobs: for each pending file, the set of blob hashes it +// still needs that are NOT yet in the cache. Files with an empty +// set are "ready" — they can be restored from the current cache +// with no further downloads. +// - blobFiles: for each blob, the set of pending files referencing +// it. Used to short-circuit "when this blob lands, which files +// become ready" without a global scan. +type restorePlan struct { + fileBlobs map[types.FileID]map[string]struct{} + blobFiles map[string]map[types.FileID]struct{} + ready []types.FileID + cached map[string]struct{} +} + +// newRestorePlan builds the file→blob index for the given files. Files +// whose chunks reference no blobs (symlinks, directories) start in the +// ready queue immediately. +func newRestorePlan( + ctx context.Context, + repos *database.Repositories, + files []*database.File, + chunkToBlobMap map[string]*database.BlobChunk, + blobIDToHash map[string]string, +) (*restorePlan, error) { + p := &restorePlan{ + fileBlobs: make(map[types.FileID]map[string]struct{}, len(files)), + blobFiles: make(map[string]map[types.FileID]struct{}), + ready: make([]types.FileID, 0, len(files)), + cached: make(map[string]struct{}), + } + for _, f := range files { + if f.IsSymlink() || f.Mode&uint32(os.ModeDir) != 0 { + // No chunks to fetch — restore can run immediately. + p.fileBlobs[f.ID] = nil + p.ready = append(p.ready, f.ID) + continue + } + fileChunks, err := repos.FileChunks.GetByFileID(ctx, f.ID) + if err != nil { + return nil, fmt.Errorf("planning %s: %w", f.Path, err) + } + blobs := make(map[string]struct{}) + for _, fc := range fileChunks { + bc, ok := chunkToBlobMap[fc.ChunkHash.String()] + if !ok { + return nil, fmt.Errorf("planning %s: chunk %s missing from blob map", + f.Path, fc.ChunkHash.String()[:16]) + } + hash, ok := blobIDToHash[bc.BlobID.String()] + if !ok { + return nil, fmt.Errorf("planning %s: blob id %s missing from id-to-hash map", + f.Path, bc.BlobID) + } + blobs[hash] = struct{}{} + } + p.fileBlobs[f.ID] = blobs + for hash := range blobs { + set, ok := p.blobFiles[hash] + if !ok { + set = make(map[types.FileID]struct{}) + p.blobFiles[hash] = set + } + set[f.ID] = struct{}{} + } + if len(blobs) == 0 { + p.ready = append(p.ready, f.ID) + } + } + return p, nil +} + +// markBlobCached records that the named blob is now resident in the +// disk cache and moves any pending file whose remaining-uncached-blobs +// set just dropped to empty onto the ready queue. +func (p *restorePlan) markBlobCached(blobHash string) { + if _, already := p.cached[blobHash]; already { + return + } + p.cached[blobHash] = struct{}{} + for fileID := range p.blobFiles[blobHash] { + blobs := p.fileBlobs[fileID] + delete(blobs, blobHash) + if len(blobs) == 0 { + p.ready = append(p.ready, fileID) + } + } +} + +// popReady returns the next ready file, removing it from the queue. If +// no file is ready, the second return value is false. +func (p *restorePlan) popReady() (types.FileID, bool) { + if len(p.ready) == 0 { + return types.FileID{}, false + } + id := p.ready[0] + p.ready = p.ready[1:] + return id, true +} + +// finishFile drops a restored file from both indexes so subsequent +// planning calls don't reconsider it. +func (p *restorePlan) finishFile(fileID types.FileID) { + for hash := range p.fileBlobs[fileID] { + if set, ok := p.blobFiles[hash]; ok { + delete(set, fileID) + if len(set) == 0 { + delete(p.blobFiles, hash) + } + } + } + delete(p.fileBlobs, fileID) + // Also scrub the file from any blobFiles entries where it might + // still appear even after its uncached-blob set was emptied. + for hash, set := range p.blobFiles { + if _, ok := set[fileID]; ok { + delete(set, fileID) + if len(set) == 0 { + delete(p.blobFiles, hash) + } + } + } +} + +// pickNextDownload returns the pending file whose remaining-uncached +// blob set is smallest (with ties broken by FileID string compare so +// the choice is deterministic across runs). This file's blobs are +// downloaded next, after which it — together with any other pending +// files whose blob sets become empty — moves to the ready queue. +// +// The zero FileID return means nothing is pending. +func (p *restorePlan) pickNextDownload() types.FileID { + var best types.FileID + bestCount := math.MaxInt + var bestID string + for id, blobs := range p.fileBlobs { + n := len(blobs) + if n == 0 { + // Already-ready files should have been popped via + // popReady; ignore here just in case. + continue + } + idStr := id.String() + if n < bestCount || (n == bestCount && (best.IsZero() || idStr < bestID)) { + best = id + bestCount = n + bestID = idStr + } + } + return best +} + +// blobsNeeded returns the uncached blob hashes for fileID in any order. +func (p *restorePlan) blobsNeeded(fileID types.FileID) []string { + blobs := p.fileBlobs[fileID] + out := make([]string, 0, len(blobs)) + for h := range blobs { + out = append(out, h) + } + return out +} + +// hasPending reports whether any unfinished files remain. +func (p *restorePlan) hasPending() bool { + return len(p.fileBlobs) > 0 +}