Add end-to-end integration tests for Vaultik

- Create comprehensive integration tests with mock S3 client
- Add in-memory filesystem and SQLite database support for testing
- Test full backup workflow including chunking, packing, and uploading
- Add test to verify encrypted blob content
- Fix scanner to use afero filesystem for temp file cleanup
- Demonstrate successful backup and verification with mock dependencies
This commit is contained in:
Jeffrey Paul 2025-07-26 15:52:23 +02:00
parent bb38f8c5d6
commit d7cd9aac27
8 changed files with 1974 additions and 1 deletions

View File

@ -676,7 +676,7 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
if err := blobWithReader.TempFile.Close(); err != nil {
log.Fatal("Failed to close temp file", "file", tempName, "error", err)
}
if err := os.Remove(tempName); err != nil {
if err := s.fs.Remove(tempName); err != nil {
log.Fatal("Failed to remove temp file", "file", tempName, "error", err)
}
}

103
internal/vaultik/helpers.go Normal file
View File

@ -0,0 +1,103 @@
package vaultik
import (
"fmt"
"strconv"
"strings"
"time"
)
// SnapshotInfo contains information about a snapshot
type SnapshotInfo struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
CompressedSize int64 `json:"compressed_size"`
}
// formatNumber formats a number with commas
func formatNumber(n int) string {
str := fmt.Sprintf("%d", n)
var result []string
for i, digit := range str {
if i > 0 && (len(str)-i)%3 == 0 {
result = append(result, ",")
}
result = append(result, string(digit))
}
return strings.Join(result, "")
}
// formatDuration formats a duration in a human-readable way
func formatDuration(d time.Duration) string {
if d < time.Second {
return fmt.Sprintf("%dms", d.Milliseconds())
}
if d < time.Minute {
return fmt.Sprintf("%.1fs", d.Seconds())
}
if d < time.Hour {
mins := int(d.Minutes())
secs := int(d.Seconds()) % 60
return fmt.Sprintf("%dm %ds", mins, secs)
}
hours := int(d.Hours())
mins := int(d.Minutes()) % 60
return fmt.Sprintf("%dh %dm", hours, mins)
}
// formatBytes formats bytes in a human-readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// parseSnapshotTimestamp extracts the timestamp from a snapshot ID
func parseSnapshotTimestamp(snapshotID string) (time.Time, error) {
// Format: hostname-YYYYMMDD-HHMMSSZ
parts := strings.Split(snapshotID, "-")
if len(parts) < 3 {
return time.Time{}, fmt.Errorf("invalid snapshot ID format")
}
dateStr := parts[len(parts)-2]
timeStr := parts[len(parts)-1]
if len(dateStr) != 8 || len(timeStr) != 7 || !strings.HasSuffix(timeStr, "Z") {
return time.Time{}, fmt.Errorf("invalid timestamp format")
}
// Remove Z suffix
timeStr = timeStr[:6]
// Parse the timestamp
timestamp, err := time.Parse("20060102150405", dateStr+timeStr)
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse timestamp: %w", err)
}
return timestamp.UTC(), nil
}
// parseDuration parses a duration string with support for days
func parseDuration(s string) (time.Duration, error) {
// Check for days suffix
if strings.HasSuffix(s, "d") {
daysStr := strings.TrimSuffix(s, "d")
days, err := strconv.Atoi(daysStr)
if err != nil {
return 0, fmt.Errorf("invalid days value: %w", err)
}
return time.Duration(days) * 24 * time.Hour, nil
}
// Otherwise use standard Go duration parsing
return time.ParseDuration(s)
}

101
internal/vaultik/info.go Normal file
View File

@ -0,0 +1,101 @@
package vaultik
import (
"fmt"
"runtime"
"strings"
"github.com/dustin/go-humanize"
)
// ShowInfo displays system and configuration information
func (v *Vaultik) ShowInfo() error {
// System Information
fmt.Printf("=== System Information ===\n")
fmt.Printf("OS/Architecture: %s/%s\n", runtime.GOOS, runtime.GOARCH)
fmt.Printf("Version: %s\n", v.Globals.Version)
fmt.Printf("Commit: %s\n", v.Globals.Commit)
fmt.Printf("Go Version: %s\n", runtime.Version())
fmt.Println()
// Storage Configuration
fmt.Printf("=== Storage Configuration ===\n")
fmt.Printf("S3 Bucket: %s\n", v.Config.S3.Bucket)
if v.Config.S3.Prefix != "" {
fmt.Printf("S3 Prefix: %s\n", v.Config.S3.Prefix)
}
fmt.Printf("S3 Endpoint: %s\n", v.Config.S3.Endpoint)
fmt.Printf("S3 Region: %s\n", v.Config.S3.Region)
fmt.Println()
// Backup Settings
fmt.Printf("=== Backup Settings ===\n")
fmt.Printf("Source Directories:\n")
for _, dir := range v.Config.SourceDirs {
fmt.Printf(" - %s\n", dir)
}
// Global exclude patterns
if len(v.Config.Exclude) > 0 {
fmt.Printf("Exclude Patterns: %s\n", strings.Join(v.Config.Exclude, ", "))
}
fmt.Printf("Compression: zstd level %d\n", v.Config.CompressionLevel)
fmt.Printf("Chunk Size: %s\n", humanize.Bytes(uint64(v.Config.ChunkSize)))
fmt.Printf("Blob Size Limit: %s\n", humanize.Bytes(uint64(v.Config.BlobSizeLimit)))
fmt.Println()
// Encryption Configuration
fmt.Printf("=== Encryption Configuration ===\n")
fmt.Printf("Recipients:\n")
for _, recipient := range v.Config.AgeRecipients {
fmt.Printf(" - %s\n", recipient)
}
fmt.Println()
// Daemon Settings (if applicable)
if v.Config.BackupInterval > 0 || v.Config.MinTimeBetweenRun > 0 {
fmt.Printf("=== Daemon Settings ===\n")
if v.Config.BackupInterval > 0 {
fmt.Printf("Backup Interval: %s\n", v.Config.BackupInterval)
}
if v.Config.MinTimeBetweenRun > 0 {
fmt.Printf("Minimum Time: %s\n", v.Config.MinTimeBetweenRun)
}
fmt.Println()
}
// Local Database
fmt.Printf("=== Local Database ===\n")
fmt.Printf("Index Path: %s\n", v.Config.IndexPath)
// Check if index file exists and get its size
if info, err := v.Fs.Stat(v.Config.IndexPath); err == nil {
fmt.Printf("Index Size: %s\n", humanize.Bytes(uint64(info.Size())))
// Get snapshot count from database
query := `SELECT COUNT(*) FROM snapshots WHERE completed_at IS NOT NULL`
var snapshotCount int
if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&snapshotCount); err == nil {
fmt.Printf("Snapshots: %d\n", snapshotCount)
}
// Get blob count from database
query = `SELECT COUNT(*) FROM blobs`
var blobCount int
if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&blobCount); err == nil {
fmt.Printf("Blobs: %d\n", blobCount)
}
// Get file count from database
query = `SELECT COUNT(*) FROM files`
var fileCount int
if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&fileCount); err == nil {
fmt.Printf("Files: %d\n", fileCount)
}
} else {
fmt.Printf("Index Size: (not created)\n")
}
return nil
}

