Refactor blob storage to use UUID primary keys and implement streaming chunking

- Changed blob table to use ID (UUID) as primary key instead of hash
- Blob records are now created at packing start, enabling immediate chunk associations
- Implemented streaming chunking to process large files without memory exhaustion
- Fixed blob manifest generation to include all referenced blobs
- Updated all foreign key references from blob_hash to blob_id
- Added progress reporting and improved error handling
- Enforced encryption requirement for all blob packing
- Updated tests to use test encryption keys
- Added Cyrillic transliteration to README
This commit is contained in:
2025-07-22 07:43:39 +02:00
parent 26db096913
commit 86b533d6ee
49 changed files with 5709 additions and 324 deletions

View File

@@ -0,0 +1,524 @@
package backup
import (
"context"
"crypto/sha256"
"database/sql"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"testing"
"testing/fstest"
"time"
"git.eeqj.de/sneak/vaultik/internal/database"
)
// MockS3Client is a mock implementation of S3 operations for testing
type MockS3Client struct {
storage map[string][]byte
}
func NewMockS3Client() *MockS3Client {
return &MockS3Client{
storage: make(map[string][]byte),
}
}
func (m *MockS3Client) PutBlob(ctx context.Context, hash string, data []byte) error {
m.storage[hash] = data
return nil
}
func (m *MockS3Client) GetBlob(ctx context.Context, hash string) ([]byte, error) {
data, ok := m.storage[hash]
if !ok {
return nil, fmt.Errorf("blob not found: %s", hash)
}
return data, nil
}
func (m *MockS3Client) BlobExists(ctx context.Context, hash string) (bool, error) {
_, ok := m.storage[hash]
return ok, nil
}
func (m *MockS3Client) CreateBucket(ctx context.Context, bucket string) error {
return nil
}
func TestBackupWithInMemoryFS(t *testing.T) {
// Create a temporary directory for the database
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")
// Create test filesystem
testFS := fstest.MapFS{
"file1.txt": &fstest.MapFile{
Data: []byte("Hello, World!"),
Mode: 0644,
ModTime: time.Now(),
},
"dir1/file2.txt": &fstest.MapFile{
Data: []byte("This is a test file with some content."),
Mode: 0755,
ModTime: time.Now(),
},
"dir1/subdir/file3.txt": &fstest.MapFile{
Data: []byte("Another file in a subdirectory."),
Mode: 0600,
ModTime: time.Now(),
},
"largefile.bin": &fstest.MapFile{
Data: generateLargeFileContent(10 * 1024 * 1024), // 10MB file with varied content
Mode: 0644,
ModTime: time.Now(),
},
}
// Initialize the database
ctx := context.Background()
db, err := database.New(ctx, dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer func() {
if err := db.Close(); err != nil {
t.Logf("Failed to close database: %v", err)
}
}()
repos := database.NewRepositories(db)
// Create mock S3 client
s3Client := NewMockS3Client()
// Run backup
backupEngine := &BackupEngine{
repos: repos,
s3Client: s3Client,
}
snapshotID, err := backupEngine.Backup(ctx, testFS, ".")
if err != nil {
t.Fatalf("Backup failed: %v", err)
}
// Verify snapshot was created
snapshot, err := repos.Snapshots.GetByID(ctx, snapshotID)
if err != nil {
t.Fatalf("Failed to get snapshot: %v", err)
}
if snapshot == nil {
t.Fatal("Snapshot not found")
}
if snapshot.FileCount == 0 {
t.Error("Expected snapshot to have files")
}
// Verify files in database
files, err := repos.Files.ListByPrefix(ctx, "")
if err != nil {
t.Fatalf("Failed to list files: %v", err)
}
expectedFiles := map[string]bool{
"file1.txt": true,
"dir1/file2.txt": true,
"dir1/subdir/file3.txt": true,
"largefile.bin": true,
}
if len(files) != len(expectedFiles) {
t.Errorf("Expected %d files, got %d", len(expectedFiles), len(files))
}
for _, file := range files {
if !expectedFiles[file.Path] {
t.Errorf("Unexpected file in database: %s", file.Path)
}
delete(expectedFiles, file.Path)
// Verify file metadata
fsFile := testFS[file.Path]
if fsFile == nil {
t.Errorf("File %s not found in test filesystem", file.Path)
continue
}
if file.Size != int64(len(fsFile.Data)) {
t.Errorf("File %s: expected size %d, got %d", file.Path, len(fsFile.Data), file.Size)
}
if file.Mode != uint32(fsFile.Mode) {
t.Errorf("File %s: expected mode %o, got %o", file.Path, fsFile.Mode, file.Mode)
}
}
if len(expectedFiles) > 0 {
t.Errorf("Files not found in database: %v", expectedFiles)
}
// Verify chunks
chunks, err := repos.Chunks.List(ctx)
if err != nil {
t.Fatalf("Failed to list chunks: %v", err)
}
if len(chunks) == 0 {
t.Error("No chunks found in database")
}
// The large file should create 10 chunks (10MB / 1MB chunk size)
// Plus the small files
minExpectedChunks := 10 + 3
if len(chunks) < minExpectedChunks {
t.Errorf("Expected at least %d chunks, got %d", minExpectedChunks, len(chunks))
}
// Verify at least one blob was created and uploaded
// We can't list blobs directly, but we can check via snapshot blobs
blobHashes, err := repos.Snapshots.GetBlobHashes(ctx, snapshotID)
if err != nil {
t.Fatalf("Failed to get blob hashes: %v", err)
}
if len(blobHashes) == 0 {
t.Error("Expected at least one blob to be created")
}
for _, blobHash := range blobHashes {
// Check blob exists in mock S3
exists, err := s3Client.BlobExists(ctx, blobHash)
if err != nil {
t.Errorf("Failed to check blob %s: %v", blobHash, err)
}
if !exists {
t.Errorf("Blob %s not found in S3", blobHash)
}
}
}
func TestBackupDeduplication(t *testing.T) {
// Create a temporary directory for the database
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")
// Create test filesystem with duplicate content
testFS := fstest.MapFS{
"file1.txt": &fstest.MapFile{
Data: []byte("Duplicate content"),
Mode: 0644,
ModTime: time.Now(),
},
"file2.txt": &fstest.MapFile{
Data: []byte("Duplicate content"),
Mode: 0644,
ModTime: time.Now(),
},
"file3.txt": &fstest.MapFile{
Data: []byte("Unique content"),
Mode: 0644,
ModTime: time.Now(),
},
}
// Initialize the database
ctx := context.Background()
db, err := database.New(ctx, dbPath)
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer func() {
if err := db.Close(); err != nil {
t.Logf("Failed to close database: %v", err)
}
}()
repos := database.NewRepositories(db)
// Create mock S3 client
s3Client := NewMockS3Client()
// Run backup
backupEngine := &BackupEngine{
repos: repos,
s3Client: s3Client,
}
_, err = backupEngine.Backup(ctx, testFS, ".")
if err != nil {
t.Fatalf("Backup failed: %v", err)
}
// Verify deduplication
chunks, err := repos.Chunks.List(ctx)
if err != nil {
t.Fatalf("Failed to list chunks: %v", err)
}
// Should have only 2 unique chunks (duplicate content + unique content)
if len(chunks) != 2 {
t.Errorf("Expected 2 unique chunks, got %d", len(chunks))
}
// Verify chunk references
for _, chunk := range chunks {
files, err := repos.ChunkFiles.GetByChunkHash(ctx, chunk.ChunkHash)
if err != nil {
t.Errorf("Failed to get files for chunk %s: %v", chunk.ChunkHash, err)
}
// The duplicate content chunk should be referenced by 2 files
if chunk.Size == int64(len("Duplicate content")) && len(files) != 2 {
t.Errorf("Expected duplicate chunk to be referenced by 2 files, got %d", len(files))
}
}
}
// BackupEngine performs backup operations
type BackupEngine struct {
repos *database.Repositories
s3Client interface {
PutBlob(ctx context.Context, hash string, data []byte) error
BlobExists(ctx context.Context, hash string) (bool, error)
}
}
// Backup performs a backup of the given filesystem
func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (string, error) {
// Create a new snapshot
hostname, _ := os.Hostname()
snapshotID := time.Now().Format(time.RFC3339)
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: hostname,
VaultikVersion: "test",
StartedAt: time.Now(),
CompletedAt: nil,
}
// Create initial snapshot record
err := b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return b.repos.Snapshots.Create(ctx, tx, snapshot)
})
if err != nil {
return "", err
}
// Track counters
var fileCount, chunkCount, blobCount, totalSize, blobSize int64
// Track which chunks we've seen to handle deduplication
processedChunks := make(map[string]bool)
// Scan the filesystem and process files
err = fs.WalkDir(fsys, root, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
// Skip directories
if d.IsDir() {
return nil
}
// Get file info
info, err := d.Info()
if err != nil {
return err
}
// Handle symlinks
if info.Mode()&fs.ModeSymlink != 0 {
// For testing, we'll skip symlinks since fstest doesn't support them well
return nil
}
// Process this file in a transaction
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
// Create file record
file := &database.File{
Path: path,
Size: info.Size(),
Mode: uint32(info.Mode()),
MTime: info.ModTime(),
CTime: info.ModTime(), // Use mtime as ctime for test
UID: 1000, // Default UID for test
GID: 1000, // Default GID for test
}
if err := b.repos.Files.Create(ctx, tx, file); err != nil {
return err
}
fileCount++
totalSize += info.Size()
// Read and process file in chunks
f, err := fsys.Open(path)
if err != nil {
return err
}
defer func() {
if err := f.Close(); err != nil {
// Log but don't fail since we're already in an error path potentially
fmt.Fprintf(os.Stderr, "Failed to close file: %v\n", err)
}
}()
// Process file in chunks
chunkIndex := 0
buffer := make([]byte, defaultChunkSize)
for {
n, err := f.Read(buffer)
if err != nil && err != io.EOF {
return err
}
if n == 0 {
break
}
chunkData := buffer[:n]
chunkHash := calculateHash(chunkData)
// Check if chunk already exists
existingChunk, _ := b.repos.Chunks.GetByHash(ctx, chunkHash)
if existingChunk == nil {
// Create new chunk
chunk := &database.Chunk{
ChunkHash: chunkHash,
SHA256: chunkHash,
Size: int64(n),
}
if err := b.repos.Chunks.Create(ctx, tx, chunk); err != nil {
return err
}
processedChunks[chunkHash] = true
}
// Create file-chunk mapping
fileChunk := &database.FileChunk{
Path: path,
Idx: chunkIndex,
ChunkHash: chunkHash,
}
if err := b.repos.FileChunks.Create(ctx, tx, fileChunk); err != nil {
return err
}
// Create chunk-file mapping
chunkFile := &database.ChunkFile{
ChunkHash: chunkHash,
FilePath: path,
FileOffset: int64(chunkIndex * defaultChunkSize),
Length: int64(n),
}
if err := b.repos.ChunkFiles.Create(ctx, tx, chunkFile); err != nil {
return err
}
chunkIndex++
}
return nil
})
return err
})
if err != nil {
return "", err
}
// After all files are processed, create blobs for new chunks
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
for chunkHash := range processedChunks {
// Get chunk data
chunk, err := b.repos.Chunks.GetByHash(ctx, chunkHash)
if err != nil {
return err
}
chunkCount++
// In a real system, blobs would contain multiple chunks and be encrypted
// For testing, we'll create a blob with a "blob-" prefix to differentiate
blobHash := "blob-" + chunkHash
// For the test, we'll create dummy data since we don't have the original
dummyData := []byte(chunkHash)
// Upload to S3 as a blob
if err := b.s3Client.PutBlob(ctx, blobHash, dummyData); err != nil {
return err
}
// Create blob entry
blob := &database.Blob{
ID: "test-blob-" + blobHash[:8],
Hash: blobHash,
CreatedTS: time.Now(),
}
if err := b.repos.Blobs.Create(ctx, tx, blob); err != nil {
return err
}
blobCount++
blobSize += chunk.Size
// Create blob-chunk mapping
blobChunk := &database.BlobChunk{
BlobID: blob.ID,
ChunkHash: chunkHash,
Offset: 0,
Length: chunk.Size,
}
if err := b.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil {
return err
}
// Add blob to snapshot
if err := b.repos.Snapshots.AddBlob(ctx, tx, snapshotID, blob.ID, blob.Hash); err != nil {
return err
}
}
return nil
})
if err != nil {
return "", err
}
// Update snapshot with final counts
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return b.repos.Snapshots.UpdateCounts(ctx, tx, snapshotID, fileCount, chunkCount, blobCount, totalSize, blobSize)
})
if err != nil {
return "", err
}
return snapshotID, nil
}
func calculateHash(data []byte) string {
h := sha256.New()
h.Write(data)
return fmt.Sprintf("%x", h.Sum(nil))
}
func generateLargeFileContent(size int) []byte {
data := make([]byte, size)
// Fill with pattern that changes every chunk to avoid deduplication
for i := 0; i < size; i++ {
chunkNum := i / defaultChunkSize
data[i] = byte((i + chunkNum) % 256)
}
return data
}
const defaultChunkSize = 1024 * 1024 // 1MB chunks

View File

@@ -1,6 +1,39 @@
package backup
import "go.uber.org/fx"
import (
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/s3"
"github.com/spf13/afero"
"go.uber.org/fx"
)
// ScannerParams holds parameters for scanner creation
type ScannerParams struct {
EnableProgress bool
}
// Module exports backup functionality
var Module = fx.Module("backup")
var Module = fx.Module("backup",
fx.Provide(
provideScannerFactory,
),
)
// ScannerFactory creates scanners with custom parameters
type ScannerFactory func(params ScannerParams) *Scanner
func provideScannerFactory(cfg *config.Config, repos *database.Repositories, s3Client *s3.Client) ScannerFactory {
return func(params ScannerParams) *Scanner {
return NewScanner(ScannerConfig{
FS: afero.NewOsFs(),
ChunkSize: cfg.ChunkSize.Int64(),
Repositories: repos,
S3Client: s3Client,
MaxBlobSize: cfg.BlobSizeLimit.Int64(),
CompressionLevel: cfg.CompressionLevel,
AgeRecipients: cfg.AgeRecipients,
EnableProgress: params.EnableProgress,
})
}
}

389
internal/backup/progress.go Normal file
View File

@@ -0,0 +1,389 @@
package backup
import (
"context"
"fmt"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/dustin/go-humanize"
)
const (
// Progress reporting intervals
SummaryInterval = 10 * time.Second // One-line status updates
DetailInterval = 60 * time.Second // Multi-line detailed status
)
// ProgressStats holds atomic counters for progress tracking
type ProgressStats struct {
FilesScanned atomic.Int64 // Total files seen during scan (includes skipped)
FilesProcessed atomic.Int64 // Files actually processed in phase 2
FilesSkipped atomic.Int64 // Files skipped due to no changes
BytesScanned atomic.Int64 // Bytes from new/changed files only
BytesSkipped atomic.Int64 // Bytes from unchanged files
BytesProcessed atomic.Int64 // Actual bytes processed (for ETA calculation)
ChunksCreated atomic.Int64
BlobsCreated atomic.Int64
BlobsUploaded atomic.Int64
BytesUploaded atomic.Int64
CurrentFile atomic.Value // stores string
TotalSize atomic.Int64 // Total size to process (set after scan phase)
TotalFiles atomic.Int64 // Total files to process in phase 2
ProcessStartTime atomic.Value // stores time.Time when processing starts
StartTime time.Time
mu sync.RWMutex
lastDetailTime time.Time
// Upload tracking
CurrentUpload atomic.Value // stores *UploadInfo
lastChunkingTime time.Time // Track when we last showed chunking progress
}
// UploadInfo tracks current upload progress
type UploadInfo struct {
BlobHash string
Size int64
StartTime time.Time
}
// ProgressReporter handles periodic progress reporting
type ProgressReporter struct {
stats *ProgressStats
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
detailTicker *time.Ticker
summaryTicker *time.Ticker
sigChan chan os.Signal
}
// NewProgressReporter creates a new progress reporter
func NewProgressReporter() *ProgressReporter {
stats := &ProgressStats{
StartTime: time.Now(),
lastDetailTime: time.Now(),
}
stats.CurrentFile.Store("")
ctx, cancel := context.WithCancel(context.Background())
pr := &ProgressReporter{
stats: stats,
ctx: ctx,
cancel: cancel,
summaryTicker: time.NewTicker(SummaryInterval),
detailTicker: time.NewTicker(DetailInterval),
sigChan: make(chan os.Signal, 1),
}
// Register for SIGUSR1
signal.Notify(pr.sigChan, syscall.SIGUSR1)
return pr
}
// Start begins the progress reporting
func (pr *ProgressReporter) Start() {
pr.wg.Add(1)
go pr.run()
// Print initial multi-line status
pr.printDetailedStatus()
}
// Stop stops the progress reporting
func (pr *ProgressReporter) Stop() {
pr.cancel()
pr.summaryTicker.Stop()
pr.detailTicker.Stop()
signal.Stop(pr.sigChan)
close(pr.sigChan)
pr.wg.Wait()
}
// GetStats returns the progress stats for updating
func (pr *ProgressReporter) GetStats() *ProgressStats {
return pr.stats
}
// SetTotalSize sets the total size to process (after scan phase)
func (pr *ProgressReporter) SetTotalSize(size int64) {
pr.stats.TotalSize.Store(size)
pr.stats.ProcessStartTime.Store(time.Now())
}
// run is the main progress reporting loop
func (pr *ProgressReporter) run() {
defer pr.wg.Done()
for {
select {
case <-pr.ctx.Done():
return
case <-pr.summaryTicker.C:
pr.printSummaryStatus()
case <-pr.detailTicker.C:
pr.printDetailedStatus()
case <-pr.sigChan:
// SIGUSR1 received, print detailed status
log.Info("SIGUSR1 received, printing detailed status")
pr.printDetailedStatus()
}
}
}
// printSummaryStatus prints a one-line status update
func (pr *ProgressReporter) printSummaryStatus() {
// Check if we're currently uploading
if uploadInfo, ok := pr.stats.CurrentUpload.Load().(*UploadInfo); ok && uploadInfo != nil {
// Show upload progress instead
pr.printUploadProgress(uploadInfo)
return
}
// Only show chunking progress if we've done chunking recently
pr.stats.mu.RLock()
timeSinceLastChunk := time.Since(pr.stats.lastChunkingTime)
pr.stats.mu.RUnlock()
if timeSinceLastChunk > SummaryInterval*2 {
// No recent chunking activity, don't show progress
return
}
elapsed := time.Since(pr.stats.StartTime)
bytesScanned := pr.stats.BytesScanned.Load()
bytesSkipped := pr.stats.BytesSkipped.Load()
bytesProcessed := pr.stats.BytesProcessed.Load()
totalSize := pr.stats.TotalSize.Load()
currentFile := pr.stats.CurrentFile.Load().(string)
// Calculate ETA if we have total size and are processing
etaStr := ""
if totalSize > 0 && bytesProcessed > 0 {
processStart, ok := pr.stats.ProcessStartTime.Load().(time.Time)
if ok && !processStart.IsZero() {
processElapsed := time.Since(processStart)
rate := float64(bytesProcessed) / processElapsed.Seconds()
if rate > 0 {
remainingBytes := totalSize - bytesProcessed
remainingSeconds := float64(remainingBytes) / rate
eta := time.Duration(remainingSeconds * float64(time.Second))
etaStr = fmt.Sprintf(" | ETA: %s", formatDuration(eta))
}
}
}
rate := float64(bytesScanned+bytesSkipped) / elapsed.Seconds()
// Show files processed / total files to process
filesProcessed := pr.stats.FilesProcessed.Load()
totalFiles := pr.stats.TotalFiles.Load()
status := fmt.Sprintf("Progress: %d/%d files, %s/%s (%.1f%%), %s/s%s",
filesProcessed,
totalFiles,
humanize.Bytes(uint64(bytesProcessed)),
humanize.Bytes(uint64(totalSize)),
float64(bytesProcessed)/float64(totalSize)*100,
humanize.Bytes(uint64(rate)),
etaStr,
)
if currentFile != "" {
status += fmt.Sprintf(" | Current: %s", truncatePath(currentFile, 40))
}
log.Info(status)
}
// printDetailedStatus prints a multi-line detailed status
func (pr *ProgressReporter) printDetailedStatus() {
pr.stats.mu.Lock()
pr.stats.lastDetailTime = time.Now()
pr.stats.mu.Unlock()
elapsed := time.Since(pr.stats.StartTime)
filesScanned := pr.stats.FilesScanned.Load()
filesSkipped := pr.stats.FilesSkipped.Load()
bytesScanned := pr.stats.BytesScanned.Load()
bytesSkipped := pr.stats.BytesSkipped.Load()
bytesProcessed := pr.stats.BytesProcessed.Load()
totalSize := pr.stats.TotalSize.Load()
chunksCreated := pr.stats.ChunksCreated.Load()
blobsCreated := pr.stats.BlobsCreated.Load()
blobsUploaded := pr.stats.BlobsUploaded.Load()
bytesUploaded := pr.stats.BytesUploaded.Load()
currentFile := pr.stats.CurrentFile.Load().(string)
totalBytes := bytesScanned + bytesSkipped
rate := float64(totalBytes) / elapsed.Seconds()
log.Notice("=== Backup Progress Report ===")
log.Info("Elapsed time", "duration", formatDuration(elapsed))
// Calculate and show ETA if we have data
if totalSize > 0 && bytesProcessed > 0 {
processStart, ok := pr.stats.ProcessStartTime.Load().(time.Time)
if ok && !processStart.IsZero() {
processElapsed := time.Since(processStart)
processRate := float64(bytesProcessed) / processElapsed.Seconds()
if processRate > 0 {
remainingBytes := totalSize - bytesProcessed
remainingSeconds := float64(remainingBytes) / processRate
eta := time.Duration(remainingSeconds * float64(time.Second))
percentComplete := float64(bytesProcessed) / float64(totalSize) * 100
log.Info("Overall progress",
"percent", fmt.Sprintf("%.1f%%", percentComplete),
"processed", humanize.Bytes(uint64(bytesProcessed)),
"total", humanize.Bytes(uint64(totalSize)),
"rate", humanize.Bytes(uint64(processRate))+"/s",
"eta", formatDuration(eta))
}
}
}
log.Info("Files processed",
"scanned", filesScanned,
"skipped", filesSkipped,
"total", filesScanned,
"skip_rate", formatPercent(filesSkipped, filesScanned))
log.Info("Data scanned",
"new", humanize.Bytes(uint64(bytesScanned)),
"skipped", humanize.Bytes(uint64(bytesSkipped)),
"total", humanize.Bytes(uint64(totalBytes)),
"scan_rate", humanize.Bytes(uint64(rate))+"/s")
log.Info("Chunks created", "count", chunksCreated)
log.Info("Blobs status",
"created", blobsCreated,
"uploaded", blobsUploaded,
"pending", blobsCreated-blobsUploaded)
log.Info("Upload progress",
"uploaded", humanize.Bytes(uint64(bytesUploaded)),
"compression_ratio", formatRatio(bytesUploaded, bytesScanned))
if currentFile != "" {
log.Info("Current file", "path", currentFile)
}
log.Notice("=============================")
}
// Helper functions
func formatDuration(d time.Duration) string {
if d < 0 {
return "unknown"
}
if d < time.Minute {
return fmt.Sprintf("%ds", int(d.Seconds()))
}
if d < time.Hour {
return fmt.Sprintf("%dm%ds", int(d.Minutes()), int(d.Seconds())%60)
}
return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60)
}
func formatPercent(numerator, denominator int64) string {
if denominator == 0 {
return "0.0%"
}
return fmt.Sprintf("%.1f%%", float64(numerator)/float64(denominator)*100)
}
func formatRatio(compressed, uncompressed int64) string {
if uncompressed == 0 {
return "1.00"
}
ratio := float64(compressed) / float64(uncompressed)
return fmt.Sprintf("%.2f", ratio)
}
func truncatePath(path string, maxLen int) string {
if len(path) <= maxLen {
return path
}
// Keep the last maxLen-3 characters and prepend "..."
return "..." + path[len(path)-(maxLen-3):]
}
// printUploadProgress prints upload progress
func (pr *ProgressReporter) printUploadProgress(info *UploadInfo) {
elapsed := time.Since(info.StartTime)
if elapsed < time.Millisecond {
elapsed = time.Millisecond // Avoid division by zero
}
bytesPerSec := float64(info.Size) / elapsed.Seconds()
bitsPerSec := bytesPerSec * 8
// Format speed in bits/second
var speedStr string
if bitsPerSec >= 1e9 {
speedStr = fmt.Sprintf("%.1fGbit/sec", bitsPerSec/1e9)
} else if bitsPerSec >= 1e6 {
speedStr = fmt.Sprintf("%.0fMbit/sec", bitsPerSec/1e6)
} else if bitsPerSec >= 1e3 {
speedStr = fmt.Sprintf("%.0fKbit/sec", bitsPerSec/1e3)
} else {
speedStr = fmt.Sprintf("%.0fbit/sec", bitsPerSec)
}
log.Info("Uploading blob",
"hash", info.BlobHash[:8]+"...",
"size", humanize.Bytes(uint64(info.Size)),
"elapsed", formatDuration(elapsed),
"speed", speedStr)
}
// ReportUploadStart marks the beginning of a blob upload
func (pr *ProgressReporter) ReportUploadStart(blobHash string, size int64) {
info := &UploadInfo{
BlobHash: blobHash,
Size: size,
StartTime: time.Now(),
}
pr.stats.CurrentUpload.Store(info)
}
// ReportUploadComplete marks the completion of a blob upload
func (pr *ProgressReporter) ReportUploadComplete(blobHash string, size int64, duration time.Duration) {
// Clear current upload
pr.stats.CurrentUpload.Store((*UploadInfo)(nil))
// Calculate speed
if duration < time.Millisecond {
duration = time.Millisecond
}
bytesPerSec := float64(size) / duration.Seconds()
bitsPerSec := bytesPerSec * 8
// Format speed
var speedStr string
if bitsPerSec >= 1e9 {
speedStr = fmt.Sprintf("%.1fGbit/sec", bitsPerSec/1e9)
} else if bitsPerSec >= 1e6 {
speedStr = fmt.Sprintf("%.0fMbit/sec", bitsPerSec/1e6)
} else if bitsPerSec >= 1e3 {
speedStr = fmt.Sprintf("%.0fKbit/sec", bitsPerSec/1e3)
} else {
speedStr = fmt.Sprintf("%.0fbit/sec", bitsPerSec)
}
log.Info("Blob uploaded",
"hash", blobHash[:8]+"...",
"size", humanize.Bytes(uint64(size)),
"duration", formatDuration(duration),
"speed", speedStr)
}
// UpdateChunkingActivity updates the last chunking time
func (pr *ProgressReporter) UpdateChunkingActivity() {
pr.stats.mu.Lock()
pr.stats.lastChunkingTime = time.Now()
pr.stats.mu.Unlock()
}

