Add end-to-end integration tests for Vaultik
- Create comprehensive integration tests with mock S3 client - Add in-memory filesystem and SQLite database support for testing - Test full backup workflow including chunking, packing, and uploading - Add test to verify encrypted blob content - Fix scanner to use afero filesystem for temp file cleanup - Demonstrate successful backup and verification with mock dependencies
This commit is contained in:
parent
bb38f8c5d6
commit
d7cd9aac27
@ -676,7 +676,7 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
|
||||
if err := blobWithReader.TempFile.Close(); err != nil {
|
||||
log.Fatal("Failed to close temp file", "file", tempName, "error", err)
|
||||
}
|
||||
if err := os.Remove(tempName); err != nil {
|
||||
if err := s.fs.Remove(tempName); err != nil {
|
||||
log.Fatal("Failed to remove temp file", "file", tempName, "error", err)
|
||||
}
|
||||
}
|
||||
|
103
internal/vaultik/helpers.go
Normal file
103
internal/vaultik/helpers.go
Normal file
@ -0,0 +1,103 @@
|
||||
package vaultik
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SnapshotInfo contains information about a snapshot
|
||||
type SnapshotInfo struct {
|
||||
ID string `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
CompressedSize int64 `json:"compressed_size"`
|
||||
}
|
||||
|
||||
// formatNumber formats a number with commas
|
||||
func formatNumber(n int) string {
|
||||
str := fmt.Sprintf("%d", n)
|
||||
var result []string
|
||||
for i, digit := range str {
|
||||
if i > 0 && (len(str)-i)%3 == 0 {
|
||||
result = append(result, ",")
|
||||
}
|
||||
result = append(result, string(digit))
|
||||
}
|
||||
return strings.Join(result, "")
|
||||
}
|
||||
|
||||
// formatDuration formats a duration in a human-readable way
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Second {
|
||||
return fmt.Sprintf("%dms", d.Milliseconds())
|
||||
}
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%.1fs", d.Seconds())
|
||||
}
|
||||
if d < time.Hour {
|
||||
mins := int(d.Minutes())
|
||||
secs := int(d.Seconds()) % 60
|
||||
return fmt.Sprintf("%dm %ds", mins, secs)
|
||||
}
|
||||
hours := int(d.Hours())
|
||||
mins := int(d.Minutes()) % 60
|
||||
return fmt.Sprintf("%dh %dm", hours, mins)
|
||||
}
|
||||
|
||||
// formatBytes formats bytes in a human-readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// parseSnapshotTimestamp extracts the timestamp from a snapshot ID
|
||||
func parseSnapshotTimestamp(snapshotID string) (time.Time, error) {
|
||||
// Format: hostname-YYYYMMDD-HHMMSSZ
|
||||
parts := strings.Split(snapshotID, "-")
|
||||
if len(parts) < 3 {
|
||||
return time.Time{}, fmt.Errorf("invalid snapshot ID format")
|
||||
}
|
||||
|
||||
dateStr := parts[len(parts)-2]
|
||||
timeStr := parts[len(parts)-1]
|
||||
|
||||
if len(dateStr) != 8 || len(timeStr) != 7 || !strings.HasSuffix(timeStr, "Z") {
|
||||
return time.Time{}, fmt.Errorf("invalid timestamp format")
|
||||
}
|
||||
|
||||
// Remove Z suffix
|
||||
timeStr = timeStr[:6]
|
||||
|
||||
// Parse the timestamp
|
||||
timestamp, err := time.Parse("20060102150405", dateStr+timeStr)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to parse timestamp: %w", err)
|
||||
}
|
||||
|
||||
return timestamp.UTC(), nil
|
||||
}
|
||||
|
||||
// parseDuration parses a duration string with support for days
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
// Check for days suffix
|
||||
if strings.HasSuffix(s, "d") {
|
||||
daysStr := strings.TrimSuffix(s, "d")
|
||||
days, err := strconv.Atoi(daysStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid days value: %w", err)
|
||||
}
|
||||
return time.Duration(days) * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
// Otherwise use standard Go duration parsing
|
||||
return time.ParseDuration(s)
|
||||
}
|
101
internal/vaultik/info.go
Normal file
101
internal/vaultik/info.go
Normal file
@ -0,0 +1,101 @@
|
||||
package vaultik
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/dustin/go-humanize"
|
||||
)
|
||||
|
||||
// ShowInfo displays system and configuration information
|
||||
func (v *Vaultik) ShowInfo() error {
|
||||
// System Information
|
||||
fmt.Printf("=== System Information ===\n")
|
||||
fmt.Printf("OS/Architecture: %s/%s\n", runtime.GOOS, runtime.GOARCH)
|
||||
fmt.Printf("Version: %s\n", v.Globals.Version)
|
||||
fmt.Printf("Commit: %s\n", v.Globals.Commit)
|
||||
fmt.Printf("Go Version: %s\n", runtime.Version())
|
||||
fmt.Println()
|
||||
|
||||
// Storage Configuration
|
||||
fmt.Printf("=== Storage Configuration ===\n")
|
||||
fmt.Printf("S3 Bucket: %s\n", v.Config.S3.Bucket)
|
||||
if v.Config.S3.Prefix != "" {
|
||||
fmt.Printf("S3 Prefix: %s\n", v.Config.S3.Prefix)
|
||||
}
|
||||
fmt.Printf("S3 Endpoint: %s\n", v.Config.S3.Endpoint)
|
||||
fmt.Printf("S3 Region: %s\n", v.Config.S3.Region)
|
||||
fmt.Println()
|
||||
|
||||
// Backup Settings
|
||||
fmt.Printf("=== Backup Settings ===\n")
|
||||
fmt.Printf("Source Directories:\n")
|
||||
for _, dir := range v.Config.SourceDirs {
|
||||
fmt.Printf(" - %s\n", dir)
|
||||
}
|
||||
|
||||
// Global exclude patterns
|
||||
if len(v.Config.Exclude) > 0 {
|
||||
fmt.Printf("Exclude Patterns: %s\n", strings.Join(v.Config.Exclude, ", "))
|
||||
}
|
||||
|
||||
fmt.Printf("Compression: zstd level %d\n", v.Config.CompressionLevel)
|
||||
fmt.Printf("Chunk Size: %s\n", humanize.Bytes(uint64(v.Config.ChunkSize)))
|
||||
fmt.Printf("Blob Size Limit: %s\n", humanize.Bytes(uint64(v.Config.BlobSizeLimit)))
|
||||
fmt.Println()
|
||||
|
||||
// Encryption Configuration
|
||||
fmt.Printf("=== Encryption Configuration ===\n")
|
||||
fmt.Printf("Recipients:\n")
|
||||
for _, recipient := range v.Config.AgeRecipients {
|
||||
fmt.Printf(" - %s\n", recipient)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Daemon Settings (if applicable)
|
||||
if v.Config.BackupInterval > 0 || v.Config.MinTimeBetweenRun > 0 {
|
||||
fmt.Printf("=== Daemon Settings ===\n")
|
||||
if v.Config.BackupInterval > 0 {
|
||||
fmt.Printf("Backup Interval: %s\n", v.Config.BackupInterval)
|
||||
}
|
||||
if v.Config.MinTimeBetweenRun > 0 {
|
||||
fmt.Printf("Minimum Time: %s\n", v.Config.MinTimeBetweenRun)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// Local Database
|
||||
fmt.Printf("=== Local Database ===\n")
|
||||
fmt.Printf("Index Path: %s\n", v.Config.IndexPath)
|
||||
|
||||
// Check if index file exists and get its size
|
||||
if info, err := v.Fs.Stat(v.Config.IndexPath); err == nil {
|
||||
fmt.Printf("Index Size: %s\n", humanize.Bytes(uint64(info.Size())))
|
||||
|
||||
// Get snapshot count from database
|
||||
query := `SELECT COUNT(*) FROM snapshots WHERE completed_at IS NOT NULL`
|
||||
var snapshotCount int
|
||||
if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&snapshotCount); err == nil {
|
||||
fmt.Printf("Snapshots: %d\n", snapshotCount)
|
||||
}
|
||||
|
||||
// Get blob count from database
|
||||
query = `SELECT COUNT(*) FROM blobs`
|
||||
var blobCount int
|
||||
if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&blobCount); err == nil {
|
||||
fmt.Printf("Blobs: %d\n", blobCount)
|
||||
}
|
||||
|
||||
// Get file count from database
|
||||
query = `SELECT COUNT(*) FROM files`
|
||||
var fileCount int
|
||||
if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&fileCount); err == nil {
|
||||
fmt.Printf("Files: %d\n", fileCount)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("Index Size: (not created)\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
379
internal/vaultik/integration_test.go
Normal file
379
internal/vaultik/integration_test.go
Normal file
@ -0,0 +1,379 @@
|
||||
package vaultik_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/config"
|
||||
"git.eeqj.de/sneak/vaultik/internal/database"
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"git.eeqj.de/sneak/vaultik/internal/s3"
|
||||
"git.eeqj.de/sneak/vaultik/internal/snapshot"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockS3Client implements a mock S3 client for testing
|
||||
type MockS3Client struct {
|
||||
mu sync.Mutex
|
||||
storage map[string][]byte
|
||||
calls []string
|
||||
}
|
||||
|
||||
func NewMockS3Client() *MockS3Client {
|
||||
return &MockS3Client{
|
||||
storage: make(map[string][]byte),
|
||||
calls: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockS3Client) PutObject(ctx context.Context, key string, reader io.Reader) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.calls = append(m.calls, "PutObject:"+key)
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.storage[key] = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockS3Client) PutObjectWithProgress(ctx context.Context, key string, reader io.Reader, size int64, progress s3.ProgressCallback) error {
|
||||
// For testing, just call PutObject
|
||||
return m.PutObject(ctx, key, reader)
|
||||
}
|
||||
|
||||
func (m *MockS3Client) GetObject(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.calls = append(m.calls, "GetObject:"+key)
|
||||
data, exists := m.storage[key]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("key not found: %s", key)
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
func (m *MockS3Client) StatObject(ctx context.Context, key string) (*s3.ObjectInfo, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.calls = append(m.calls, "StatObject:"+key)
|
||||
data, exists := m.storage[key]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("key not found: %s", key)
|
||||
}
|
||||
return &s3.ObjectInfo{
|
||||
Key: key,
|
||||
Size: int64(len(data)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockS3Client) DeleteObject(ctx context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.calls = append(m.calls, "DeleteObject:"+key)
|
||||
delete(m.storage, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockS3Client) ListObjects(ctx context.Context, prefix string) ([]*s3.ObjectInfo, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.calls = append(m.calls, "ListObjects:"+prefix)
|
||||
var objects []*s3.ObjectInfo
|
||||
for key, data := range m.storage {
|
||||
if len(prefix) == 0 || (len(key) >= len(prefix) && key[:len(prefix)] == prefix) {
|
||||
objects = append(objects, &s3.ObjectInfo{
|
||||
Key: key,
|
||||
Size: int64(len(data)),
|
||||
})
|
||||
}
|
||||
}
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// GetCalls returns the list of S3 operations that were called
|
||||
func (m *MockS3Client) GetCalls() []string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
calls := make([]string, len(m.calls))
|
||||
copy(calls, m.calls)
|
||||
return calls
|
||||
}
|
||||
|
||||
// GetStorageSize returns the number of objects in storage
|
||||
func (m *MockS3Client) GetStorageSize() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return len(m.storage)
|
||||
}
|
||||
|
||||
// TestEndToEndBackup tests the full backup workflow with mocked dependencies
|
||||
func TestEndToEndBackup(t *testing.T) {
|
||||
// Initialize logger
|
||||
log.Initialize(log.Config{})
|
||||
|
||||
// Create in-memory filesystem
|
||||
fs := afero.NewMemMapFs()
|
||||
|
||||
// Create test directory structure and files
|
||||
testFiles := map[string]string{
|
||||
"/home/user/documents/file1.txt": "This is file 1 content",
|
||||
"/home/user/documents/file2.txt": "This is file 2 content with more data",
|
||||
"/home/user/pictures/photo1.jpg": "Binary photo data here...",
|
||||
"/home/user/code/main.go": "package main\n\nfunc main() {\n\tprintln(\"Hello, World!\")\n}",
|
||||
}
|
||||
|
||||
// Create all directories first
|
||||
dirs := []string{
|
||||
"/home/user/documents",
|
||||
"/home/user/pictures",
|
||||
"/home/user/code",
|
||||
}
|
||||
for _, dir := range dirs {
|
||||
if err := fs.MkdirAll(dir, 0755); err != nil {
|
||||
t.Fatalf("failed to create directory %s: %v", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create test files
|
||||
for path, content := range testFiles {
|
||||
if err := afero.WriteFile(fs, path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("failed to create test file %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create mock S3 client
|
||||
mockS3 := NewMockS3Client()
|
||||
|
||||
// Create test configuration
|
||||
cfg := &config.Config{
|
||||
SourceDirs: []string{"/home/user"},
|
||||
Exclude: []string{"*.tmp", "*.log"},
|
||||
ChunkSize: config.Size(16 * 1024), // 16KB chunks
|
||||
BlobSizeLimit: config.Size(100 * 1024), // 100KB blobs
|
||||
CompressionLevel: 3,
|
||||
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
|
||||
AgeSecretKey: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5", // Test private key
|
||||
S3: config.S3Config{
|
||||
Endpoint: "http://localhost:9000", // MinIO endpoint for testing
|
||||
Region: "us-east-1",
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "test-access",
|
||||
SecretAccessKey: "test-secret",
|
||||
},
|
||||
IndexPath: ":memory:", // In-memory SQLite database
|
||||
}
|
||||
|
||||
// For a true end-to-end test, we'll create a simpler test that focuses on
|
||||
// the core backup logic using the scanner directly with our mock S3 client
|
||||
ctx := context.Background()
|
||||
|
||||
// Create in-memory database
|
||||
db, err := database.New(ctx, ":memory:")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Errorf("failed to close database: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
repos := database.NewRepositories(db)
|
||||
|
||||
// Create scanner with mock S3 client
|
||||
scanner := snapshot.NewScanner(snapshot.ScannerConfig{
|
||||
FS: fs,
|
||||
ChunkSize: cfg.ChunkSize.Int64(),
|
||||
Repositories: repos,
|
||||
S3Client: mockS3,
|
||||
MaxBlobSize: cfg.BlobSizeLimit.Int64(),
|
||||
CompressionLevel: cfg.CompressionLevel,
|
||||
AgeRecipients: cfg.AgeRecipients,
|
||||
EnableProgress: false,
|
||||
})
|
||||
|
||||
// Create a snapshot record
|
||||
snapshotID := "test-snapshot-001"
|
||||
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
snapshot := &database.Snapshot{
|
||||
ID: snapshotID,
|
||||
Hostname: "test-host",
|
||||
VaultikVersion: "test-version",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
return repos.Snapshots.Create(ctx, tx, snapshot)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the backup scan
|
||||
result, err := scanner.Scan(ctx, "/home/user", snapshotID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify scan results
|
||||
// The scanner counts both files and directories, so we have:
|
||||
// 4 files + 4 directories (/home, /home/user, /home/user/documents, /home/user/pictures, /home/user/code)
|
||||
assert.GreaterOrEqual(t, result.FilesScanned, 4, "Should scan at least 4 files")
|
||||
assert.Greater(t, result.BytesScanned, int64(0), "Should scan some bytes")
|
||||
assert.Greater(t, result.ChunksCreated, 0, "Should create chunks")
|
||||
assert.Greater(t, result.BlobsCreated, 0, "Should create blobs")
|
||||
|
||||
// Verify S3 operations
|
||||
calls := mockS3.GetCalls()
|
||||
t.Logf("S3 operations performed: %v", calls)
|
||||
|
||||
// Should have uploaded at least one blob
|
||||
blobUploads := 0
|
||||
for _, call := range calls {
|
||||
if len(call) > 10 && call[:10] == "PutObject:" {
|
||||
if len(call) > 16 && call[10:16] == "blobs/" {
|
||||
blobUploads++
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Greater(t, blobUploads, 0, "Should upload at least one blob")
|
||||
|
||||
// Verify files in database
|
||||
files, err := repos.Files.ListByPrefix(ctx, "/home/user")
|
||||
require.NoError(t, err)
|
||||
// Count only regular files (not directories)
|
||||
regularFiles := 0
|
||||
for _, f := range files {
|
||||
if f.Mode&0x80000000 == 0 { // Check if regular file (not directory)
|
||||
regularFiles++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 4, regularFiles, "Should have 4 regular files in database")
|
||||
|
||||
// Verify chunks were created by checking a specific file
|
||||
fileChunks, err := repos.FileChunks.GetByPath(ctx, "/home/user/documents/file1.txt")
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, len(fileChunks), 0, "Should have chunks for file1.txt")
|
||||
|
||||
// Verify blobs were uploaded to S3
|
||||
assert.Greater(t, mockS3.GetStorageSize(), 0, "Should have blobs in S3 storage")
|
||||
|
||||
// Complete the snapshot - just verify we got results
|
||||
// In a real integration test, we'd update the snapshot record
|
||||
|
||||
// Create snapshot manager to test metadata export
|
||||
snapshotManager := &snapshot.SnapshotManager{}
|
||||
snapshotManager.SetFilesystem(fs)
|
||||
|
||||
// Note: We can't fully test snapshot metadata export without a proper S3 client mock
|
||||
// that implements all required methods. This would require refactoring the S3 client
|
||||
// interface to be more testable.
|
||||
|
||||
t.Logf("Backup completed successfully:")
|
||||
t.Logf(" Files scanned: %d", result.FilesScanned)
|
||||
t.Logf(" Bytes scanned: %d", result.BytesScanned)
|
||||
t.Logf(" Chunks created: %d", result.ChunksCreated)
|
||||
t.Logf(" Blobs created: %d", result.BlobsCreated)
|
||||
t.Logf(" S3 storage size: %d objects", mockS3.GetStorageSize())
|
||||
}
|
||||
|
||||
// TestBackupAndVerify tests backing up files and verifying the blobs
|
||||
func TestBackupAndVerify(t *testing.T) {
|
||||
// Initialize logger
|
||||
log.Initialize(log.Config{})
|
||||
|
||||
// Create in-memory filesystem
|
||||
fs := afero.NewMemMapFs()
|
||||
|
||||
// Create test files
|
||||
testContent := "This is a test file with some content that should be backed up"
|
||||
err := fs.MkdirAll("/data", 0755)
|
||||
require.NoError(t, err)
|
||||
err = afero.WriteFile(fs, "/data/test.txt", []byte(testContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create mock S3 client
|
||||
mockS3 := NewMockS3Client()
|
||||
|
||||
// Create test database
|
||||
ctx := context.Background()
|
||||
db, err := database.New(ctx, ":memory:")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Errorf("failed to close database: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
repos := database.NewRepositories(db)
|
||||
|
||||
// Create scanner
|
||||
scanner := snapshot.NewScanner(snapshot.ScannerConfig{
|
||||
FS: fs,
|
||||
ChunkSize: int64(1024 * 16), // 16KB chunks
|
||||
Repositories: repos,
|
||||
S3Client: mockS3,
|
||||
MaxBlobSize: int64(1024 * 1024), // 1MB blobs
|
||||
CompressionLevel: 3,
|
||||
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
|
||||
})
|
||||
|
||||
// Create a snapshot
|
||||
snapshotID := "test-snapshot-001"
|
||||
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
snapshot := &database.Snapshot{
|
||||
ID: snapshotID,
|
||||
Hostname: "test-host",
|
||||
VaultikVersion: "test-version",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
return repos.Snapshots.Create(ctx, tx, snapshot)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the backup
|
||||
result, err := scanner.Scan(ctx, "/data", snapshotID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify backup created blobs
|
||||
assert.Greater(t, result.BlobsCreated, 0, "Should create at least one blob")
|
||||
assert.Equal(t, mockS3.GetStorageSize(), result.BlobsCreated, "S3 should have the blobs")
|
||||
|
||||
// Verify we can retrieve the blob from S3
|
||||
objects, err := mockS3.ListObjects(ctx, "blobs/")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, objects, result.BlobsCreated, "Should have correct number of blobs in S3")
|
||||
|
||||
// Get the first blob and verify it exists
|
||||
if len(objects) > 0 {
|
||||
blobKey := objects[0].Key
|
||||
t.Logf("Verifying blob: %s", blobKey)
|
||||
|
||||
// Get blob info
|
||||
blobInfo, err := mockS3.StatObject(ctx, blobKey)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, blobInfo.Size, int64(0), "Blob should have content")
|
||||
|
||||
// Get blob content
|
||||
reader, err := mockS3.GetObject(ctx, blobKey)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = reader.Close() }()
|
||||
|
||||
// Verify blob data is encrypted (should not contain plaintext)
|
||||
blobData, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, string(blobData), testContent, "Blob should be encrypted")
|
||||
assert.Greater(t, len(blobData), 0, "Blob should have data")
|
||||
}
|
||||
|
||||
t.Logf("Backup and verify test completed successfully")
|
||||
}
|
169
internal/vaultik/prune.go
Normal file
169
internal/vaultik/prune.go
Normal file
@ -0,0 +1,169 @@
|
||||
package vaultik
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"github.com/dustin/go-humanize"
|
||||
)
|
||||
|
||||
// PruneOptions contains options for the prune command
|
||||
type PruneOptions struct {
|
||||
Force bool
|
||||
}
|
||||
|
||||
// PruneBlobs removes unreferenced blobs from storage
|
||||
func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
|
||||
log.Info("Starting prune operation")
|
||||
|
||||
// Get all remote snapshots and their manifests
|
||||
allBlobsReferenced := make(map[string]bool)
|
||||
manifestCount := 0
|
||||
|
||||
// List all snapshots in S3
|
||||
log.Info("Listing remote snapshots")
|
||||
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
|
||||
|
||||
var snapshotIDs []string
|
||||
for object := range objectCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing remote snapshots: %w", object.Err)
|
||||
}
|
||||
|
||||
// Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
|
||||
parts := strings.Split(object.Key, "/")
|
||||
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
|
||||
// Check if this is a directory by looking for trailing slash
|
||||
if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") {
|
||||
snapshotID := parts[1]
|
||||
// Only add unique snapshot IDs
|
||||
found := false
|
||||
for _, id := range snapshotIDs {
|
||||
if id == snapshotID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
snapshotIDs = append(snapshotIDs, snapshotID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Found manifests in remote storage", "count", len(snapshotIDs))
|
||||
|
||||
// Download and parse each manifest to get referenced blobs
|
||||
for _, snapshotID := range snapshotIDs {
|
||||
log.Debug("Processing manifest", "snapshot_id", snapshotID)
|
||||
|
||||
manifest, err := v.downloadManifest(snapshotID)
|
||||
if err != nil {
|
||||
log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add all blobs from this manifest to our referenced set
|
||||
for _, blob := range manifest.Blobs {
|
||||
allBlobsReferenced[blob.Hash] = true
|
||||
}
|
||||
manifestCount++
|
||||
}
|
||||
|
||||
log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced))
|
||||
|
||||
// List all blobs in S3
|
||||
log.Info("Listing all blobs in storage")
|
||||
allBlobs := make(map[string]int64) // hash -> size
|
||||
blobObjectCh := v.S3Client.ListObjectsStream(v.ctx, "blobs/", true)
|
||||
|
||||
for object := range blobObjectCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing blobs: %w", object.Err)
|
||||
}
|
||||
|
||||
// Extract hash from path like blobs/ab/cd/abcdef123456...
|
||||
parts := strings.Split(object.Key, "/")
|
||||
if len(parts) == 4 && parts[0] == "blobs" {
|
||||
hash := parts[3]
|
||||
allBlobs[hash] = object.Size
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Found blobs in storage", "count", len(allBlobs))
|
||||
|
||||
// Find unreferenced blobs
|
||||
var unreferencedBlobs []string
|
||||
var totalSize int64
|
||||
for hash, size := range allBlobs {
|
||||
if !allBlobsReferenced[hash] {
|
||||
unreferencedBlobs = append(unreferencedBlobs, hash)
|
||||
totalSize += size
|
||||
}
|
||||
}
|
||||
|
||||
if len(unreferencedBlobs) == 0 {
|
||||
log.Info("No unreferenced blobs found")
|
||||
fmt.Println("No unreferenced blobs to remove.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Show what will be deleted
|
||||
log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize)))
|
||||
fmt.Printf("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize)))
|
||||
|
||||
// Confirm unless --force is used
|
||||
if !opts.Force {
|
||||
fmt.Printf("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs))
|
||||
var confirm string
|
||||
if _, err := fmt.Scanln(&confirm); err != nil {
|
||||
// Treat EOF or error as "no"
|
||||
fmt.Println("Cancelled")
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(confirm) != "y" {
|
||||
fmt.Println("Cancelled")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Delete unreferenced blobs
|
||||
log.Info("Deleting unreferenced blobs")
|
||||
deletedCount := 0
|
||||
deletedSize := int64(0)
|
||||
|
||||
for i, hash := range unreferencedBlobs {
|
||||
blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash)
|
||||
|
||||
if err := v.S3Client.RemoveObject(v.ctx, blobPath); err != nil {
|
||||
log.Error("Failed to delete blob", "hash", hash, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
deletedCount++
|
||||
deletedSize += allBlobs[hash]
|
||||
|
||||
// Progress update every 100 blobs
|
||||
if (i+1)%100 == 0 || i == len(unreferencedBlobs)-1 {
|
||||
log.Info("Deletion progress",
|
||||
"deleted", i+1,
|
||||
"total", len(unreferencedBlobs),
|
||||
"percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(unreferencedBlobs))*100),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Prune complete",
|
||||
"deleted_count", deletedCount,
|
||||
"deleted_size", humanize.Bytes(uint64(deletedSize)),
|
||||
"failed", len(unreferencedBlobs)-deletedCount,
|
||||
)
|
||||
|
||||
fmt.Printf("\nDeleted %d blob(s) totaling %s\n", deletedCount, humanize.Bytes(uint64(deletedSize)))
|
||||
if deletedCount < len(unreferencedBlobs) {
|
||||
fmt.Printf("Failed to delete %d blob(s)\n", len(unreferencedBlobs)-deletedCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
701
internal/vaultik/snapshot.go
Normal file
701
internal/vaultik/snapshot.go
Normal file
@ -0,0 +1,701 @@
|
||||
package vaultik
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/database"
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"git.eeqj.de/sneak/vaultik/internal/snapshot"
|
||||
"github.com/dustin/go-humanize"
|
||||
)
|
||||
|
||||
// SnapshotCreateOptions contains options for the snapshot create command
|
||||
type SnapshotCreateOptions struct {
|
||||
Daemon bool
|
||||
Cron bool
|
||||
Prune bool
|
||||
}
|
||||
|
||||
// CreateSnapshot executes the snapshot creation operation
|
||||
func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
|
||||
snapshotStartTime := time.Now()
|
||||
|
||||
log.Info("Starting snapshot creation",
|
||||
"version", v.Globals.Version,
|
||||
"commit", v.Globals.Commit,
|
||||
"index_path", v.Config.IndexPath,
|
||||
)
|
||||
|
||||
// Clean up incomplete snapshots FIRST, before any scanning
|
||||
// This is critical for data safety - see CleanupIncompleteSnapshots for details
|
||||
hostname := v.Config.Hostname
|
||||
if hostname == "" {
|
||||
hostname, _ = os.Hostname()
|
||||
}
|
||||
|
||||
// CRITICAL: This MUST succeed. If we fail to clean up incomplete snapshots,
|
||||
// the deduplication logic will think files from the incomplete snapshot were
|
||||
// already backed up and skip them, resulting in data loss.
|
||||
if err := v.SnapshotManager.CleanupIncompleteSnapshots(v.ctx, hostname); err != nil {
|
||||
return fmt.Errorf("cleanup incomplete snapshots: %w", err)
|
||||
}
|
||||
|
||||
if opts.Daemon {
|
||||
log.Info("Running in daemon mode")
|
||||
// TODO: Implement daemon mode with inotify
|
||||
return fmt.Errorf("daemon mode not yet implemented")
|
||||
}
|
||||
|
||||
// Resolve source directories to absolute paths
|
||||
resolvedDirs := make([]string, 0, len(v.Config.SourceDirs))
|
||||
for _, dir := range v.Config.SourceDirs {
|
||||
absPath, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Resolve symlinks
|
||||
resolvedPath, err := filepath.EvalSymlinks(absPath)
|
||||
if err != nil {
|
||||
// If the path doesn't exist yet, use the absolute path
|
||||
if os.IsNotExist(err) {
|
||||
resolvedPath = absPath
|
||||
} else {
|
||||
return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
resolvedDirs = append(resolvedDirs, resolvedPath)
|
||||
}
|
||||
|
||||
// Create scanner with progress enabled (unless in cron mode)
|
||||
scanner := v.ScannerFactory(snapshot.ScannerParams{
|
||||
EnableProgress: !opts.Cron,
|
||||
Fs: v.Fs,
|
||||
})
|
||||
|
||||
// Statistics tracking
|
||||
totalFiles := 0
|
||||
totalBytes := int64(0)
|
||||
totalChunks := 0
|
||||
totalBlobs := 0
|
||||
totalBytesSkipped := int64(0)
|
||||
totalFilesSkipped := 0
|
||||
totalFilesDeleted := 0
|
||||
totalBytesDeleted := int64(0)
|
||||
totalBytesUploaded := int64(0)
|
||||
totalBlobsUploaded := 0
|
||||
uploadDuration := time.Duration(0)
|
||||
|
||||
// Create a new snapshot at the beginning
|
||||
snapshotID, err := v.SnapshotManager.CreateSnapshot(v.ctx, hostname, v.Globals.Version, v.Globals.Commit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating snapshot: %w", err)
|
||||
}
|
||||
log.Info("Beginning snapshot", "snapshot_id", snapshotID)
|
||||
_, _ = fmt.Fprintf(v.Stdout, "Beginning snapshot: %s\n", snapshotID)
|
||||
|
||||
for i, dir := range resolvedDirs {
|
||||
// Check if context is cancelled
|
||||
select {
|
||||
case <-v.ctx.Done():
|
||||
log.Info("Snapshot creation cancelled")
|
||||
return v.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info("Scanning directory", "path", dir)
|
||||
_, _ = fmt.Fprintf(v.Stdout, "Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir)
|
||||
result, err := scanner.Scan(v.ctx, dir, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan %s: %w", dir, err)
|
||||
}
|
||||
|
||||
totalFiles += result.FilesScanned
|
||||
totalBytes += result.BytesScanned
|
||||
totalChunks += result.ChunksCreated
|
||||
totalBlobs += result.BlobsCreated
|
||||
totalFilesSkipped += result.FilesSkipped
|
||||
totalBytesSkipped += result.BytesSkipped
|
||||
totalFilesDeleted += result.FilesDeleted
|
||||
totalBytesDeleted += result.BytesDeleted
|
||||
|
||||
log.Info("Directory scan complete",
|
||||
"path", dir,
|
||||
"files", result.FilesScanned,
|
||||
"files_skipped", result.FilesSkipped,
|
||||
"bytes", result.BytesScanned,
|
||||
"bytes_skipped", result.BytesSkipped,
|
||||
"chunks", result.ChunksCreated,
|
||||
"blobs", result.BlobsCreated,
|
||||
"duration", result.EndTime.Sub(result.StartTime))
|
||||
|
||||
// Remove per-directory summary - the scanner already prints its own summary
|
||||
}
|
||||
|
||||
// Get upload statistics from scanner progress if available
|
||||
if s := scanner.GetProgress(); s != nil {
|
||||
stats := s.GetStats()
|
||||
totalBytesUploaded = stats.BytesUploaded.Load()
|
||||
totalBlobsUploaded = int(stats.BlobsUploaded.Load())
|
||||
uploadDuration = time.Duration(stats.UploadDurationMs.Load()) * time.Millisecond
|
||||
}
|
||||
|
||||
// Update snapshot statistics with extended fields
|
||||
extStats := snapshot.ExtendedBackupStats{
|
||||
BackupStats: snapshot.BackupStats{
|
||||
FilesScanned: totalFiles,
|
||||
BytesScanned: totalBytes,
|
||||
ChunksCreated: totalChunks,
|
||||
BlobsCreated: totalBlobs,
|
||||
BytesUploaded: totalBytesUploaded,
|
||||
},
|
||||
BlobUncompressedSize: 0, // Will be set from database query below
|
||||
CompressionLevel: v.Config.CompressionLevel,
|
||||
UploadDurationMs: uploadDuration.Milliseconds(),
|
||||
}
|
||||
|
||||
if err := v.SnapshotManager.UpdateSnapshotStatsExtended(v.ctx, snapshotID, extStats); err != nil {
|
||||
return fmt.Errorf("updating snapshot stats: %w", err)
|
||||
}
|
||||
|
||||
// Mark snapshot as complete
|
||||
if err := v.SnapshotManager.CompleteSnapshot(v.ctx, snapshotID); err != nil {
|
||||
return fmt.Errorf("completing snapshot: %w", err)
|
||||
}
|
||||
|
||||
// Export snapshot metadata
|
||||
// Export snapshot metadata without closing the database
|
||||
// The export function should handle its own database connection
|
||||
if err := v.SnapshotManager.ExportSnapshotMetadata(v.ctx, v.Config.IndexPath, snapshotID); err != nil {
|
||||
return fmt.Errorf("exporting snapshot metadata: %w", err)
|
||||
}
|
||||
|
||||
// Calculate final statistics
|
||||
snapshotDuration := time.Since(snapshotStartTime)
|
||||
totalFilesChanged := totalFiles - totalFilesSkipped
|
||||
totalBytesChanged := totalBytes
|
||||
totalBytesAll := totalBytes + totalBytesSkipped
|
||||
|
||||
// Calculate upload speed
|
||||
var avgUploadSpeed string
|
||||
if totalBytesUploaded > 0 && uploadDuration > 0 {
|
||||
bytesPerSec := float64(totalBytesUploaded) / uploadDuration.Seconds()
|
||||
bitsPerSec := bytesPerSec * 8
|
||||
if bitsPerSec >= 1e9 {
|
||||
avgUploadSpeed = fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9)
|
||||
} else if bitsPerSec >= 1e6 {
|
||||
avgUploadSpeed = fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6)
|
||||
} else if bitsPerSec >= 1e3 {
|
||||
avgUploadSpeed = fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3)
|
||||
} else {
|
||||
avgUploadSpeed = fmt.Sprintf("%.0f bit/s", bitsPerSec)
|
||||
}
|
||||
} else {
|
||||
avgUploadSpeed = "N/A"
|
||||
}
|
||||
|
||||
// Get total blob sizes from database
|
||||
totalBlobSizeCompressed := int64(0)
|
||||
totalBlobSizeUncompressed := int64(0)
|
||||
if blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID); err == nil {
|
||||
for _, hash := range blobHashes {
|
||||
if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil {
|
||||
totalBlobSizeCompressed += blob.CompressedSize
|
||||
totalBlobSizeUncompressed += blob.UncompressedSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate compression ratio
|
||||
var compressionRatio float64
|
||||
if totalBlobSizeUncompressed > 0 {
|
||||
compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed)
|
||||
} else {
|
||||
compressionRatio = 1.0
|
||||
}
|
||||
|
||||
// Print comprehensive summary
|
||||
_, _ = fmt.Fprintf(v.Stdout, "=== Snapshot Complete ===\n")
|
||||
_, _ = fmt.Fprintf(v.Stdout, "ID: %s\n", snapshotID)
|
||||
_, _ = fmt.Fprintf(v.Stdout, "Files: %s examined, %s to process, %s unchanged",
|
||||
formatNumber(totalFiles),
|
||||
formatNumber(totalFilesChanged),
|
||||
formatNumber(totalFilesSkipped))
|
||||
if totalFilesDeleted > 0 {
|
||||
_, _ = fmt.Fprintf(v.Stdout, ", %s deleted", formatNumber(totalFilesDeleted))
|
||||
}
|
||||
_, _ = fmt.Fprintln(v.Stdout)
|
||||
_, _ = fmt.Fprintf(v.Stdout, "Data: %s total (%s to process)",
|
||||
humanize.Bytes(uint64(totalBytesAll)),
|
||||
humanize.Bytes(uint64(totalBytesChanged)))
|
||||
if totalBytesDeleted > 0 {
|
||||
_, _ = fmt.Fprintf(v.Stdout, ", %s deleted", humanize.Bytes(uint64(totalBytesDeleted)))
|
||||
}
|
||||
_, _ = fmt.Fprintln(v.Stdout)
|
||||
if totalBlobsUploaded > 0 {
|
||||
_, _ = fmt.Fprintf(v.Stdout, "Storage: %s compressed from %s (%.2fx)\n",
|
||||
humanize.Bytes(uint64(totalBlobSizeCompressed)),
|
||||
humanize.Bytes(uint64(totalBlobSizeUncompressed)),
|
||||
compressionRatio)
|
||||
_, _ = fmt.Fprintf(v.Stdout, "Upload: %d blobs, %s in %s (%s)\n",
|
||||
totalBlobsUploaded,
|
||||
humanize.Bytes(uint64(totalBytesUploaded)),
|
||||
formatDuration(uploadDuration),
|
||||
avgUploadSpeed)
|
||||
}
|
||||
_, _ = fmt.Fprintf(v.Stdout, "Duration: %s\n", formatDuration(snapshotDuration))
|
||||
|
||||
if opts.Prune {
|
||||
log.Info("Pruning enabled - will delete old snapshots after snapshot")
|
||||
// TODO: Implement pruning
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListSnapshots lists all snapshots
|
||||
func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
|
||||
// Get all remote snapshots
|
||||
remoteSnapshots := make(map[string]bool)
|
||||
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
|
||||
|
||||
for object := range objectCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing remote snapshots: %w", object.Err)
|
||||
}
|
||||
|
||||
// Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
|
||||
parts := strings.Split(object.Key, "/")
|
||||
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
|
||||
remoteSnapshots[parts[1]] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Get all local snapshots
|
||||
localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing local snapshots: %w", err)
|
||||
}
|
||||
|
||||
// Build a map of local snapshots for quick lookup
|
||||
localSnapshotMap := make(map[string]*database.Snapshot)
|
||||
for _, s := range localSnapshots {
|
||||
localSnapshotMap[s.ID] = s
|
||||
}
|
||||
|
||||
// Remove local snapshots that don't exist remotely
|
||||
for _, snapshot := range localSnapshots {
|
||||
if !remoteSnapshots[snapshot.ID] {
|
||||
log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID)
|
||||
|
||||
// Delete related records first to avoid foreign key constraints
|
||||
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshot.ID); err != nil {
|
||||
log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err)
|
||||
}
|
||||
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshot.ID); err != nil {
|
||||
log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err)
|
||||
}
|
||||
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshot.ID); err != nil {
|
||||
log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err)
|
||||
}
|
||||
|
||||
// Now delete the snapshot itself
|
||||
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil {
|
||||
log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err)
|
||||
} else {
|
||||
log.Info("Deleted local snapshot not found in remote", "snapshot_id", snapshot.ID)
|
||||
delete(localSnapshotMap, snapshot.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build final snapshot list
|
||||
snapshots := make([]SnapshotInfo, 0, len(remoteSnapshots))
|
||||
|
||||
for snapshotID := range remoteSnapshots {
|
||||
// Check if we have this snapshot locally
|
||||
if localSnap, exists := localSnapshotMap[snapshotID]; exists && localSnap.CompletedAt != nil {
|
||||
// Get total compressed size of all blobs referenced by this snapshot
|
||||
totalSize, err := v.Repositories.Snapshots.GetSnapshotTotalCompressedSize(v.ctx, snapshotID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get total compressed size", "id", snapshotID, "error", err)
|
||||
// Fall back to stored blob size
|
||||
totalSize = localSnap.BlobSize
|
||||
}
|
||||
|
||||
snapshots = append(snapshots, SnapshotInfo{
|
||||
ID: localSnap.ID,
|
||||
Timestamp: localSnap.StartedAt,
|
||||
CompressedSize: totalSize,
|
||||
})
|
||||
} else {
|
||||
// Remote snapshot not in local DB - fetch manifest to get size
|
||||
timestamp, err := parseSnapshotTimestamp(snapshotID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to download manifest to get size
|
||||
totalSize, err := v.getManifestSize(snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err)
|
||||
}
|
||||
|
||||
snapshots = append(snapshots, SnapshotInfo{
|
||||
ID: snapshotID,
|
||||
Timestamp: timestamp,
|
||||
CompressedSize: totalSize,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by timestamp (newest first)
|
||||
sort.Slice(snapshots, func(i, j int) bool {
|
||||
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
|
||||
})
|
||||
|
||||
if jsonOutput {
|
||||
// JSON output
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(snapshots)
|
||||
}
|
||||
|
||||
// Table output
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
if _, err := fmt.Fprintln(w, "SNAPSHOT ID\tTIMESTAMP\tCOMPRESSED SIZE"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintln(w, "───────────\t─────────\t───────────────"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, snap := range snapshots {
|
||||
if _, err := fmt.Fprintf(w, "%s\t%s\t%s\n",
|
||||
snap.ID,
|
||||
snap.Timestamp.Format("2006-01-02 15:04:05"),
|
||||
formatBytes(snap.CompressedSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// PurgeSnapshots removes old snapshots based on criteria
|
||||
func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool) error {
|
||||
// Sync with remote first
|
||||
if err := v.syncWithRemote(); err != nil {
|
||||
return fmt.Errorf("syncing with remote: %w", err)
|
||||
}
|
||||
|
||||
// Get snapshots from local database
|
||||
dbSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing snapshots: %w", err)
|
||||
}
|
||||
|
||||
// Convert to SnapshotInfo format, only including completed snapshots
|
||||
snapshots := make([]SnapshotInfo, 0, len(dbSnapshots))
|
||||
for _, s := range dbSnapshots {
|
||||
if s.CompletedAt != nil {
|
||||
snapshots = append(snapshots, SnapshotInfo{
|
||||
ID: s.ID,
|
||||
Timestamp: s.StartedAt,
|
||||
CompressedSize: s.BlobSize,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by timestamp (newest first)
|
||||
sort.Slice(snapshots, func(i, j int) bool {
|
||||
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
|
||||
})
|
||||
|
||||
var toDelete []SnapshotInfo
|
||||
|
||||
if keepLatest {
|
||||
// Keep only the most recent snapshot
|
||||
if len(snapshots) > 1 {
|
||||
toDelete = snapshots[1:]
|
||||
}
|
||||
} else if olderThan != "" {
|
||||
// Parse duration
|
||||
duration, err := parseDuration(olderThan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration: %w", err)
|
||||
}
|
||||
|
||||
cutoff := time.Now().UTC().Add(-duration)
|
||||
for _, snap := range snapshots {
|
||||
if snap.Timestamp.Before(cutoff) {
|
||||
toDelete = append(toDelete, snap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDelete) == 0 {
|
||||
fmt.Println("No snapshots to delete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Show what will be deleted
|
||||
fmt.Printf("The following snapshots will be deleted:\n\n")
|
||||
for _, snap := range toDelete {
|
||||
fmt.Printf(" %s (%s, %s)\n",
|
||||
snap.ID,
|
||||
snap.Timestamp.Format("2006-01-02 15:04:05"),
|
||||
formatBytes(snap.CompressedSize))
|
||||
}
|
||||
|
||||
// Confirm unless --force is used
|
||||
if !force {
|
||||
fmt.Printf("\nDelete %d snapshot(s)? [y/N] ", len(toDelete))
|
||||
var confirm string
|
||||
if _, err := fmt.Scanln(&confirm); err != nil {
|
||||
// Treat EOF or error as "no"
|
||||
fmt.Println("Cancelled")
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(confirm) != "y" {
|
||||
fmt.Println("Cancelled")
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete))
|
||||
}
|
||||
|
||||
// Delete snapshots
|
||||
for _, snap := range toDelete {
|
||||
log.Info("Deleting snapshot", "id", snap.ID)
|
||||
if err := v.deleteSnapshot(snap.ID); err != nil {
|
||||
return fmt.Errorf("deleting snapshot %s: %w", snap.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Deleted %d snapshot(s)\n", len(toDelete))
|
||||
|
||||
// Note: Run 'vaultik prune' separately to clean up unreferenced blobs
|
||||
fmt.Println("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifySnapshot checks snapshot integrity
|
||||
func (v *Vaultik) VerifySnapshot(snapshotID string, deep bool) error {
|
||||
// Parse snapshot ID to extract timestamp
|
||||
parts := strings.Split(snapshotID, "-")
|
||||
var snapshotTime time.Time
|
||||
if len(parts) >= 3 {
|
||||
// Format: hostname-YYYYMMDD-HHMMSSZ
|
||||
dateStr := parts[len(parts)-2]
|
||||
timeStr := parts[len(parts)-1]
|
||||
if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") {
|
||||
timeStr = timeStr[:6] // Remove Z
|
||||
timestamp, err := time.Parse("20060102150405", dateStr+timeStr)
|
||||
if err == nil {
|
||||
snapshotTime = timestamp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Verifying snapshot %s\n", snapshotID)
|
||||
if !snapshotTime.IsZero() {
|
||||
fmt.Printf("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST"))
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Download and parse manifest
|
||||
manifest, err := v.downloadManifest(snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("downloading manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Snapshot information:\n")
|
||||
fmt.Printf(" Blob count: %d\n", manifest.BlobCount)
|
||||
fmt.Printf(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize)))
|
||||
if manifest.Timestamp != "" {
|
||||
if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil {
|
||||
fmt.Printf(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST"))
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Check each blob exists
|
||||
fmt.Printf("Checking blob existence...\n")
|
||||
missing := 0
|
||||
verified := 0
|
||||
missingSize := int64(0)
|
||||
|
||||
for _, blob := range manifest.Blobs {
|
||||
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash)
|
||||
|
||||
if deep {
|
||||
// Download and verify hash
|
||||
// TODO: Implement deep verification
|
||||
fmt.Printf("Deep verification not yet implemented\n")
|
||||
return nil
|
||||
} else {
|
||||
// Just check existence
|
||||
_, err := v.S3Client.StatObject(v.ctx, blobPath)
|
||||
if err != nil {
|
||||
fmt.Printf(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize)))
|
||||
missing++
|
||||
missingSize += blob.CompressedSize
|
||||
} else {
|
||||
verified++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("\nVerification complete:\n")
|
||||
fmt.Printf(" Verified: %d blobs (%s)\n", verified,
|
||||
humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize)))
|
||||
if missing > 0 {
|
||||
fmt.Printf(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize)))
|
||||
} else {
|
||||
fmt.Printf(" Missing: 0 blobs\n")
|
||||
}
|
||||
fmt.Printf(" Status: ")
|
||||
if missing > 0 {
|
||||
fmt.Printf("FAILED - %d blobs are missing\n", missing)
|
||||
return fmt.Errorf("%d blobs are missing", missing)
|
||||
} else {
|
||||
fmt.Printf("OK - All blobs verified\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper methods that were previously on SnapshotApp
|
||||
|
||||
func (v *Vaultik) getManifestSize(snapshotID string) (int64, error) {
|
||||
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
|
||||
|
||||
reader, err := v.S3Client.GetObject(v.ctx, manifestPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("downloading manifest: %w", err)
|
||||
}
|
||||
defer func() { _ = reader.Close() }()
|
||||
|
||||
manifest, err := snapshot.DecodeManifest(reader)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decoding manifest: %w", err)
|
||||
}
|
||||
|
||||
return manifest.TotalCompressedSize, nil
|
||||
}
|
||||
|
||||
func (v *Vaultik) downloadManifest(snapshotID string) (*snapshot.Manifest, error) {
|
||||
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
|
||||
|
||||
reader, err := v.S3Client.GetObject(v.ctx, manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = reader.Close() }()
|
||||
|
||||
manifest, err := snapshot.DecodeManifest(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding manifest: %w", err)
|
||||
}
|
||||
|
||||
return manifest, nil
|
||||
}
|
||||
|
||||
func (v *Vaultik) deleteSnapshot(snapshotID string) error {
|
||||
// First, delete from S3
|
||||
// List all objects under metadata/{snapshotID}/
|
||||
prefix := fmt.Sprintf("metadata/%s/", snapshotID)
|
||||
objectCh := v.S3Client.ListObjectsStream(v.ctx, prefix, true)
|
||||
|
||||
var objectsToDelete []string
|
||||
for object := range objectCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing objects: %w", object.Err)
|
||||
}
|
||||
objectsToDelete = append(objectsToDelete, object.Key)
|
||||
}
|
||||
|
||||
// Delete all objects
|
||||
for _, key := range objectsToDelete {
|
||||
if err := v.S3Client.RemoveObject(v.ctx, key); err != nil {
|
||||
return fmt.Errorf("removing %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Then, delete from local database
|
||||
// Delete related records first to avoid foreign key constraints
|
||||
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotID); err != nil {
|
||||
log.Error("Failed to delete snapshot files", "snapshot_id", snapshotID, "error", err)
|
||||
}
|
||||
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotID); err != nil {
|
||||
log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshotID, "error", err)
|
||||
}
|
||||
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotID); err != nil {
|
||||
log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshotID, "error", err)
|
||||
}
|
||||
|
||||
// Now delete the snapshot itself
|
||||
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotID); err != nil {
|
||||
return fmt.Errorf("deleting snapshot from database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Vaultik) syncWithRemote() error {
|
||||
log.Info("Syncing with remote snapshots")
|
||||
|
||||
// Get all remote snapshot IDs
|
||||
remoteSnapshots := make(map[string]bool)
|
||||
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
|
||||
|
||||
for object := range objectCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing remote snapshots: %w", object.Err)
|
||||
}
|
||||
|
||||
// Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
|
||||
parts := strings.Split(object.Key, "/")
|
||||
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
|
||||
remoteSnapshots[parts[1]] = true
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Found remote snapshots", "count", len(remoteSnapshots))
|
||||
|
||||
// Get all local snapshots (use a high limit to get all)
|
||||
localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listing local snapshots: %w", err)
|
||||
}
|
||||
|
||||
// Remove local snapshots that don't exist remotely
|
||||
removedCount := 0
|
||||
for _, snapshot := range localSnapshots {
|
||||
if !remoteSnapshots[snapshot.ID] {
|
||||
log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID)
|
||||
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil {
|
||||
log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err)
|
||||
} else {
|
||||
removedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if removedCount > 0 {
|
||||
log.Info("Removed local snapshots not found in remote", "count", removedCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
124
internal/vaultik/vaultik.go
Normal file
124
internal/vaultik/vaultik.go
Normal file
@ -0,0 +1,124 @@
|
||||
package vaultik
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/config"
|
||||
"git.eeqj.de/sneak/vaultik/internal/crypto"
|
||||
"git.eeqj.de/sneak/vaultik/internal/database"
|
||||
"git.eeqj.de/sneak/vaultik/internal/globals"
|
||||
"git.eeqj.de/sneak/vaultik/internal/s3"
|
||||
"git.eeqj.de/sneak/vaultik/internal/snapshot"
|
||||
"github.com/spf13/afero"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// Vaultik contains all dependencies needed for vaultik operations
|
||||
type Vaultik struct {
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
DB *database.DB
|
||||
Repositories *database.Repositories
|
||||
S3Client *s3.Client
|
||||
ScannerFactory snapshot.ScannerFactory
|
||||
SnapshotManager *snapshot.SnapshotManager
|
||||
Shutdowner fx.Shutdowner
|
||||
Fs afero.Fs
|
||||
|
||||
// Context management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// IO
|
||||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
Stdin io.Reader
|
||||
}
|
||||
|
||||
// VaultikParams contains all parameters for New that can be provided by fx
|
||||
type VaultikParams struct {
|
||||
fx.In
|
||||
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
DB *database.DB
|
||||
Repositories *database.Repositories
|
||||
S3Client *s3.Client
|
||||
ScannerFactory snapshot.ScannerFactory
|
||||
SnapshotManager *snapshot.SnapshotManager
|
||||
Shutdowner fx.Shutdowner
|
||||
Fs afero.Fs `optional:"true"`
|
||||
}
|
||||
|
||||
// New creates a new Vaultik instance with proper context management
|
||||
// It automatically includes crypto capabilities if age_secret_key is configured
|
||||
func New(params VaultikParams) *Vaultik {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Use provided filesystem or default to OS filesystem
|
||||
fs := params.Fs
|
||||
if fs == nil {
|
||||
fs = afero.NewOsFs()
|
||||
}
|
||||
|
||||
// Set filesystem on SnapshotManager
|
||||
params.SnapshotManager.SetFilesystem(fs)
|
||||
|
||||
return &Vaultik{
|
||||
Globals: params.Globals,
|
||||
Config: params.Config,
|
||||
DB: params.DB,
|
||||
Repositories: params.Repositories,
|
||||
S3Client: params.S3Client,
|
||||
ScannerFactory: params.ScannerFactory,
|
||||
SnapshotManager: params.SnapshotManager,
|
||||
Shutdowner: params.Shutdowner,
|
||||
Fs: fs,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
Stdout: os.Stdout,
|
||||
Stderr: os.Stderr,
|
||||
Stdin: os.Stdin,
|
||||
}
|
||||
}
|
||||
|
||||
// Context returns the Vaultik's context
|
||||
func (v *Vaultik) Context() context.Context {
|
||||
return v.ctx
|
||||
}
|
||||
|
||||
// Cancel cancels the Vaultik's context
|
||||
func (v *Vaultik) Cancel() {
|
||||
v.cancel()
|
||||
}
|
||||
|
||||
// CanDecrypt returns true if this Vaultik instance has decryption capabilities
|
||||
func (v *Vaultik) CanDecrypt() bool {
|
||||
return v.Config.AgeSecretKey != ""
|
||||
}
|
||||
|
||||
// GetEncryptor creates a new Encryptor instance based on the configured age recipients
|
||||
// Returns an error if no recipients are configured
|
||||
func (v *Vaultik) GetEncryptor() (*crypto.Encryptor, error) {
|
||||
if len(v.Config.AgeRecipients) == 0 {
|
||||
return nil, fmt.Errorf("no age recipients configured")
|
||||
}
|
||||
return crypto.NewEncryptor(v.Config.AgeRecipients)
|
||||
}
|
||||
|
||||
// GetDecryptor creates a new Decryptor instance based on the configured age secret key
|
||||
// Returns an error if no secret key is configured
|
||||
func (v *Vaultik) GetDecryptor() (*crypto.Decryptor, error) {
|
||||
if v.Config.AgeSecretKey == "" {
|
||||
return nil, fmt.Errorf("no age secret key configured")
|
||||
}
|
||||
return crypto.NewDecryptor(v.Config.AgeSecretKey)
|
||||
}
|
||||
|
||||
// GetFilesystem returns the filesystem instance used by Vaultik
|
||||
func (v *Vaultik) GetFilesystem() afero.Fs {
|
||||
return v.Fs
|
||||
}
|
396
internal/vaultik/verify.go
Normal file
396
internal/vaultik/verify.go
Normal file
@ -0,0 +1,396 @@
|
||||
package vaultik
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"git.eeqj.de/sneak/vaultik/internal/snapshot"
|
||||
"github.com/dustin/go-humanize"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// VerifyOptions contains options for the verify command
|
||||
type VerifyOptions struct {
|
||||
Deep bool
|
||||
}
|
||||
|
||||
// RunDeepVerify executes deep verification operation
|
||||
func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
|
||||
// Check for decryption capability
|
||||
if !v.CanDecrypt() {
|
||||
return fmt.Errorf("age_secret_key missing from config - required for deep verification")
|
||||
}
|
||||
|
||||
log.Info("Starting snapshot verification",
|
||||
"snapshot_id", snapshotID,
|
||||
"mode", map[bool]string{true: "deep", false: "shallow"}[opts.Deep],
|
||||
)
|
||||
|
||||
// Step 1: Download manifest
|
||||
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
|
||||
log.Info("Downloading manifest", "path", manifestPath)
|
||||
|
||||
manifestReader, err := v.S3Client.GetObject(v.ctx, manifestPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download manifest: %w", err)
|
||||
}
|
||||
defer func() { _ = manifestReader.Close() }()
|
||||
|
||||
// Decompress manifest
|
||||
manifest, err := snapshot.DecodeManifest(manifestReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode manifest: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Manifest loaded",
|
||||
"blob_count", manifest.BlobCount,
|
||||
"total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize)),
|
||||
)
|
||||
|
||||
// Step 2: Download and decrypt database
|
||||
dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
|
||||
log.Info("Downloading encrypted database", "path", dbPath)
|
||||
|
||||
dbReader, err := v.S3Client.GetObject(v.ctx, dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download database: %w", err)
|
||||
}
|
||||
defer func() { _ = dbReader.Close() }()
|
||||
|
||||
// Decrypt and decompress database
|
||||
tempDB, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if tempDB != nil {
|
||||
_ = tempDB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Step 3: Compare blob lists
|
||||
if err := v.verifyBlobLists(snapshotID, manifest, tempDB.DB); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 4: Verify blob existence
|
||||
if err := v.verifyBlobExistence(manifest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 5: Deep verification if requested
|
||||
if opts.Deep {
|
||||
if err := v.performDeepVerification(manifest, tempDB.DB); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("✓ Verification completed successfully",
|
||||
"snapshot_id", snapshotID,
|
||||
"mode", map[bool]string{true: "deep", false: "shallow"}[opts.Deep],
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tempDB wraps sql.DB with cleanup
|
||||
type tempDB struct {
|
||||
*sql.DB
|
||||
tempPath string
|
||||
}
|
||||
|
||||
func (t *tempDB) Close() error {
|
||||
err := t.DB.Close()
|
||||
_ = os.Remove(t.tempPath)
|
||||
return err
|
||||
}
|
||||
|
||||
// decryptAndLoadDatabase decrypts and loads the database from the encrypted stream
|
||||
func (v *Vaultik) decryptAndLoadDatabase(reader io.ReadCloser, secretKey string) (*tempDB, error) {
|
||||
// Get decryptor
|
||||
decryptor, err := v.GetDecryptor()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get decryptor: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the stream
|
||||
decryptedReader, err := decryptor.DecryptStream(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt database: %w", err)
|
||||
}
|
||||
|
||||
// Decompress the database
|
||||
decompressor, err := zstd.NewReader(decryptedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create decompressor: %w", err)
|
||||
}
|
||||
defer decompressor.Close()
|
||||
|
||||
// Create temporary file for database
|
||||
tempFile, err := os.CreateTemp("", "vaultik-verify-*.db")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tempPath := tempFile.Name()
|
||||
|
||||
// Copy decompressed data to temp file
|
||||
if _, err := io.Copy(tempFile, decompressor); err != nil {
|
||||
_ = tempFile.Close()
|
||||
_ = os.Remove(tempPath)
|
||||
return nil, fmt.Errorf("failed to write database: %w", err)
|
||||
}
|
||||
|
||||
// Close temp file before opening with sqlite
|
||||
if err := tempFile.Close(); err != nil {
|
||||
_ = os.Remove(tempPath)
|
||||
return nil, fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
// Open the database
|
||||
db, err := sql.Open("sqlite3", tempPath)
|
||||
if err != nil {
|
||||
_ = os.Remove(tempPath)
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
return &tempDB{
|
||||
DB: db,
|
||||
tempPath: tempPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// verifyBlobLists compares the blob lists between manifest and database
|
||||
func (v *Vaultik) verifyBlobLists(snapshotID string, manifest *snapshot.Manifest, db *sql.DB) error {
|
||||
log.Info("Verifying blob lists match between manifest and database")
|
||||
|
||||
// Get blobs from database
|
||||
query := `
|
||||
SELECT b.blob_hash, b.compressed_size
|
||||
FROM snapshot_blobs sb
|
||||
JOIN blobs b ON sb.blob_hash = b.blob_hash
|
||||
WHERE sb.snapshot_id = ?
|
||||
ORDER BY b.blob_hash
|
||||
`
|
||||
rows, err := db.QueryContext(v.ctx, query, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query snapshot blobs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
// Build map of database blobs
|
||||
dbBlobs := make(map[string]int64)
|
||||
for rows.Next() {
|
||||
var hash string
|
||||
var size int64
|
||||
if err := rows.Scan(&hash, &size); err != nil {
|
||||
return fmt.Errorf("failed to scan blob row: %w", err)
|
||||
}
|
||||
dbBlobs[hash] = size
|
||||
}
|
||||
|
||||
// Build map of manifest blobs
|
||||
manifestBlobs := make(map[string]int64)
|
||||
for _, blob := range manifest.Blobs {
|
||||
manifestBlobs[blob.Hash] = blob.CompressedSize
|
||||
}
|
||||
|
||||
// Compare counts
|
||||
if len(dbBlobs) != len(manifestBlobs) {
|
||||
return fmt.Errorf("blob count mismatch: database has %d blobs, manifest has %d blobs",
|
||||
len(dbBlobs), len(manifestBlobs))
|
||||
}
|
||||
|
||||
// Check each blob exists in both
|
||||
for hash, dbSize := range dbBlobs {
|
||||
manifestSize, exists := manifestBlobs[hash]
|
||||
if !exists {
|
||||
return fmt.Errorf("blob %s exists in database but not in manifest", hash)
|
||||
}
|
||||
if dbSize != manifestSize {
|
||||
return fmt.Errorf("blob %s size mismatch: database has %d bytes, manifest has %d bytes",
|
||||
hash, dbSize, manifestSize)
|
||||
}
|
||||
}
|
||||
|
||||
for hash := range manifestBlobs {
|
||||
if _, exists := dbBlobs[hash]; !exists {
|
||||
return fmt.Errorf("blob %s exists in manifest but not in database", hash)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("✓ Blob lists match", "blob_count", len(dbBlobs))
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyBlobExistence checks that all blobs exist in S3
|
||||
func (v *Vaultik) verifyBlobExistence(manifest *snapshot.Manifest) error {
|
||||
log.Info("Verifying blob existence in S3", "blob_count", len(manifest.Blobs))
|
||||
|
||||
for i, blob := range manifest.Blobs {
|
||||
// Construct blob path
|
||||
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash)
|
||||
|
||||
// Check blob exists with HeadObject
|
||||
stat, err := v.S3Client.StatObject(v.ctx, blobPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("blob %s missing from S3: %w", blob.Hash, err)
|
||||
}
|
||||
|
||||
// Verify size matches
|
||||
if stat.Size != blob.CompressedSize {
|
||||
return fmt.Errorf("blob %s size mismatch: S3 has %d bytes, manifest has %d bytes",
|
||||
blob.Hash, stat.Size, blob.CompressedSize)
|
||||
}
|
||||
|
||||
// Progress update every 100 blobs
|
||||
if (i+1)%100 == 0 || i == len(manifest.Blobs)-1 {
|
||||
log.Info("Blob existence check progress",
|
||||
"checked", i+1,
|
||||
"total", len(manifest.Blobs),
|
||||
"percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(manifest.Blobs))*100),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("✓ All blobs exist in S3")
|
||||
return nil
|
||||
}
|
||||
|
||||
// performDeepVerification downloads and verifies the content of each blob
|
||||
func (v *Vaultik) performDeepVerification(manifest *snapshot.Manifest, db *sql.DB) error {
|
||||
log.Info("Starting deep verification - downloading and verifying all blobs")
|
||||
|
||||
totalBytes := int64(0)
|
||||
for i, blobInfo := range manifest.Blobs {
|
||||
// Verify individual blob
|
||||
if err := v.verifyBlob(blobInfo, db); err != nil {
|
||||
return fmt.Errorf("blob %s verification failed: %w", blobInfo.Hash, err)
|
||||
}
|
||||
|
||||
totalBytes += blobInfo.CompressedSize
|
||||
|
||||
// Progress update
|
||||
log.Info("Deep verification progress",
|
||||
"blob", fmt.Sprintf("%d/%d", i+1, len(manifest.Blobs)),
|
||||
"total_downloaded", humanize.Bytes(uint64(totalBytes)),
|
||||
"percent", fmt.Sprintf("%.1f%%", float64(i+1)/float64(len(manifest.Blobs))*100),
|
||||
)
|
||||
}
|
||||
|
||||
log.Info("✓ Deep verification completed successfully",
|
||||
"blobs_verified", len(manifest.Blobs),
|
||||
"total_size", humanize.Bytes(uint64(totalBytes)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyBlob downloads and verifies a single blob
|
||||
func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error {
|
||||
// Download blob
|
||||
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobInfo.Hash[:2], blobInfo.Hash[2:4], blobInfo.Hash)
|
||||
reader, err := v.S3Client.GetObject(v.ctx, blobPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download: %w", err)
|
||||
}
|
||||
defer func() { _ = reader.Close() }()
|
||||
|
||||
// Get decryptor
|
||||
decryptor, err := v.GetDecryptor()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get decryptor: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt blob
|
||||
decryptedReader, err := decryptor.DecryptStream(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
// Decompress blob
|
||||
decompressor, err := zstd.NewReader(decryptedReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decompress: %w", err)
|
||||
}
|
||||
defer decompressor.Close()
|
||||
|
||||
// Query blob chunks from database to get offsets and lengths
|
||||
query := `
|
||||
SELECT bc.chunk_hash, bc.offset, bc.length
|
||||
FROM blob_chunks bc
|
||||
JOIN blobs b ON bc.blob_id = b.id
|
||||
WHERE b.blob_hash = ?
|
||||
ORDER BY bc.offset
|
||||
`
|
||||
rows, err := db.QueryContext(v.ctx, query, blobInfo.Hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query blob chunks: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var lastOffset int64 = -1
|
||||
chunkCount := 0
|
||||
totalRead := int64(0)
|
||||
|
||||
// Verify each chunk in the blob
|
||||
for rows.Next() {
|
||||
var chunkHash string
|
||||
var offset, length int64
|
||||
if err := rows.Scan(&chunkHash, &offset, &length); err != nil {
|
||||
return fmt.Errorf("failed to scan chunk row: %w", err)
|
||||
}
|
||||
|
||||
// Verify chunk ordering
|
||||
if offset <= lastOffset {
|
||||
return fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset)
|
||||
}
|
||||
lastOffset = offset
|
||||
|
||||
// Read chunk data from decompressed stream
|
||||
if offset > totalRead {
|
||||
// Skip to the correct offset
|
||||
skipBytes := offset - totalRead
|
||||
if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil {
|
||||
return fmt.Errorf("failed to skip to offset %d: %w", offset, err)
|
||||
}
|
||||
totalRead = offset
|
||||
}
|
||||
|
||||
// Read chunk data
|
||||
chunkData := make([]byte, length)
|
||||
if _, err := io.ReadFull(decompressor, chunkData); err != nil {
|
||||
return fmt.Errorf("failed to read chunk at offset %d: %w", offset, err)
|
||||
}
|
||||
totalRead += length
|
||||
|
||||
// Verify chunk hash
|
||||
hasher := sha256.New()
|
||||
hasher.Write(chunkData)
|
||||
calculatedHash := hex.EncodeToString(hasher.Sum(nil))
|
||||
|
||||
if calculatedHash != chunkHash {
|
||||
return fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s",
|
||||
offset, calculatedHash, chunkHash)
|
||||
}
|
||||
|
||||
chunkCount++
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return fmt.Errorf("error iterating blob chunks: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Blob verified",
|
||||
"hash", blobInfo.Hash,
|
||||
"chunks", chunkCount,
|
||||
"size", humanize.Bytes(uint64(blobInfo.CompressedSize)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user