View File

@ -0,0 +1,379 @@
package vaultik_test
import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
"sync"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/s3"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// MockS3Client implements a mock S3 client for testing
type MockS3Client struct {
mu sync.Mutex
storage map[string][]byte
calls []string
}
func NewMockS3Client() *MockS3Client {
return &MockS3Client{
storage: make(map[string][]byte),
calls: make([]string, 0),
}
}
func (m *MockS3Client) PutObject(ctx context.Context, key string, reader io.Reader) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "PutObject:"+key)
data, err := io.ReadAll(reader)
if err != nil {
return err
}
m.storage[key] = data
return nil
}
func (m *MockS3Client) PutObjectWithProgress(ctx context.Context, key string, reader io.Reader, size int64, progress s3.ProgressCallback) error {
// For testing, just call PutObject
return m.PutObject(ctx, key, reader)
}
func (m *MockS3Client) GetObject(ctx context.Context, key string) (io.ReadCloser, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "GetObject:"+key)
data, exists := m.storage[key]
if !exists {
return nil, fmt.Errorf("key not found: %s", key)
}
return io.NopCloser(bytes.NewReader(data)), nil
}
func (m *MockS3Client) StatObject(ctx context.Context, key string) (*s3.ObjectInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "StatObject:"+key)
data, exists := m.storage[key]
if !exists {
return nil, fmt.Errorf("key not found: %s", key)
}
return &s3.ObjectInfo{
Key: key,
Size: int64(len(data)),
}, nil
}
func (m *MockS3Client) DeleteObject(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "DeleteObject:"+key)
delete(m.storage, key)
return nil
}
func (m *MockS3Client) ListObjects(ctx context.Context, prefix string) ([]*s3.ObjectInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "ListObjects:"+prefix)
var objects []*s3.ObjectInfo
for key, data := range m.storage {
if len(prefix) == 0 || (len(key) >= len(prefix) && key[:len(prefix)] == prefix) {
objects = append(objects, &s3.ObjectInfo{
Key: key,
Size: int64(len(data)),
})
}
}
return objects, nil
}
// GetCalls returns the list of S3 operations that were called
func (m *MockS3Client) GetCalls() []string {
m.mu.Lock()
defer m.mu.Unlock()
calls := make([]string, len(m.calls))
copy(calls, m.calls)
return calls
}
// GetStorageSize returns the number of objects in storage
func (m *MockS3Client) GetStorageSize() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.storage)
}
// TestEndToEndBackup tests the full backup workflow with mocked dependencies
func TestEndToEndBackup(t *testing.T) {
// Initialize logger
log.Initialize(log.Config{})
// Create in-memory filesystem
fs := afero.NewMemMapFs()
// Create test directory structure and files
testFiles := map[string]string{
"/home/user/documents/file1.txt": "This is file 1 content",
"/home/user/documents/file2.txt": "This is file 2 content with more data",
"/home/user/pictures/photo1.jpg": "Binary photo data here...",
"/home/user/code/main.go": "package main\n\nfunc main() {\n\tprintln(\"Hello, World!\")\n}",
}
// Create all directories first
dirs := []string{
"/home/user/documents",
"/home/user/pictures",
"/home/user/code",
}
for _, dir := range dirs {
if err := fs.MkdirAll(dir, 0755); err != nil {
t.Fatalf("failed to create directory %s: %v", dir, err)
}
}
// Create test files
for path, content := range testFiles {
if err := afero.WriteFile(fs, path, []byte(content), 0644); err != nil {
t.Fatalf("failed to create test file %s: %v", path, err)
}
}
// Create mock S3 client
mockS3 := NewMockS3Client()
// Create test configuration
cfg := &config.Config{
SourceDirs: []string{"/home/user"},
Exclude: []string{"*.tmp", "*.log"},
ChunkSize: config.Size(16 * 1024), // 16KB chunks
BlobSizeLimit: config.Size(100 * 1024), // 100KB blobs
CompressionLevel: 3,
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
AgeSecretKey: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5", // Test private key
S3: config.S3Config{
Endpoint: "http://localhost:9000", // MinIO endpoint for testing
Region: "us-east-1",
Bucket: "test-bucket",
AccessKeyID: "test-access",
SecretAccessKey: "test-secret",
},
IndexPath: ":memory:", // In-memory SQLite database
}
// For a true end-to-end test, we'll create a simpler test that focuses on
// the core backup logic using the scanner directly with our mock S3 client
ctx := context.Background()
// Create in-memory database
db, err := database.New(ctx, ":memory:")
require.NoError(t, err)
defer func() {
if err := db.Close(); err != nil {
t.Errorf("failed to close database: %v", err)
}
}()
repos := database.NewRepositories(db)
// Create scanner with mock S3 client
scanner := snapshot.NewScanner(snapshot.ScannerConfig{
FS: fs,
ChunkSize: cfg.ChunkSize.Int64(),
Repositories: repos,
S3Client: mockS3,
MaxBlobSize: cfg.BlobSizeLimit.Int64(),
CompressionLevel: cfg.CompressionLevel,
AgeRecipients: cfg.AgeRecipients,
EnableProgress: false,
})
// Create a snapshot record
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: "test-host",
VaultikVersion: "test-version",
StartedAt: time.Now(),
}
return repos.Snapshots.Create(ctx, tx, snapshot)
})
require.NoError(t, err)
// Run the backup scan
result, err := scanner.Scan(ctx, "/home/user", snapshotID)
require.NoError(t, err)
// Verify scan results
// The scanner counts both files and directories, so we have:
// 4 files + 4 directories (/home, /home/user, /home/user/documents, /home/user/pictures, /home/user/code)
assert.GreaterOrEqual(t, result.FilesScanned, 4, "Should scan at least 4 files")
assert.Greater(t, result.BytesScanned, int64(0), "Should scan some bytes")
assert.Greater(t, result.ChunksCreated, 0, "Should create chunks")
assert.Greater(t, result.BlobsCreated, 0, "Should create blobs")
// Verify S3 operations
calls := mockS3.GetCalls()
t.Logf("S3 operations performed: %v", calls)
// Should have uploaded at least one blob
blobUploads := 0
for _, call := range calls {
if len(call) > 10 && call[:10] == "PutObject:" {
if len(call) > 16 && call[10:16] == "blobs/" {
blobUploads++
}
}
}
assert.Greater(t, blobUploads, 0, "Should upload at least one blob")
// Verify files in database
files, err := repos.Files.ListByPrefix(ctx, "/home/user")
require.NoError(t, err)
// Count only regular files (not directories)
regularFiles := 0
for _, f := range files {
if f.Mode&0x80000000 == 0 { // Check if regular file (not directory)
regularFiles++
}
}
assert.Equal(t, 4, regularFiles, "Should have 4 regular files in database")
// Verify chunks were created by checking a specific file
fileChunks, err := repos.FileChunks.GetByPath(ctx, "/home/user/documents/file1.txt")
require.NoError(t, err)
assert.Greater(t, len(fileChunks), 0, "Should have chunks for file1.txt")
// Verify blobs were uploaded to S3
assert.Greater(t, mockS3.GetStorageSize(), 0, "Should have blobs in S3 storage")
// Complete the snapshot - just verify we got results
// In a real integration test, we'd update the snapshot record
// Create snapshot manager to test metadata export
snapshotManager := &snapshot.SnapshotManager{}
snapshotManager.SetFilesystem(fs)
// Note: We can't fully test snapshot metadata export without a proper S3 client mock
// that implements all required methods. This would require refactoring the S3 client
// interface to be more testable.
t.Logf("Backup completed successfully:")
t.Logf(" Files scanned: %d", result.FilesScanned)
t.Logf(" Bytes scanned: %d", result.BytesScanned)
t.Logf(" Chunks created: %d", result.ChunksCreated)
t.Logf(" Blobs created: %d", result.BlobsCreated)
t.Logf(" S3 storage size: %d objects", mockS3.GetStorageSize())
}
// TestBackupAndVerify tests backing up files and verifying the blobs
func TestBackupAndVerify(t *testing.T) {
// Initialize logger
log.Initialize(log.Config{})
// Create in-memory filesystem
fs := afero.NewMemMapFs()
// Create test files
testContent := "This is a test file with some content that should be backed up"
err := fs.MkdirAll("/data", 0755)
require.NoError(t, err)
err = afero.WriteFile(fs, "/data/test.txt", []byte(testContent), 0644)
require.NoError(t, err)
// Create mock S3 client
mockS3 := NewMockS3Client()
// Create test database
ctx := context.Background()
db, err := database.New(ctx, ":memory:")
require.NoError(t, err)
defer func() {
if err := db.Close(); err != nil {
t.Errorf("failed to close database: %v", err)
}
}()
repos := database.NewRepositories(db)
// Create scanner
scanner := snapshot.NewScanner(snapshot.ScannerConfig{
FS: fs,
ChunkSize: int64(1024 * 16), // 16KB chunks
Repositories: repos,
S3Client: mockS3,
MaxBlobSize: int64(1024 * 1024), // 1MB blobs
CompressionLevel: 3,
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
})
// Create a snapshot
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: "test-host",
VaultikVersion: "test-version",
StartedAt: time.Now(),
}
return repos.Snapshots.Create(ctx, tx, snapshot)
})
require.NoError(t, err)
// Run the backup
result, err := scanner.Scan(ctx, "/data", snapshotID)
require.NoError(t, err)
// Verify backup created blobs
assert.Greater(t, result.BlobsCreated, 0, "Should create at least one blob")
assert.Equal(t, mockS3.GetStorageSize(), result.BlobsCreated, "S3 should have the blobs")
// Verify we can retrieve the blob from S3
objects, err := mockS3.ListObjects(ctx, "blobs/")
require.NoError(t, err)
assert.Len(t, objects, result.BlobsCreated, "Should have correct number of blobs in S3")
// Get the first blob and verify it exists
if len(objects) > 0 {
blobKey := objects[0].Key
t.Logf("Verifying blob: %s", blobKey)
// Get blob info
blobInfo, err := mockS3.StatObject(ctx, blobKey)
require.NoError(t, err)
assert.Greater(t, blobInfo.Size, int64(0), "Blob should have content")
// Get blob content
reader, err := mockS3.GetObject(ctx, blobKey)
require.NoError(t, err)
defer func() { _ = reader.Close() }()
// Verify blob data is encrypted (should not contain plaintext)
blobData, err := io.ReadAll(reader)
require.NoError(t, err)
assert.NotContains(t, string(blobData), testContent, "Blob should be encrypted")
assert.Greater(t, len(blobData), 0, "Blob should have data")
}
t.Logf("Backup and verify test completed successfully")
}