View File

@@ -2,71 +2,197 @@ package backup
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"git.eeqj.de/sneak/vaultik/internal/blob"
"git.eeqj.de/sneak/vaultik/internal/chunker"
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/dustin/go-humanize"
"github.com/spf13/afero"
)
// FileToProcess holds information about a file that needs processing
type FileToProcess struct {
Path string
FileInfo os.FileInfo
File *database.File
}
// Scanner scans directories and populates the database with file and chunk information
type Scanner struct {
fs afero.Fs
chunkSize int
repos *database.Repositories
fs afero.Fs
chunker *chunker.Chunker
packer *blob.Packer
repos *database.Repositories
s3Client S3Client
maxBlobSize int64
compressionLevel int
ageRecipient string
snapshotID string // Current snapshot being processed
progress *ProgressReporter
// Mutex for coordinating blob creation
packerMu sync.Mutex // Blocks chunk production during blob creation
// Context for cancellation
scanCtx context.Context
}
// S3Client interface for blob storage operations
type S3Client interface {
PutObject(ctx context.Context, key string, data io.Reader) error
}
// ScannerConfig contains configuration for the scanner
type ScannerConfig struct {
FS afero.Fs
ChunkSize int
Repositories *database.Repositories
FS afero.Fs
ChunkSize int64
Repositories *database.Repositories
S3Client S3Client
MaxBlobSize int64
CompressionLevel int
AgeRecipients []string // Optional, empty means no encryption
EnableProgress bool // Enable progress reporting
}
// ScanResult contains the results of a scan operation
type ScanResult struct {
FilesScanned int
BytesScanned int64
StartTime time.Time
EndTime time.Time
FilesScanned int
FilesSkipped int
BytesScanned int64
BytesSkipped int64
ChunksCreated int
BlobsCreated int
StartTime time.Time
EndTime time.Time
}
// NewScanner creates a new scanner instance
func NewScanner(cfg ScannerConfig) *Scanner {
// Create encryptor (required for blob packing)
if len(cfg.AgeRecipients) == 0 {
log.Error("No age recipients configured - encryption is required")
return nil
}
enc, err := crypto.NewEncryptor(cfg.AgeRecipients)
if err != nil {
log.Error("Failed to create encryptor", "error", err)
return nil
}
// Create blob packer with encryption
packerCfg := blob.PackerConfig{
MaxBlobSize: cfg.MaxBlobSize,
CompressionLevel: cfg.CompressionLevel,
Encryptor: enc,
Repositories: cfg.Repositories,
}
packer, err := blob.NewPacker(packerCfg)
if err != nil {
log.Error("Failed to create packer", "error", err)
return nil
}
var progress *ProgressReporter
if cfg.EnableProgress {
progress = NewProgressReporter()
}
return &Scanner{
fs: cfg.FS,
chunkSize: cfg.ChunkSize,
repos: cfg.Repositories,
fs: cfg.FS,
chunker: chunker.NewChunker(cfg.ChunkSize),
packer: packer,
repos: cfg.Repositories,
s3Client: cfg.S3Client,
maxBlobSize: cfg.MaxBlobSize,
compressionLevel: cfg.CompressionLevel,
ageRecipient: strings.Join(cfg.AgeRecipients, ","),
progress: progress,
}
}
// Scan scans a directory and populates the database
func (s *Scanner) Scan(ctx context.Context, path string) (*ScanResult, error) {
func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*ScanResult, error) {
s.snapshotID = snapshotID
s.scanCtx = ctx
result := &ScanResult{
StartTime: time.Now(),
}
// Start a transaction
err := s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return s.scanDirectory(ctx, tx, path, result)
})
if err != nil {
return nil, fmt.Errorf("scan failed: %w", err)
// Set blob handler for concurrent upload
if s.s3Client != nil {
log.Debug("Setting blob handler for S3 uploads")
s.packer.SetBlobHandler(s.handleBlobReady)
} else {
log.Debug("No S3 client configured, blobs will not be uploaded")
}
// Start progress reporting if enabled
if s.progress != nil {
s.progress.Start()
defer s.progress.Stop()
}
// Phase 1: Scan directory and collect files to process
log.Info("Phase 1: Scanning directory structure")
filesToProcess, err := s.scanPhase(ctx, path, result)
if err != nil {
return nil, fmt.Errorf("scan phase failed: %w", err)
}
// Calculate total size to process
var totalSizeToProcess int64
for _, file := range filesToProcess {
totalSizeToProcess += file.FileInfo.Size()
}
// Update progress with total size and file count
if s.progress != nil {
s.progress.SetTotalSize(totalSizeToProcess)
s.progress.GetStats().TotalFiles.Store(int64(len(filesToProcess)))
}
log.Info("Phase 1 complete",
"total_files", len(filesToProcess),
"total_size", humanize.Bytes(uint64(totalSizeToProcess)),
"files_skipped", result.FilesSkipped,
"bytes_skipped", humanize.Bytes(uint64(result.BytesSkipped)))
// Phase 2: Process files and create chunks
if len(filesToProcess) > 0 {
log.Info("Phase 2: Processing files and creating chunks")
if err := s.processPhase(ctx, filesToProcess, result); err != nil {
return nil, fmt.Errorf("process phase failed: %w", err)
}
}
// Get final stats from packer
blobs := s.packer.GetFinishedBlobs()
result.BlobsCreated += len(blobs)
result.EndTime = time.Now()
return result, nil
}
func (s *Scanner) scanDirectory(ctx context.Context, tx *sql.Tx, path string, result *ScanResult) error {
return afero.Walk(s.fs, path, func(path string, info os.FileInfo, err error) error {
// scanPhase performs the initial directory scan to identify files to process
func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult) ([]*FileToProcess, error) {
var filesToProcess []*FileToProcess
var mu sync.Mutex
log.Debug("Starting directory walk", "path", path)
err := afero.Walk(s.fs, path, func(path string, info os.FileInfo, err error) error {
log.Debug("Walking file", "path", path)
if err != nil {
log.Debug("Error walking path", "path", path, "error", err)
return err
}
@@ -77,21 +203,108 @@ func (s *Scanner) scanDirectory(ctx context.Context, tx *sql.Tx, path string, re
default:
}
// Skip directories
if info.IsDir() {
return nil
// Check file and update metadata
file, needsProcessing, err := s.checkFileAndUpdateMetadata(ctx, path, info, result)
if err != nil {
// Don't log context cancellation as an error
if err == context.Canceled {
return err
}
return fmt.Errorf("failed to check %s: %w", path, err)
}
// Process the file
if err := s.processFile(ctx, tx, path, info, result); err != nil {
return fmt.Errorf("failed to process %s: %w", path, err)
// If file needs processing, add to list
if needsProcessing && info.Mode().IsRegular() && info.Size() > 0 {
mu.Lock()
filesToProcess = append(filesToProcess, &FileToProcess{
Path: path,
FileInfo: info,
File: file,
})
mu.Unlock()
}
return nil
})
if err != nil {
return nil, err
}
return filesToProcess, nil
}
func (s *Scanner) processFile(ctx context.Context, tx *sql.Tx, path string, info os.FileInfo, result *ScanResult) error {
// processPhase processes the files that need backing up
func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProcess, result *ScanResult) error {
// Process each file
for _, fileToProcess := range filesToProcess {
// Update progress
if s.progress != nil {
s.progress.GetStats().CurrentFile.Store(fileToProcess.Path)
}
// Process file in streaming fashion
if err := s.processFileStreaming(ctx, fileToProcess, result); err != nil {
return fmt.Errorf("processing file %s: %w", fileToProcess.Path, err)
}
// Update files processed counter
if s.progress != nil {
s.progress.GetStats().FilesProcessed.Add(1)
}
}
// Final flush (outside any transaction)
s.packerMu.Lock()
if err := s.packer.Flush(); err != nil {
s.packerMu.Unlock()
return fmt.Errorf("flushing packer: %w", err)
}
s.packerMu.Unlock()
// If no S3 client, store any remaining blobs
if s.s3Client == nil {
blobs := s.packer.GetFinishedBlobs()
for _, b := range blobs {
// Blob metadata is already stored incrementally during packing
// Just add the blob to the snapshot
err := s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, b.ID, b.Hash)
})
if err != nil {
return fmt.Errorf("storing blob metadata: %w", err)
}
}
result.BlobsCreated += len(blobs)
}
return nil
}
// checkFileAndUpdateMetadata checks if a file needs processing and updates metadata
func (s *Scanner) checkFileAndUpdateMetadata(ctx context.Context, path string, info os.FileInfo, result *ScanResult) (*database.File, bool, error) {
// Check context cancellation
select {
case <-ctx.Done():
return nil, false, ctx.Err()
default:
}
var file *database.File
var needsProcessing bool
// Use a short transaction just for the database operations
err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
var err error
file, needsProcessing, err = s.checkFile(txCtx, tx, path, info, result)
return err
})
return file, needsProcessing, err
}
// checkFile checks if a file needs processing and updates metadata within a transaction
func (s *Scanner) checkFile(ctx context.Context, tx *sql.Tx, path string, info os.FileInfo, result *ScanResult) (*database.File, bool, error) {
// Get file stats
stat, ok := info.Sys().(interface {
Uid() uint32
@@ -125,92 +338,378 @@ func (s *Scanner) processFile(ctx context.Context, tx *sql.Tx, path string, info
LinkTarget: linkTarget,
}
// Insert file
if err := s.repos.Files.Create(ctx, tx, file); err != nil {
return err
// Check if file has changed since last backup
log.Debug("Checking if file exists in database", "path", path)
existingFile, err := s.repos.Files.GetByPathTx(ctx, tx, path)
if err != nil {
return nil, false, fmt.Errorf("checking existing file: %w", err)
}
fileChanged := existingFile == nil || s.hasFileChanged(existingFile, file)
// Always update file metadata
log.Debug("Updating file metadata", "path", path, "changed", fileChanged)
if err := s.repos.Files.Create(ctx, tx, file); err != nil {
return nil, false, err
}
log.Debug("File metadata updated", "path", path)
// Add file to snapshot
log.Debug("Adding file to snapshot", "path", path, "snapshot", s.snapshotID)
if err := s.repos.Snapshots.AddFile(ctx, tx, s.snapshotID, path); err != nil {
return nil, false, fmt.Errorf("adding file to snapshot: %w", err)
}
log.Debug("File added to snapshot", "path", path)
result.FilesScanned++
result.BytesScanned += info.Size()
// Process chunks only for regular files
if info.Mode().IsRegular() && info.Size() > 0 {
if err := s.processFileChunks(ctx, tx, path, result); err != nil {
return err
// Update progress
if s.progress != nil {
stats := s.progress.GetStats()
stats.FilesScanned.Add(1)
stats.CurrentFile.Store(path)
}
// Track skipped files
if info.Mode().IsRegular() && info.Size() > 0 && !fileChanged {
result.FilesSkipped++
result.BytesSkipped += info.Size()
if s.progress != nil {
stats := s.progress.GetStats()
stats.FilesSkipped.Add(1)
stats.BytesSkipped.Add(info.Size())
}
// File hasn't changed, but we still need to associate existing chunks with this snapshot
log.Debug("File hasn't changed, associating existing chunks", "path", path)
if err := s.associateExistingChunks(ctx, tx, path); err != nil {
return nil, false, fmt.Errorf("associating existing chunks: %w", err)
}
log.Debug("Existing chunks associated", "path", path)
} else {
// File changed or is not a regular file
result.BytesScanned += info.Size()
if s.progress != nil {
s.progress.GetStats().BytesScanned.Add(info.Size())
}
}
return nil
return file, fileChanged, nil
}
func (s *Scanner) processFileChunks(ctx context.Context, tx *sql.Tx, path string, result *ScanResult) error {
file, err := s.fs.Open(path)
// hasFileChanged determines if a file has changed since last backup
func (s *Scanner) hasFileChanged(existingFile, newFile *database.File) bool {
// Check if any metadata has changed
if existingFile.Size != newFile.Size {
return true
}
if existingFile.MTime.Unix() != newFile.MTime.Unix() {
return true
}
if existingFile.Mode != newFile.Mode {
return true
}
if existingFile.UID != newFile.UID {
return true
}
if existingFile.GID != newFile.GID {
return true
}
if existingFile.LinkTarget != newFile.LinkTarget {
return true
}
return false
}
// associateExistingChunks links existing chunks from an unchanged file to the current snapshot
func (s *Scanner) associateExistingChunks(ctx context.Context, tx *sql.Tx, path string) error {
log.Debug("associateExistingChunks start", "path", path)
// Get existing file chunks
log.Debug("Getting existing file chunks", "path", path)
fileChunks, err := s.repos.FileChunks.GetByFileTx(ctx, tx, path)
if err != nil {
return err
return fmt.Errorf("getting existing file chunks: %w", err)
}
defer func() {
if err := file.Close(); err != nil {
database.Fatal("failed to close file %s: %v", path, err)
log.Debug("Got file chunks", "path", path, "count", len(fileChunks))
// For each chunk, find its blob and associate with current snapshot
processedBlobs := make(map[string]bool)
for i, fc := range fileChunks {
log.Debug("Processing chunk", "path", path, "chunk_index", i, "chunk_hash", fc.ChunkHash)
// Find which blob contains this chunk
log.Debug("Finding blob for chunk", "chunk_hash", fc.ChunkHash)
blobChunk, err := s.repos.BlobChunks.GetByChunkHashTx(ctx, tx, fc.ChunkHash)
if err != nil {
return fmt.Errorf("finding blob for chunk %s: %w", fc.ChunkHash, err)
}
}()
if blobChunk == nil {
log.Warn("Chunk exists but not in any blob", "chunk", fc.ChunkHash, "file", path)
continue
}
log.Debug("Found blob for chunk", "chunk_hash", fc.ChunkHash, "blob_id", blobChunk.BlobID)
sequence := 0
buffer := make([]byte, s.chunkSize)
for {
n, err := io.ReadFull(file, buffer)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return err
// Get blob to find its hash
blob, err := s.repos.Blobs.GetByID(ctx, blobChunk.BlobID)
if err != nil {
return fmt.Errorf("getting blob %s: %w", blobChunk.BlobID, err)
}
if blob == nil {
log.Warn("Blob record not found", "blob_id", blobChunk.BlobID)
continue
}
if n == 0 {
break
}
// Calculate chunk hash
h := sha256.New()
h.Write(buffer[:n])
hash := hex.EncodeToString(h.Sum(nil))
// Create chunk if it doesn't exist
chunk := &database.Chunk{
ChunkHash: hash,
SHA256: hash, // Using same hash for now
Size: int64(n),
}
// Try to insert chunk (ignore duplicate errors)
_ = s.repos.Chunks.Create(ctx, tx, chunk)
// Create file-chunk mapping
fileChunk := &database.FileChunk{
Path: path,
ChunkHash: hash,
Idx: sequence,
}
if err := s.repos.FileChunks.Create(ctx, tx, fileChunk); err != nil {
return err
}
// Create chunk-file mapping
chunkFile := &database.ChunkFile{
ChunkHash: hash,
FilePath: path,
FileOffset: int64(sequence * s.chunkSize),
Length: int64(n),
}
if err := s.repos.ChunkFiles.Create(ctx, tx, chunkFile); err != nil {
return err
}
sequence++
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
// Add blob to snapshot if not already processed
if !processedBlobs[blobChunk.BlobID] {
log.Debug("Adding blob to snapshot", "blob_id", blobChunk.BlobID, "blob_hash", blob.Hash, "snapshot", s.snapshotID)
if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, blobChunk.BlobID, blob.Hash); err != nil {
return fmt.Errorf("adding existing blob to snapshot: %w", err)
}
log.Debug("Added blob to snapshot", "blob_id", blobChunk.BlobID)
processedBlobs[blobChunk.BlobID] = true
}
}
log.Debug("associateExistingChunks complete", "path", path, "blobs_processed", len(processedBlobs))
return nil
}
// handleBlobReady is called by the packer when a blob is finalized
func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
log.Debug("Blob handler called", "blob_hash", blobWithReader.Hash[:8]+"...")
startTime := time.Now()
finishedBlob := blobWithReader.FinishedBlob
// Report upload start
if s.progress != nil {
s.progress.ReportUploadStart(finishedBlob.Hash, finishedBlob.Compressed)
}
// Upload to S3 first (without holding any locks)
// Use scan context for cancellation support
ctx := s.scanCtx
if ctx == nil {
ctx = context.Background()
}
if err := s.s3Client.PutObject(ctx, "blobs/"+finishedBlob.Hash, blobWithReader.Reader); err != nil {
return fmt.Errorf("uploading blob %s to S3: %w", finishedBlob.Hash, err)
}
uploadDuration := time.Since(startTime)
// Report upload complete
if s.progress != nil {
s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration)
}
// Update progress
if s.progress != nil {
stats := s.progress.GetStats()
stats.BlobsUploaded.Add(1)
stats.BytesUploaded.Add(finishedBlob.Compressed)
stats.BlobsCreated.Add(1)
}
// Store metadata in database (after upload is complete)
dbCtx := s.scanCtx
if dbCtx == nil {
dbCtx = context.Background()
}
err := s.repos.WithTx(dbCtx, func(ctx context.Context, tx *sql.Tx) error {
// Update blob upload timestamp
if err := s.repos.Blobs.UpdateUploaded(ctx, tx, finishedBlob.ID); err != nil {
return fmt.Errorf("updating blob upload timestamp: %w", err)
}
// Add the blob to the snapshot
if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, finishedBlob.ID, finishedBlob.Hash); err != nil {
return fmt.Errorf("adding blob to snapshot: %w", err)
}
// Record upload metrics
upload := &database.Upload{
BlobHash: finishedBlob.Hash,
UploadedAt: startTime,
Size: finishedBlob.Compressed,
DurationMs: uploadDuration.Milliseconds(),
}
if err := s.repos.Uploads.Create(ctx, tx, upload); err != nil {
return fmt.Errorf("recording upload metrics: %w", err)
}
return nil
})
// Cleanup temp file if needed
if blobWithReader.TempFile != nil {
tempName := blobWithReader.TempFile.Name()
if err := blobWithReader.TempFile.Close(); err != nil {
log.Fatal("Failed to close temp file", "file", tempName, "error", err)
}
if err := os.Remove(tempName); err != nil {
log.Fatal("Failed to remove temp file", "file", tempName, "error", err)
}
}
return err
}
// processFileStreaming processes a file by streaming chunks directly to the packer
func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) error {
// Open the file
file, err := s.fs.Open(fileToProcess.Path)
if err != nil {
return fmt.Errorf("opening file: %w", err)
}
defer func() { _ = file.Close() }()
// We'll collect file chunks for database storage
// but process them for packing as we go
type chunkInfo struct {
fileChunk database.FileChunk
offset int64
size int64
}
var chunks []chunkInfo
chunkIndex := 0
// Process chunks in streaming fashion
err = s.chunker.ChunkReaderStreaming(file, func(chunk chunker.Chunk) error {
// Check for cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
log.Debug("Processing chunk",
"file", fileToProcess.Path,
"chunk", chunkIndex,
"hash", chunk.Hash,
"size", chunk.Size)
// Check if chunk already exists
chunkExists := false
err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
existing, err := s.repos.Chunks.GetByHash(txCtx, chunk.Hash)
if err != nil {
return err
}
chunkExists = (existing != nil)
// Store chunk if new
if !chunkExists {
dbChunk := &database.Chunk{
ChunkHash: chunk.Hash,
SHA256: chunk.Hash,
Size: chunk.Size,
}
if err := s.repos.Chunks.Create(txCtx, tx, dbChunk); err != nil {
return fmt.Errorf("creating chunk: %w", err)
}
}
return nil
})
if err != nil {
return fmt.Errorf("checking/storing chunk: %w", err)
}
// Track file chunk association for later storage
chunks = append(chunks, chunkInfo{
fileChunk: database.FileChunk{
Path: fileToProcess.Path,
Idx: chunkIndex,
ChunkHash: chunk.Hash,
},
offset: chunk.Offset,
size: chunk.Size,
})
// Update stats
if chunkExists {
result.FilesSkipped++ // Track as skipped for now
result.BytesSkipped += chunk.Size
if s.progress != nil {
s.progress.GetStats().BytesSkipped.Add(chunk.Size)
}
} else {
result.ChunksCreated++
result.BytesScanned += chunk.Size
if s.progress != nil {
s.progress.GetStats().ChunksCreated.Add(1)
s.progress.GetStats().BytesProcessed.Add(chunk.Size)
s.progress.UpdateChunkingActivity()
}
}
// Add chunk to packer immediately (streaming)
// This happens outside the database transaction
if !chunkExists {
s.packerMu.Lock()
err := s.packer.AddChunk(&blob.ChunkRef{
Hash: chunk.Hash,
Data: chunk.Data,
})
if err == blob.ErrBlobSizeLimitExceeded {
// Finalize current blob and retry
if err := s.packer.FinalizeBlob(); err != nil {
s.packerMu.Unlock()
return fmt.Errorf("finalizing blob: %w", err)
}
// Retry adding the chunk
if err := s.packer.AddChunk(&blob.ChunkRef{
Hash: chunk.Hash,
Data: chunk.Data,
}); err != nil {
s.packerMu.Unlock()
return fmt.Errorf("adding chunk after finalize: %w", err)
}
} else if err != nil {
s.packerMu.Unlock()
return fmt.Errorf("adding chunk to packer: %w", err)
}
s.packerMu.Unlock()
}
// Clear chunk data from memory immediately after use
chunk.Data = nil
chunkIndex++
return nil
})
if err != nil {
return fmt.Errorf("chunking file: %w", err)
}
// Store file-chunk associations and chunk-file mappings in database
err = s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
for _, ci := range chunks {
// Create file-chunk mapping
if err := s.repos.FileChunks.Create(txCtx, tx, &ci.fileChunk); err != nil {
return fmt.Errorf("creating file chunk: %w", err)
}
// Create chunk-file mapping
chunkFile := &database.ChunkFile{
ChunkHash: ci.fileChunk.ChunkHash,
FilePath: fileToProcess.Path,
FileOffset: ci.offset,
Length: ci.size,
}
if err := s.repos.ChunkFiles.Create(txCtx, tx, chunkFile); err != nil {
return fmt.Errorf("creating chunk file: %w", err)
}
}
// Add file to snapshot
if err := s.repos.Snapshots.AddFile(txCtx, tx, s.snapshotID, fileToProcess.Path); err != nil {
return fmt.Errorf("adding file to snapshot: %w", err)
}
return nil
})
return err
}

