diff --git a/internal/snapshot/scanner.go b/internal/snapshot/scanner.go index c1c7630..004a11b 100644 --- a/internal/snapshot/scanner.go +++ b/internal/snapshot/scanner.go @@ -676,7 +676,7 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { if err := blobWithReader.TempFile.Close(); err != nil { log.Fatal("Failed to close temp file", "file", tempName, "error", err) } - if err := os.Remove(tempName); err != nil { + if err := s.fs.Remove(tempName); err != nil { log.Fatal("Failed to remove temp file", "file", tempName, "error", err) } } diff --git a/internal/vaultik/helpers.go b/internal/vaultik/helpers.go new file mode 100644 index 0000000..80019d4 --- /dev/null +++ b/internal/vaultik/helpers.go @@ -0,0 +1,103 @@ +package vaultik + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +// SnapshotInfo contains information about a snapshot +type SnapshotInfo struct { + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + CompressedSize int64 `json:"compressed_size"` +} + +// formatNumber formats a number with commas +func formatNumber(n int) string { + str := fmt.Sprintf("%d", n) + var result []string + for i, digit := range str { + if i > 0 && (len(str)-i)%3 == 0 { + result = append(result, ",") + } + result = append(result, string(digit)) + } + return strings.Join(result, "") +} + +// formatDuration formats a duration in a human-readable way +func formatDuration(d time.Duration) string { + if d < time.Second { + return fmt.Sprintf("%dms", d.Milliseconds()) + } + if d < time.Minute { + return fmt.Sprintf("%.1fs", d.Seconds()) + } + if d < time.Hour { + mins := int(d.Minutes()) + secs := int(d.Seconds()) % 60 + return fmt.Sprintf("%dm %ds", mins, secs) + } + hours := int(d.Hours()) + mins := int(d.Minutes()) % 60 + return fmt.Sprintf("%dh %dm", hours, mins) +} + +// formatBytes formats bytes in a human-readable format +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +// parseSnapshotTimestamp extracts the timestamp from a snapshot ID +func parseSnapshotTimestamp(snapshotID string) (time.Time, error) { + // Format: hostname-YYYYMMDD-HHMMSSZ + parts := strings.Split(snapshotID, "-") + if len(parts) < 3 { + return time.Time{}, fmt.Errorf("invalid snapshot ID format") + } + + dateStr := parts[len(parts)-2] + timeStr := parts[len(parts)-1] + + if len(dateStr) != 8 || len(timeStr) != 7 || !strings.HasSuffix(timeStr, "Z") { + return time.Time{}, fmt.Errorf("invalid timestamp format") + } + + // Remove Z suffix + timeStr = timeStr[:6] + + // Parse the timestamp + timestamp, err := time.Parse("20060102150405", dateStr+timeStr) + if err != nil { + return time.Time{}, fmt.Errorf("failed to parse timestamp: %w", err) + } + + return timestamp.UTC(), nil +} + +// parseDuration parses a duration string with support for days +func parseDuration(s string) (time.Duration, error) { + // Check for days suffix + if strings.HasSuffix(s, "d") { + daysStr := strings.TrimSuffix(s, "d") + days, err := strconv.Atoi(daysStr) + if err != nil { + return 0, fmt.Errorf("invalid days value: %w", err) + } + return time.Duration(days) * 24 * time.Hour, nil + } + + // Otherwise use standard Go duration parsing + return time.ParseDuration(s) +} diff --git a/internal/vaultik/info.go b/internal/vaultik/info.go new file mode 100644 index 0000000..87453c5 --- /dev/null +++ b/internal/vaultik/info.go @@ -0,0 +1,101 @@ +package vaultik + +import ( + "fmt" + "runtime" + "strings" + + "github.com/dustin/go-humanize" +) + +// ShowInfo displays system and configuration information +func (v *Vaultik) ShowInfo() error { + // System Information + fmt.Printf("=== System Information ===\n") + fmt.Printf("OS/Architecture: %s/%s\n", runtime.GOOS, runtime.GOARCH) + fmt.Printf("Version: %s\n", v.Globals.Version) + fmt.Printf("Commit: %s\n", v.Globals.Commit) + fmt.Printf("Go Version: %s\n", runtime.Version()) + fmt.Println() + + // Storage Configuration + fmt.Printf("=== Storage Configuration ===\n") + fmt.Printf("S3 Bucket: %s\n", v.Config.S3.Bucket) + if v.Config.S3.Prefix != "" { + fmt.Printf("S3 Prefix: %s\n", v.Config.S3.Prefix) + } + fmt.Printf("S3 Endpoint: %s\n", v.Config.S3.Endpoint) + fmt.Printf("S3 Region: %s\n", v.Config.S3.Region) + fmt.Println() + + // Backup Settings + fmt.Printf("=== Backup Settings ===\n") + fmt.Printf("Source Directories:\n") + for _, dir := range v.Config.SourceDirs { + fmt.Printf(" - %s\n", dir) + } + + // Global exclude patterns + if len(v.Config.Exclude) > 0 { + fmt.Printf("Exclude Patterns: %s\n", strings.Join(v.Config.Exclude, ", ")) + } + + fmt.Printf("Compression: zstd level %d\n", v.Config.CompressionLevel) + fmt.Printf("Chunk Size: %s\n", humanize.Bytes(uint64(v.Config.ChunkSize))) + fmt.Printf("Blob Size Limit: %s\n", humanize.Bytes(uint64(v.Config.BlobSizeLimit))) + fmt.Println() + + // Encryption Configuration + fmt.Printf("=== Encryption Configuration ===\n") + fmt.Printf("Recipients:\n") + for _, recipient := range v.Config.AgeRecipients { + fmt.Printf(" - %s\n", recipient) + } + fmt.Println() + + // Daemon Settings (if applicable) + if v.Config.BackupInterval > 0 || v.Config.MinTimeBetweenRun > 0 { + fmt.Printf("=== Daemon Settings ===\n") + if v.Config.BackupInterval > 0 { + fmt.Printf("Backup Interval: %s\n", v.Config.BackupInterval) + } + if v.Config.MinTimeBetweenRun > 0 { + fmt.Printf("Minimum Time: %s\n", v.Config.MinTimeBetweenRun) + } + fmt.Println() + } + + // Local Database + fmt.Printf("=== Local Database ===\n") + fmt.Printf("Index Path: %s\n", v.Config.IndexPath) + + // Check if index file exists and get its size + if info, err := v.Fs.Stat(v.Config.IndexPath); err == nil { + fmt.Printf("Index Size: %s\n", humanize.Bytes(uint64(info.Size()))) + + // Get snapshot count from database + query := `SELECT COUNT(*) FROM snapshots WHERE completed_at IS NOT NULL` + var snapshotCount int + if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&snapshotCount); err == nil { + fmt.Printf("Snapshots: %d\n", snapshotCount) + } + + // Get blob count from database + query = `SELECT COUNT(*) FROM blobs` + var blobCount int + if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&blobCount); err == nil { + fmt.Printf("Blobs: %d\n", blobCount) + } + + // Get file count from database + query = `SELECT COUNT(*) FROM files` + var fileCount int + if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&fileCount); err == nil { + fmt.Printf("Files: %d\n", fileCount) + } + } else { + fmt.Printf("Index Size: (not created)\n") + } + + return nil +} diff --git a/internal/vaultik/integration_test.go b/internal/vaultik/integration_test.go new file mode 100644 index 0000000..d99186f --- /dev/null +++ b/internal/vaultik/integration_test.go @@ -0,0 +1,379 @@ +package vaultik_test + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "sync" + "testing" + "time" + + "git.eeqj.de/sneak/vaultik/internal/config" + "git.eeqj.de/sneak/vaultik/internal/database" + "git.eeqj.de/sneak/vaultik/internal/log" + "git.eeqj.de/sneak/vaultik/internal/s3" + "git.eeqj.de/sneak/vaultik/internal/snapshot" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockS3Client implements a mock S3 client for testing +type MockS3Client struct { + mu sync.Mutex + storage map[string][]byte + calls []string +} + +func NewMockS3Client() *MockS3Client { + return &MockS3Client{ + storage: make(map[string][]byte), + calls: make([]string, 0), + } +} + +func (m *MockS3Client) PutObject(ctx context.Context, key string, reader io.Reader) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls = append(m.calls, "PutObject:"+key) + data, err := io.ReadAll(reader) + if err != nil { + return err + } + m.storage[key] = data + return nil +} + +func (m *MockS3Client) PutObjectWithProgress(ctx context.Context, key string, reader io.Reader, size int64, progress s3.ProgressCallback) error { + // For testing, just call PutObject + return m.PutObject(ctx, key, reader) +} + +func (m *MockS3Client) GetObject(ctx context.Context, key string) (io.ReadCloser, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls = append(m.calls, "GetObject:"+key) + data, exists := m.storage[key] + if !exists { + return nil, fmt.Errorf("key not found: %s", key) + } + return io.NopCloser(bytes.NewReader(data)), nil +} + +func (m *MockS3Client) StatObject(ctx context.Context, key string) (*s3.ObjectInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls = append(m.calls, "StatObject:"+key) + data, exists := m.storage[key] + if !exists { + return nil, fmt.Errorf("key not found: %s", key) + } + return &s3.ObjectInfo{ + Key: key, + Size: int64(len(data)), + }, nil +} + +func (m *MockS3Client) DeleteObject(ctx context.Context, key string) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls = append(m.calls, "DeleteObject:"+key) + delete(m.storage, key) + return nil +} + +func (m *MockS3Client) ListObjects(ctx context.Context, prefix string) ([]*s3.ObjectInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls = append(m.calls, "ListObjects:"+prefix) + var objects []*s3.ObjectInfo + for key, data := range m.storage { + if len(prefix) == 0 || (len(key) >= len(prefix) && key[:len(prefix)] == prefix) { + objects = append(objects, &s3.ObjectInfo{ + Key: key, + Size: int64(len(data)), + }) + } + } + return objects, nil +} + +// GetCalls returns the list of S3 operations that were called +func (m *MockS3Client) GetCalls() []string { + m.mu.Lock() + defer m.mu.Unlock() + + calls := make([]string, len(m.calls)) + copy(calls, m.calls) + return calls +} + +// GetStorageSize returns the number of objects in storage +func (m *MockS3Client) GetStorageSize() int { + m.mu.Lock() + defer m.mu.Unlock() + + return len(m.storage) +} + +// TestEndToEndBackup tests the full backup workflow with mocked dependencies +func TestEndToEndBackup(t *testing.T) { + // Initialize logger + log.Initialize(log.Config{}) + + // Create in-memory filesystem + fs := afero.NewMemMapFs() + + // Create test directory structure and files + testFiles := map[string]string{ + "/home/user/documents/file1.txt": "This is file 1 content", + "/home/user/documents/file2.txt": "This is file 2 content with more data", + "/home/user/pictures/photo1.jpg": "Binary photo data here...", + "/home/user/code/main.go": "package main\n\nfunc main() {\n\tprintln(\"Hello, World!\")\n}", + } + + // Create all directories first + dirs := []string{ + "/home/user/documents", + "/home/user/pictures", + "/home/user/code", + } + for _, dir := range dirs { + if err := fs.MkdirAll(dir, 0755); err != nil { + t.Fatalf("failed to create directory %s: %v", dir, err) + } + } + + // Create test files + for path, content := range testFiles { + if err := afero.WriteFile(fs, path, []byte(content), 0644); err != nil { + t.Fatalf("failed to create test file %s: %v", path, err) + } + } + + // Create mock S3 client + mockS3 := NewMockS3Client() + + // Create test configuration + cfg := &config.Config{ + SourceDirs: []string{"/home/user"}, + Exclude: []string{"*.tmp", "*.log"}, + ChunkSize: config.Size(16 * 1024), // 16KB chunks + BlobSizeLimit: config.Size(100 * 1024), // 100KB blobs + CompressionLevel: 3, + AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key + AgeSecretKey: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5", // Test private key + S3: config.S3Config{ + Endpoint: "http://localhost:9000", // MinIO endpoint for testing + Region: "us-east-1", + Bucket: "test-bucket", + AccessKeyID: "test-access", + SecretAccessKey: "test-secret", + }, + IndexPath: ":memory:", // In-memory SQLite database + } + + // For a true end-to-end test, we'll create a simpler test that focuses on + // the core backup logic using the scanner directly with our mock S3 client + ctx := context.Background() + + // Create in-memory database + db, err := database.New(ctx, ":memory:") + require.NoError(t, err) + defer func() { + if err := db.Close(); err != nil { + t.Errorf("failed to close database: %v", err) + } + }() + + repos := database.NewRepositories(db) + + // Create scanner with mock S3 client + scanner := snapshot.NewScanner(snapshot.ScannerConfig{ + FS: fs, + ChunkSize: cfg.ChunkSize.Int64(), + Repositories: repos, + S3Client: mockS3, + MaxBlobSize: cfg.BlobSizeLimit.Int64(), + CompressionLevel: cfg.CompressionLevel, + AgeRecipients: cfg.AgeRecipients, + EnableProgress: false, + }) + + // Create a snapshot record + snapshotID := "test-snapshot-001" + err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { + snapshot := &database.Snapshot{ + ID: snapshotID, + Hostname: "test-host", + VaultikVersion: "test-version", + StartedAt: time.Now(), + } + return repos.Snapshots.Create(ctx, tx, snapshot) + }) + require.NoError(t, err) + + // Run the backup scan + result, err := scanner.Scan(ctx, "/home/user", snapshotID) + require.NoError(t, err) + + // Verify scan results + // The scanner counts both files and directories, so we have: + // 4 files + 4 directories (/home, /home/user, /home/user/documents, /home/user/pictures, /home/user/code) + assert.GreaterOrEqual(t, result.FilesScanned, 4, "Should scan at least 4 files") + assert.Greater(t, result.BytesScanned, int64(0), "Should scan some bytes") + assert.Greater(t, result.ChunksCreated, 0, "Should create chunks") + assert.Greater(t, result.BlobsCreated, 0, "Should create blobs") + + // Verify S3 operations + calls := mockS3.GetCalls() + t.Logf("S3 operations performed: %v", calls) + + // Should have uploaded at least one blob + blobUploads := 0 + for _, call := range calls { + if len(call) > 10 && call[:10] == "PutObject:" { + if len(call) > 16 && call[10:16] == "blobs/" { + blobUploads++ + } + } + } + assert.Greater(t, blobUploads, 0, "Should upload at least one blob") + + // Verify files in database + files, err := repos.Files.ListByPrefix(ctx, "/home/user") + require.NoError(t, err) + // Count only regular files (not directories) + regularFiles := 0 + for _, f := range files { + if f.Mode&0x80000000 == 0 { // Check if regular file (not directory) + regularFiles++ + } + } + assert.Equal(t, 4, regularFiles, "Should have 4 regular files in database") + + // Verify chunks were created by checking a specific file + fileChunks, err := repos.FileChunks.GetByPath(ctx, "/home/user/documents/file1.txt") + require.NoError(t, err) + assert.Greater(t, len(fileChunks), 0, "Should have chunks for file1.txt") + + // Verify blobs were uploaded to S3 + assert.Greater(t, mockS3.GetStorageSize(), 0, "Should have blobs in S3 storage") + + // Complete the snapshot - just verify we got results + // In a real integration test, we'd update the snapshot record + + // Create snapshot manager to test metadata export + snapshotManager := &snapshot.SnapshotManager{} + snapshotManager.SetFilesystem(fs) + + // Note: We can't fully test snapshot metadata export without a proper S3 client mock + // that implements all required methods. This would require refactoring the S3 client + // interface to be more testable. + + t.Logf("Backup completed successfully:") + t.Logf(" Files scanned: %d", result.FilesScanned) + t.Logf(" Bytes scanned: %d", result.BytesScanned) + t.Logf(" Chunks created: %d", result.ChunksCreated) + t.Logf(" Blobs created: %d", result.BlobsCreated) + t.Logf(" S3 storage size: %d objects", mockS3.GetStorageSize()) +} + +// TestBackupAndVerify tests backing up files and verifying the blobs +func TestBackupAndVerify(t *testing.T) { + // Initialize logger + log.Initialize(log.Config{}) + + // Create in-memory filesystem + fs := afero.NewMemMapFs() + + // Create test files + testContent := "This is a test file with some content that should be backed up" + err := fs.MkdirAll("/data", 0755) + require.NoError(t, err) + err = afero.WriteFile(fs, "/data/test.txt", []byte(testContent), 0644) + require.NoError(t, err) + + // Create mock S3 client + mockS3 := NewMockS3Client() + + // Create test database + ctx := context.Background() + db, err := database.New(ctx, ":memory:") + require.NoError(t, err) + defer func() { + if err := db.Close(); err != nil { + t.Errorf("failed to close database: %v", err) + } + }() + + repos := database.NewRepositories(db) + + // Create scanner + scanner := snapshot.NewScanner(snapshot.ScannerConfig{ + FS: fs, + ChunkSize: int64(1024 * 16), // 16KB chunks + Repositories: repos, + S3Client: mockS3, + MaxBlobSize: int64(1024 * 1024), // 1MB blobs + CompressionLevel: 3, + AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key + }) + + // Create a snapshot + snapshotID := "test-snapshot-001" + err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { + snapshot := &database.Snapshot{ + ID: snapshotID, + Hostname: "test-host", + VaultikVersion: "test-version", + StartedAt: time.Now(), + } + return repos.Snapshots.Create(ctx, tx, snapshot) + }) + require.NoError(t, err) + + // Run the backup + result, err := scanner.Scan(ctx, "/data", snapshotID) + require.NoError(t, err) + + // Verify backup created blobs + assert.Greater(t, result.BlobsCreated, 0, "Should create at least one blob") + assert.Equal(t, mockS3.GetStorageSize(), result.BlobsCreated, "S3 should have the blobs") + + // Verify we can retrieve the blob from S3 + objects, err := mockS3.ListObjects(ctx, "blobs/") + require.NoError(t, err) + assert.Len(t, objects, result.BlobsCreated, "Should have correct number of blobs in S3") + + // Get the first blob and verify it exists + if len(objects) > 0 { + blobKey := objects[0].Key + t.Logf("Verifying blob: %s", blobKey) + + // Get blob info + blobInfo, err := mockS3.StatObject(ctx, blobKey) + require.NoError(t, err) + assert.Greater(t, blobInfo.Size, int64(0), "Blob should have content") + + // Get blob content + reader, err := mockS3.GetObject(ctx, blobKey) + require.NoError(t, err) + defer func() { _ = reader.Close() }() + + // Verify blob data is encrypted (should not contain plaintext) + blobData, err := io.ReadAll(reader) + require.NoError(t, err) + assert.NotContains(t, string(blobData), testContent, "Blob should be encrypted") + assert.Greater(t, len(blobData), 0, "Blob should have data") + } + + t.Logf("Backup and verify test completed successfully") +} diff --git a/internal/vaultik/prune.go b/internal/vaultik/prune.go new file mode 100644 index 0000000..030098f --- /dev/null +++ b/internal/vaultik/prune.go @@ -0,0 +1,169 @@ +package vaultik + +import ( + "fmt" + "strings" + + "git.eeqj.de/sneak/vaultik/internal/log" + "github.com/dustin/go-humanize" +) + +// PruneOptions contains options for the prune command +type PruneOptions struct { + Force bool +} + +// PruneBlobs removes unreferenced blobs from storage +func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { + log.Info("Starting prune operation") + + // Get all remote snapshots and their manifests + allBlobsReferenced := make(map[string]bool) + manifestCount := 0 + + // List all snapshots in S3 + log.Info("Listing remote snapshots") + objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false) + + var snapshotIDs []string + for object := range objectCh { + if object.Err != nil { + return fmt.Errorf("listing remote snapshots: %w", object.Err) + } + + // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/ + parts := strings.Split(object.Key, "/") + if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { + // Check if this is a directory by looking for trailing slash + if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") { + snapshotID := parts[1] + // Only add unique snapshot IDs + found := false + for _, id := range snapshotIDs { + if id == snapshotID { + found = true + break + } + } + if !found { + snapshotIDs = append(snapshotIDs, snapshotID) + } + } + } + } + + log.Info("Found manifests in remote storage", "count", len(snapshotIDs)) + + // Download and parse each manifest to get referenced blobs + for _, snapshotID := range snapshotIDs { + log.Debug("Processing manifest", "snapshot_id", snapshotID) + + manifest, err := v.downloadManifest(snapshotID) + if err != nil { + log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err) + continue + } + + // Add all blobs from this manifest to our referenced set + for _, blob := range manifest.Blobs { + allBlobsReferenced[blob.Hash] = true + } + manifestCount++ + } + + log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced)) + + // List all blobs in S3 + log.Info("Listing all blobs in storage") + allBlobs := make(map[string]int64) // hash -> size + blobObjectCh := v.S3Client.ListObjectsStream(v.ctx, "blobs/", true) + + for object := range blobObjectCh { + if object.Err != nil { + return fmt.Errorf("listing blobs: %w", object.Err) + } + + // Extract hash from path like blobs/ab/cd/abcdef123456... + parts := strings.Split(object.Key, "/") + if len(parts) == 4 && parts[0] == "blobs" { + hash := parts[3] + allBlobs[hash] = object.Size + } + } + + log.Info("Found blobs in storage", "count", len(allBlobs)) + + // Find unreferenced blobs + var unreferencedBlobs []string + var totalSize int64 + for hash, size := range allBlobs { + if !allBlobsReferenced[hash] { + unreferencedBlobs = append(unreferencedBlobs, hash) + totalSize += size + } + } + + if len(unreferencedBlobs) == 0 { + log.Info("No unreferenced blobs found") + fmt.Println("No unreferenced blobs to remove.") + return nil + } + + // Show what will be deleted + log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize))) + fmt.Printf("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize))) + + // Confirm unless --force is used + if !opts.Force { + fmt.Printf("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs)) + var confirm string + if _, err := fmt.Scanln(&confirm); err != nil { + // Treat EOF or error as "no" + fmt.Println("Cancelled") + return nil + } + if strings.ToLower(confirm) != "y" { + fmt.Println("Cancelled") + return nil + } + } + + // Delete unreferenced blobs + log.Info("Deleting unreferenced blobs") + deletedCount := 0 + deletedSize := int64(0) + + for i, hash := range unreferencedBlobs { + blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash) + + if err := v.S3Client.RemoveObject(v.ctx, blobPath); err != nil { + log.Error("Failed to delete blob", "hash", hash, "error", err) + continue + } + + deletedCount++ + deletedSize += allBlobs[hash] + + // Progress update every 100 blobs + if (i+1)%100 == 0 || i == len(unreferencedBlobs)-1 { + log.Info("Deletion progress", + "deleted", i+1, + "total", len(unreferencedBlobs), + "percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(unreferencedBlobs))*100), + ) + } + } + + log.Info("Prune complete", + "deleted_count", deletedCount, + "deleted_size", humanize.Bytes(uint64(deletedSize)), + "failed", len(unreferencedBlobs)-deletedCount, + ) + + fmt.Printf("\nDeleted %d blob(s) totaling %s\n", deletedCount, humanize.Bytes(uint64(deletedSize))) + if deletedCount < len(unreferencedBlobs) { + fmt.Printf("Failed to delete %d blob(s)\n", len(unreferencedBlobs)-deletedCount) + } + + return nil +} diff --git a/internal/vaultik/snapshot.go b/internal/vaultik/snapshot.go new file mode 100644 index 0000000..97c317b --- /dev/null +++ b/internal/vaultik/snapshot.go @@ -0,0 +1,701 @@ +package vaultik + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "text/tabwriter" + "time" + + "git.eeqj.de/sneak/vaultik/internal/database" + "git.eeqj.de/sneak/vaultik/internal/log" + "git.eeqj.de/sneak/vaultik/internal/snapshot" + "github.com/dustin/go-humanize" +) + +// SnapshotCreateOptions contains options for the snapshot create command +type SnapshotCreateOptions struct { + Daemon bool + Cron bool + Prune bool +} + +// CreateSnapshot executes the snapshot creation operation +func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error { + snapshotStartTime := time.Now() + + log.Info("Starting snapshot creation", + "version", v.Globals.Version, + "commit", v.Globals.Commit, + "index_path", v.Config.IndexPath, + ) + + // Clean up incomplete snapshots FIRST, before any scanning + // This is critical for data safety - see CleanupIncompleteSnapshots for details + hostname := v.Config.Hostname + if hostname == "" { + hostname, _ = os.Hostname() + } + + // CRITICAL: This MUST succeed. If we fail to clean up incomplete snapshots, + // the deduplication logic will think files from the incomplete snapshot were + // already backed up and skip them, resulting in data loss. + if err := v.SnapshotManager.CleanupIncompleteSnapshots(v.ctx, hostname); err != nil { + return fmt.Errorf("cleanup incomplete snapshots: %w", err) + } + + if opts.Daemon { + log.Info("Running in daemon mode") + // TODO: Implement daemon mode with inotify + return fmt.Errorf("daemon mode not yet implemented") + } + + // Resolve source directories to absolute paths + resolvedDirs := make([]string, 0, len(v.Config.SourceDirs)) + for _, dir := range v.Config.SourceDirs { + absPath, err := filepath.Abs(dir) + if err != nil { + return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err) + } + + // Resolve symlinks + resolvedPath, err := filepath.EvalSymlinks(absPath) + if err != nil { + // If the path doesn't exist yet, use the absolute path + if os.IsNotExist(err) { + resolvedPath = absPath + } else { + return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err) + } + } + + resolvedDirs = append(resolvedDirs, resolvedPath) + } + + // Create scanner with progress enabled (unless in cron mode) + scanner := v.ScannerFactory(snapshot.ScannerParams{ + EnableProgress: !opts.Cron, + Fs: v.Fs, + }) + + // Statistics tracking + totalFiles := 0 + totalBytes := int64(0) + totalChunks := 0 + totalBlobs := 0 + totalBytesSkipped := int64(0) + totalFilesSkipped := 0 + totalFilesDeleted := 0 + totalBytesDeleted := int64(0) + totalBytesUploaded := int64(0) + totalBlobsUploaded := 0 + uploadDuration := time.Duration(0) + + // Create a new snapshot at the beginning + snapshotID, err := v.SnapshotManager.CreateSnapshot(v.ctx, hostname, v.Globals.Version, v.Globals.Commit) + if err != nil { + return fmt.Errorf("creating snapshot: %w", err) + } + log.Info("Beginning snapshot", "snapshot_id", snapshotID) + _, _ = fmt.Fprintf(v.Stdout, "Beginning snapshot: %s\n", snapshotID) + + for i, dir := range resolvedDirs { + // Check if context is cancelled + select { + case <-v.ctx.Done(): + log.Info("Snapshot creation cancelled") + return v.ctx.Err() + default: + } + + log.Info("Scanning directory", "path", dir) + _, _ = fmt.Fprintf(v.Stdout, "Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir) + result, err := scanner.Scan(v.ctx, dir, snapshotID) + if err != nil { + return fmt.Errorf("failed to scan %s: %w", dir, err) + } + + totalFiles += result.FilesScanned + totalBytes += result.BytesScanned + totalChunks += result.ChunksCreated + totalBlobs += result.BlobsCreated + totalFilesSkipped += result.FilesSkipped + totalBytesSkipped += result.BytesSkipped + totalFilesDeleted += result.FilesDeleted + totalBytesDeleted += result.BytesDeleted + + log.Info("Directory scan complete", + "path", dir, + "files", result.FilesScanned, + "files_skipped", result.FilesSkipped, + "bytes", result.BytesScanned, + "bytes_skipped", result.BytesSkipped, + "chunks", result.ChunksCreated, + "blobs", result.BlobsCreated, + "duration", result.EndTime.Sub(result.StartTime)) + + // Remove per-directory summary - the scanner already prints its own summary + } + + // Get upload statistics from scanner progress if available + if s := scanner.GetProgress(); s != nil { + stats := s.GetStats() + totalBytesUploaded = stats.BytesUploaded.Load() + totalBlobsUploaded = int(stats.BlobsUploaded.Load()) + uploadDuration = time.Duration(stats.UploadDurationMs.Load()) * time.Millisecond + } + + // Update snapshot statistics with extended fields + extStats := snapshot.ExtendedBackupStats{ + BackupStats: snapshot.BackupStats{ + FilesScanned: totalFiles, + BytesScanned: totalBytes, + ChunksCreated: totalChunks, + BlobsCreated: totalBlobs, + BytesUploaded: totalBytesUploaded, + }, + BlobUncompressedSize: 0, // Will be set from database query below + CompressionLevel: v.Config.CompressionLevel, + UploadDurationMs: uploadDuration.Milliseconds(), + } + + if err := v.SnapshotManager.UpdateSnapshotStatsExtended(v.ctx, snapshotID, extStats); err != nil { + return fmt.Errorf("updating snapshot stats: %w", err) + } + + // Mark snapshot as complete + if err := v.SnapshotManager.CompleteSnapshot(v.ctx, snapshotID); err != nil { + return fmt.Errorf("completing snapshot: %w", err) + } + + // Export snapshot metadata + // Export snapshot metadata without closing the database + // The export function should handle its own database connection + if err := v.SnapshotManager.ExportSnapshotMetadata(v.ctx, v.Config.IndexPath, snapshotID); err != nil { + return fmt.Errorf("exporting snapshot metadata: %w", err) + } + + // Calculate final statistics + snapshotDuration := time.Since(snapshotStartTime) + totalFilesChanged := totalFiles - totalFilesSkipped + totalBytesChanged := totalBytes + totalBytesAll := totalBytes + totalBytesSkipped + + // Calculate upload speed + var avgUploadSpeed string + if totalBytesUploaded > 0 && uploadDuration > 0 { + bytesPerSec := float64(totalBytesUploaded) / uploadDuration.Seconds() + bitsPerSec := bytesPerSec * 8 + if bitsPerSec >= 1e9 { + avgUploadSpeed = fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9) + } else if bitsPerSec >= 1e6 { + avgUploadSpeed = fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6) + } else if bitsPerSec >= 1e3 { + avgUploadSpeed = fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3) + } else { + avgUploadSpeed = fmt.Sprintf("%.0f bit/s", bitsPerSec) + } + } else { + avgUploadSpeed = "N/A" + } + + // Get total blob sizes from database + totalBlobSizeCompressed := int64(0) + totalBlobSizeUncompressed := int64(0) + if blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID); err == nil { + for _, hash := range blobHashes { + if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil { + totalBlobSizeCompressed += blob.CompressedSize + totalBlobSizeUncompressed += blob.UncompressedSize + } + } + } + + // Calculate compression ratio + var compressionRatio float64 + if totalBlobSizeUncompressed > 0 { + compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed) + } else { + compressionRatio = 1.0 + } + + // Print comprehensive summary + _, _ = fmt.Fprintf(v.Stdout, "=== Snapshot Complete ===\n") + _, _ = fmt.Fprintf(v.Stdout, "ID: %s\n", snapshotID) + _, _ = fmt.Fprintf(v.Stdout, "Files: %s examined, %s to process, %s unchanged", + formatNumber(totalFiles), + formatNumber(totalFilesChanged), + formatNumber(totalFilesSkipped)) + if totalFilesDeleted > 0 { + _, _ = fmt.Fprintf(v.Stdout, ", %s deleted", formatNumber(totalFilesDeleted)) + } + _, _ = fmt.Fprintln(v.Stdout) + _, _ = fmt.Fprintf(v.Stdout, "Data: %s total (%s to process)", + humanize.Bytes(uint64(totalBytesAll)), + humanize.Bytes(uint64(totalBytesChanged))) + if totalBytesDeleted > 0 { + _, _ = fmt.Fprintf(v.Stdout, ", %s deleted", humanize.Bytes(uint64(totalBytesDeleted))) + } + _, _ = fmt.Fprintln(v.Stdout) + if totalBlobsUploaded > 0 { + _, _ = fmt.Fprintf(v.Stdout, "Storage: %s compressed from %s (%.2fx)\n", + humanize.Bytes(uint64(totalBlobSizeCompressed)), + humanize.Bytes(uint64(totalBlobSizeUncompressed)), + compressionRatio) + _, _ = fmt.Fprintf(v.Stdout, "Upload: %d blobs, %s in %s (%s)\n", + totalBlobsUploaded, + humanize.Bytes(uint64(totalBytesUploaded)), + formatDuration(uploadDuration), + avgUploadSpeed) + } + _, _ = fmt.Fprintf(v.Stdout, "Duration: %s\n", formatDuration(snapshotDuration)) + + if opts.Prune { + log.Info("Pruning enabled - will delete old snapshots after snapshot") + // TODO: Implement pruning + } + + return nil +} + +// ListSnapshots lists all snapshots +func (v *Vaultik) ListSnapshots(jsonOutput bool) error { + // Get all remote snapshots + remoteSnapshots := make(map[string]bool) + objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false) + + for object := range objectCh { + if object.Err != nil { + return fmt.Errorf("listing remote snapshots: %w", object.Err) + } + + // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/ + parts := strings.Split(object.Key, "/") + if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { + remoteSnapshots[parts[1]] = true + } + } + + // Get all local snapshots + localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000) + if err != nil { + return fmt.Errorf("listing local snapshots: %w", err) + } + + // Build a map of local snapshots for quick lookup + localSnapshotMap := make(map[string]*database.Snapshot) + for _, s := range localSnapshots { + localSnapshotMap[s.ID] = s + } + + // Remove local snapshots that don't exist remotely + for _, snapshot := range localSnapshots { + if !remoteSnapshots[snapshot.ID] { + log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID) + + // Delete related records first to avoid foreign key constraints + if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshot.ID); err != nil { + log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err) + } + if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshot.ID); err != nil { + log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err) + } + if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshot.ID); err != nil { + log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err) + } + + // Now delete the snapshot itself + if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil { + log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err) + } else { + log.Info("Deleted local snapshot not found in remote", "snapshot_id", snapshot.ID) + delete(localSnapshotMap, snapshot.ID) + } + } + } + + // Build final snapshot list + snapshots := make([]SnapshotInfo, 0, len(remoteSnapshots)) + + for snapshotID := range remoteSnapshots { + // Check if we have this snapshot locally + if localSnap, exists := localSnapshotMap[snapshotID]; exists && localSnap.CompletedAt != nil { + // Get total compressed size of all blobs referenced by this snapshot + totalSize, err := v.Repositories.Snapshots.GetSnapshotTotalCompressedSize(v.ctx, snapshotID) + if err != nil { + log.Warn("Failed to get total compressed size", "id", snapshotID, "error", err) + // Fall back to stored blob size + totalSize = localSnap.BlobSize + } + + snapshots = append(snapshots, SnapshotInfo{ + ID: localSnap.ID, + Timestamp: localSnap.StartedAt, + CompressedSize: totalSize, + }) + } else { + // Remote snapshot not in local DB - fetch manifest to get size + timestamp, err := parseSnapshotTimestamp(snapshotID) + if err != nil { + log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err) + continue + } + + // Try to download manifest to get size + totalSize, err := v.getManifestSize(snapshotID) + if err != nil { + return fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err) + } + + snapshots = append(snapshots, SnapshotInfo{ + ID: snapshotID, + Timestamp: timestamp, + CompressedSize: totalSize, + }) + } + } + + // Sort by timestamp (newest first) + sort.Slice(snapshots, func(i, j int) bool { + return snapshots[i].Timestamp.After(snapshots[j].Timestamp) + }) + + if jsonOutput { + // JSON output + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent("", " ") + return encoder.Encode(snapshots) + } + + // Table output + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + if _, err := fmt.Fprintln(w, "SNAPSHOT ID\tTIMESTAMP\tCOMPRESSED SIZE"); err != nil { + return err + } + if _, err := fmt.Fprintln(w, "───────────\t─────────\t───────────────"); err != nil { + return err + } + + for _, snap := range snapshots { + if _, err := fmt.Fprintf(w, "%s\t%s\t%s\n", + snap.ID, + snap.Timestamp.Format("2006-01-02 15:04:05"), + formatBytes(snap.CompressedSize)); err != nil { + return err + } + } + + return w.Flush() +} + +// PurgeSnapshots removes old snapshots based on criteria +func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool) error { + // Sync with remote first + if err := v.syncWithRemote(); err != nil { + return fmt.Errorf("syncing with remote: %w", err) + } + + // Get snapshots from local database + dbSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000) + if err != nil { + return fmt.Errorf("listing snapshots: %w", err) + } + + // Convert to SnapshotInfo format, only including completed snapshots + snapshots := make([]SnapshotInfo, 0, len(dbSnapshots)) + for _, s := range dbSnapshots { + if s.CompletedAt != nil { + snapshots = append(snapshots, SnapshotInfo{ + ID: s.ID, + Timestamp: s.StartedAt, + CompressedSize: s.BlobSize, + }) + } + } + + // Sort by timestamp (newest first) + sort.Slice(snapshots, func(i, j int) bool { + return snapshots[i].Timestamp.After(snapshots[j].Timestamp) + }) + + var toDelete []SnapshotInfo + + if keepLatest { + // Keep only the most recent snapshot + if len(snapshots) > 1 { + toDelete = snapshots[1:] + } + } else if olderThan != "" { + // Parse duration + duration, err := parseDuration(olderThan) + if err != nil { + return fmt.Errorf("invalid duration: %w", err) + } + + cutoff := time.Now().UTC().Add(-duration) + for _, snap := range snapshots { + if snap.Timestamp.Before(cutoff) { + toDelete = append(toDelete, snap) + } + } + } + + if len(toDelete) == 0 { + fmt.Println("No snapshots to delete") + return nil + } + + // Show what will be deleted + fmt.Printf("The following snapshots will be deleted:\n\n") + for _, snap := range toDelete { + fmt.Printf(" %s (%s, %s)\n", + snap.ID, + snap.Timestamp.Format("2006-01-02 15:04:05"), + formatBytes(snap.CompressedSize)) + } + + // Confirm unless --force is used + if !force { + fmt.Printf("\nDelete %d snapshot(s)? [y/N] ", len(toDelete)) + var confirm string + if _, err := fmt.Scanln(&confirm); err != nil { + // Treat EOF or error as "no" + fmt.Println("Cancelled") + return nil + } + if strings.ToLower(confirm) != "y" { + fmt.Println("Cancelled") + return nil + } + } else { + fmt.Printf("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete)) + } + + // Delete snapshots + for _, snap := range toDelete { + log.Info("Deleting snapshot", "id", snap.ID) + if err := v.deleteSnapshot(snap.ID); err != nil { + return fmt.Errorf("deleting snapshot %s: %w", snap.ID, err) + } + } + + fmt.Printf("Deleted %d snapshot(s)\n", len(toDelete)) + + // Note: Run 'vaultik prune' separately to clean up unreferenced blobs + fmt.Println("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.") + + return nil +} + +// VerifySnapshot checks snapshot integrity +func (v *Vaultik) VerifySnapshot(snapshotID string, deep bool) error { + // Parse snapshot ID to extract timestamp + parts := strings.Split(snapshotID, "-") + var snapshotTime time.Time + if len(parts) >= 3 { + // Format: hostname-YYYYMMDD-HHMMSSZ + dateStr := parts[len(parts)-2] + timeStr := parts[len(parts)-1] + if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") { + timeStr = timeStr[:6] // Remove Z + timestamp, err := time.Parse("20060102150405", dateStr+timeStr) + if err == nil { + snapshotTime = timestamp + } + } + } + + fmt.Printf("Verifying snapshot %s\n", snapshotID) + if !snapshotTime.IsZero() { + fmt.Printf("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) + } + fmt.Println() + + // Download and parse manifest + manifest, err := v.downloadManifest(snapshotID) + if err != nil { + return fmt.Errorf("downloading manifest: %w", err) + } + + fmt.Printf("Snapshot information:\n") + fmt.Printf(" Blob count: %d\n", manifest.BlobCount) + fmt.Printf(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize))) + if manifest.Timestamp != "" { + if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil { + fmt.Printf(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST")) + } + } + fmt.Println() + + // Check each blob exists + fmt.Printf("Checking blob existence...\n") + missing := 0 + verified := 0 + missingSize := int64(0) + + for _, blob := range manifest.Blobs { + blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash) + + if deep { + // Download and verify hash + // TODO: Implement deep verification + fmt.Printf("Deep verification not yet implemented\n") + return nil + } else { + // Just check existence + _, err := v.S3Client.StatObject(v.ctx, blobPath) + if err != nil { + fmt.Printf(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize))) + missing++ + missingSize += blob.CompressedSize + } else { + verified++ + } + } + } + + fmt.Printf("\nVerification complete:\n") + fmt.Printf(" Verified: %d blobs (%s)\n", verified, + humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize))) + if missing > 0 { + fmt.Printf(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize))) + } else { + fmt.Printf(" Missing: 0 blobs\n") + } + fmt.Printf(" Status: ") + if missing > 0 { + fmt.Printf("FAILED - %d blobs are missing\n", missing) + return fmt.Errorf("%d blobs are missing", missing) + } else { + fmt.Printf("OK - All blobs verified\n") + } + + return nil +} + +// Helper methods that were previously on SnapshotApp + +func (v *Vaultik) getManifestSize(snapshotID string) (int64, error) { + manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) + + reader, err := v.S3Client.GetObject(v.ctx, manifestPath) + if err != nil { + return 0, fmt.Errorf("downloading manifest: %w", err) + } + defer func() { _ = reader.Close() }() + + manifest, err := snapshot.DecodeManifest(reader) + if err != nil { + return 0, fmt.Errorf("decoding manifest: %w", err) + } + + return manifest.TotalCompressedSize, nil +} + +func (v *Vaultik) downloadManifest(snapshotID string) (*snapshot.Manifest, error) { + manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) + + reader, err := v.S3Client.GetObject(v.ctx, manifestPath) + if err != nil { + return nil, err + } + defer func() { _ = reader.Close() }() + + manifest, err := snapshot.DecodeManifest(reader) + if err != nil { + return nil, fmt.Errorf("decoding manifest: %w", err) + } + + return manifest, nil +} + +func (v *Vaultik) deleteSnapshot(snapshotID string) error { + // First, delete from S3 + // List all objects under metadata/{snapshotID}/ + prefix := fmt.Sprintf("metadata/%s/", snapshotID) + objectCh := v.S3Client.ListObjectsStream(v.ctx, prefix, true) + + var objectsToDelete []string + for object := range objectCh { + if object.Err != nil { + return fmt.Errorf("listing objects: %w", object.Err) + } + objectsToDelete = append(objectsToDelete, object.Key) + } + + // Delete all objects + for _, key := range objectsToDelete { + if err := v.S3Client.RemoveObject(v.ctx, key); err != nil { + return fmt.Errorf("removing %s: %w", key, err) + } + } + + // Then, delete from local database + // Delete related records first to avoid foreign key constraints + if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotID); err != nil { + log.Error("Failed to delete snapshot files", "snapshot_id", snapshotID, "error", err) + } + if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotID); err != nil { + log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshotID, "error", err) + } + if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotID); err != nil { + log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshotID, "error", err) + } + + // Now delete the snapshot itself + if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotID); err != nil { + return fmt.Errorf("deleting snapshot from database: %w", err) + } + + return nil +} + +func (v *Vaultik) syncWithRemote() error { + log.Info("Syncing with remote snapshots") + + // Get all remote snapshot IDs + remoteSnapshots := make(map[string]bool) + objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false) + + for object := range objectCh { + if object.Err != nil { + return fmt.Errorf("listing remote snapshots: %w", object.Err) + } + + // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/ + parts := strings.Split(object.Key, "/") + if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { + remoteSnapshots[parts[1]] = true + } + } + + log.Debug("Found remote snapshots", "count", len(remoteSnapshots)) + + // Get all local snapshots (use a high limit to get all) + localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000) + if err != nil { + return fmt.Errorf("listing local snapshots: %w", err) + } + + // Remove local snapshots that don't exist remotely + removedCount := 0 + for _, snapshot := range localSnapshots { + if !remoteSnapshots[snapshot.ID] { + log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID) + if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil { + log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err) + } else { + removedCount++ + } + } + } + + if removedCount > 0 { + log.Info("Removed local snapshots not found in remote", "count", removedCount) + } + + return nil +} diff --git a/internal/vaultik/vaultik.go b/internal/vaultik/vaultik.go new file mode 100644 index 0000000..ae5f410 --- /dev/null +++ b/internal/vaultik/vaultik.go @@ -0,0 +1,124 @@ +package vaultik + +import ( + "context" + "fmt" + "io" + "os" + + "git.eeqj.de/sneak/vaultik/internal/config" + "git.eeqj.de/sneak/vaultik/internal/crypto" + "git.eeqj.de/sneak/vaultik/internal/database" + "git.eeqj.de/sneak/vaultik/internal/globals" + "git.eeqj.de/sneak/vaultik/internal/s3" + "git.eeqj.de/sneak/vaultik/internal/snapshot" + "github.com/spf13/afero" + "go.uber.org/fx" +) + +// Vaultik contains all dependencies needed for vaultik operations +type Vaultik struct { + Globals *globals.Globals + Config *config.Config + DB *database.DB + Repositories *database.Repositories + S3Client *s3.Client + ScannerFactory snapshot.ScannerFactory + SnapshotManager *snapshot.SnapshotManager + Shutdowner fx.Shutdowner + Fs afero.Fs + + // Context management + ctx context.Context + cancel context.CancelFunc + + // IO + Stdout io.Writer + Stderr io.Writer + Stdin io.Reader +} + +// VaultikParams contains all parameters for New that can be provided by fx +type VaultikParams struct { + fx.In + + Globals *globals.Globals + Config *config.Config + DB *database.DB + Repositories *database.Repositories + S3Client *s3.Client + ScannerFactory snapshot.ScannerFactory + SnapshotManager *snapshot.SnapshotManager + Shutdowner fx.Shutdowner + Fs afero.Fs `optional:"true"` +} + +// New creates a new Vaultik instance with proper context management +// It automatically includes crypto capabilities if age_secret_key is configured +func New(params VaultikParams) *Vaultik { + ctx, cancel := context.WithCancel(context.Background()) + + // Use provided filesystem or default to OS filesystem + fs := params.Fs + if fs == nil { + fs = afero.NewOsFs() + } + + // Set filesystem on SnapshotManager + params.SnapshotManager.SetFilesystem(fs) + + return &Vaultik{ + Globals: params.Globals, + Config: params.Config, + DB: params.DB, + Repositories: params.Repositories, + S3Client: params.S3Client, + ScannerFactory: params.ScannerFactory, + SnapshotManager: params.SnapshotManager, + Shutdowner: params.Shutdowner, + Fs: fs, + ctx: ctx, + cancel: cancel, + Stdout: os.Stdout, + Stderr: os.Stderr, + Stdin: os.Stdin, + } +} + +// Context returns the Vaultik's context +func (v *Vaultik) Context() context.Context { + return v.ctx +} + +// Cancel cancels the Vaultik's context +func (v *Vaultik) Cancel() { + v.cancel() +} + +// CanDecrypt returns true if this Vaultik instance has decryption capabilities +func (v *Vaultik) CanDecrypt() bool { + return v.Config.AgeSecretKey != "" +} + +// GetEncryptor creates a new Encryptor instance based on the configured age recipients +// Returns an error if no recipients are configured +func (v *Vaultik) GetEncryptor() (*crypto.Encryptor, error) { + if len(v.Config.AgeRecipients) == 0 { + return nil, fmt.Errorf("no age recipients configured") + } + return crypto.NewEncryptor(v.Config.AgeRecipients) +} + +// GetDecryptor creates a new Decryptor instance based on the configured age secret key +// Returns an error if no secret key is configured +func (v *Vaultik) GetDecryptor() (*crypto.Decryptor, error) { + if v.Config.AgeSecretKey == "" { + return nil, fmt.Errorf("no age secret key configured") + } + return crypto.NewDecryptor(v.Config.AgeSecretKey) +} + +// GetFilesystem returns the filesystem instance used by Vaultik +func (v *Vaultik) GetFilesystem() afero.Fs { + return v.Fs +} diff --git a/internal/vaultik/verify.go b/internal/vaultik/verify.go new file mode 100644 index 0000000..bb32054 --- /dev/null +++ b/internal/vaultik/verify.go @@ -0,0 +1,396 @@ +package vaultik + +import ( + "crypto/sha256" + "database/sql" + "encoding/hex" + "fmt" + "io" + "os" + + "git.eeqj.de/sneak/vaultik/internal/log" + "git.eeqj.de/sneak/vaultik/internal/snapshot" + "github.com/dustin/go-humanize" + "github.com/klauspost/compress/zstd" + _ "github.com/mattn/go-sqlite3" +) + +// VerifyOptions contains options for the verify command +type VerifyOptions struct { + Deep bool +} + +// RunDeepVerify executes deep verification operation +func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { + // Check for decryption capability + if !v.CanDecrypt() { + return fmt.Errorf("age_secret_key missing from config - required for deep verification") + } + + log.Info("Starting snapshot verification", + "snapshot_id", snapshotID, + "mode", map[bool]string{true: "deep", false: "shallow"}[opts.Deep], + ) + + // Step 1: Download manifest + manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) + log.Info("Downloading manifest", "path", manifestPath) + + manifestReader, err := v.S3Client.GetObject(v.ctx, manifestPath) + if err != nil { + return fmt.Errorf("failed to download manifest: %w", err) + } + defer func() { _ = manifestReader.Close() }() + + // Decompress manifest + manifest, err := snapshot.DecodeManifest(manifestReader) + if err != nil { + return fmt.Errorf("failed to decode manifest: %w", err) + } + + log.Info("Manifest loaded", + "blob_count", manifest.BlobCount, + "total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize)), + ) + + // Step 2: Download and decrypt database + dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) + log.Info("Downloading encrypted database", "path", dbPath) + + dbReader, err := v.S3Client.GetObject(v.ctx, dbPath) + if err != nil { + return fmt.Errorf("failed to download database: %w", err) + } + defer func() { _ = dbReader.Close() }() + + // Decrypt and decompress database + tempDB, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey) + if err != nil { + return fmt.Errorf("failed to decrypt database: %w", err) + } + defer func() { + if tempDB != nil { + _ = tempDB.Close() + } + }() + + // Step 3: Compare blob lists + if err := v.verifyBlobLists(snapshotID, manifest, tempDB.DB); err != nil { + return err + } + + // Step 4: Verify blob existence + if err := v.verifyBlobExistence(manifest); err != nil { + return err + } + + // Step 5: Deep verification if requested + if opts.Deep { + if err := v.performDeepVerification(manifest, tempDB.DB); err != nil { + return err + } + } + + log.Info("✓ Verification completed successfully", + "snapshot_id", snapshotID, + "mode", map[bool]string{true: "deep", false: "shallow"}[opts.Deep], + ) + + return nil +} + +// tempDB wraps sql.DB with cleanup +type tempDB struct { + *sql.DB + tempPath string +} + +func (t *tempDB) Close() error { + err := t.DB.Close() + _ = os.Remove(t.tempPath) + return err +} + +// decryptAndLoadDatabase decrypts and loads the database from the encrypted stream +func (v *Vaultik) decryptAndLoadDatabase(reader io.ReadCloser, secretKey string) (*tempDB, error) { + // Get decryptor + decryptor, err := v.GetDecryptor() + if err != nil { + return nil, fmt.Errorf("failed to get decryptor: %w", err) + } + + // Decrypt the stream + decryptedReader, err := decryptor.DecryptStream(reader) + if err != nil { + return nil, fmt.Errorf("failed to decrypt database: %w", err) + } + + // Decompress the database + decompressor, err := zstd.NewReader(decryptedReader) + if err != nil { + return nil, fmt.Errorf("failed to create decompressor: %w", err) + } + defer decompressor.Close() + + // Create temporary file for database + tempFile, err := os.CreateTemp("", "vaultik-verify-*.db") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + tempPath := tempFile.Name() + + // Copy decompressed data to temp file + if _, err := io.Copy(tempFile, decompressor); err != nil { + _ = tempFile.Close() + _ = os.Remove(tempPath) + return nil, fmt.Errorf("failed to write database: %w", err) + } + + // Close temp file before opening with sqlite + if err := tempFile.Close(); err != nil { + _ = os.Remove(tempPath) + return nil, fmt.Errorf("failed to close temp file: %w", err) + } + + // Open the database + db, err := sql.Open("sqlite3", tempPath) + if err != nil { + _ = os.Remove(tempPath) + return nil, fmt.Errorf("failed to open database: %w", err) + } + + return &tempDB{ + DB: db, + tempPath: tempPath, + }, nil +} + +// verifyBlobLists compares the blob lists between manifest and database +func (v *Vaultik) verifyBlobLists(snapshotID string, manifest *snapshot.Manifest, db *sql.DB) error { + log.Info("Verifying blob lists match between manifest and database") + + // Get blobs from database + query := ` + SELECT b.blob_hash, b.compressed_size + FROM snapshot_blobs sb + JOIN blobs b ON sb.blob_hash = b.blob_hash + WHERE sb.snapshot_id = ? + ORDER BY b.blob_hash + ` + rows, err := db.QueryContext(v.ctx, query, snapshotID) + if err != nil { + return fmt.Errorf("failed to query snapshot blobs: %w", err) + } + defer func() { _ = rows.Close() }() + + // Build map of database blobs + dbBlobs := make(map[string]int64) + for rows.Next() { + var hash string + var size int64 + if err := rows.Scan(&hash, &size); err != nil { + return fmt.Errorf("failed to scan blob row: %w", err) + } + dbBlobs[hash] = size + } + + // Build map of manifest blobs + manifestBlobs := make(map[string]int64) + for _, blob := range manifest.Blobs { + manifestBlobs[blob.Hash] = blob.CompressedSize + } + + // Compare counts + if len(dbBlobs) != len(manifestBlobs) { + return fmt.Errorf("blob count mismatch: database has %d blobs, manifest has %d blobs", + len(dbBlobs), len(manifestBlobs)) + } + + // Check each blob exists in both + for hash, dbSize := range dbBlobs { + manifestSize, exists := manifestBlobs[hash] + if !exists { + return fmt.Errorf("blob %s exists in database but not in manifest", hash) + } + if dbSize != manifestSize { + return fmt.Errorf("blob %s size mismatch: database has %d bytes, manifest has %d bytes", + hash, dbSize, manifestSize) + } + } + + for hash := range manifestBlobs { + if _, exists := dbBlobs[hash]; !exists { + return fmt.Errorf("blob %s exists in manifest but not in database", hash) + } + } + + log.Info("✓ Blob lists match", "blob_count", len(dbBlobs)) + return nil +} + +// verifyBlobExistence checks that all blobs exist in S3 +func (v *Vaultik) verifyBlobExistence(manifest *snapshot.Manifest) error { + log.Info("Verifying blob existence in S3", "blob_count", len(manifest.Blobs)) + + for i, blob := range manifest.Blobs { + // Construct blob path + blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash) + + // Check blob exists with HeadObject + stat, err := v.S3Client.StatObject(v.ctx, blobPath) + if err != nil { + return fmt.Errorf("blob %s missing from S3: %w", blob.Hash, err) + } + + // Verify size matches + if stat.Size != blob.CompressedSize { + return fmt.Errorf("blob %s size mismatch: S3 has %d bytes, manifest has %d bytes", + blob.Hash, stat.Size, blob.CompressedSize) + } + + // Progress update every 100 blobs + if (i+1)%100 == 0 || i == len(manifest.Blobs)-1 { + log.Info("Blob existence check progress", + "checked", i+1, + "total", len(manifest.Blobs), + "percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(manifest.Blobs))*100), + ) + } + } + + log.Info("✓ All blobs exist in S3") + return nil +} + +// performDeepVerification downloads and verifies the content of each blob +func (v *Vaultik) performDeepVerification(manifest *snapshot.Manifest, db *sql.DB) error { + log.Info("Starting deep verification - downloading and verifying all blobs") + + totalBytes := int64(0) + for i, blobInfo := range manifest.Blobs { + // Verify individual blob + if err := v.verifyBlob(blobInfo, db); err != nil { + return fmt.Errorf("blob %s verification failed: %w", blobInfo.Hash, err) + } + + totalBytes += blobInfo.CompressedSize + + // Progress update + log.Info("Deep verification progress", + "blob", fmt.Sprintf("%d/%d", i+1, len(manifest.Blobs)), + "total_downloaded", humanize.Bytes(uint64(totalBytes)), + "percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(manifest.Blobs))*100), + ) + } + + log.Info("✓ Deep verification completed successfully", + "blobs_verified", len(manifest.Blobs), + "total_size", humanize.Bytes(uint64(totalBytes)), + ) + + return nil +} + +// verifyBlob downloads and verifies a single blob +func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { + // Download blob + blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobInfo.Hash[:2], blobInfo.Hash[2:4], blobInfo.Hash) + reader, err := v.S3Client.GetObject(v.ctx, blobPath) + if err != nil { + return fmt.Errorf("failed to download: %w", err) + } + defer func() { _ = reader.Close() }() + + // Get decryptor + decryptor, err := v.GetDecryptor() + if err != nil { + return fmt.Errorf("failed to get decryptor: %w", err) + } + + // Decrypt blob + decryptedReader, err := decryptor.DecryptStream(reader) + if err != nil { + return fmt.Errorf("failed to decrypt: %w", err) + } + + // Decompress blob + decompressor, err := zstd.NewReader(decryptedReader) + if err != nil { + return fmt.Errorf("failed to decompress: %w", err) + } + defer decompressor.Close() + + // Query blob chunks from database to get offsets and lengths + query := ` + SELECT bc.chunk_hash, bc.offset, bc.length + FROM blob_chunks bc + JOIN blobs b ON bc.blob_id = b.id + WHERE b.blob_hash = ? + ORDER BY bc.offset + ` + rows, err := db.QueryContext(v.ctx, query, blobInfo.Hash) + if err != nil { + return fmt.Errorf("failed to query blob chunks: %w", err) + } + defer func() { _ = rows.Close() }() + + var lastOffset int64 = -1 + chunkCount := 0 + totalRead := int64(0) + + // Verify each chunk in the blob + for rows.Next() { + var chunkHash string + var offset, length int64 + if err := rows.Scan(&chunkHash, &offset, &length); err != nil { + return fmt.Errorf("failed to scan chunk row: %w", err) + } + + // Verify chunk ordering + if offset <= lastOffset { + return fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) + } + lastOffset = offset + + // Read chunk data from decompressed stream + if offset > totalRead { + // Skip to the correct offset + skipBytes := offset - totalRead + if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil { + return fmt.Errorf("failed to skip to offset %d: %w", offset, err) + } + totalRead = offset + } + + // Read chunk data + chunkData := make([]byte, length) + if _, err := io.ReadFull(decompressor, chunkData); err != nil { + return fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) + } + totalRead += length + + // Verify chunk hash + hasher := sha256.New() + hasher.Write(chunkData) + calculatedHash := hex.EncodeToString(hasher.Sum(nil)) + + if calculatedHash != chunkHash { + return fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", + offset, calculatedHash, chunkHash) + } + + chunkCount++ + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating blob chunks: %w", err) + } + + log.Debug("Blob verified", + "hash", blobInfo.Hash, + "chunks", chunkCount, + "size", humanize.Bytes(uint64(blobInfo.CompressedSize)), + ) + + return nil +}