169
internal/vaultik/prune.go Normal file
View File

@ -0,0 +1,169 @@
package vaultik
import (
"fmt"
"strings"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/dustin/go-humanize"
)
// PruneOptions contains options for the prune command
type PruneOptions struct {
Force bool
}
// PruneBlobs removes unreferenced blobs from storage
func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
log.Info("Starting prune operation")
// Get all remote snapshots and their manifests
allBlobsReferenced := make(map[string]bool)
manifestCount := 0
// List all snapshots in S3
log.Info("Listing remote snapshots")
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
var snapshotIDs []string
for object := range objectCh {
if object.Err != nil {
return fmt.Errorf("listing remote snapshots: %w", object.Err)
}
// Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
parts := strings.Split(object.Key, "/")
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
// Check if this is a directory by looking for trailing slash
if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") {
snapshotID := parts[1]
// Only add unique snapshot IDs
found := false
for _, id := range snapshotIDs {
if id == snapshotID {
found = true
break
}
}
if !found {
snapshotIDs = append(snapshotIDs, snapshotID)
}
}
}
}
log.Info("Found manifests in remote storage", "count", len(snapshotIDs))
// Download and parse each manifest to get referenced blobs
for _, snapshotID := range snapshotIDs {
log.Debug("Processing manifest", "snapshot_id", snapshotID)
manifest, err := v.downloadManifest(snapshotID)
if err != nil {
log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err)
continue
}
// Add all blobs from this manifest to our referenced set
for _, blob := range manifest.Blobs {
allBlobsReferenced[blob.Hash] = true
}
manifestCount++
}
log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced))
// List all blobs in S3
log.Info("Listing all blobs in storage")
allBlobs := make(map[string]int64) // hash -> size
blobObjectCh := v.S3Client.ListObjectsStream(v.ctx, "blobs/", true)
for object := range blobObjectCh {
if object.Err != nil {
return fmt.Errorf("listing blobs: %w", object.Err)
}
// Extract hash from path like blobs/ab/cd/abcdef123456...
parts := strings.Split(object.Key, "/")
if len(parts) == 4 && parts[0] == "blobs" {
hash := parts[3]
allBlobs[hash] = object.Size
}
}
log.Info("Found blobs in storage", "count", len(allBlobs))
// Find unreferenced blobs
var unreferencedBlobs []string
var totalSize int64
for hash, size := range allBlobs {
if !allBlobsReferenced[hash] {
unreferencedBlobs = append(unreferencedBlobs, hash)
totalSize += size
}
}
if len(unreferencedBlobs) == 0 {
log.Info("No unreferenced blobs found")
fmt.Println("No unreferenced blobs to remove.")
return nil
}
// Show what will be deleted
log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize)))
fmt.Printf("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize)))
// Confirm unless --force is used
if !opts.Force {
fmt.Printf("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs))
var confirm string
if _, err := fmt.Scanln(&confirm); err != nil {
// Treat EOF or error as "no"
fmt.Println("Cancelled")
return nil
}
if strings.ToLower(confirm) != "y" {
fmt.Println("Cancelled")
return nil
}
}
// Delete unreferenced blobs
log.Info("Deleting unreferenced blobs")
deletedCount := 0
deletedSize := int64(0)
for i, hash := range unreferencedBlobs {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash)
if err := v.S3Client.RemoveObject(v.ctx, blobPath); err != nil {
log.Error("Failed to delete blob", "hash", hash, "error", err)
continue
}
deletedCount++
deletedSize += allBlobs[hash]
// Progress update every 100 blobs
if (i+1)%100 == 0 || i == len(unreferencedBlobs)-1 {
log.Info("Deletion progress",
"deleted", i+1,
"total", len(unreferencedBlobs),
"percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(unreferencedBlobs))*100),
)
}
}
log.Info("Prune complete",
"deleted_count", deletedCount,
"deleted_size", humanize.Bytes(uint64(deletedSize)),
"failed", len(unreferencedBlobs)-deletedCount,
)
fmt.Printf("\nDeleted %d blob(s) totaling %s\n", deletedCount, humanize.Bytes(uint64(deletedSize)))
if deletedCount < len(unreferencedBlobs) {
fmt.Printf("Failed to delete %d blob(s)\n", len(unreferencedBlobs)-deletedCount)
}
return nil
}