View File

@@ -2,16 +2,21 @@ package backup_test
import (
"context"
"database/sql"
"path/filepath"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/backup"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/spf13/afero"
)
func TestScannerSimpleDirectory(t *testing.T) {
// Initialize logger for tests
log.Initialize(log.Config{})
// Create in-memory filesystem
fs := afero.NewMemMapFs()
@@ -56,25 +61,53 @@ func TestScannerSimpleDirectory(t *testing.T) {
// Create scanner
scanner := backup.NewScanner(backup.ScannerConfig{
FS: fs,
ChunkSize: 1024 * 16, // 16KB chunks for testing
Repositories: repos,
FS: fs,
ChunkSize: int64(1024 * 16), // 16KB chunks for testing
Repositories: repos,
MaxBlobSize: int64(1024 * 1024), // 1MB blobs
CompressionLevel: 3,
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
})
// Scan the directory
// Create a snapshot record for testing
ctx := context.Background()
result, err := scanner.Scan(ctx, "/source")
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),
CompletedAt: nil,
FileCount: 0,
ChunkCount: 0,
BlobCount: 0,
TotalSize: 0,
BlobSize: 0,
CompressionRatio: 1.0,
}
return repos.Snapshots.Create(ctx, tx, snapshot)
})
if err != nil {
t.Fatalf("failed to create snapshot: %v", err)
}
// Scan the directory
var result *backup.ScanResult
result, err = scanner.Scan(ctx, "/source", snapshotID)
if err != nil {
t.Fatalf("scan failed: %v", err)
}
// Verify results
if result.FilesScanned != 6 {
t.Errorf("expected 6 files scanned, got %d", result.FilesScanned)
// We now scan 6 files + 3 directories (source, subdir, subdir2) = 9 entries
if result.FilesScanned != 9 {
t.Errorf("expected 9 entries scanned, got %d", result.FilesScanned)
}
if result.BytesScanned != 97 { // Total size of all test files: 13 + 20 + 20 + 28 + 0 + 16 = 97
t.Errorf("expected 97 bytes scanned, got %d", result.BytesScanned)
// Directories have their own sizes, so the total will be more than just file content
if result.BytesScanned < 97 { // At minimum we have 97 bytes of file content
t.Errorf("expected at least 97 bytes scanned, got %d", result.BytesScanned)
}
// Verify files in database
@@ -83,8 +116,9 @@ func TestScannerSimpleDirectory(t *testing.T) {
t.Fatalf("failed to list files: %v", err)
}
if len(files) != 6 {
t.Errorf("expected 6 files in database, got %d", len(files))
// We should have 6 files + 3 directories = 9 entries
if len(files) != 9 {
t.Errorf("expected 9 entries in database, got %d", len(files))
}
// Verify specific file
@@ -126,6 +160,9 @@ func TestScannerSimpleDirectory(t *testing.T) {
}
func TestScannerWithSymlinks(t *testing.T) {
// Initialize logger for tests
log.Initialize(log.Config{})
// Create in-memory filesystem
fs := afero.NewMemMapFs()
@@ -171,14 +208,40 @@ func TestScannerWithSymlinks(t *testing.T) {
// Create scanner
scanner := backup.NewScanner(backup.ScannerConfig{
FS: fs,
ChunkSize: 1024 * 16,
Repositories: repos,
FS: fs,
ChunkSize: 1024 * 16,
Repositories: repos,
MaxBlobSize: int64(1024 * 1024),
CompressionLevel: 3,
AgeRecipients: []string{},
})
// Scan the directory
// Create a snapshot record for testing
ctx := context.Background()
result, err := scanner.Scan(ctx, "/source")
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),
CompletedAt: nil,
FileCount: 0,
ChunkCount: 0,
BlobCount: 0,
TotalSize: 0,
BlobSize: 0,
CompressionRatio: 1.0,
}
return repos.Snapshots.Create(ctx, tx, snapshot)
})
if err != nil {
t.Fatalf("failed to create snapshot: %v", err)
}
// Scan the directory
var result *backup.ScanResult
result, err = scanner.Scan(ctx, "/source", snapshotID)
if err != nil {
t.Fatalf("scan failed: %v", err)
}
@@ -209,13 +272,19 @@ func TestScannerWithSymlinks(t *testing.T) {
}
func TestScannerLargeFile(t *testing.T) {
// Initialize logger for tests
log.Initialize(log.Config{})
// Create in-memory filesystem
fs := afero.NewMemMapFs()
// Create a large file that will require multiple chunks
// Use random content to ensure good chunk boundaries
largeContent := make([]byte, 1024*1024) // 1MB
for i := range largeContent {
largeContent[i] = byte(i % 256)
// Fill with pseudo-random data to ensure chunk boundaries
for i := 0; i < len(largeContent); i++ {
// Simple pseudo-random generator for deterministic tests
largeContent[i] = byte((i * 7919) ^ (i >> 3))
}
if err := fs.MkdirAll("/source", 0755); err != nil {
@@ -238,22 +307,54 @@ func TestScannerLargeFile(t *testing.T) {
repos := database.NewRepositories(db)
// Create scanner with 64KB chunks
// Create scanner with 64KB average chunk size
scanner := backup.NewScanner(backup.ScannerConfig{
FS: fs,
ChunkSize: 1024 * 64, // 64KB chunks
Repositories: repos,
FS: fs,
ChunkSize: int64(1024 * 64), // 64KB average chunks
Repositories: repos,
MaxBlobSize: int64(1024 * 1024),
CompressionLevel: 3,
AgeRecipients: []string{},
})
// Scan the directory
// Create a snapshot record for testing
ctx := context.Background()
result, err := scanner.Scan(ctx, "/source")
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),
CompletedAt: nil,
FileCount: 0,
ChunkCount: 0,
BlobCount: 0,
TotalSize: 0,
BlobSize: 0,
CompressionRatio: 1.0,
}
return repos.Snapshots.Create(ctx, tx, snapshot)
})
if err != nil {
t.Fatalf("failed to create snapshot: %v", err)
}
// Scan the directory
var result *backup.ScanResult
result, err = scanner.Scan(ctx, "/source", snapshotID)
if err != nil {
t.Fatalf("scan failed: %v", err)
}
if result.BytesScanned != 1024*1024 {
t.Errorf("expected %d bytes scanned, got %d", 1024*1024, result.BytesScanned)
// We scan 1 file + 1 directory = 2 entries
if result.FilesScanned != 2 {
t.Errorf("expected 2 entries scanned, got %d", result.FilesScanned)
}
// The file size should be at least 1MB
if result.BytesScanned < 1024*1024 {
t.Errorf("expected at least %d bytes scanned, got %d", 1024*1024, result.BytesScanned)
}
// Verify chunks
@@ -262,11 +363,15 @@ func TestScannerLargeFile(t *testing.T) {
t.Fatalf("failed to get chunks: %v", err)
}
expectedChunks := 16 // 1MB / 64KB
if len(chunks) != expectedChunks {
t.Errorf("expected %d chunks, got %d", expectedChunks, len(chunks))
// With content-defined chunking, the number of chunks depends on content
// For a 1MB file, we should get at least 1 chunk
if len(chunks) < 1 {
t.Errorf("expected at least 1 chunk, got %d", len(chunks))
}
// Log the actual number of chunks for debugging
t.Logf("1MB file produced %d chunks with 64KB average chunk size", len(chunks))
// Verify chunk sequence
for i, fc := range chunks {
if fc.Idx != i {

542
internal/backup/snapshot.go Normal file
View File

@@ -0,0 +1,542 @@
package backup
// Snapshot Metadata Export Process
// ================================
//
// The snapshot metadata contains all information needed to restore a backup.
// Instead of creating a custom format, we use a trimmed copy of the SQLite
// database containing only data relevant to the current snapshot.
//
// Process Overview:
// 1. After all files/chunks/blobs are backed up, create a snapshot record
// 2. Close the main database to ensure consistency
// 3. Copy the entire database to a temporary file
// 4. Open the temporary database
// 5. Delete all snapshots except the current one
// 6. Delete all orphaned records:
// - Files not referenced by any remaining snapshot
// - Chunks not referenced by any remaining files
// - Blobs not containing any remaining chunks
// - All related mapping tables (file_chunks, chunk_files, blob_chunks)
// 7. Close the temporary database
// 8. Use sqlite3 to dump the cleaned database to SQL
// 9. Delete the temporary database file
// 10. Compress the SQL dump with zstd
// 11. Encrypt the compressed dump with age (if encryption is enabled)
// 12. Upload to S3 as: snapshots/{snapshot-id}.sql.zst[.age]
// 13. Reopen the main database
//
// Advantages of this approach:
// - No custom metadata format needed
// - Reuses existing database schema and relationships
// - SQL dumps are portable and compress well
// - Restore process can simply execute the SQL
// - Atomic and consistent snapshot of all metadata
//
// TODO: Future improvements:
// - Add snapshot-file relationships to track which files belong to which snapshot
// - Implement incremental snapshots that reference previous snapshots
// - Add snapshot manifest with additional metadata (size, chunk count, etc.)
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"time"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/klauspost/compress/zstd"
)
// SnapshotManager handles snapshot creation and metadata export
type SnapshotManager struct {
repos *database.Repositories
s3Client S3Client
encryptor Encryptor
}
// Encryptor interface for snapshot encryption
type Encryptor interface {
Encrypt(data []byte) ([]byte, error)
}
// NewSnapshotManager creates a new snapshot manager
func NewSnapshotManager(repos *database.Repositories, s3Client S3Client, encryptor Encryptor) *SnapshotManager {
return &SnapshotManager{
repos: repos,
s3Client: s3Client,
encryptor: encryptor,
}
}
// CreateSnapshot creates a new snapshot record in the database at the start of a backup
func (sm *SnapshotManager) CreateSnapshot(ctx context.Context, hostname, version string) (string, error) {
snapshotID := fmt.Sprintf("%s-%s", hostname, time.Now().Format("20060102-150405"))
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: hostname,
VaultikVersion: version,
StartedAt: time.Now(),
CompletedAt: nil, // Not completed yet
FileCount: 0,
ChunkCount: 0,
BlobCount: 0,
TotalSize: 0,
BlobSize: 0,
CompressionRatio: 1.0,
}
err := sm.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return sm.repos.Snapshots.Create(ctx, tx, snapshot)
})
if err != nil {
return "", fmt.Errorf("creating snapshot: %w", err)
}
log.Info("Created snapshot", "snapshot_id", snapshotID)
return snapshotID, nil
}
// UpdateSnapshotStats updates the statistics for a snapshot during backup
func (sm *SnapshotManager) UpdateSnapshotStats(ctx context.Context, snapshotID string, stats BackupStats) error {
err := sm.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return sm.repos.Snapshots.UpdateCounts(ctx, tx, snapshotID,
int64(stats.FilesScanned),
int64(stats.ChunksCreated),
int64(stats.BlobsCreated),
stats.BytesScanned,
stats.BytesUploaded,
)
})
if err != nil {
return fmt.Errorf("updating snapshot stats: %w", err)
}
return nil
}
// CompleteSnapshot marks a snapshot as completed and exports its metadata
func (sm *SnapshotManager) CompleteSnapshot(ctx context.Context, snapshotID string) error {
// Mark the snapshot as completed
err := sm.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return sm.repos.Snapshots.MarkComplete(ctx, tx, snapshotID)
})
if err != nil {
return fmt.Errorf("marking snapshot complete: %w", err)
}
log.Info("Completed snapshot", "snapshot_id", snapshotID)
return nil
}
// ExportSnapshotMetadata exports snapshot metadata to S3
//
// This method executes the complete snapshot metadata export process:
// 1. Creates a temporary directory for working files
// 2. Copies the main database to preserve its state
// 3. Cleans the copy to contain only current snapshot data
// 4. Dumps the cleaned database to SQL
// 5. Compresses the SQL dump with zstd
// 6. Encrypts the compressed data (if encryption is enabled)
// 7. Uploads to S3 at: snapshots/{snapshot-id}.sql.zst[.age]
//
// The caller is responsible for:
// - Ensuring the main database is closed before calling this method
// - Reopening the main database after this method returns
//
// This ensures database consistency during the copy operation.
func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath string, snapshotID string) error {
log.Info("Exporting snapshot metadata", "snapshot_id", snapshotID)
// Create temp directory for all temporary files
tempDir, err := os.MkdirTemp("", "vaultik-snapshot-*")
if err != nil {
return fmt.Errorf("creating temp dir: %w", err)
}
defer func() {
if err := os.RemoveAll(tempDir); err != nil {
log.Debug("Failed to remove temp dir", "path", tempDir, "error", err)
}
}()
// Step 1: Copy database to temp file
// The main database should be closed at this point
tempDBPath := filepath.Join(tempDir, "snapshot.db")
if err := copyFile(dbPath, tempDBPath); err != nil {
return fmt.Errorf("copying database: %w", err)
}
// Step 2: Clean the temp database to only contain current snapshot data
if err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID); err != nil {
return fmt.Errorf("cleaning snapshot database: %w", err)
}
// Step 3: Dump the cleaned database to SQL
dumpPath := filepath.Join(tempDir, "snapshot.sql")
if err := sm.dumpDatabase(tempDBPath, dumpPath); err != nil {
return fmt.Errorf("dumping database: %w", err)
}
// Step 4: Compress the SQL dump
compressedPath := filepath.Join(tempDir, "snapshot.sql.zst")
if err := sm.compressDump(dumpPath, compressedPath); err != nil {
return fmt.Errorf("compressing dump: %w", err)
}
// Step 5: Read compressed data for encryption/upload
compressedData, err := os.ReadFile(compressedPath)
if err != nil {
return fmt.Errorf("reading compressed dump: %w", err)
}
// Step 6: Encrypt if encryptor is available
finalData := compressedData
if sm.encryptor != nil {
encrypted, err := sm.encryptor.Encrypt(compressedData)
if err != nil {
return fmt.Errorf("encrypting snapshot: %w", err)
}
finalData = encrypted
}
// Step 7: Generate blob manifest (before closing temp DB)
blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID)
if err != nil {
return fmt.Errorf("generating blob manifest: %w", err)
}
// Step 8: Upload to S3 in snapshot subdirectory
// Upload database backup (encrypted)
dbKey := fmt.Sprintf("metadata/%s/db.zst", snapshotID)
if sm.encryptor != nil {
dbKey += ".age"
}
if err := sm.s3Client.PutObject(ctx, dbKey, bytes.NewReader(finalData)); err != nil {
return fmt.Errorf("uploading snapshot database: %w", err)
}
// Upload blob manifest (unencrypted, compressed)
manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
if err := sm.s3Client.PutObject(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil {
return fmt.Errorf("uploading blob manifest: %w", err)
}
log.Info("Uploaded snapshot metadata",
"snapshot_id", snapshotID,
"db_size", len(finalData),
"manifest_size", len(blobManifest))
return nil
}
// cleanSnapshotDB removes all data except for the specified snapshot
//
// Current implementation:
// Since we don't yet have snapshot-file relationships, this currently only
// removes other snapshots. In a complete implementation, it would:
//
// 1. Delete all snapshots except the current one
// 2. Delete files not belonging to the current snapshot
// 3. Delete file_chunks for deleted files (CASCADE)
// 4. Delete chunk_files for deleted files
// 5. Delete chunks with no remaining file references
// 6. Delete blob_chunks for deleted chunks
// 7. Delete blobs with no remaining chunks
//
// The order is important to maintain referential integrity.
//
// Future implementation when we have snapshot_files table:
//
// DELETE FROM snapshots WHERE id != ?;
// DELETE FROM files WHERE path NOT IN (
// SELECT file_path FROM snapshot_files WHERE snapshot_id = ?
// );
// DELETE FROM chunks WHERE chunk_hash NOT IN (
// SELECT DISTINCT chunk_hash FROM file_chunks
// );
// DELETE FROM blobs WHERE blob_hash NOT IN (
// SELECT DISTINCT blob_hash FROM blob_chunks
// );
func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, snapshotID string) error {
// Open the temp database
db, err := database.New(ctx, dbPath)
if err != nil {
return fmt.Errorf("opening temp database: %w", err)
}
defer func() {
if err := db.Close(); err != nil {
log.Debug("Failed to close temp database", "error", err)
}
}()
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
defer func() {
if rbErr := tx.Rollback(); rbErr != nil && rbErr != sql.ErrTxDone {
log.Debug("Failed to rollback transaction", "error", rbErr)
}
}()
// Step 1: Delete all other snapshots
_, err = tx.ExecContext(ctx, "DELETE FROM snapshots WHERE id != ?", snapshotID)
if err != nil {
return fmt.Errorf("deleting other snapshots: %w", err)
}
// Step 2: Delete files not in this snapshot
_, err = tx.ExecContext(ctx, `
DELETE FROM files
WHERE path NOT IN (
SELECT file_path FROM snapshot_files WHERE snapshot_id = ?
)`, snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned files: %w", err)
}
// Step 3: file_chunks will be deleted via CASCADE from files
// Step 4: Delete chunk_files for deleted files
_, err = tx.ExecContext(ctx, `
DELETE FROM chunk_files
WHERE file_path NOT IN (
SELECT path FROM files
)`)
if err != nil {
return fmt.Errorf("deleting orphaned chunk_files: %w", err)
}
// Step 5: Delete chunks with no remaining file references
_, err = tx.ExecContext(ctx, `
DELETE FROM chunks
WHERE chunk_hash NOT IN (
SELECT DISTINCT chunk_hash FROM file_chunks
)`)
if err != nil {
return fmt.Errorf("deleting orphaned chunks: %w", err)
}
// Step 6: Delete blob_chunks for deleted chunks
_, err = tx.ExecContext(ctx, `
DELETE FROM blob_chunks
WHERE chunk_hash NOT IN (
SELECT chunk_hash FROM chunks
)`)
if err != nil {
return fmt.Errorf("deleting orphaned blob_chunks: %w", err)
}
// Step 7: Delete blobs not in this snapshot
_, err = tx.ExecContext(ctx, `
DELETE FROM blobs
WHERE blob_hash NOT IN (
SELECT blob_hash FROM snapshot_blobs WHERE snapshot_id = ?
)`, snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned blobs: %w", err)
}
// Step 8: Delete orphaned snapshot_files and snapshot_blobs
_, err = tx.ExecContext(ctx, "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned snapshot_files: %w", err)
}
_, err = tx.ExecContext(ctx, "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned snapshot_blobs: %w", err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}
// dumpDatabase creates a SQL dump of the database
func (sm *SnapshotManager) dumpDatabase(dbPath, dumpPath string) error {
cmd := exec.Command("sqlite3", dbPath, ".dump")
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("running sqlite3 dump: %w", err)
}
if err := os.WriteFile(dumpPath, output, 0644); err != nil {
return fmt.Errorf("writing dump file: %w", err)
}
return nil
}
// compressDump compresses the SQL dump using zstd
func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error {
input, err := os.Open(inputPath)
if err != nil {
return fmt.Errorf("opening input file: %w", err)
}
defer func() {
if err := input.Close(); err != nil {
log.Debug("Failed to close input file", "error", err)
}
}()
output, err := os.Create(outputPath)
if err != nil {
return fmt.Errorf("creating output file: %w", err)
}
defer func() {
if err := output.Close(); err != nil {
log.Debug("Failed to close output file", "error", err)
}
}()
// Create zstd encoder with good compression and multithreading
zstdWriter, err := zstd.NewWriter(output,
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
zstd.WithEncoderConcurrency(runtime.NumCPU()),
zstd.WithWindowSize(4<<20), // 4MB window for metadata files
)
if err != nil {
return fmt.Errorf("creating zstd writer: %w", err)
}
defer func() {
if err := zstdWriter.Close(); err != nil {
log.Debug("Failed to close zstd writer", "error", err)
}
}()
if _, err := io.Copy(zstdWriter, input); err != nil {
return fmt.Errorf("compressing data: %w", err)
}
return nil
}
// copyFile copies a file from src to dst
func copyFile(src, dst string) error {
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer func() {
if err := sourceFile.Close(); err != nil {
log.Debug("Failed to close source file", "error", err)
}
}()
destFile, err := os.Create(dst)
if err != nil {
return err
}
defer func() {
if err := destFile.Close(); err != nil {
log.Debug("Failed to close destination file", "error", err)
}
}()
if _, err := io.Copy(destFile, sourceFile); err != nil {
return err
}
return nil
}
// generateBlobManifest creates a compressed JSON list of all blobs in the snapshot
func (sm *SnapshotManager) generateBlobManifest(ctx context.Context, dbPath string, snapshotID string) ([]byte, error) {
// Open the cleaned database using the database package
db, err := database.New(ctx, dbPath)
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}
defer func() { _ = db.Close() }()
// Create repositories to access the data
repos := database.NewRepositories(db)
// Get all blobs for this snapshot
blobs, err := repos.Snapshots.GetBlobHashes(ctx, snapshotID)
if err != nil {
return nil, fmt.Errorf("getting snapshot blobs: %w", err)
}
// Create manifest structure
manifest := struct {
SnapshotID string `json:"snapshot_id"`
Timestamp string `json:"timestamp"`
BlobCount int `json:"blob_count"`
Blobs []string `json:"blobs"`
}{
SnapshotID: snapshotID,
Timestamp: time.Now().UTC().Format(time.RFC3339),
BlobCount: len(blobs),
Blobs: blobs,
}
// Marshal to JSON
jsonData, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return nil, fmt.Errorf("marshaling manifest: %w", err)
}
// Compress with zstd
compressed, err := compressData(jsonData)
if err != nil {
return nil, fmt.Errorf("compressing manifest: %w", err)
}
log.Info("Generated blob manifest",
"snapshot_id", snapshotID,
"blob_count", len(blobs),
"json_size", len(jsonData),
"compressed_size", len(compressed))
return compressed, nil
}
// compressData compresses data using zstd
func compressData(data []byte) ([]byte, error) {
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf,
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
)
if err != nil {
return nil, err
}
if _, err := w.Write(data); err != nil {
_ = w.Close()
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// BackupStats contains statistics from a backup operation
type BackupStats struct {
FilesScanned int
BytesScanned int64
ChunksCreated int
BlobsCreated int
BytesUploaded int64
}

View File

@@ -0,0 +1,147 @@
package backup
import (
"context"
"database/sql"
"path/filepath"
"testing"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
)
func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
// Initialize logger
log.Initialize(log.Config{})
ctx := context.Background()
// Create a test database
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")
db, err := database.New(ctx, dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
repos := database.NewRepositories(db)
// Create an empty snapshot
snapshot := &database.Snapshot{
ID: "empty-snapshot",
Hostname: "test-host",
}
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return repos.Snapshots.Create(ctx, tx, snapshot)
})
if err != nil {
t.Fatalf("failed to create snapshot: %v", err)
}
// Create some files and chunks not associated with any snapshot
file := &database.File{Path: "/orphan/file.txt", Size: 1000}
chunk := &database.Chunk{ChunkHash: "orphan-chunk", SHA256: "orphan-chunk", Size: 500}
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
if err := repos.Files.Create(ctx, tx, file); err != nil {
return err
}
return repos.Chunks.Create(ctx, tx, chunk)
})
if err != nil {
t.Fatalf("failed to create orphan data: %v", err)
}
// Close the database
if err := db.Close(); err != nil {
t.Fatalf("failed to close database: %v", err)
}
// Copy database
tempDBPath := filepath.Join(tempDir, "temp.db")
if err := copyFile(dbPath, tempDBPath); err != nil {
t.Fatalf("failed to copy database: %v", err)
}
// Clean the database
sm := &SnapshotManager{}
if err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshot.ID); err != nil {
t.Fatalf("failed to clean snapshot database: %v", err)
}
// Verify the cleaned database
cleanedDB, err := database.New(ctx, tempDBPath)
if err != nil {
t.Fatalf("failed to open cleaned database: %v", err)
}
defer func() {
if err := cleanedDB.Close(); err != nil {
t.Errorf("failed to close database: %v", err)
}
}()
cleanedRepos := database.NewRepositories(cleanedDB)
// Verify snapshot exists
verifySnapshot, err := cleanedRepos.Snapshots.GetByID(ctx, snapshot.ID)
if err != nil {
t.Fatalf("failed to get snapshot: %v", err)
}
if verifySnapshot == nil {
t.Error("snapshot should exist")
}
// Verify orphan file is gone
f, err := cleanedRepos.Files.GetByPath(ctx, file.Path)
if err != nil {
t.Fatalf("failed to check file: %v", err)
}
if f != nil {
t.Error("orphan file should not exist")
}
// Verify orphan chunk is gone
c, err := cleanedRepos.Chunks.GetByHash(ctx, chunk.ChunkHash)
if err != nil {
t.Fatalf("failed to check chunk: %v", err)
}
if c != nil {
t.Error("orphan chunk should not exist")
}
}
func TestCleanSnapshotDBNonExistentSnapshot(t *testing.T) {
// Initialize logger
log.Initialize(log.Config{})
ctx := context.Background()
// Create a test database
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")
db, err := database.New(ctx, dbPath)
if err != nil {
t.Fatalf("failed to create database: %v", err)
}
// Close immediately
if err := db.Close(); err != nil {
t.Fatalf("failed to close database: %v", err)
}
// Copy database
tempDBPath := filepath.Join(tempDir, "temp.db")
if err := copyFile(dbPath, tempDBPath); err != nil {
t.Fatalf("failed to copy database: %v", err)
}
// Try to clean with non-existent snapshot
sm := &SnapshotManager{}
err = sm.cleanSnapshotDB(ctx, tempDBPath, "non-existent-snapshot")
// Should not error - it will just delete everything
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}

