Files
vaultik/internal/vaultik/restore.go
clawbot d77ac18aaa fix: add missing printfStdout, printlnStdout, scanlnStdin, FetchBlob, and FetchAndDecryptBlob methods
These methods were referenced in main but never defined, causing compilation
failures. They were introduced by merges that assumed dependent PRs were
already merged.
2026-02-19 23:51:53 -08:00

680 lines
20 KiB
Go

package vaultik
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"time"
"filippo.io/age"
"git.eeqj.de/sneak/vaultik/internal/blobgen"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/dustin/go-humanize"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
"golang.org/x/term"
)
// RestoreOptions contains options for the restore operation
type RestoreOptions struct {
SnapshotID string
TargetDir string
Paths []string // Optional paths to restore (empty = all)
Verify bool // Verify restored files by checking chunk hashes
}
// RestoreResult contains statistics from a restore operation
type RestoreResult struct {
FilesRestored int
BytesRestored int64
BlobsDownloaded int
BytesDownloaded int64
Duration time.Duration
// Verification results (only populated if Verify option is set)
FilesVerified int
BytesVerified int64
FilesFailed int
FailedFiles []string // Paths of files that failed verification
}
// Restore restores files from a snapshot to the target directory
func (v *Vaultik) Restore(opts *RestoreOptions) error {
startTime := time.Now()
// Check for age_secret_key
if v.Config.AgeSecretKey == "" {
return fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'")
}
// Parse the age identity
identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey)
if err != nil {
return fmt.Errorf("parsing age secret key: %w", err)
}
log.Info("Starting restore operation",
"snapshot_id", opts.SnapshotID,
"target_dir", opts.TargetDir,
"paths", opts.Paths,
)
// Step 1: Download and decrypt the snapshot metadata database
log.Info("Downloading snapshot metadata...")
tempDB, err := v.downloadSnapshotDB(opts.SnapshotID, identity)
if err != nil {
return fmt.Errorf("downloading snapshot database: %w", err)
}
defer func() {
if err := tempDB.Close(); err != nil {
log.Debug("Failed to close temp database", "error", err)
}
// Clean up temp file
if err := v.Fs.Remove(tempDB.Path()); err != nil {
log.Debug("Failed to remove temp database", "error", err)
}
}()
repos := database.NewRepositories(tempDB)
// Step 2: Get list of files to restore
files, err := v.getFilesToRestore(v.ctx, repos, opts.Paths)
if err != nil {
return fmt.Errorf("getting files to restore: %w", err)
}
if len(files) == 0 {
log.Warn("No files found to restore")
return nil
}
log.Info("Found files to restore", "count", len(files))
// Step 3: Create target directory
if err := v.Fs.MkdirAll(opts.TargetDir, 0755); err != nil {
return fmt.Errorf("creating target directory: %w", err)
}
// Step 4: Build a map of chunks to blobs for efficient restoration
chunkToBlobMap, err := v.buildChunkToBlobMap(v.ctx, repos)
if err != nil {
return fmt.Errorf("building chunk-to-blob map: %w", err)
}
// Step 5: Restore files
result := &RestoreResult{}
blobCache := make(map[string][]byte) // Cache downloaded and decrypted blobs
for i, file := range files {
if v.ctx.Err() != nil {
return v.ctx.Err()
}
if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil {
log.Error("Failed to restore file", "path", file.Path, "error", err)
// Continue with other files
continue
}
// Progress logging
if (i+1)%100 == 0 || i+1 == len(files) {
log.Info("Restore progress",
"files", fmt.Sprintf("%d/%d", i+1, len(files)),
"bytes", humanize.Bytes(uint64(result.BytesRestored)),
)
}
}
result.Duration = time.Since(startTime)
log.Info("Restore complete",
"files_restored", result.FilesRestored,
"bytes_restored", humanize.Bytes(uint64(result.BytesRestored)),
"blobs_downloaded", result.BlobsDownloaded,
"bytes_downloaded", humanize.Bytes(uint64(result.BytesDownloaded)),
"duration", result.Duration,
)
_, _ = fmt.Fprintf(v.Stdout, "Restored %d files (%s) in %s\n",
result.FilesRestored,
humanize.Bytes(uint64(result.BytesRestored)),
result.Duration.Round(time.Second),
)
// Run verification if requested
if opts.Verify {
if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil {
return fmt.Errorf("verification failed: %w", err)
}
if result.FilesFailed > 0 {
_, _ = fmt.Fprintf(v.Stdout, "\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed)
for _, path := range result.FailedFiles {
_, _ = fmt.Fprintf(v.Stdout, " - %s\n", path)
}
return fmt.Errorf("%d files failed verification", result.FilesFailed)
}
_, _ = fmt.Fprintf(v.Stdout, "Verified %d files (%s)\n",
result.FilesVerified,
humanize.Bytes(uint64(result.BytesVerified)),
)
}
return nil
}
// downloadSnapshotDB downloads and decrypts the snapshot metadata database
func (v *Vaultik) downloadSnapshotDB(snapshotID string, identity age.Identity) (*database.DB, error) {
// Download encrypted database from storage
dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
reader, err := v.Storage.Get(v.ctx, dbKey)
if err != nil {
return nil, fmt.Errorf("downloading %s: %w", dbKey, err)
}
defer func() { _ = reader.Close() }()
// Read all data
encryptedData, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("reading encrypted data: %w", err)
}
log.Debug("Downloaded encrypted database", "size", humanize.Bytes(uint64(len(encryptedData))))
// Decrypt and decompress using blobgen.Reader
blobReader, err := blobgen.NewReader(bytes.NewReader(encryptedData), identity)
if err != nil {
return nil, fmt.Errorf("creating decryption reader: %w", err)
}
defer func() { _ = blobReader.Close() }()
// Read the binary SQLite database
dbData, err := io.ReadAll(blobReader)
if err != nil {
return nil, fmt.Errorf("decrypting and decompressing: %w", err)
}
log.Debug("Decrypted database", "size", humanize.Bytes(uint64(len(dbData))))
// Create a temporary database file and write the binary SQLite data directly
tempFile, err := afero.TempFile(v.Fs, "", "vaultik-restore-*.db")
if err != nil {
return nil, fmt.Errorf("creating temp file: %w", err)
}
tempPath := tempFile.Name()
// Write the binary SQLite database directly
if _, err := tempFile.Write(dbData); err != nil {
_ = tempFile.Close()
_ = v.Fs.Remove(tempPath)
return nil, fmt.Errorf("writing database file: %w", err)
}
if err := tempFile.Close(); err != nil {
_ = v.Fs.Remove(tempPath)
return nil, fmt.Errorf("closing temp file: %w", err)
}
log.Debug("Created restore database", "path", tempPath)
// Open the database
db, err := database.New(v.ctx, tempPath)
if err != nil {
return nil, fmt.Errorf("opening restore database: %w", err)
}
return db, nil
}
// getFilesToRestore returns the list of files to restore based on path filters
func (v *Vaultik) getFilesToRestore(ctx context.Context, repos *database.Repositories, pathFilters []string) ([]*database.File, error) {
// If no filters, get all files
if len(pathFilters) == 0 {
return repos.Files.ListAll(ctx)
}
// Get files matching the path filters
var result []*database.File
seen := make(map[string]bool)
for _, filter := range pathFilters {
// Normalize the filter path
filter = filepath.Clean(filter)
// Get files with this prefix
files, err := repos.Files.ListByPrefix(ctx, filter)
if err != nil {
return nil, fmt.Errorf("listing files with prefix %s: %w", filter, err)
}
for _, file := range files {
if !seen[file.ID.String()] {
seen[file.ID.String()] = true
result = append(result, file)
}
}
}
return result, nil
}
// buildChunkToBlobMap creates a mapping from chunk hash to blob information
func (v *Vaultik) buildChunkToBlobMap(ctx context.Context, repos *database.Repositories) (map[string]*database.BlobChunk, error) {
// Query all blob_chunks
query := `SELECT blob_id, chunk_hash, offset, length FROM blob_chunks`
rows, err := repos.DB().Conn().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("querying blob_chunks: %w", err)
}
defer func() { _ = rows.Close() }()
result := make(map[string]*database.BlobChunk)
for rows.Next() {
var bc database.BlobChunk
var blobIDStr, chunkHashStr string
if err := rows.Scan(&blobIDStr, &chunkHashStr, &bc.Offset, &bc.Length); err != nil {
return nil, fmt.Errorf("scanning blob_chunk: %w", err)
}
blobID, err := types.ParseBlobID(blobIDStr)
if err != nil {
return nil, fmt.Errorf("parsing blob ID: %w", err)
}
bc.BlobID = blobID
bc.ChunkHash = types.ChunkHash(chunkHashStr)
result[chunkHashStr] = &bc
}
return result, rows.Err()
}
// restoreFile restores a single file
func (v *Vaultik) restoreFile(
ctx context.Context,
repos *database.Repositories,
file *database.File,
targetDir string,
identity age.Identity,
chunkToBlobMap map[string]*database.BlobChunk,
blobCache map[string][]byte,
result *RestoreResult,
) error {
// Calculate target path - use full original path under target directory
targetPath := filepath.Join(targetDir, file.Path.String())
// Create parent directories
parentDir := filepath.Dir(targetPath)
if err := v.Fs.MkdirAll(parentDir, 0755); err != nil {
return fmt.Errorf("creating parent directory: %w", err)
}
// Handle symlinks
if file.IsSymlink() {
return v.restoreSymlink(file, targetPath, result)
}
// Handle directories
if file.Mode&uint32(os.ModeDir) != 0 {
return v.restoreDirectory(file, targetPath, result)
}
// Handle regular files
return v.restoreRegularFile(ctx, repos, file, targetPath, identity, chunkToBlobMap, blobCache, result)
}
// restoreSymlink restores a symbolic link
func (v *Vaultik) restoreSymlink(file *database.File, targetPath string, result *RestoreResult) error {
// Remove existing file if it exists
_ = v.Fs.Remove(targetPath)
// Create symlink
// Note: afero.MemMapFs doesn't support symlinks, so we use os for real filesystems
if osFs, ok := v.Fs.(*afero.OsFs); ok {
_ = osFs // silence unused variable warning
if err := os.Symlink(file.LinkTarget.String(), targetPath); err != nil {
return fmt.Errorf("creating symlink: %w", err)
}
} else {
log.Debug("Symlink creation not supported on this filesystem", "path", file.Path, "target", file.LinkTarget)
}
result.FilesRestored++
log.Debug("Restored symlink", "path", file.Path, "target", file.LinkTarget)
return nil
}
// restoreDirectory restores a directory with proper permissions
func (v *Vaultik) restoreDirectory(file *database.File, targetPath string, result *RestoreResult) error {
// Create directory
if err := v.Fs.MkdirAll(targetPath, os.FileMode(file.Mode)); err != nil {
return fmt.Errorf("creating directory: %w", err)
}
// Set permissions
if err := v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil {
log.Debug("Failed to set directory permissions", "path", targetPath, "error", err)
}
// Set ownership (requires root)
if osFs, ok := v.Fs.(*afero.OsFs); ok {
_ = osFs
if err := os.Chown(targetPath, int(file.UID), int(file.GID)); err != nil {
log.Debug("Failed to set directory ownership", "path", targetPath, "error", err)
}
}
// Set mtime
if err := v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil {
log.Debug("Failed to set directory mtime", "path", targetPath, "error", err)
}
result.FilesRestored++
return nil
}
// restoreRegularFile restores a regular file by reconstructing it from chunks
func (v *Vaultik) restoreRegularFile(
ctx context.Context,
repos *database.Repositories,
file *database.File,
targetPath string,
identity age.Identity,
chunkToBlobMap map[string]*database.BlobChunk,
blobCache map[string][]byte,
result *RestoreResult,
) error {
// Get file chunks in order
fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID)
if err != nil {
return fmt.Errorf("getting file chunks: %w", err)
}
// Create output file
outFile, err := v.Fs.Create(targetPath)
if err != nil {
return fmt.Errorf("creating output file: %w", err)
}
defer func() { _ = outFile.Close() }()
// Write chunks in order
var bytesWritten int64
for _, fc := range fileChunks {
// Find which blob contains this chunk
chunkHashStr := fc.ChunkHash.String()
blobChunk, ok := chunkToBlobMap[chunkHashStr]
if !ok {
return fmt.Errorf("chunk %s not found in any blob", chunkHashStr[:16])
}
// Get the blob's hash from the database
blob, err := repos.Blobs.GetByID(ctx, blobChunk.BlobID.String())
if err != nil {
return fmt.Errorf("getting blob %s: %w", blobChunk.BlobID, err)
}
// Download and decrypt blob if not cached
blobHashStr := blob.Hash.String()
blobData, ok := blobCache[blobHashStr]
if !ok {
blobData, err = v.downloadBlob(ctx, blobHashStr, blob.CompressedSize, identity)
if err != nil {
return fmt.Errorf("downloading blob %s: %w", blobHashStr[:16], err)
}
blobCache[blobHashStr] = blobData
result.BlobsDownloaded++
result.BytesDownloaded += blob.CompressedSize
}
// Extract chunk from blob
if blobChunk.Offset+blobChunk.Length > int64(len(blobData)) {
return fmt.Errorf("chunk %s extends beyond blob data (offset=%d, length=%d, blob_size=%d)",
fc.ChunkHash[:16], blobChunk.Offset, blobChunk.Length, len(blobData))
}
chunkData := blobData[blobChunk.Offset : blobChunk.Offset+blobChunk.Length]
// Write chunk to output file
n, err := outFile.Write(chunkData)
if err != nil {
return fmt.Errorf("writing chunk: %w", err)
}
bytesWritten += int64(n)
}
// Close file before setting metadata
if err := outFile.Close(); err != nil {
return fmt.Errorf("closing output file: %w", err)
}
// Set permissions
if err := v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil {
log.Debug("Failed to set file permissions", "path", targetPath, "error", err)
}
// Set ownership (requires root)
if osFs, ok := v.Fs.(*afero.OsFs); ok {
_ = osFs
if err := os.Chown(targetPath, int(file.UID), int(file.GID)); err != nil {
log.Debug("Failed to set file ownership", "path", targetPath, "error", err)
}
}
// Set mtime
if err := v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil {
log.Debug("Failed to set file mtime", "path", targetPath, "error", err)
}
result.FilesRestored++
result.BytesRestored += bytesWritten
log.Debug("Restored file", "path", file.Path, "size", humanize.Bytes(uint64(bytesWritten)))
return nil
}
// BlobFetchResult holds the result of fetching and decrypting a blob.
type BlobFetchResult struct {
Data []byte
CompressedSize int64
}
// FetchAndDecryptBlob downloads a blob from storage, decrypts and decompresses it.
func (v *Vaultik) FetchAndDecryptBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) (*BlobFetchResult, error) {
// Construct blob path with sharding
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
reader, err := v.Storage.Get(ctx, blobPath)
if err != nil {
return nil, fmt.Errorf("downloading blob: %w", err)
}
defer func() { _ = reader.Close() }()
// Read encrypted data
encryptedData, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("reading blob data: %w", err)
}
// Decrypt and decompress
blobReader, err := blobgen.NewReader(bytes.NewReader(encryptedData), identity)
if err != nil {
return nil, fmt.Errorf("creating decryption reader: %w", err)
}
defer func() { _ = blobReader.Close() }()
data, err := io.ReadAll(blobReader)
if err != nil {
return nil, fmt.Errorf("decrypting blob: %w", err)
}
log.Debug("Downloaded and decrypted blob",
"hash", blobHash[:16],
"encrypted_size", humanize.Bytes(uint64(len(encryptedData))),
"decrypted_size", humanize.Bytes(uint64(len(data))),
)
return &BlobFetchResult{
Data: data,
CompressedSize: int64(len(encryptedData)),
}, nil
}
// downloadBlob downloads and decrypts a blob
func (v *Vaultik) downloadBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) ([]byte, error) {
result, err := v.FetchAndDecryptBlob(ctx, blobHash, expectedSize, identity)
if err != nil {
return nil, err
}
return result.Data, nil
}
// verifyRestoredFiles verifies that all restored files match their expected chunk hashes
func (v *Vaultik) verifyRestoredFiles(
ctx context.Context,
repos *database.Repositories,
files []*database.File,
targetDir string,
result *RestoreResult,
) error {
// Calculate total bytes to verify for progress bar
var totalBytes int64
regularFiles := make([]*database.File, 0, len(files))
for _, file := range files {
// Skip symlinks and directories - only verify regular files
if file.IsSymlink() || file.Mode&uint32(os.ModeDir) != 0 {
continue
}
regularFiles = append(regularFiles, file)
totalBytes += file.Size
}
if len(regularFiles) == 0 {
log.Info("No regular files to verify")
return nil
}
log.Info("Verifying restored files",
"files", len(regularFiles),
"bytes", humanize.Bytes(uint64(totalBytes)),
)
_, _ = fmt.Fprintf(v.Stdout, "\nVerifying %d files (%s)...\n",
len(regularFiles),
humanize.Bytes(uint64(totalBytes)),
)
// Create progress bar if output is a terminal
var bar *progressbar.ProgressBar
if isTerminal() {
bar = progressbar.NewOptions64(
totalBytes,
progressbar.OptionSetDescription("Verifying"),
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionShowBytes(true),
progressbar.OptionShowCount(),
progressbar.OptionSetWidth(40),
progressbar.OptionThrottle(100*time.Millisecond),
progressbar.OptionOnCompletion(func() {
fmt.Fprint(os.Stderr, "\n")
}),
progressbar.OptionSetRenderBlankState(true),
)
}
// Verify each file
for _, file := range regularFiles {
if ctx.Err() != nil {
return ctx.Err()
}
targetPath := filepath.Join(targetDir, file.Path.String())
bytesVerified, err := v.verifyFile(ctx, repos, file, targetPath)
if err != nil {
log.Error("File verification failed", "path", file.Path, "error", err)
result.FilesFailed++
result.FailedFiles = append(result.FailedFiles, file.Path.String())
} else {
result.FilesVerified++
result.BytesVerified += bytesVerified
}
// Update progress bar
if bar != nil {
_ = bar.Add64(file.Size)
}
}
if bar != nil {
_ = bar.Finish()
}
log.Info("Verification complete",
"files_verified", result.FilesVerified,
"bytes_verified", humanize.Bytes(uint64(result.BytesVerified)),
"files_failed", result.FilesFailed,
)
return nil
}
// verifyFile verifies a single restored file by checking its chunk hashes
func (v *Vaultik) verifyFile(
ctx context.Context,
repos *database.Repositories,
file *database.File,
targetPath string,
) (int64, error) {
// Get file chunks in order
fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID)
if err != nil {
return 0, fmt.Errorf("getting file chunks: %w", err)
}
// Open the restored file
f, err := v.Fs.Open(targetPath)
if err != nil {
return 0, fmt.Errorf("opening file: %w", err)
}
defer func() { _ = f.Close() }()
// Verify each chunk
var bytesVerified int64
for _, fc := range fileChunks {
// Get chunk size from database
chunk, err := repos.Chunks.GetByHash(ctx, fc.ChunkHash.String())
if err != nil {
return bytesVerified, fmt.Errorf("getting chunk %s: %w", fc.ChunkHash.String()[:16], err)
}
// Read chunk data from file
chunkData := make([]byte, chunk.Size)
n, err := io.ReadFull(f, chunkData)
if err != nil {
return bytesVerified, fmt.Errorf("reading chunk data: %w", err)
}
if int64(n) != chunk.Size {
return bytesVerified, fmt.Errorf("short read: expected %d bytes, got %d", chunk.Size, n)
}
// Calculate hash and compare
hash := sha256.Sum256(chunkData)
actualHash := hex.EncodeToString(hash[:])
expectedHash := fc.ChunkHash.String()
if actualHash != expectedHash {
return bytesVerified, fmt.Errorf("chunk %d hash mismatch: expected %s, got %s",
fc.Idx, expectedHash[:16], actualHash[:16])
}
bytesVerified += int64(n)
}
log.Debug("File verified", "path", file.Path, "bytes", bytesVerified, "chunks", len(fileChunks))
return bytesVerified, nil
}
// isTerminal returns true if stdout is a terminal
func isTerminal() bool {
return term.IsTerminal(int(os.Stdout.Fd()))
}