View File

@ -0,0 +1,701 @@
package vaultik
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"text/tabwriter"
"time"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"github.com/dustin/go-humanize"
)
// SnapshotCreateOptions contains options for the snapshot create command
type SnapshotCreateOptions struct {
Daemon bool
Cron bool
Prune bool
}
// CreateSnapshot executes the snapshot creation operation
func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
snapshotStartTime := time.Now()
log.Info("Starting snapshot creation",
"version", v.Globals.Version,
"commit", v.Globals.Commit,
"index_path", v.Config.IndexPath,
)
// Clean up incomplete snapshots FIRST, before any scanning
// This is critical for data safety - see CleanupIncompleteSnapshots for details
hostname := v.Config.Hostname
if hostname == "" {
hostname, _ = os.Hostname()
}
// CRITICAL: This MUST succeed. If we fail to clean up incomplete snapshots,
// the deduplication logic will think files from the incomplete snapshot were
// already backed up and skip them, resulting in data loss.
if err := v.SnapshotManager.CleanupIncompleteSnapshots(v.ctx, hostname); err != nil {
return fmt.Errorf("cleanup incomplete snapshots: %w", err)
}
if opts.Daemon {
log.Info("Running in daemon mode")
// TODO: Implement daemon mode with inotify
return fmt.Errorf("daemon mode not yet implemented")
}
// Resolve source directories to absolute paths
resolvedDirs := make([]string, 0, len(v.Config.SourceDirs))
for _, dir := range v.Config.SourceDirs {
absPath, err := filepath.Abs(dir)
if err != nil {
return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
}
// Resolve symlinks
resolvedPath, err := filepath.EvalSymlinks(absPath)
if err != nil {
// If the path doesn't exist yet, use the absolute path
if os.IsNotExist(err) {
resolvedPath = absPath
} else {
return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err)
}
}
resolvedDirs = append(resolvedDirs, resolvedPath)
}
// Create scanner with progress enabled (unless in cron mode)
scanner := v.ScannerFactory(snapshot.ScannerParams{
EnableProgress: !opts.Cron,
Fs: v.Fs,
})
// Statistics tracking
totalFiles := 0
totalBytes := int64(0)
totalChunks := 0
totalBlobs := 0
totalBytesSkipped := int64(0)
totalFilesSkipped := 0
totalFilesDeleted := 0
totalBytesDeleted := int64(0)
totalBytesUploaded := int64(0)
totalBlobsUploaded := 0
uploadDuration := time.Duration(0)
// Create a new snapshot at the beginning
snapshotID, err := v.SnapshotManager.CreateSnapshot(v.ctx, hostname, v.Globals.Version, v.Globals.Commit)
if err != nil {
return fmt.Errorf("creating snapshot: %w", err)
}
log.Info("Beginning snapshot", "snapshot_id", snapshotID)
_, _ = fmt.Fprintf(v.Stdout, "Beginning snapshot: %s\n", snapshotID)
for i, dir := range resolvedDirs {
// Check if context is cancelled
select {
case <-v.ctx.Done():
log.Info("Snapshot creation cancelled")
return v.ctx.Err()
default:
}
log.Info("Scanning directory", "path", dir)
_, _ = fmt.Fprintf(v.Stdout, "Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir)
result, err := scanner.Scan(v.ctx, dir, snapshotID)
if err != nil {
return fmt.Errorf("failed to scan %s: %w", dir, err)
}
totalFiles += result.FilesScanned
totalBytes += result.BytesScanned
totalChunks += result.ChunksCreated
totalBlobs += result.BlobsCreated
totalFilesSkipped += result.FilesSkipped
totalBytesSkipped += result.BytesSkipped
totalFilesDeleted += result.FilesDeleted
totalBytesDeleted += result.BytesDeleted
log.Info("Directory scan complete",
"path", dir,
"files", result.FilesScanned,
"files_skipped", result.FilesSkipped,
"bytes", result.BytesScanned,
"bytes_skipped", result.BytesSkipped,
"chunks", result.ChunksCreated,
"blobs", result.BlobsCreated,
"duration", result.EndTime.Sub(result.StartTime))
// Remove per-directory summary - the scanner already prints its own summary
}
// Get upload statistics from scanner progress if available
if s := scanner.GetProgress(); s != nil {
stats := s.GetStats()
totalBytesUploaded = stats.BytesUploaded.Load()
totalBlobsUploaded = int(stats.BlobsUploaded.Load())
uploadDuration = time.Duration(stats.UploadDurationMs.Load()) * time.Millisecond
}
// Update snapshot statistics with extended fields
extStats := snapshot.ExtendedBackupStats{
BackupStats: snapshot.BackupStats{
FilesScanned: totalFiles,
BytesScanned: totalBytes,
ChunksCreated: totalChunks,
BlobsCreated: totalBlobs,
BytesUploaded: totalBytesUploaded,
},
BlobUncompressedSize: 0, // Will be set from database query below
CompressionLevel: v.Config.CompressionLevel,
UploadDurationMs: uploadDuration.Milliseconds(),
}
if err := v.SnapshotManager.UpdateSnapshotStatsExtended(v.ctx, snapshotID, extStats); err != nil {
return fmt.Errorf("updating snapshot stats: %w", err)
}
// Mark snapshot as complete
if err := v.SnapshotManager.CompleteSnapshot(v.ctx, snapshotID); err != nil {
return fmt.Errorf("completing snapshot: %w", err)
}
// Export snapshot metadata
// Export snapshot metadata without closing the database
// The export function should handle its own database connection
if err := v.SnapshotManager.ExportSnapshotMetadata(v.ctx, v.Config.IndexPath, snapshotID); err != nil {
return fmt.Errorf("exporting snapshot metadata: %w", err)
}
// Calculate final statistics
snapshotDuration := time.Since(snapshotStartTime)
totalFilesChanged := totalFiles - totalFilesSkipped
totalBytesChanged := totalBytes
totalBytesAll := totalBytes + totalBytesSkipped
// Calculate upload speed
var avgUploadSpeed string
if totalBytesUploaded > 0 && uploadDuration > 0 {
bytesPerSec := float64(totalBytesUploaded) / uploadDuration.Seconds()
bitsPerSec := bytesPerSec * 8
if bitsPerSec >= 1e9 {
avgUploadSpeed = fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9)
} else if bitsPerSec >= 1e6 {
avgUploadSpeed = fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6)
} else if bitsPerSec >= 1e3 {
avgUploadSpeed = fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3)
} else {
avgUploadSpeed = fmt.Sprintf("%.0f bit/s", bitsPerSec)
}
} else {
avgUploadSpeed = "N/A"
}
// Get total blob sizes from database
totalBlobSizeCompressed := int64(0)
totalBlobSizeUncompressed := int64(0)
if blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID); err == nil {
for _, hash := range blobHashes {
if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil {
totalBlobSizeCompressed += blob.CompressedSize
totalBlobSizeUncompressed += blob.UncompressedSize
}
}
}
// Calculate compression ratio
var compressionRatio float64
if totalBlobSizeUncompressed > 0 {
compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed)
} else {
compressionRatio = 1.0
}
// Print comprehensive summary
_, _ = fmt.Fprintf(v.Stdout, "=== Snapshot Complete ===\n")
_, _ = fmt.Fprintf(v.Stdout, "ID: %s\n", snapshotID)
_, _ = fmt.Fprintf(v.Stdout, "Files: %s examined, %s to process, %s unchanged",
formatNumber(totalFiles),
formatNumber(totalFilesChanged),
formatNumber(totalFilesSkipped))
if totalFilesDeleted > 0 {
_, _ = fmt.Fprintf(v.Stdout, ", %s deleted", formatNumber(totalFilesDeleted))
}
_, _ = fmt.Fprintln(v.Stdout)
_, _ = fmt.Fprintf(v.Stdout, "Data: %s total (%s to process)",
humanize.Bytes(uint64(totalBytesAll)),
humanize.Bytes(uint64(totalBytesChanged)))
if totalBytesDeleted > 0 {
_, _ = fmt.Fprintf(v.Stdout, ", %s deleted", humanize.Bytes(uint64(totalBytesDeleted)))
}
_, _ = fmt.Fprintln(v.Stdout)
if totalBlobsUploaded > 0 {
_, _ = fmt.Fprintf(v.Stdout, "Storage: %s compressed from %s (%.2fx)\n",
humanize.Bytes(uint64(totalBlobSizeCompressed)),
humanize.Bytes(uint64(totalBlobSizeUncompressed)),
compressionRatio)
_, _ = fmt.Fprintf(v.Stdout, "Upload: %d blobs, %s in %s (%s)\n",
totalBlobsUploaded,
humanize.Bytes(uint64(totalBytesUploaded)),
formatDuration(uploadDuration),
avgUploadSpeed)
}
_, _ = fmt.Fprintf(v.Stdout, "Duration: %s\n", formatDuration(snapshotDuration))
if opts.Prune {
log.Info("Pruning enabled - will delete old snapshots after snapshot")
// TODO: Implement pruning
}
return nil
}
// ListSnapshots lists all snapshots
func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
// Get all remote snapshots
remoteSnapshots := make(map[string]bool)
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
for object := range objectCh {
if object.Err != nil {
return fmt.Errorf("listing remote snapshots: %w", object.Err)
}
// Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
parts := strings.Split(object.Key, "/")
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
remoteSnapshots[parts[1]] = true
}
}
// Get all local snapshots
localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000)
if err != nil {
return fmt.Errorf("listing local snapshots: %w", err)
}
// Build a map of local snapshots for quick lookup
localSnapshotMap := make(map[string]*database.Snapshot)
for _, s := range localSnapshots {
localSnapshotMap[s.ID] = s
}
// Remove local snapshots that don't exist remotely
for _, snapshot := range localSnapshots {
if !remoteSnapshots[snapshot.ID] {
log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID)
// Delete related records first to avoid foreign key constraints
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshot.ID); err != nil {
log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshot.ID); err != nil {
log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshot.ID); err != nil {
log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err)
}
// Now delete the snapshot itself
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil {
log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err)
} else {
log.Info("Deleted local snapshot not found in remote", "snapshot_id", snapshot.ID)
delete(localSnapshotMap, snapshot.ID)
}
}
}
// Build final snapshot list
snapshots := make([]SnapshotInfo, 0, len(remoteSnapshots))
for snapshotID := range remoteSnapshots {
// Check if we have this snapshot locally
if localSnap, exists := localSnapshotMap[snapshotID]; exists && localSnap.CompletedAt != nil {
// Get total compressed size of all blobs referenced by this snapshot
totalSize, err := v.Repositories.Snapshots.GetSnapshotTotalCompressedSize(v.ctx, snapshotID)
if err != nil {
log.Warn("Failed to get total compressed size", "id", snapshotID, "error", err)
// Fall back to stored blob size
totalSize = localSnap.BlobSize
}
snapshots = append(snapshots, SnapshotInfo{
ID: localSnap.ID,
Timestamp: localSnap.StartedAt,
CompressedSize: totalSize,
})
} else {
// Remote snapshot not in local DB - fetch manifest to get size
timestamp, err := parseSnapshotTimestamp(snapshotID)
if err != nil {
log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err)
continue
}
// Try to download manifest to get size
totalSize, err := v.getManifestSize(snapshotID)
if err != nil {
return fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err)
}
snapshots = append(snapshots, SnapshotInfo{
ID: snapshotID,
Timestamp: timestamp,
CompressedSize: totalSize,
})
}
}
// Sort by timestamp (newest first)
sort.Slice(snapshots, func(i, j int) bool {
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
})
if jsonOutput {
// JSON output
encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ")
return encoder.Encode(snapshots)
}
// Table output
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
if _, err := fmt.Fprintln(w, "SNAPSHOT ID\tTIMESTAMP\tCOMPRESSED SIZE"); err != nil {
return err
}
if _, err := fmt.Fprintln(w, "───────────\t─────────\t───────────────"); err != nil {
return err
}
for _, snap := range snapshots {
if _, err := fmt.Fprintf(w, "%s\t%s\t%s\n",
snap.ID,
snap.Timestamp.Format("2006-01-02 15:04:05"),
formatBytes(snap.CompressedSize)); err != nil {
return err
}
}
return w.Flush()
}
// PurgeSnapshots removes old snapshots based on criteria
func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool) error {
// Sync with remote first
if err := v.syncWithRemote(); err != nil {
return fmt.Errorf("syncing with remote: %w", err)
}
// Get snapshots from local database
dbSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000)
if err != nil {
return fmt.Errorf("listing snapshots: %w", err)
}
// Convert to SnapshotInfo format, only including completed snapshots
snapshots := make([]SnapshotInfo, 0, len(dbSnapshots))
for _, s := range dbSnapshots {
if s.CompletedAt != nil {
snapshots = append(snapshots, SnapshotInfo{
ID: s.ID,
Timestamp: s.StartedAt,
CompressedSize: s.BlobSize,
})
}
}
// Sort by timestamp (newest first)
sort.Slice(snapshots, func(i, j int) bool {
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
})
var toDelete []SnapshotInfo
if keepLatest {
// Keep only the most recent snapshot
if len(snapshots) > 1 {
toDelete = snapshots[1:]
}
} else if olderThan != "" {
// Parse duration
duration, err := parseDuration(olderThan)
if err != nil {
return fmt.Errorf("invalid duration: %w", err)
}
cutoff := time.Now().UTC().Add(-duration)
for _, snap := range snapshots {
if snap.Timestamp.Before(cutoff) {
toDelete = append(toDelete, snap)
}
}
}
if len(toDelete) == 0 {
fmt.Println("No snapshots to delete")
return nil
}
// Show what will be deleted
fmt.Printf("The following snapshots will be deleted:\n\n")
for _, snap := range toDelete {
fmt.Printf(" %s (%s, %s)\n",
snap.ID,
snap.Timestamp.Format("2006-01-02 15:04:05"),
formatBytes(snap.CompressedSize))
}
// Confirm unless --force is used
if !force {
fmt.Printf("\nDelete %d snapshot(s)? [y/N] ", len(toDelete))
var confirm string
if _, err := fmt.Scanln(&confirm); err != nil {
// Treat EOF or error as "no"
fmt.Println("Cancelled")
return nil
}
if strings.ToLower(confirm) != "y" {
fmt.Println("Cancelled")
return nil
}
} else {
fmt.Printf("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete))
}
// Delete snapshots
for _, snap := range toDelete {
log.Info("Deleting snapshot", "id", snap.ID)
if err := v.deleteSnapshot(snap.ID); err != nil {
return fmt.Errorf("deleting snapshot %s: %w", snap.ID, err)
}
}
fmt.Printf("Deleted %d snapshot(s)\n", len(toDelete))
// Note: Run 'vaultik prune' separately to clean up unreferenced blobs
fmt.Println("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.")
return nil
}
// VerifySnapshot checks snapshot integrity
func (v *Vaultik) VerifySnapshot(snapshotID string, deep bool) error {
// Parse snapshot ID to extract timestamp
parts := strings.Split(snapshotID, "-")
var snapshotTime time.Time
if len(parts) >= 3 {
// Format: hostname-YYYYMMDD-HHMMSSZ
dateStr := parts[len(parts)-2]
timeStr := parts[len(parts)-1]
if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") {
timeStr = timeStr[:6] // Remove Z
timestamp, err := time.Parse("20060102150405", dateStr+timeStr)
if err == nil {
snapshotTime = timestamp
}
}
}
fmt.Printf("Verifying snapshot %s\n", snapshotID)
if !snapshotTime.IsZero() {
fmt.Printf("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST"))
}
fmt.Println()
// Download and parse manifest
manifest, err := v.downloadManifest(snapshotID)
if err != nil {
return fmt.Errorf("downloading manifest: %w", err)
}
fmt.Printf("Snapshot information:\n")
fmt.Printf(" Blob count: %d\n", manifest.BlobCount)
fmt.Printf(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize)))
if manifest.Timestamp != "" {
if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil {
fmt.Printf(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST"))
}
}
fmt.Println()
// Check each blob exists
fmt.Printf("Checking blob existence...\n")
missing := 0
verified := 0
missingSize := int64(0)
for _, blob := range manifest.Blobs {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash)
if deep {
// Download and verify hash
// TODO: Implement deep verification
fmt.Printf("Deep verification not yet implemented\n")
return nil
} else {
// Just check existence
_, err := v.S3Client.StatObject(v.ctx, blobPath)
if err != nil {
fmt.Printf(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize)))
missing++
missingSize += blob.CompressedSize
} else {
verified++
}
}
}
fmt.Printf("\nVerification complete:\n")
fmt.Printf(" Verified: %d blobs (%s)\n", verified,
humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize)))
if missing > 0 {
fmt.Printf(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize)))
} else {
fmt.Printf(" Missing: 0 blobs\n")
}
fmt.Printf(" Status: ")
if missing > 0 {
fmt.Printf("FAILED - %d blobs are missing\n", missing)
return fmt.Errorf("%d blobs are missing", missing)
} else {
fmt.Printf("OK - All blobs verified\n")
}
return nil
}
// Helper methods that were previously on SnapshotApp
func (v *Vaultik) getManifestSize(snapshotID string) (int64, error) {
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
reader, err := v.S3Client.GetObject(v.ctx, manifestPath)
if err != nil {
return 0, fmt.Errorf("downloading manifest: %w", err)
}
defer func() { _ = reader.Close() }()
manifest, err := snapshot.DecodeManifest(reader)
if err != nil {
return 0, fmt.Errorf("decoding manifest: %w", err)
}
return manifest.TotalCompressedSize, nil
}
func (v *Vaultik) downloadManifest(snapshotID string) (*snapshot.Manifest, error) {
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
reader, err := v.S3Client.GetObject(v.ctx, manifestPath)
if err != nil {
return nil, err
}
defer func() { _ = reader.Close() }()
manifest, err := snapshot.DecodeManifest(reader)
if err != nil {
return nil, fmt.Errorf("decoding manifest: %w", err)
}
return manifest, nil
}
func (v *Vaultik) deleteSnapshot(snapshotID string) error {
// First, delete from S3
// List all objects under metadata/{snapshotID}/
prefix := fmt.Sprintf("metadata/%s/", snapshotID)
objectCh := v.S3Client.ListObjectsStream(v.ctx, prefix, true)
var objectsToDelete []string
for object := range objectCh {
if object.Err != nil {
return fmt.Errorf("listing objects: %w", object.Err)
}
objectsToDelete = append(objectsToDelete, object.Key)
}
// Delete all objects
for _, key := range objectsToDelete {
if err := v.S3Client.RemoveObject(v.ctx, key); err != nil {
return fmt.Errorf("removing %s: %w", key, err)
}
}
// Then, delete from local database
// Delete related records first to avoid foreign key constraints
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotID); err != nil {
log.Error("Failed to delete snapshot files", "snapshot_id", snapshotID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotID); err != nil {
log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshotID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotID); err != nil {
log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshotID, "error", err)
}
// Now delete the snapshot itself
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotID); err != nil {
return fmt.Errorf("deleting snapshot from database: %w", err)
}
return nil
}
func (v *Vaultik) syncWithRemote() error {
log.Info("Syncing with remote snapshots")
// Get all remote snapshot IDs
remoteSnapshots := make(map[string]bool)
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
for object := range objectCh {
if object.Err != nil {
return fmt.Errorf("listing remote snapshots: %w", object.Err)
}
// Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
parts := strings.Split(object.Key, "/")
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
remoteSnapshots[parts[1]] = true
}
}
log.Debug("Found remote snapshots", "count", len(remoteSnapshots))
// Get all local snapshots (use a high limit to get all)
localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000)
if err != nil {
return fmt.Errorf("listing local snapshots: %w", err)
}
// Remove local snapshots that don't exist remotely
removedCount := 0
for _, snapshot := range localSnapshots {
if !remoteSnapshots[snapshot.ID] {
log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID)
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil {
log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err)
} else {
removedCount++
}
}
}
if removedCount > 0 {
log.Info("Removed local snapshots not found in remote", "count", removedCount)
}
return nil
}