6
internal/blob/errors.go Normal file
View File

@@ -0,0 +1,6 @@
package blob
import "errors"
// ErrBlobSizeLimitExceeded is returned when adding a chunk would exceed the blob size limit
var ErrBlobSizeLimitExceeded = errors.New("adding chunk would exceed blob size limit")

517
internal/blob/packer.go Normal file
View File

@@ -0,0 +1,517 @@
package blob
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"hash"
"io"
"math/bits"
"os"
"runtime"
"sync"
"time"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/google/uuid"
"github.com/klauspost/compress/zstd"
)
// BlobHandler is called when a blob is finalized
type BlobHandler func(blob *BlobWithReader) error
// PackerConfig holds configuration for creating a Packer
type PackerConfig struct {
MaxBlobSize int64
CompressionLevel int
Encryptor Encryptor // Required - blobs are always encrypted
Repositories *database.Repositories // For creating blob records
BlobHandler BlobHandler // Optional - called when blob is ready
}
// Packer combines chunks into blobs with compression and encryption
type Packer struct {
maxBlobSize int64
compressionLevel int
encryptor Encryptor // Required - blobs are always encrypted
blobHandler BlobHandler // Called when blob is ready
repos *database.Repositories // For creating blob records
// Mutex for thread-safe blob creation
mu sync.Mutex
// Current blob being packed
currentBlob *blobInProgress
finishedBlobs []*FinishedBlob // Only used if no handler provided
}
// Encryptor interface for encryption support
type Encryptor interface {
Encrypt(data []byte) ([]byte, error)
EncryptWriter(dst io.Writer) (io.WriteCloser, error)
}
// blobInProgress represents a blob being assembled
type blobInProgress struct {
id string // UUID of the blob
chunks []*chunkInfo // Track chunk metadata
chunkSet map[string]bool // Track unique chunks in this blob
tempFile *os.File // Temporary file for encrypted compressed data
hasher hash.Hash // For computing hash of final encrypted data
compressor io.WriteCloser // Compression writer
encryptor io.WriteCloser // Encryption writer (if encryption enabled)
finalWriter io.Writer // The final writer in the chain
startTime time.Time
size int64 // Current uncompressed size
compressedSize int64 // Current compressed size (estimated)
}
// ChunkRef represents a chunk to be added to a blob
type ChunkRef struct {
Hash string
Data []byte
}
// chunkInfo tracks chunk metadata in a blob
type chunkInfo struct {
Hash string
Offset int64
Size int64
}
// FinishedBlob represents a completed blob ready for storage
type FinishedBlob struct {
ID string
Hash string
Data []byte // Compressed data
Chunks []*BlobChunkRef
CreatedTS time.Time
Uncompressed int64
Compressed int64
}
// BlobChunkRef represents a chunk's position within a blob
type BlobChunkRef struct {
ChunkHash string
Offset int64
Length int64
}
// BlobWithReader wraps a FinishedBlob with its data reader
type BlobWithReader struct {
*FinishedBlob
Reader io.ReadSeeker
TempFile *os.File // Optional, only set for disk-based blobs
}
// NewPacker creates a new blob packer
func NewPacker(cfg PackerConfig) (*Packer, error) {
if cfg.Encryptor == nil {
return nil, fmt.Errorf("encryptor is required - blobs must be encrypted")
}
if cfg.MaxBlobSize <= 0 {
return nil, fmt.Errorf("max blob size must be positive")
}
return &Packer{
maxBlobSize: cfg.MaxBlobSize,
compressionLevel: cfg.CompressionLevel,
encryptor: cfg.Encryptor,
blobHandler: cfg.BlobHandler,
repos: cfg.Repositories,
finishedBlobs: make([]*FinishedBlob, 0),
}, nil
}
// SetBlobHandler sets the handler to be called when a blob is finalized
func (p *Packer) SetBlobHandler(handler BlobHandler) {
p.mu.Lock()
defer p.mu.Unlock()
p.blobHandler = handler
}
// AddChunk adds a chunk to the current blob
// Returns ErrBlobSizeLimitExceeded if adding the chunk would exceed the size limit
func (p *Packer) AddChunk(chunk *ChunkRef) error {
p.mu.Lock()
defer p.mu.Unlock()
// Initialize new blob if needed
if p.currentBlob == nil {
if err := p.startNewBlob(); err != nil {
return fmt.Errorf("starting new blob: %w", err)
}
}
// Check if adding this chunk would exceed blob size limit
// Use conservative estimate: assume no compression
// Skip size check if chunk already exists in blob
if !p.currentBlob.chunkSet[chunk.Hash] {
currentSize := p.currentBlob.size
newSize := currentSize + int64(len(chunk.Data))
if newSize > p.maxBlobSize && len(p.currentBlob.chunks) > 0 {
// Return error indicating size limit would be exceeded
return ErrBlobSizeLimitExceeded
}
}
// Add chunk to current blob
if err := p.addChunkToCurrentBlob(chunk); err != nil {
return err
}
return nil
}
// Flush finalizes any pending blob
func (p *Packer) Flush() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.currentBlob != nil && len(p.currentBlob.chunks) > 0 {
if err := p.finalizeCurrentBlob(); err != nil {
return fmt.Errorf("finalizing blob: %w", err)
}
}
return nil
}
// FinalizeBlob finalizes the current blob being assembled
// Caller must handle retrying the chunk that triggered size limit
func (p *Packer) FinalizeBlob() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.currentBlob == nil {
return nil
}
return p.finalizeCurrentBlob()
}
// GetFinishedBlobs returns all completed blobs and clears the list
func (p *Packer) GetFinishedBlobs() []*FinishedBlob {
p.mu.Lock()
defer p.mu.Unlock()
blobs := p.finishedBlobs
p.finishedBlobs = make([]*FinishedBlob, 0)
return blobs
}
// startNewBlob initializes a new blob (must be called with lock held)
func (p *Packer) startNewBlob() error {
// Generate UUID for the blob
blobID := uuid.New().String()
// Create blob record in database
if p.repos != nil {
blob := &database.Blob{
ID: blobID,
Hash: "", // Will be set when finalized
CreatedTS: time.Now(),
FinishedTS: nil,
UncompressedSize: 0,
CompressedSize: 0,
UploadedTS: nil,
}
err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return p.repos.Blobs.Create(ctx, tx, blob)
})
if err != nil {
return fmt.Errorf("creating blob record: %w", err)
}
}
// Create temporary file
tempFile, err := os.CreateTemp("", "vaultik-blob-*.tmp")
if err != nil {
return fmt.Errorf("creating temp file: %w", err)
}
p.currentBlob = &blobInProgress{
id: blobID,
chunks: make([]*chunkInfo, 0),
chunkSet: make(map[string]bool),
startTime: time.Now(),
tempFile: tempFile,
hasher: sha256.New(),
size: 0,
compressedSize: 0,
}
// Build writer chain: compressor -> [encryptor ->] hasher+file
// This ensures only encrypted data touches disk
// Final destination: write to both file and hasher
finalWriter := io.MultiWriter(tempFile, p.currentBlob.hasher)
// Set up encryption (required - closest to disk)
encWriter, err := p.encryptor.EncryptWriter(finalWriter)
if err != nil {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
return fmt.Errorf("creating encryption writer: %w", err)
}
p.currentBlob.encryptor = encWriter
currentWriter := encWriter
// Set up compression (processes data before encryption)
encoderLevel := zstd.EncoderLevel(p.compressionLevel)
if p.compressionLevel < 1 {
encoderLevel = zstd.SpeedDefault
} else if p.compressionLevel > 9 {
encoderLevel = zstd.SpeedBestCompression
}
// Calculate window size based on blob size
windowSize := p.maxBlobSize / 100
if windowSize < (1 << 20) { // Min 1MB
windowSize = 1 << 20
} else if windowSize > (128 << 20) { // Max 128MB
windowSize = 128 << 20
}
windowSize = 1 << uint(63-bits.LeadingZeros64(uint64(windowSize)))
compWriter, err := zstd.NewWriter(currentWriter,
zstd.WithEncoderLevel(encoderLevel),
zstd.WithEncoderConcurrency(runtime.NumCPU()),
zstd.WithWindowSize(int(windowSize)),
)
if err != nil {
if p.currentBlob.encryptor != nil {
_ = p.currentBlob.encryptor.Close()
}
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
return fmt.Errorf("creating compression writer: %w", err)
}
p.currentBlob.compressor = compWriter
p.currentBlob.finalWriter = compWriter
log.Debug("Started new blob", "blob_id", blobID, "temp_file", tempFile.Name())
return nil
}
// addChunkToCurrentBlob adds a chunk to the current blob (must be called with lock held)
func (p *Packer) addChunkToCurrentBlob(chunk *ChunkRef) error {
// Skip if chunk already in current blob
if p.currentBlob.chunkSet[chunk.Hash] {
log.Debug("Skipping duplicate chunk in blob", "chunk_hash", chunk.Hash)
return nil
}
// Track offset before writing
offset := p.currentBlob.size
// Write to the final writer (compression -> encryption -> disk)
if _, err := p.currentBlob.finalWriter.Write(chunk.Data); err != nil {
return fmt.Errorf("writing to blob stream: %w", err)
}
// Track chunk info
chunkSize := int64(len(chunk.Data))
chunkInfo := &chunkInfo{
Hash: chunk.Hash,
Offset: offset,
Size: chunkSize,
}
p.currentBlob.chunks = append(p.currentBlob.chunks, chunkInfo)
p.currentBlob.chunkSet[chunk.Hash] = true
// Store blob-chunk association in database immediately
if p.repos != nil {
blobChunk := &database.BlobChunk{
BlobID: p.currentBlob.id,
ChunkHash: chunk.Hash,
Offset: offset,
Length: chunkSize,
}
err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return p.repos.BlobChunks.Create(ctx, tx, blobChunk)
})
if err != nil {
log.Error("Failed to store blob-chunk association", "error", err,
"blob_id", p.currentBlob.id, "chunk_hash", chunk.Hash)
// Continue anyway - we can reconstruct this later if needed
}
}
// Update total size
p.currentBlob.size += chunkSize
log.Debug("Added chunk to blob",
"blob_id", p.currentBlob.id,
"chunk_hash", chunk.Hash,
"chunk_size", len(chunk.Data),
"offset", offset,
"blob_chunks", len(p.currentBlob.chunks),
"uncompressed_size", p.currentBlob.size)
return nil
}
// finalizeCurrentBlob completes the current blob (must be called with lock held)
func (p *Packer) finalizeCurrentBlob() error {
if p.currentBlob == nil {
return nil
}
// Close compression writer to flush all data
if err := p.currentBlob.compressor.Close(); err != nil {
p.cleanupTempFile()
return fmt.Errorf("closing compression writer: %w", err)
}
// Close encryption writer
if err := p.currentBlob.encryptor.Close(); err != nil {
p.cleanupTempFile()
return fmt.Errorf("closing encryption writer: %w", err)
}
// Sync file to ensure all data is written
if err := p.currentBlob.tempFile.Sync(); err != nil {
p.cleanupTempFile()
return fmt.Errorf("syncing temp file: %w", err)
}
// Get the final size (encrypted if applicable)
finalSize, err := p.currentBlob.tempFile.Seek(0, io.SeekCurrent)
if err != nil {
p.cleanupTempFile()
return fmt.Errorf("getting file size: %w", err)
}
// Reset to beginning for reading
if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil {
p.cleanupTempFile()
return fmt.Errorf("seeking to start: %w", err)
}
// Get hash from hasher (of final encrypted data)
finalHash := p.currentBlob.hasher.Sum(nil)
blobHash := hex.EncodeToString(finalHash)
// Create chunk references with offsets
chunkRefs := make([]*BlobChunkRef, 0, len(p.currentBlob.chunks))
for _, chunk := range p.currentBlob.chunks {
chunkRefs = append(chunkRefs, &BlobChunkRef{
ChunkHash: chunk.Hash,
Offset: chunk.Offset,
Length: chunk.Size,
})
}
// Update blob record in database with hash and sizes
if p.repos != nil {
err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return p.repos.Blobs.UpdateFinished(ctx, tx, p.currentBlob.id, blobHash,
p.currentBlob.size, finalSize)
})
if err != nil {
p.cleanupTempFile()
return fmt.Errorf("updating blob record: %w", err)
}
}
// Create finished blob
finished := &FinishedBlob{
ID: p.currentBlob.id,
Hash: blobHash,
Data: nil, // We don't load data into memory anymore
Chunks: chunkRefs,
CreatedTS: p.currentBlob.startTime,
Uncompressed: p.currentBlob.size,
Compressed: finalSize,
}
compressionRatio := float64(finished.Compressed) / float64(finished.Uncompressed)
log.Info("Finalized blob",
"hash", blobHash,
"chunks", len(chunkRefs),
"uncompressed", finished.Uncompressed,
"compressed", finished.Compressed,
"ratio", fmt.Sprintf("%.2f", compressionRatio),
"duration", time.Since(p.currentBlob.startTime))
// Call blob handler if set
if p.blobHandler != nil {
log.Debug("Calling blob handler", "blob_hash", blobHash[:8]+"...")
// Reset file position for handler
if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil {
p.cleanupTempFile()
return fmt.Errorf("seeking for handler: %w", err)
}
// Create a blob reader that includes the data stream
blobWithReader := &BlobWithReader{
FinishedBlob: finished,
Reader: p.currentBlob.tempFile,
TempFile: p.currentBlob.tempFile,
}
if err := p.blobHandler(blobWithReader); err != nil {
p.cleanupTempFile()
return fmt.Errorf("blob handler failed: %w", err)
}
// Note: blob handler is responsible for closing/cleaning up temp file
p.currentBlob = nil
} else {
log.Debug("No blob handler set", "blob_hash", blobHash[:8]+"...")
// No handler, need to read data for legacy behavior
if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil {
p.cleanupTempFile()
return fmt.Errorf("seeking to read data: %w", err)
}
data, err := io.ReadAll(p.currentBlob.tempFile)
if err != nil {
p.cleanupTempFile()
return fmt.Errorf("reading blob data: %w", err)
}
finished.Data = data
p.finishedBlobs = append(p.finishedBlobs, finished)
// Cleanup
p.cleanupTempFile()
p.currentBlob = nil
}
return nil
}
// cleanupTempFile removes the temporary file
func (p *Packer) cleanupTempFile() {
if p.currentBlob != nil && p.currentBlob.tempFile != nil {
name := p.currentBlob.tempFile.Name()
_ = p.currentBlob.tempFile.Close()
_ = os.Remove(name)
}
}
// PackChunks is a convenience method to pack multiple chunks at once
func (p *Packer) PackChunks(chunks []*ChunkRef) error {
for _, chunk := range chunks {
err := p.AddChunk(chunk)
if err == ErrBlobSizeLimitExceeded {
// Finalize current blob and retry
if err := p.FinalizeBlob(); err != nil {
return fmt.Errorf("finalizing blob before retry: %w", err)
}
// Retry the chunk
if err := p.AddChunk(chunk); err != nil {
return fmt.Errorf("adding chunk %s after finalize: %w", chunk.Hash, err)
}
} else if err != nil {
return fmt.Errorf("adding chunk %s: %w", chunk.Hash, err)
}
}
return p.Flush()
}

View File

@@ -0,0 +1,328 @@
package blob
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"io"
"testing"
"filippo.io/age"
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/klauspost/compress/zstd"
)
const (
// Test key from test/insecure-integration-test.key
testPrivateKey = "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5"
testPublicKey = "age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"
)
func TestPacker(t *testing.T) {
// Initialize logger for tests
log.Initialize(log.Config{})
// Parse test identity
identity, err := age.ParseX25519Identity(testPrivateKey)
if err != nil {
t.Fatalf("failed to parse test identity: %v", err)
}
// Create test encryptor using the public key
enc, err := crypto.NewEncryptor([]string{testPublicKey})
if err != nil {
t.Fatalf("failed to create encryptor: %v", err)
}
t.Run("single chunk creates single blob", func(t *testing.T) {
// Create test database
db, err := database.NewTestDB()
if err != nil {
t.Fatalf("failed to create test db: %v", err)
}
defer func() { _ = db.Close() }()
repos := database.NewRepositories(db)
cfg := PackerConfig{
MaxBlobSize: 10 * 1024 * 1024, // 10MB
CompressionLevel: 3,
Encryptor: enc,
Repositories: repos,
}
packer, err := NewPacker(cfg)
if err != nil {
t.Fatalf("failed to create packer: %v", err)
}
// Create a chunk
data := []byte("Hello, World!")
hash := sha256.Sum256(data)
chunk := &ChunkRef{
Hash: hex.EncodeToString(hash[:]),
Data: data,
}
// Add chunk
if err := packer.AddChunk(chunk); err != nil {
t.Fatalf("failed to add chunk: %v", err)
}
// Flush
if err := packer.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
// Get finished blobs
blobs := packer.GetFinishedBlobs()
if len(blobs) != 1 {
t.Fatalf("expected 1 blob, got %d", len(blobs))
}
blob := blobs[0]
if len(blob.Chunks) != 1 {
t.Errorf("expected 1 chunk in blob, got %d", len(blob.Chunks))
}
// Note: Very small data may not compress well
t.Logf("Compression: %d -> %d bytes", blob.Uncompressed, blob.Compressed)
// Decrypt the blob data
decrypted, err := age.Decrypt(bytes.NewReader(blob.Data), identity)
if err != nil {
t.Fatalf("failed to decrypt blob: %v", err)
}
// Decompress the decrypted data
reader, err := zstd.NewReader(decrypted)
if err != nil {
t.Fatalf("failed to create decompressor: %v", err)
}
defer reader.Close()
var decompressed bytes.Buffer
if _, err := io.Copy(&decompressed, reader); err != nil {
t.Fatalf("failed to decompress: %v", err)
}
if !bytes.Equal(decompressed.Bytes(), data) {
t.Error("decompressed data doesn't match original")
}
})
t.Run("multiple chunks packed together", func(t *testing.T) {
// Create test database
db, err := database.NewTestDB()
if err != nil {
t.Fatalf("failed to create test db: %v", err)
}
defer func() { _ = db.Close() }()
repos := database.NewRepositories(db)
cfg := PackerConfig{
MaxBlobSize: 10 * 1024 * 1024, // 10MB
CompressionLevel: 3,
Encryptor: enc,
Repositories: repos,
}
packer, err := NewPacker(cfg)
if err != nil {
t.Fatalf("failed to create packer: %v", err)
}
// Create multiple small chunks
chunks := make([]*ChunkRef, 10)
for i := 0; i < 10; i++ {
data := bytes.Repeat([]byte{byte(i)}, 1000)
hash := sha256.Sum256(data)
chunks[i] = &ChunkRef{
Hash: hex.EncodeToString(hash[:]),
Data: data,
}
}
// Add all chunks
for _, chunk := range chunks {
err := packer.AddChunk(chunk)
if err != nil {
t.Fatalf("failed to add chunk: %v", err)
}
}
// Flush
if err := packer.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
// Should have one blob with all chunks
blobs := packer.GetFinishedBlobs()
if len(blobs) != 1 {
t.Fatalf("expected 1 blob, got %d", len(blobs))
}
if len(blobs[0].Chunks) != 10 {
t.Errorf("expected 10 chunks in blob, got %d", len(blobs[0].Chunks))
}
// Verify offsets are correct
expectedOffset := int64(0)
for i, chunkRef := range blobs[0].Chunks {
if chunkRef.Offset != expectedOffset {
t.Errorf("chunk %d: expected offset %d, got %d", i, expectedOffset, chunkRef.Offset)
}
if chunkRef.Length != 1000 {
t.Errorf("chunk %d: expected length 1000, got %d", i, chunkRef.Length)
}
expectedOffset += chunkRef.Length
}
})
t.Run("blob size limit enforced", func(t *testing.T) {
// Create test database
db, err := database.NewTestDB()
if err != nil {
t.Fatalf("failed to create test db: %v", err)
}
defer func() { _ = db.Close() }()
repos := database.NewRepositories(db)
// Small blob size limit to force multiple blobs
cfg := PackerConfig{
MaxBlobSize: 5000, // 5KB max
CompressionLevel: 3,
Encryptor: enc,
Repositories: repos,
}
packer, err := NewPacker(cfg)
if err != nil {
t.Fatalf("failed to create packer: %v", err)
}
// Create chunks that will exceed the limit
chunks := make([]*ChunkRef, 10)
for i := 0; i < 10; i++ {
data := bytes.Repeat([]byte{byte(i)}, 1000) // 1KB each
hash := sha256.Sum256(data)
chunks[i] = &ChunkRef{
Hash: hex.EncodeToString(hash[:]),
Data: data,
}
}
blobCount := 0
// Add chunks and handle size limit errors
for _, chunk := range chunks {
err := packer.AddChunk(chunk)
if err == ErrBlobSizeLimitExceeded {
// Finalize current blob
if err := packer.FinalizeBlob(); err != nil {
t.Fatalf("failed to finalize blob: %v", err)
}
blobCount++
// Retry adding the chunk
if err := packer.AddChunk(chunk); err != nil {
t.Fatalf("failed to add chunk after finalize: %v", err)
}
} else if err != nil {
t.Fatalf("failed to add chunk: %v", err)
}
}
// Flush remaining
if err := packer.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
// Get all blobs
blobs := packer.GetFinishedBlobs()
totalBlobs := blobCount + len(blobs)
// Should have multiple blobs due to size limit
if totalBlobs < 2 {
t.Errorf("expected multiple blobs due to size limit, got %d", totalBlobs)
}
// Verify each blob respects size limit (approximately)
for _, blob := range blobs {
if blob.Compressed > 6000 { // Allow some overhead
t.Errorf("blob size %d exceeds limit", blob.Compressed)
}
}
})
t.Run("with encryption", func(t *testing.T) {
// Create test database
db, err := database.NewTestDB()
if err != nil {
t.Fatalf("failed to create test db: %v", err)
}
defer func() { _ = db.Close() }()
repos := database.NewRepositories(db)
// Generate test identity (using the one from parent test)
cfg := PackerConfig{
MaxBlobSize: 10 * 1024 * 1024, // 10MB
CompressionLevel: 3,
Encryptor: enc,
Repositories: repos,
}
packer, err := NewPacker(cfg)
if err != nil {
t.Fatalf("failed to create packer: %v", err)
}
// Create test data
data := bytes.Repeat([]byte("Test data for encryption!"), 100)
hash := sha256.Sum256(data)
chunk := &ChunkRef{
Hash: hex.EncodeToString(hash[:]),
Data: data,
}
// Add chunk and flush
if err := packer.AddChunk(chunk); err != nil {
t.Fatalf("failed to add chunk: %v", err)
}
if err := packer.Flush(); err != nil {
t.Fatalf("failed to flush: %v", err)
}
// Get blob
blobs := packer.GetFinishedBlobs()
if len(blobs) != 1 {
t.Fatalf("expected 1 blob, got %d", len(blobs))
}
blob := blobs[0]
// Decrypt the blob
decrypted, err := age.Decrypt(bytes.NewReader(blob.Data), identity)
if err != nil {
t.Fatalf("failed to decrypt blob: %v", err)
}
var decryptedData bytes.Buffer
if _, err := decryptedData.ReadFrom(decrypted); err != nil {
t.Fatalf("failed to read decrypted data: %v", err)
}
// Decompress
reader, err := zstd.NewReader(&decryptedData)
if err != nil {
t.Fatalf("failed to create decompressor: %v", err)
}
defer reader.Close()
var decompressed bytes.Buffer
if _, err := decompressed.ReadFrom(reader); err != nil {
t.Fatalf("failed to decompress: %v", err)
}
// Verify data
if !bytes.Equal(decompressed.Bytes(), data) {
t.Error("decrypted and decompressed data doesn't match original")
}
})
}

146
internal/chunker/chunker.go Normal file
View File

@@ -0,0 +1,146 @@
package chunker
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"github.com/jotfs/fastcdc-go"
)
// Chunk represents a single chunk of data
type Chunk struct {
Hash string // Content hash of the chunk
Data []byte // Chunk data
Offset int64 // Offset in the original file
Size int64 // Size of the chunk
}
// Chunker provides content-defined chunking using FastCDC
type Chunker struct {
avgChunkSize int
minChunkSize int
maxChunkSize int
}
// NewChunker creates a new chunker with the specified average chunk size
func NewChunker(avgChunkSize int64) *Chunker {
// FastCDC recommends min = avg/4 and max = avg*4
return &Chunker{
avgChunkSize: int(avgChunkSize),
minChunkSize: int(avgChunkSize / 4),
maxChunkSize: int(avgChunkSize * 4),
}
}
// ChunkReader splits the reader into content-defined chunks
func (c *Chunker) ChunkReader(r io.Reader) ([]Chunk, error) {
opts := fastcdc.Options{
MinSize: c.minChunkSize,
AverageSize: c.avgChunkSize,
MaxSize: c.maxChunkSize,
}
chunker, err := fastcdc.NewChunker(r, opts)
if err != nil {
return nil, fmt.Errorf("creating chunker: %w", err)
}
var chunks []Chunk
offset := int64(0)
for {
chunk, err := chunker.Next()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("reading chunk: %w", err)
}
// Calculate hash
hash := sha256.Sum256(chunk.Data)
// Make a copy of the data since FastCDC reuses the buffer
chunkData := make([]byte, len(chunk.Data))
copy(chunkData, chunk.Data)
chunks = append(chunks, Chunk{
Hash: hex.EncodeToString(hash[:]),
Data: chunkData,
Offset: offset,
Size: int64(len(chunk.Data)),
})
offset += int64(len(chunk.Data))
}
return chunks, nil
}
// ChunkCallback is called for each chunk as it's processed
type ChunkCallback func(chunk Chunk) error
// ChunkReaderStreaming splits the reader into chunks and calls the callback for each
func (c *Chunker) ChunkReaderStreaming(r io.Reader, callback ChunkCallback) error {
opts := fastcdc.Options{
MinSize: c.minChunkSize,
AverageSize: c.avgChunkSize,
MaxSize: c.maxChunkSize,
}
chunker, err := fastcdc.NewChunker(r, opts)
if err != nil {
return fmt.Errorf("creating chunker: %w", err)
}
offset := int64(0)
for {
chunk, err := chunker.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("reading chunk: %w", err)
}
// Calculate hash
hash := sha256.Sum256(chunk.Data)
// Make a copy of the data since FastCDC reuses the buffer
chunkData := make([]byte, len(chunk.Data))
copy(chunkData, chunk.Data)
if err := callback(Chunk{
Hash: hex.EncodeToString(hash[:]),
Data: chunkData,
Offset: offset,
Size: int64(len(chunk.Data)),
}); err != nil {
return fmt.Errorf("callback error: %w", err)
}
offset += int64(len(chunk.Data))
}
return nil
}
// ChunkFile splits a file into content-defined chunks
func (c *Chunker) ChunkFile(path string) ([]Chunk, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("opening file: %w", err)
}
defer func() {
if err := file.Close(); err != nil && err.Error() != "invalid argument" {
// Log error or handle as needed
_ = err
}
}()
return c.ChunkReader(file)
}

View File

@@ -0,0 +1,77 @@
package chunker
import (
"bytes"
"testing"
)
func TestChunkerExpectedChunkCount(t *testing.T) {
tests := []struct {
name string
fileSize int
avgChunkSize int64
minExpected int
maxExpected int
}{
{
name: "1MB file with 64KB average",
fileSize: 1024 * 1024,
avgChunkSize: 64 * 1024,
minExpected: 8, // At least half the expected count
maxExpected: 32, // At most double the expected count
},
{
name: "10MB file with 256KB average",
fileSize: 10 * 1024 * 1024,
avgChunkSize: 256 * 1024,
minExpected: 10, // FastCDC may produce larger chunks
maxExpected: 80,
},
{
name: "512KB file with 64KB average",
fileSize: 512 * 1024,
avgChunkSize: 64 * 1024,
minExpected: 4, // ~8 expected
maxExpected: 16,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chunker := NewChunker(tt.avgChunkSize)
// Create data with some variation to trigger chunk boundaries
data := make([]byte, tt.fileSize)
for i := 0; i < len(data); i++ {
// Use a pattern that should create boundaries
data[i] = byte((i * 17) ^ (i >> 5))
}
chunks, err := chunker.ChunkReader(bytes.NewReader(data))
if err != nil {
t.Fatalf("chunking failed: %v", err)
}
t.Logf("Created %d chunks for %d bytes with %d average chunk size",
len(chunks), tt.fileSize, tt.avgChunkSize)
if len(chunks) < tt.minExpected {
t.Errorf("too few chunks: got %d, expected at least %d",
len(chunks), tt.minExpected)
}
if len(chunks) > tt.maxExpected {
t.Errorf("too many chunks: got %d, expected at most %d",
len(chunks), tt.maxExpected)
}
// Verify chunks reconstruct to original
var reconstructed []byte
for _, chunk := range chunks {
reconstructed = append(reconstructed, chunk.Data...)
}
if !bytes.Equal(data, reconstructed) {
t.Error("reconstructed data doesn't match original")
}
})
}
}

View File

@@ -0,0 +1,128 @@
package chunker
import (
"bytes"
"crypto/rand"
"testing"
)
func TestChunker(t *testing.T) {
t.Run("small file produces single chunk", func(t *testing.T) {
chunker := NewChunker(1024 * 1024) // 1MB average
data := bytes.Repeat([]byte("hello"), 100) // 500 bytes
chunks, err := chunker.ChunkReader(bytes.NewReader(data))
if err != nil {
t.Fatalf("chunking failed: %v", err)
}
if len(chunks) != 1 {
t.Errorf("expected 1 chunk, got %d", len(chunks))
}
if chunks[0].Size != int64(len(data)) {
t.Errorf("expected chunk size %d, got %d", len(data), chunks[0].Size)
}
})
t.Run("large file produces multiple chunks", func(t *testing.T) {
chunker := NewChunker(256 * 1024) // 256KB average chunk size
// Generate 2MB of random data
data := make([]byte, 2*1024*1024)
if _, err := rand.Read(data); err != nil {
t.Fatalf("failed to generate random data: %v", err)
}
chunks, err := chunker.ChunkReader(bytes.NewReader(data))
if err != nil {
t.Fatalf("chunking failed: %v", err)
}
// Should produce multiple chunks - with FastCDC we expect around 8 chunks for 2MB with 256KB average
if len(chunks) < 4 || len(chunks) > 16 {
t.Errorf("expected 4-16 chunks, got %d", len(chunks))
}
// Verify chunks reconstruct original data
var reconstructed []byte
for _, chunk := range chunks {
reconstructed = append(reconstructed, chunk.Data...)
}
if !bytes.Equal(data, reconstructed) {
t.Error("reconstructed data doesn't match original")
}
// Verify offsets
var expectedOffset int64
for i, chunk := range chunks {
if chunk.Offset != expectedOffset {
t.Errorf("chunk %d: expected offset %d, got %d", i, expectedOffset, chunk.Offset)
}
expectedOffset += chunk.Size
}
})
t.Run("deterministic chunking", func(t *testing.T) {
chunker1 := NewChunker(256 * 1024)
chunker2 := NewChunker(256 * 1024)
// Use deterministic data
data := bytes.Repeat([]byte("abcdefghijklmnopqrstuvwxyz"), 20000) // ~520KB
chunks1, err := chunker1.ChunkReader(bytes.NewReader(data))
if err != nil {
t.Fatalf("chunking failed: %v", err)
}
chunks2, err := chunker2.ChunkReader(bytes.NewReader(data))
if err != nil {
t.Fatalf("chunking failed: %v", err)
}
// Should produce same chunks
if len(chunks1) != len(chunks2) {
t.Fatalf("different number of chunks: %d vs %d", len(chunks1), len(chunks2))
}
for i := range chunks1 {
if chunks1[i].Hash != chunks2[i].Hash {
t.Errorf("chunk %d: different hashes", i)
}
if chunks1[i].Size != chunks2[i].Size {
t.Errorf("chunk %d: different sizes", i)
}
}
})
}
func TestChunkBoundaries(t *testing.T) {
chunker := NewChunker(256 * 1024) // 256KB average
// FastCDC uses avg/4 for min and avg*4 for max
avgSize := int64(256 * 1024)
minSize := avgSize / 4
maxSize := avgSize * 4
// Test that minimum chunk size is respected
data := make([]byte, minSize+1024)
if _, err := rand.Read(data); err != nil {
t.Fatalf("failed to generate random data: %v", err)
}
chunks, err := chunker.ChunkReader(bytes.NewReader(data))
if err != nil {
t.Fatalf("chunking failed: %v", err)
}
for i, chunk := range chunks {
// Last chunk can be smaller than minimum
if i < len(chunks)-1 && chunk.Size < minSize {
t.Errorf("chunk %d size %d is below minimum %d", i, chunk.Size, minSize)
}
if chunk.Size > maxSize {
t.Errorf("chunk %d size %d exceeds maximum %d", i, chunk.Size, maxSize)
}
}
}

View File