124
internal/vaultik/vaultik.go Normal file
View File

@ -0,0 +1,124 @@
package vaultik
import (
"context"
"fmt"
"io"
"os"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/s3"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"github.com/spf13/afero"
"go.uber.org/fx"
)
// Vaultik contains all dependencies needed for vaultik operations
type Vaultik struct {
Globals *globals.Globals
Config *config.Config
DB *database.DB
Repositories *database.Repositories
S3Client *s3.Client
ScannerFactory snapshot.ScannerFactory
SnapshotManager *snapshot.SnapshotManager
Shutdowner fx.Shutdowner
Fs afero.Fs
// Context management
ctx context.Context
cancel context.CancelFunc
// IO
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
}
// VaultikParams contains all parameters for New that can be provided by fx
type VaultikParams struct {
fx.In
Globals *globals.Globals
Config *config.Config
DB *database.DB
Repositories *database.Repositories
S3Client *s3.Client
ScannerFactory snapshot.ScannerFactory
SnapshotManager *snapshot.SnapshotManager
Shutdowner fx.Shutdowner
Fs afero.Fs `optional:"true"`
}
// New creates a new Vaultik instance with proper context management
// It automatically includes crypto capabilities if age_secret_key is configured
func New(params VaultikParams) *Vaultik {
ctx, cancel := context.WithCancel(context.Background())
// Use provided filesystem or default to OS filesystem
fs := params.Fs
if fs == nil {
fs = afero.NewOsFs()
}
// Set filesystem on SnapshotManager
params.SnapshotManager.SetFilesystem(fs)
return &Vaultik{
Globals: params.Globals,
Config: params.Config,
DB: params.DB,
Repositories: params.Repositories,
S3Client: params.S3Client,
ScannerFactory: params.ScannerFactory,
SnapshotManager: params.SnapshotManager,
Shutdowner: params.Shutdowner,
Fs: fs,
ctx: ctx,
cancel: cancel,
Stdout: os.Stdout,
Stderr: os.Stderr,
Stdin: os.Stdin,
}
}
// Context returns the Vaultik's context
func (v *Vaultik) Context() context.Context {
return v.ctx
}
// Cancel cancels the Vaultik's context
func (v *Vaultik) Cancel() {
v.cancel()
}
// CanDecrypt returns true if this Vaultik instance has decryption capabilities
func (v *Vaultik) CanDecrypt() bool {
return v.Config.AgeSecretKey != ""
}
// GetEncryptor creates a new Encryptor instance based on the configured age recipients
// Returns an error if no recipients are configured
func (v *Vaultik) GetEncryptor() (*crypto.Encryptor, error) {
if len(v.Config.AgeRecipients) == 0 {
return nil, fmt.Errorf("no age recipients configured")
}
return crypto.NewEncryptor(v.Config.AgeRecipients)
}
// GetDecryptor creates a new Decryptor instance based on the configured age secret key
// Returns an error if no secret key is configured
func (v *Vaultik) GetDecryptor() (*crypto.Decryptor, error) {
if v.Config.AgeSecretKey == "" {
return nil, fmt.Errorf("no age secret key configured")
}
return crypto.NewDecryptor(v.Config.AgeSecretKey)
}
// GetFilesystem returns the filesystem instance used by Vaultik
func (v *Vaultik) GetFilesystem() afero.Fs {
return v.Fs
}