@@ -3,17 +3,22 @@ package cli
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log"
"go.uber.org/fx"
)
// AppOptions contains common options for creating the fx application
type AppOptions struct {
ConfigPath string
LogOptions log.LogOptions
Modules []fx.Option
Invokes []fx.Option
}
@@ -32,9 +37,12 @@ func setupGlobals(lc fx.Lifecycle, g *globals.Globals) {
func NewApp(opts AppOptions) *fx.App {
baseModules := []fx.Option{
fx.Supply(config.ConfigPath(opts.ConfigPath)),
fx.Supply(opts.LogOptions),
fx.Provide(globals.New),
fx.Provide(log.New),
config.Module,
database.Module,
log.Module,
fx.Invoke(setupGlobals),
fx.NopLogger,
}
@@ -47,18 +55,50 @@ func NewApp(opts AppOptions) *fx.App {
// RunApp starts and stops the fx application within the given context
func RunApp(ctx context.Context, app *fx.App) error {
// Set up signal handling for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
// Create a context that will be cancelled on signal
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Start the app
if err := app.Start(ctx); err != nil {
return fmt.Errorf("failed to start app: %w", err)
}
defer func() {
if err := app.Stop(ctx); err != nil {
fmt.Printf("error stopping app: %v\n", err)
// Handle shutdown
shutdownComplete := make(chan struct{})
go func() {
defer close(shutdownComplete)
<-sigChan
log.Notice("Received interrupt signal, shutting down gracefully...")
// Create a timeout context for shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := app.Stop(shutdownCtx); err != nil {
log.Error("Error during shutdown", "error", err)
}
}()
// Wait for context cancellation
<-ctx.Done()
return nil
// Wait for either the signal handler to complete shutdown or the app to request shutdown
select {
case <-shutdownComplete:
// Shutdown completed via signal
return nil
case <-ctx.Done():
// Context cancelled (shouldn't happen in normal operation)
if err := app.Stop(context.Background()); err != nil {
log.Error("Error stopping app", "error", err)
}
return ctx.Err()
case <-app.Done():
// App finished running (e.g., backup completed)
return nil
}
}
// RunWithApp is a helper that creates and runs an fx app with the given options

View File

@@ -4,10 +4,15 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"git.eeqj.de/sneak/vaultik/internal/backup"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/s3"
"github.com/spf13/cobra"
"go.uber.org/fx"
)
@@ -20,6 +25,18 @@ type BackupOptions struct {
Prune bool
}
// BackupApp contains all dependencies needed for running backups
type BackupApp struct {
Globals *globals.Globals
Config *config.Config
Repositories *database.Repositories
ScannerFactory backup.ScannerFactory
S3Client *s3.Client
DB *database.DB
Lifecycle fx.Lifecycle
Shutdowner fx.Shutdowner
}
// NewBackupCommand creates the backup command
func NewBackupCommand() *cobra.Command {
opts := &BackupOptions{}
@@ -59,25 +76,212 @@ a path using --config or by setting VAULTIK_CONFIG to a path.`,
}
func runBackup(ctx context.Context, opts *BackupOptions) error {
rootFlags := GetRootFlags()
return RunWithApp(ctx, AppOptions{
ConfigPath: opts.ConfigPath,
LogOptions: log.LogOptions{
Verbose: rootFlags.Verbose,
Debug: rootFlags.Debug,
Cron: opts.Cron,
},
Modules: []fx.Option{
backup.Module,
s3.Module,
fx.Provide(fx.Annotate(
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
scannerFactory backup.ScannerFactory, s3Client *s3.Client, db *database.DB,
lc fx.Lifecycle, shutdowner fx.Shutdowner) *BackupApp {
return &BackupApp{
Globals: g,
Config: cfg,
Repositories: repos,
ScannerFactory: scannerFactory,
S3Client: s3Client,
DB: db,
Lifecycle: lc,
Shutdowner: shutdowner,
}
},
)),
},
Invokes: []fx.Option{
fx.Invoke(func(g *globals.Globals, cfg *config.Config, repos *database.Repositories) error {
// TODO: Implement backup logic
fmt.Printf("Running backup with config: %s\n", opts.ConfigPath)
fmt.Printf("Version: %s, Commit: %s\n", g.Version, g.Commit)
fmt.Printf("Index path: %s\n", cfg.IndexPath)
if opts.Daemon {
fmt.Println("Running in daemon mode")
}
if opts.Cron {
fmt.Println("Running in cron mode")
}
if opts.Prune {
fmt.Println("Pruning enabled - will delete old snapshots after backup")
}
return nil
fx.Invoke(func(app *BackupApp, lc fx.Lifecycle) {
// Create a cancellable context for the backup
backupCtx, backupCancel := context.WithCancel(context.Background())
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
// Start the backup in a goroutine
go func() {
// Run the backup
if err := app.runBackup(backupCtx, opts); err != nil {
if err != context.Canceled {
log.Error("Backup failed", "error", err)
}
}
// Shutdown the app when backup completes
if err := app.Shutdowner.Shutdown(); err != nil {
log.Error("Failed to shutdown", "error", err)
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
log.Debug("Stopping backup")
// Cancel the backup context
backupCancel()
return nil
},
})
}),
},
})
}
// runBackup executes the backup operation
func (app *BackupApp) runBackup(ctx context.Context, opts *BackupOptions) error {
log.Info("Starting backup",
"config", opts.ConfigPath,
"version", app.Globals.Version,
"commit", app.Globals.Commit,
"index_path", app.Config.IndexPath,
)
if opts.Daemon {
log.Info("Running in daemon mode")
// TODO: Implement daemon mode with inotify
return fmt.Errorf("daemon mode not yet implemented")
}
// Resolve source directories to absolute paths
resolvedDirs := make([]string, 0, len(app.Config.SourceDirs))
for _, dir := range app.Config.SourceDirs {
absPath, err := filepath.Abs(dir)
if err != nil {
return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
}
// Resolve symlinks
resolvedPath, err := filepath.EvalSymlinks(absPath)
if err != nil {
// If the path doesn't exist yet, use the absolute path
if os.IsNotExist(err) {
resolvedPath = absPath
} else {
return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err)
}
}
resolvedDirs = append(resolvedDirs, resolvedPath)
}
// Create scanner with progress enabled (unless in cron mode)
scanner := app.ScannerFactory(backup.ScannerParams{
EnableProgress: !opts.Cron,
})
// Perform a single backup run
log.Notice("Starting backup", "source_dirs", len(resolvedDirs))
for i, dir := range resolvedDirs {
log.Info("Source directory", "index", i+1, "path", dir)
}
totalFiles := 0
totalBytes := int64(0)
totalChunks := 0
totalBlobs := 0
// Create a new snapshot at the beginning of backup
hostname := app.Config.Hostname
if hostname == "" {
hostname, _ = os.Hostname()
}
// Create encryptor if age recipients are configured
var encryptor backup.Encryptor
if len(app.Config.AgeRecipients) > 0 {
cryptoEncryptor, err := crypto.NewEncryptor(app.Config.AgeRecipients)
if err != nil {
return fmt.Errorf("creating encryptor: %w", err)
}
encryptor = cryptoEncryptor
}
snapshotManager := backup.NewSnapshotManager(app.Repositories, app.S3Client, encryptor)
snapshotID, err := snapshotManager.CreateSnapshot(ctx, hostname, app.Globals.Version)
if err != nil {
return fmt.Errorf("creating snapshot: %w", err)
}
log.Info("Created snapshot", "snapshot_id", snapshotID)
for _, dir := range resolvedDirs {
// Check if context is cancelled
select {
case <-ctx.Done():
log.Info("Backup cancelled")
return ctx.Err()
default:
}
log.Info("Scanning directory", "path", dir)
result, err := scanner.Scan(ctx, dir, snapshotID)
if err != nil {
return fmt.Errorf("failed to scan %s: %w", dir, err)
}
totalFiles += result.FilesScanned
totalBytes += result.BytesScanned
totalChunks += result.ChunksCreated
totalBlobs += result.BlobsCreated
log.Info("Directory scan complete",
"path", dir,
"files", result.FilesScanned,
"files_skipped", result.FilesSkipped,
"bytes", result.BytesScanned,
"bytes_skipped", result.BytesSkipped,
"chunks", result.ChunksCreated,
"blobs", result.BlobsCreated,
"duration", result.EndTime.Sub(result.StartTime))
}
// Update snapshot statistics
stats := backup.BackupStats{
FilesScanned: totalFiles,
BytesScanned: totalBytes,
ChunksCreated: totalChunks,
BlobsCreated: totalBlobs,
BytesUploaded: totalBytes, // TODO: Track actual uploaded bytes
}
if err := snapshotManager.UpdateSnapshotStats(ctx, snapshotID, stats); err != nil {
return fmt.Errorf("updating snapshot stats: %w", err)
}
// Mark snapshot as complete
if err := snapshotManager.CompleteSnapshot(ctx, snapshotID); err != nil {
return fmt.Errorf("completing snapshot: %w", err)
}
// Export snapshot metadata
// Export snapshot metadata without closing the database
// The export function should handle its own database connection
if err := snapshotManager.ExportSnapshotMetadata(ctx, app.Config.IndexPath, snapshotID); err != nil {
return fmt.Errorf("exporting snapshot metadata: %w", err)
}
log.Notice("Backup complete",
"snapshot_id", snapshotID,
"total_files", totalFiles,
"total_bytes", totalBytes,
"total_chunks", totalChunks,
"total_blobs", totalBlobs)
if opts.Prune {
log.Info("Pruning enabled - will delete old snapshots after backup")
// TODO: Implement pruning
}
return nil
}

View File

@@ -4,6 +4,14 @@ import (
"github.com/spf13/cobra"
)
// RootFlags holds global flags
type RootFlags struct {
Verbose bool
Debug bool
}
var rootFlags RootFlags
// NewRootCommand creates the root cobra command
func NewRootCommand() *cobra.Command {
cmd := &cobra.Command{
@@ -15,6 +23,10 @@ on the source system.`,
SilenceUsage: true,
}
// Add global flags
cmd.PersistentFlags().BoolVarP(&rootFlags.Verbose, "verbose", "v", false, "Enable verbose output")
cmd.PersistentFlags().BoolVar(&rootFlags.Debug, "debug", false, "Enable debug output")
// Add subcommands
cmd.AddCommand(
NewBackupCommand(),
@@ -27,3 +39,8 @@ on the source system.`,
return cmd
}
// GetRootFlags returns the global flags
func GetRootFlags() RootFlags {
return rootFlags
}

View File

@@ -11,10 +11,10 @@ import (
// Config represents the application configuration
type Config struct {
AgeRecipient string `yaml:"age_recipient"`
AgeRecipients []string `yaml:"age_recipients"`
BackupInterval time.Duration `yaml:"backup_interval"`
BlobSizeLimit int64 `yaml:"blob_size_limit"`
ChunkSize int64 `yaml:"chunk_size"`
BlobSizeLimit Size `yaml:"blob_size_limit"`
ChunkSize Size `yaml:"chunk_size"`
Exclude []string `yaml:"exclude"`
FullScanInterval time.Duration `yaml:"full_scan_interval"`
Hostname string `yaml:"hostname"`
@@ -35,7 +35,7 @@ type S3Config struct {
SecretAccessKey string `yaml:"secret_access_key"`
Region string `yaml:"region"`
UseSSL bool `yaml:"use_ssl"`
PartSize int64 `yaml:"part_size"`
PartSize Size `yaml:"part_size"`
}
// ConfigPath wraps the config file path for fx injection
@@ -64,8 +64,8 @@ func Load(path string) (*Config, error) {
cfg := &Config{
// Set defaults
BlobSizeLimit: 10 * 1024 * 1024 * 1024, // 10GB
ChunkSize: 10 * 1024 * 1024, // 10MB
BlobSizeLimit: Size(10 * 1024 * 1024 * 1024), // 10GB
ChunkSize: Size(10 * 1024 * 1024), // 10MB
BackupInterval: 1 * time.Hour,
FullScanInterval: 24 * time.Hour,
MinTimeBetweenRun: 15 * time.Minute,
@@ -97,7 +97,7 @@ func Load(path string) (*Config, error) {
cfg.S3.Region = "us-east-1"
}
if cfg.S3.PartSize == 0 {
cfg.S3.PartSize = 5 * 1024 * 1024 // 5MB
cfg.S3.PartSize = Size(5 * 1024 * 1024) // 5MB
}
if err := cfg.Validate(); err != nil {
@@ -109,8 +109,8 @@ func Load(path string) (*Config, error) {
// Validate checks if the configuration is valid
func (c *Config) Validate() error {
if c.AgeRecipient == "" {
return fmt.Errorf("age_recipient is required")
if len(c.AgeRecipients) == 0 {
return fmt.Errorf("at least one age_recipient is required")
}
if len(c.SourceDirs) == 0 {
@@ -133,11 +133,11 @@ func (c *Config) Validate() error {
return fmt.Errorf("s3.secret_access_key is required")
}
if c.ChunkSize < 1024*1024 { // 1MB minimum
if c.ChunkSize.Int64() < 1024*1024 { // 1MB minimum
return fmt.Errorf("chunk_size must be at least 1MB")
}
if c.BlobSizeLimit < c.ChunkSize {
if c.BlobSizeLimit.Int64() < c.ChunkSize.Int64() {
return fmt.Errorf("blob_size_limit must be at least chunk_size")
}

View File

@@ -6,6 +6,12 @@ import (
"testing"
)
const (
TEST_SNEAK_AGE_PUBLIC_KEY = "age1278m9q7dp3chsh2dcy82qk27v047zywyvtxwnj4cvt0z65jw6a7q5dqhfj"
TEST_INTEGRATION_AGE_PUBLIC_KEY = "age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"
TEST_INTEGRATION_AGE_PRIVATE_KEY = "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5"
)
func TestMain(m *testing.M) {
// Set up test environment
testConfigPath := filepath.Join("..", "..", "test", "config.yaml")
@@ -32,8 +38,11 @@ func TestConfigLoad(t *testing.T) {
}
// Basic validation
if cfg.AgeRecipient != "age1xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" {
t.Errorf("Expected age recipient to be set, got '%s'", cfg.AgeRecipient)
if len(cfg.AgeRecipients) != 2 {
t.Errorf("Expected 2 age recipients, got %d", len(cfg.AgeRecipients))
}
if cfg.AgeRecipients[0] != TEST_SNEAK_AGE_PUBLIC_KEY {
t.Errorf("Expected first age recipient to be %s, got '%s'", TEST_SNEAK_AGE_PUBLIC_KEY, cfg.AgeRecipients[0])
}
if len(cfg.SourceDirs) != 2 {

45
internal/config/size.go Normal file
View File

@@ -0,0 +1,45 @@
package config
import (
"fmt"
"github.com/dustin/go-humanize"
)
// Size is a custom type that can unmarshal from both int64 and string
type Size int64
// UnmarshalYAML implements yaml.Unmarshaler for Size
func (s *Size) UnmarshalYAML(unmarshal func(interface{}) error) error {
// Try to unmarshal as int64 first
var intVal int64
if err := unmarshal(&intVal); err == nil {
*s = Size(intVal)
return nil
}
// Try to unmarshal as string
var strVal string
if err := unmarshal(&strVal); err != nil {
return fmt.Errorf("size must be a number or string")
}
// Parse the string using go-humanize
bytes, err := humanize.ParseBytes(strVal)
if err != nil {
return fmt.Errorf("invalid size format: %w", err)
}
*s = Size(bytes)
return nil
}
// Int64 returns the size as int64
func (s Size) Int64() int64 {
return int64(s)
}
// String returns the size as a human-readable string
func (s Size) String() string {
return humanize.Bytes(uint64(s))
}

View File

@@ -0,0 +1,125 @@
package crypto
import (
"bytes"
"fmt"
"io"
"sync"
"filippo.io/age"
)
// Encryptor provides thread-safe encryption using age
type Encryptor struct {
recipients []age.Recipient
mu sync.RWMutex
}
// NewEncryptor creates a new encryptor with the given age public keys
func NewEncryptor(publicKeys []string) (*Encryptor, error) {
if len(publicKeys) == 0 {
return nil, fmt.Errorf("at least one recipient is required")
}
recipients := make([]age.Recipient, 0, len(publicKeys))
for _, key := range publicKeys {
recipient, err := age.ParseX25519Recipient(key)
if err != nil {
return nil, fmt.Errorf("parsing age recipient %s: %w", key, err)
}
recipients = append(recipients, recipient)
}
return &Encryptor{
recipients: recipients,
}, nil
}
// Encrypt encrypts data using age encryption
func (e *Encryptor) Encrypt(data []byte) ([]byte, error) {
e.mu.RLock()
recipients := e.recipients
e.mu.RUnlock()
var buf bytes.Buffer
// Create encrypted writer for all recipients
w, err := age.Encrypt(&buf, recipients...)
if err != nil {
return nil, fmt.Errorf("creating encrypted writer: %w", err)
}
// Write data
if _, err := w.Write(data); err != nil {
return nil, fmt.Errorf("writing encrypted data: %w", err)
}
// Close to flush
if err := w.Close(); err != nil {
return nil, fmt.Errorf("closing encrypted writer: %w", err)
}
return buf.Bytes(), nil
}
// EncryptStream encrypts data from reader to writer
func (e *Encryptor) EncryptStream(dst io.Writer, src io.Reader) error {
e.mu.RLock()
recipients := e.recipients
e.mu.RUnlock()
// Create encrypted writer for all recipients
w, err := age.Encrypt(dst, recipients...)
if err != nil {
return fmt.Errorf("creating encrypted writer: %w", err)
}
// Copy data
if _, err := io.Copy(w, src); err != nil {
return fmt.Errorf("copying encrypted data: %w", err)
}
// Close to flush
if err := w.Close(); err != nil {
return fmt.Errorf("closing encrypted writer: %w", err)
}
return nil
}
// EncryptWriter creates a writer that encrypts data written to it
func (e *Encryptor) EncryptWriter(dst io.Writer) (io.WriteCloser, error) {
e.mu.RLock()
recipients := e.recipients
e.mu.RUnlock()
// Create encrypted writer for all recipients
w, err := age.Encrypt(dst, recipients...)
if err != nil {
return nil, fmt.Errorf("creating encrypted writer: %w", err)
}
return w, nil
}
// UpdateRecipients updates the recipients (thread-safe)
func (e *Encryptor) UpdateRecipients(publicKeys []string) error {
if len(publicKeys) == 0 {
return fmt.Errorf("at least one recipient is required")
}
recipients := make([]age.Recipient, 0, len(publicKeys))
for _, key := range publicKeys {
recipient, err := age.ParseX25519Recipient(key)
if err != nil {
return fmt.Errorf("parsing age recipient %s: %w", key, err)
}
recipients = append(recipients, recipient)
}
e.mu.Lock()
e.recipients = recipients
e.mu.Unlock()
return nil
}

View File

@@ -0,0 +1,157 @@
package crypto
import (
"bytes"
"testing"
"filippo.io/age"
)
func TestEncryptor(t *testing.T) {
// Generate a test key pair
identity, err := age.GenerateX25519Identity()
if err != nil {
t.Fatalf("failed to generate identity: %v", err)
}
publicKey := identity.Recipient().String()
// Create encryptor
enc, err := NewEncryptor([]string{publicKey})
if err != nil {
t.Fatalf("failed to create encryptor: %v", err)
}
// Test data
plaintext := []byte("Hello, World! This is a test message.")
// Encrypt
ciphertext, err := enc.Encrypt(plaintext)
if err != nil {
t.Fatalf("failed to encrypt: %v", err)
}
// Verify it's actually encrypted (should be larger and different)
if bytes.Equal(plaintext, ciphertext) {
t.Error("ciphertext equals plaintext")
}
// Decrypt to verify
r, err := age.Decrypt(bytes.NewReader(ciphertext), identity)
if err != nil {
t.Fatalf("failed to decrypt: %v", err)
}
var decrypted bytes.Buffer
if _, err := decrypted.ReadFrom(r); err != nil {
t.Fatalf("failed to read decrypted data: %v", err)
}
if !bytes.Equal(plaintext, decrypted.Bytes()) {
t.Error("decrypted data doesn't match original")
}
}
func TestEncryptorMultipleRecipients(t *testing.T) {
// Generate three test key pairs
identity1, err := age.GenerateX25519Identity()
if err != nil {
t.Fatalf("failed to generate identity1: %v", err)
}
identity2, err := age.GenerateX25519Identity()
if err != nil {
t.Fatalf("failed to generate identity2: %v", err)
}
identity3, err := age.GenerateX25519Identity()
if err != nil {
t.Fatalf("failed to generate identity3: %v", err)
}
publicKeys := []string{
identity1.Recipient().String(),
identity2.Recipient().String(),
identity3.Recipient().String(),
}
// Create encryptor with multiple recipients
enc, err := NewEncryptor(publicKeys)
if err != nil {
t.Fatalf("failed to create encryptor: %v", err)
}
// Test data
plaintext := []byte("Secret message for multiple recipients")
// Encrypt
ciphertext, err := enc.Encrypt(plaintext)
if err != nil {
t.Fatalf("failed to encrypt: %v", err)
}
// Verify each recipient can decrypt
identities := []age.Identity{identity1, identity2, identity3}
for i, identity := range identities {
r, err := age.Decrypt(bytes.NewReader(ciphertext), identity)
if err != nil {
t.Fatalf("recipient %d failed to decrypt: %v", i+1, err)
}
var decrypted bytes.Buffer
if _, err := decrypted.ReadFrom(r); err != nil {
t.Fatalf("recipient %d failed to read decrypted data: %v", i+1, err)
}
if !bytes.Equal(plaintext, decrypted.Bytes()) {
t.Errorf("recipient %d: decrypted data doesn't match original", i+1)
}
}
}
func TestEncryptorUpdateRecipients(t *testing.T) {
// Generate two identities
identity1, _ := age.GenerateX25519Identity()
identity2, _ := age.GenerateX25519Identity()
publicKey1 := identity1.Recipient().String()
publicKey2 := identity2.Recipient().String()
// Create encryptor with first key
enc, err := NewEncryptor([]string{publicKey1})
if err != nil {
t.Fatalf("failed to create encryptor: %v", err)
}
// Encrypt with first key
plaintext := []byte("test data")
ciphertext1, err := enc.Encrypt(plaintext)
if err != nil {
t.Fatalf("failed to encrypt: %v", err)
}
// Update to second key
if err := enc.UpdateRecipients([]string{publicKey2}); err != nil {
t.Fatalf("failed to update recipients: %v", err)
}
// Encrypt with second key
ciphertext2, err := enc.Encrypt(plaintext)
if err != nil {
t.Fatalf("failed to encrypt: %v", err)
}
// First ciphertext should only decrypt with first identity
if _, err := age.Decrypt(bytes.NewReader(ciphertext1), identity1); err != nil {
t.Error("failed to decrypt with identity1")
}
if _, err := age.Decrypt(bytes.NewReader(ciphertext1), identity2); err == nil {
t.Error("should not decrypt with identity2")
}
// Second ciphertext should only decrypt with second identity
if _, err := age.Decrypt(bytes.NewReader(ciphertext2), identity2); err != nil {
t.Error("failed to decrypt with identity2")
}
if _, err := age.Decrypt(bytes.NewReader(ciphertext2), identity1); err == nil {
t.Error("should not decrypt with identity1")
}
}

View File

@@ -16,15 +16,15 @@ func NewBlobChunkRepository(db *DB) *BlobChunkRepository {
func (r *BlobChunkRepository) Create(ctx context.Context, tx *sql.Tx, bc *BlobChunk) error {
query := `
INSERT INTO blob_chunks (blob_hash, chunk_hash, offset, length)
INSERT INTO blob_chunks (blob_id, chunk_hash, offset, length)
VALUES (?, ?, ?, ?)
`
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, bc.BlobHash, bc.ChunkHash, bc.Offset, bc.Length)
_, err = tx.ExecContext(ctx, query, bc.BlobID, bc.ChunkHash, bc.Offset, bc.Length)
} else {
_, err = r.db.ExecWithLock(ctx, query, bc.BlobHash, bc.ChunkHash, bc.Offset, bc.Length)
_, err = r.db.ExecWithLock(ctx, query, bc.BlobID, bc.ChunkHash, bc.Offset, bc.Length)
}
if err != nil {
@@ -34,15 +34,15 @@ func (r *BlobChunkRepository) Create(ctx context.Context, tx *sql.Tx, bc *BlobCh
return nil
}
func (r *BlobChunkRepository) GetByBlobHash(ctx context.Context, blobHash string) ([]*BlobChunk, error) {
func (r *BlobChunkRepository) GetByBlobID(ctx context.Context, blobID string) ([]*BlobChunk, error) {
query := `
SELECT blob_hash, chunk_hash, offset, length
SELECT blob_id, chunk_hash, offset, length
FROM blob_chunks
WHERE blob_hash = ?
WHERE blob_id = ?
ORDER BY offset
`
rows, err := r.db.conn.QueryContext(ctx, query, blobHash)
rows, err := r.db.conn.QueryContext(ctx, query, blobID)
if err != nil {
return nil, fmt.Errorf("querying blob chunks: %w", err)
}
@@ -51,7 +51,7 @@ func (r *BlobChunkRepository) GetByBlobHash(ctx context.Context, blobHash string
var blobChunks []*BlobChunk
for rows.Next() {
var bc BlobChunk
err := rows.Scan(&bc.BlobHash, &bc.ChunkHash, &bc.Offset, &bc.Length)
err := rows.Scan(&bc.BlobID, &bc.ChunkHash, &bc.Offset, &bc.Length)
if err != nil {
return nil, fmt.Errorf("scanning blob chunk: %w", err)
}
@@ -63,26 +63,61 @@ func (r *BlobChunkRepository) GetByBlobHash(ctx context.Context, blobHash string
func (r *BlobChunkRepository) GetByChunkHash(ctx context.Context, chunkHash string) (*BlobChunk, error) {
query := `
SELECT blob_hash, chunk_hash, offset, length
SELECT blob_id, chunk_hash, offset, length
FROM blob_chunks
WHERE chunk_hash = ?
LIMIT 1
`
LogSQL("GetByChunkHash", query, chunkHash)
var bc BlobChunk
err := r.db.conn.QueryRowContext(ctx, query, chunkHash).Scan(
&bc.BlobHash,
&bc.BlobID,
&bc.ChunkHash,
&bc.Offset,
&bc.Length,
)
if err == sql.ErrNoRows {
LogSQL("GetByChunkHash", "No rows found", chunkHash)
return nil, nil
}
if err != nil {
LogSQL("GetByChunkHash", "Error", chunkHash, err)
return nil, fmt.Errorf("querying blob chunk: %w", err)
}
LogSQL("GetByChunkHash", "Found blob", chunkHash, "blob", bc.BlobID)
return &bc, nil
}
// GetByChunkHashTx retrieves a blob chunk within a transaction
func (r *BlobChunkRepository) GetByChunkHashTx(ctx context.Context, tx *sql.Tx, chunkHash string) (*BlobChunk, error) {
query := `
SELECT blob_id, chunk_hash, offset, length
FROM blob_chunks
WHERE chunk_hash = ?
LIMIT 1
`
LogSQL("GetByChunkHashTx", query, chunkHash)
var bc BlobChunk
err := tx.QueryRowContext(ctx, query, chunkHash).Scan(
&bc.BlobID,
&bc.ChunkHash,
&bc.Offset,
&bc.Length,
)
if err == sql.ErrNoRows {
LogSQL("GetByChunkHashTx", "No rows found", chunkHash)
return nil, nil
}
if err != nil {
LogSQL("GetByChunkHashTx", "Error", chunkHash, err)
return nil, fmt.Errorf("querying blob chunk: %w", err)
}
LogSQL("GetByChunkHashTx", "Found blob", chunkHash, "blob", bc.BlobID)
return &bc, nil
}

View File

@@ -14,7 +14,7 @@ func TestBlobChunkRepository(t *testing.T) {
// Test Create
bc1 := &BlobChunk{
BlobHash: "blob1",
BlobID: "blob1-uuid",
ChunkHash: "chunk1",
Offset: 0,
Length: 1024,
@@ -27,7 +27,7 @@ func TestBlobChunkRepository(t *testing.T) {
// Add more chunks to the same blob
bc2 := &BlobChunk{
BlobHash: "blob1",
BlobID: "blob1-uuid",
ChunkHash: "chunk2",
Offset: 1024,
Length: 2048,
@@ -38,7 +38,7 @@ func TestBlobChunkRepository(t *testing.T) {
}
bc3 := &BlobChunk{
BlobHash: "blob1",
BlobID: "blob1-uuid",
ChunkHash: "chunk3",
Offset: 3072,
Length: 512,
@@ -48,8 +48,8 @@ func TestBlobChunkRepository(t *testing.T) {
t.Fatalf("failed to create third blob chunk: %v", err)
}
// Test GetByBlobHash
chunks, err := repo.GetByBlobHash(ctx, "blob1")
// Test GetByBlobID
chunks, err := repo.GetByBlobID(ctx, "blob1-uuid")
if err != nil {
t.Fatalf("failed to get blob chunks: %v", err)
}
@@ -73,8 +73,8 @@ func TestBlobChunkRepository(t *testing.T) {
if bc == nil {
t.Fatal("expected blob chunk, got nil")
}
if bc.BlobHash != "blob1" {
t.Errorf("wrong blob hash: expected blob1, got %s", bc.BlobHash)
if bc.BlobID != "blob1-uuid" {
t.Errorf("wrong blob ID: expected blob1-uuid, got %s", bc.BlobID)
}
if bc.Offset != 1024 {
t.Errorf("wrong offset: expected 1024, got %d", bc.Offset)
@@ -100,10 +100,10 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
// Create chunks across multiple blobs
// Some chunks are shared between blobs (deduplication scenario)
blobChunks := []BlobChunk{
{BlobHash: "blob1", ChunkHash: "chunk1", Offset: 0, Length: 1024},
{BlobHash: "blob1", ChunkHash: "chunk2", Offset: 1024, Length: 1024},
{BlobHash: "blob2", ChunkHash: "chunk2", Offset: 0, Length: 1024}, // chunk2 is shared
{BlobHash: "blob2", ChunkHash: "chunk3", Offset: 1024, Length: 1024},
{BlobID: "blob1-uuid", ChunkHash: "chunk1", Offset: 0, Length: 1024},
{BlobID: "blob1-uuid", ChunkHash: "chunk2", Offset: 1024, Length: 1024},
{BlobID: "blob2-uuid", ChunkHash: "chunk2", Offset: 0, Length: 1024}, // chunk2 is shared
{BlobID: "blob2-uuid", ChunkHash: "chunk3", Offset: 1024, Length: 1024},
}
for _, bc := range blobChunks {
@@ -114,7 +114,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
}
// Verify blob1 chunks
chunks, err := repo.GetByBlobHash(ctx, "blob1")
chunks, err := repo.GetByBlobID(ctx, "blob1-uuid")
if err != nil {
t.Fatalf("failed to get blob1 chunks: %v", err)
}
@@ -123,7 +123,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
}
// Verify blob2 chunks
chunks, err = repo.GetByBlobHash(ctx, "blob2")
chunks, err = repo.GetByBlobID(ctx, "blob2-uuid")
if err != nil {
t.Fatalf("failed to get blob2 chunks: %v", err)
}
@@ -140,7 +140,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
t.Fatal("expected shared chunk, got nil")
}
// GetByChunkHash returns first match, should be blob1
if bc.BlobHash != "blob1" {
t.Errorf("expected blob1 for shared chunk, got %s", bc.BlobHash)
if bc.BlobID != "blob1-uuid" {
t.Errorf("expected blob1-uuid for shared chunk, got %s", bc.BlobID)
}
}

View File

@@ -17,15 +17,27 @@ func NewBlobRepository(db *DB) *BlobRepository {
func (r *BlobRepository) Create(ctx context.Context, tx *sql.Tx, blob *Blob) error {
query := `
INSERT INTO blobs (blob_hash, created_ts)
VALUES (?, ?)
INSERT INTO blobs (id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
var finishedTS, uploadedTS *int64
if blob.FinishedTS != nil {
ts := blob.FinishedTS.Unix()
finishedTS = &ts
}
if blob.UploadedTS != nil {
ts := blob.UploadedTS.Unix()
uploadedTS = &ts
}
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, blob.BlobHash, blob.CreatedTS.Unix())
_, err = tx.ExecContext(ctx, query, blob.ID, blob.Hash, blob.CreatedTS.Unix(),
finishedTS, blob.UncompressedSize, blob.CompressedSize, uploadedTS)
} else {
_, err = r.db.ExecWithLock(ctx, query, blob.BlobHash, blob.CreatedTS.Unix())
_, err = r.db.ExecWithLock(ctx, query, blob.ID, blob.Hash, blob.CreatedTS.Unix(),
finishedTS, blob.UncompressedSize, blob.CompressedSize, uploadedTS)
}
if err != nil {
@@ -37,17 +49,23 @@ func (r *BlobRepository) Create(ctx context.Context, tx *sql.Tx, blob *Blob) err
func (r *BlobRepository) GetByHash(ctx context.Context, hash string) (*Blob, error) {
query := `
SELECT blob_hash, created_ts
SELECT id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts
FROM blobs
WHERE blob_hash = ?
`
var blob Blob
var createdTSUnix int64
var finishedTSUnix, uploadedTSUnix sql.NullInt64
err := r.db.conn.QueryRowContext(ctx, query, hash).Scan(
&blob.BlobHash,
&blob.ID,
&blob.Hash,
&createdTSUnix,
&finishedTSUnix,
&blob.UncompressedSize,
&blob.CompressedSize,
&uploadedTSUnix,
)
if err == sql.ErrNoRows {
@@ -58,39 +76,100 @@ func (r *BlobRepository) GetByHash(ctx context.Context, hash string) (*Blob, err
}
blob.CreatedTS = time.Unix(createdTSUnix, 0)
if finishedTSUnix.Valid {
ts := time.Unix(finishedTSUnix.Int64, 0)
blob.FinishedTS = &ts
}
if uploadedTSUnix.Valid {
ts := time.Unix(uploadedTSUnix.Int64, 0)
blob.UploadedTS = &ts
}
return &blob, nil
}
func (r *BlobRepository) List(ctx context.Context, limit, offset int) ([]*Blob, error) {
// GetByID retrieves a blob by its ID
func (r *BlobRepository) GetByID(ctx context.Context, id string) (*Blob, error) {
query := `
SELECT blob_hash, created_ts
SELECT id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts
FROM blobs
ORDER BY blob_hash
LIMIT ? OFFSET ?
WHERE id = ?
`
rows, err := r.db.conn.QueryContext(ctx, query, limit, offset)
var blob Blob
var createdTSUnix int64
var finishedTSUnix, uploadedTSUnix sql.NullInt64
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&blob.ID,
&blob.Hash,
&createdTSUnix,
&finishedTSUnix,
&blob.UncompressedSize,
&blob.CompressedSize,
&uploadedTSUnix,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("querying blobs: %w", err)
}
defer CloseRows(rows)
var blobs []*Blob
for rows.Next() {
var blob Blob
var createdTSUnix int64
err := rows.Scan(
&blob.BlobHash,
&createdTSUnix,
)
if err != nil {
return nil, fmt.Errorf("scanning blob: %w", err)
}
blob.CreatedTS = time.Unix(createdTSUnix, 0)
blobs = append(blobs, &blob)
return nil, fmt.Errorf("querying blob: %w", err)
}
return blobs, rows.Err()
blob.CreatedTS = time.Unix(createdTSUnix, 0)
if finishedTSUnix.Valid {
ts := time.Unix(finishedTSUnix.Int64, 0)
blob.FinishedTS = &ts
}
if uploadedTSUnix.Valid {
ts := time.Unix(uploadedTSUnix.Int64, 0)
blob.UploadedTS = &ts
}
return &blob, nil
}
// UpdateFinished updates a blob when it's finalized
func (r *BlobRepository) UpdateFinished(ctx context.Context, tx *sql.Tx, id string, hash string, uncompressedSize, compressedSize int64) error {
query := `
UPDATE blobs
SET blob_hash = ?, finished_ts = ?, uncompressed_size = ?, compressed_size = ?
WHERE id = ?
`
now := time.Now().Unix()
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, hash, now, uncompressedSize, compressedSize, id)
} else {
_, err = r.db.ExecWithLock(ctx, query, hash, now, uncompressedSize, compressedSize, id)
}
if err != nil {
return fmt.Errorf("updating blob: %w", err)
}
return nil
}
// UpdateUploaded marks a blob as uploaded
func (r *BlobRepository) UpdateUploaded(ctx context.Context, tx *sql.Tx, id string) error {
query := `
UPDATE blobs
SET uploaded_ts = ?
WHERE id = ?
`
now := time.Now().Unix()
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, now, id)
} else {
_, err = r.db.ExecWithLock(ctx, query, now, id)
}
if err != nil {
return fmt.Errorf("marking blob as uploaded: %w", err)
}
return nil
}

View File

@@ -15,7 +15,8 @@ func TestBlobRepository(t *testing.T) {
// Test Create
blob := &Blob{
BlobHash: "blobhash123",
ID: "test-blob-id-123",
Hash: "blobhash123",
CreatedTS: time.Now().Truncate(time.Second),
}
@@ -25,23 +26,36 @@ func TestBlobRepository(t *testing.T) {
}
// Test GetByHash
retrieved, err := repo.GetByHash(ctx, blob.BlobHash)
retrieved, err := repo.GetByHash(ctx, blob.Hash)
if err != nil {
t.Fatalf("failed to get blob: %v", err)
}
if retrieved == nil {
t.Fatal("expected blob, got nil")
}
if retrieved.BlobHash != blob.BlobHash {
t.Errorf("blob hash mismatch: got %s, want %s", retrieved.BlobHash, blob.BlobHash)
if retrieved.Hash != blob.Hash {
t.Errorf("blob hash mismatch: got %s, want %s", retrieved.Hash, blob.Hash)
}
if !retrieved.CreatedTS.Equal(blob.CreatedTS) {
t.Errorf("created timestamp mismatch: got %v, want %v", retrieved.CreatedTS, blob.CreatedTS)
}
// Test List
// Test GetByID
retrievedByID, err := repo.GetByID(ctx, blob.ID)
if err != nil {
t.Fatalf("failed to get blob by ID: %v", err)
}
if retrievedByID == nil {
t.Fatal("expected blob, got nil")
}
if retrievedByID.ID != blob.ID {
t.Errorf("blob ID mismatch: got %s, want %s", retrievedByID.ID, blob.ID)
}
// Test with second blob
blob2 := &Blob{
BlobHash: "blobhash456",
ID: "test-blob-id-456",
Hash: "blobhash456",
CreatedTS: time.Now().Truncate(time.Second),
}
err = repo.Create(ctx, nil, blob2)
@@ -49,29 +63,45 @@ func TestBlobRepository(t *testing.T) {
t.Fatalf("failed to create second blob: %v", err)
}
blobs, err := repo.List(ctx, 10, 0)
// Test UpdateFinished
now := time.Now()
err = repo.UpdateFinished(ctx, nil, blob.ID, blob.Hash, 1000, 500)
if err != nil {
t.Fatalf("failed to list blobs: %v", err)
}
if len(blobs) != 2 {
t.Errorf("expected 2 blobs, got %d", len(blobs))
t.Fatalf("failed to update blob as finished: %v", err)
}
// Test pagination
blobs, err = repo.List(ctx, 1, 0)
// Verify update
updated, err := repo.GetByID(ctx, blob.ID)
if err != nil {
t.Fatalf("failed to list blobs with limit: %v", err)
t.Fatalf("failed to get updated blob: %v", err)
}
if len(blobs) != 1 {
t.Errorf("expected 1 blob with limit, got %d", len(blobs))
if updated.FinishedTS == nil {
t.Fatal("expected finished timestamp to be set")
}
if updated.UncompressedSize != 1000 {
t.Errorf("expected uncompressed size 1000, got %d", updated.UncompressedSize)
}
if updated.CompressedSize != 500 {
t.Errorf("expected compressed size 500, got %d", updated.CompressedSize)
}
blobs, err = repo.List(ctx, 1, 1)
// Test UpdateUploaded
err = repo.UpdateUploaded(ctx, nil, blob.ID)
if err != nil {
t.Fatalf("failed to list blobs with offset: %v", err)
t.Fatalf("failed to update blob as uploaded: %v", err)
}
if len(blobs) != 1 {
t.Errorf("expected 1 blob with offset, got %d", len(blobs))
// Verify upload update
uploaded, err := repo.GetByID(ctx, blob.ID)
if err != nil {
t.Fatalf("failed to get uploaded blob: %v", err)
}
if uploaded.UploadedTS == nil {
t.Fatal("expected uploaded timestamp to be set")
}
// Allow 1 second tolerance for timestamp comparison
if uploaded.UploadedTS.Before(now.Add(-1 * time.Second)) {
t.Error("uploaded timestamp should be around test time")
}
}
@@ -83,7 +113,8 @@ func TestBlobRepositoryDuplicate(t *testing.T) {
repo := NewBlobRepository(db)
blob := &Blob{
BlobHash: "duplicate_blob",
ID: "duplicate-test-id",
Hash: "duplicate_blob",
CreatedTS: time.Now().Truncate(time.Second),
}

View File

@@ -4,8 +4,11 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"sync"
"git.eeqj.de/sneak/vaultik/internal/log"
_ "modernc.org/sqlite"
)
@@ -15,23 +18,54 @@ type DB struct {
}
func New(ctx context.Context, path string) (*DB, error) {
conn, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000")
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
// First, try to recover from any stale locks
if err := recoverDatabase(ctx, path); err != nil {
log.Warn("Failed to recover database", "error", err)
}
if err := conn.PingContext(ctx); err != nil {
if closeErr := conn.Close(); closeErr != nil {
Fatal("failed to close database connection: %v", closeErr)
// First attempt with standard WAL mode
conn, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=10000&_locking_mode=NORMAL")
if err == nil {
// Set connection pool settings to ensure proper cleanup
conn.SetMaxOpenConns(1) // SQLite only supports one writer
conn.SetMaxIdleConns(1)
if err := conn.PingContext(ctx); err == nil {
// Success on first try
db := &DB{conn: conn}
if err := db.createSchema(ctx); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("creating schema: %w", err)
}
return db, nil
}
return nil, fmt.Errorf("pinging database: %w", err)
_ = conn.Close()
}
// If first attempt failed, try with TRUNCATE mode to clear any locks
log.Info("Database appears locked, attempting recovery with TRUNCATE mode")
conn, err = sql.Open("sqlite", path+"?_journal_mode=TRUNCATE&_synchronous=NORMAL&_busy_timeout=10000")
if err != nil {
return nil, fmt.Errorf("opening database in recovery mode: %w", err)
}
// Set connection pool settings
conn.SetMaxOpenConns(1)
conn.SetMaxIdleConns(1)
if err := conn.PingContext(ctx); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("database still locked after recovery attempt: %w", err)
}
// Switch back to WAL mode
if _, err := conn.ExecContext(ctx, "PRAGMA journal_mode=WAL"); err != nil {
log.Warn("Failed to switch back to WAL mode", "error", err)
}
db := &DB{conn: conn}
if err := db.createSchema(ctx); err != nil {
if closeErr := conn.Close(); closeErr != nil {
Fatal("failed to close database connection: %v", closeErr)
}
_ = conn.Close()
return nil, fmt.Errorf("creating schema: %w", err)
}
@@ -39,9 +73,68 @@ func New(ctx context.Context, path string) (*DB, error) {
}
func (db *DB) Close() error {
log.Debug("Closing database connection")
if err := db.conn.Close(); err != nil {
Fatal("failed to close database: %v", err)
log.Error("Failed to close database", "error", err)
return fmt.Errorf("failed to close database: %w", err)
}
log.Debug("Database connection closed successfully")
return nil
}
// recoverDatabase attempts to recover a locked database
func recoverDatabase(ctx context.Context, path string) error {
// Check if database file exists
if _, err := os.Stat(path); os.IsNotExist(err) {
// No database file, nothing to recover
return nil
}
// Remove stale lock files
// SQLite creates -wal and -shm files for WAL mode
walPath := path + "-wal"
shmPath := path + "-shm"
journalPath := path + "-journal"
log.Info("Attempting database recovery", "path", path)
// Always remove lock files on startup to ensure clean state
removed := false
// Check for and remove journal file (from non-WAL mode)
if _, err := os.Stat(journalPath); err == nil {
log.Info("Found journal file, removing", "path", journalPath)
if err := os.Remove(journalPath); err != nil {
log.Warn("Failed to remove journal file", "error", err)
} else {
removed = true
}
}
// Remove WAL file
if _, err := os.Stat(walPath); err == nil {
log.Info("Found WAL file, removing", "path", walPath)
if err := os.Remove(walPath); err != nil {
log.Warn("Failed to remove WAL file", "error", err)
} else {
removed = true
}
}
// Remove SHM file
if _, err := os.Stat(shmPath); err == nil {
log.Info("Found shared memory file, removing", "path", shmPath)
if err := os.Remove(shmPath); err != nil {
log.Warn("Failed to remove shared memory file", "error", err)
} else {
removed = true
}
}
if removed {
log.Info("Database lock files removed")
}
return nil
}
@@ -55,18 +148,24 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
// LockForWrite acquires the write lock
func (db *DB) LockForWrite() {
log.Debug("Attempting to acquire write lock")
db.writeLock.Lock()
log.Debug("Write lock acquired")
}
// UnlockWrite releases the write lock
func (db *DB) UnlockWrite() {
log.Debug("Releasing write lock")
db.writeLock.Unlock()
log.Debug("Write lock released")
}
// ExecWithLock executes a write query with the write lock held
func (db *DB) ExecWithLock(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
db.writeLock.Lock()
defer db.writeLock.Unlock()
LogSQL("Execute", query, args...)
return db.conn.ExecContext(ctx, query, args...)
}
@@ -104,16 +203,22 @@ func (db *DB) createSchema(ctx context.Context) error {
);
CREATE TABLE IF NOT EXISTS blobs (
blob_hash TEXT PRIMARY KEY,
created_ts INTEGER NOT NULL
id TEXT PRIMARY KEY,
blob_hash TEXT UNIQUE,
created_ts INTEGER NOT NULL,
finished_ts INTEGER,
uncompressed_size INTEGER NOT NULL DEFAULT 0,
compressed_size INTEGER NOT NULL DEFAULT 0,
uploaded_ts INTEGER
);
CREATE TABLE IF NOT EXISTS blob_chunks (
blob_hash TEXT NOT NULL,
blob_id TEXT NOT NULL,
chunk_hash TEXT NOT NULL,
offset INTEGER NOT NULL,
length INTEGER NOT NULL,
PRIMARY KEY (blob_hash, chunk_hash)
PRIMARY KEY (blob_id, chunk_hash),
FOREIGN KEY (blob_id) REFERENCES blobs(id)
);
CREATE TABLE IF NOT EXISTS chunk_files (
@@ -128,13 +233,38 @@ func (db *DB) createSchema(ctx context.Context) error {
id TEXT PRIMARY KEY,
hostname TEXT NOT NULL,
vaultik_version TEXT NOT NULL,
created_ts INTEGER NOT NULL,
file_count INTEGER NOT NULL,
chunk_count INTEGER NOT NULL,
blob_count INTEGER NOT NULL,
total_size INTEGER NOT NULL,
blob_size INTEGER NOT NULL,
compression_ratio REAL NOT NULL
started_at INTEGER NOT NULL,
completed_at INTEGER,
file_count INTEGER NOT NULL DEFAULT 0,
chunk_count INTEGER NOT NULL DEFAULT 0,
blob_count INTEGER NOT NULL DEFAULT 0,
total_size INTEGER NOT NULL DEFAULT 0,
blob_size INTEGER NOT NULL DEFAULT 0,
compression_ratio REAL NOT NULL DEFAULT 1.0
);
CREATE TABLE IF NOT EXISTS snapshot_files (
snapshot_id TEXT NOT NULL,
file_path TEXT NOT NULL,
PRIMARY KEY (snapshot_id, file_path),
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
FOREIGN KEY (file_path) REFERENCES files(path) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS snapshot_blobs (
snapshot_id TEXT NOT NULL,
blob_id TEXT NOT NULL,
blob_hash TEXT NOT NULL,
PRIMARY KEY (snapshot_id, blob_id),
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS uploads (
blob_hash TEXT PRIMARY KEY,
uploaded_at INTEGER NOT NULL,
size INTEGER NOT NULL,
duration_ms INTEGER NOT NULL
);
`
@@ -146,3 +276,10 @@ func (db *DB) createSchema(ctx context.Context) error {
func NewTestDB() (*DB, error) {
return New(context.Background(), ":memory:")
}
// LogSQL logs SQL queries if debug mode is enabled
func LogSQL(operation, query string, args ...interface{}) {
if strings.Contains(os.Getenv("GODEBUG"), "vaultik") {
log.Debug("SQL "+operation, "query", strings.TrimSpace(query), "args", fmt.Sprintf("%v", args))
}
}

View File

@@ -62,6 +62,36 @@ func (r *FileChunkRepository) GetByPath(ctx context.Context, path string) ([]*Fi
return fileChunks, rows.Err()
}
// GetByPathTx retrieves file chunks within a transaction
func (r *FileChunkRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path string) ([]*FileChunk, error) {
query := `
SELECT path, idx, chunk_hash
FROM file_chunks
WHERE path = ?
ORDER BY idx
`
LogSQL("GetByPathTx", query, path)
rows, err := tx.QueryContext(ctx, query, path)
if err != nil {
return nil, fmt.Errorf("querying file chunks: %w", err)
}
defer CloseRows(rows)
var fileChunks []*FileChunk
for rows.Next() {
var fc FileChunk
err := rows.Scan(&fc.Path, &fc.Idx, &fc.ChunkHash)
if err != nil {
return nil, fmt.Errorf("scanning file chunk: %w", err)
}
fileChunks = append(fileChunks, &fc)
}
LogSQL("GetByPathTx", "Complete", path, "count", len(fileChunks))
return fileChunks, rows.Err()
}
func (r *FileChunkRepository) DeleteByPath(ctx context.Context, tx *sql.Tx, path string) error {
query := `DELETE FROM file_chunks WHERE path = ?`
@@ -81,5 +111,16 @@ func (r *FileChunkRepository) DeleteByPath(ctx context.Context, tx *sql.Tx, path
// GetByFile is an alias for GetByPath for compatibility
func (r *FileChunkRepository) GetByFile(ctx context.Context, path string) ([]*FileChunk, error) {
return r.GetByPath(ctx, path)
LogSQL("GetByFile", "Starting", path)
result, err := r.GetByPath(ctx, path)
LogSQL("GetByFile", "Complete", path, "count", len(result))
return result, err
}
// GetByFileTx retrieves file chunks within a transaction
func (r *FileChunkRepository) GetByFileTx(ctx context.Context, tx *sql.Tx, path string) ([]*FileChunk, error) {
LogSQL("GetByFileTx", "Starting", path)
result, err := r.GetByPathTx(ctx, tx, path)
LogSQL("GetByFileTx", "Complete", path, "count", len(result))
return result, err
}

View File

@@ -31,6 +31,7 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err
var err error
if tx != nil {
LogSQL("Execute", query, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
_, err = tx.ExecContext(ctx, query, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
} else {
_, err = r.db.ExecWithLock(ctx, query, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
@@ -81,6 +82,46 @@ func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, err
return &file, nil
}
func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path string) (*File, error) {
query := `
SELECT path, mtime, ctime, size, mode, uid, gid, link_target
FROM files
WHERE path = ?
`
var file File
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
LogSQL("GetByPathTx QueryRowContext", query, path)
err := tx.QueryRowContext(ctx, query, path).Scan(
&file.Path,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
LogSQL("GetByPathTx Scan complete", query, path)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("querying file: %w", err)
}
file.MTime = time.Unix(mtimeUnix, 0)
file.CTime = time.Unix(ctimeUnix, 0)
if linkTarget.Valid {
file.LinkTarget = linkTarget.String
}
return &file, nil
}
func (r *FileRepository) ListModifiedSince(ctx context.Context, since time.Time) ([]*File, error) {
query := `
SELECT path, mtime, ctime, size, mode, uid, gid, link_target

View File

@@ -35,13 +35,18 @@ type Chunk struct {
// Blob represents a blob record in the database
type Blob struct {
BlobHash string
CreatedTS time.Time
ID string
Hash string // Can be empty until blob is finalized
CreatedTS time.Time
FinishedTS *time.Time // nil if not yet finalized
UncompressedSize int64
CompressedSize int64
UploadedTS *time.Time // nil if not yet uploaded
}
// BlobChunk represents the mapping between blobs and chunks
type BlobChunk struct {
BlobHash string
BlobID string
ChunkHash string
Offset int64
Length int64
@@ -60,7 +65,8 @@ type Snapshot struct {
ID string
Hostname string
VaultikVersion string
CreatedTS time.Time
StartedAt time.Time
CompletedAt *time.Time // nil if still in progress
FileCount int64
ChunkCount int64
BlobCount int64
@@ -68,3 +74,21 @@ type Snapshot struct {
BlobSize int64 // Total size of all referenced blobs (compressed and encrypted)
CompressionRatio float64 // Compression ratio (BlobSize / TotalSize)
}
// IsComplete returns true if the snapshot has completed
func (s *Snapshot) IsComplete() bool {
return s.CompletedAt != nil
}
// SnapshotFile represents the mapping between snapshots and files
type SnapshotFile struct {
SnapshotID string
FilePath string
}
// SnapshotBlob represents the mapping between snapshots and blobs
type SnapshotBlob struct {
SnapshotID string
BlobID string
BlobHash string // Denormalized for easier manifest generation
}

View File

@@ -7,6 +7,7 @@ import (
"path/filepath"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/log"
"go.uber.org/fx"
)
@@ -32,7 +33,13 @@ func provideDatabase(lc fx.Lifecycle, cfg *config.Config) (*DB, error) {
lc.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
return db.Close()
log.Debug("Database module OnStop hook called")
if err := db.Close(); err != nil {
log.Error("Failed to close database in OnStop hook", "error", err)
return err
}
log.Debug("Database closed successfully in OnStop hook")
return nil
},
})

View File

@@ -15,6 +15,7 @@ type Repositories struct {
BlobChunks *BlobChunkRepository
ChunkFiles *ChunkFileRepository
Snapshots *SnapshotRepository
Uploads *UploadRepository
}
func NewRepositories(db *DB) *Repositories {
@@ -27,6 +28,7 @@ func NewRepositories(db *DB) *Repositories {
BlobChunks: NewBlobChunkRepository(db),
ChunkFiles: NewChunkFileRepository(db),
Snapshots: NewSnapshotRepository(db),
Uploads: NewUploadRepository(db.conn),
}
}
@@ -34,13 +36,19 @@ type TxFunc func(ctx context.Context, tx *sql.Tx) error
func (r *Repositories) WithTx(ctx context.Context, fn TxFunc) error {
// Acquire write lock for the entire transaction
LogSQL("WithTx", "Acquiring write lock", "")
r.db.LockForWrite()
defer r.db.UnlockWrite()
defer func() {
LogSQL("WithTx", "Releasing write lock", "")
r.db.UnlockWrite()
}()
LogSQL("WithTx", "Beginning transaction", "")
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
LogSQL("WithTx", "Transaction started", "")
defer func() {
if p := recover(); p != nil {

View File

@@ -71,7 +71,8 @@ func TestRepositoriesTransaction(t *testing.T) {
// Create blob
blob := &Blob{
BlobHash: "tx_blob1",
ID: "tx-blob-id-1",
Hash: "tx_blob1",
CreatedTS: time.Now().Truncate(time.Second),
}
if err := repos.Blobs.Create(ctx, tx, blob); err != nil {
@@ -80,7 +81,7 @@ func TestRepositoriesTransaction(t *testing.T) {
// Map chunks to blob
bc1 := &BlobChunk{
BlobHash: blob.BlobHash,
BlobID: blob.ID,
ChunkHash: chunk1.ChunkHash,
Offset: 0,
Length: 512,
@@ -90,7 +91,7 @@ func TestRepositoriesTransaction(t *testing.T) {
}
bc2 := &BlobChunk{
BlobHash: blob.BlobHash,
BlobID: blob.ID,
ChunkHash: chunk2.ChunkHash,
Offset: 512,
Length: 512,

View File

@@ -0,0 +1,11 @@
-- Track blob upload metrics
CREATE TABLE IF NOT EXISTS uploads (
blob_hash TEXT PRIMARY KEY,
uploaded_at TIMESTAMP NOT NULL,
size INTEGER NOT NULL,
duration_ms INTEGER NOT NULL,
FOREIGN KEY (blob_hash) REFERENCES blobs(blob_hash)
);
CREATE INDEX idx_uploads_uploaded_at ON uploads(uploaded_at);
CREATE INDEX idx_uploads_duration ON uploads(duration_ms);

View File

@@ -17,17 +17,23 @@ func NewSnapshotRepository(db *DB) *SnapshotRepository {
func (r *SnapshotRepository) Create(ctx context.Context, tx *sql.Tx, snapshot *Snapshot) error {
query := `
INSERT INTO snapshots (id, hostname, vaultik_version, created_ts, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO snapshots (id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
var completedAt *int64
if snapshot.CompletedAt != nil {
ts := snapshot.CompletedAt.Unix()
completedAt = &ts
}
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.CreatedTS.Unix(),
snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.CompressionRatio)
_, err = tx.ExecContext(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.StartedAt.Unix(),
completedAt, snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.CompressionRatio)
} else {
_, err = r.db.ExecWithLock(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.CreatedTS.Unix(),
snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.CompressionRatio)
_, err = r.db.ExecWithLock(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.StartedAt.Unix(),
completedAt, snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.CompressionRatio)
}
if err != nil {
@@ -70,19 +76,21 @@ func (r *SnapshotRepository) UpdateCounts(ctx context.Context, tx *sql.Tx, snaps
func (r *SnapshotRepository) GetByID(ctx context.Context, snapshotID string) (*Snapshot, error) {
query := `
SELECT id, hostname, vaultik_version, created_ts, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
SELECT id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
FROM snapshots
WHERE id = ?
`
var snapshot Snapshot
var createdTSUnix int64
var startedAtUnix int64
var completedAtUnix *int64
err := r.db.conn.QueryRowContext(ctx, query, snapshotID).Scan(
&snapshot.ID,
&snapshot.Hostname,
&snapshot.VaultikVersion,
&createdTSUnix,
&startedAtUnix,
&completedAtUnix,
&snapshot.FileCount,
&snapshot.ChunkCount,
&snapshot.BlobCount,
@@ -98,16 +106,20 @@ func (r *SnapshotRepository) GetByID(ctx context.Context, snapshotID string) (*S
return nil, fmt.Errorf("querying snapshot: %w", err)
}
snapshot.CreatedTS = time.Unix(createdTSUnix, 0)
snapshot.StartedAt = time.Unix(startedAtUnix, 0)
if completedAtUnix != nil {
t := time.Unix(*completedAtUnix, 0)
snapshot.CompletedAt = &t
}
return &snapshot, nil
}
func (r *SnapshotRepository) ListRecent(ctx context.Context, limit int) ([]*Snapshot, error) {
query := `
SELECT id, hostname, vaultik_version, created_ts, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
SELECT id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
FROM snapshots
ORDER BY created_ts DESC
ORDER BY started_at DESC
LIMIT ?
`
@@ -120,13 +132,15 @@ func (r *SnapshotRepository) ListRecent(ctx context.Context, limit int) ([]*Snap
var snapshots []*Snapshot
for rows.Next() {
var snapshot Snapshot
var createdTSUnix int64
var startedAtUnix int64
var completedAtUnix *int64
err := rows.Scan(
&snapshot.ID,
&snapshot.Hostname,
&snapshot.VaultikVersion,
&createdTSUnix,
&startedAtUnix,
&completedAtUnix,
&snapshot.FileCount,
&snapshot.ChunkCount,
&snapshot.BlobCount,
@@ -138,7 +152,154 @@ func (r *SnapshotRepository) ListRecent(ctx context.Context, limit int) ([]*Snap
return nil, fmt.Errorf("scanning snapshot: %w", err)
}
snapshot.CreatedTS = time.Unix(createdTSUnix, 0)
snapshot.StartedAt = time.Unix(startedAtUnix, 0)
if completedAtUnix != nil {
t := time.Unix(*completedAtUnix, 0)
snapshot.CompletedAt = &t
}
snapshots = append(snapshots, &snapshot)
}
return snapshots, rows.Err()
}
// MarkComplete marks a snapshot as completed with the current timestamp
func (r *SnapshotRepository) MarkComplete(ctx context.Context, tx *sql.Tx, snapshotID string) error {
query := `
UPDATE snapshots
SET completed_at = ?
WHERE id = ?
`
completedAt := time.Now().Unix()
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, completedAt, snapshotID)
} else {
_, err = r.db.ExecWithLock(ctx, query, completedAt, snapshotID)
}
if err != nil {
return fmt.Errorf("marking snapshot complete: %w", err)
}
return nil
}
// AddFile adds a file to a snapshot
func (r *SnapshotRepository) AddFile(ctx context.Context, tx *sql.Tx, snapshotID string, filePath string) error {
query := `
INSERT OR IGNORE INTO snapshot_files (snapshot_id, file_path)
VALUES (?, ?)
`
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, snapshotID, filePath)
} else {
_, err = r.db.ExecWithLock(ctx, query, snapshotID, filePath)
}
if err != nil {
return fmt.Errorf("adding file to snapshot: %w", err)
}
return nil
}
// AddBlob adds a blob to a snapshot
func (r *SnapshotRepository) AddBlob(ctx context.Context, tx *sql.Tx, snapshotID string, blobID string, blobHash string) error {
query := `
INSERT OR IGNORE INTO snapshot_blobs (snapshot_id, blob_id, blob_hash)
VALUES (?, ?, ?)
`
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, snapshotID, blobID, blobHash)
} else {
_, err = r.db.ExecWithLock(ctx, query, snapshotID, blobID, blobHash)
}
if err != nil {
return fmt.Errorf("adding blob to snapshot: %w", err)
}
return nil
}
// GetBlobHashes returns all blob hashes for a snapshot
func (r *SnapshotRepository) GetBlobHashes(ctx context.Context, snapshotID string) ([]string, error) {
query := `
SELECT sb.blob_hash
FROM snapshot_blobs sb
WHERE sb.snapshot_id = ?
ORDER BY sb.blob_hash
`
rows, err := r.db.conn.QueryContext(ctx, query, snapshotID)
if err != nil {
return nil, fmt.Errorf("querying blob hashes: %w", err)
}
defer CloseRows(rows)
var blobs []string
for rows.Next() {
var blobHash string
if err := rows.Scan(&blobHash); err != nil {
return nil, fmt.Errorf("scanning blob hash: %w", err)
}
blobs = append(blobs, blobHash)
}
return blobs, rows.Err()
}
// GetIncompleteSnapshots returns all snapshots that haven't been completed
func (r *SnapshotRepository) GetIncompleteSnapshots(ctx context.Context) ([]*Snapshot, error) {
query := `
SELECT id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
FROM snapshots
WHERE completed_at IS NULL
ORDER BY started_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("querying incomplete snapshots: %w", err)
}
defer CloseRows(rows)
var snapshots []*Snapshot
for rows.Next() {
var snapshot Snapshot
var startedAtUnix int64
var completedAtUnix *int64
err := rows.Scan(
&snapshot.ID,
&snapshot.Hostname,
&snapshot.VaultikVersion,
&startedAtUnix,
&completedAtUnix,
&snapshot.FileCount,
&snapshot.ChunkCount,
&snapshot.BlobCount,
&snapshot.TotalSize,
&snapshot.BlobSize,
&snapshot.CompressionRatio,
)
if err != nil {
return nil, fmt.Errorf("scanning snapshot: %w", err)
}
snapshot.StartedAt = time.Unix(startedAtUnix, 0)
if completedAtUnix != nil {
t := time.Unix(*completedAtUnix, 0)
snapshot.CompletedAt = &t
}
snapshots = append(snapshots, &snapshot)
}

View File

@@ -30,7 +30,8 @@ func TestSnapshotRepository(t *testing.T) {
ID: "2024-01-01T12:00:00Z",
Hostname: "test-host",
VaultikVersion: "1.0.0",
CreatedTS: time.Now().Truncate(time.Second),
StartedAt: time.Now().Truncate(time.Second),
CompletedAt: nil,
FileCount: 100,
ChunkCount: 500,
BlobCount: 10,
@@ -99,7 +100,8 @@ func TestSnapshotRepository(t *testing.T) {
ID: fmt.Sprintf("2024-01-0%dT12:00:00Z", i),
Hostname: "test-host",
VaultikVersion: "1.0.0",
CreatedTS: time.Now().Add(time.Duration(i) * time.Hour).Truncate(time.Second),
StartedAt: time.Now().Add(time.Duration(i) * time.Hour).Truncate(time.Second),
CompletedAt: nil,
FileCount: int64(100 * i),
ChunkCount: int64(500 * i),
BlobCount: int64(10 * i),
@@ -121,7 +123,7 @@ func TestSnapshotRepository(t *testing.T) {
// Verify order (most recent first)
for i := 0; i < len(recent)-1; i++ {
if recent[i].CreatedTS.Before(recent[i+1].CreatedTS) {
if recent[i].StartedAt.Before(recent[i+1].StartedAt) {
t.Error("snapshots not in descending order")
}
}
@@ -162,7 +164,8 @@ func TestSnapshotRepositoryDuplicate(t *testing.T) {
ID: "2024-01-01T12:00:00Z",
Hostname: "test-host",
VaultikVersion: "1.0.0",
CreatedTS: time.Now().Truncate(time.Second),
StartedAt: time.Now().Truncate(time.Second),
CompletedAt: nil,
FileCount: 100,
ChunkCount: 500,
BlobCount: 10,

View File

@@ -0,0 +1,135 @@
package database
import (
"context"
"database/sql"
"time"
"git.eeqj.de/sneak/vaultik/internal/log"
)
// Upload represents a blob upload record
type Upload struct {
BlobHash string
UploadedAt time.Time
Size int64
DurationMs int64
}
// UploadRepository handles upload records
type UploadRepository struct {
conn *sql.DB
}
// NewUploadRepository creates a new upload repository
func NewUploadRepository(conn *sql.DB) *UploadRepository {
return &UploadRepository{conn: conn}
}
// Create inserts a new upload record
func (r *UploadRepository) Create(ctx context.Context, tx *sql.Tx, upload *Upload) error {
query := `
INSERT INTO uploads (blob_hash, uploaded_at, size, duration_ms)
VALUES (?, ?, ?, ?)
`
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, upload.BlobHash, upload.UploadedAt, upload.Size, upload.DurationMs)
} else {
_, err = r.conn.ExecContext(ctx, query, upload.BlobHash, upload.UploadedAt, upload.Size, upload.DurationMs)
}
return err
}
// GetByBlobHash retrieves an upload record by blob hash
func (r *UploadRepository) GetByBlobHash(ctx context.Context, blobHash string) (*Upload, error) {
query := `
SELECT blob_hash, uploaded_at, size, duration_ms
FROM uploads
WHERE blob_hash = ?
`
var upload Upload
err := r.conn.QueryRowContext(ctx, query, blobHash).Scan(
&upload.BlobHash,
&upload.UploadedAt,
&upload.Size,
&upload.DurationMs,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &upload, nil
}
// GetRecentUploads retrieves recent uploads ordered by upload time
func (r *UploadRepository) GetRecentUploads(ctx context.Context, limit int) ([]*Upload, error) {
query := `
SELECT blob_hash, uploaded_at, size, duration_ms
FROM uploads
ORDER BY uploaded_at DESC
LIMIT ?
`
rows, err := r.conn.QueryContext(ctx, query, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
log.Error("failed to close rows", "error", err)
}
}()
var uploads []*Upload
for rows.Next() {
var upload Upload
if err := rows.Scan(&upload.BlobHash, &upload.UploadedAt, &upload.Size, &upload.DurationMs); err != nil {
return nil, err
}
uploads = append(uploads, &upload)
}
return uploads, rows.Err()
}
// GetUploadStats returns aggregate statistics for uploads
func (r *UploadRepository) GetUploadStats(ctx context.Context, since time.Time) (*UploadStats, error) {
query := `
SELECT
COUNT(*) as count,
COALESCE(SUM(size), 0) as total_size,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms,
COALESCE(MIN(duration_ms), 0) as min_duration_ms,
COALESCE(MAX(duration_ms), 0) as max_duration_ms
FROM uploads
WHERE uploaded_at >= ?
`
var stats UploadStats
err := r.conn.QueryRowContext(ctx, query, since).Scan(
&stats.Count,
&stats.TotalSize,
&stats.AvgDurationMs,
&stats.MinDurationMs,
&stats.MaxDurationMs,
)
return &stats, err
}
// UploadStats contains aggregate upload statistics
type UploadStats struct {
Count int64
TotalSize int64
AvgDurationMs float64
MinDurationMs int64
MaxDurationMs int64
}

175
internal/log/log.go Normal file
View File

@@ -0,0 +1,175 @@
package log
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"golang.org/x/term"
)
// LogLevel represents the logging level
type LogLevel int
const (
LevelFatal LogLevel = iota
LevelError
LevelWarn
LevelNotice
LevelInfo
LevelDebug
)
// Logger configuration
type Config struct {
Verbose bool
Debug bool
Cron bool
}
var logger *slog.Logger
// Initialize sets up the global logger based on the provided configuration
func Initialize(cfg Config) {
// Determine log level based on configuration
var level slog.Level
if cfg.Cron {
// In cron mode, only show fatal errors (which we'll handle specially)
level = slog.LevelError
} else if cfg.Debug || strings.Contains(os.Getenv("GODEBUG"), "vaultik") {
level = slog.LevelDebug
} else if cfg.Verbose {
level = slog.LevelInfo
} else {
level = slog.LevelWarn
}
// Create handler with appropriate level
opts := &slog.HandlerOptions{
Level: level,
}
// Check if stdout is a TTY
if term.IsTerminal(int(os.Stdout.Fd())) {
// Use colorized TTY handler
logger = slog.New(NewTTYHandler(os.Stdout, opts))
} else {
// Use JSON format for non-TTY output
logger = slog.New(slog.NewJSONHandler(os.Stdout, opts))
}
// Set as default logger
slog.SetDefault(logger)
}
// getCaller returns the caller information as a string
func getCaller(skip int) string {
_, file, line, ok := runtime.Caller(skip)
if !ok {
return "unknown"
}
return fmt.Sprintf("%s:%d", filepath.Base(file), line)
}
// Fatal logs a fatal error and exits
func Fatal(msg string, args ...any) {
if logger != nil {
// Add caller info to args
args = append(args, "caller", getCaller(2))
logger.Error(msg, args...)
}
os.Exit(1)
}
// Fatalf logs a formatted fatal error and exits
func Fatalf(format string, args ...any) {
Fatal(fmt.Sprintf(format, args...))
}
// Error logs an error
func Error(msg string, args ...any) {
if logger != nil {
args = append(args, "caller", getCaller(2))
logger.Error(msg, args...)
}
}
// Errorf logs a formatted error
func Errorf(format string, args ...any) {
Error(fmt.Sprintf(format, args...))
}
// Warn logs a warning
func Warn(msg string, args ...any) {
if logger != nil {
args = append(args, "caller", getCaller(2))
logger.Warn(msg, args...)
}
}
// Warnf logs a formatted warning
func Warnf(format string, args ...any) {
Warn(fmt.Sprintf(format, args...))
}
// Notice logs a notice (mapped to Info level)
func Notice(msg string, args ...any) {
if logger != nil {
args = append(args, "caller", getCaller(2))
logger.Info(msg, args...)
}
}
// Noticef logs a formatted notice
func Noticef(format string, args ...any) {
Notice(fmt.Sprintf(format, args...))
}
// Info logs an info message
func Info(msg string, args ...any) {
if logger != nil {
args = append(args, "caller", getCaller(2))
logger.Info(msg, args...)
}
}
// Infof logs a formatted info message
func Infof(format string, args ...any) {
Info(fmt.Sprintf(format, args...))
}
// Debug logs a debug message
func Debug(msg string, args ...any) {
if logger != nil {
args = append(args, "caller", getCaller(2))
logger.Debug(msg, args...)
}
}
// Debugf logs a formatted debug message
func Debugf(format string, args ...any) {
Debug(fmt.Sprintf(format, args...))
}
// With returns a logger with additional context
func With(args ...any) *slog.Logger {
if logger != nil {
return logger.With(args...)
}
return slog.Default()
}
// WithContext returns a logger with context
func WithContext(ctx context.Context) *slog.Logger {
return logger
}
// Logger returns the underlying slog.Logger
func Logger() *slog.Logger {
return logger
}

24
internal/log/module.go Normal file
View File

@@ -0,0 +1,24 @@
package log
import (
"go.uber.org/fx"
)
// Module exports logging functionality
var Module = fx.Module("log",
fx.Invoke(func(cfg Config) {
Initialize(cfg)
}),
)
// New creates a new logger configuration from provided options
func New(opts LogOptions) Config {
return Config(opts)
}
// LogOptions are provided by the CLI
type LogOptions struct {
Verbose bool
Debug bool
Cron bool
}

140
internal/log/tty_handler.go Normal file
View File

@@ -0,0 +1,140 @@
package log
import (
"context"
"fmt"
"io"
"log/slog"
"sync"
"time"
)
// ANSI color codes
const (
colorReset = "\033[0m"
colorRed = "\033[31m"
colorYellow = "\033[33m"
colorBlue = "\033[34m"
colorGray = "\033[90m"
colorGreen = "\033[32m"
colorCyan = "\033[36m"
colorBold = "\033[1m"
)
// TTYHandler is a custom handler for TTY output with colors
type TTYHandler struct {
opts slog.HandlerOptions
mu sync.Mutex
out io.Writer
}
// NewTTYHandler creates a new TTY handler
func NewTTYHandler(out io.Writer, opts *slog.HandlerOptions) *TTYHandler {
if opts == nil {
opts = &slog.HandlerOptions{}
}
return &TTYHandler{
out: out,
opts: *opts,
}
}
// Enabled reports whether the handler handles records at the given level
func (h *TTYHandler) Enabled(_ context.Context, level slog.Level) bool {
return level >= h.opts.Level.Level()
}
// Handle writes the log record
func (h *TTYHandler) Handle(_ context.Context, r slog.Record) error {
h.mu.Lock()
defer h.mu.Unlock()
// Format timestamp
timestamp := r.Time.Format("15:04:05")
// Level and color
level := r.Level.String()
var levelColor string
switch r.Level {
case slog.LevelDebug:
levelColor = colorGray
level = "DEBUG"
case slog.LevelInfo:
levelColor = colorGreen
level = "INFO "
case slog.LevelWarn:
levelColor = colorYellow
level = "WARN "
case slog.LevelError:
levelColor = colorRed
level = "ERROR"
default:
levelColor = colorReset
}
// Print main message
_, _ = fmt.Fprintf(h.out, "%s%s%s %s%s%s %s%s%s",
colorGray, timestamp, colorReset,
levelColor, level, colorReset,
colorBold, r.Message, colorReset)
// Print attributes
r.Attrs(func(a slog.Attr) bool {
value := a.Value.String()
// Special handling for certain attribute types
switch a.Value.Kind() {
case slog.KindDuration:
if d, ok := a.Value.Any().(time.Duration); ok {
value = formatDuration(d)
}
case slog.KindInt64:
if a.Key == "bytes" {
value = formatBytes(a.Value.Int64())
}
}
_, _ = fmt.Fprintf(h.out, " %s%s%s=%s%s%s",
colorCyan, a.Key, colorReset,
colorBlue, value, colorReset)
return true
})
_, _ = fmt.Fprintln(h.out)
return nil
}
// WithAttrs returns a new handler with the given attributes
func (h *TTYHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return h // Simplified for now
}
// WithGroup returns a new handler with the given group name
func (h *TTYHandler) WithGroup(name string) slog.Handler {
return h // Simplified for now
}
// formatDuration formats a duration in a human-readable way
func formatDuration(d time.Duration) string {
if d < time.Millisecond {
return fmt.Sprintf("%dµs", d.Microseconds())
} else if d < time.Second {
return fmt.Sprintf("%dms", d.Milliseconds())
} else if d < time.Minute {
return fmt.Sprintf("%.1fs", d.Seconds())
}
return d.String()
}
// formatBytes formats bytes in a human-readable way
func formatBytes(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
}

40
internal/s3/module.go Normal file
View File

@@ -0,0 +1,40 @@
package s3
import (
"context"
"git.eeqj.de/sneak/vaultik/internal/config"
"go.uber.org/fx"
)
// Module exports S3 functionality
var Module = fx.Module("s3",
fx.Provide(
provideClient,
),
)
func provideClient(lc fx.Lifecycle, cfg *config.Config) (*Client, error) {
ctx := context.Background()
client, err := NewClient(ctx, Config{
Endpoint: cfg.S3.Endpoint,
Bucket: cfg.S3.Bucket,
Prefix: cfg.S3.Prefix,
AccessKeyID: cfg.S3.AccessKeyID,
SecretAccessKey: cfg.S3.SecretAccessKey,
Region: cfg.S3.Region,
})
if err != nil {
return nil, err
}
lc.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
// S3 client doesn't need explicit cleanup
return nil
},
})
return client, nil
}