396
internal/vaultik/verify.go Normal file
View File

@ -0,0 +1,396 @@
package vaultik
import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"io"
"os"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"github.com/dustin/go-humanize"
"github.com/klauspost/compress/zstd"
_ "github.com/mattn/go-sqlite3"
)
// VerifyOptions contains options for the verify command
type VerifyOptions struct {
Deep bool
}
// RunDeepVerify executes deep verification operation
func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
// Check for decryption capability
if !v.CanDecrypt() {
return fmt.Errorf("age_secret_key missing from config - required for deep verification")
}
log.Info("Starting snapshot verification",
"snapshot_id", snapshotID,
"mode", map[bool]string{true: "deep", false: "shallow"}[opts.Deep],
)
// Step 1: Download manifest
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
log.Info("Downloading manifest", "path", manifestPath)
manifestReader, err := v.S3Client.GetObject(v.ctx, manifestPath)
if err != nil {
return fmt.Errorf("failed to download manifest: %w", err)
}
defer func() { _ = manifestReader.Close() }()
// Decompress manifest
manifest, err := snapshot.DecodeManifest(manifestReader)
if err != nil {
return fmt.Errorf("failed to decode manifest: %w", err)
}
log.Info("Manifest loaded",
"blob_count", manifest.BlobCount,
"total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize)),
)
// Step 2: Download and decrypt database
dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
log.Info("Downloading encrypted database", "path", dbPath)
dbReader, err := v.S3Client.GetObject(v.ctx, dbPath)
if err != nil {
return fmt.Errorf("failed to download database: %w", err)
}
defer func() { _ = dbReader.Close() }()
// Decrypt and decompress database
tempDB, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey)
if err != nil {
return fmt.Errorf("failed to decrypt database: %w", err)
}
defer func() {
if tempDB != nil {
_ = tempDB.Close()
}
}()
// Step 3: Compare blob lists
if err := v.verifyBlobLists(snapshotID, manifest, tempDB.DB); err != nil {
return err
}
// Step 4: Verify blob existence
if err := v.verifyBlobExistence(manifest); err != nil {
return err
}
// Step 5: Deep verification if requested
if opts.Deep {
if err := v.performDeepVerification(manifest, tempDB.DB); err != nil {
return err
}
}
log.Info("✓ Verification completed successfully",
"snapshot_id", snapshotID,
"mode", map[bool]string{true: "deep", false: "shallow"}[opts.Deep],
)
return nil
}
// tempDB wraps sql.DB with cleanup
type tempDB struct {
*sql.DB
tempPath string
}
func (t *tempDB) Close() error {
err := t.DB.Close()
_ = os.Remove(t.tempPath)
return err
}
// decryptAndLoadDatabase decrypts and loads the database from the encrypted stream
func (v *Vaultik) decryptAndLoadDatabase(reader io.ReadCloser, secretKey string) (*tempDB, error) {
// Get decryptor
decryptor, err := v.GetDecryptor()
if err != nil {
return nil, fmt.Errorf("failed to get decryptor: %w", err)
}
// Decrypt the stream
decryptedReader, err := decryptor.DecryptStream(reader)
if err != nil {
return nil, fmt.Errorf("failed to decrypt database: %w", err)
}
// Decompress the database
decompressor, err := zstd.NewReader(decryptedReader)
if err != nil {
return nil, fmt.Errorf("failed to create decompressor: %w", err)
}
defer decompressor.Close()
// Create temporary file for database
tempFile, err := os.CreateTemp("", "vaultik-verify-*.db")
if err != nil {
return nil, fmt.Errorf("failed to create temp file: %w", err)
}
tempPath := tempFile.Name()
// Copy decompressed data to temp file
if _, err := io.Copy(tempFile, decompressor); err != nil {
_ = tempFile.Close()
_ = os.Remove(tempPath)
return nil, fmt.Errorf("failed to write database: %w", err)
}
// Close temp file before opening with sqlite
if err := tempFile.Close(); err != nil {
_ = os.Remove(tempPath)
return nil, fmt.Errorf("failed to close temp file: %w", err)
}
// Open the database
db, err := sql.Open("sqlite3", tempPath)
if err != nil {
_ = os.Remove(tempPath)
return nil, fmt.Errorf("failed to open database: %w", err)
}
return &tempDB{
DB: db,
tempPath: tempPath,
}, nil
}
// verifyBlobLists compares the blob lists between manifest and database
func (v *Vaultik) verifyBlobLists(snapshotID string, manifest *snapshot.Manifest, db *sql.DB) error {
log.Info("Verifying blob lists match between manifest and database")
// Get blobs from database
query := `
SELECT b.blob_hash, b.compressed_size
FROM snapshot_blobs sb
JOIN blobs b ON sb.blob_hash = b.blob_hash
WHERE sb.snapshot_id = ?
ORDER BY b.blob_hash
`
rows, err := db.QueryContext(v.ctx, query, snapshotID)
if err != nil {
return fmt.Errorf("failed to query snapshot blobs: %w", err)
}
defer func() { _ = rows.Close() }()
// Build map of database blobs
dbBlobs := make(map[string]int64)
for rows.Next() {
var hash string
var size int64
if err := rows.Scan(&hash, &size); err != nil {
return fmt.Errorf("failed to scan blob row: %w", err)
}
dbBlobs[hash] = size
}
// Build map of manifest blobs
manifestBlobs := make(map[string]int64)
for _, blob := range manifest.Blobs {
manifestBlobs[blob.Hash] = blob.CompressedSize
}
// Compare counts
if len(dbBlobs) != len(manifestBlobs) {
return fmt.Errorf("blob count mismatch: database has %d blobs, manifest has %d blobs",
len(dbBlobs), len(manifestBlobs))
}
// Check each blob exists in both
for hash, dbSize := range dbBlobs {
manifestSize, exists := manifestBlobs[hash]
if !exists {
return fmt.Errorf("blob %s exists in database but not in manifest", hash)
}
if dbSize != manifestSize {
return fmt.Errorf("blob %s size mismatch: database has %d bytes, manifest has %d bytes",
hash, dbSize, manifestSize)
}
}
for hash := range manifestBlobs {
if _, exists := dbBlobs[hash]; !exists {
return fmt.Errorf("blob %s exists in manifest but not in database", hash)
}
}
log.Info("✓ Blob lists match", "blob_count", len(dbBlobs))
return nil
}
// verifyBlobExistence checks that all blobs exist in S3
func (v *Vaultik) verifyBlobExistence(manifest *snapshot.Manifest) error {
log.Info("Verifying blob existence in S3", "blob_count", len(manifest.Blobs))
for i, blob := range manifest.Blobs {
// Construct blob path
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash)
// Check blob exists with HeadObject
stat, err := v.S3Client.StatObject(v.ctx, blobPath)
if err != nil {
return fmt.Errorf("blob %s missing from S3: %w", blob.Hash, err)
}
// Verify size matches
if stat.Size != blob.CompressedSize {
return fmt.Errorf("blob %s size mismatch: S3 has %d bytes, manifest has %d bytes",
blob.Hash, stat.Size, blob.CompressedSize)
}
// Progress update every 100 blobs
if (i+1)%100 == 0 || i == len(manifest.Blobs)-1 {
log.Info("Blob existence check progress",
"checked", i+1,
"total", len(manifest.Blobs),
"percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(manifest.Blobs))*100),
)
}
}
log.Info("✓ All blobs exist in S3")
return nil
}
// performDeepVerification downloads and verifies the content of each blob
func (v *Vaultik) performDeepVerification(manifest *snapshot.Manifest, db *sql.DB) error {
log.Info("Starting deep verification - downloading and verifying all blobs")
totalBytes := int64(0)
for i, blobInfo := range manifest.Blobs {
// Verify individual blob
if err := v.verifyBlob(blobInfo, db); err != nil {
return fmt.Errorf("blob %s verification failed: %w", blobInfo.Hash, err)
}
totalBytes += blobInfo.CompressedSize
// Progress update
log.Info("Deep verification progress",
"blob", fmt.Sprintf("%d/%d", i+1, len(manifest.Blobs)),
"total_downloaded", humanize.Bytes(uint64(totalBytes)),
"percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(manifest.Blobs))*100),
)
}
log.Info("✓ Deep verification completed successfully",
"blobs_verified", len(manifest.Blobs),
"total_size", humanize.Bytes(uint64(totalBytes)),
)
return nil
}
// verifyBlob downloads and verifies a single blob
func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error {
// Download blob
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobInfo.Hash[:2], blobInfo.Hash[2:4], blobInfo.Hash)
reader, err := v.S3Client.GetObject(v.ctx, blobPath)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer func() { _ = reader.Close() }()
// Get decryptor
decryptor, err := v.GetDecryptor()
if err != nil {
return fmt.Errorf("failed to get decryptor: %w", err)
}
// Decrypt blob
decryptedReader, err := decryptor.DecryptStream(reader)
if err != nil {
return fmt.Errorf("failed to decrypt: %w", err)
}
// Decompress blob
decompressor, err := zstd.NewReader(decryptedReader)
if err != nil {
return fmt.Errorf("failed to decompress: %w", err)
}
defer decompressor.Close()
// Query blob chunks from database to get offsets and lengths
query := `
SELECT bc.chunk_hash, bc.offset, bc.length
FROM blob_chunks bc
JOIN blobs b ON bc.blob_id = b.id
WHERE b.blob_hash = ?
ORDER BY bc.offset
`
rows, err := db.QueryContext(v.ctx, query, blobInfo.Hash)
if err != nil {
return fmt.Errorf("failed to query blob chunks: %w", err)
}
defer func() { _ = rows.Close() }()
var lastOffset int64 = -1
chunkCount := 0
totalRead := int64(0)
// Verify each chunk in the blob
for rows.Next() {
var chunkHash string
var offset, length int64
if err := rows.Scan(&chunkHash, &offset, &length); err != nil {
return fmt.Errorf("failed to scan chunk row: %w", err)
}
// Verify chunk ordering
if offset <= lastOffset {
return fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset)
}
lastOffset = offset
// Read chunk data from decompressed stream
if offset > totalRead {
// Skip to the correct offset
skipBytes := offset - totalRead
if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil {
return fmt.Errorf("failed to skip to offset %d: %w", offset, err)
}
totalRead = offset
}
// Read chunk data
chunkData := make([]byte, length)
if _, err := io.ReadFull(decompressor, chunkData); err != nil {
return fmt.Errorf("failed to read chunk at offset %d: %w", offset, err)
}
totalRead += length
// Verify chunk hash
hasher := sha256.New()
hasher.Write(chunkData)
calculatedHash := hex.EncodeToString(hasher.Sum(nil))
if calculatedHash != chunkHash {
return fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s",
offset, calculatedHash, chunkHash)
}
chunkCount++
}
if err := rows.Err(); err != nil {
return fmt.Errorf("error iterating blob chunks: %w", err)
}
log.Debug("Blob verified",
"hash", blobInfo.Hash,
"chunks", chunkCount,
"size", humanize.Bytes(uint64(blobInfo.CompressedSize)),
)
return nil
}