Major refactoring: UUID-based storage, streaming architecture, and CLI improvements
This commit represents a significant architectural overhaul of vaultik: Database Schema Changes: - Switch files table to use UUID primary keys instead of path-based keys - Add UUID primary keys to blobs table for immediate chunk association - Update all foreign key relationships to use UUIDs - Add comprehensive schema documentation in DATAMODEL.md - Add SQLite busy timeout handling for concurrent operations Streaming and Performance Improvements: - Implement true streaming blob packing without intermediate storage - Add streaming chunk processing to reduce memory usage - Improve progress reporting with real-time metrics - Add upload metrics tracking in new uploads table CLI Refactoring: - Restructure CLI to use subcommands: snapshot create/list/purge/verify - Add store info command for S3 configuration display - Add custom duration parser supporting days/weeks/months/years - Remove old backup.go in favor of enhanced snapshot.go - Add --cron flag for silent operation Configuration Changes: - Remove unused index_prefix configuration option - Add support for snapshot pruning retention policies - Improve configuration validation and error messages Testing Improvements: - Add comprehensive repository tests with edge cases - Add cascade delete debugging tests - Fix concurrent operation tests to use SQLite busy timeout - Remove tolerance for SQLITE_BUSY errors in tests Documentation: - Add MIT LICENSE file - Update README with new command structure - Add comprehensive DATAMODEL.md explaining database schema - Update DESIGN.md with UUID-based architecture Other Changes: - Add test-config.yml for testing - Update Makefile with better test output formatting - Fix various race conditions in concurrent operations - Improve error handling throughout
This commit is contained in:
@@ -338,97 +338,103 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process this file in a transaction
|
||||
// Create file record in a short transaction
|
||||
file := &database.File{
|
||||
Path: path,
|
||||
Size: info.Size(),
|
||||
Mode: uint32(info.Mode()),
|
||||
MTime: info.ModTime(),
|
||||
CTime: info.ModTime(), // Use mtime as ctime for test
|
||||
UID: 1000, // Default UID for test
|
||||
GID: 1000, // Default GID for test
|
||||
}
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
// Create file record
|
||||
file := &database.File{
|
||||
Path: path,
|
||||
Size: info.Size(),
|
||||
Mode: uint32(info.Mode()),
|
||||
MTime: info.ModTime(),
|
||||
CTime: info.ModTime(), // Use mtime as ctime for test
|
||||
UID: 1000, // Default UID for test
|
||||
GID: 1000, // Default GID for test
|
||||
}
|
||||
return b.repos.Files.Create(ctx, tx, file)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := b.repos.Files.Create(ctx, tx, file); err != nil {
|
||||
fileCount++
|
||||
totalSize += info.Size()
|
||||
|
||||
// Read and process file in chunks
|
||||
f, err := fsys.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
// Log but don't fail since we're already in an error path potentially
|
||||
fmt.Fprintf(os.Stderr, "Failed to close file: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Process file in chunks
|
||||
chunkIndex := 0
|
||||
buffer := make([]byte, defaultChunkSize)
|
||||
|
||||
for {
|
||||
n, err := f.Read(buffer)
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
fileCount++
|
||||
totalSize += info.Size()
|
||||
|
||||
// Read and process file in chunks
|
||||
f, err := fsys.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
// Log but don't fail since we're already in an error path potentially
|
||||
fmt.Fprintf(os.Stderr, "Failed to close file: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Process file in chunks
|
||||
chunkIndex := 0
|
||||
buffer := make([]byte, defaultChunkSize)
|
||||
chunkData := buffer[:n]
|
||||
chunkHash := calculateHash(chunkData)
|
||||
|
||||
for {
|
||||
n, err := f.Read(buffer)
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
chunkData := buffer[:n]
|
||||
chunkHash := calculateHash(chunkData)
|
||||
|
||||
// Check if chunk already exists
|
||||
existingChunk, _ := b.repos.Chunks.GetByHash(ctx, chunkHash)
|
||||
if existingChunk == nil {
|
||||
// Create new chunk
|
||||
// Check if chunk already exists (outside of transaction)
|
||||
existingChunk, _ := b.repos.Chunks.GetByHash(ctx, chunkHash)
|
||||
if existingChunk == nil {
|
||||
// Create new chunk in a short transaction
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
chunk := &database.Chunk{
|
||||
ChunkHash: chunkHash,
|
||||
SHA256: chunkHash,
|
||||
Size: int64(n),
|
||||
}
|
||||
if err := b.repos.Chunks.Create(ctx, tx, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
processedChunks[chunkHash] = true
|
||||
return b.repos.Chunks.Create(ctx, tx, chunk)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
processedChunks[chunkHash] = true
|
||||
}
|
||||
|
||||
// Create file-chunk mapping
|
||||
// Create file-chunk mapping in a short transaction
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
fileChunk := &database.FileChunk{
|
||||
Path: path,
|
||||
FileID: file.ID,
|
||||
Idx: chunkIndex,
|
||||
ChunkHash: chunkHash,
|
||||
}
|
||||
if err := b.repos.FileChunks.Create(ctx, tx, fileChunk); err != nil {
|
||||
return err
|
||||
}
|
||||
return b.repos.FileChunks.Create(ctx, tx, fileChunk)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create chunk-file mapping
|
||||
// Create chunk-file mapping in a short transaction
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
chunkFile := &database.ChunkFile{
|
||||
ChunkHash: chunkHash,
|
||||
FilePath: path,
|
||||
FileID: file.ID,
|
||||
FileOffset: int64(chunkIndex * defaultChunkSize),
|
||||
Length: int64(n),
|
||||
}
|
||||
if err := b.repos.ChunkFiles.Create(ctx, tx, chunkFile); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunkIndex++
|
||||
return b.repos.ChunkFiles.Create(ctx, tx, chunkFile)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
chunkIndex++
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -436,61 +442,64 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
|
||||
}
|
||||
|
||||
// After all files are processed, create blobs for new chunks
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
for chunkHash := range processedChunks {
|
||||
// Get chunk data
|
||||
chunk, err := b.repos.Chunks.GetByHash(ctx, chunkHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for chunkHash := range processedChunks {
|
||||
// Get chunk data (outside of transaction)
|
||||
chunk, err := b.repos.Chunks.GetByHash(ctx, chunkHash)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
chunkCount++
|
||||
chunkCount++
|
||||
|
||||
// In a real system, blobs would contain multiple chunks and be encrypted
|
||||
// For testing, we'll create a blob with a "blob-" prefix to differentiate
|
||||
blobHash := "blob-" + chunkHash
|
||||
// In a real system, blobs would contain multiple chunks and be encrypted
|
||||
// For testing, we'll create a blob with a "blob-" prefix to differentiate
|
||||
blobHash := "blob-" + chunkHash
|
||||
|
||||
// For the test, we'll create dummy data since we don't have the original
|
||||
dummyData := []byte(chunkHash)
|
||||
// For the test, we'll create dummy data since we don't have the original
|
||||
dummyData := []byte(chunkHash)
|
||||
|
||||
// Upload to S3 as a blob
|
||||
if err := b.s3Client.PutBlob(ctx, blobHash, dummyData); err != nil {
|
||||
return err
|
||||
}
|
||||
// Upload to S3 as a blob
|
||||
if err := b.s3Client.PutBlob(ctx, blobHash, dummyData); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create blob entry
|
||||
// Create blob entry in a short transaction
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
blob := &database.Blob{
|
||||
ID: "test-blob-" + blobHash[:8],
|
||||
Hash: blobHash,
|
||||
CreatedTS: time.Now(),
|
||||
}
|
||||
if err := b.repos.Blobs.Create(ctx, tx, blob); err != nil {
|
||||
return err
|
||||
}
|
||||
blobCount++
|
||||
blobSize += chunk.Size
|
||||
return b.repos.Blobs.Create(ctx, tx, blob)
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create blob-chunk mapping
|
||||
blobCount++
|
||||
blobSize += chunk.Size
|
||||
|
||||
// Create blob-chunk mapping in a short transaction
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
blobChunk := &database.BlobChunk{
|
||||
BlobID: blob.ID,
|
||||
BlobID: "test-blob-" + blobHash[:8],
|
||||
ChunkHash: chunkHash,
|
||||
Offset: 0,
|
||||
Length: chunk.Size,
|
||||
}
|
||||
if err := b.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add blob to snapshot
|
||||
if err := b.repos.Snapshots.AddBlob(ctx, tx, snapshotID, blob.ID, blob.Hash); err != nil {
|
||||
return err
|
||||
}
|
||||
return b.repos.BlobChunks.Create(ctx, tx, blobChunk)
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
// Add blob to snapshot in a short transaction
|
||||
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
return b.repos.Snapshots.AddBlob(ctx, tx, snapshotID, "test-blob-"+blobHash[:8], blobHash)
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Update snapshot with final counts
|
||||
|
||||
@@ -13,7 +13,9 @@ type ScannerParams struct {
|
||||
EnableProgress bool
|
||||
}
|
||||
|
||||
// Module exports backup functionality
|
||||
// Module exports backup functionality as an fx module.
|
||||
// It provides a ScannerFactory that can create Scanner instances
|
||||
// with custom parameters while sharing common dependencies.
|
||||
var Module = fx.Module("backup",
|
||||
fx.Provide(
|
||||
provideScannerFactory,
|
||||
|
||||
@@ -15,9 +15,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Progress reporting intervals
|
||||
SummaryInterval = 10 * time.Second // One-line status updates
|
||||
DetailInterval = 60 * time.Second // Multi-line detailed status
|
||||
// SummaryInterval defines how often one-line status updates are printed.
|
||||
// These updates show current progress, ETA, and the file being processed.
|
||||
SummaryInterval = 10 * time.Second
|
||||
|
||||
// DetailInterval defines how often multi-line detailed status reports are printed.
|
||||
// These reports include comprehensive statistics about files, chunks, blobs, and uploads.
|
||||
DetailInterval = 60 * time.Second
|
||||
)
|
||||
|
||||
// ProgressStats holds atomic counters for progress tracking
|
||||
@@ -32,6 +36,7 @@ type ProgressStats struct {
|
||||
BlobsCreated atomic.Int64
|
||||
BlobsUploaded atomic.Int64
|
||||
BytesUploaded atomic.Int64
|
||||
UploadDurationMs atomic.Int64 // Total milliseconds spent uploading to S3
|
||||
CurrentFile atomic.Value // stores string
|
||||
TotalSize atomic.Int64 // Total size to process (set after scan phase)
|
||||
TotalFiles atomic.Int64 // Total files to process in phase 2
|
||||
@@ -66,8 +71,8 @@ type ProgressReporter struct {
|
||||
// NewProgressReporter creates a new progress reporter
|
||||
func NewProgressReporter() *ProgressReporter {
|
||||
stats := &ProgressStats{
|
||||
StartTime: time.Now(),
|
||||
lastDetailTime: time.Now(),
|
||||
StartTime: time.Now().UTC(),
|
||||
lastDetailTime: time.Now().UTC(),
|
||||
}
|
||||
stats.CurrentFile.Store("")
|
||||
|
||||
@@ -115,7 +120,7 @@ func (pr *ProgressReporter) GetStats() *ProgressStats {
|
||||
// SetTotalSize sets the total size to process (after scan phase)
|
||||
func (pr *ProgressReporter) SetTotalSize(size int64) {
|
||||
pr.stats.TotalSize.Store(size)
|
||||
pr.stats.ProcessStartTime.Store(time.Now())
|
||||
pr.stats.ProcessStartTime.Store(time.Now().UTC())
|
||||
}
|
||||
|
||||
// run is the main progress reporting loop
|
||||
@@ -186,7 +191,7 @@ func (pr *ProgressReporter) printSummaryStatus() {
|
||||
filesProcessed := pr.stats.FilesProcessed.Load()
|
||||
totalFiles := pr.stats.TotalFiles.Load()
|
||||
|
||||
status := fmt.Sprintf("Progress: %d/%d files, %s/%s (%.1f%%), %s/s%s",
|
||||
status := fmt.Sprintf("Snapshot progress: %d/%d files, %s/%s (%.1f%%), %s/s%s",
|
||||
filesProcessed,
|
||||
totalFiles,
|
||||
humanize.Bytes(uint64(bytesProcessed)),
|
||||
@@ -206,7 +211,7 @@ func (pr *ProgressReporter) printSummaryStatus() {
|
||||
// printDetailedStatus prints a multi-line detailed status
|
||||
func (pr *ProgressReporter) printDetailedStatus() {
|
||||
pr.stats.mu.Lock()
|
||||
pr.stats.lastDetailTime = time.Now()
|
||||
pr.stats.lastDetailTime = time.Now().UTC()
|
||||
pr.stats.mu.Unlock()
|
||||
|
||||
elapsed := time.Since(pr.stats.StartTime)
|
||||
@@ -225,7 +230,7 @@ func (pr *ProgressReporter) printDetailedStatus() {
|
||||
totalBytes := bytesScanned + bytesSkipped
|
||||
rate := float64(totalBytes) / elapsed.Seconds()
|
||||
|
||||
log.Notice("=== Backup Progress Report ===")
|
||||
log.Notice("=== Snapshot Progress Report ===")
|
||||
log.Info("Elapsed time", "duration", formatDuration(elapsed))
|
||||
|
||||
// Calculate and show ETA if we have data
|
||||
@@ -264,7 +269,7 @@ func (pr *ProgressReporter) printDetailedStatus() {
|
||||
"created", blobsCreated,
|
||||
"uploaded", blobsUploaded,
|
||||
"pending", blobsCreated-blobsUploaded)
|
||||
log.Info("Upload progress",
|
||||
log.Info("Total uploaded to S3",
|
||||
"uploaded", humanize.Bytes(uint64(bytesUploaded)),
|
||||
"compression_ratio", formatRatio(bytesUploaded, bytesScanned))
|
||||
if currentFile != "" {
|
||||
@@ -313,31 +318,8 @@ func truncatePath(path string, maxLen int) string {
|
||||
|
||||
// printUploadProgress prints upload progress
|
||||
func (pr *ProgressReporter) printUploadProgress(info *UploadInfo) {
|
||||
elapsed := time.Since(info.StartTime)
|
||||
if elapsed < time.Millisecond {
|
||||
elapsed = time.Millisecond // Avoid division by zero
|
||||
}
|
||||
|
||||
bytesPerSec := float64(info.Size) / elapsed.Seconds()
|
||||
bitsPerSec := bytesPerSec * 8
|
||||
|
||||
// Format speed in bits/second
|
||||
var speedStr string
|
||||
if bitsPerSec >= 1e9 {
|
||||
speedStr = fmt.Sprintf("%.1fGbit/sec", bitsPerSec/1e9)
|
||||
} else if bitsPerSec >= 1e6 {
|
||||
speedStr = fmt.Sprintf("%.0fMbit/sec", bitsPerSec/1e6)
|
||||
} else if bitsPerSec >= 1e3 {
|
||||
speedStr = fmt.Sprintf("%.0fKbit/sec", bitsPerSec/1e3)
|
||||
} else {
|
||||
speedStr = fmt.Sprintf("%.0fbit/sec", bitsPerSec)
|
||||
}
|
||||
|
||||
log.Info("Uploading blob",
|
||||
"hash", info.BlobHash[:8]+"...",
|
||||
"size", humanize.Bytes(uint64(info.Size)),
|
||||
"elapsed", formatDuration(elapsed),
|
||||
"speed", speedStr)
|
||||
// This function is called repeatedly during upload, not just at start
|
||||
// Don't print anything here - the actual progress is shown by ReportUploadProgress
|
||||
}
|
||||
|
||||
// ReportUploadStart marks the beginning of a blob upload
|
||||
@@ -345,7 +327,7 @@ func (pr *ProgressReporter) ReportUploadStart(blobHash string, size int64) {
|
||||
info := &UploadInfo{
|
||||
BlobHash: blobHash,
|
||||
Size: size,
|
||||
StartTime: time.Now(),
|
||||
StartTime: time.Now().UTC(),
|
||||
}
|
||||
pr.stats.CurrentUpload.Store(info)
|
||||
}
|
||||
@@ -355,6 +337,9 @@ func (pr *ProgressReporter) ReportUploadComplete(blobHash string, size int64, du
|
||||
// Clear current upload
|
||||
pr.stats.CurrentUpload.Store((*UploadInfo)(nil))
|
||||
|
||||
// Add to total upload duration
|
||||
pr.stats.UploadDurationMs.Add(duration.Milliseconds())
|
||||
|
||||
// Calculate speed
|
||||
if duration < time.Millisecond {
|
||||
duration = time.Millisecond
|
||||
@@ -374,7 +359,7 @@ func (pr *ProgressReporter) ReportUploadComplete(blobHash string, size int64, du
|
||||
speedStr = fmt.Sprintf("%.0fbit/sec", bitsPerSec)
|
||||
}
|
||||
|
||||
log.Info("Blob uploaded",
|
||||
log.Info("Blob upload completed",
|
||||
"hash", blobHash[:8]+"...",
|
||||
"size", humanize.Bytes(uint64(size)),
|
||||
"duration", formatDuration(duration),
|
||||
@@ -384,6 +369,44 @@ func (pr *ProgressReporter) ReportUploadComplete(blobHash string, size int64, du
|
||||
// UpdateChunkingActivity updates the last chunking time
|
||||
func (pr *ProgressReporter) UpdateChunkingActivity() {
|
||||
pr.stats.mu.Lock()
|
||||
pr.stats.lastChunkingTime = time.Now()
|
||||
pr.stats.lastChunkingTime = time.Now().UTC()
|
||||
pr.stats.mu.Unlock()
|
||||
}
|
||||
|
||||
// ReportUploadProgress reports current upload progress with instantaneous speed
|
||||
func (pr *ProgressReporter) ReportUploadProgress(blobHash string, bytesUploaded, totalSize int64, instantSpeed float64) {
|
||||
// Update the current upload info with progress
|
||||
if uploadInfo, ok := pr.stats.CurrentUpload.Load().(*UploadInfo); ok && uploadInfo != nil {
|
||||
// Format speed in bits/second
|
||||
bitsPerSec := instantSpeed * 8
|
||||
var speedStr string
|
||||
if bitsPerSec >= 1e9 {
|
||||
speedStr = fmt.Sprintf("%.1fGbit/sec", bitsPerSec/1e9)
|
||||
} else if bitsPerSec >= 1e6 {
|
||||
speedStr = fmt.Sprintf("%.0fMbit/sec", bitsPerSec/1e6)
|
||||
} else if bitsPerSec >= 1e3 {
|
||||
speedStr = fmt.Sprintf("%.0fKbit/sec", bitsPerSec/1e3)
|
||||
} else {
|
||||
speedStr = fmt.Sprintf("%.0fbit/sec", bitsPerSec)
|
||||
}
|
||||
|
||||
percent := float64(bytesUploaded) / float64(totalSize) * 100
|
||||
|
||||
// Calculate ETA based on current speed
|
||||
etaStr := "unknown"
|
||||
if instantSpeed > 0 && bytesUploaded < totalSize {
|
||||
remainingBytes := totalSize - bytesUploaded
|
||||
remainingSeconds := float64(remainingBytes) / instantSpeed
|
||||
eta := time.Duration(remainingSeconds * float64(time.Second))
|
||||
etaStr = formatDuration(eta)
|
||||
}
|
||||
|
||||
log.Info("Blob upload progress",
|
||||
"hash", blobHash[:8]+"...",
|
||||
"progress", fmt.Sprintf("%.1f%%", percent),
|
||||
"uploaded", humanize.Bytes(uint64(bytesUploaded)),
|
||||
"total", humanize.Bytes(uint64(totalSize)),
|
||||
"speed", speedStr,
|
||||
"eta", etaStr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"git.eeqj.de/sneak/vaultik/internal/crypto"
|
||||
"git.eeqj.de/sneak/vaultik/internal/database"
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"git.eeqj.de/sneak/vaultik/internal/s3"
|
||||
"github.com/dustin/go-humanize"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
@@ -49,6 +50,8 @@ type Scanner struct {
|
||||
// S3Client interface for blob storage operations
|
||||
type S3Client interface {
|
||||
PutObject(ctx context.Context, key string, data io.Reader) error
|
||||
PutObjectWithProgress(ctx context.Context, key string, data io.Reader, size int64, progress s3.ProgressCallback) error
|
||||
StatObject(ctx context.Context, key string) (*s3.ObjectInfo, error)
|
||||
}
|
||||
|
||||
// ScannerConfig contains configuration for the scanner
|
||||
@@ -125,7 +128,7 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
|
||||
s.snapshotID = snapshotID
|
||||
s.scanCtx = ctx
|
||||
result := &ScanResult{
|
||||
StartTime: time.Now(),
|
||||
StartTime: time.Now().UTC(),
|
||||
}
|
||||
|
||||
// Set blob handler for concurrent upload
|
||||
@@ -143,7 +146,7 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
|
||||
}
|
||||
|
||||
// Phase 1: Scan directory and collect files to process
|
||||
log.Info("Phase 1: Scanning directory structure")
|
||||
log.Info("Phase 1/3: Scanning directory structure")
|
||||
filesToProcess, err := s.scanPhase(ctx, path, result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan phase failed: %w", err)
|
||||
@@ -169,7 +172,7 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
|
||||
|
||||
// Phase 2: Process files and create chunks
|
||||
if len(filesToProcess) > 0 {
|
||||
log.Info("Phase 2: Processing files and creating chunks")
|
||||
log.Info("Phase 2/3: Creating snapshot (chunking, compressing, encrypting, and uploading blobs)")
|
||||
if err := s.processPhase(ctx, filesToProcess, result); err != nil {
|
||||
return nil, fmt.Errorf("process phase failed: %w", err)
|
||||
}
|
||||
@@ -179,7 +182,7 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
|
||||
blobs := s.packer.GetFinishedBlobs()
|
||||
result.BlobsCreated += len(blobs)
|
||||
|
||||
result.EndTime = time.Now()
|
||||
result.EndTime = time.Now().UTC()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -290,21 +293,12 @@ func (s *Scanner) checkFileAndUpdateMetadata(ctx context.Context, path string, i
|
||||
default:
|
||||
}
|
||||
|
||||
var file *database.File
|
||||
var needsProcessing bool
|
||||
|
||||
// Use a short transaction just for the database operations
|
||||
err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
|
||||
var err error
|
||||
file, needsProcessing, err = s.checkFile(txCtx, tx, path, info, result)
|
||||
return err
|
||||
})
|
||||
|
||||
return file, needsProcessing, err
|
||||
// Process file without holding a long transaction
|
||||
return s.checkFile(ctx, path, info, result)
|
||||
}
|
||||
|
||||
// checkFile checks if a file needs processing and updates metadata within a transaction
|
||||
func (s *Scanner) checkFile(ctx context.Context, tx *sql.Tx, path string, info os.FileInfo, result *ScanResult) (*database.File, bool, error) {
|
||||
// checkFile checks if a file needs processing and updates metadata
|
||||
func (s *Scanner) checkFile(ctx context.Context, path string, info os.FileInfo, result *ScanResult) (*database.File, bool, error) {
|
||||
// Get file stats
|
||||
stat, ok := info.Sys().(interface {
|
||||
Uid() uint32
|
||||
@@ -338,25 +332,31 @@ func (s *Scanner) checkFile(ctx context.Context, tx *sql.Tx, path string, info o
|
||||
LinkTarget: linkTarget,
|
||||
}
|
||||
|
||||
// Check if file has changed since last backup
|
||||
// Check if file has changed since last backup (no transaction needed for read)
|
||||
log.Debug("Checking if file exists in database", "path", path)
|
||||
existingFile, err := s.repos.Files.GetByPathTx(ctx, tx, path)
|
||||
existingFile, err := s.repos.Files.GetByPath(ctx, path)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("checking existing file: %w", err)
|
||||
}
|
||||
|
||||
fileChanged := existingFile == nil || s.hasFileChanged(existingFile, file)
|
||||
|
||||
// Always update file metadata
|
||||
// Update file metadata in a short transaction
|
||||
log.Debug("Updating file metadata", "path", path, "changed", fileChanged)
|
||||
if err := s.repos.Files.Create(ctx, tx, file); err != nil {
|
||||
err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
return s.repos.Files.Create(ctx, tx, file)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
log.Debug("File metadata updated", "path", path)
|
||||
|
||||
// Add file to snapshot
|
||||
// Add file to snapshot in a short transaction
|
||||
log.Debug("Adding file to snapshot", "path", path, "snapshot", s.snapshotID)
|
||||
if err := s.repos.Snapshots.AddFile(ctx, tx, s.snapshotID, path); err != nil {
|
||||
err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
return s.repos.Snapshots.AddFile(ctx, tx, s.snapshotID, path)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("adding file to snapshot: %w", err)
|
||||
}
|
||||
log.Debug("File added to snapshot", "path", path)
|
||||
@@ -381,7 +381,7 @@ func (s *Scanner) checkFile(ctx context.Context, tx *sql.Tx, path string, info o
|
||||
}
|
||||
// File hasn't changed, but we still need to associate existing chunks with this snapshot
|
||||
log.Debug("File hasn't changed, associating existing chunks", "path", path)
|
||||
if err := s.associateExistingChunks(ctx, tx, path); err != nil {
|
||||
if err := s.associateExistingChunks(ctx, path); err != nil {
|
||||
return nil, false, fmt.Errorf("associating existing chunks: %w", err)
|
||||
}
|
||||
log.Debug("Existing chunks associated", "path", path)
|
||||
@@ -421,25 +421,25 @@ func (s *Scanner) hasFileChanged(existingFile, newFile *database.File) bool {
|
||||
}
|
||||
|
||||
// associateExistingChunks links existing chunks from an unchanged file to the current snapshot
|
||||
func (s *Scanner) associateExistingChunks(ctx context.Context, tx *sql.Tx, path string) error {
|
||||
func (s *Scanner) associateExistingChunks(ctx context.Context, path string) error {
|
||||
log.Debug("associateExistingChunks start", "path", path)
|
||||
|
||||
// Get existing file chunks
|
||||
// Get existing file chunks (no transaction needed for read)
|
||||
log.Debug("Getting existing file chunks", "path", path)
|
||||
fileChunks, err := s.repos.FileChunks.GetByFileTx(ctx, tx, path)
|
||||
fileChunks, err := s.repos.FileChunks.GetByFile(ctx, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting existing file chunks: %w", err)
|
||||
}
|
||||
log.Debug("Got file chunks", "path", path, "count", len(fileChunks))
|
||||
|
||||
// For each chunk, find its blob and associate with current snapshot
|
||||
processedBlobs := make(map[string]bool)
|
||||
// Collect unique blob IDs that need to be added to snapshot
|
||||
blobsToAdd := make(map[string]string) // blob ID -> blob hash
|
||||
for i, fc := range fileChunks {
|
||||
log.Debug("Processing chunk", "path", path, "chunk_index", i, "chunk_hash", fc.ChunkHash)
|
||||
|
||||
// Find which blob contains this chunk
|
||||
// Find which blob contains this chunk (no transaction needed for read)
|
||||
log.Debug("Finding blob for chunk", "chunk_hash", fc.ChunkHash)
|
||||
blobChunk, err := s.repos.BlobChunks.GetByChunkHashTx(ctx, tx, fc.ChunkHash)
|
||||
blobChunk, err := s.repos.BlobChunks.GetByChunkHash(ctx, fc.ChunkHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding blob for chunk %s: %w", fc.ChunkHash, err)
|
||||
}
|
||||
@@ -449,28 +449,39 @@ func (s *Scanner) associateExistingChunks(ctx context.Context, tx *sql.Tx, path
|
||||
}
|
||||
log.Debug("Found blob for chunk", "chunk_hash", fc.ChunkHash, "blob_id", blobChunk.BlobID)
|
||||
|
||||
// Get blob to find its hash
|
||||
blob, err := s.repos.Blobs.GetByID(ctx, blobChunk.BlobID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting blob %s: %w", blobChunk.BlobID, err)
|
||||
}
|
||||
if blob == nil {
|
||||
log.Warn("Blob record not found", "blob_id", blobChunk.BlobID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add blob to snapshot if not already processed
|
||||
if !processedBlobs[blobChunk.BlobID] {
|
||||
log.Debug("Adding blob to snapshot", "blob_id", blobChunk.BlobID, "blob_hash", blob.Hash, "snapshot", s.snapshotID)
|
||||
if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, blobChunk.BlobID, blob.Hash); err != nil {
|
||||
return fmt.Errorf("adding existing blob to snapshot: %w", err)
|
||||
}
|
||||
log.Debug("Added blob to snapshot", "blob_id", blobChunk.BlobID)
|
||||
processedBlobs[blobChunk.BlobID] = true
|
||||
// Track blob ID for later processing
|
||||
if _, exists := blobsToAdd[blobChunk.BlobID]; !exists {
|
||||
blobsToAdd[blobChunk.BlobID] = "" // We'll get the hash later
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("associateExistingChunks complete", "path", path, "blobs_processed", len(processedBlobs))
|
||||
// Now get blob hashes outside of transaction operations
|
||||
for blobID := range blobsToAdd {
|
||||
blob, err := s.repos.Blobs.GetByID(ctx, blobID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting blob %s: %w", blobID, err)
|
||||
}
|
||||
if blob == nil {
|
||||
log.Warn("Blob record not found", "blob_id", blobID)
|
||||
delete(blobsToAdd, blobID)
|
||||
continue
|
||||
}
|
||||
blobsToAdd[blobID] = blob.Hash
|
||||
}
|
||||
|
||||
// Add blobs to snapshot using short transactions
|
||||
for blobID, blobHash := range blobsToAdd {
|
||||
log.Debug("Adding blob to snapshot", "blob_id", blobID, "blob_hash", blobHash, "snapshot", s.snapshotID)
|
||||
err := s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
return s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, blobID, blobHash)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding existing blob to snapshot: %w", err)
|
||||
}
|
||||
log.Debug("Added blob to snapshot", "blob_id", blobID)
|
||||
}
|
||||
|
||||
log.Debug("associateExistingChunks complete", "path", path, "blobs_processed", len(blobsToAdd))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -478,7 +489,7 @@ func (s *Scanner) associateExistingChunks(ctx context.Context, tx *sql.Tx, path
|
||||
func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
|
||||
log.Debug("Blob handler called", "blob_hash", blobWithReader.Hash[:8]+"...")
|
||||
|
||||
startTime := time.Now()
|
||||
startTime := time.Now().UTC()
|
||||
finishedBlob := blobWithReader.FinishedBlob
|
||||
|
||||
// Report upload start
|
||||
@@ -492,7 +503,40 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if err := s.s3Client.PutObject(ctx, "blobs/"+finishedBlob.Hash, blobWithReader.Reader); err != nil {
|
||||
|
||||
// Track bytes uploaded for accurate speed calculation
|
||||
lastProgressTime := time.Now()
|
||||
lastProgressBytes := int64(0)
|
||||
|
||||
progressCallback := func(uploaded int64) error {
|
||||
|
||||
// Calculate instantaneous speed
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(lastProgressTime).Seconds()
|
||||
if elapsed > 0.5 { // Update speed every 0.5 seconds
|
||||
bytesSinceLastUpdate := uploaded - lastProgressBytes
|
||||
speed := float64(bytesSinceLastUpdate) / elapsed
|
||||
|
||||
if s.progress != nil {
|
||||
s.progress.ReportUploadProgress(finishedBlob.Hash, uploaded, finishedBlob.Compressed, speed)
|
||||
}
|
||||
|
||||
lastProgressTime = now
|
||||
lastProgressBytes = uploaded
|
||||
}
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create sharded path: blobs/ca/fe/cafebabe...
|
||||
blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash)
|
||||
if err := s.s3Client.PutObjectWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil {
|
||||
return fmt.Errorf("uploading blob %s to S3: %w", finishedBlob.Hash, err)
|
||||
}
|
||||
|
||||
@@ -574,8 +618,8 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
|
||||
var chunks []chunkInfo
|
||||
chunkIndex := 0
|
||||
|
||||
// Process chunks in streaming fashion
|
||||
err = s.chunker.ChunkReaderStreaming(file, func(chunk chunker.Chunk) error {
|
||||
// Process chunks in streaming fashion and get full file hash
|
||||
fileHash, err := s.chunker.ChunkReaderStreaming(file, func(chunk chunker.Chunk) error {
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -589,17 +633,16 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
|
||||
"hash", chunk.Hash,
|
||||
"size", chunk.Size)
|
||||
|
||||
// Check if chunk already exists
|
||||
chunkExists := false
|
||||
err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
|
||||
existing, err := s.repos.Chunks.GetByHash(txCtx, chunk.Hash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chunkExists = (existing != nil)
|
||||
// Check if chunk already exists (outside of transaction)
|
||||
existing, err := s.repos.Chunks.GetByHash(ctx, chunk.Hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking chunk existence: %w", err)
|
||||
}
|
||||
chunkExists := (existing != nil)
|
||||
|
||||
// Store chunk if new
|
||||
if !chunkExists {
|
||||
// Store chunk if new
|
||||
if !chunkExists {
|
||||
err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
|
||||
dbChunk := &database.Chunk{
|
||||
ChunkHash: chunk.Hash,
|
||||
SHA256: chunk.Hash,
|
||||
@@ -608,17 +651,17 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
|
||||
if err := s.repos.Chunks.Create(txCtx, tx, dbChunk); err != nil {
|
||||
return fmt.Errorf("creating chunk: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("storing chunk: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking/storing chunk: %w", err)
|
||||
}
|
||||
|
||||
// Track file chunk association for later storage
|
||||
chunks = append(chunks, chunkInfo{
|
||||
fileChunk: database.FileChunk{
|
||||
Path: fileToProcess.Path,
|
||||
FileID: fileToProcess.File.ID,
|
||||
Idx: chunkIndex,
|
||||
ChunkHash: chunk.Hash,
|
||||
},
|
||||
@@ -683,6 +726,11 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
|
||||
return fmt.Errorf("chunking file: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Completed chunking file",
|
||||
"path", fileToProcess.Path,
|
||||
"file_hash", fileHash,
|
||||
"chunks", len(chunks))
|
||||
|
||||
// Store file-chunk associations and chunk-file mappings in database
|
||||
err = s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
|
||||
for _, ci := range chunks {
|
||||
@@ -694,7 +742,7 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
|
||||
// Create chunk-file mapping
|
||||
chunkFile := &database.ChunkFile{
|
||||
ChunkHash: ci.fileChunk.ChunkHash,
|
||||
FilePath: fileToProcess.Path,
|
||||
FileID: fileToProcess.File.ID,
|
||||
FileOffset: ci.offset,
|
||||
Length: ci.size,
|
||||
}
|
||||
@@ -704,7 +752,7 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
|
||||
}
|
||||
|
||||
// Add file to snapshot
|
||||
if err := s.repos.Snapshots.AddFile(txCtx, tx, s.snapshotID, fileToProcess.Path); err != nil {
|
||||
if err := s.repos.Snapshots.AddFileByID(txCtx, tx, s.snapshotID, fileToProcess.File.ID); err != nil {
|
||||
return fmt.Errorf("adding file to snapshot: %w", err)
|
||||
}
|
||||
|
||||
@@ -713,3 +761,8 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetProgress returns the progress reporter for this scanner
|
||||
func (s *Scanner) GetProgress() *ProgressReporter {
|
||||
return s.progress
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ func TestScannerWithSymlinks(t *testing.T) {
|
||||
Repositories: repos,
|
||||
MaxBlobSize: int64(1024 * 1024),
|
||||
CompressionLevel: 3,
|
||||
AgeRecipients: []string{},
|
||||
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
|
||||
})
|
||||
|
||||
// Create a snapshot record for testing
|
||||
@@ -314,7 +314,7 @@ func TestScannerLargeFile(t *testing.T) {
|
||||
Repositories: repos,
|
||||
MaxBlobSize: int64(1024 * 1024),
|
||||
CompressionLevel: 3,
|
||||
AgeRecipients: []string{},
|
||||
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
|
||||
})
|
||||
|
||||
// Create a snapshot record for testing
|
||||
|
||||
@@ -78,21 +78,22 @@ func NewSnapshotManager(repos *database.Repositories, s3Client S3Client, encrypt
|
||||
}
|
||||
|
||||
// CreateSnapshot creates a new snapshot record in the database at the start of a backup
|
||||
func (sm *SnapshotManager) CreateSnapshot(ctx context.Context, hostname, version string) (string, error) {
|
||||
snapshotID := fmt.Sprintf("%s-%s", hostname, time.Now().Format("20060102-150405"))
|
||||
func (sm *SnapshotManager) CreateSnapshot(ctx context.Context, hostname, version, gitRevision string) (string, error) {
|
||||
snapshotID := fmt.Sprintf("%s-%s", hostname, time.Now().UTC().Format("20060102-150405Z"))
|
||||
|
||||
snapshot := &database.Snapshot{
|
||||
ID: snapshotID,
|
||||
Hostname: hostname,
|
||||
VaultikVersion: version,
|
||||
StartedAt: time.Now(),
|
||||
CompletedAt: nil, // Not completed yet
|
||||
FileCount: 0,
|
||||
ChunkCount: 0,
|
||||
BlobCount: 0,
|
||||
TotalSize: 0,
|
||||
BlobSize: 0,
|
||||
CompressionRatio: 1.0,
|
||||
ID: snapshotID,
|
||||
Hostname: hostname,
|
||||
VaultikVersion: version,
|
||||
VaultikGitRevision: gitRevision,
|
||||
StartedAt: time.Now().UTC(),
|
||||
CompletedAt: nil, // Not completed yet
|
||||
FileCount: 0,
|
||||
ChunkCount: 0,
|
||||
BlobCount: 0,
|
||||
TotalSize: 0,
|
||||
BlobSize: 0,
|
||||
CompressionRatio: 1.0,
|
||||
}
|
||||
|
||||
err := sm.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
@@ -126,6 +127,30 @@ func (sm *SnapshotManager) UpdateSnapshotStats(ctx context.Context, snapshotID s
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSnapshotStatsExtended updates snapshot statistics with extended metrics.
|
||||
// This includes compression level, uncompressed blob size, and upload duration.
|
||||
func (sm *SnapshotManager) UpdateSnapshotStatsExtended(ctx context.Context, snapshotID string, stats ExtendedBackupStats) error {
|
||||
return sm.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
// First update basic stats
|
||||
if err := sm.repos.Snapshots.UpdateCounts(ctx, tx, snapshotID,
|
||||
int64(stats.FilesScanned),
|
||||
int64(stats.ChunksCreated),
|
||||
int64(stats.BlobsCreated),
|
||||
stats.BytesScanned,
|
||||
stats.BytesUploaded,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then update extended stats
|
||||
return sm.repos.Snapshots.UpdateExtendedStats(ctx, tx, snapshotID,
|
||||
stats.BlobUncompressedSize,
|
||||
stats.CompressionLevel,
|
||||
stats.UploadDurationMs,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// CompleteSnapshot marks a snapshot as completed and exports its metadata
|
||||
func (sm *SnapshotManager) CompleteSnapshot(ctx context.Context, snapshotID string) error {
|
||||
// Mark the snapshot as completed
|
||||
@@ -158,14 +183,16 @@ func (sm *SnapshotManager) CompleteSnapshot(ctx context.Context, snapshotID stri
|
||||
//
|
||||
// This ensures database consistency during the copy operation.
|
||||
func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath string, snapshotID string) error {
|
||||
log.Info("Exporting snapshot metadata", "snapshot_id", snapshotID)
|
||||
log.Info("Phase 3/3: Exporting snapshot metadata", "snapshot_id", snapshotID, "source_db", dbPath)
|
||||
|
||||
// Create temp directory for all temporary files
|
||||
tempDir, err := os.MkdirTemp("", "vaultik-snapshot-*")
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating temp dir: %w", err)
|
||||
}
|
||||
log.Debug("Created temporary directory", "path", tempDir)
|
||||
defer func() {
|
||||
log.Debug("Cleaning up temporary directory", "path", tempDir)
|
||||
if err := os.RemoveAll(tempDir); err != nil {
|
||||
log.Debug("Failed to remove temp dir", "path", tempDir, "error", err)
|
||||
}
|
||||
@@ -174,28 +201,37 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
|
||||
// Step 1: Copy database to temp file
|
||||
// The main database should be closed at this point
|
||||
tempDBPath := filepath.Join(tempDir, "snapshot.db")
|
||||
log.Debug("Copying database to temporary location", "source", dbPath, "destination", tempDBPath)
|
||||
if err := copyFile(dbPath, tempDBPath); err != nil {
|
||||
return fmt.Errorf("copying database: %w", err)
|
||||
}
|
||||
log.Debug("Database copy complete", "size", getFileSize(tempDBPath))
|
||||
|
||||
// Step 2: Clean the temp database to only contain current snapshot data
|
||||
log.Debug("Cleaning snapshot database to contain only current snapshot", "snapshot_id", snapshotID)
|
||||
if err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID); err != nil {
|
||||
return fmt.Errorf("cleaning snapshot database: %w", err)
|
||||
}
|
||||
log.Debug("Database cleaning complete", "size_after_clean", getFileSize(tempDBPath))
|
||||
|
||||
// Step 3: Dump the cleaned database to SQL
|
||||
dumpPath := filepath.Join(tempDir, "snapshot.sql")
|
||||
log.Debug("Dumping database to SQL", "source", tempDBPath, "destination", dumpPath)
|
||||
if err := sm.dumpDatabase(tempDBPath, dumpPath); err != nil {
|
||||
return fmt.Errorf("dumping database: %w", err)
|
||||
}
|
||||
log.Debug("SQL dump complete", "size", getFileSize(dumpPath))
|
||||
|
||||
// Step 4: Compress the SQL dump
|
||||
compressedPath := filepath.Join(tempDir, "snapshot.sql.zst")
|
||||
log.Debug("Compressing SQL dump", "source", dumpPath, "destination", compressedPath)
|
||||
if err := sm.compressDump(dumpPath, compressedPath); err != nil {
|
||||
return fmt.Errorf("compressing dump: %w", err)
|
||||
}
|
||||
log.Debug("Compression complete", "original_size", getFileSize(dumpPath), "compressed_size", getFileSize(compressedPath))
|
||||
|
||||
// Step 5: Read compressed data for encryption/upload
|
||||
log.Debug("Reading compressed data for upload", "path", compressedPath)
|
||||
compressedData, err := os.ReadFile(compressedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading compressed dump: %w", err)
|
||||
@@ -204,14 +240,19 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
|
||||
// Step 6: Encrypt if encryptor is available
|
||||
finalData := compressedData
|
||||
if sm.encryptor != nil {
|
||||
log.Debug("Encrypting snapshot data", "size_before", len(compressedData))
|
||||
encrypted, err := sm.encryptor.Encrypt(compressedData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypting snapshot: %w", err)
|
||||
}
|
||||
finalData = encrypted
|
||||
log.Debug("Encryption complete", "size_after", len(encrypted))
|
||||
} else {
|
||||
log.Debug("No encryption configured, using compressed data as-is")
|
||||
}
|
||||
|
||||
// Step 7: Generate blob manifest (before closing temp DB)
|
||||
log.Debug("Generating blob manifest from temporary database", "db_path", tempDBPath)
|
||||
blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating blob manifest: %w", err)
|
||||
@@ -224,15 +265,19 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
|
||||
dbKey += ".age"
|
||||
}
|
||||
|
||||
log.Debug("Uploading snapshot database to S3", "key", dbKey, "size", len(finalData))
|
||||
if err := sm.s3Client.PutObject(ctx, dbKey, bytes.NewReader(finalData)); err != nil {
|
||||
return fmt.Errorf("uploading snapshot database: %w", err)
|
||||
}
|
||||
log.Debug("Database upload complete", "key", dbKey)
|
||||
|
||||
// Upload blob manifest (unencrypted, compressed)
|
||||
manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
|
||||
log.Debug("Uploading blob manifest to S3", "key", manifestKey, "size", len(blobManifest))
|
||||
if err := sm.s3Client.PutObject(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil {
|
||||
return fmt.Errorf("uploading blob manifest: %w", err)
|
||||
}
|
||||
log.Debug("Manifest upload complete", "key", manifestKey)
|
||||
|
||||
log.Info("Uploaded snapshot metadata",
|
||||
"snapshot_id", snapshotID,
|
||||
@@ -260,14 +305,18 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
|
||||
// Future implementation when we have snapshot_files table:
|
||||
//
|
||||
// DELETE FROM snapshots WHERE id != ?;
|
||||
// DELETE FROM files WHERE path NOT IN (
|
||||
// SELECT file_path FROM snapshot_files WHERE snapshot_id = ?
|
||||
// DELETE FROM files WHERE NOT EXISTS (
|
||||
// SELECT 1 FROM snapshot_files
|
||||
// WHERE snapshot_files.file_id = files.id
|
||||
// AND snapshot_files.snapshot_id = ?
|
||||
// );
|
||||
// DELETE FROM chunks WHERE chunk_hash NOT IN (
|
||||
// SELECT DISTINCT chunk_hash FROM file_chunks
|
||||
// DELETE FROM chunks WHERE NOT EXISTS (
|
||||
// SELECT 1 FROM file_chunks
|
||||
// WHERE file_chunks.chunk_hash = chunks.chunk_hash
|
||||
// );
|
||||
// DELETE FROM blobs WHERE blob_hash NOT IN (
|
||||
// SELECT DISTINCT blob_hash FROM blob_chunks
|
||||
// DELETE FROM blobs WHERE NOT EXISTS (
|
||||
// SELECT 1 FROM blob_chunks
|
||||
// WHERE blob_chunks.blob_hash = blobs.blob_hash
|
||||
// );
|
||||
func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, snapshotID string) error {
|
||||
// Open the temp database
|
||||
@@ -293,84 +342,127 @@ func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, s
|
||||
}()
|
||||
|
||||
// Step 1: Delete all other snapshots
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM snapshots WHERE id != ?", snapshotID)
|
||||
log.Debug("Deleting other snapshots", "keeping", snapshotID)
|
||||
database.LogSQL("Execute", "DELETE FROM snapshots WHERE id != ?", snapshotID)
|
||||
result, err := tx.ExecContext(ctx, "DELETE FROM snapshots WHERE id != ?", snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting other snapshots: %w", err)
|
||||
}
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
log.Debug("Deleted snapshots", "count", rowsAffected)
|
||||
|
||||
// Step 2: Delete files not in this snapshot
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
log.Debug("Deleting files not in current snapshot")
|
||||
database.LogSQL("Execute", `DELETE FROM files WHERE NOT EXISTS (SELECT 1 FROM snapshot_files WHERE snapshot_files.file_id = files.id AND snapshot_files.snapshot_id = ?)`, snapshotID)
|
||||
result, err = tx.ExecContext(ctx, `
|
||||
DELETE FROM files
|
||||
WHERE path NOT IN (
|
||||
SELECT file_path FROM snapshot_files WHERE snapshot_id = ?
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM snapshot_files
|
||||
WHERE snapshot_files.file_id = files.id
|
||||
AND snapshot_files.snapshot_id = ?
|
||||
)`, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned files: %w", err)
|
||||
}
|
||||
rowsAffected, _ = result.RowsAffected()
|
||||
log.Debug("Deleted files", "count", rowsAffected)
|
||||
|
||||
// Step 3: file_chunks will be deleted via CASCADE from files
|
||||
log.Debug("file_chunks will be deleted via CASCADE")
|
||||
|
||||
// Step 4: Delete chunk_files for deleted files
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
log.Debug("Deleting orphaned chunk_files")
|
||||
database.LogSQL("Execute", `DELETE FROM chunk_files WHERE NOT EXISTS (SELECT 1 FROM files WHERE files.id = chunk_files.file_id)`)
|
||||
result, err = tx.ExecContext(ctx, `
|
||||
DELETE FROM chunk_files
|
||||
WHERE file_path NOT IN (
|
||||
SELECT path FROM files
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM files
|
||||
WHERE files.id = chunk_files.file_id
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned chunk_files: %w", err)
|
||||
}
|
||||
rowsAffected, _ = result.RowsAffected()
|
||||
log.Debug("Deleted chunk_files", "count", rowsAffected)
|
||||
|
||||
// Step 5: Delete chunks with no remaining file references
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
log.Debug("Deleting orphaned chunks")
|
||||
database.LogSQL("Execute", `DELETE FROM chunks WHERE NOT EXISTS (SELECT 1 FROM file_chunks WHERE file_chunks.chunk_hash = chunks.chunk_hash)`)
|
||||
result, err = tx.ExecContext(ctx, `
|
||||
DELETE FROM chunks
|
||||
WHERE chunk_hash NOT IN (
|
||||
SELECT DISTINCT chunk_hash FROM file_chunks
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM file_chunks
|
||||
WHERE file_chunks.chunk_hash = chunks.chunk_hash
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned chunks: %w", err)
|
||||
}
|
||||
rowsAffected, _ = result.RowsAffected()
|
||||
log.Debug("Deleted chunks", "count", rowsAffected)
|
||||
|
||||
// Step 6: Delete blob_chunks for deleted chunks
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
log.Debug("Deleting orphaned blob_chunks")
|
||||
database.LogSQL("Execute", `DELETE FROM blob_chunks WHERE NOT EXISTS (SELECT 1 FROM chunks WHERE chunks.chunk_hash = blob_chunks.chunk_hash)`)
|
||||
result, err = tx.ExecContext(ctx, `
|
||||
DELETE FROM blob_chunks
|
||||
WHERE chunk_hash NOT IN (
|
||||
SELECT chunk_hash FROM chunks
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM chunks
|
||||
WHERE chunks.chunk_hash = blob_chunks.chunk_hash
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned blob_chunks: %w", err)
|
||||
}
|
||||
rowsAffected, _ = result.RowsAffected()
|
||||
log.Debug("Deleted blob_chunks", "count", rowsAffected)
|
||||
|
||||
// Step 7: Delete blobs not in this snapshot
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
log.Debug("Deleting blobs not in current snapshot")
|
||||
database.LogSQL("Execute", `DELETE FROM blobs WHERE NOT EXISTS (SELECT 1 FROM snapshot_blobs WHERE snapshot_blobs.blob_hash = blobs.blob_hash AND snapshot_blobs.snapshot_id = ?)`, snapshotID)
|
||||
result, err = tx.ExecContext(ctx, `
|
||||
DELETE FROM blobs
|
||||
WHERE blob_hash NOT IN (
|
||||
SELECT blob_hash FROM snapshot_blobs WHERE snapshot_id = ?
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM snapshot_blobs
|
||||
WHERE snapshot_blobs.blob_hash = blobs.blob_hash
|
||||
AND snapshot_blobs.snapshot_id = ?
|
||||
)`, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned blobs: %w", err)
|
||||
}
|
||||
rowsAffected, _ = result.RowsAffected()
|
||||
log.Debug("Deleted blobs not in snapshot", "count", rowsAffected)
|
||||
|
||||
// Step 8: Delete orphaned snapshot_files and snapshot_blobs
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID)
|
||||
log.Debug("Deleting orphaned snapshot_files")
|
||||
database.LogSQL("Execute", "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID)
|
||||
result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned snapshot_files: %w", err)
|
||||
}
|
||||
rowsAffected, _ = result.RowsAffected()
|
||||
log.Debug("Deleted snapshot_files", "count", rowsAffected)
|
||||
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID)
|
||||
log.Debug("Deleting orphaned snapshot_blobs")
|
||||
database.LogSQL("Execute", "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID)
|
||||
result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned snapshot_blobs: %w", err)
|
||||
}
|
||||
rowsAffected, _ = result.RowsAffected()
|
||||
log.Debug("Deleted snapshot_blobs", "count", rowsAffected)
|
||||
|
||||
// Commit transaction
|
||||
log.Debug("Committing cleanup transaction")
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("committing transaction: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Database cleanup complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// dumpDatabase creates a SQL dump of the database
|
||||
func (sm *SnapshotManager) dumpDatabase(dbPath, dumpPath string) error {
|
||||
log.Debug("Running sqlite3 dump command", "source", dbPath, "destination", dumpPath)
|
||||
cmd := exec.Command("sqlite3", dbPath, ".dump")
|
||||
|
||||
output, err := cmd.Output()
|
||||
@@ -378,6 +470,7 @@ func (sm *SnapshotManager) dumpDatabase(dbPath, dumpPath string) error {
|
||||
return fmt.Errorf("running sqlite3 dump: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("SQL dump generated", "size", len(output))
|
||||
if err := os.WriteFile(dumpPath, output, 0644); err != nil {
|
||||
return fmt.Errorf("writing dump file: %w", err)
|
||||
}
|
||||
@@ -387,27 +480,32 @@ func (sm *SnapshotManager) dumpDatabase(dbPath, dumpPath string) error {
|
||||
|
||||
// compressDump compresses the SQL dump using zstd
|
||||
func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error {
|
||||
log.Debug("Opening SQL dump for compression", "path", inputPath)
|
||||
input, err := os.Open(inputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening input file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Debug("Closing input file", "path", inputPath)
|
||||
if err := input.Close(); err != nil {
|
||||
log.Debug("Failed to close input file", "error", err)
|
||||
log.Debug("Failed to close input file", "path", inputPath, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("Creating output file for compressed data", "path", outputPath)
|
||||
output, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating output file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
log.Debug("Closing output file", "path", outputPath)
|
||||
if err := output.Close(); err != nil {
|
||||
log.Debug("Failed to close output file", "error", err)
|
||||
log.Debug("Failed to close output file", "path", outputPath, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create zstd encoder with good compression and multithreading
|
||||
log.Debug("Creating zstd compressor", "level", "SpeedBetterCompression", "concurrency", runtime.NumCPU())
|
||||
zstdWriter, err := zstd.NewWriter(output,
|
||||
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
|
||||
zstd.WithEncoderConcurrency(runtime.NumCPU()),
|
||||
@@ -422,6 +520,7 @@ func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("Compressing data")
|
||||
if _, err := io.Copy(zstdWriter, input); err != nil {
|
||||
return fmt.Errorf("compressing data: %w", err)
|
||||
}
|
||||
@@ -431,35 +530,44 @@ func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error {
|
||||
|
||||
// copyFile copies a file from src to dst
|
||||
func copyFile(src, dst string) error {
|
||||
log.Debug("Opening source file for copy", "path", src)
|
||||
sourceFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
log.Debug("Closing source file", "path", src)
|
||||
if err := sourceFile.Close(); err != nil {
|
||||
log.Debug("Failed to close source file", "error", err)
|
||||
log.Debug("Failed to close source file", "path", src, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("Creating destination file", "path", dst)
|
||||
destFile, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
log.Debug("Closing destination file", "path", dst)
|
||||
if err := destFile.Close(); err != nil {
|
||||
log.Debug("Failed to close destination file", "error", err)
|
||||
log.Debug("Failed to close destination file", "path", dst, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(destFile, sourceFile); err != nil {
|
||||
log.Debug("Copying file data")
|
||||
n, err := io.Copy(destFile, sourceFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug("File copy complete", "bytes_copied", n)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateBlobManifest creates a compressed JSON list of all blobs in the snapshot
|
||||
func (sm *SnapshotManager) generateBlobManifest(ctx context.Context, dbPath string, snapshotID string) ([]byte, error) {
|
||||
log.Debug("Generating blob manifest", "db_path", dbPath, "snapshot_id", snapshotID)
|
||||
|
||||
// Open the cleaned database using the database package
|
||||
db, err := database.New(ctx, dbPath)
|
||||
if err != nil {
|
||||
@@ -471,10 +579,12 @@ func (sm *SnapshotManager) generateBlobManifest(ctx context.Context, dbPath stri
|
||||
repos := database.NewRepositories(db)
|
||||
|
||||
// Get all blobs for this snapshot
|
||||
log.Debug("Querying blobs for snapshot", "snapshot_id", snapshotID)
|
||||
blobs, err := repos.Snapshots.GetBlobHashes(ctx, snapshotID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting snapshot blobs: %w", err)
|
||||
}
|
||||
log.Debug("Found blobs", "count", len(blobs))
|
||||
|
||||
// Create manifest structure
|
||||
manifest := struct {
|
||||
@@ -490,16 +600,20 @@ func (sm *SnapshotManager) generateBlobManifest(ctx context.Context, dbPath stri
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
log.Debug("Marshaling manifest to JSON")
|
||||
jsonData, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling manifest: %w", err)
|
||||
}
|
||||
log.Debug("JSON manifest created", "size", len(jsonData))
|
||||
|
||||
// Compress with zstd
|
||||
log.Debug("Compressing manifest with zstd")
|
||||
compressed, err := compressData(jsonData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compressing manifest: %w", err)
|
||||
}
|
||||
log.Debug("Manifest compressed", "original_size", len(jsonData), "compressed_size", len(compressed))
|
||||
|
||||
log.Info("Generated blob manifest",
|
||||
"snapshot_id", snapshotID,
|
||||
@@ -532,6 +646,15 @@ func compressData(data []byte) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// getFileSize returns the size of a file in bytes, or -1 if error
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
|
||||
// BackupStats contains statistics from a backup operation
|
||||
type BackupStats struct {
|
||||
FilesScanned int
|
||||
@@ -540,3 +663,108 @@ type BackupStats struct {
|
||||
BlobsCreated int
|
||||
BytesUploaded int64
|
||||
}
|
||||
|
||||
// ExtendedBackupStats contains additional statistics for comprehensive tracking
|
||||
type ExtendedBackupStats struct {
|
||||
BackupStats
|
||||
BlobUncompressedSize int64 // Total uncompressed size of all referenced blobs
|
||||
CompressionLevel int // Compression level used for this snapshot
|
||||
UploadDurationMs int64 // Total milliseconds spent uploading to S3
|
||||
}
|
||||
|
||||
// CleanupIncompleteSnapshots removes incomplete snapshots that don't have metadata in S3.
|
||||
// This is critical for data safety: incomplete snapshots can cause deduplication to skip
|
||||
// files that were never successfully backed up, resulting in data loss.
|
||||
func (sm *SnapshotManager) CleanupIncompleteSnapshots(ctx context.Context, hostname string) error {
|
||||
log.Info("Checking for incomplete snapshots", "hostname", hostname)
|
||||
|
||||
// Get all incomplete snapshots for this hostname
|
||||
incompleteSnapshots, err := sm.repos.Snapshots.GetIncompleteByHostname(ctx, hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting incomplete snapshots: %w", err)
|
||||
}
|
||||
|
||||
if len(incompleteSnapshots) == 0 {
|
||||
log.Debug("No incomplete snapshots found")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("Found incomplete snapshots", "count", len(incompleteSnapshots))
|
||||
|
||||
// Check each incomplete snapshot for metadata in S3
|
||||
for _, snapshot := range incompleteSnapshots {
|
||||
// Check if metadata exists in S3
|
||||
metadataKey := fmt.Sprintf("metadata/%s/db.zst", snapshot.ID)
|
||||
_, err := sm.s3Client.StatObject(ctx, metadataKey)
|
||||
|
||||
if err != nil {
|
||||
// Metadata doesn't exist in S3 - this is an incomplete snapshot
|
||||
log.Info("Cleaning up incomplete snapshot", "snapshot_id", snapshot.ID, "started_at", snapshot.StartedAt)
|
||||
|
||||
// Delete the snapshot and all its associations
|
||||
if err := sm.deleteSnapshot(ctx, snapshot.ID); err != nil {
|
||||
return fmt.Errorf("deleting incomplete snapshot %s: %w", snapshot.ID, err)
|
||||
}
|
||||
|
||||
log.Info("Deleted incomplete snapshot", "snapshot_id", snapshot.ID)
|
||||
} else {
|
||||
// Metadata exists - this snapshot was completed but database wasn't updated
|
||||
// This shouldn't happen in normal operation, but mark it complete
|
||||
log.Warn("Found snapshot with metadata but incomplete in DB", "snapshot_id", snapshot.ID)
|
||||
if err := sm.repos.Snapshots.MarkComplete(ctx, nil, snapshot.ID); err != nil {
|
||||
log.Error("Failed to mark snapshot complete", "snapshot_id", snapshot.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteSnapshot removes a snapshot and all its associations from the database
|
||||
func (sm *SnapshotManager) deleteSnapshot(ctx context.Context, snapshotID string) error {
|
||||
// Delete snapshot_files entries
|
||||
if err := sm.repos.Snapshots.DeleteSnapshotFiles(ctx, snapshotID); err != nil {
|
||||
return fmt.Errorf("deleting snapshot files: %w", err)
|
||||
}
|
||||
|
||||
// Delete snapshot_blobs entries
|
||||
if err := sm.repos.Snapshots.DeleteSnapshotBlobs(ctx, snapshotID); err != nil {
|
||||
return fmt.Errorf("deleting snapshot blobs: %w", err)
|
||||
}
|
||||
|
||||
// Delete the snapshot itself
|
||||
if err := sm.repos.Snapshots.Delete(ctx, snapshotID); err != nil {
|
||||
return fmt.Errorf("deleting snapshot: %w", err)
|
||||
}
|
||||
|
||||
// Clean up orphaned data
|
||||
log.Debug("Cleaning up orphaned data")
|
||||
if err := sm.cleanupOrphanedData(ctx); err != nil {
|
||||
return fmt.Errorf("cleaning up orphaned data: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes files, chunks, and blobs that are no longer referenced by any snapshot
|
||||
func (sm *SnapshotManager) cleanupOrphanedData(ctx context.Context) error {
|
||||
// Delete orphaned files (files not in any snapshot)
|
||||
log.Debug("Deleting orphaned files")
|
||||
if err := sm.repos.Files.DeleteOrphaned(ctx); err != nil {
|
||||
return fmt.Errorf("deleting orphaned files: %w", err)
|
||||
}
|
||||
|
||||
// Delete orphaned chunks (chunks not referenced by any file)
|
||||
log.Debug("Deleting orphaned chunks")
|
||||
if err := sm.repos.Chunks.DeleteOrphaned(ctx); err != nil {
|
||||
return fmt.Errorf("deleting orphaned chunks: %w", err)
|
||||
}
|
||||
|
||||
// Delete orphaned blobs (blobs not in any snapshot)
|
||||
log.Debug("Deleting orphaned blobs")
|
||||
if err := sm.repos.Blobs.DeleteOrphaned(ctx); err != nil {
|
||||
return fmt.Errorf("deleting orphaned blobs: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
// Package blob handles the creation of blobs - the final storage units for Vaultik.
|
||||
// A blob is a large file (up to 10GB) containing many compressed and encrypted chunks
|
||||
// from multiple source files. Blobs are content-addressed, meaning their filename
|
||||
// is derived from the SHA256 hash of their compressed and encrypted content.
|
||||
//
|
||||
// The blob creation process:
|
||||
// 1. Chunks are accumulated from multiple files
|
||||
// 2. The collection is compressed using zstd
|
||||
// 3. The compressed data is encrypted using age
|
||||
// 4. The encrypted blob is hashed to create its content-addressed name
|
||||
// 5. The blob is uploaded to S3 using the hash as the filename
|
||||
//
|
||||
// This design optimizes storage efficiency by batching many small chunks into
|
||||
// larger blobs, reducing the number of S3 operations and associated costs.
|
||||
package blob
|
||||
|
||||
import (
|
||||
@@ -20,19 +34,25 @@ import (
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// BlobHandler is called when a blob is finalized
|
||||
// BlobHandler is a callback function invoked when a blob is finalized and ready for upload.
|
||||
// The handler receives a BlobWithReader containing the blob metadata and a reader for
|
||||
// the compressed and encrypted blob content. The handler is responsible for uploading
|
||||
// the blob to storage and cleaning up any temporary files.
|
||||
type BlobHandler func(blob *BlobWithReader) error
|
||||
|
||||
// PackerConfig holds configuration for creating a Packer
|
||||
// PackerConfig holds configuration for creating a Packer.
|
||||
// All fields except BlobHandler are required.
|
||||
type PackerConfig struct {
|
||||
MaxBlobSize int64
|
||||
CompressionLevel int
|
||||
Encryptor Encryptor // Required - blobs are always encrypted
|
||||
Repositories *database.Repositories // For creating blob records
|
||||
BlobHandler BlobHandler // Optional - called when blob is ready
|
||||
MaxBlobSize int64 // Maximum size of a blob before forcing finalization
|
||||
CompressionLevel int // Zstd compression level (1-19, higher = better compression)
|
||||
Encryptor Encryptor // Age encryptor for blob encryption (required)
|
||||
Repositories *database.Repositories // Database repositories for tracking blob metadata
|
||||
BlobHandler BlobHandler // Optional callback when blob is ready for upload
|
||||
}
|
||||
|
||||
// Packer combines chunks into blobs with compression and encryption
|
||||
// Packer accumulates chunks and packs them into blobs.
|
||||
// It handles compression, encryption, and coordination with the database
|
||||
// to track blob metadata. Packer is thread-safe.
|
||||
type Packer struct {
|
||||
maxBlobSize int64
|
||||
compressionLevel int
|
||||
@@ -69,10 +89,13 @@ type blobInProgress struct {
|
||||
compressedSize int64 // Current compressed size (estimated)
|
||||
}
|
||||
|
||||
// ChunkRef represents a chunk to be added to a blob
|
||||
// ChunkRef represents a chunk to be added to a blob.
|
||||
// The Hash is the content-addressed identifier (SHA256) of the chunk,
|
||||
// and Data contains the raw chunk bytes. After adding to a blob,
|
||||
// the Data can be safely discarded as it's written to the blob immediately.
|
||||
type ChunkRef struct {
|
||||
Hash string
|
||||
Data []byte
|
||||
Hash string // SHA256 hash of the chunk data
|
||||
Data []byte // Raw chunk content
|
||||
}
|
||||
|
||||
// chunkInfo tracks chunk metadata in a blob
|
||||
@@ -107,7 +130,9 @@ type BlobWithReader struct {
|
||||
TempFile *os.File // Optional, only set for disk-based blobs
|
||||
}
|
||||
|
||||
// NewPacker creates a new blob packer
|
||||
// NewPacker creates a new blob packer that accumulates chunks into blobs.
|
||||
// The packer will automatically finalize blobs when they reach MaxBlobSize.
|
||||
// Returns an error if required configuration fields are missing or invalid.
|
||||
func NewPacker(cfg PackerConfig) (*Packer, error) {
|
||||
if cfg.Encryptor == nil {
|
||||
return nil, fmt.Errorf("encryptor is required - blobs must be encrypted")
|
||||
@@ -125,15 +150,21 @@ func NewPacker(cfg PackerConfig) (*Packer, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetBlobHandler sets the handler to be called when a blob is finalized
|
||||
// SetBlobHandler sets the handler to be called when a blob is finalized.
|
||||
// The handler is responsible for uploading the blob to storage.
|
||||
// If no handler is set, finalized blobs are stored in memory and can be
|
||||
// retrieved with GetFinishedBlobs().
|
||||
func (p *Packer) SetBlobHandler(handler BlobHandler) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.blobHandler = handler
|
||||
}
|
||||
|
||||
// AddChunk adds a chunk to the current blob
|
||||
// Returns ErrBlobSizeLimitExceeded if adding the chunk would exceed the size limit
|
||||
// AddChunk adds a chunk to the current blob being packed.
|
||||
// If adding the chunk would exceed MaxBlobSize, returns ErrBlobSizeLimitExceeded.
|
||||
// In this case, the caller should finalize the current blob and retry.
|
||||
// The chunk data is written immediately and can be garbage collected after this call.
|
||||
// Thread-safe.
|
||||
func (p *Packer) AddChunk(chunk *ChunkRef) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -166,7 +197,10 @@ func (p *Packer) AddChunk(chunk *ChunkRef) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush finalizes any pending blob
|
||||
// Flush finalizes any in-progress blob, compressing, encrypting, and hashing it.
|
||||
// This should be called after all chunks have been added to ensure no data is lost.
|
||||
// If a BlobHandler is set, it will be called with the finalized blob.
|
||||
// Thread-safe.
|
||||
func (p *Packer) Flush() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -180,8 +214,12 @@ func (p *Packer) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinalizeBlob finalizes the current blob being assembled
|
||||
// Caller must handle retrying the chunk that triggered size limit
|
||||
// FinalizeBlob finalizes the current blob being assembled.
|
||||
// This compresses the accumulated chunks, encrypts the result, and computes
|
||||
// the content-addressed hash. The finalized blob is either passed to the
|
||||
// BlobHandler (if set) or stored internally.
|
||||
// Caller must handle retrying any chunk that triggered size limit exceeded.
|
||||
// Not thread-safe - caller must hold the lock.
|
||||
func (p *Packer) FinalizeBlob() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -193,7 +231,10 @@ func (p *Packer) FinalizeBlob() error {
|
||||
return p.finalizeCurrentBlob()
|
||||
}
|
||||
|
||||
// GetFinishedBlobs returns all completed blobs and clears the list
|
||||
// GetFinishedBlobs returns all completed blobs and clears the internal list.
|
||||
// This is only used when no BlobHandler is set. After calling this method,
|
||||
// the caller is responsible for uploading the blobs to storage.
|
||||
// Thread-safe.
|
||||
func (p *Packer) GetFinishedBlobs() []*FinishedBlob {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
@@ -212,8 +253,8 @@ func (p *Packer) startNewBlob() error {
|
||||
if p.repos != nil {
|
||||
blob := &database.Blob{
|
||||
ID: blobID,
|
||||
Hash: "", // Will be set when finalized
|
||||
CreatedTS: time.Now(),
|
||||
Hash: "temp-placeholder-" + blobID, // Temporary placeholder until finalized
|
||||
CreatedTS: time.Now().UTC(),
|
||||
FinishedTS: nil,
|
||||
UncompressedSize: 0,
|
||||
CompressedSize: 0,
|
||||
@@ -237,7 +278,7 @@ func (p *Packer) startNewBlob() error {
|
||||
id: blobID,
|
||||
chunks: make([]*chunkInfo, 0),
|
||||
chunkSet: make(map[string]bool),
|
||||
startTime: time.Now(),
|
||||
startTime: time.Now().UTC(),
|
||||
tempFile: tempFile,
|
||||
hasher: sha256.New(),
|
||||
size: 0,
|
||||
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
"github.com/jotfs/fastcdc-go"
|
||||
)
|
||||
|
||||
// Chunk represents a single chunk of data
|
||||
// Chunk represents a single chunk of data produced by the content-defined chunking algorithm.
|
||||
// Each chunk is identified by its SHA256 hash and contains the raw data along with
|
||||
// its position and size information from the original file.
|
||||
type Chunk struct {
|
||||
Hash string // Content hash of the chunk
|
||||
Data []byte // Chunk data
|
||||
@@ -18,14 +20,20 @@ type Chunk struct {
|
||||
Size int64 // Size of the chunk
|
||||
}
|
||||
|
||||
// Chunker provides content-defined chunking using FastCDC
|
||||
// Chunker provides content-defined chunking using the FastCDC algorithm.
|
||||
// It splits data into variable-sized chunks based on content patterns, ensuring
|
||||
// that identical data sequences produce identical chunks regardless of their
|
||||
// position in the file. This enables efficient deduplication.
|
||||
type Chunker struct {
|
||||
avgChunkSize int
|
||||
minChunkSize int
|
||||
maxChunkSize int
|
||||
}
|
||||
|
||||
// NewChunker creates a new chunker with the specified average chunk size
|
||||
// NewChunker creates a new chunker with the specified average chunk size.
|
||||
// The actual chunk sizes will vary between avgChunkSize/4 and avgChunkSize*4
|
||||
// as recommended by the FastCDC algorithm. Typical values for avgChunkSize
|
||||
// are 64KB (65536), 256KB (262144), or 1MB (1048576).
|
||||
func NewChunker(avgChunkSize int64) *Chunker {
|
||||
// FastCDC recommends min = avg/4 and max = avg*4
|
||||
return &Chunker{
|
||||
@@ -35,7 +43,10 @@ func NewChunker(avgChunkSize int64) *Chunker {
|
||||
}
|
||||
}
|
||||
|
||||
// ChunkReader splits the reader into content-defined chunks
|
||||
// ChunkReader splits the reader into content-defined chunks and returns all chunks at once.
|
||||
// This method loads all chunk data into memory, so it should only be used for
|
||||
// reasonably sized inputs. For large files or streams, use ChunkReaderStreaming instead.
|
||||
// Returns an error if chunking fails or if reading from the input fails.
|
||||
func (c *Chunker) ChunkReader(r io.Reader) ([]Chunk, error) {
|
||||
opts := fastcdc.Options{
|
||||
MinSize: c.minChunkSize,
|
||||
@@ -80,20 +91,31 @@ func (c *Chunker) ChunkReader(r io.Reader) ([]Chunk, error) {
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// ChunkCallback is called for each chunk as it's processed
|
||||
// ChunkCallback is a function called for each chunk as it's processed.
|
||||
// The callback receives a Chunk containing the hash, data, offset, and size.
|
||||
// If the callback returns an error, chunk processing stops and the error is propagated.
|
||||
type ChunkCallback func(chunk Chunk) error
|
||||
|
||||
// ChunkReaderStreaming splits the reader into chunks and calls the callback for each
|
||||
func (c *Chunker) ChunkReaderStreaming(r io.Reader, callback ChunkCallback) error {
|
||||
// ChunkReaderStreaming splits the reader into chunks and calls the callback for each chunk.
|
||||
// This is the preferred method for processing large files or streams as it doesn't
|
||||
// accumulate all chunks in memory. The callback is invoked for each chunk as it's
|
||||
// produced, allowing for streaming processing and immediate storage or transmission.
|
||||
// Returns the SHA256 hash of the entire file content and an error if chunking fails,
|
||||
// reading fails, or if the callback returns an error.
|
||||
func (c *Chunker) ChunkReaderStreaming(r io.Reader, callback ChunkCallback) (string, error) {
|
||||
// Create a tee reader to calculate full file hash while chunking
|
||||
fileHasher := sha256.New()
|
||||
teeReader := io.TeeReader(r, fileHasher)
|
||||
|
||||
opts := fastcdc.Options{
|
||||
MinSize: c.minChunkSize,
|
||||
AverageSize: c.avgChunkSize,
|
||||
MaxSize: c.maxChunkSize,
|
||||
}
|
||||
|
||||
chunker, err := fastcdc.NewChunker(r, opts)
|
||||
chunker, err := fastcdc.NewChunker(teeReader, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating chunker: %w", err)
|
||||
return "", fmt.Errorf("creating chunker: %w", err)
|
||||
}
|
||||
|
||||
offset := int64(0)
|
||||
@@ -104,10 +126,10 @@ func (c *Chunker) ChunkReaderStreaming(r io.Reader, callback ChunkCallback) erro
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading chunk: %w", err)
|
||||
return "", fmt.Errorf("reading chunk: %w", err)
|
||||
}
|
||||
|
||||
// Calculate hash
|
||||
// Calculate chunk hash
|
||||
hash := sha256.Sum256(chunk.Data)
|
||||
|
||||
// Make a copy of the data since FastCDC reuses the buffer
|
||||
@@ -120,16 +142,20 @@ func (c *Chunker) ChunkReaderStreaming(r io.Reader, callback ChunkCallback) erro
|
||||
Offset: offset,
|
||||
Size: int64(len(chunk.Data)),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("callback error: %w", err)
|
||||
return "", fmt.Errorf("callback error: %w", err)
|
||||
}
|
||||
|
||||
offset += int64(len(chunk.Data))
|
||||
}
|
||||
|
||||
return nil
|
||||
// Return the full file hash
|
||||
return hex.EncodeToString(fileHasher.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// ChunkFile splits a file into content-defined chunks
|
||||
// ChunkFile splits a file into content-defined chunks by reading the entire file.
|
||||
// This is a convenience method that opens the file and passes it to ChunkReader.
|
||||
// For large files, consider using ChunkReaderStreaming with a file handle instead.
|
||||
// Returns an error if the file cannot be opened or if chunking fails.
|
||||
func (c *Chunker) ChunkFile(path string) ([]Chunk, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,7 +15,9 @@ import (
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// AppOptions contains common options for creating the fx application
|
||||
// AppOptions contains common options for creating the fx application.
|
||||
// It includes the configuration file path, logging options, and additional
|
||||
// fx modules and invocations that should be included in the application.
|
||||
type AppOptions struct {
|
||||
ConfigPath string
|
||||
LogOptions log.LogOptions
|
||||
@@ -27,13 +29,16 @@ type AppOptions struct {
|
||||
func setupGlobals(lc fx.Lifecycle, g *globals.Globals) {
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
g.StartTime = time.Now()
|
||||
g.StartTime = time.Now().UTC()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NewApp creates a new fx application with common modules
|
||||
// NewApp creates a new fx application with common modules.
|
||||
// It sets up the base modules (config, database, logging, globals) and
|
||||
// combines them with any additional modules specified in the options.
|
||||
// The returned fx.App is ready to be started with RunApp.
|
||||
func NewApp(opts AppOptions) *fx.App {
|
||||
baseModules := []fx.Option{
|
||||
fx.Supply(config.ConfigPath(opts.ConfigPath)),
|
||||
@@ -53,7 +58,10 @@ func NewApp(opts AppOptions) *fx.App {
|
||||
return fx.New(allOptions...)
|
||||
}
|
||||
|
||||
// RunApp starts and stops the fx application within the given context
|
||||
// RunApp starts and stops the fx application within the given context.
|
||||
// It handles graceful shutdown on interrupt signals (SIGINT, SIGTERM) and
|
||||
// ensures the application stops cleanly. The function blocks until the
|
||||
// application completes or is interrupted. Returns an error if startup fails.
|
||||
func RunApp(ctx context.Context, app *fx.App) error {
|
||||
// Set up signal handling for graceful shutdown
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
@@ -101,7 +109,9 @@ func RunApp(ctx context.Context, app *fx.App) error {
|
||||
}
|
||||
}
|
||||
|
||||
// RunWithApp is a helper that creates and runs an fx app with the given options
|
||||
// RunWithApp is a helper that creates and runs an fx app with the given options.
|
||||
// It combines NewApp and RunApp into a single convenient function. This is the
|
||||
// preferred way to run CLI commands that need the full application context.
|
||||
func RunWithApp(ctx context.Context, opts AppOptions) error {
|
||||
app := NewApp(opts)
|
||||
return RunApp(ctx, app)
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/backup"
|
||||
"git.eeqj.de/sneak/vaultik/internal/config"
|
||||
"git.eeqj.de/sneak/vaultik/internal/crypto"
|
||||
"git.eeqj.de/sneak/vaultik/internal/database"
|
||||
"git.eeqj.de/sneak/vaultik/internal/globals"
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"git.eeqj.de/sneak/vaultik/internal/s3"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// BackupOptions contains options for the backup command
|
||||
type BackupOptions struct {
|
||||
ConfigPath string
|
||||
Daemon bool
|
||||
Cron bool
|
||||
Prune bool
|
||||
}
|
||||
|
||||
// BackupApp contains all dependencies needed for running backups
|
||||
type BackupApp struct {
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
Repositories *database.Repositories
|
||||
ScannerFactory backup.ScannerFactory
|
||||
S3Client *s3.Client
|
||||
DB *database.DB
|
||||
Lifecycle fx.Lifecycle
|
||||
Shutdowner fx.Shutdowner
|
||||
}
|
||||
|
||||
// NewBackupCommand creates the backup command
|
||||
func NewBackupCommand() *cobra.Command {
|
||||
opts := &BackupOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "backup",
|
||||
Short: "Perform incremental backup",
|
||||
Long: `Backup configured directories using incremental deduplication and encryption.
|
||||
|
||||
Config is located at /etc/vaultik/config.yml, but can be overridden by specifying
|
||||
a path using --config or by setting VAULTIK_CONFIG to a path.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// If --config not specified, check environment variable
|
||||
if opts.ConfigPath == "" {
|
||||
opts.ConfigPath = os.Getenv("VAULTIK_CONFIG")
|
||||
}
|
||||
// If still not specified, use default
|
||||
if opts.ConfigPath == "" {
|
||||
defaultConfig := "/etc/vaultik/config.yml"
|
||||
if _, err := os.Stat(defaultConfig); err == nil {
|
||||
opts.ConfigPath = defaultConfig
|
||||
} else {
|
||||
return fmt.Errorf("no config file specified, VAULTIK_CONFIG not set, and %s not found", defaultConfig)
|
||||
}
|
||||
}
|
||||
return runBackup(cmd.Context(), opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&opts.ConfigPath, "config", "", "Path to config file")
|
||||
cmd.Flags().BoolVar(&opts.Daemon, "daemon", false, "Run in daemon mode with inotify monitoring")
|
||||
cmd.Flags().BoolVar(&opts.Cron, "cron", false, "Run in cron mode (silent unless error)")
|
||||
cmd.Flags().BoolVar(&opts.Prune, "prune", false, "Delete all previous snapshots and unreferenced blobs after backup")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runBackup(ctx context.Context, opts *BackupOptions) error {
|
||||
rootFlags := GetRootFlags()
|
||||
return RunWithApp(ctx, AppOptions{
|
||||
ConfigPath: opts.ConfigPath,
|
||||
LogOptions: log.LogOptions{
|
||||
Verbose: rootFlags.Verbose,
|
||||
Debug: rootFlags.Debug,
|
||||
Cron: opts.Cron,
|
||||
},
|
||||
Modules: []fx.Option{
|
||||
backup.Module,
|
||||
s3.Module,
|
||||
fx.Provide(fx.Annotate(
|
||||
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
|
||||
scannerFactory backup.ScannerFactory, s3Client *s3.Client, db *database.DB,
|
||||
lc fx.Lifecycle, shutdowner fx.Shutdowner) *BackupApp {
|
||||
return &BackupApp{
|
||||
Globals: g,
|
||||
Config: cfg,
|
||||
Repositories: repos,
|
||||
ScannerFactory: scannerFactory,
|
||||
S3Client: s3Client,
|
||||
DB: db,
|
||||
Lifecycle: lc,
|
||||
Shutdowner: shutdowner,
|
||||
}
|
||||
},
|
||||
)),
|
||||
},
|
||||
Invokes: []fx.Option{
|
||||
fx.Invoke(func(app *BackupApp, lc fx.Lifecycle) {
|
||||
// Create a cancellable context for the backup
|
||||
backupCtx, backupCancel := context.WithCancel(context.Background())
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
// Start the backup in a goroutine
|
||||
go func() {
|
||||
// Run the backup
|
||||
if err := app.runBackup(backupCtx, opts); err != nil {
|
||||
if err != context.Canceled {
|
||||
log.Error("Backup failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the app when backup completes
|
||||
if err := app.Shutdowner.Shutdown(); err != nil {
|
||||
log.Error("Failed to shutdown", "error", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
},
|
||||
OnStop: func(ctx context.Context) error {
|
||||
log.Debug("Stopping backup")
|
||||
// Cancel the backup context
|
||||
backupCancel()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// runBackup executes the backup operation
|
||||
func (app *BackupApp) runBackup(ctx context.Context, opts *BackupOptions) error {
|
||||
log.Info("Starting backup",
|
||||
"config", opts.ConfigPath,
|
||||
"version", app.Globals.Version,
|
||||
"commit", app.Globals.Commit,
|
||||
"index_path", app.Config.IndexPath,
|
||||
)
|
||||
|
||||
if opts.Daemon {
|
||||
log.Info("Running in daemon mode")
|
||||
// TODO: Implement daemon mode with inotify
|
||||
return fmt.Errorf("daemon mode not yet implemented")
|
||||
}
|
||||
|
||||
// Resolve source directories to absolute paths
|
||||
resolvedDirs := make([]string, 0, len(app.Config.SourceDirs))
|
||||
for _, dir := range app.Config.SourceDirs {
|
||||
absPath, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Resolve symlinks
|
||||
resolvedPath, err := filepath.EvalSymlinks(absPath)
|
||||
if err != nil {
|
||||
// If the path doesn't exist yet, use the absolute path
|
||||
if os.IsNotExist(err) {
|
||||
resolvedPath = absPath
|
||||
} else {
|
||||
return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
resolvedDirs = append(resolvedDirs, resolvedPath)
|
||||
}
|
||||
|
||||
// Create scanner with progress enabled (unless in cron mode)
|
||||
scanner := app.ScannerFactory(backup.ScannerParams{
|
||||
EnableProgress: !opts.Cron,
|
||||
})
|
||||
|
||||
// Perform a single backup run
|
||||
log.Notice("Starting backup", "source_dirs", len(resolvedDirs))
|
||||
for i, dir := range resolvedDirs {
|
||||
log.Info("Source directory", "index", i+1, "path", dir)
|
||||
}
|
||||
|
||||
totalFiles := 0
|
||||
totalBytes := int64(0)
|
||||
totalChunks := 0
|
||||
totalBlobs := 0
|
||||
|
||||
// Create a new snapshot at the beginning of backup
|
||||
hostname := app.Config.Hostname
|
||||
if hostname == "" {
|
||||
hostname, _ = os.Hostname()
|
||||
}
|
||||
|
||||
// Create encryptor if age recipients are configured
|
||||
var encryptor backup.Encryptor
|
||||
if len(app.Config.AgeRecipients) > 0 {
|
||||
cryptoEncryptor, err := crypto.NewEncryptor(app.Config.AgeRecipients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating encryptor: %w", err)
|
||||
}
|
||||
encryptor = cryptoEncryptor
|
||||
}
|
||||
|
||||
snapshotManager := backup.NewSnapshotManager(app.Repositories, app.S3Client, encryptor)
|
||||
snapshotID, err := snapshotManager.CreateSnapshot(ctx, hostname, app.Globals.Version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating snapshot: %w", err)
|
||||
}
|
||||
log.Info("Created snapshot", "snapshot_id", snapshotID)
|
||||
|
||||
for _, dir := range resolvedDirs {
|
||||
// Check if context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Backup cancelled")
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info("Scanning directory", "path", dir)
|
||||
result, err := scanner.Scan(ctx, dir, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan %s: %w", dir, err)
|
||||
}
|
||||
|
||||
totalFiles += result.FilesScanned
|
||||
totalBytes += result.BytesScanned
|
||||
totalChunks += result.ChunksCreated
|
||||
totalBlobs += result.BlobsCreated
|
||||
|
||||
log.Info("Directory scan complete",
|
||||
"path", dir,
|
||||
"files", result.FilesScanned,
|
||||
"files_skipped", result.FilesSkipped,
|
||||
"bytes", result.BytesScanned,
|
||||
"bytes_skipped", result.BytesSkipped,
|
||||
"chunks", result.ChunksCreated,
|
||||
"blobs", result.BlobsCreated,
|
||||
"duration", result.EndTime.Sub(result.StartTime))
|
||||
}
|
||||
|
||||
// Update snapshot statistics
|
||||
stats := backup.BackupStats{
|
||||
FilesScanned: totalFiles,
|
||||
BytesScanned: totalBytes,
|
||||
ChunksCreated: totalChunks,
|
||||
BlobsCreated: totalBlobs,
|
||||
BytesUploaded: totalBytes, // TODO: Track actual uploaded bytes
|
||||
}
|
||||
|
||||
if err := snapshotManager.UpdateSnapshotStats(ctx, snapshotID, stats); err != nil {
|
||||
return fmt.Errorf("updating snapshot stats: %w", err)
|
||||
}
|
||||
|
||||
// Mark snapshot as complete
|
||||
if err := snapshotManager.CompleteSnapshot(ctx, snapshotID); err != nil {
|
||||
return fmt.Errorf("completing snapshot: %w", err)
|
||||
}
|
||||
|
||||
// Export snapshot metadata
|
||||
// Export snapshot metadata without closing the database
|
||||
// The export function should handle its own database connection
|
||||
if err := snapshotManager.ExportSnapshotMetadata(ctx, app.Config.IndexPath, snapshotID); err != nil {
|
||||
return fmt.Errorf("exporting snapshot metadata: %w", err)
|
||||
}
|
||||
|
||||
log.Notice("Backup complete",
|
||||
"snapshot_id", snapshotID,
|
||||
"total_files", totalFiles,
|
||||
"total_bytes", totalBytes,
|
||||
"total_chunks", totalChunks,
|
||||
"total_blobs", totalBlobs)
|
||||
|
||||
if opts.Prune {
|
||||
log.Info("Pruning enabled - will delete old snapshots after backup")
|
||||
// TODO: Implement pruning
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
94
internal/cli/duration.go
Normal file
94
internal/cli/duration.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// parseDuration parses duration strings. Supports standard Go duration format
|
||||
// (e.g., "3h30m", "1h45m30s") as well as extended units:
|
||||
// - d: days (e.g., "30d", "7d")
|
||||
// - w: weeks (e.g., "2w", "4w")
|
||||
// - mo: months (30 days) (e.g., "6mo", "1mo")
|
||||
// - y: years (365 days) (e.g., "1y", "2y")
|
||||
//
|
||||
// Can combine units: "1y6mo", "2w3d", "1d12h30m"
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
// First try standard Go duration parsing
|
||||
if d, err := time.ParseDuration(s); err == nil {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Extended duration parsing
|
||||
// Check for negative values
|
||||
if strings.HasPrefix(strings.TrimSpace(s), "-") {
|
||||
return 0, fmt.Errorf("negative durations are not supported")
|
||||
}
|
||||
|
||||
// Pattern matches: number + unit, repeated
|
||||
re := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*([a-zA-Z]+)`)
|
||||
matches := re.FindAllStringSubmatch(s, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return 0, fmt.Errorf("invalid duration format: %q", s)
|
||||
}
|
||||
|
||||
var total time.Duration
|
||||
|
||||
for _, match := range matches {
|
||||
valueStr := match[1]
|
||||
unit := strings.ToLower(match[2])
|
||||
|
||||
value, err := strconv.ParseFloat(valueStr, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid number %q: %w", valueStr, err)
|
||||
}
|
||||
|
||||
var d time.Duration
|
||||
switch unit {
|
||||
// Standard time units
|
||||
case "ns", "nanosecond", "nanoseconds":
|
||||
d = time.Duration(value)
|
||||
case "us", "µs", "microsecond", "microseconds":
|
||||
d = time.Duration(value * float64(time.Microsecond))
|
||||
case "ms", "millisecond", "milliseconds":
|
||||
d = time.Duration(value * float64(time.Millisecond))
|
||||
case "s", "sec", "second", "seconds":
|
||||
d = time.Duration(value * float64(time.Second))
|
||||
case "m", "min", "minute", "minutes":
|
||||
d = time.Duration(value * float64(time.Minute))
|
||||
case "h", "hr", "hour", "hours":
|
||||
d = time.Duration(value * float64(time.Hour))
|
||||
// Extended units
|
||||
case "d", "day", "days":
|
||||
d = time.Duration(value * float64(24*time.Hour))
|
||||
case "w", "week", "weeks":
|
||||
d = time.Duration(value * float64(7*24*time.Hour))
|
||||
case "mo", "month", "months":
|
||||
// Using 30 days as approximation
|
||||
d = time.Duration(value * float64(30*24*time.Hour))
|
||||
case "y", "year", "years":
|
||||
// Using 365 days as approximation
|
||||
d = time.Duration(value * float64(365*24*time.Hour))
|
||||
default:
|
||||
// Try parsing as standard Go duration unit
|
||||
testStr := fmt.Sprintf("1%s", unit)
|
||||
if _, err := time.ParseDuration(testStr); err == nil {
|
||||
// It's a valid Go duration unit, parse the full value
|
||||
fullStr := fmt.Sprintf("%g%s", value, unit)
|
||||
if d, err = time.ParseDuration(fullStr); err != nil {
|
||||
return 0, fmt.Errorf("invalid duration %q: %w", fullStr, err)
|
||||
}
|
||||
} else {
|
||||
return 0, fmt.Errorf("unknown time unit %q", unit)
|
||||
}
|
||||
}
|
||||
|
||||
total += d
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
263
internal/cli/duration_test.go
Normal file
263
internal/cli/duration_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
// Standard Go durations
|
||||
{
|
||||
name: "standard seconds",
|
||||
input: "30s",
|
||||
expected: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "standard minutes",
|
||||
input: "45m",
|
||||
expected: 45 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "standard hours",
|
||||
input: "2h",
|
||||
expected: 2 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "standard combined",
|
||||
input: "3h30m",
|
||||
expected: 3*time.Hour + 30*time.Minute,
|
||||
},
|
||||
{
|
||||
name: "standard complex",
|
||||
input: "1h45m30s",
|
||||
expected: 1*time.Hour + 45*time.Minute + 30*time.Second,
|
||||
},
|
||||
{
|
||||
name: "standard with milliseconds",
|
||||
input: "1s500ms",
|
||||
expected: 1*time.Second + 500*time.Millisecond,
|
||||
},
|
||||
// Extended units - days
|
||||
{
|
||||
name: "single day",
|
||||
input: "1d",
|
||||
expected: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "multiple days",
|
||||
input: "7d",
|
||||
expected: 7 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "fractional days",
|
||||
input: "1.5d",
|
||||
expected: 36 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "days spelled out",
|
||||
input: "3days",
|
||||
expected: 3 * 24 * time.Hour,
|
||||
},
|
||||
// Extended units - weeks
|
||||
{
|
||||
name: "single week",
|
||||
input: "1w",
|
||||
expected: 7 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "multiple weeks",
|
||||
input: "4w",
|
||||
expected: 4 * 7 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "weeks spelled out",
|
||||
input: "2weeks",
|
||||
expected: 2 * 7 * 24 * time.Hour,
|
||||
},
|
||||
// Extended units - months
|
||||
{
|
||||
name: "single month",
|
||||
input: "1mo",
|
||||
expected: 30 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "multiple months",
|
||||
input: "6mo",
|
||||
expected: 6 * 30 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "months spelled out",
|
||||
input: "3months",
|
||||
expected: 3 * 30 * 24 * time.Hour,
|
||||
},
|
||||
// Extended units - years
|
||||
{
|
||||
name: "single year",
|
||||
input: "1y",
|
||||
expected: 365 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "multiple years",
|
||||
input: "2y",
|
||||
expected: 2 * 365 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "years spelled out",
|
||||
input: "1year",
|
||||
expected: 365 * 24 * time.Hour,
|
||||
},
|
||||
// Combined extended units
|
||||
{
|
||||
name: "weeks and days",
|
||||
input: "2w3d",
|
||||
expected: 2*7*24*time.Hour + 3*24*time.Hour,
|
||||
},
|
||||
{
|
||||
name: "years and months",
|
||||
input: "1y6mo",
|
||||
expected: 365*24*time.Hour + 6*30*24*time.Hour,
|
||||
},
|
||||
{
|
||||
name: "days and hours",
|
||||
input: "1d12h",
|
||||
expected: 24*time.Hour + 12*time.Hour,
|
||||
},
|
||||
{
|
||||
name: "complex combination",
|
||||
input: "1y2mo3w4d5h6m7s",
|
||||
expected: 365*24*time.Hour + 2*30*24*time.Hour + 3*7*24*time.Hour + 4*24*time.Hour + 5*time.Hour + 6*time.Minute + 7*time.Second,
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
input: "1d 12h 30m",
|
||||
expected: 24*time.Hour + 12*time.Hour + 30*time.Minute,
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
name: "zero duration",
|
||||
input: "0s",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "large duration",
|
||||
input: "10y",
|
||||
expected: 10 * 365 * 24 * time.Hour,
|
||||
},
|
||||
// Error cases
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid format",
|
||||
input: "abc",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown unit",
|
||||
input: "5x",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid number",
|
||||
input: "xyzd",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative not supported",
|
||||
input: "-5d",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseDuration(tt.input)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "expected error for input %q", tt.input)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err, "unexpected error for input %q", tt.input)
|
||||
assert.Equal(t, tt.expected, got, "duration mismatch for input %q", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDurationSpecialCases(t *testing.T) {
|
||||
// Test that standard Go durations work exactly as expected
|
||||
standardDurations := []string{
|
||||
"300ms",
|
||||
"1.5h",
|
||||
"2h45m",
|
||||
"72h",
|
||||
"1us",
|
||||
"1µs",
|
||||
"1ns",
|
||||
}
|
||||
|
||||
for _, d := range standardDurations {
|
||||
expected, err := time.ParseDuration(d)
|
||||
assert.NoError(t, err)
|
||||
|
||||
got, err := parseDuration(d)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, got, "standard duration %q should parse identically", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDurationRealWorldExamples(t *testing.T) {
|
||||
// Test real-world snapshot purge scenarios
|
||||
tests := []struct {
|
||||
description string
|
||||
input string
|
||||
olderThan time.Duration
|
||||
}{
|
||||
{
|
||||
description: "keep snapshots from last 30 days",
|
||||
input: "30d",
|
||||
olderThan: 30 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
description: "keep snapshots from last 6 months",
|
||||
input: "6mo",
|
||||
olderThan: 6 * 30 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
description: "keep snapshots from last year",
|
||||
input: "1y",
|
||||
olderThan: 365 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
description: "keep snapshots from last week and a half",
|
||||
input: "1w3d",
|
||||
olderThan: 10 * 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
description: "keep snapshots from last 90 days",
|
||||
input: "90d",
|
||||
olderThan: 90 * 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
got, err := parseDuration(tt.input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.olderThan, got)
|
||||
|
||||
// Verify the duration makes sense for snapshot purging
|
||||
assert.Greater(t, got, time.Hour, "snapshot purge duration should be at least an hour")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// CLIEntry is the main entry point for the CLI application
|
||||
// CLIEntry is the main entry point for the CLI application.
|
||||
// It creates the root command, executes it, and exits with status 1
|
||||
// if an error occurs. This function should be called from main().
|
||||
func CLIEntry() {
|
||||
rootCmd := NewRootCommand()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestCLIEntry(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify all subcommands are registered
|
||||
expectedCommands := []string{"backup", "restore", "prune", "verify", "fetch"}
|
||||
expectedCommands := []string{"snapshot", "store", "restore", "prune", "verify", "fetch"}
|
||||
for _, expected := range expectedCommands {
|
||||
found := false
|
||||
for _, cmd := range cmd.Commands() {
|
||||
@@ -32,19 +32,24 @@ func TestCLIEntry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify backup command has proper flags
|
||||
backupCmd, _, err := cmd.Find([]string{"backup"})
|
||||
// Verify snapshot command has subcommands
|
||||
snapshotCmd, _, err := cmd.Find([]string{"snapshot"})
|
||||
if err != nil {
|
||||
t.Errorf("Failed to find backup command: %v", err)
|
||||
t.Errorf("Failed to find snapshot command: %v", err)
|
||||
} else {
|
||||
if backupCmd.Flag("config") == nil {
|
||||
t.Error("Backup command missing --config flag")
|
||||
}
|
||||
if backupCmd.Flag("daemon") == nil {
|
||||
t.Error("Backup command missing --daemon flag")
|
||||
}
|
||||
if backupCmd.Flag("cron") == nil {
|
||||
t.Error("Backup command missing --cron flag")
|
||||
// Check snapshot subcommands
|
||||
expectedSubCommands := []string{"create", "list", "purge", "verify"}
|
||||
for _, expected := range expectedSubCommands {
|
||||
found := false
|
||||
for _, subcmd := range snapshotCmd.Commands() {
|
||||
if subcmd.Use == expected || subcmd.Name() == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected snapshot subcommand '%s' not found", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// RootFlags holds global flags
|
||||
// RootFlags holds global flags that apply to all commands.
|
||||
// These flags are defined on the root command and inherited by all subcommands.
|
||||
type RootFlags struct {
|
||||
Verbose bool
|
||||
Debug bool
|
||||
ConfigPath string
|
||||
Verbose bool
|
||||
Debug bool
|
||||
}
|
||||
|
||||
var rootFlags RootFlags
|
||||
|
||||
// NewRootCommand creates the root cobra command
|
||||
// NewRootCommand creates the root cobra command for the vaultik CLI.
|
||||
// It sets up the command structure, global flags, and adds all subcommands.
|
||||
// This is the main entry point for the CLI command hierarchy.
|
||||
func NewRootCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "vaultik",
|
||||
@@ -24,23 +31,49 @@ on the source system.`,
|
||||
}
|
||||
|
||||
// Add global flags
|
||||
cmd.PersistentFlags().StringVar(&rootFlags.ConfigPath, "config", "", "Path to config file (default: $VAULTIK_CONFIG or /etc/vaultik/config.yml)")
|
||||
cmd.PersistentFlags().BoolVarP(&rootFlags.Verbose, "verbose", "v", false, "Enable verbose output")
|
||||
cmd.PersistentFlags().BoolVar(&rootFlags.Debug, "debug", false, "Enable debug output")
|
||||
|
||||
// Add subcommands
|
||||
cmd.AddCommand(
|
||||
NewBackupCommand(),
|
||||
NewRestoreCommand(),
|
||||
NewPruneCommand(),
|
||||
NewVerifyCommand(),
|
||||
NewFetchCommand(),
|
||||
SnapshotCmd(),
|
||||
NewStoreCommand(),
|
||||
NewSnapshotCommand(),
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// GetRootFlags returns the global flags
|
||||
// GetRootFlags returns the global flags that were parsed from the command line.
|
||||
// This allows subcommands to access global flag values like verbosity and config path.
|
||||
func GetRootFlags() RootFlags {
|
||||
return rootFlags
|
||||
}
|
||||
|
||||
// ResolveConfigPath resolves the config file path from flags, environment, or default.
|
||||
// It checks in order: 1) --config flag, 2) VAULTIK_CONFIG environment variable,
|
||||
// 3) default location /etc/vaultik/config.yml. Returns an error if no valid
|
||||
// config file can be found through any of these methods.
|
||||
func ResolveConfigPath() (string, error) {
|
||||
// First check global flag
|
||||
if rootFlags.ConfigPath != "" {
|
||||
return rootFlags.ConfigPath, nil
|
||||
}
|
||||
|
||||
// Then check environment variable
|
||||
if envPath := os.Getenv("VAULTIK_CONFIG"); envPath != "" {
|
||||
return envPath, nil
|
||||
}
|
||||
|
||||
// Finally check default location
|
||||
defaultPath := "/etc/vaultik/config.yml"
|
||||
if _, err := os.Stat(defaultPath); err == nil {
|
||||
return defaultPath, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no config file specified, VAULTIK_CONFIG not set, and %s not found", defaultPath)
|
||||
}
|
||||
|
||||
@@ -1,90 +1,892 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/backup"
|
||||
"git.eeqj.de/sneak/vaultik/internal/config"
|
||||
"git.eeqj.de/sneak/vaultik/internal/crypto"
|
||||
"git.eeqj.de/sneak/vaultik/internal/database"
|
||||
"git.eeqj.de/sneak/vaultik/internal/globals"
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"git.eeqj.de/sneak/vaultik/internal/s3"
|
||||
"github.com/dustin/go-humanize"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
func SnapshotCmd() *cobra.Command {
|
||||
// SnapshotCreateOptions contains options for the snapshot create command
|
||||
type SnapshotCreateOptions struct {
|
||||
Daemon bool
|
||||
Cron bool
|
||||
Prune bool
|
||||
}
|
||||
|
||||
// SnapshotCreateApp contains all dependencies needed for creating snapshots
|
||||
type SnapshotCreateApp struct {
|
||||
Globals *globals.Globals
|
||||
Config *config.Config
|
||||
Repositories *database.Repositories
|
||||
ScannerFactory backup.ScannerFactory
|
||||
S3Client *s3.Client
|
||||
DB *database.DB
|
||||
Lifecycle fx.Lifecycle
|
||||
Shutdowner fx.Shutdowner
|
||||
}
|
||||
|
||||
// SnapshotApp contains dependencies for snapshot commands
|
||||
type SnapshotApp struct {
|
||||
*SnapshotCreateApp // Reuse snapshot creation functionality
|
||||
S3Client *s3.Client
|
||||
}
|
||||
|
||||
// SnapshotInfo represents snapshot information for listing
|
||||
type SnapshotInfo struct {
|
||||
ID string `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
CompressedSize int64 `json:"compressed_size"`
|
||||
}
|
||||
|
||||
// NewSnapshotCommand creates the snapshot command and subcommands
|
||||
func NewSnapshotCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "snapshot",
|
||||
Short: "Manage snapshots",
|
||||
Long: "Commands for listing, removing, and querying snapshots",
|
||||
Short: "Snapshot management commands",
|
||||
Long: "Commands for creating, listing, and managing snapshots",
|
||||
}
|
||||
|
||||
cmd.AddCommand(snapshotListCmd())
|
||||
cmd.AddCommand(snapshotRmCmd())
|
||||
cmd.AddCommand(snapshotLatestCmd())
|
||||
// Add subcommands
|
||||
cmd.AddCommand(newSnapshotCreateCommand())
|
||||
cmd.AddCommand(newSnapshotListCommand())
|
||||
cmd.AddCommand(newSnapshotPurgeCommand())
|
||||
cmd.AddCommand(newSnapshotVerifyCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func snapshotListCmd() *cobra.Command {
|
||||
var (
|
||||
bucket string
|
||||
prefix string
|
||||
limit int
|
||||
// newSnapshotCreateCommand creates the 'snapshot create' subcommand
|
||||
func newSnapshotCreateCommand() *cobra.Command {
|
||||
opts := &SnapshotCreateOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new snapshot",
|
||||
Long: `Creates a new snapshot of the configured directories.
|
||||
|
||||
Config is located at /etc/vaultik/config.yml by default, but can be overridden by
|
||||
specifying a path using --config or by setting VAULTIK_CONFIG to a path.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Use unified config resolution
|
||||
configPath, err := ResolveConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use the backup functionality from cli package
|
||||
rootFlags := GetRootFlags()
|
||||
return RunWithApp(cmd.Context(), AppOptions{
|
||||
ConfigPath: configPath,
|
||||
LogOptions: log.LogOptions{
|
||||
Verbose: rootFlags.Verbose,
|
||||
Debug: rootFlags.Debug,
|
||||
Cron: opts.Cron,
|
||||
},
|
||||
Modules: []fx.Option{
|
||||
backup.Module,
|
||||
s3.Module,
|
||||
fx.Provide(fx.Annotate(
|
||||
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
|
||||
scannerFactory backup.ScannerFactory, s3Client *s3.Client, db *database.DB,
|
||||
lc fx.Lifecycle, shutdowner fx.Shutdowner) *SnapshotCreateApp {
|
||||
return &SnapshotCreateApp{
|
||||
Globals: g,
|
||||
Config: cfg,
|
||||
Repositories: repos,
|
||||
ScannerFactory: scannerFactory,
|
||||
S3Client: s3Client,
|
||||
DB: db,
|
||||
Lifecycle: lc,
|
||||
Shutdowner: shutdowner,
|
||||
}
|
||||
},
|
||||
)),
|
||||
},
|
||||
Invokes: []fx.Option{
|
||||
fx.Invoke(func(app *SnapshotCreateApp, lc fx.Lifecycle) {
|
||||
// Create a cancellable context for the snapshot
|
||||
snapshotCtx, snapshotCancel := context.WithCancel(context.Background())
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
// Start the snapshot creation in a goroutine
|
||||
go func() {
|
||||
// Run the snapshot creation
|
||||
if err := app.runSnapshot(snapshotCtx, opts); err != nil {
|
||||
if err != context.Canceled {
|
||||
log.Error("Snapshot creation failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the app when snapshot completes
|
||||
if err := app.Shutdowner.Shutdown(); err != nil {
|
||||
log.Error("Failed to shutdown", "error", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
},
|
||||
OnStop: func(ctx context.Context) error {
|
||||
log.Debug("Stopping snapshot creation")
|
||||
// Cancel the snapshot context
|
||||
snapshotCancel()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}),
|
||||
},
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&opts.Daemon, "daemon", false, "Run in daemon mode with inotify monitoring")
|
||||
cmd.Flags().BoolVar(&opts.Cron, "cron", false, "Run in cron mode (silent unless error)")
|
||||
cmd.Flags().BoolVar(&opts.Prune, "prune", false, "Delete all previous snapshots and unreferenced blobs after backup")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runSnapshot executes the snapshot creation operation
|
||||
func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCreateOptions) error {
|
||||
snapshotStartTime := time.Now()
|
||||
|
||||
log.Info("Starting snapshot creation",
|
||||
"version", app.Globals.Version,
|
||||
"commit", app.Globals.Commit,
|
||||
"index_path", app.Config.IndexPath,
|
||||
)
|
||||
|
||||
// Clean up incomplete snapshots FIRST, before any scanning
|
||||
// This is critical for data safety - see CleanupIncompleteSnapshots for details
|
||||
hostname := app.Config.Hostname
|
||||
if hostname == "" {
|
||||
hostname, _ = os.Hostname()
|
||||
}
|
||||
|
||||
// Create encryptor if needed for snapshot manager
|
||||
var encryptor backup.Encryptor
|
||||
if len(app.Config.AgeRecipients) > 0 {
|
||||
cryptoEncryptor, err := crypto.NewEncryptor(app.Config.AgeRecipients)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating encryptor: %w", err)
|
||||
}
|
||||
encryptor = cryptoEncryptor
|
||||
}
|
||||
|
||||
snapshotManager := backup.NewSnapshotManager(app.Repositories, app.S3Client, encryptor)
|
||||
// CRITICAL: This MUST succeed. If we fail to clean up incomplete snapshots,
|
||||
// the deduplication logic will think files from the incomplete snapshot were
|
||||
// already backed up and skip them, resulting in data loss.
|
||||
if err := snapshotManager.CleanupIncompleteSnapshots(ctx, hostname); err != nil {
|
||||
return fmt.Errorf("cleanup incomplete snapshots: %w", err)
|
||||
}
|
||||
|
||||
if opts.Daemon {
|
||||
log.Info("Running in daemon mode")
|
||||
// TODO: Implement daemon mode with inotify
|
||||
return fmt.Errorf("daemon mode not yet implemented")
|
||||
}
|
||||
|
||||
// Resolve source directories to absolute paths
|
||||
resolvedDirs := make([]string, 0, len(app.Config.SourceDirs))
|
||||
for _, dir := range app.Config.SourceDirs {
|
||||
absPath, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Resolve symlinks
|
||||
resolvedPath, err := filepath.EvalSymlinks(absPath)
|
||||
if err != nil {
|
||||
// If the path doesn't exist yet, use the absolute path
|
||||
if os.IsNotExist(err) {
|
||||
resolvedPath = absPath
|
||||
} else {
|
||||
return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
resolvedDirs = append(resolvedDirs, resolvedPath)
|
||||
}
|
||||
|
||||
// Create scanner with progress enabled (unless in cron mode)
|
||||
scanner := app.ScannerFactory(backup.ScannerParams{
|
||||
EnableProgress: !opts.Cron,
|
||||
})
|
||||
|
||||
// Perform a single snapshot run
|
||||
log.Notice("Starting snapshot", "source_dirs", len(resolvedDirs))
|
||||
for i, dir := range resolvedDirs {
|
||||
log.Info("Source directory", "index", i+1, "path", dir)
|
||||
}
|
||||
|
||||
// Statistics tracking
|
||||
totalFiles := 0
|
||||
totalBytes := int64(0)
|
||||
totalChunks := 0
|
||||
totalBlobs := 0
|
||||
totalBytesSkipped := int64(0)
|
||||
totalFilesSkipped := 0
|
||||
totalBytesUploaded := int64(0)
|
||||
totalBlobsUploaded := 0
|
||||
uploadDuration := time.Duration(0)
|
||||
|
||||
// Create a new snapshot at the beginning
|
||||
// (hostname, encryptor, and snapshotManager already created above for cleanup)
|
||||
snapshotID, err := snapshotManager.CreateSnapshot(ctx, hostname, app.Globals.Version, app.Globals.Commit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating snapshot: %w", err)
|
||||
}
|
||||
log.Info("Created snapshot", "snapshot_id", snapshotID)
|
||||
|
||||
for _, dir := range resolvedDirs {
|
||||
// Check if context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("Snapshot creation cancelled")
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Info("Scanning directory", "path", dir)
|
||||
result, err := scanner.Scan(ctx, dir, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan %s: %w", dir, err)
|
||||
}
|
||||
|
||||
totalFiles += result.FilesScanned
|
||||
totalBytes += result.BytesScanned
|
||||
totalChunks += result.ChunksCreated
|
||||
totalBlobs += result.BlobsCreated
|
||||
totalFilesSkipped += result.FilesSkipped
|
||||
totalBytesSkipped += result.BytesSkipped
|
||||
|
||||
log.Info("Directory scan complete",
|
||||
"path", dir,
|
||||
"files", result.FilesScanned,
|
||||
"files_skipped", result.FilesSkipped,
|
||||
"bytes", result.BytesScanned,
|
||||
"bytes_skipped", result.BytesSkipped,
|
||||
"chunks", result.ChunksCreated,
|
||||
"blobs", result.BlobsCreated,
|
||||
"duration", result.EndTime.Sub(result.StartTime))
|
||||
}
|
||||
|
||||
// Get upload statistics from scanner progress if available
|
||||
if s := scanner.GetProgress(); s != nil {
|
||||
stats := s.GetStats()
|
||||
totalBytesUploaded = stats.BytesUploaded.Load()
|
||||
totalBlobsUploaded = int(stats.BlobsUploaded.Load())
|
||||
uploadDuration = time.Duration(stats.UploadDurationMs.Load()) * time.Millisecond
|
||||
}
|
||||
|
||||
// Update snapshot statistics with extended fields
|
||||
extStats := backup.ExtendedBackupStats{
|
||||
BackupStats: backup.BackupStats{
|
||||
FilesScanned: totalFiles,
|
||||
BytesScanned: totalBytes,
|
||||
ChunksCreated: totalChunks,
|
||||
BlobsCreated: totalBlobs,
|
||||
BytesUploaded: totalBytesUploaded,
|
||||
},
|
||||
BlobUncompressedSize: 0, // Will be set from database query below
|
||||
CompressionLevel: app.Config.CompressionLevel,
|
||||
UploadDurationMs: uploadDuration.Milliseconds(),
|
||||
}
|
||||
|
||||
if err := snapshotManager.UpdateSnapshotStatsExtended(ctx, snapshotID, extStats); err != nil {
|
||||
return fmt.Errorf("updating snapshot stats: %w", err)
|
||||
}
|
||||
|
||||
// Mark snapshot as complete
|
||||
if err := snapshotManager.CompleteSnapshot(ctx, snapshotID); err != nil {
|
||||
return fmt.Errorf("completing snapshot: %w", err)
|
||||
}
|
||||
|
||||
// Export snapshot metadata
|
||||
// Export snapshot metadata without closing the database
|
||||
// The export function should handle its own database connection
|
||||
if err := snapshotManager.ExportSnapshotMetadata(ctx, app.Config.IndexPath, snapshotID); err != nil {
|
||||
return fmt.Errorf("exporting snapshot metadata: %w", err)
|
||||
}
|
||||
|
||||
// Calculate final statistics
|
||||
snapshotDuration := time.Since(snapshotStartTime)
|
||||
totalFilesChanged := totalFiles - totalFilesSkipped
|
||||
totalBytesChanged := totalBytes
|
||||
totalBytesAll := totalBytes + totalBytesSkipped
|
||||
|
||||
// Calculate upload speed
|
||||
var avgUploadSpeed string
|
||||
if totalBytesUploaded > 0 && uploadDuration > 0 {
|
||||
bytesPerSec := float64(totalBytesUploaded) / uploadDuration.Seconds()
|
||||
bitsPerSec := bytesPerSec * 8
|
||||
if bitsPerSec >= 1e9 {
|
||||
avgUploadSpeed = fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9)
|
||||
} else if bitsPerSec >= 1e6 {
|
||||
avgUploadSpeed = fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6)
|
||||
} else if bitsPerSec >= 1e3 {
|
||||
avgUploadSpeed = fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3)
|
||||
} else {
|
||||
avgUploadSpeed = fmt.Sprintf("%.0f bit/s", bitsPerSec)
|
||||
}
|
||||
} else {
|
||||
avgUploadSpeed = "N/A"
|
||||
}
|
||||
|
||||
// Get total blob sizes from database
|
||||
totalBlobSizeCompressed := int64(0)
|
||||
totalBlobSizeUncompressed := int64(0)
|
||||
if blobHashes, err := app.Repositories.Snapshots.GetBlobHashes(ctx, snapshotID); err == nil {
|
||||
for _, hash := range blobHashes {
|
||||
if blob, err := app.Repositories.Blobs.GetByHash(ctx, hash); err == nil && blob != nil {
|
||||
totalBlobSizeCompressed += blob.CompressedSize
|
||||
totalBlobSizeUncompressed += blob.UncompressedSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate compression ratio
|
||||
var compressionRatio float64
|
||||
if totalBlobSizeUncompressed > 0 {
|
||||
compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed)
|
||||
} else {
|
||||
compressionRatio = 1.0
|
||||
}
|
||||
|
||||
// Print comprehensive summary
|
||||
log.Notice("=== Snapshot Summary ===")
|
||||
log.Info("Snapshot ID", "id", snapshotID)
|
||||
log.Info("Source files",
|
||||
"total_count", formatNumber(totalFiles),
|
||||
"total_size", humanize.Bytes(uint64(totalBytesAll)))
|
||||
log.Info("Changed files",
|
||||
"count", formatNumber(totalFilesChanged),
|
||||
"size", humanize.Bytes(uint64(totalBytesChanged)))
|
||||
log.Info("Unchanged files",
|
||||
"count", formatNumber(totalFilesSkipped),
|
||||
"size", humanize.Bytes(uint64(totalBytesSkipped)))
|
||||
log.Info("Blob storage",
|
||||
"total_uncompressed", humanize.Bytes(uint64(totalBlobSizeUncompressed)),
|
||||
"total_compressed", humanize.Bytes(uint64(totalBlobSizeCompressed)),
|
||||
"compression_ratio", fmt.Sprintf("%.2fx", compressionRatio),
|
||||
"compression_level", app.Config.CompressionLevel)
|
||||
log.Info("Upload activity",
|
||||
"bytes_uploaded", humanize.Bytes(uint64(totalBytesUploaded)),
|
||||
"blobs_uploaded", totalBlobsUploaded,
|
||||
"upload_time", formatDuration(uploadDuration),
|
||||
"avg_speed", avgUploadSpeed)
|
||||
log.Info("Total time", "duration", formatDuration(snapshotDuration))
|
||||
log.Notice("==========================")
|
||||
|
||||
if opts.Prune {
|
||||
log.Info("Pruning enabled - will delete old snapshots after snapshot")
|
||||
// TODO: Implement pruning
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newSnapshotListCommand creates the 'snapshot list' subcommand
|
||||
func newSnapshotListCommand() *cobra.Command {
|
||||
var jsonOutput bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List snapshots",
|
||||
Long: "List all snapshots in the bucket, sorted by timestamp",
|
||||
Short: "List all snapshots",
|
||||
Long: "Lists all snapshots with their ID, timestamp, and compressed size",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
panic("unimplemented")
|
||||
return runSnapshotCommand(cmd.Context(), func(app *SnapshotApp) error {
|
||||
return app.List(cmd.Context(), jsonOutput)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&bucket, "bucket", "", "S3 bucket name")
|
||||
cmd.Flags().StringVar(&prefix, "prefix", "", "S3 prefix")
|
||||
cmd.Flags().IntVar(&limit, "limit", 10, "Maximum number of snapshots to list")
|
||||
_ = cmd.MarkFlagRequired("bucket")
|
||||
cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output in JSON format")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func snapshotRmCmd() *cobra.Command {
|
||||
var (
|
||||
bucket string
|
||||
prefix string
|
||||
snapshot string
|
||||
)
|
||||
// newSnapshotPurgeCommand creates the 'snapshot purge' subcommand
|
||||
func newSnapshotPurgeCommand() *cobra.Command {
|
||||
var keepLatest bool
|
||||
var olderThan string
|
||||
var force bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "rm",
|
||||
Short: "Remove a snapshot",
|
||||
Long: "Remove a snapshot and optionally its associated blobs",
|
||||
Use: "purge",
|
||||
Short: "Purge old snapshots",
|
||||
Long: "Removes snapshots based on age or count criteria",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
panic("unimplemented")
|
||||
// Validate flags
|
||||
if !keepLatest && olderThan == "" {
|
||||
return fmt.Errorf("must specify either --keep-latest or --older-than")
|
||||
}
|
||||
if keepLatest && olderThan != "" {
|
||||
return fmt.Errorf("cannot specify both --keep-latest and --older-than")
|
||||
}
|
||||
|
||||
return runSnapshotCommand(cmd.Context(), func(app *SnapshotApp) error {
|
||||
return app.Purge(cmd.Context(), keepLatest, olderThan, force)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&bucket, "bucket", "", "S3 bucket name")
|
||||
cmd.Flags().StringVar(&prefix, "prefix", "", "S3 prefix")
|
||||
cmd.Flags().StringVar(&snapshot, "snapshot", "", "Snapshot ID to remove")
|
||||
_ = cmd.MarkFlagRequired("bucket")
|
||||
_ = cmd.MarkFlagRequired("snapshot")
|
||||
cmd.Flags().BoolVar(&keepLatest, "keep-latest", false, "Keep only the latest snapshot")
|
||||
cmd.Flags().StringVar(&olderThan, "older-than", "", "Remove snapshots older than duration (e.g., 30d, 6m, 1y)")
|
||||
cmd.Flags().BoolVar(&force, "force", false, "Skip confirmation prompt")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func snapshotLatestCmd() *cobra.Command {
|
||||
var (
|
||||
bucket string
|
||||
prefix string
|
||||
)
|
||||
// newSnapshotVerifyCommand creates the 'snapshot verify' subcommand
|
||||
func newSnapshotVerifyCommand() *cobra.Command {
|
||||
var deep bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "latest",
|
||||
Short: "Get the latest snapshot ID",
|
||||
Long: "Display the ID of the most recent snapshot",
|
||||
Use: "verify <snapshot-id>",
|
||||
Short: "Verify snapshot integrity",
|
||||
Long: "Verifies that all blobs referenced in a snapshot exist",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
panic("unimplemented")
|
||||
return runSnapshotCommand(cmd.Context(), func(app *SnapshotApp) error {
|
||||
return app.Verify(cmd.Context(), args[0], deep)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&bucket, "bucket", "", "S3 bucket name")
|
||||
cmd.Flags().StringVar(&prefix, "prefix", "", "S3 prefix")
|
||||
_ = cmd.MarkFlagRequired("bucket")
|
||||
cmd.Flags().BoolVar(&deep, "deep", false, "Download and verify blob hashes")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// List lists all snapshots
|
||||
func (app *SnapshotApp) List(ctx context.Context, jsonOutput bool) error {
|
||||
snapshots, err := app.getSnapshots(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sort by timestamp (newest first)
|
||||
sort.Slice(snapshots, func(i, j int) bool {
|
||||
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
|
||||
})
|
||||
|
||||
if jsonOutput {
|
||||
// JSON output
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(snapshots)
|
||||
}
|
||||
|
||||
// Table output
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||
if _, err := fmt.Fprintln(w, "SNAPSHOT ID\tTIMESTAMP\tCOMPRESSED SIZE"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintln(w, "───────────\t─────────\t───────────────"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, snap := range snapshots {
|
||||
if _, err := fmt.Fprintf(w, "%s\t%s\t%s\n",
|
||||
snap.ID,
|
||||
snap.Timestamp.Format("2006-01-02 15:04:05"),
|
||||
formatBytes(snap.CompressedSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
// Purge removes old snapshots based on criteria
|
||||
func (app *SnapshotApp) Purge(ctx context.Context, keepLatest bool, olderThan string, force bool) error {
|
||||
snapshots, err := app.getSnapshots(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sort by timestamp (newest first)
|
||||
sort.Slice(snapshots, func(i, j int) bool {
|
||||
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
|
||||
})
|
||||
|
||||
var toDelete []SnapshotInfo
|
||||
|
||||
if keepLatest {
|
||||
// Keep only the most recent snapshot
|
||||
if len(snapshots) > 1 {
|
||||
toDelete = snapshots[1:]
|
||||
}
|
||||
} else if olderThan != "" {
|
||||
// Parse duration
|
||||
duration, err := parseDuration(olderThan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration: %w", err)
|
||||
}
|
||||
|
||||
cutoff := time.Now().UTC().Add(-duration)
|
||||
for _, snap := range snapshots {
|
||||
if snap.Timestamp.Before(cutoff) {
|
||||
toDelete = append(toDelete, snap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDelete) == 0 {
|
||||
fmt.Println("No snapshots to delete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Show what will be deleted
|
||||
fmt.Printf("The following snapshots will be deleted:\n\n")
|
||||
for _, snap := range toDelete {
|
||||
fmt.Printf(" %s (%s, %s)\n",
|
||||
snap.ID,
|
||||
snap.Timestamp.Format("2006-01-02 15:04:05"),
|
||||
formatBytes(snap.CompressedSize))
|
||||
}
|
||||
|
||||
// Confirm unless --force is used
|
||||
if !force {
|
||||
fmt.Printf("\nDelete %d snapshot(s)? [y/N] ", len(toDelete))
|
||||
var confirm string
|
||||
if _, err := fmt.Scanln(&confirm); err != nil {
|
||||
// Treat EOF or error as "no"
|
||||
fmt.Println("Cancelled")
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(confirm) != "y" {
|
||||
fmt.Println("Cancelled")
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete))
|
||||
}
|
||||
|
||||
// Delete snapshots
|
||||
for _, snap := range toDelete {
|
||||
log.Info("Deleting snapshot", "id", snap.ID)
|
||||
if err := app.deleteSnapshot(ctx, snap.ID); err != nil {
|
||||
return fmt.Errorf("deleting snapshot %s: %w", snap.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Deleted %d snapshot(s)\n", len(toDelete))
|
||||
|
||||
// TODO: Run blob pruning to clean up unreferenced blobs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify checks snapshot integrity
|
||||
func (app *SnapshotApp) Verify(ctx context.Context, snapshotID string, deep bool) error {
|
||||
fmt.Printf("Verifying snapshot %s...\n", snapshotID)
|
||||
|
||||
// Download and parse manifest
|
||||
manifest, err := app.downloadManifest(ctx, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("downloading manifest: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Manifest contains %d blobs\n", len(manifest))
|
||||
|
||||
// Check each blob exists
|
||||
missing := 0
|
||||
verified := 0
|
||||
|
||||
for _, blobHash := range manifest {
|
||||
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
|
||||
|
||||
if deep {
|
||||
// Download and verify hash
|
||||
// TODO: Implement deep verification
|
||||
fmt.Printf("Deep verification not yet implemented\n")
|
||||
return nil
|
||||
} else {
|
||||
// Just check existence
|
||||
_, err := app.S3Client.StatObject(ctx, blobPath)
|
||||
if err != nil {
|
||||
fmt.Printf(" Missing: %s\n", blobHash)
|
||||
missing++
|
||||
} else {
|
||||
verified++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("\nVerification complete:\n")
|
||||
fmt.Printf(" Verified: %d\n", verified)
|
||||
fmt.Printf(" Missing: %d\n", missing)
|
||||
|
||||
if missing > 0 {
|
||||
return fmt.Errorf("%d blobs are missing", missing)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSnapshots retrieves all snapshots from S3
|
||||
func (app *SnapshotApp) getSnapshots(ctx context.Context) ([]SnapshotInfo, error) {
|
||||
var snapshots []SnapshotInfo
|
||||
|
||||
// List all objects under metadata/
|
||||
objectCh := app.S3Client.ListObjectsStream(ctx, "metadata/", true)
|
||||
|
||||
// Track unique snapshots
|
||||
snapshotMap := make(map[string]*SnapshotInfo)
|
||||
|
||||
for object := range objectCh {
|
||||
if object.Err != nil {
|
||||
return nil, fmt.Errorf("listing objects: %w", object.Err)
|
||||
}
|
||||
|
||||
// Extract snapshot ID from paths like metadata/2024-01-15-143052-hostname/manifest.json.zst
|
||||
parts := strings.Split(object.Key, "/")
|
||||
if len(parts) < 3 || parts[0] != "metadata" {
|
||||
continue
|
||||
}
|
||||
|
||||
snapshotID := parts[1]
|
||||
if snapshotID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Initialize snapshot info if not seen
|
||||
if _, exists := snapshotMap[snapshotID]; !exists {
|
||||
timestamp, err := parseSnapshotTimestamp(snapshotID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
snapshotMap[snapshotID] = &SnapshotInfo{
|
||||
ID: snapshotID,
|
||||
Timestamp: timestamp,
|
||||
CompressedSize: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For each snapshot, download manifest and calculate total blob size
|
||||
for _, snap := range snapshotMap {
|
||||
manifest, err := app.downloadManifest(ctx, snap.ID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to download manifest", "id", snap.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate total size of referenced blobs
|
||||
for _, blobHash := range manifest {
|
||||
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
|
||||
info, err := app.S3Client.StatObject(ctx, blobPath)
|
||||
if err != nil {
|
||||
log.Warn("Failed to stat blob", "blob", blobHash, "error", err)
|
||||
continue
|
||||
}
|
||||
snap.CompressedSize += info.Size
|
||||
}
|
||||
|
||||
snapshots = append(snapshots, *snap)
|
||||
}
|
||||
|
||||
return snapshots, nil
|
||||
}
|
||||
|
||||
// downloadManifest downloads and parses a snapshot manifest
|
||||
func (app *SnapshotApp) downloadManifest(ctx context.Context, snapshotID string) ([]string, error) {
|
||||
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
|
||||
|
||||
reader, err := app.S3Client.GetObject(ctx, manifestPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = reader.Close() }()
|
||||
|
||||
// Decompress
|
||||
zr, err := zstd.NewReader(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating zstd reader: %w", err)
|
||||
}
|
||||
defer zr.Close()
|
||||
|
||||
// Decode JSON
|
||||
var manifest []string
|
||||
if err := json.NewDecoder(zr).Decode(&manifest); err != nil {
|
||||
return nil, fmt.Errorf("decoding manifest: %w", err)
|
||||
}
|
||||
|
||||
return manifest, nil
|
||||
}
|
||||
|
||||
// deleteSnapshot removes a snapshot and its metadata
|
||||
func (app *SnapshotApp) deleteSnapshot(ctx context.Context, snapshotID string) error {
|
||||
// List all objects under metadata/{snapshotID}/
|
||||
prefix := fmt.Sprintf("metadata/%s/", snapshotID)
|
||||
objectCh := app.S3Client.ListObjectsStream(ctx, prefix, true)
|
||||
|
||||
var objectsToDelete []string
|
||||
for object := range objectCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing objects: %w", object.Err)
|
||||
}
|
||||
objectsToDelete = append(objectsToDelete, object.Key)
|
||||
}
|
||||
|
||||
// Delete all objects
|
||||
for _, key := range objectsToDelete {
|
||||
if err := app.S3Client.RemoveObject(ctx, key); err != nil {
|
||||
return fmt.Errorf("removing %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSnapshotTimestamp extracts timestamp from snapshot ID
|
||||
// Format: hostname-20240115-143052Z
|
||||
func parseSnapshotTimestamp(snapshotID string) (time.Time, error) {
|
||||
// Find the last hyphen to separate hostname from timestamp
|
||||
lastHyphen := strings.LastIndex(snapshotID, "-")
|
||||
if lastHyphen == -1 {
|
||||
return time.Time{}, fmt.Errorf("invalid snapshot ID format")
|
||||
}
|
||||
|
||||
// Extract timestamp part (everything after hostname)
|
||||
timestampPart := snapshotID[lastHyphen+1:]
|
||||
|
||||
// The timestamp format is YYYYMMDD-HHMMSSZ
|
||||
// We need to find where the date ends and time begins
|
||||
if len(timestampPart) < 8 {
|
||||
return time.Time{}, fmt.Errorf("invalid snapshot ID format: timestamp too short")
|
||||
}
|
||||
|
||||
// Find where the hostname ends by looking for pattern YYYYMMDD
|
||||
hostnameEnd := strings.LastIndex(snapshotID[:lastHyphen], "-")
|
||||
if hostnameEnd == -1 {
|
||||
return time.Time{}, fmt.Errorf("invalid snapshot ID format: missing date separator")
|
||||
}
|
||||
|
||||
// Get the full timestamp including date from before the last hyphen
|
||||
fullTimestamp := snapshotID[hostnameEnd+1:]
|
||||
|
||||
// Parse the timestamp with Z suffix
|
||||
return time.Parse("20060102-150405Z", fullTimestamp)
|
||||
}
|
||||
|
||||
// parseDuration is now in duration.go
|
||||
|
||||
// runSnapshotCommand creates the FX app and runs the given function
|
||||
func runSnapshotCommand(ctx context.Context, fn func(*SnapshotApp) error) error {
|
||||
var result error
|
||||
rootFlags := GetRootFlags()
|
||||
|
||||
// Use unified config resolution
|
||||
configPath, err := ResolveConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = RunWithApp(ctx, AppOptions{
|
||||
ConfigPath: configPath,
|
||||
LogOptions: log.LogOptions{
|
||||
Verbose: rootFlags.Verbose,
|
||||
Debug: rootFlags.Debug,
|
||||
},
|
||||
Modules: []fx.Option{
|
||||
s3.Module,
|
||||
fx.Provide(func(
|
||||
g *globals.Globals,
|
||||
cfg *config.Config,
|
||||
db *database.DB,
|
||||
repos *database.Repositories,
|
||||
s3Client *s3.Client,
|
||||
lc fx.Lifecycle,
|
||||
shutdowner fx.Shutdowner,
|
||||
) *SnapshotApp {
|
||||
snapshotCreateApp := &SnapshotCreateApp{
|
||||
Globals: g,
|
||||
Config: cfg,
|
||||
Repositories: repos,
|
||||
ScannerFactory: nil, // Not needed for snapshot commands
|
||||
S3Client: s3Client,
|
||||
DB: db,
|
||||
Lifecycle: lc,
|
||||
Shutdowner: shutdowner,
|
||||
}
|
||||
return &SnapshotApp{
|
||||
SnapshotCreateApp: snapshotCreateApp,
|
||||
S3Client: s3Client,
|
||||
}
|
||||
}),
|
||||
},
|
||||
Invokes: []fx.Option{
|
||||
fx.Invoke(func(app *SnapshotApp, shutdowner fx.Shutdowner) {
|
||||
result = fn(app)
|
||||
// Shutdown after command completes
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond) // Brief delay to ensure clean shutdown
|
||||
if err := shutdowner.Shutdown(); err != nil {
|
||||
log.Error("Failed to shutdown", "error", err)
|
||||
}
|
||||
}()
|
||||
}),
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// formatNumber formats a number with comma separators
|
||||
func formatNumber(n int) string {
|
||||
if n < 1000 {
|
||||
return fmt.Sprintf("%d", n)
|
||||
}
|
||||
return humanize.Comma(int64(n))
|
||||
}
|
||||
|
||||
// formatDuration formats a duration in a human-readable way
|
||||
func formatDuration(d time.Duration) string {
|
||||
if d < time.Second {
|
||||
return fmt.Sprintf("%dms", d.Milliseconds())
|
||||
}
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%.1fs", d.Seconds())
|
||||
}
|
||||
if d < time.Hour {
|
||||
mins := int(d.Minutes())
|
||||
secs := int(d.Seconds()) % 60
|
||||
if secs > 0 {
|
||||
return fmt.Sprintf("%dm%ds", mins, secs)
|
||||
}
|
||||
return fmt.Sprintf("%dm", mins)
|
||||
}
|
||||
hours := int(d.Hours())
|
||||
mins := int(d.Minutes()) % 60
|
||||
if mins > 0 {
|
||||
return fmt.Sprintf("%dh%dm", hours, mins)
|
||||
}
|
||||
return fmt.Sprintf("%dh", hours)
|
||||
}
|
||||
|
||||
159
internal/cli/store.go
Normal file
159
internal/cli/store.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"git.eeqj.de/sneak/vaultik/internal/s3"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// StoreApp contains dependencies for store commands
|
||||
type StoreApp struct {
|
||||
S3Client *s3.Client
|
||||
Shutdowner fx.Shutdowner
|
||||
}
|
||||
|
||||
// NewStoreCommand creates the store command and subcommands
|
||||
func NewStoreCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "store",
|
||||
Short: "Storage information commands",
|
||||
Long: "Commands for viewing information about the S3 storage backend",
|
||||
}
|
||||
|
||||
// Add subcommands
|
||||
cmd.AddCommand(newStoreInfoCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// newStoreInfoCommand creates the 'store info' subcommand
|
||||
func newStoreInfoCommand() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "info",
|
||||
Short: "Display storage information",
|
||||
Long: "Shows S3 bucket configuration and storage statistics including snapshots and blobs",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runWithApp(cmd.Context(), func(app *StoreApp) error {
|
||||
return app.Info(cmd.Context())
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Info displays storage information
|
||||
func (app *StoreApp) Info(ctx context.Context) error {
|
||||
// Get bucket info
|
||||
bucketName := app.S3Client.BucketName()
|
||||
endpoint := app.S3Client.Endpoint()
|
||||
|
||||
fmt.Printf("Storage Information\n")
|
||||
fmt.Printf("==================\n\n")
|
||||
fmt.Printf("S3 Configuration:\n")
|
||||
fmt.Printf(" Endpoint: %s\n", endpoint)
|
||||
fmt.Printf(" Bucket: %s\n\n", bucketName)
|
||||
|
||||
// Count snapshots by listing metadata/ prefix
|
||||
snapshotCount := 0
|
||||
snapshotCh := app.S3Client.ListObjectsStream(ctx, "metadata/", true)
|
||||
snapshotDirs := make(map[string]bool)
|
||||
|
||||
for object := range snapshotCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing snapshots: %w", object.Err)
|
||||
}
|
||||
// Extract snapshot ID from path like metadata/2024-01-15-143052-hostname/
|
||||
parts := strings.Split(object.Key, "/")
|
||||
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
|
||||
snapshotDirs[parts[1]] = true
|
||||
}
|
||||
}
|
||||
snapshotCount = len(snapshotDirs)
|
||||
|
||||
// Count blobs and calculate total size by listing blobs/ prefix
|
||||
blobCount := 0
|
||||
var totalSize int64
|
||||
|
||||
blobCh := app.S3Client.ListObjectsStream(ctx, "blobs/", false)
|
||||
for object := range blobCh {
|
||||
if object.Err != nil {
|
||||
return fmt.Errorf("listing blobs: %w", object.Err)
|
||||
}
|
||||
if !strings.HasSuffix(object.Key, "/") { // Skip directories
|
||||
blobCount++
|
||||
totalSize += object.Size
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Storage Statistics:\n")
|
||||
fmt.Printf(" Snapshots: %d\n", snapshotCount)
|
||||
fmt.Printf(" Blobs: %d\n", blobCount)
|
||||
fmt.Printf(" Total Size: %s\n", formatBytes(totalSize))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatBytes formats bytes into human-readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// runWithApp creates the FX app and runs the given function
|
||||
func runWithApp(ctx context.Context, fn func(*StoreApp) error) error {
|
||||
var result error
|
||||
rootFlags := GetRootFlags()
|
||||
|
||||
// Use unified config resolution
|
||||
configPath, err := ResolveConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = RunWithApp(ctx, AppOptions{
|
||||
ConfigPath: configPath,
|
||||
LogOptions: log.LogOptions{
|
||||
Verbose: rootFlags.Verbose,
|
||||
Debug: rootFlags.Debug,
|
||||
},
|
||||
Modules: []fx.Option{
|
||||
s3.Module,
|
||||
fx.Provide(func(s3Client *s3.Client, shutdowner fx.Shutdowner) *StoreApp {
|
||||
return &StoreApp{
|
||||
S3Client: s3Client,
|
||||
Shutdowner: shutdowner,
|
||||
}
|
||||
}),
|
||||
},
|
||||
Invokes: []fx.Option{
|
||||
fx.Invoke(func(app *StoreApp, shutdowner fx.Shutdowner) {
|
||||
result = fn(app)
|
||||
// Shutdown after command completes
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond) // Brief delay to ensure clean shutdown
|
||||
if err := shutdowner.Shutdown(); err != nil {
|
||||
log.Error("Failed to shutdown", "error", err)
|
||||
}
|
||||
}()
|
||||
}),
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -9,7 +9,10 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config represents the application configuration
|
||||
// Config represents the application configuration for Vaultik.
|
||||
// It defines all settings for backup operations, including source directories,
|
||||
// encryption recipients, S3 storage configuration, and performance tuning parameters.
|
||||
// Configuration is typically loaded from a YAML file.
|
||||
type Config struct {
|
||||
AgeRecipients []string `yaml:"age_recipients"`
|
||||
BackupInterval time.Duration `yaml:"backup_interval"`
|
||||
@@ -19,14 +22,15 @@ type Config struct {
|
||||
FullScanInterval time.Duration `yaml:"full_scan_interval"`
|
||||
Hostname string `yaml:"hostname"`
|
||||
IndexPath string `yaml:"index_path"`
|
||||
IndexPrefix string `yaml:"index_prefix"`
|
||||
MinTimeBetweenRun time.Duration `yaml:"min_time_between_run"`
|
||||
S3 S3Config `yaml:"s3"`
|
||||
SourceDirs []string `yaml:"source_dirs"`
|
||||
CompressionLevel int `yaml:"compression_level"`
|
||||
}
|
||||
|
||||
// S3Config represents S3 storage configuration
|
||||
// S3Config represents S3 storage configuration for backup storage.
|
||||
// It supports both AWS S3 and S3-compatible storage services.
|
||||
// All fields except UseSSL and PartSize are required.
|
||||
type S3Config struct {
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Bucket string `yaml:"bucket"`
|
||||
@@ -38,10 +42,14 @@ type S3Config struct {
|
||||
PartSize Size `yaml:"part_size"`
|
||||
}
|
||||
|
||||
// ConfigPath wraps the config file path for fx injection
|
||||
// ConfigPath wraps the config file path for fx dependency injection.
|
||||
// This type allows the config file path to be injected as a distinct type
|
||||
// rather than a plain string, avoiding conflicts with other string dependencies.
|
||||
type ConfigPath string
|
||||
|
||||
// New creates a new Config instance
|
||||
// New creates a new Config instance by loading from the specified path.
|
||||
// This function is used by the fx dependency injection framework.
|
||||
// Returns an error if the path is empty or if loading fails.
|
||||
func New(path ConfigPath) (*Config, error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("config path not provided")
|
||||
@@ -55,7 +63,11 @@ func New(path ConfigPath) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Load reads and parses the configuration file
|
||||
// Load reads and parses the configuration file from the specified path.
|
||||
// It applies default values for optional fields, performs environment variable
|
||||
// substitution for certain fields (like IndexPath), and validates the configuration.
|
||||
// The configuration file should be in YAML format. Returns an error if the file
|
||||
// cannot be read, parsed, or if validation fails.
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
@@ -70,7 +82,6 @@ func Load(path string) (*Config, error) {
|
||||
FullScanInterval: 24 * time.Hour,
|
||||
MinTimeBetweenRun: 15 * time.Minute,
|
||||
IndexPath: "/var/lib/vaultik/index.sqlite",
|
||||
IndexPrefix: "index/",
|
||||
CompressionLevel: 3,
|
||||
}
|
||||
|
||||
@@ -107,7 +118,15 @@ func Load(path string) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
// Validate checks if the configuration is valid and complete.
|
||||
// It ensures all required fields are present and have valid values:
|
||||
// - At least one age recipient must be specified
|
||||
// - At least one source directory must be configured
|
||||
// - S3 credentials and endpoint must be provided
|
||||
// - Chunk size must be at least 1MB
|
||||
// - Blob size limit must be at least the chunk size
|
||||
// - Compression level must be between 1 and 19
|
||||
// Returns an error describing the first validation failure encountered.
|
||||
func (c *Config) Validate() error {
|
||||
if len(c.AgeRecipients) == 0 {
|
||||
return fmt.Errorf("at least one age_recipient is required")
|
||||
@@ -148,7 +167,8 @@ func (c *Config) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Module exports the config module for fx
|
||||
// Module exports the config module for fx dependency injection.
|
||||
// It provides the Config type to other modules in the application.
|
||||
var Module = fx.Module("config",
|
||||
fx.Provide(New),
|
||||
)
|
||||
|
||||
@@ -6,10 +6,14 @@ import (
|
||||
"github.com/dustin/go-humanize"
|
||||
)
|
||||
|
||||
// Size is a custom type that can unmarshal from both int64 and string
|
||||
// Size represents a byte size that can be specified in configuration files.
|
||||
// It can unmarshal from both numeric values (interpreted as bytes) and
|
||||
// human-readable strings like "10MB", "2.5GB", or "1TB".
|
||||
type Size int64
|
||||
|
||||
// UnmarshalYAML implements yaml.Unmarshaler for Size
|
||||
// UnmarshalYAML implements yaml.Unmarshaler for Size, allowing it to be
|
||||
// parsed from YAML configuration files. It accepts both numeric values
|
||||
// (interpreted as bytes) and string values with units (e.g., "10MB").
|
||||
func (s *Size) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
// Try to unmarshal as int64 first
|
||||
var intVal int64
|
||||
@@ -34,12 +38,16 @@ func (s *Size) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int64 returns the size as int64
|
||||
// Int64 returns the size as int64 bytes.
|
||||
// This is useful when the size needs to be passed to APIs that expect
|
||||
// a numeric byte count.
|
||||
func (s Size) Int64() int64 {
|
||||
return int64(s)
|
||||
}
|
||||
|
||||
// String returns the size as a human-readable string
|
||||
// String returns the size as a human-readable string.
|
||||
// For example, 1048576 bytes would be formatted as "1.0 MB".
|
||||
// This implements the fmt.Stringer interface.
|
||||
func (s Size) String() string {
|
||||
return humanize.Bytes(uint64(s))
|
||||
}
|
||||
|
||||
@@ -9,13 +9,19 @@ import (
|
||||
"filippo.io/age"
|
||||
)
|
||||
|
||||
// Encryptor provides thread-safe encryption using age
|
||||
// Encryptor provides thread-safe encryption using the age encryption library.
|
||||
// It supports encrypting data for multiple recipients simultaneously, allowing
|
||||
// any of the corresponding private keys to decrypt the data. This is useful
|
||||
// for backup scenarios where multiple parties should be able to decrypt the data.
|
||||
type Encryptor struct {
|
||||
recipients []age.Recipient
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewEncryptor creates a new encryptor with the given age public keys
|
||||
// NewEncryptor creates a new encryptor with the given age public keys.
|
||||
// Each public key should be a valid age X25519 recipient string (e.g., "age1...")
|
||||
// At least one recipient must be provided. Returns an error if any of the
|
||||
// public keys are invalid or if no recipients are specified.
|
||||
func NewEncryptor(publicKeys []string) (*Encryptor, error) {
|
||||
if len(publicKeys) == 0 {
|
||||
return nil, fmt.Errorf("at least one recipient is required")
|
||||
@@ -35,7 +41,10 @@ func NewEncryptor(publicKeys []string) (*Encryptor, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts data using age encryption
|
||||
// Encrypt encrypts data using age encryption for all configured recipients.
|
||||
// The encrypted data can be decrypted by any of the corresponding private keys.
|
||||
// This method is suitable for small to medium amounts of data that fit in memory.
|
||||
// For large data streams, use EncryptStream or EncryptWriter instead.
|
||||
func (e *Encryptor) Encrypt(data []byte) ([]byte, error) {
|
||||
e.mu.RLock()
|
||||
recipients := e.recipients
|
||||
@@ -62,7 +71,10 @@ func (e *Encryptor) Encrypt(data []byte) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// EncryptStream encrypts data from reader to writer
|
||||
// EncryptStream encrypts data from reader to writer using age encryption.
|
||||
// This method is suitable for encrypting large files or streams as it processes
|
||||
// data in a streaming fashion without loading everything into memory.
|
||||
// The encrypted data is written directly to the destination writer.
|
||||
func (e *Encryptor) EncryptStream(dst io.Writer, src io.Reader) error {
|
||||
e.mu.RLock()
|
||||
recipients := e.recipients
|
||||
@@ -87,7 +99,11 @@ func (e *Encryptor) EncryptStream(dst io.Writer, src io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncryptWriter creates a writer that encrypts data written to it
|
||||
// EncryptWriter creates a writer that encrypts data written to it.
|
||||
// All data written to the returned WriteCloser will be encrypted and written
|
||||
// to the destination writer. The caller must call Close() on the returned
|
||||
// writer to ensure all encrypted data is properly flushed and finalized.
|
||||
// This is useful for integrating encryption into existing writer-based pipelines.
|
||||
func (e *Encryptor) EncryptWriter(dst io.Writer) (io.WriteCloser, error) {
|
||||
e.mu.RLock()
|
||||
recipients := e.recipients
|
||||
@@ -102,7 +118,11 @@ func (e *Encryptor) EncryptWriter(dst io.Writer) (io.WriteCloser, error) {
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// UpdateRecipients updates the recipients (thread-safe)
|
||||
// UpdateRecipients updates the recipients for future encryption operations.
|
||||
// This method is thread-safe and can be called while other encryption operations
|
||||
// are in progress. Existing encryption operations will continue with the old
|
||||
// recipients. At least one recipient must be provided. Returns an error if any
|
||||
// of the public keys are invalid or if no recipients are specified.
|
||||
func (e *Encryptor) UpdateRecipients(publicKeys []string) error {
|
||||
if len(publicKeys) == 0 {
|
||||
return fmt.Errorf("at least one recipient is required")
|
||||
|
||||
@@ -24,7 +24,7 @@ func (r *BlobChunkRepository) Create(ctx context.Context, tx *sql.Tx, bc *BlobCh
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, bc.BlobID, bc.ChunkHash, bc.Offset, bc.Length)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, bc.BlobID, bc.ChunkHash, bc.Offset, bc.Length)
|
||||
_, err = r.db.ExecWithLog(ctx, query, bc.BlobID, bc.ChunkHash, bc.Offset, bc.Length)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -2,7 +2,9 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBlobChunkRepository(t *testing.T) {
|
||||
@@ -10,78 +12,112 @@ func TestBlobChunkRepository(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewBlobChunkRepository(db)
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create blob first
|
||||
blob := &Blob{
|
||||
ID: "blob1-uuid",
|
||||
Hash: "blob1-hash",
|
||||
CreatedTS: time.Now(),
|
||||
}
|
||||
err := repos.Blobs.Create(ctx, nil, blob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create blob: %v", err)
|
||||
}
|
||||
|
||||
// Create chunks
|
||||
chunks := []string{"chunk1", "chunk2", "chunk3"}
|
||||
for _, chunkHash := range chunks {
|
||||
chunk := &Chunk{
|
||||
ChunkHash: chunkHash,
|
||||
SHA256: chunkHash + "-sha",
|
||||
Size: 1024,
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk %s: %v", chunkHash, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Create
|
||||
bc1 := &BlobChunk{
|
||||
BlobID: "blob1-uuid",
|
||||
BlobID: blob.ID,
|
||||
ChunkHash: "chunk1",
|
||||
Offset: 0,
|
||||
Length: 1024,
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, nil, bc1)
|
||||
err = repos.BlobChunks.Create(ctx, nil, bc1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create blob chunk: %v", err)
|
||||
}
|
||||
|
||||
// Add more chunks to the same blob
|
||||
bc2 := &BlobChunk{
|
||||
BlobID: "blob1-uuid",
|
||||
BlobID: blob.ID,
|
||||
ChunkHash: "chunk2",
|
||||
Offset: 1024,
|
||||
Length: 2048,
|
||||
}
|
||||
err = repo.Create(ctx, nil, bc2)
|
||||
err = repos.BlobChunks.Create(ctx, nil, bc2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create second blob chunk: %v", err)
|
||||
}
|
||||
|
||||
bc3 := &BlobChunk{
|
||||
BlobID: "blob1-uuid",
|
||||
BlobID: blob.ID,
|
||||
ChunkHash: "chunk3",
|
||||
Offset: 3072,
|
||||
Length: 512,
|
||||
}
|
||||
err = repo.Create(ctx, nil, bc3)
|
||||
err = repos.BlobChunks.Create(ctx, nil, bc3)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create third blob chunk: %v", err)
|
||||
}
|
||||
|
||||
// Test GetByBlobID
|
||||
chunks, err := repo.GetByBlobID(ctx, "blob1-uuid")
|
||||
blobChunks, err := repos.BlobChunks.GetByBlobID(ctx, blob.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob chunks: %v", err)
|
||||
}
|
||||
if len(chunks) != 3 {
|
||||
t.Errorf("expected 3 chunks, got %d", len(chunks))
|
||||
if len(blobChunks) != 3 {
|
||||
t.Errorf("expected 3 chunks, got %d", len(blobChunks))
|
||||
}
|
||||
|
||||
// Verify order by offset
|
||||
expectedOffsets := []int64{0, 1024, 3072}
|
||||
for i, chunk := range chunks {
|
||||
if chunk.Offset != expectedOffsets[i] {
|
||||
t.Errorf("wrong chunk order: expected offset %d, got %d", expectedOffsets[i], chunk.Offset)
|
||||
for i, bc := range blobChunks {
|
||||
if bc.Offset != expectedOffsets[i] {
|
||||
t.Errorf("wrong chunk order: expected offset %d, got %d", expectedOffsets[i], bc.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetByChunkHash
|
||||
bc, err := repo.GetByChunkHash(ctx, "chunk2")
|
||||
bc, err := repos.BlobChunks.GetByChunkHash(ctx, "chunk2")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob chunk by chunk hash: %v", err)
|
||||
}
|
||||
if bc == nil {
|
||||
t.Fatal("expected blob chunk, got nil")
|
||||
}
|
||||
if bc.BlobID != "blob1-uuid" {
|
||||
t.Errorf("wrong blob ID: expected blob1-uuid, got %s", bc.BlobID)
|
||||
if bc.BlobID != blob.ID {
|
||||
t.Errorf("wrong blob ID: expected %s, got %s", blob.ID, bc.BlobID)
|
||||
}
|
||||
if bc.Offset != 1024 {
|
||||
t.Errorf("wrong offset: expected 1024, got %d", bc.Offset)
|
||||
}
|
||||
|
||||
// Test duplicate insert (should fail due to primary key constraint)
|
||||
err = repos.BlobChunks.Create(ctx, nil, bc1)
|
||||
if err == nil {
|
||||
t.Fatal("duplicate blob_chunk insert should fail due to primary key constraint")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "UNIQUE") && !strings.Contains(err.Error(), "constraint") {
|
||||
t.Fatalf("expected constraint error, got: %v", err)
|
||||
}
|
||||
|
||||
// Test non-existent chunk
|
||||
bc, err = repo.GetByChunkHash(ctx, "nonexistent")
|
||||
bc, err = repos.BlobChunks.GetByChunkHash(ctx, "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -95,26 +131,61 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewBlobChunkRepository(db)
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create blobs
|
||||
blob1 := &Blob{
|
||||
ID: "blob1-uuid",
|
||||
Hash: "blob1-hash",
|
||||
CreatedTS: time.Now(),
|
||||
}
|
||||
blob2 := &Blob{
|
||||
ID: "blob2-uuid",
|
||||
Hash: "blob2-hash",
|
||||
CreatedTS: time.Now(),
|
||||
}
|
||||
|
||||
err := repos.Blobs.Create(ctx, nil, blob1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create blob1: %v", err)
|
||||
}
|
||||
err = repos.Blobs.Create(ctx, nil, blob2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create blob2: %v", err)
|
||||
}
|
||||
|
||||
// Create chunks
|
||||
chunkHashes := []string{"chunk1", "chunk2", "chunk3"}
|
||||
for _, chunkHash := range chunkHashes {
|
||||
chunk := &Chunk{
|
||||
ChunkHash: chunkHash,
|
||||
SHA256: chunkHash + "-sha",
|
||||
Size: 1024,
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk %s: %v", chunkHash, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create chunks across multiple blobs
|
||||
// Some chunks are shared between blobs (deduplication scenario)
|
||||
blobChunks := []BlobChunk{
|
||||
{BlobID: "blob1-uuid", ChunkHash: "chunk1", Offset: 0, Length: 1024},
|
||||
{BlobID: "blob1-uuid", ChunkHash: "chunk2", Offset: 1024, Length: 1024},
|
||||
{BlobID: "blob2-uuid", ChunkHash: "chunk2", Offset: 0, Length: 1024}, // chunk2 is shared
|
||||
{BlobID: "blob2-uuid", ChunkHash: "chunk3", Offset: 1024, Length: 1024},
|
||||
{BlobID: blob1.ID, ChunkHash: "chunk1", Offset: 0, Length: 1024},
|
||||
{BlobID: blob1.ID, ChunkHash: "chunk2", Offset: 1024, Length: 1024},
|
||||
{BlobID: blob2.ID, ChunkHash: "chunk2", Offset: 0, Length: 1024}, // chunk2 is shared
|
||||
{BlobID: blob2.ID, ChunkHash: "chunk3", Offset: 1024, Length: 1024},
|
||||
}
|
||||
|
||||
for _, bc := range blobChunks {
|
||||
err := repo.Create(ctx, nil, &bc)
|
||||
err := repos.BlobChunks.Create(ctx, nil, &bc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create blob chunk: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify blob1 chunks
|
||||
chunks, err := repo.GetByBlobID(ctx, "blob1-uuid")
|
||||
chunks, err := repos.BlobChunks.GetByBlobID(ctx, blob1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob1 chunks: %v", err)
|
||||
}
|
||||
@@ -123,7 +194,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify blob2 chunks
|
||||
chunks, err = repo.GetByBlobID(ctx, "blob2-uuid")
|
||||
chunks, err = repos.BlobChunks.GetByBlobID(ctx, blob2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob2 chunks: %v", err)
|
||||
}
|
||||
@@ -132,7 +203,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify shared chunk
|
||||
bc, err := repo.GetByChunkHash(ctx, "chunk2")
|
||||
bc, err := repos.BlobChunks.GetByChunkHash(ctx, "chunk2")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get shared chunk: %v", err)
|
||||
}
|
||||
@@ -140,7 +211,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
|
||||
t.Fatal("expected shared chunk, got nil")
|
||||
}
|
||||
// GetByChunkHash returns first match, should be blob1
|
||||
if bc.BlobID != "blob1-uuid" {
|
||||
t.Errorf("expected blob1-uuid for shared chunk, got %s", bc.BlobID)
|
||||
if bc.BlobID != blob1.ID {
|
||||
t.Errorf("expected %s for shared chunk, got %s", blob1.ID, bc.BlobID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
)
|
||||
|
||||
type BlobRepository struct {
|
||||
@@ -36,7 +38,7 @@ func (r *BlobRepository) Create(ctx context.Context, tx *sql.Tx, blob *Blob) err
|
||||
_, err = tx.ExecContext(ctx, query, blob.ID, blob.Hash, blob.CreatedTS.Unix(),
|
||||
finishedTS, blob.UncompressedSize, blob.CompressedSize, uploadedTS)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, blob.ID, blob.Hash, blob.CreatedTS.Unix(),
|
||||
_, err = r.db.ExecWithLog(ctx, query, blob.ID, blob.Hash, blob.CreatedTS.Unix(),
|
||||
finishedTS, blob.UncompressedSize, blob.CompressedSize, uploadedTS)
|
||||
}
|
||||
|
||||
@@ -75,13 +77,13 @@ func (r *BlobRepository) GetByHash(ctx context.Context, hash string) (*Blob, err
|
||||
return nil, fmt.Errorf("querying blob: %w", err)
|
||||
}
|
||||
|
||||
blob.CreatedTS = time.Unix(createdTSUnix, 0)
|
||||
blob.CreatedTS = time.Unix(createdTSUnix, 0).UTC()
|
||||
if finishedTSUnix.Valid {
|
||||
ts := time.Unix(finishedTSUnix.Int64, 0)
|
||||
ts := time.Unix(finishedTSUnix.Int64, 0).UTC()
|
||||
blob.FinishedTS = &ts
|
||||
}
|
||||
if uploadedTSUnix.Valid {
|
||||
ts := time.Unix(uploadedTSUnix.Int64, 0)
|
||||
ts := time.Unix(uploadedTSUnix.Int64, 0).UTC()
|
||||
blob.UploadedTS = &ts
|
||||
}
|
||||
return &blob, nil
|
||||
@@ -116,13 +118,13 @@ func (r *BlobRepository) GetByID(ctx context.Context, id string) (*Blob, error)
|
||||
return nil, fmt.Errorf("querying blob: %w", err)
|
||||
}
|
||||
|
||||
blob.CreatedTS = time.Unix(createdTSUnix, 0)
|
||||
blob.CreatedTS = time.Unix(createdTSUnix, 0).UTC()
|
||||
if finishedTSUnix.Valid {
|
||||
ts := time.Unix(finishedTSUnix.Int64, 0)
|
||||
ts := time.Unix(finishedTSUnix.Int64, 0).UTC()
|
||||
blob.FinishedTS = &ts
|
||||
}
|
||||
if uploadedTSUnix.Valid {
|
||||
ts := time.Unix(uploadedTSUnix.Int64, 0)
|
||||
ts := time.Unix(uploadedTSUnix.Int64, 0).UTC()
|
||||
blob.UploadedTS = &ts
|
||||
}
|
||||
return &blob, nil
|
||||
@@ -136,12 +138,12 @@ func (r *BlobRepository) UpdateFinished(ctx context.Context, tx *sql.Tx, id stri
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
now := time.Now().Unix()
|
||||
now := time.Now().UTC().Unix()
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, hash, now, uncompressedSize, compressedSize, id)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, hash, now, uncompressedSize, compressedSize, id)
|
||||
_, err = r.db.ExecWithLog(ctx, query, hash, now, uncompressedSize, compressedSize, id)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -159,12 +161,12 @@ func (r *BlobRepository) UpdateUploaded(ctx context.Context, tx *sql.Tx, id stri
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
now := time.Now().Unix()
|
||||
now := time.Now().UTC().Unix()
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, now, id)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, now, id)
|
||||
_, err = r.db.ExecWithLog(ctx, query, now, id)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -173,3 +175,26 @@ func (r *BlobRepository) UpdateUploaded(ctx context.Context, tx *sql.Tx, id stri
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteOrphaned deletes blobs that are not referenced by any snapshot
|
||||
func (r *BlobRepository) DeleteOrphaned(ctx context.Context) error {
|
||||
query := `
|
||||
DELETE FROM blobs
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM snapshot_blobs
|
||||
WHERE snapshot_blobs.blob_id = blobs.id
|
||||
)
|
||||
`
|
||||
|
||||
result, err := r.db.ExecWithLog(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned blobs: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected > 0 {
|
||||
log.Debug("Deleted orphaned blobs", "count", rowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
124
internal/database/cascade_debug_test.go
Normal file
124
internal/database/cascade_debug_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCascadeDeleteDebug tests cascade delete with debug output
|
||||
func TestCascadeDeleteDebug(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Check if foreign keys are enabled
|
||||
var fkEnabled int
|
||||
err := db.conn.QueryRow("PRAGMA foreign_keys").Scan(&fkEnabled)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Foreign keys enabled: %d", fkEnabled)
|
||||
|
||||
// Create a file
|
||||
file := &File{
|
||||
Path: "/cascade-test.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err = repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
t.Logf("Created file with ID: %s", file.ID)
|
||||
|
||||
// Create chunks and file-chunk mappings
|
||||
for i := 0; i < 3; i++ {
|
||||
chunk := &Chunk{
|
||||
ChunkHash: fmt.Sprintf("cascade-chunk-%d", i),
|
||||
SHA256: fmt.Sprintf("cascade-sha-%d", i),
|
||||
Size: 1024,
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk: %v", err)
|
||||
}
|
||||
|
||||
fc := &FileChunk{
|
||||
FileID: file.ID,
|
||||
Idx: i,
|
||||
ChunkHash: chunk.ChunkHash,
|
||||
}
|
||||
err = repos.FileChunks.Create(ctx, nil, fc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file chunk: %v", err)
|
||||
}
|
||||
t.Logf("Created file chunk mapping: file_id=%s, idx=%d, chunk=%s", fc.FileID, fc.Idx, fc.ChunkHash)
|
||||
}
|
||||
|
||||
// Verify file chunks exist
|
||||
fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("File chunks before delete: %d", len(fileChunks))
|
||||
|
||||
// Check the foreign key constraint
|
||||
var fkInfo string
|
||||
err = db.conn.QueryRow(`
|
||||
SELECT sql FROM sqlite_master
|
||||
WHERE type='table' AND name='file_chunks'
|
||||
`).Scan(&fkInfo)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("file_chunks table definition:\n%s", fkInfo)
|
||||
|
||||
// Delete the file
|
||||
t.Log("Deleting file...")
|
||||
err = repos.Files.DeleteByID(ctx, nil, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete file: %v", err)
|
||||
}
|
||||
|
||||
// Verify file is gone
|
||||
deletedFile, err := repos.Files.GetByID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if deletedFile != nil {
|
||||
t.Error("file should have been deleted")
|
||||
} else {
|
||||
t.Log("File was successfully deleted")
|
||||
}
|
||||
|
||||
// Check file chunks after delete
|
||||
fileChunks, err = repos.FileChunks.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("File chunks after delete: %d", len(fileChunks))
|
||||
|
||||
// Manually check the database
|
||||
var count int
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM file_chunks WHERE file_id = ?", file.ID).Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Manual count of file_chunks for deleted file: %d", count)
|
||||
|
||||
if len(fileChunks) != 0 {
|
||||
t.Errorf("expected 0 file chunks after cascade delete, got %d", len(fileChunks))
|
||||
// List the remaining chunks
|
||||
for _, fc := range fileChunks {
|
||||
t.Logf("Remaining chunk: file_id=%s, idx=%d, chunk=%s", fc.FileID, fc.Idx, fc.ChunkHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -16,16 +16,16 @@ func NewChunkFileRepository(db *DB) *ChunkFileRepository {
|
||||
|
||||
func (r *ChunkFileRepository) Create(ctx context.Context, tx *sql.Tx, cf *ChunkFile) error {
|
||||
query := `
|
||||
INSERT INTO chunk_files (chunk_hash, file_path, file_offset, length)
|
||||
INSERT INTO chunk_files (chunk_hash, file_id, file_offset, length)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(chunk_hash, file_path) DO NOTHING
|
||||
ON CONFLICT(chunk_hash, file_id) DO NOTHING
|
||||
`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, cf.ChunkHash, cf.FilePath, cf.FileOffset, cf.Length)
|
||||
_, err = tx.ExecContext(ctx, query, cf.ChunkHash, cf.FileID, cf.FileOffset, cf.Length)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, cf.ChunkHash, cf.FilePath, cf.FileOffset, cf.Length)
|
||||
_, err = r.db.ExecWithLog(ctx, query, cf.ChunkHash, cf.FileID, cf.FileOffset, cf.Length)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -37,7 +37,7 @@ func (r *ChunkFileRepository) Create(ctx context.Context, tx *sql.Tx, cf *ChunkF
|
||||
|
||||
func (r *ChunkFileRepository) GetByChunkHash(ctx context.Context, chunkHash string) ([]*ChunkFile, error) {
|
||||
query := `
|
||||
SELECT chunk_hash, file_path, file_offset, length
|
||||
SELECT chunk_hash, file_id, file_offset, length
|
||||
FROM chunk_files
|
||||
WHERE chunk_hash = ?
|
||||
`
|
||||
@@ -51,7 +51,7 @@ func (r *ChunkFileRepository) GetByChunkHash(ctx context.Context, chunkHash stri
|
||||
var chunkFiles []*ChunkFile
|
||||
for rows.Next() {
|
||||
var cf ChunkFile
|
||||
err := rows.Scan(&cf.ChunkHash, &cf.FilePath, &cf.FileOffset, &cf.Length)
|
||||
err := rows.Scan(&cf.ChunkHash, &cf.FileID, &cf.FileOffset, &cf.Length)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning chunk file: %w", err)
|
||||
}
|
||||
@@ -63,9 +63,10 @@ func (r *ChunkFileRepository) GetByChunkHash(ctx context.Context, chunkHash stri
|
||||
|
||||
func (r *ChunkFileRepository) GetByFilePath(ctx context.Context, filePath string) ([]*ChunkFile, error) {
|
||||
query := `
|
||||
SELECT chunk_hash, file_path, file_offset, length
|
||||
FROM chunk_files
|
||||
WHERE file_path = ?
|
||||
SELECT cf.chunk_hash, cf.file_id, cf.file_offset, cf.length
|
||||
FROM chunk_files cf
|
||||
JOIN files f ON cf.file_id = f.id
|
||||
WHERE f.path = ?
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query, filePath)
|
||||
@@ -77,7 +78,34 @@ func (r *ChunkFileRepository) GetByFilePath(ctx context.Context, filePath string
|
||||
var chunkFiles []*ChunkFile
|
||||
for rows.Next() {
|
||||
var cf ChunkFile
|
||||
err := rows.Scan(&cf.ChunkHash, &cf.FilePath, &cf.FileOffset, &cf.Length)
|
||||
err := rows.Scan(&cf.ChunkHash, &cf.FileID, &cf.FileOffset, &cf.Length)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning chunk file: %w", err)
|
||||
}
|
||||
chunkFiles = append(chunkFiles, &cf)
|
||||
}
|
||||
|
||||
return chunkFiles, rows.Err()
|
||||
}
|
||||
|
||||
// GetByFileID retrieves chunk files by file ID
|
||||
func (r *ChunkFileRepository) GetByFileID(ctx context.Context, fileID string) ([]*ChunkFile, error) {
|
||||
query := `
|
||||
SELECT chunk_hash, file_id, file_offset, length
|
||||
FROM chunk_files
|
||||
WHERE file_id = ?
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query, fileID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying chunk files: %w", err)
|
||||
}
|
||||
defer CloseRows(rows)
|
||||
|
||||
var chunkFiles []*ChunkFile
|
||||
for rows.Next() {
|
||||
var cf ChunkFile
|
||||
err := rows.Scan(&cf.ChunkHash, &cf.FileID, &cf.FileOffset, &cf.Length)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning chunk file: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestChunkFileRepository(t *testing.T) {
|
||||
@@ -11,16 +12,49 @@ func TestChunkFileRepository(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewChunkFileRepository(db)
|
||||
fileRepo := NewFileRepository(db)
|
||||
|
||||
// Create test files first
|
||||
testTime := time.Now().Truncate(time.Second)
|
||||
file1 := &File{
|
||||
Path: "/file1.txt",
|
||||
MTime: testTime,
|
||||
CTime: testTime,
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
LinkTarget: "",
|
||||
}
|
||||
err := fileRepo.Create(ctx, nil, file1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file1: %v", err)
|
||||
}
|
||||
|
||||
file2 := &File{
|
||||
Path: "/file2.txt",
|
||||
MTime: testTime,
|
||||
CTime: testTime,
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
LinkTarget: "",
|
||||
}
|
||||
err = fileRepo.Create(ctx, nil, file2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file2: %v", err)
|
||||
}
|
||||
|
||||
// Test Create
|
||||
cf1 := &ChunkFile{
|
||||
ChunkHash: "chunk1",
|
||||
FilePath: "/file1.txt",
|
||||
FileID: file1.ID,
|
||||
FileOffset: 0,
|
||||
Length: 1024,
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, nil, cf1)
|
||||
err = repo.Create(ctx, nil, cf1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk file: %v", err)
|
||||
}
|
||||
@@ -28,7 +62,7 @@ func TestChunkFileRepository(t *testing.T) {
|
||||
// Add same chunk in different file (deduplication scenario)
|
||||
cf2 := &ChunkFile{
|
||||
ChunkHash: "chunk1",
|
||||
FilePath: "/file2.txt",
|
||||
FileID: file2.ID,
|
||||
FileOffset: 2048,
|
||||
Length: 1024,
|
||||
}
|
||||
@@ -50,10 +84,10 @@ func TestChunkFileRepository(t *testing.T) {
|
||||
foundFile1 := false
|
||||
foundFile2 := false
|
||||
for _, cf := range chunkFiles {
|
||||
if cf.FilePath == "/file1.txt" && cf.FileOffset == 0 {
|
||||
if cf.FileID == file1.ID && cf.FileOffset == 0 {
|
||||
foundFile1 = true
|
||||
}
|
||||
if cf.FilePath == "/file2.txt" && cf.FileOffset == 2048 {
|
||||
if cf.FileID == file2.ID && cf.FileOffset == 2048 {
|
||||
foundFile2 = true
|
||||
}
|
||||
}
|
||||
@@ -61,10 +95,10 @@ func TestChunkFileRepository(t *testing.T) {
|
||||
t.Error("not all expected files found")
|
||||
}
|
||||
|
||||
// Test GetByFilePath
|
||||
chunkFiles, err = repo.GetByFilePath(ctx, "/file1.txt")
|
||||
// Test GetByFileID
|
||||
chunkFiles, err = repo.GetByFileID(ctx, file1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get chunks by file path: %v", err)
|
||||
t.Fatalf("failed to get chunks by file ID: %v", err)
|
||||
}
|
||||
if len(chunkFiles) != 1 {
|
||||
t.Errorf("expected 1 chunk for file, got %d", len(chunkFiles))
|
||||
@@ -86,6 +120,23 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewChunkFileRepository(db)
|
||||
fileRepo := NewFileRepository(db)
|
||||
|
||||
// Create test files
|
||||
testTime := time.Now().Truncate(time.Second)
|
||||
file1 := &File{Path: "/file1.txt", MTime: testTime, CTime: testTime, Size: 3072, Mode: 0644, UID: 1000, GID: 1000}
|
||||
file2 := &File{Path: "/file2.txt", MTime: testTime, CTime: testTime, Size: 3072, Mode: 0644, UID: 1000, GID: 1000}
|
||||
file3 := &File{Path: "/file3.txt", MTime: testTime, CTime: testTime, Size: 2048, Mode: 0644, UID: 1000, GID: 1000}
|
||||
|
||||
if err := fileRepo.Create(ctx, nil, file1); err != nil {
|
||||
t.Fatalf("failed to create file1: %v", err)
|
||||
}
|
||||
if err := fileRepo.Create(ctx, nil, file2); err != nil {
|
||||
t.Fatalf("failed to create file2: %v", err)
|
||||
}
|
||||
if err := fileRepo.Create(ctx, nil, file3); err != nil {
|
||||
t.Fatalf("failed to create file3: %v", err)
|
||||
}
|
||||
|
||||
// Simulate a scenario where multiple files share chunks
|
||||
// File1: chunk1, chunk2, chunk3
|
||||
@@ -94,16 +145,16 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
|
||||
|
||||
chunkFiles := []ChunkFile{
|
||||
// File1
|
||||
{ChunkHash: "chunk1", FilePath: "/file1.txt", FileOffset: 0, Length: 1024},
|
||||
{ChunkHash: "chunk2", FilePath: "/file1.txt", FileOffset: 1024, Length: 1024},
|
||||
{ChunkHash: "chunk3", FilePath: "/file1.txt", FileOffset: 2048, Length: 1024},
|
||||
{ChunkHash: "chunk1", FileID: file1.ID, FileOffset: 0, Length: 1024},
|
||||
{ChunkHash: "chunk2", FileID: file1.ID, FileOffset: 1024, Length: 1024},
|
||||
{ChunkHash: "chunk3", FileID: file1.ID, FileOffset: 2048, Length: 1024},
|
||||
// File2
|
||||
{ChunkHash: "chunk2", FilePath: "/file2.txt", FileOffset: 0, Length: 1024},
|
||||
{ChunkHash: "chunk3", FilePath: "/file2.txt", FileOffset: 1024, Length: 1024},
|
||||
{ChunkHash: "chunk4", FilePath: "/file2.txt", FileOffset: 2048, Length: 1024},
|
||||
{ChunkHash: "chunk2", FileID: file2.ID, FileOffset: 0, Length: 1024},
|
||||
{ChunkHash: "chunk3", FileID: file2.ID, FileOffset: 1024, Length: 1024},
|
||||
{ChunkHash: "chunk4", FileID: file2.ID, FileOffset: 2048, Length: 1024},
|
||||
// File3
|
||||
{ChunkHash: "chunk1", FilePath: "/file3.txt", FileOffset: 0, Length: 1024},
|
||||
{ChunkHash: "chunk4", FilePath: "/file3.txt", FileOffset: 1024, Length: 1024},
|
||||
{ChunkHash: "chunk1", FileID: file3.ID, FileOffset: 0, Length: 1024},
|
||||
{ChunkHash: "chunk4", FileID: file3.ID, FileOffset: 1024, Length: 1024},
|
||||
}
|
||||
|
||||
for _, cf := range chunkFiles {
|
||||
@@ -132,7 +183,7 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test file2 chunks
|
||||
chunks, err := repo.GetByFilePath(ctx, "/file2.txt")
|
||||
chunks, err := repo.GetByFileID(ctx, file2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get chunks for file2: %v", err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
)
|
||||
|
||||
type ChunkRepository struct {
|
||||
@@ -25,7 +27,7 @@ func (r *ChunkRepository) Create(ctx context.Context, tx *sql.Tx, chunk *Chunk)
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, chunk.ChunkHash, chunk.SHA256, chunk.Size)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, chunk.ChunkHash, chunk.SHA256, chunk.Size)
|
||||
_, err = r.db.ExecWithLog(ctx, query, chunk.ChunkHash, chunk.SHA256, chunk.Size)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -139,3 +141,26 @@ func (r *ChunkRepository) ListUnpacked(ctx context.Context, limit int) ([]*Chunk
|
||||
|
||||
return chunks, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteOrphaned deletes chunks that are not referenced by any file
|
||||
func (r *ChunkRepository) DeleteOrphaned(ctx context.Context) error {
|
||||
query := `
|
||||
DELETE FROM chunks
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM file_chunks
|
||||
WHERE file_chunks.chunk_hash = chunks.chunk_hash
|
||||
)
|
||||
`
|
||||
|
||||
result, err := r.db.ExecWithLog(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned chunks: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected > 0 {
|
||||
log.Debug("Deleted orphaned chunks", "count", rowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,84 +1,158 @@
|
||||
// Package database provides the local SQLite index for Vaultik backup operations.
|
||||
// The database tracks files, chunks, and their associations with blobs.
|
||||
//
|
||||
// Blobs in Vaultik are the final storage units uploaded to S3. Each blob is a
|
||||
// large (up to 10GB) file containing many compressed and encrypted chunks from
|
||||
// multiple source files. Blobs are content-addressed, meaning their filename
|
||||
// is derived from their SHA256 hash after compression and encryption.
|
||||
//
|
||||
// The database does not support migrations. If the schema changes, delete
|
||||
// the local database and perform a full backup to recreate it.
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
var schemaSQL string
|
||||
|
||||
// DB represents the Vaultik local index database connection.
|
||||
// It uses SQLite to track file metadata, content-defined chunks, and blob associations.
|
||||
// The database enables incremental backups by detecting changed files and
|
||||
// supports deduplication by tracking which chunks are already stored in blobs.
|
||||
// Write operations are synchronized through a mutex to ensure thread safety.
|
||||
type DB struct {
|
||||
conn *sql.DB
|
||||
writeLock sync.Mutex
|
||||
conn *sql.DB
|
||||
path string
|
||||
}
|
||||
|
||||
// New creates a new database connection at the specified path.
|
||||
// It automatically handles database recovery, creates the schema if needed,
|
||||
// and configures SQLite with appropriate settings for performance and reliability.
|
||||
// The database uses WAL mode for better concurrency and sets a busy timeout
|
||||
// to handle concurrent access gracefully.
|
||||
//
|
||||
// If the database appears locked, it will attempt recovery by removing stale
|
||||
// lock files and switching temporarily to TRUNCATE journal mode.
|
||||
//
|
||||
// New creates a new database connection at the specified path.
|
||||
// It automatically handles recovery from stale locks, creates the schema if needed,
|
||||
// and configures SQLite with WAL mode for better concurrency.
|
||||
// The path parameter can be a file path for persistent storage or ":memory:"
|
||||
// for an in-memory database (useful for testing).
|
||||
func New(ctx context.Context, path string) (*DB, error) {
|
||||
log.Debug("Opening database connection", "path", path)
|
||||
|
||||
// First, try to recover from any stale locks
|
||||
if err := recoverDatabase(ctx, path); err != nil {
|
||||
log.Warn("Failed to recover database", "error", err)
|
||||
}
|
||||
|
||||
// First attempt with standard WAL mode
|
||||
conn, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=10000&_locking_mode=NORMAL")
|
||||
log.Debug("Attempting to open database with WAL mode", "path", path)
|
||||
conn, err := sql.Open(
|
||||
"sqlite",
|
||||
path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=10000&_locking_mode=NORMAL&_foreign_keys=ON",
|
||||
)
|
||||
if err == nil {
|
||||
// Set connection pool settings to ensure proper cleanup
|
||||
conn.SetMaxOpenConns(1) // SQLite only supports one writer
|
||||
// Set connection pool settings
|
||||
// SQLite can handle multiple readers but only one writer at a time.
|
||||
// Setting MaxOpenConns to 1 ensures all writes are serialized through
|
||||
// a single connection, preventing SQLITE_BUSY errors.
|
||||
conn.SetMaxOpenConns(1)
|
||||
conn.SetMaxIdleConns(1)
|
||||
|
||||
if err := conn.PingContext(ctx); err == nil {
|
||||
// Success on first try
|
||||
db := &DB{conn: conn}
|
||||
log.Debug("Database opened successfully with WAL mode", "path", path)
|
||||
|
||||
// Enable foreign keys explicitly
|
||||
if _, err := conn.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
|
||||
log.Warn("Failed to enable foreign keys", "error", err)
|
||||
}
|
||||
|
||||
db := &DB{conn: conn, path: path}
|
||||
if err := db.createSchema(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("creating schema: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
log.Debug("Failed to ping database, closing connection", "path", path, "error", err)
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
// If first attempt failed, try with TRUNCATE mode to clear any locks
|
||||
log.Info("Database appears locked, attempting recovery with TRUNCATE mode")
|
||||
conn, err = sql.Open("sqlite", path+"?_journal_mode=TRUNCATE&_synchronous=NORMAL&_busy_timeout=10000")
|
||||
log.Info(
|
||||
"Database appears locked, attempting recovery with TRUNCATE mode",
|
||||
"path", path,
|
||||
)
|
||||
conn, err = sql.Open(
|
||||
"sqlite",
|
||||
path+"?_journal_mode=TRUNCATE&_synchronous=NORMAL&_busy_timeout=10000&_foreign_keys=ON",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening database in recovery mode: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool settings
|
||||
// SQLite can handle multiple readers but only one writer at a time.
|
||||
// Setting MaxOpenConns to 1 ensures all writes are serialized through
|
||||
// a single connection, preventing SQLITE_BUSY errors.
|
||||
conn.SetMaxOpenConns(1)
|
||||
conn.SetMaxIdleConns(1)
|
||||
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
log.Debug("Failed to ping database in recovery mode, closing", "path", path, "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("database still locked after recovery attempt: %w", err)
|
||||
return nil, fmt.Errorf(
|
||||
"database still locked after recovery attempt: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Debug("Database opened in TRUNCATE mode", "path", path)
|
||||
|
||||
// Switch back to WAL mode
|
||||
log.Debug("Switching database back to WAL mode", "path", path)
|
||||
if _, err := conn.ExecContext(ctx, "PRAGMA journal_mode=WAL"); err != nil {
|
||||
log.Warn("Failed to switch back to WAL mode", "error", err)
|
||||
log.Warn("Failed to switch back to WAL mode", "path", path, "error", err)
|
||||
}
|
||||
|
||||
db := &DB{conn: conn}
|
||||
// Ensure foreign keys are enabled
|
||||
if _, err := conn.ExecContext(ctx, "PRAGMA foreign_keys=ON"); err != nil {
|
||||
log.Warn("Failed to enable foreign keys", "path", path, "error", err)
|
||||
}
|
||||
|
||||
db := &DB{conn: conn, path: path}
|
||||
if err := db.createSchema(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("creating schema: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("Database connection established successfully", "path", path)
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
// It ensures all pending operations are completed before closing.
|
||||
// Returns an error if the database connection cannot be closed properly.
|
||||
func (db *DB) Close() error {
|
||||
log.Debug("Closing database connection")
|
||||
log.Debug("Closing database connection", "path", db.path)
|
||||
if err := db.conn.Close(); err != nil {
|
||||
log.Error("Failed to close database", "error", err)
|
||||
log.Error("Failed to close database", "path", db.path, "error", err)
|
||||
return fmt.Errorf("failed to close database: %w", err)
|
||||
}
|
||||
log.Debug("Database connection closed successfully")
|
||||
log.Debug("Database connection closed successfully", "path", db.path)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -138,148 +212,79 @@ func recoverDatabase(ctx context.Context, path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conn returns the underlying *sql.DB connection.
|
||||
// This should be used sparingly and primarily for read operations.
|
||||
// For write operations, prefer using the ExecWithLog method.
|
||||
func (db *DB) Conn() *sql.DB {
|
||||
return db.conn
|
||||
}
|
||||
|
||||
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
// BeginTx starts a new database transaction with the given options.
|
||||
// The caller is responsible for committing or rolling back the transaction.
|
||||
// For write transactions, consider using the Repositories.WithTx method instead,
|
||||
// which handles locking and rollback automatically.
|
||||
func (db *DB) BeginTx(
|
||||
ctx context.Context,
|
||||
opts *sql.TxOptions,
|
||||
) (*sql.Tx, error) {
|
||||
return db.conn.BeginTx(ctx, opts)
|
||||
}
|
||||
|
||||
// LockForWrite acquires the write lock
|
||||
func (db *DB) LockForWrite() {
|
||||
log.Debug("Attempting to acquire write lock")
|
||||
db.writeLock.Lock()
|
||||
log.Debug("Write lock acquired")
|
||||
}
|
||||
|
||||
// UnlockWrite releases the write lock
|
||||
func (db *DB) UnlockWrite() {
|
||||
log.Debug("Releasing write lock")
|
||||
db.writeLock.Unlock()
|
||||
log.Debug("Write lock released")
|
||||
}
|
||||
|
||||
// ExecWithLock executes a write query with the write lock held
|
||||
func (db *DB) ExecWithLock(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
db.writeLock.Lock()
|
||||
defer db.writeLock.Unlock()
|
||||
// Note: LockForWrite and UnlockWrite methods have been removed.
|
||||
// SQLite handles its own locking internally, so explicit locking is not needed.
|
||||
|
||||
// ExecWithLog executes a write query with SQL logging.
|
||||
// SQLite handles its own locking internally, so we just pass through to ExecContext.
|
||||
// The query and args parameters follow the same format as sql.DB.ExecContext.
|
||||
func (db *DB) ExecWithLog(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (sql.Result, error) {
|
||||
LogSQL("Execute", query, args...)
|
||||
return db.conn.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// QueryRowWithLock executes a write query that returns a row with the write lock held
|
||||
func (db *DB) QueryRowWithLock(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
db.writeLock.Lock()
|
||||
defer db.writeLock.Unlock()
|
||||
// QueryRowWithLog executes a query that returns at most one row with SQL logging.
|
||||
// This is useful for queries that modify data and return values (e.g., INSERT ... RETURNING).
|
||||
// SQLite handles its own locking internally.
|
||||
// The query and args parameters follow the same format as sql.DB.QueryRowContext.
|
||||
func (db *DB) QueryRowWithLog(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) *sql.Row {
|
||||
LogSQL("QueryRow", query, args...)
|
||||
return db.conn.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) createSchema(ctx context.Context) error {
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
path TEXT PRIMARY KEY,
|
||||
mtime INTEGER NOT NULL,
|
||||
ctime INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
mode INTEGER NOT NULL,
|
||||
uid INTEGER NOT NULL,
|
||||
gid INTEGER NOT NULL,
|
||||
link_target TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS file_chunks (
|
||||
path TEXT NOT NULL,
|
||||
idx INTEGER NOT NULL,
|
||||
chunk_hash TEXT NOT NULL,
|
||||
PRIMARY KEY (path, idx)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
chunk_hash TEXT PRIMARY KEY,
|
||||
sha256 TEXT NOT NULL,
|
||||
size INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS blobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
blob_hash TEXT UNIQUE,
|
||||
created_ts INTEGER NOT NULL,
|
||||
finished_ts INTEGER,
|
||||
uncompressed_size INTEGER NOT NULL DEFAULT 0,
|
||||
compressed_size INTEGER NOT NULL DEFAULT 0,
|
||||
uploaded_ts INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS blob_chunks (
|
||||
blob_id TEXT NOT NULL,
|
||||
chunk_hash TEXT NOT NULL,
|
||||
offset INTEGER NOT NULL,
|
||||
length INTEGER NOT NULL,
|
||||
PRIMARY KEY (blob_id, chunk_hash),
|
||||
FOREIGN KEY (blob_id) REFERENCES blobs(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS chunk_files (
|
||||
chunk_hash TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
file_offset INTEGER NOT NULL,
|
||||
length INTEGER NOT NULL,
|
||||
PRIMARY KEY (chunk_hash, file_path)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS snapshots (
|
||||
id TEXT PRIMARY KEY,
|
||||
hostname TEXT NOT NULL,
|
||||
vaultik_version TEXT NOT NULL,
|
||||
started_at INTEGER NOT NULL,
|
||||
completed_at INTEGER,
|
||||
file_count INTEGER NOT NULL DEFAULT 0,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
blob_count INTEGER NOT NULL DEFAULT 0,
|
||||
total_size INTEGER NOT NULL DEFAULT 0,
|
||||
blob_size INTEGER NOT NULL DEFAULT 0,
|
||||
compression_ratio REAL NOT NULL DEFAULT 1.0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS snapshot_files (
|
||||
snapshot_id TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
PRIMARY KEY (snapshot_id, file_path),
|
||||
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (file_path) REFERENCES files(path) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS snapshot_blobs (
|
||||
snapshot_id TEXT NOT NULL,
|
||||
blob_id TEXT NOT NULL,
|
||||
blob_hash TEXT NOT NULL,
|
||||
PRIMARY KEY (snapshot_id, blob_id),
|
||||
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS uploads (
|
||||
blob_hash TEXT PRIMARY KEY,
|
||||
uploaded_at INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
duration_ms INTEGER NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
_, err := db.conn.ExecContext(ctx, schema)
|
||||
_, err := db.conn.ExecContext(ctx, schemaSQL)
|
||||
return err
|
||||
}
|
||||
|
||||
// NewTestDB creates an in-memory SQLite database for testing
|
||||
// NewTestDB creates an in-memory SQLite database for testing purposes.
|
||||
// The database is automatically initialized with the schema and is ready for use.
|
||||
// Each call creates a new independent database instance.
|
||||
func NewTestDB() (*DB, error) {
|
||||
return New(context.Background(), ":memory:")
|
||||
}
|
||||
|
||||
// LogSQL logs SQL queries if debug mode is enabled
|
||||
// LogSQL logs SQL queries and their arguments when debug mode is enabled.
|
||||
// Debug mode is activated by setting the GODEBUG environment variable to include "vaultik".
|
||||
// This is useful for troubleshooting database operations and understanding query patterns.
|
||||
//
|
||||
// The operation parameter describes the type of SQL operation (e.g., "Execute", "Query").
|
||||
// The query parameter is the SQL statement being executed.
|
||||
// The args parameter contains the query arguments that will be interpolated.
|
||||
func LogSQL(operation, query string, args ...interface{}) {
|
||||
if strings.Contains(os.Getenv("GODEBUG"), "vaultik") {
|
||||
log.Debug("SQL "+operation, "query", strings.TrimSpace(query), "args", fmt.Sprintf("%v", args))
|
||||
log.Debug(
|
||||
"SQL "+operation,
|
||||
"query",
|
||||
strings.TrimSpace(query),
|
||||
"args",
|
||||
fmt.Sprintf("%v", args),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,21 +67,26 @@ func TestDatabaseConcurrentAccess(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Test concurrent writes
|
||||
done := make(chan bool, 10)
|
||||
type result struct {
|
||||
index int
|
||||
err error
|
||||
}
|
||||
results := make(chan result, 10)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(i int) {
|
||||
_, err := db.ExecWithLock(ctx, "INSERT INTO chunks (chunk_hash, sha256, size) VALUES (?, ?, ?)",
|
||||
_, err := db.ExecWithLog(ctx, "INSERT INTO chunks (chunk_hash, sha256, size) VALUES (?, ?, ?)",
|
||||
fmt.Sprintf("hash%d", i), fmt.Sprintf("sha%d", i), i*1024)
|
||||
if err != nil {
|
||||
t.Errorf("concurrent insert failed: %v", err)
|
||||
}
|
||||
done <- true
|
||||
results <- result{index: i, err: err}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
// Wait for all goroutines and check results
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
r := <-results
|
||||
if r.err != nil {
|
||||
t.Fatalf("concurrent insert %d failed: %v", r.index, r.err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all inserts succeeded
|
||||
|
||||
@@ -16,16 +16,16 @@ func NewFileChunkRepository(db *DB) *FileChunkRepository {
|
||||
|
||||
func (r *FileChunkRepository) Create(ctx context.Context, tx *sql.Tx, fc *FileChunk) error {
|
||||
query := `
|
||||
INSERT INTO file_chunks (path, idx, chunk_hash)
|
||||
INSERT INTO file_chunks (file_id, idx, chunk_hash)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(path, idx) DO NOTHING
|
||||
ON CONFLICT(file_id, idx) DO NOTHING
|
||||
`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, fc.Path, fc.Idx, fc.ChunkHash)
|
||||
_, err = tx.ExecContext(ctx, query, fc.FileID, fc.Idx, fc.ChunkHash)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, fc.Path, fc.Idx, fc.ChunkHash)
|
||||
_, err = r.db.ExecWithLog(ctx, query, fc.FileID, fc.Idx, fc.ChunkHash)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -37,10 +37,11 @@ func (r *FileChunkRepository) Create(ctx context.Context, tx *sql.Tx, fc *FileCh
|
||||
|
||||
func (r *FileChunkRepository) GetByPath(ctx context.Context, path string) ([]*FileChunk, error) {
|
||||
query := `
|
||||
SELECT path, idx, chunk_hash
|
||||
FROM file_chunks
|
||||
WHERE path = ?
|
||||
ORDER BY idx
|
||||
SELECT fc.file_id, fc.idx, fc.chunk_hash
|
||||
FROM file_chunks fc
|
||||
JOIN files f ON fc.file_id = f.id
|
||||
WHERE f.path = ?
|
||||
ORDER BY fc.idx
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query, path)
|
||||
@@ -52,7 +53,35 @@ func (r *FileChunkRepository) GetByPath(ctx context.Context, path string) ([]*Fi
|
||||
var fileChunks []*FileChunk
|
||||
for rows.Next() {
|
||||
var fc FileChunk
|
||||
err := rows.Scan(&fc.Path, &fc.Idx, &fc.ChunkHash)
|
||||
err := rows.Scan(&fc.FileID, &fc.Idx, &fc.ChunkHash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning file chunk: %w", err)
|
||||
}
|
||||
fileChunks = append(fileChunks, &fc)
|
||||
}
|
||||
|
||||
return fileChunks, rows.Err()
|
||||
}
|
||||
|
||||
// GetByFileID retrieves file chunks by file ID
|
||||
func (r *FileChunkRepository) GetByFileID(ctx context.Context, fileID string) ([]*FileChunk, error) {
|
||||
query := `
|
||||
SELECT file_id, idx, chunk_hash
|
||||
FROM file_chunks
|
||||
WHERE file_id = ?
|
||||
ORDER BY idx
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query, fileID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying file chunks: %w", err)
|
||||
}
|
||||
defer CloseRows(rows)
|
||||
|
||||
var fileChunks []*FileChunk
|
||||
for rows.Next() {
|
||||
var fc FileChunk
|
||||
err := rows.Scan(&fc.FileID, &fc.Idx, &fc.ChunkHash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning file chunk: %w", err)
|
||||
}
|
||||
@@ -65,10 +94,11 @@ func (r *FileChunkRepository) GetByPath(ctx context.Context, path string) ([]*Fi
|
||||
// GetByPathTx retrieves file chunks within a transaction
|
||||
func (r *FileChunkRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path string) ([]*FileChunk, error) {
|
||||
query := `
|
||||
SELECT path, idx, chunk_hash
|
||||
FROM file_chunks
|
||||
WHERE path = ?
|
||||
ORDER BY idx
|
||||
SELECT fc.file_id, fc.idx, fc.chunk_hash
|
||||
FROM file_chunks fc
|
||||
JOIN files f ON fc.file_id = f.id
|
||||
WHERE f.path = ?
|
||||
ORDER BY fc.idx
|
||||
`
|
||||
|
||||
LogSQL("GetByPathTx", query, path)
|
||||
@@ -81,7 +111,7 @@ func (r *FileChunkRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path
|
||||
var fileChunks []*FileChunk
|
||||
for rows.Next() {
|
||||
var fc FileChunk
|
||||
err := rows.Scan(&fc.Path, &fc.Idx, &fc.ChunkHash)
|
||||
err := rows.Scan(&fc.FileID, &fc.Idx, &fc.ChunkHash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning file chunk: %w", err)
|
||||
}
|
||||
@@ -93,13 +123,31 @@ func (r *FileChunkRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path
|
||||
}
|
||||
|
||||
func (r *FileChunkRepository) DeleteByPath(ctx context.Context, tx *sql.Tx, path string) error {
|
||||
query := `DELETE FROM file_chunks WHERE path = ?`
|
||||
query := `DELETE FROM file_chunks WHERE file_id = (SELECT id FROM files WHERE path = ?)`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, path)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, path)
|
||||
_, err = r.db.ExecWithLog(ctx, query, path)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting file chunks: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByFileID deletes all chunks for a file by its UUID
|
||||
func (r *FileChunkRepository) DeleteByFileID(ctx context.Context, tx *sql.Tx, fileID string) error {
|
||||
query := `DELETE FROM file_chunks WHERE file_id = ?`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, fileID)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLog(ctx, query, fileID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFileChunkRepository(t *testing.T) {
|
||||
@@ -12,22 +13,40 @@ func TestFileChunkRepository(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewFileChunkRepository(db)
|
||||
fileRepo := NewFileRepository(db)
|
||||
|
||||
// Create test file first
|
||||
testTime := time.Now().Truncate(time.Second)
|
||||
file := &File{
|
||||
Path: "/test/file.txt",
|
||||
MTime: testTime,
|
||||
CTime: testTime,
|
||||
Size: 3072,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
LinkTarget: "",
|
||||
}
|
||||
err := fileRepo.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Test Create
|
||||
fc1 := &FileChunk{
|
||||
Path: "/test/file.txt",
|
||||
FileID: file.ID,
|
||||
Idx: 0,
|
||||
ChunkHash: "chunk1",
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, nil, fc1)
|
||||
err = repo.Create(ctx, nil, fc1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file chunk: %v", err)
|
||||
}
|
||||
|
||||
// Add more chunks for the same file
|
||||
fc2 := &FileChunk{
|
||||
Path: "/test/file.txt",
|
||||
FileID: file.ID,
|
||||
Idx: 1,
|
||||
ChunkHash: "chunk2",
|
||||
}
|
||||
@@ -37,7 +56,7 @@ func TestFileChunkRepository(t *testing.T) {
|
||||
}
|
||||
|
||||
fc3 := &FileChunk{
|
||||
Path: "/test/file.txt",
|
||||
FileID: file.ID,
|
||||
Idx: 2,
|
||||
ChunkHash: "chunk3",
|
||||
}
|
||||
@@ -46,8 +65,8 @@ func TestFileChunkRepository(t *testing.T) {
|
||||
t.Fatalf("failed to create third file chunk: %v", err)
|
||||
}
|
||||
|
||||
// Test GetByPath
|
||||
chunks, err := repo.GetByPath(ctx, "/test/file.txt")
|
||||
// Test GetByFile
|
||||
chunks, err := repo.GetByFile(ctx, "/test/file.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get file chunks: %v", err)
|
||||
}
|
||||
@@ -68,13 +87,13 @@ func TestFileChunkRepository(t *testing.T) {
|
||||
t.Fatalf("failed to create duplicate file chunk: %v", err)
|
||||
}
|
||||
|
||||
// Test DeleteByPath
|
||||
err = repo.DeleteByPath(ctx, nil, "/test/file.txt")
|
||||
// Test DeleteByFileID
|
||||
err = repo.DeleteByFileID(ctx, nil, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete file chunks: %v", err)
|
||||
}
|
||||
|
||||
chunks, err = repo.GetByPath(ctx, "/test/file.txt")
|
||||
chunks, err = repo.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get deleted file chunks: %v", err)
|
||||
}
|
||||
@@ -89,15 +108,38 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewFileChunkRepository(db)
|
||||
fileRepo := NewFileRepository(db)
|
||||
|
||||
// Create test files
|
||||
testTime := time.Now().Truncate(time.Second)
|
||||
filePaths := []string{"/file1.txt", "/file2.txt", "/file3.txt"}
|
||||
files := make([]*File, len(filePaths))
|
||||
|
||||
for i, path := range filePaths {
|
||||
file := &File{
|
||||
Path: path,
|
||||
MTime: testTime,
|
||||
CTime: testTime,
|
||||
Size: 2048,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
LinkTarget: "",
|
||||
}
|
||||
err := fileRepo.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file %s: %v", path, err)
|
||||
}
|
||||
files[i] = file
|
||||
}
|
||||
|
||||
// Create chunks for multiple files
|
||||
files := []string{"/file1.txt", "/file2.txt", "/file3.txt"}
|
||||
for _, path := range files {
|
||||
for i := 0; i < 2; i++ {
|
||||
for i, file := range files {
|
||||
for j := 0; j < 2; j++ {
|
||||
fc := &FileChunk{
|
||||
Path: path,
|
||||
Idx: i,
|
||||
ChunkHash: fmt.Sprintf("%s_chunk%d", path, i),
|
||||
FileID: file.ID,
|
||||
Idx: j,
|
||||
ChunkHash: fmt.Sprintf("file%d_chunk%d", i, j),
|
||||
}
|
||||
err := repo.Create(ctx, nil, fc)
|
||||
if err != nil {
|
||||
@@ -107,13 +149,13 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify each file has correct chunks
|
||||
for _, path := range files {
|
||||
chunks, err := repo.GetByPath(ctx, path)
|
||||
for i, file := range files {
|
||||
chunks, err := repo.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get chunks for %s: %v", path, err)
|
||||
t.Fatalf("failed to get chunks for file %d: %v", i, err)
|
||||
}
|
||||
if len(chunks) != 2 {
|
||||
t.Errorf("expected 2 chunks for %s, got %d", path, len(chunks))
|
||||
t.Errorf("expected 2 chunks for file %d, got %d", i, len(chunks))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/vaultik/internal/log"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type FileRepository struct {
|
||||
@@ -16,10 +19,16 @@ func NewFileRepository(db *DB) *FileRepository {
|
||||
}
|
||||
|
||||
func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) error {
|
||||
// Generate UUID if not provided
|
||||
if file.ID == "" {
|
||||
file.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO files (path, mtime, ctime, size, mode, uid, gid, link_target)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO files (id, path, mtime, ctime, size, mode, uid, gid, link_target)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(path) DO UPDATE SET
|
||||
id = excluded.id,
|
||||
mtime = excluded.mtime,
|
||||
ctime = excluded.ctime,
|
||||
size = excluded.size,
|
||||
@@ -27,14 +36,15 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err
|
||||
uid = excluded.uid,
|
||||
gid = excluded.gid,
|
||||
link_target = excluded.link_target
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
LogSQL("Execute", query, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
|
||||
_, err = tx.ExecContext(ctx, query, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
|
||||
LogSQL("Execute", query, file.ID, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
|
||||
err = tx.QueryRowContext(ctx, query, file.ID, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget).Scan(&file.ID)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
|
||||
err = r.db.QueryRowWithLog(ctx, query, file.ID, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget).Scan(&file.ID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -46,7 +56,7 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err
|
||||
|
||||
func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, error) {
|
||||
query := `
|
||||
SELECT path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
FROM files
|
||||
WHERE path = ?
|
||||
`
|
||||
@@ -56,6 +66,7 @@ func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, err
|
||||
var linkTarget sql.NullString
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, path).Scan(
|
||||
&file.ID,
|
||||
&file.Path,
|
||||
&mtimeUnix,
|
||||
&ctimeUnix,
|
||||
@@ -73,8 +84,48 @@ func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, err
|
||||
return nil, fmt.Errorf("querying file: %w", err)
|
||||
}
|
||||
|
||||
file.MTime = time.Unix(mtimeUnix, 0)
|
||||
file.CTime = time.Unix(ctimeUnix, 0)
|
||||
file.MTime = time.Unix(mtimeUnix, 0).UTC()
|
||||
file.CTime = time.Unix(ctimeUnix, 0).UTC()
|
||||
if linkTarget.Valid {
|
||||
file.LinkTarget = linkTarget.String
|
||||
}
|
||||
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a file by its UUID
|
||||
func (r *FileRepository) GetByID(ctx context.Context, id string) (*File, error) {
|
||||
query := `
|
||||
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
FROM files
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
var file File
|
||||
var mtimeUnix, ctimeUnix int64
|
||||
var linkTarget sql.NullString
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
|
||||
&file.ID,
|
||||
&file.Path,
|
||||
&mtimeUnix,
|
||||
&ctimeUnix,
|
||||
&file.Size,
|
||||
&file.Mode,
|
||||
&file.UID,
|
||||
&file.GID,
|
||||
&linkTarget,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying file: %w", err)
|
||||
}
|
||||
|
||||
file.MTime = time.Unix(mtimeUnix, 0).UTC()
|
||||
file.CTime = time.Unix(ctimeUnix, 0).UTC()
|
||||
if linkTarget.Valid {
|
||||
file.LinkTarget = linkTarget.String
|
||||
}
|
||||
@@ -84,7 +135,7 @@ func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, err
|
||||
|
||||
func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path string) (*File, error) {
|
||||
query := `
|
||||
SELECT path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
FROM files
|
||||
WHERE path = ?
|
||||
`
|
||||
@@ -95,6 +146,7 @@ func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path strin
|
||||
|
||||
LogSQL("GetByPathTx QueryRowContext", query, path)
|
||||
err := tx.QueryRowContext(ctx, query, path).Scan(
|
||||
&file.ID,
|
||||
&file.Path,
|
||||
&mtimeUnix,
|
||||
&ctimeUnix,
|
||||
@@ -113,8 +165,8 @@ func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path strin
|
||||
return nil, fmt.Errorf("querying file: %w", err)
|
||||
}
|
||||
|
||||
file.MTime = time.Unix(mtimeUnix, 0)
|
||||
file.CTime = time.Unix(ctimeUnix, 0)
|
||||
file.MTime = time.Unix(mtimeUnix, 0).UTC()
|
||||
file.CTime = time.Unix(ctimeUnix, 0).UTC()
|
||||
if linkTarget.Valid {
|
||||
file.LinkTarget = linkTarget.String
|
||||
}
|
||||
@@ -124,7 +176,7 @@ func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path strin
|
||||
|
||||
func (r *FileRepository) ListModifiedSince(ctx context.Context, since time.Time) ([]*File, error) {
|
||||
query := `
|
||||
SELECT path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
FROM files
|
||||
WHERE mtime >= ?
|
||||
ORDER BY path
|
||||
@@ -143,6 +195,7 @@ func (r *FileRepository) ListModifiedSince(ctx context.Context, since time.Time)
|
||||
var linkTarget sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&file.ID,
|
||||
&file.Path,
|
||||
&mtimeUnix,
|
||||
&ctimeUnix,
|
||||
@@ -175,7 +228,25 @@ func (r *FileRepository) Delete(ctx context.Context, tx *sql.Tx, path string) er
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, path)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, path)
|
||||
_, err = r.db.ExecWithLog(ctx, query, path)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByID deletes a file by its UUID
|
||||
func (r *FileRepository) DeleteByID(ctx context.Context, tx *sql.Tx, id string) error {
|
||||
query := `DELETE FROM files WHERE id = ?`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, id)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLog(ctx, query, id)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -187,7 +258,7 @@ func (r *FileRepository) Delete(ctx context.Context, tx *sql.Tx, path string) er
|
||||
|
||||
func (r *FileRepository) ListByPrefix(ctx context.Context, prefix string) ([]*File, error) {
|
||||
query := `
|
||||
SELECT path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
|
||||
FROM files
|
||||
WHERE path LIKE ? || '%'
|
||||
ORDER BY path
|
||||
@@ -206,6 +277,7 @@ func (r *FileRepository) ListByPrefix(ctx context.Context, prefix string) ([]*Fi
|
||||
var linkTarget sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&file.ID,
|
||||
&file.Path,
|
||||
&mtimeUnix,
|
||||
&ctimeUnix,
|
||||
@@ -230,3 +302,26 @@ func (r *FileRepository) ListByPrefix(ctx context.Context, prefix string) ([]*Fi
|
||||
|
||||
return files, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteOrphaned deletes files that are not referenced by any snapshot
|
||||
func (r *FileRepository) DeleteOrphaned(ctx context.Context) error {
|
||||
query := `
|
||||
DELETE FROM files
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM snapshot_files
|
||||
WHERE snapshot_files.file_id = files.id
|
||||
)
|
||||
`
|
||||
|
||||
result, err := r.db.ExecWithLog(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting orphaned files: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected > 0 {
|
||||
log.Debug("Deleted orphaned files", "count", rowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
// Package database provides data models and repository interfaces for the Vaultik backup system.
|
||||
// It includes types for files, chunks, blobs, snapshots, and their relationships.
|
||||
package database
|
||||
|
||||
import "time"
|
||||
|
||||
// File represents a file record in the database
|
||||
// File represents a file or directory in the backup system.
|
||||
// It stores metadata about files including timestamps, permissions, ownership,
|
||||
// and symlink targets. This information is used to restore files with their
|
||||
// original attributes.
|
||||
type File struct {
|
||||
ID string // UUID primary key
|
||||
Path string
|
||||
MTime time.Time
|
||||
CTime time.Time
|
||||
@@ -14,37 +20,52 @@ type File struct {
|
||||
LinkTarget string // empty for regular files, target path for symlinks
|
||||
}
|
||||
|
||||
// IsSymlink returns true if this file is a symbolic link
|
||||
// IsSymlink returns true if this file is a symbolic link.
|
||||
// A file is considered a symlink if it has a non-empty LinkTarget.
|
||||
func (f *File) IsSymlink() bool {
|
||||
return f.LinkTarget != ""
|
||||
}
|
||||
|
||||
// FileChunk represents the mapping between files and chunks
|
||||
// FileChunk represents the mapping between files and their constituent chunks.
|
||||
// Large files are split into multiple chunks for efficient deduplication and storage.
|
||||
// The Idx field maintains the order of chunks within a file.
|
||||
type FileChunk struct {
|
||||
Path string
|
||||
FileID string
|
||||
Idx int
|
||||
ChunkHash string
|
||||
}
|
||||
|
||||
// Chunk represents a chunk record in the database
|
||||
// Chunk represents a data chunk in the deduplication system.
|
||||
// Files are split into chunks which are content-addressed by their hash.
|
||||
// The ChunkHash is used for deduplication, while SHA256 provides
|
||||
// an additional verification hash.
|
||||
type Chunk struct {
|
||||
ChunkHash string
|
||||
SHA256 string
|
||||
Size int64
|
||||
}
|
||||
|
||||
// Blob represents a blob record in the database
|
||||
// Blob represents a blob record in the database.
|
||||
// A blob is Vaultik's final storage unit - a large file (up to 10GB) containing
|
||||
// many compressed and encrypted chunks from multiple source files.
|
||||
// Blobs are content-addressed, meaning their filename in S3 is derived from
|
||||
// the SHA256 hash of their compressed and encrypted content.
|
||||
// The blob creation process is: chunks are accumulated -> compressed with zstd
|
||||
// -> encrypted with age -> hashed -> uploaded to S3 with the hash as filename.
|
||||
type Blob struct {
|
||||
ID string
|
||||
Hash string // Can be empty until blob is finalized
|
||||
CreatedTS time.Time
|
||||
FinishedTS *time.Time // nil if not yet finalized
|
||||
UncompressedSize int64
|
||||
CompressedSize int64
|
||||
UploadedTS *time.Time // nil if not yet uploaded
|
||||
ID string // UUID assigned when blob creation starts
|
||||
Hash string // SHA256 of final compressed+encrypted content (empty until finalized)
|
||||
CreatedTS time.Time // When blob creation started
|
||||
FinishedTS *time.Time // When blob was finalized (nil if still packing)
|
||||
UncompressedSize int64 // Total size of raw chunks before compression
|
||||
CompressedSize int64 // Size after compression and encryption
|
||||
UploadedTS *time.Time // When blob was uploaded to S3 (nil if not uploaded)
|
||||
}
|
||||
|
||||
// BlobChunk represents the mapping between blobs and chunks
|
||||
// BlobChunk represents the mapping between blobs and the chunks they contain.
|
||||
// This allows tracking which chunks are stored in which blobs, along with
|
||||
// their position and size within the blob. The offset and length fields
|
||||
// enable extracting specific chunks from a blob without processing the entire blob.
|
||||
type BlobChunk struct {
|
||||
BlobID string
|
||||
ChunkHash string
|
||||
@@ -52,27 +73,34 @@ type BlobChunk struct {
|
||||
Length int64
|
||||
}
|
||||
|
||||
// ChunkFile represents the reverse mapping of chunks to files
|
||||
// ChunkFile represents the reverse mapping showing which files contain a specific chunk.
|
||||
// This is used during deduplication to identify all files that share a chunk,
|
||||
// which is important for garbage collection and integrity verification.
|
||||
type ChunkFile struct {
|
||||
ChunkHash string
|
||||
FilePath string
|
||||
FileID string
|
||||
FileOffset int64
|
||||
Length int64
|
||||
}
|
||||
|
||||
// Snapshot represents a snapshot record in the database
|
||||
type Snapshot struct {
|
||||
ID string
|
||||
Hostname string
|
||||
VaultikVersion string
|
||||
StartedAt time.Time
|
||||
CompletedAt *time.Time // nil if still in progress
|
||||
FileCount int64
|
||||
ChunkCount int64
|
||||
BlobCount int64
|
||||
TotalSize int64 // Total size of all referenced files
|
||||
BlobSize int64 // Total size of all referenced blobs (compressed and encrypted)
|
||||
CompressionRatio float64 // Compression ratio (BlobSize / TotalSize)
|
||||
ID string
|
||||
Hostname string
|
||||
VaultikVersion string
|
||||
VaultikGitRevision string
|
||||
StartedAt time.Time
|
||||
CompletedAt *time.Time // nil if still in progress
|
||||
FileCount int64
|
||||
ChunkCount int64
|
||||
BlobCount int64
|
||||
TotalSize int64 // Total size of all referenced files
|
||||
BlobSize int64 // Total size of all referenced blobs (compressed and encrypted)
|
||||
BlobUncompressedSize int64 // Total uncompressed size of all referenced blobs
|
||||
CompressionRatio float64 // Compression ratio (BlobSize / BlobUncompressedSize)
|
||||
CompressionLevel int // Compression level used for this snapshot
|
||||
UploadBytes int64 // Total bytes uploaded during this snapshot
|
||||
UploadDurationMs int64 // Total milliseconds spent uploading to S3
|
||||
}
|
||||
|
||||
// IsComplete returns true if the snapshot has completed
|
||||
@@ -83,7 +111,7 @@ func (s *Snapshot) IsComplete() bool {
|
||||
// SnapshotFile represents the mapping between snapshots and files
|
||||
type SnapshotFile struct {
|
||||
SnapshotID string
|
||||
FilePath string
|
||||
FileID string
|
||||
}
|
||||
|
||||
// SnapshotBlob represents the mapping between snapshots and blobs
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Repositories provides access to all database repositories.
|
||||
// It serves as a centralized access point for all database operations
|
||||
// and manages transaction coordination across repositories.
|
||||
type Repositories struct {
|
||||
db *DB
|
||||
Files *FileRepository
|
||||
@@ -18,6 +21,8 @@ type Repositories struct {
|
||||
Uploads *UploadRepository
|
||||
}
|
||||
|
||||
// NewRepositories creates a new Repositories instance with all repository types.
|
||||
// Each repository shares the same database connection for coordinated transactions.
|
||||
func NewRepositories(db *DB) *Repositories {
|
||||
return &Repositories{
|
||||
db: db,
|
||||
@@ -32,17 +37,16 @@ func NewRepositories(db *DB) *Repositories {
|
||||
}
|
||||
}
|
||||
|
||||
// TxFunc is a function that executes within a database transaction.
|
||||
// The transaction is automatically committed if the function returns nil,
|
||||
// or rolled back if it returns an error.
|
||||
type TxFunc func(ctx context.Context, tx *sql.Tx) error
|
||||
|
||||
// WithTx executes a function within a write transaction.
|
||||
// SQLite handles its own locking internally, so no explicit locking is needed.
|
||||
// The transaction is automatically committed on success or rolled back on error.
|
||||
// This method should be used for all write operations to ensure atomicity.
|
||||
func (r *Repositories) WithTx(ctx context.Context, fn TxFunc) error {
|
||||
// Acquire write lock for the entire transaction
|
||||
LogSQL("WithTx", "Acquiring write lock", "")
|
||||
r.db.LockForWrite()
|
||||
defer func() {
|
||||
LogSQL("WithTx", "Releasing write lock", "")
|
||||
r.db.UnlockWrite()
|
||||
}()
|
||||
|
||||
LogSQL("WithTx", "Beginning transaction", "")
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -71,6 +75,10 @@ func (r *Repositories) WithTx(ctx context.Context, fn TxFunc) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// WithReadTx executes a function within a read-only transaction.
|
||||
// Read transactions can run concurrently with other read transactions
|
||||
// but will be blocked by write transactions. The transaction is
|
||||
// automatically committed on success or rolled back on error.
|
||||
func (r *Repositories) WithReadTx(ctx context.Context, fn TxFunc) error {
|
||||
opts := &sql.TxOptions{
|
||||
ReadOnly: true,
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestRepositoriesTransaction(t *testing.T) {
|
||||
|
||||
// Map chunks to file
|
||||
fc1 := &FileChunk{
|
||||
Path: file.Path,
|
||||
FileID: file.ID,
|
||||
Idx: 0,
|
||||
ChunkHash: chunk1.ChunkHash,
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func TestRepositoriesTransaction(t *testing.T) {
|
||||
}
|
||||
|
||||
fc2 := &FileChunk{
|
||||
Path: file.Path,
|
||||
FileID: file.ID,
|
||||
Idx: 1,
|
||||
ChunkHash: chunk2.ChunkHash,
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func TestRepositoriesTransaction(t *testing.T) {
|
||||
t.Error("expected file after transaction")
|
||||
}
|
||||
|
||||
chunks, err := repos.FileChunks.GetByPath(ctx, "/test/tx_file.txt")
|
||||
chunks, err := repos.FileChunks.GetByFile(ctx, "/test/tx_file.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get file chunks: %v", err)
|
||||
}
|
||||
@@ -218,7 +218,7 @@ func TestRepositoriesReadTransaction(t *testing.T) {
|
||||
var retrievedFile *File
|
||||
err = repos.WithReadTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
var err error
|
||||
retrievedFile, err = repos.Files.GetByPath(ctx, "/test/read_file.txt")
|
||||
retrievedFile, err = repos.Files.GetByPathTx(ctx, tx, "/test/read_file.txt")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
876
internal/database/repository_comprehensive_test.go
Normal file
876
internal/database/repository_comprehensive_test.go
Normal file
@@ -0,0 +1,876 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestFileRepositoryUUIDGeneration tests that files get unique UUIDs
|
||||
func TestFileRepositoryUUIDGeneration(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewFileRepository(db)
|
||||
|
||||
// Create multiple files
|
||||
files := []*File{
|
||||
{
|
||||
Path: "/file1.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
},
|
||||
{
|
||||
Path: "/file2.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 2048,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
},
|
||||
}
|
||||
|
||||
uuids := make(map[string]bool)
|
||||
for _, file := range files {
|
||||
err := repo.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Check UUID was generated
|
||||
if file.ID == "" {
|
||||
t.Error("file ID was not generated")
|
||||
}
|
||||
|
||||
// Check UUID is unique
|
||||
if uuids[file.ID] {
|
||||
t.Errorf("duplicate UUID generated: %s", file.ID)
|
||||
}
|
||||
uuids[file.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileRepositoryGetByID tests retrieving files by UUID
|
||||
func TestFileRepositoryGetByID(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewFileRepository(db)
|
||||
|
||||
// Create a file
|
||||
file := &File{
|
||||
Path: "/test.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve by ID
|
||||
retrieved, err := repo.GetByID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get file by ID: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.ID != file.ID {
|
||||
t.Errorf("ID mismatch: expected %s, got %s", file.ID, retrieved.ID)
|
||||
}
|
||||
if retrieved.Path != file.Path {
|
||||
t.Errorf("Path mismatch: expected %s, got %s", file.Path, retrieved.Path)
|
||||
}
|
||||
|
||||
// Test non-existent ID
|
||||
nonExistent, err := repo.GetByID(ctx, "non-existent-uuid")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID should not return error for non-existent ID: %v", err)
|
||||
}
|
||||
if nonExistent != nil {
|
||||
t.Error("expected nil for non-existent ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOrphanedFileCleanup tests the cleanup of orphaned files
|
||||
func TestOrphanedFileCleanup(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create files
|
||||
file1 := &File{
|
||||
Path: "/orphaned.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
file2 := &File{
|
||||
Path: "/referenced.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 2048,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
|
||||
err := repos.Files.Create(ctx, nil, file1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file1: %v", err)
|
||||
}
|
||||
err = repos.Files.Create(ctx, nil, file2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file2: %v", err)
|
||||
}
|
||||
|
||||
// Create a snapshot and reference only file2
|
||||
snapshot := &Snapshot{
|
||||
ID: "test-snapshot",
|
||||
Hostname: "test-host",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
err = repos.Snapshots.Create(ctx, nil, snapshot)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Add file2 to snapshot
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add file to snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Run orphaned cleanup
|
||||
err = repos.Files.DeleteOrphaned(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete orphaned files: %v", err)
|
||||
}
|
||||
|
||||
// Check that orphaned file is gone
|
||||
orphanedFile, err := repos.Files.GetByID(ctx, file1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting file: %v", err)
|
||||
}
|
||||
if orphanedFile != nil {
|
||||
t.Error("orphaned file should have been deleted")
|
||||
}
|
||||
|
||||
// Check that referenced file still exists
|
||||
referencedFile, err := repos.Files.GetByID(ctx, file2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting file: %v", err)
|
||||
}
|
||||
if referencedFile == nil {
|
||||
t.Error("referenced file should not have been deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOrphanedChunkCleanup tests the cleanup of orphaned chunks
|
||||
func TestOrphanedChunkCleanup(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create chunks
|
||||
chunk1 := &Chunk{
|
||||
ChunkHash: "orphaned-chunk",
|
||||
SHA256: "orphaned-chunk-sha",
|
||||
Size: 1024,
|
||||
}
|
||||
chunk2 := &Chunk{
|
||||
ChunkHash: "referenced-chunk",
|
||||
SHA256: "referenced-chunk-sha",
|
||||
Size: 1024,
|
||||
}
|
||||
|
||||
err := repos.Chunks.Create(ctx, nil, chunk1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk1: %v", err)
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk2: %v", err)
|
||||
}
|
||||
|
||||
// Create a file and reference only chunk2
|
||||
file := &File{
|
||||
Path: "/test.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err = repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Create file-chunk mapping only for chunk2
|
||||
fc := &FileChunk{
|
||||
FileID: file.ID,
|
||||
Idx: 0,
|
||||
ChunkHash: chunk2.ChunkHash,
|
||||
}
|
||||
err = repos.FileChunks.Create(ctx, nil, fc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file chunk: %v", err)
|
||||
}
|
||||
|
||||
// Run orphaned cleanup
|
||||
err = repos.Chunks.DeleteOrphaned(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete orphaned chunks: %v", err)
|
||||
}
|
||||
|
||||
// Check that orphaned chunk is gone
|
||||
orphanedChunk, err := repos.Chunks.GetByHash(ctx, chunk1.ChunkHash)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting chunk: %v", err)
|
||||
}
|
||||
if orphanedChunk != nil {
|
||||
t.Error("orphaned chunk should have been deleted")
|
||||
}
|
||||
|
||||
// Check that referenced chunk still exists
|
||||
referencedChunk, err := repos.Chunks.GetByHash(ctx, chunk2.ChunkHash)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting chunk: %v", err)
|
||||
}
|
||||
if referencedChunk == nil {
|
||||
t.Error("referenced chunk should not have been deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOrphanedBlobCleanup tests the cleanup of orphaned blobs
|
||||
func TestOrphanedBlobCleanup(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create blobs
|
||||
blob1 := &Blob{
|
||||
ID: "orphaned-blob-id",
|
||||
Hash: "orphaned-blob",
|
||||
CreatedTS: time.Now().Truncate(time.Second),
|
||||
}
|
||||
blob2 := &Blob{
|
||||
ID: "referenced-blob-id",
|
||||
Hash: "referenced-blob",
|
||||
CreatedTS: time.Now().Truncate(time.Second),
|
||||
}
|
||||
|
||||
err := repos.Blobs.Create(ctx, nil, blob1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create blob1: %v", err)
|
||||
}
|
||||
err = repos.Blobs.Create(ctx, nil, blob2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create blob2: %v", err)
|
||||
}
|
||||
|
||||
// Create a snapshot and reference only blob2
|
||||
snapshot := &Snapshot{
|
||||
ID: "test-snapshot",
|
||||
Hostname: "test-host",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
err = repos.Snapshots.Create(ctx, nil, snapshot)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Add blob2 to snapshot
|
||||
err = repos.Snapshots.AddBlob(ctx, nil, snapshot.ID, blob2.ID, blob2.Hash)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add blob to snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Run orphaned cleanup
|
||||
err = repos.Blobs.DeleteOrphaned(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete orphaned blobs: %v", err)
|
||||
}
|
||||
|
||||
// Check that orphaned blob is gone
|
||||
orphanedBlob, err := repos.Blobs.GetByID(ctx, blob1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting blob: %v", err)
|
||||
}
|
||||
if orphanedBlob != nil {
|
||||
t.Error("orphaned blob should have been deleted")
|
||||
}
|
||||
|
||||
// Check that referenced blob still exists
|
||||
referencedBlob, err := repos.Blobs.GetByID(ctx, blob2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting blob: %v", err)
|
||||
}
|
||||
if referencedBlob == nil {
|
||||
t.Error("referenced blob should not have been deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileChunkRepositoryWithUUIDs tests file-chunk relationships with UUIDs
|
||||
func TestFileChunkRepositoryWithUUIDs(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create a file
|
||||
file := &File{
|
||||
Path: "/test.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 3072,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err := repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Create chunks
|
||||
chunks := []string{"chunk1", "chunk2", "chunk3"}
|
||||
for i, chunkHash := range chunks {
|
||||
chunk := &Chunk{
|
||||
ChunkHash: chunkHash,
|
||||
SHA256: fmt.Sprintf("sha-%s", chunkHash),
|
||||
Size: 1024,
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk: %v", err)
|
||||
}
|
||||
|
||||
// Create file-chunk mapping
|
||||
fc := &FileChunk{
|
||||
FileID: file.ID,
|
||||
Idx: i,
|
||||
ChunkHash: chunkHash,
|
||||
}
|
||||
err = repos.FileChunks.Create(ctx, nil, fc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file chunk: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetByFileID
|
||||
fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get file chunks: %v", err)
|
||||
}
|
||||
if len(fileChunks) != 3 {
|
||||
t.Errorf("expected 3 chunks, got %d", len(fileChunks))
|
||||
}
|
||||
|
||||
// Test DeleteByFileID
|
||||
err = repos.FileChunks.DeleteByFileID(ctx, nil, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete file chunks: %v", err)
|
||||
}
|
||||
|
||||
fileChunks, err = repos.FileChunks.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get file chunks after delete: %v", err)
|
||||
}
|
||||
if len(fileChunks) != 0 {
|
||||
t.Errorf("expected 0 chunks after delete, got %d", len(fileChunks))
|
||||
}
|
||||
}
|
||||
|
||||
// TestChunkFileRepositoryWithUUIDs tests chunk-file relationships with UUIDs
|
||||
func TestChunkFileRepositoryWithUUIDs(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create files
|
||||
file1 := &File{
|
||||
Path: "/file1.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
file2 := &File{
|
||||
Path: "/file2.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
|
||||
err := repos.Files.Create(ctx, nil, file1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file1: %v", err)
|
||||
}
|
||||
err = repos.Files.Create(ctx, nil, file2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file2: %v", err)
|
||||
}
|
||||
|
||||
// Create a chunk that appears in both files (deduplication)
|
||||
chunk := &Chunk{
|
||||
ChunkHash: "shared-chunk",
|
||||
SHA256: "shared-chunk-sha",
|
||||
Size: 1024,
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk: %v", err)
|
||||
}
|
||||
|
||||
// Create chunk-file mappings
|
||||
cf1 := &ChunkFile{
|
||||
ChunkHash: chunk.ChunkHash,
|
||||
FileID: file1.ID,
|
||||
FileOffset: 0,
|
||||
Length: 1024,
|
||||
}
|
||||
cf2 := &ChunkFile{
|
||||
ChunkHash: chunk.ChunkHash,
|
||||
FileID: file2.ID,
|
||||
FileOffset: 512,
|
||||
Length: 1024,
|
||||
}
|
||||
|
||||
err = repos.ChunkFiles.Create(ctx, nil, cf1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk file 1: %v", err)
|
||||
}
|
||||
err = repos.ChunkFiles.Create(ctx, nil, cf2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk file 2: %v", err)
|
||||
}
|
||||
|
||||
// Test GetByChunkHash
|
||||
chunkFiles, err := repos.ChunkFiles.GetByChunkHash(ctx, chunk.ChunkHash)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get chunk files: %v", err)
|
||||
}
|
||||
if len(chunkFiles) != 2 {
|
||||
t.Errorf("expected 2 files for chunk, got %d", len(chunkFiles))
|
||||
}
|
||||
|
||||
// Test GetByFileID
|
||||
chunkFiles, err = repos.ChunkFiles.GetByFileID(ctx, file1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get chunks by file ID: %v", err)
|
||||
}
|
||||
if len(chunkFiles) != 1 {
|
||||
t.Errorf("expected 1 chunk for file, got %d", len(chunkFiles))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSnapshotRepositoryExtendedFields tests snapshot with version and git revision
|
||||
func TestSnapshotRepositoryExtendedFields(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewSnapshotRepository(db)
|
||||
|
||||
// Create snapshot with extended fields
|
||||
snapshot := &Snapshot{
|
||||
ID: "test-20250722-120000Z",
|
||||
Hostname: "test-host",
|
||||
VaultikVersion: "0.0.1",
|
||||
VaultikGitRevision: "abc123def456",
|
||||
StartedAt: time.Now(),
|
||||
CompletedAt: nil,
|
||||
FileCount: 100,
|
||||
ChunkCount: 200,
|
||||
BlobCount: 50,
|
||||
TotalSize: 1024 * 1024,
|
||||
BlobSize: 512 * 1024,
|
||||
BlobUncompressedSize: 1024 * 1024,
|
||||
CompressionLevel: 6,
|
||||
CompressionRatio: 2.0,
|
||||
UploadDurationMs: 5000,
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, nil, snapshot)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify
|
||||
retrieved, err := repo.GetByID(ctx, snapshot.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get snapshot: %v", err)
|
||||
}
|
||||
|
||||
if retrieved.VaultikVersion != snapshot.VaultikVersion {
|
||||
t.Errorf("version mismatch: expected %s, got %s", snapshot.VaultikVersion, retrieved.VaultikVersion)
|
||||
}
|
||||
if retrieved.VaultikGitRevision != snapshot.VaultikGitRevision {
|
||||
t.Errorf("git revision mismatch: expected %s, got %s", snapshot.VaultikGitRevision, retrieved.VaultikGitRevision)
|
||||
}
|
||||
if retrieved.CompressionLevel != snapshot.CompressionLevel {
|
||||
t.Errorf("compression level mismatch: expected %d, got %d", snapshot.CompressionLevel, retrieved.CompressionLevel)
|
||||
}
|
||||
if retrieved.BlobUncompressedSize != snapshot.BlobUncompressedSize {
|
||||
t.Errorf("uncompressed size mismatch: expected %d, got %d", snapshot.BlobUncompressedSize, retrieved.BlobUncompressedSize)
|
||||
}
|
||||
if retrieved.UploadDurationMs != snapshot.UploadDurationMs {
|
||||
t.Errorf("upload duration mismatch: expected %d, got %d", snapshot.UploadDurationMs, retrieved.UploadDurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
// TestComplexOrphanedDataScenario tests a complex scenario with multiple relationships
|
||||
func TestComplexOrphanedDataScenario(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create snapshots
|
||||
snapshot1 := &Snapshot{
|
||||
ID: "snapshot1",
|
||||
Hostname: "host1",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
snapshot2 := &Snapshot{
|
||||
ID: "snapshot2",
|
||||
Hostname: "host1",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
|
||||
err := repos.Snapshots.Create(ctx, nil, snapshot1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create snapshot1: %v", err)
|
||||
}
|
||||
err = repos.Snapshots.Create(ctx, nil, snapshot2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create snapshot2: %v", err)
|
||||
}
|
||||
|
||||
// Create files
|
||||
files := make([]*File, 3)
|
||||
for i := range files {
|
||||
files[i] = &File{
|
||||
Path: fmt.Sprintf("/file%d.txt", i),
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err = repos.Files.Create(ctx, nil, files[i])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file%d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add files to snapshots
|
||||
// Snapshot1: file0, file1
|
||||
// Snapshot2: file1, file2
|
||||
// file0: only in snapshot1
|
||||
// file1: in both snapshots
|
||||
// file2: only in snapshot2
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot1.ID, files[0].ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot1.ID, files[1].ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot2.ID, files[1].ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot2.ID, files[2].ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Delete snapshot1
|
||||
err = repos.Snapshots.DeleteSnapshotFiles(ctx, snapshot1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = repos.Snapshots.Delete(ctx, snapshot1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Run orphaned cleanup
|
||||
err = repos.Files.DeleteOrphaned(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check results
|
||||
// file0 should be deleted (only in deleted snapshot)
|
||||
file0, err := repos.Files.GetByID(ctx, files[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting file0: %v", err)
|
||||
}
|
||||
if file0 != nil {
|
||||
t.Error("file0 should have been deleted")
|
||||
}
|
||||
|
||||
// file1 should exist (still in snapshot2)
|
||||
file1, err := repos.Files.GetByID(ctx, files[1].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting file1: %v", err)
|
||||
}
|
||||
if file1 == nil {
|
||||
t.Error("file1 should still exist")
|
||||
}
|
||||
|
||||
// file2 should exist (still in snapshot2)
|
||||
file2, err := repos.Files.GetByID(ctx, files[2].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting file2: %v", err)
|
||||
}
|
||||
if file2 == nil {
|
||||
t.Error("file2 should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCascadeDelete tests that cascade deletes work properly
|
||||
func TestCascadeDelete(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create a file
|
||||
file := &File{
|
||||
Path: "/cascade-test.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err := repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file: %v", err)
|
||||
}
|
||||
|
||||
// Create chunks and file-chunk mappings
|
||||
for i := 0; i < 3; i++ {
|
||||
chunk := &Chunk{
|
||||
ChunkHash: fmt.Sprintf("cascade-chunk-%d", i),
|
||||
SHA256: fmt.Sprintf("cascade-sha-%d", i),
|
||||
Size: 1024,
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk: %v", err)
|
||||
}
|
||||
|
||||
fc := &FileChunk{
|
||||
FileID: file.ID,
|
||||
Idx: i,
|
||||
ChunkHash: chunk.ChunkHash,
|
||||
}
|
||||
err = repos.FileChunks.Create(ctx, nil, fc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file chunk: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify file chunks exist
|
||||
fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(fileChunks) != 3 {
|
||||
t.Errorf("expected 3 file chunks, got %d", len(fileChunks))
|
||||
}
|
||||
|
||||
// Delete the file
|
||||
err = repos.Files.DeleteByID(ctx, nil, file.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete file: %v", err)
|
||||
}
|
||||
|
||||
// Verify file chunks were cascade deleted
|
||||
fileChunks, err = repos.FileChunks.GetByFileID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(fileChunks) != 0 {
|
||||
t.Errorf("expected 0 file chunks after cascade delete, got %d", len(fileChunks))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransactionIsolation tests that transactions properly isolate changes
|
||||
func TestTransactionIsolation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Start a transaction
|
||||
err := repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
|
||||
// Create a file within the transaction
|
||||
file := &File{
|
||||
Path: "/tx-test.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err := repos.Files.Create(ctx, tx, file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Within the same transaction, we should be able to query it
|
||||
// Note: This would require modifying GetByPath to accept a tx parameter
|
||||
// For now, we'll just test that rollback works
|
||||
|
||||
// Return an error to trigger rollback
|
||||
return fmt.Errorf("intentional rollback")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error from transaction")
|
||||
}
|
||||
|
||||
// Verify the file was not created (transaction rolled back)
|
||||
files, err := repos.Files.ListByPrefix(ctx, "/tx-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Error("file should not exist after rollback")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentOrphanedCleanup tests that concurrent cleanup operations don't interfere
|
||||
func TestConcurrentOrphanedCleanup(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Set a 5-second busy timeout to handle concurrent operations
|
||||
if _, err := db.conn.Exec("PRAGMA busy_timeout = 5000"); err != nil {
|
||||
t.Fatalf("failed to set busy timeout: %v", err)
|
||||
}
|
||||
|
||||
// Create a snapshot
|
||||
snapshot := &Snapshot{
|
||||
ID: "concurrent-test",
|
||||
Hostname: "test-host",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
err := repos.Snapshots.Create(ctx, nil, snapshot)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create many files, some orphaned
|
||||
for i := 0; i < 20; i++ {
|
||||
file := &File{
|
||||
Path: fmt.Sprintf("/concurrent-%d.txt", i),
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err = repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Add even-numbered files to snapshot
|
||||
if i%2 == 0 {
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run multiple cleanup operations concurrently
|
||||
// Note: SQLite has limited support for concurrent writes, so we expect some to fail
|
||||
done := make(chan error, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
go func() {
|
||||
done <- repos.Files.DeleteOrphaned(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all to complete
|
||||
for i := 0; i < 3; i++ {
|
||||
err := <-done
|
||||
if err != nil {
|
||||
t.Errorf("cleanup %d failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct files were deleted
|
||||
files, err := repos.Files.ListByPrefix(ctx, "/concurrent-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should have 10 files remaining (even numbered)
|
||||
if len(files) != 10 {
|
||||
t.Errorf("expected 10 files remaining, got %d", len(files))
|
||||
}
|
||||
|
||||
// Verify all remaining files are even-numbered
|
||||
for _, file := range files {
|
||||
var num int
|
||||
_, err := fmt.Sscanf(file.Path, "/concurrent-%d.txt", &num)
|
||||
if err != nil {
|
||||
t.Logf("failed to parse file number from %s: %v", file.Path, err)
|
||||
}
|
||||
if num%2 != 0 {
|
||||
t.Errorf("odd-numbered file %s should have been deleted", file.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
165
internal/database/repository_debug_test.go
Normal file
165
internal/database/repository_debug_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestOrphanedFileCleanupDebug tests orphaned file cleanup with debug output
|
||||
func TestOrphanedFileCleanupDebug(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create files
|
||||
file1 := &File{
|
||||
Path: "/orphaned.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
file2 := &File{
|
||||
Path: "/referenced.txt",
|
||||
MTime: time.Now().Truncate(time.Second),
|
||||
CTime: time.Now().Truncate(time.Second),
|
||||
Size: 2048,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
|
||||
err := repos.Files.Create(ctx, nil, file1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file1: %v", err)
|
||||
}
|
||||
t.Logf("Created file1 with ID: %s", file1.ID)
|
||||
|
||||
err = repos.Files.Create(ctx, nil, file2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file2: %v", err)
|
||||
}
|
||||
t.Logf("Created file2 with ID: %s", file2.ID)
|
||||
|
||||
// Create a snapshot and reference only file2
|
||||
snapshot := &Snapshot{
|
||||
ID: "test-snapshot",
|
||||
Hostname: "test-host",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
err = repos.Snapshots.Create(ctx, nil, snapshot)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create snapshot: %v", err)
|
||||
}
|
||||
t.Logf("Created snapshot: %s", snapshot.ID)
|
||||
|
||||
// Check snapshot_files before adding
|
||||
var count int
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM snapshot_files").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("snapshot_files count before add: %d", count)
|
||||
|
||||
// Add file2 to snapshot
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add file to snapshot: %v", err)
|
||||
}
|
||||
t.Logf("Added file2 to snapshot")
|
||||
|
||||
// Check snapshot_files after adding
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM snapshot_files").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("snapshot_files count after add: %d", count)
|
||||
|
||||
// Check which files are referenced
|
||||
rows, err := db.conn.Query("SELECT file_id FROM snapshot_files")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
t.Logf("failed to close rows: %v", err)
|
||||
}
|
||||
}()
|
||||
t.Log("Files in snapshot_files:")
|
||||
for rows.Next() {
|
||||
var fileID string
|
||||
if err := rows.Scan(&fileID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf(" - %s", fileID)
|
||||
}
|
||||
|
||||
// Check files before cleanup
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM files").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Files count before cleanup: %d", count)
|
||||
|
||||
// Run orphaned cleanup
|
||||
err = repos.Files.DeleteOrphaned(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to delete orphaned files: %v", err)
|
||||
}
|
||||
t.Log("Ran orphaned cleanup")
|
||||
|
||||
// Check files after cleanup
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM files").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Files count after cleanup: %d", count)
|
||||
|
||||
// List remaining files
|
||||
files, err := repos.Files.ListByPrefix(ctx, "/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Remaining files:")
|
||||
for _, f := range files {
|
||||
t.Logf(" - ID: %s, Path: %s", f.ID, f.Path)
|
||||
}
|
||||
|
||||
// Check that orphaned file is gone
|
||||
orphanedFile, err := repos.Files.GetByID(ctx, file1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting file: %v", err)
|
||||
}
|
||||
if orphanedFile != nil {
|
||||
t.Error("orphaned file should have been deleted")
|
||||
// Let's check why it wasn't deleted
|
||||
var exists bool
|
||||
err = db.conn.QueryRow(`
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM snapshot_files
|
||||
WHERE file_id = ?
|
||||
)`, file1.ID).Scan(&exists)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("File1 exists in snapshot_files: %v", exists)
|
||||
} else {
|
||||
t.Log("Orphaned file was correctly deleted")
|
||||
}
|
||||
|
||||
// Check that referenced file still exists
|
||||
referencedFile, err := repos.Files.GetByID(ctx, file2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting file: %v", err)
|
||||
}
|
||||
if referencedFile == nil {
|
||||
t.Error("referenced file should not have been deleted")
|
||||
} else {
|
||||
t.Log("Referenced file correctly remains")
|
||||
}
|
||||
}
|
||||
543
internal/database/repository_edge_cases_test.go
Normal file
543
internal/database/repository_edge_cases_test.go
Normal file
@@ -0,0 +1,543 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestFileRepositoryEdgeCases tests edge cases for file repository
|
||||
func TestFileRepositoryEdgeCases(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repo := NewFileRepository(db)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
file *File
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "empty path",
|
||||
file: &File{
|
||||
Path: "",
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
},
|
||||
wantErr: false, // Empty strings are allowed, only NULL is not allowed
|
||||
},
|
||||
{
|
||||
name: "very long path",
|
||||
file: &File{
|
||||
Path: "/" + strings.Repeat("a", 4096),
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path with special characters",
|
||||
file: &File{
|
||||
Path: "/test/file with spaces and 特殊文字.txt",
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero size file",
|
||||
file: &File{
|
||||
Path: "/empty.txt",
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 0,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "symlink with target",
|
||||
file: &File{
|
||||
Path: "/link",
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 0,
|
||||
Mode: 0777 | 0120000, // symlink mode
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
LinkTarget: "/target",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Add a unique suffix to paths to avoid UNIQUE constraint violations
|
||||
if tt.file.Path != "" {
|
||||
tt.file.Path = fmt.Sprintf("%s_%d_%d", tt.file.Path, i, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
err := repo.Create(ctx, nil, tt.file)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("Create() error = %v, want error containing %q", err, tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDuplicateHandling tests handling of duplicate entries
|
||||
func TestDuplicateHandling(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Test duplicate file paths - Create uses UPSERT logic
|
||||
t.Run("duplicate file paths", func(t *testing.T) {
|
||||
file1 := &File{
|
||||
Path: "/duplicate.txt",
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
file2 := &File{
|
||||
Path: "/duplicate.txt", // Same path
|
||||
MTime: time.Now().Add(time.Hour),
|
||||
CTime: time.Now().Add(time.Hour),
|
||||
Size: 2048,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
|
||||
err := repos.Files.Create(ctx, nil, file1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file1: %v", err)
|
||||
}
|
||||
originalID := file1.ID
|
||||
|
||||
// Create with same path should update the existing record (UPSERT behavior)
|
||||
err = repos.Files.Create(ctx, nil, file2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file2: %v", err)
|
||||
}
|
||||
|
||||
// Verify the file was updated, not duplicated
|
||||
retrievedFile, err := repos.Files.GetByPath(ctx, "/duplicate.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve file: %v", err)
|
||||
}
|
||||
|
||||
// The file should have been updated with file2's data
|
||||
if retrievedFile.Size != 2048 {
|
||||
t.Errorf("expected size 2048, got %d", retrievedFile.Size)
|
||||
}
|
||||
|
||||
// ID might be different due to the UPSERT
|
||||
if retrievedFile.ID != file2.ID {
|
||||
t.Logf("File ID changed from %s to %s during upsert", originalID, retrievedFile.ID)
|
||||
}
|
||||
})
|
||||
|
||||
// Test duplicate chunk hashes
|
||||
t.Run("duplicate chunk hashes", func(t *testing.T) {
|
||||
chunk := &Chunk{
|
||||
ChunkHash: "duplicate-chunk",
|
||||
SHA256: "duplicate-sha",
|
||||
Size: 1024,
|
||||
}
|
||||
|
||||
err := repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create chunk: %v", err)
|
||||
}
|
||||
|
||||
// Creating the same chunk again should be idempotent (ON CONFLICT DO NOTHING)
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Errorf("duplicate chunk creation should be idempotent, got error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test duplicate file-chunk mappings
|
||||
t.Run("duplicate file-chunk mappings", func(t *testing.T) {
|
||||
file := &File{
|
||||
Path: "/test-dup-fc.txt",
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
err := repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
chunk := &Chunk{
|
||||
ChunkHash: "test-chunk-dup",
|
||||
SHA256: "test-sha-dup",
|
||||
Size: 1024,
|
||||
}
|
||||
err = repos.Chunks.Create(ctx, nil, chunk)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fc := &FileChunk{
|
||||
FileID: file.ID,
|
||||
Idx: 0,
|
||||
ChunkHash: chunk.ChunkHash,
|
||||
}
|
||||
|
||||
err = repos.FileChunks.Create(ctx, nil, fc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Creating the same mapping again should be idempotent
|
||||
err = repos.FileChunks.Create(ctx, nil, fc)
|
||||
if err != nil {
|
||||
t.Error("file-chunk creation should be idempotent")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNullHandling tests handling of NULL values
|
||||
func TestNullHandling(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Test file with no link target
|
||||
t.Run("file without link target", func(t *testing.T) {
|
||||
file := &File{
|
||||
Path: "/regular.txt",
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
LinkTarget: "", // Should be stored as NULL
|
||||
}
|
||||
|
||||
err := repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
retrieved, err := repos.Files.GetByID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if retrieved.LinkTarget != "" {
|
||||
t.Errorf("expected empty link target, got %q", retrieved.LinkTarget)
|
||||
}
|
||||
})
|
||||
|
||||
// Test snapshot with NULL completed_at
|
||||
t.Run("incomplete snapshot", func(t *testing.T) {
|
||||
snapshot := &Snapshot{
|
||||
ID: "incomplete-test",
|
||||
Hostname: "test-host",
|
||||
StartedAt: time.Now(),
|
||||
CompletedAt: nil, // Should remain NULL until completed
|
||||
}
|
||||
|
||||
err := repos.Snapshots.Create(ctx, nil, snapshot)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
retrieved, err := repos.Snapshots.GetByID(ctx, snapshot.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if retrieved.CompletedAt != nil {
|
||||
t.Error("expected nil CompletedAt for incomplete snapshot")
|
||||
}
|
||||
})
|
||||
|
||||
// Test blob with NULL uploaded_ts
|
||||
t.Run("blob not uploaded", func(t *testing.T) {
|
||||
blob := &Blob{
|
||||
ID: "not-uploaded",
|
||||
Hash: "test-hash",
|
||||
CreatedTS: time.Now(),
|
||||
UploadedTS: nil, // Not uploaded yet
|
||||
}
|
||||
|
||||
err := repos.Blobs.Create(ctx, nil, blob)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
retrieved, err := repos.Blobs.GetByID(ctx, blob.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if retrieved.UploadedTS != nil {
|
||||
t.Error("expected nil UploadedTS for non-uploaded blob")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLargeDatasets tests operations with large amounts of data
|
||||
func TestLargeDatasets(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping large dataset test in short mode")
|
||||
}
|
||||
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create a snapshot
|
||||
snapshot := &Snapshot{
|
||||
ID: "large-dataset-test",
|
||||
Hostname: "test-host",
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
err := repos.Snapshots.Create(ctx, nil, snapshot)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create many files
|
||||
const fileCount = 1000
|
||||
fileIDs := make([]string, fileCount)
|
||||
|
||||
t.Run("create many files", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < fileCount; i++ {
|
||||
file := &File{
|
||||
Path: fmt.Sprintf("/large/file%05d.txt", i),
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: int64(i * 1024),
|
||||
Mode: 0644,
|
||||
UID: uint32(1000 + (i % 10)),
|
||||
GID: uint32(1000 + (i % 10)),
|
||||
}
|
||||
err := repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create file %d: %v", i, err)
|
||||
}
|
||||
fileIDs[i] = file.ID
|
||||
|
||||
// Add half to snapshot
|
||||
if i%2 == 0 {
|
||||
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Logf("Created %d files in %v", fileCount, time.Since(start))
|
||||
})
|
||||
|
||||
// Test ListByPrefix performance
|
||||
t.Run("list by prefix performance", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
files, err := repos.Files.ListByPrefix(ctx, "/large/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(files) != fileCount {
|
||||
t.Errorf("expected %d files, got %d", fileCount, len(files))
|
||||
}
|
||||
t.Logf("Listed %d files in %v", len(files), time.Since(start))
|
||||
})
|
||||
|
||||
// Test orphaned cleanup performance
|
||||
t.Run("orphaned cleanup performance", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
err := repos.Files.DeleteOrphaned(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Cleaned up orphaned files in %v", time.Since(start))
|
||||
|
||||
// Verify correct number remain
|
||||
files, err := repos.Files.ListByPrefix(ctx, "/large/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(files) != fileCount/2 {
|
||||
t.Errorf("expected %d files after cleanup, got %d", fileCount/2, len(files))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorPropagation tests that errors are properly propagated
|
||||
func TestErrorPropagation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Test GetByID with non-existent ID
|
||||
t.Run("GetByID non-existent", func(t *testing.T) {
|
||||
file, err := repos.Files.GetByID(ctx, "non-existent-uuid")
|
||||
if err != nil {
|
||||
t.Errorf("GetByID should not return error for non-existent ID, got: %v", err)
|
||||
}
|
||||
if file != nil {
|
||||
t.Error("expected nil file for non-existent ID")
|
||||
}
|
||||
})
|
||||
|
||||
// Test GetByPath with non-existent path
|
||||
t.Run("GetByPath non-existent", func(t *testing.T) {
|
||||
file, err := repos.Files.GetByPath(ctx, "/non/existent/path.txt")
|
||||
if err != nil {
|
||||
t.Errorf("GetByPath should not return error for non-existent path, got: %v", err)
|
||||
}
|
||||
if file != nil {
|
||||
t.Error("expected nil file for non-existent path")
|
||||
}
|
||||
})
|
||||
|
||||
// Test invalid foreign key reference
|
||||
t.Run("invalid foreign key", func(t *testing.T) {
|
||||
fc := &FileChunk{
|
||||
FileID: "non-existent-file-id",
|
||||
Idx: 0,
|
||||
ChunkHash: "some-chunk",
|
||||
}
|
||||
err := repos.FileChunks.Create(ctx, nil, fc)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid foreign key")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "FOREIGN KEY") {
|
||||
t.Errorf("expected foreign key error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestQueryInjection tests that the system is safe from SQL injection
|
||||
func TestQueryInjection(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Test various injection attempts
|
||||
injectionTests := []string{
|
||||
"'; DROP TABLE files; --",
|
||||
"' OR '1'='1",
|
||||
"'; DELETE FROM files WHERE '1'='1'; --",
|
||||
`test'); DROP TABLE files; --`,
|
||||
}
|
||||
|
||||
for _, injection := range injectionTests {
|
||||
t.Run("injection attempt", func(t *testing.T) {
|
||||
// Try injection in file path
|
||||
file := &File{
|
||||
Path: injection,
|
||||
MTime: time.Now(),
|
||||
CTime: time.Now(),
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
_ = repos.Files.Create(ctx, nil, file)
|
||||
// Should either succeed (treating as normal string) or fail with constraint
|
||||
// but should NOT execute the injected SQL
|
||||
|
||||
// Verify tables still exist
|
||||
var count int
|
||||
err := db.conn.QueryRow("SELECT COUNT(*) FROM files").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatal("files table was damaged by injection")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTimezoneHandling tests that times are properly handled in UTC
|
||||
func TestTimezoneHandling(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
repos := NewRepositories(db)
|
||||
|
||||
// Create file with specific timezone
|
||||
loc, err := time.LoadLocation("America/New_York")
|
||||
if err != nil {
|
||||
t.Skip("timezone not available")
|
||||
}
|
||||
|
||||
// Use Truncate to remove sub-second precision since we store as Unix timestamps
|
||||
nyTime := time.Now().In(loc).Truncate(time.Second)
|
||||
file := &File{
|
||||
Path: "/timezone-test.txt",
|
||||
MTime: nyTime,
|
||||
CTime: nyTime,
|
||||
Size: 1024,
|
||||
Mode: 0644,
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
}
|
||||
|
||||
err = repos.Files.Create(ctx, nil, file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Retrieve and verify times are in UTC
|
||||
retrieved, err := repos.Files.GetByID(ctx, file.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check that times are equivalent (same instant)
|
||||
if !retrieved.MTime.Equal(nyTime) {
|
||||
t.Error("time was not preserved correctly")
|
||||
}
|
||||
|
||||
// Check that retrieved time is in UTC
|
||||
if retrieved.MTime.Location() != time.UTC {
|
||||
t.Error("retrieved time is not in UTC")
|
||||
}
|
||||
}
|
||||
113
internal/database/schema.sql
Normal file
113
internal/database/schema.sql
Normal file
@@ -0,0 +1,113 @@
|
||||
-- Vaultik Database Schema
|
||||
-- Note: This database does not support migrations. If the schema changes,
|
||||
-- delete the local database and perform a full backup to recreate it.
|
||||
|
||||
-- Files table: stores metadata about files in the filesystem
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
id TEXT PRIMARY KEY, -- UUID
|
||||
path TEXT NOT NULL UNIQUE,
|
||||
mtime INTEGER NOT NULL,
|
||||
ctime INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
mode INTEGER NOT NULL,
|
||||
uid INTEGER NOT NULL,
|
||||
gid INTEGER NOT NULL,
|
||||
link_target TEXT
|
||||
);
|
||||
|
||||
-- Create index on path for efficient lookups
|
||||
CREATE INDEX IF NOT EXISTS idx_files_path ON files(path);
|
||||
|
||||
-- File chunks table: maps files to their constituent chunks
|
||||
CREATE TABLE IF NOT EXISTS file_chunks (
|
||||
file_id TEXT NOT NULL,
|
||||
idx INTEGER NOT NULL,
|
||||
chunk_hash TEXT NOT NULL,
|
||||
PRIMARY KEY (file_id, idx),
|
||||
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Chunks table: stores unique content-defined chunks
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
chunk_hash TEXT PRIMARY KEY,
|
||||
sha256 TEXT NOT NULL,
|
||||
size INTEGER NOT NULL
|
||||
);
|
||||
|
||||
-- Blobs table: stores packed, compressed, and encrypted blob information
|
||||
CREATE TABLE IF NOT EXISTS blobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
blob_hash TEXT UNIQUE,
|
||||
created_ts INTEGER NOT NULL,
|
||||
finished_ts INTEGER,
|
||||
uncompressed_size INTEGER NOT NULL DEFAULT 0,
|
||||
compressed_size INTEGER NOT NULL DEFAULT 0,
|
||||
uploaded_ts INTEGER
|
||||
);
|
||||
|
||||
-- Blob chunks table: maps chunks to the blobs that contain them
|
||||
CREATE TABLE IF NOT EXISTS blob_chunks (
|
||||
blob_id TEXT NOT NULL,
|
||||
chunk_hash TEXT NOT NULL,
|
||||
offset INTEGER NOT NULL,
|
||||
length INTEGER NOT NULL,
|
||||
PRIMARY KEY (blob_id, chunk_hash),
|
||||
FOREIGN KEY (blob_id) REFERENCES blobs(id)
|
||||
);
|
||||
|
||||
-- Chunk files table: reverse mapping of chunks to files
|
||||
CREATE TABLE IF NOT EXISTS chunk_files (
|
||||
chunk_hash TEXT NOT NULL,
|
||||
file_id TEXT NOT NULL,
|
||||
file_offset INTEGER NOT NULL,
|
||||
length INTEGER NOT NULL,
|
||||
PRIMARY KEY (chunk_hash, file_id),
|
||||
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Snapshots table: tracks backup snapshots
|
||||
CREATE TABLE IF NOT EXISTS snapshots (
|
||||
id TEXT PRIMARY KEY,
|
||||
hostname TEXT NOT NULL,
|
||||
vaultik_version TEXT NOT NULL,
|
||||
vaultik_git_revision TEXT NOT NULL,
|
||||
started_at INTEGER NOT NULL,
|
||||
completed_at INTEGER,
|
||||
file_count INTEGER NOT NULL DEFAULT 0,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
blob_count INTEGER NOT NULL DEFAULT 0,
|
||||
total_size INTEGER NOT NULL DEFAULT 0,
|
||||
blob_size INTEGER NOT NULL DEFAULT 0,
|
||||
blob_uncompressed_size INTEGER NOT NULL DEFAULT 0,
|
||||
compression_ratio REAL NOT NULL DEFAULT 1.0,
|
||||
compression_level INTEGER NOT NULL DEFAULT 3,
|
||||
upload_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
upload_duration_ms INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
-- Snapshot files table: maps snapshots to files
|
||||
CREATE TABLE IF NOT EXISTS snapshot_files (
|
||||
snapshot_id TEXT NOT NULL,
|
||||
file_id TEXT NOT NULL,
|
||||
PRIMARY KEY (snapshot_id, file_id),
|
||||
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Snapshot blobs table: maps snapshots to blobs
|
||||
CREATE TABLE IF NOT EXISTS snapshot_blobs (
|
||||
snapshot_id TEXT NOT NULL,
|
||||
blob_id TEXT NOT NULL,
|
||||
blob_hash TEXT NOT NULL,
|
||||
PRIMARY KEY (snapshot_id, blob_id),
|
||||
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Uploads table: tracks blob upload metrics
|
||||
CREATE TABLE IF NOT EXISTS uploads (
|
||||
blob_hash TEXT PRIMARY KEY,
|
||||
uploaded_at INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
duration_ms INTEGER NOT NULL
|
||||
);
|
||||
@@ -17,8 +17,10 @@ func NewSnapshotRepository(db *DB) *SnapshotRepository {
|
||||
|
||||
func (r *SnapshotRepository) Create(ctx context.Context, tx *sql.Tx, snapshot *Snapshot) error {
|
||||
query := `
|
||||
INSERT INTO snapshots (id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO snapshots (id, hostname, vaultik_version, vaultik_git_revision, started_at, completed_at,
|
||||
file_count, chunk_count, blob_count, total_size, blob_size, blob_uncompressed_size,
|
||||
compression_ratio, compression_level, upload_bytes, upload_duration_ms)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
var completedAt *int64
|
||||
@@ -29,11 +31,13 @@ func (r *SnapshotRepository) Create(ctx context.Context, tx *sql.Tx, snapshot *S
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.StartedAt.Unix(),
|
||||
completedAt, snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.CompressionRatio)
|
||||
_, err = tx.ExecContext(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.VaultikGitRevision, snapshot.StartedAt.Unix(),
|
||||
completedAt, snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.BlobUncompressedSize,
|
||||
snapshot.CompressionRatio, snapshot.CompressionLevel, snapshot.UploadBytes, snapshot.UploadDurationMs)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.StartedAt.Unix(),
|
||||
completedAt, snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.CompressionRatio)
|
||||
_, err = r.db.ExecWithLog(ctx, query, snapshot.ID, snapshot.Hostname, snapshot.VaultikVersion, snapshot.VaultikGitRevision, snapshot.StartedAt.Unix(),
|
||||
completedAt, snapshot.FileCount, snapshot.ChunkCount, snapshot.BlobCount, snapshot.TotalSize, snapshot.BlobSize, snapshot.BlobUncompressedSize,
|
||||
snapshot.CompressionRatio, snapshot.CompressionLevel, snapshot.UploadBytes, snapshot.UploadDurationMs)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -64,7 +68,7 @@ func (r *SnapshotRepository) UpdateCounts(ctx context.Context, tx *sql.Tx, snaps
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, fileCount, chunkCount, blobCount, totalSize, blobSize, compressionRatio, snapshotID)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, fileCount, chunkCount, blobCount, totalSize, blobSize, compressionRatio, snapshotID)
|
||||
_, err = r.db.ExecWithLog(ctx, query, fileCount, chunkCount, blobCount, totalSize, blobSize, compressionRatio, snapshotID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -74,9 +78,58 @@ func (r *SnapshotRepository) UpdateCounts(ctx context.Context, tx *sql.Tx, snaps
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateExtendedStats updates extended statistics for a snapshot
|
||||
func (r *SnapshotRepository) UpdateExtendedStats(ctx context.Context, tx *sql.Tx, snapshotID string, blobUncompressedSize int64, compressionLevel int, uploadDurationMs int64) error {
|
||||
// Calculate compression ratio based on uncompressed vs compressed sizes
|
||||
var compressionRatio float64
|
||||
if blobUncompressedSize > 0 {
|
||||
// Get current blob_size from DB to calculate ratio
|
||||
var blobSize int64
|
||||
queryGet := `SELECT blob_size FROM snapshots WHERE id = ?`
|
||||
if tx != nil {
|
||||
err := tx.QueryRowContext(ctx, queryGet, snapshotID).Scan(&blobSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting blob size: %w", err)
|
||||
}
|
||||
} else {
|
||||
err := r.db.conn.QueryRowContext(ctx, queryGet, snapshotID).Scan(&blobSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting blob size: %w", err)
|
||||
}
|
||||
}
|
||||
compressionRatio = float64(blobSize) / float64(blobUncompressedSize)
|
||||
} else {
|
||||
compressionRatio = 1.0
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE snapshots
|
||||
SET blob_uncompressed_size = ?,
|
||||
compression_ratio = ?,
|
||||
compression_level = ?,
|
||||
upload_bytes = blob_size,
|
||||
upload_duration_ms = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, blobUncompressedSize, compressionRatio, compressionLevel, uploadDurationMs, snapshotID)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLog(ctx, query, blobUncompressedSize, compressionRatio, compressionLevel, uploadDurationMs, snapshotID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating extended stats: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SnapshotRepository) GetByID(ctx context.Context, snapshotID string) (*Snapshot, error) {
|
||||
query := `
|
||||
SELECT id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
|
||||
SELECT id, hostname, vaultik_version, vaultik_git_revision, started_at, completed_at,
|
||||
file_count, chunk_count, blob_count, total_size, blob_size, blob_uncompressed_size,
|
||||
compression_ratio, compression_level, upload_bytes, upload_duration_ms
|
||||
FROM snapshots
|
||||
WHERE id = ?
|
||||
`
|
||||
@@ -89,6 +142,7 @@ func (r *SnapshotRepository) GetByID(ctx context.Context, snapshotID string) (*S
|
||||
&snapshot.ID,
|
||||
&snapshot.Hostname,
|
||||
&snapshot.VaultikVersion,
|
||||
&snapshot.VaultikGitRevision,
|
||||
&startedAtUnix,
|
||||
&completedAtUnix,
|
||||
&snapshot.FileCount,
|
||||
@@ -96,7 +150,11 @@ func (r *SnapshotRepository) GetByID(ctx context.Context, snapshotID string) (*S
|
||||
&snapshot.BlobCount,
|
||||
&snapshot.TotalSize,
|
||||
&snapshot.BlobSize,
|
||||
&snapshot.BlobUncompressedSize,
|
||||
&snapshot.CompressionRatio,
|
||||
&snapshot.CompressionLevel,
|
||||
&snapshot.UploadBytes,
|
||||
&snapshot.UploadDurationMs,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -106,9 +164,9 @@ func (r *SnapshotRepository) GetByID(ctx context.Context, snapshotID string) (*S
|
||||
return nil, fmt.Errorf("querying snapshot: %w", err)
|
||||
}
|
||||
|
||||
snapshot.StartedAt = time.Unix(startedAtUnix, 0)
|
||||
snapshot.StartedAt = time.Unix(startedAtUnix, 0).UTC()
|
||||
if completedAtUnix != nil {
|
||||
t := time.Unix(*completedAtUnix, 0)
|
||||
t := time.Unix(*completedAtUnix, 0).UTC()
|
||||
snapshot.CompletedAt = &t
|
||||
}
|
||||
|
||||
@@ -117,7 +175,7 @@ func (r *SnapshotRepository) GetByID(ctx context.Context, snapshotID string) (*S
|
||||
|
||||
func (r *SnapshotRepository) ListRecent(ctx context.Context, limit int) ([]*Snapshot, error) {
|
||||
query := `
|
||||
SELECT id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
|
||||
SELECT id, hostname, vaultik_version, vaultik_git_revision, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
|
||||
FROM snapshots
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
@@ -139,6 +197,7 @@ func (r *SnapshotRepository) ListRecent(ctx context.Context, limit int) ([]*Snap
|
||||
&snapshot.ID,
|
||||
&snapshot.Hostname,
|
||||
&snapshot.VaultikVersion,
|
||||
&snapshot.VaultikGitRevision,
|
||||
&startedAtUnix,
|
||||
&completedAtUnix,
|
||||
&snapshot.FileCount,
|
||||
@@ -172,13 +231,13 @@ func (r *SnapshotRepository) MarkComplete(ctx context.Context, tx *sql.Tx, snaps
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
completedAt := time.Now().Unix()
|
||||
completedAt := time.Now().UTC().Unix()
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, completedAt, snapshotID)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, completedAt, snapshotID)
|
||||
_, err = r.db.ExecWithLog(ctx, query, completedAt, snapshotID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -191,15 +250,36 @@ func (r *SnapshotRepository) MarkComplete(ctx context.Context, tx *sql.Tx, snaps
|
||||
// AddFile adds a file to a snapshot
|
||||
func (r *SnapshotRepository) AddFile(ctx context.Context, tx *sql.Tx, snapshotID string, filePath string) error {
|
||||
query := `
|
||||
INSERT OR IGNORE INTO snapshot_files (snapshot_id, file_path)
|
||||
VALUES (?, ?)
|
||||
INSERT OR IGNORE INTO snapshot_files (snapshot_id, file_id)
|
||||
SELECT ?, id FROM files WHERE path = ?
|
||||
`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, snapshotID, filePath)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, snapshotID, filePath)
|
||||
_, err = r.db.ExecWithLog(ctx, query, snapshotID, filePath)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding file to snapshot: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddFileByID adds a file to a snapshot by file ID
|
||||
func (r *SnapshotRepository) AddFileByID(ctx context.Context, tx *sql.Tx, snapshotID string, fileID string) error {
|
||||
query := `
|
||||
INSERT OR IGNORE INTO snapshot_files (snapshot_id, file_id)
|
||||
VALUES (?, ?)
|
||||
`
|
||||
|
||||
var err error
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, snapshotID, fileID)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLog(ctx, query, snapshotID, fileID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -220,7 +300,7 @@ func (r *SnapshotRepository) AddBlob(ctx context.Context, tx *sql.Tx, snapshotID
|
||||
if tx != nil {
|
||||
_, err = tx.ExecContext(ctx, query, snapshotID, blobID, blobHash)
|
||||
} else {
|
||||
_, err = r.db.ExecWithLock(ctx, query, snapshotID, blobID, blobHash)
|
||||
_, err = r.db.ExecWithLog(ctx, query, snapshotID, blobID, blobHash)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -260,7 +340,7 @@ func (r *SnapshotRepository) GetBlobHashes(ctx context.Context, snapshotID strin
|
||||
// GetIncompleteSnapshots returns all snapshots that haven't been completed
|
||||
func (r *SnapshotRepository) GetIncompleteSnapshots(ctx context.Context) ([]*Snapshot, error) {
|
||||
query := `
|
||||
SELECT id, hostname, vaultik_version, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
|
||||
SELECT id, hostname, vaultik_version, vaultik_git_revision, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
|
||||
FROM snapshots
|
||||
WHERE completed_at IS NULL
|
||||
ORDER BY started_at DESC
|
||||
@@ -282,6 +362,7 @@ func (r *SnapshotRepository) GetIncompleteSnapshots(ctx context.Context) ([]*Sna
|
||||
&snapshot.ID,
|
||||
&snapshot.Hostname,
|
||||
&snapshot.VaultikVersion,
|
||||
&snapshot.VaultikGitRevision,
|
||||
&startedAtUnix,
|
||||
&completedAtUnix,
|
||||
&snapshot.FileCount,
|
||||
@@ -306,3 +387,90 @@ func (r *SnapshotRepository) GetIncompleteSnapshots(ctx context.Context) ([]*Sna
|
||||
|
||||
return snapshots, rows.Err()
|
||||
}
|
||||
|
||||
// GetIncompleteByHostname returns all incomplete snapshots for a specific hostname
|
||||
func (r *SnapshotRepository) GetIncompleteByHostname(ctx context.Context, hostname string) ([]*Snapshot, error) {
|
||||
query := `
|
||||
SELECT id, hostname, vaultik_version, vaultik_git_revision, started_at, completed_at, file_count, chunk_count, blob_count, total_size, blob_size, compression_ratio
|
||||
FROM snapshots
|
||||
WHERE completed_at IS NULL AND hostname = ?
|
||||
ORDER BY started_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query, hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying incomplete snapshots: %w", err)
|
||||
}
|
||||
defer CloseRows(rows)
|
||||
|
||||
var snapshots []*Snapshot
|
||||
for rows.Next() {
|
||||
var snapshot Snapshot
|
||||
var startedAtUnix int64
|
||||
var completedAtUnix *int64
|
||||
|
||||
err := rows.Scan(
|
||||
&snapshot.ID,
|
||||
&snapshot.Hostname,
|
||||
&snapshot.VaultikVersion,
|
||||
&snapshot.VaultikGitRevision,
|
||||
&startedAtUnix,
|
||||
&completedAtUnix,
|
||||
&snapshot.FileCount,
|
||||
&snapshot.ChunkCount,
|
||||
&snapshot.BlobCount,
|
||||
&snapshot.TotalSize,
|
||||
&snapshot.BlobSize,
|
||||
&snapshot.CompressionRatio,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning snapshot: %w", err)
|
||||
}
|
||||
|
||||
snapshot.StartedAt = time.Unix(startedAtUnix, 0).UTC()
|
||||
if completedAtUnix != nil {
|
||||
t := time.Unix(*completedAtUnix, 0).UTC()
|
||||
snapshot.CompletedAt = &t
|
||||
}
|
||||
|
||||
snapshots = append(snapshots, &snapshot)
|
||||
}
|
||||
|
||||
return snapshots, rows.Err()
|
||||
}
|
||||
|
||||
// Delete removes a snapshot record
|
||||
func (r *SnapshotRepository) Delete(ctx context.Context, snapshotID string) error {
|
||||
query := `DELETE FROM snapshots WHERE id = ?`
|
||||
|
||||
_, err := r.db.ExecWithLog(ctx, query, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting snapshot: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSnapshotFiles removes all snapshot_files entries for a snapshot
|
||||
func (r *SnapshotRepository) DeleteSnapshotFiles(ctx context.Context, snapshotID string) error {
|
||||
query := `DELETE FROM snapshot_files WHERE snapshot_id = ?`
|
||||
|
||||
_, err := r.db.ExecWithLog(ctx, query, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting snapshot files: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSnapshotBlobs removes all snapshot_blobs entries for a snapshot
|
||||
func (r *SnapshotRepository) DeleteSnapshotBlobs(ctx context.Context, snapshotID string) error {
|
||||
query := `DELETE FROM snapshot_blobs WHERE snapshot_id = ?`
|
||||
|
||||
_, err := r.db.ExecWithLog(ctx, query, snapshotID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting snapshot blobs: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,21 +3,29 @@ package s3
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
)
|
||||
|
||||
// Client wraps the AWS S3 client for vaultik operations
|
||||
// Client wraps the AWS S3 client for vaultik operations.
|
||||
// It provides a simplified interface for S3 operations with automatic
|
||||
// prefix handling and connection management. All operations are performed
|
||||
// within the configured bucket and prefix.
|
||||
type Client struct {
|
||||
s3Client *s3.Client
|
||||
bucket string
|
||||
prefix string
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// Config contains S3 client configuration
|
||||
// Config contains S3 client configuration.
|
||||
// All fields are required except Prefix, which defaults to an empty string.
|
||||
// The Endpoint field should include the protocol (http:// or https://).
|
||||
type Config struct {
|
||||
Endpoint string
|
||||
Bucket string
|
||||
@@ -27,7 +35,10 @@ type Config struct {
|
||||
Region string
|
||||
}
|
||||
|
||||
// NewClient creates a new S3 client
|
||||
// NewClient creates a new S3 client with the provided configuration.
|
||||
// It establishes a connection to the S3-compatible storage service and
|
||||
// validates the credentials. The client uses static credentials and
|
||||
// path-style URLs for compatibility with various S3-compatible services.
|
||||
func NewClient(ctx context.Context, cfg Config) (*Client, error) {
|
||||
// Create AWS config
|
||||
awsCfg, err := config.LoadDefaultConfig(ctx,
|
||||
@@ -56,10 +67,14 @@ func NewClient(ctx context.Context, cfg Config) (*Client, error) {
|
||||
s3Client: s3Client,
|
||||
bucket: cfg.Bucket,
|
||||
prefix: cfg.Prefix,
|
||||
endpoint: cfg.Endpoint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PutObject uploads an object to S3
|
||||
// PutObject uploads an object to S3 with the specified key.
|
||||
// The key is automatically prefixed with the configured prefix.
|
||||
// The data parameter should be a reader containing the object data.
|
||||
// Returns an error if the upload fails.
|
||||
func (c *Client) PutObject(ctx context.Context, key string, data io.Reader) error {
|
||||
fullKey := c.prefix + key
|
||||
_, err := c.s3Client.PutObject(ctx, &s3.PutObjectInput{
|
||||
@@ -70,7 +85,46 @@ func (c *Client) PutObject(ctx context.Context, key string, data io.Reader) erro
|
||||
return err
|
||||
}
|
||||
|
||||
// GetObject downloads an object from S3
|
||||
// ProgressCallback is called during upload progress with bytes uploaded so far.
|
||||
// The callback should return an error to cancel the upload.
|
||||
type ProgressCallback func(bytesUploaded int64) error
|
||||
|
||||
// PutObjectWithProgress uploads an object to S3 with progress tracking.
|
||||
// The key is automatically prefixed with the configured prefix.
|
||||
// The size parameter must be the exact size of the data to upload.
|
||||
// The progress callback is called periodically with the number of bytes uploaded.
|
||||
// Returns an error if the upload fails.
|
||||
func (c *Client) PutObjectWithProgress(ctx context.Context, key string, data io.Reader, size int64, progress ProgressCallback) error {
|
||||
fullKey := c.prefix + key
|
||||
|
||||
// Create an uploader with the S3 client
|
||||
uploader := manager.NewUploader(c.s3Client, func(u *manager.Uploader) {
|
||||
// Set part size to 10MB for better progress granularity
|
||||
u.PartSize = 10 * 1024 * 1024
|
||||
})
|
||||
|
||||
// Create a progress reader that tracks upload progress
|
||||
pr := &progressReader{
|
||||
reader: data,
|
||||
size: size,
|
||||
callback: progress,
|
||||
read: 0,
|
||||
}
|
||||
|
||||
// Upload the file
|
||||
_, err := uploader.Upload(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(c.bucket),
|
||||
Key: aws.String(fullKey),
|
||||
Body: pr,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetObject downloads an object from S3 with the specified key.
|
||||
// The key is automatically prefixed with the configured prefix.
|
||||
// Returns a ReadCloser containing the object data. The caller must
|
||||
// close the returned reader when done to avoid resource leaks.
|
||||
func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
fullKey := c.prefix + key
|
||||
result, err := c.s3Client.GetObject(ctx, &s3.GetObjectInput{
|
||||
@@ -83,7 +137,9 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, erro
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
// DeleteObject removes an object from S3
|
||||
// DeleteObject removes an object from S3 with the specified key.
|
||||
// The key is automatically prefixed with the configured prefix.
|
||||
// No error is returned if the object doesn't exist.
|
||||
func (c *Client) DeleteObject(ctx context.Context, key string) error {
|
||||
fullKey := c.prefix + key
|
||||
_, err := c.s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
@@ -93,7 +149,11 @@ func (c *Client) DeleteObject(ctx context.Context, key string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// ListObjects lists objects with the given prefix
|
||||
// ListObjects lists all objects with the given prefix.
|
||||
// The prefix is combined with the client's configured prefix.
|
||||
// Returns a slice of object keys with the base prefix removed.
|
||||
// This method loads all matching keys into memory, so use
|
||||
// ListObjectsStream for large result sets.
|
||||
func (c *Client) ListObjects(ctx context.Context, prefix string) ([]string, error) {
|
||||
fullPrefix := c.prefix + prefix
|
||||
|
||||
@@ -124,7 +184,10 @@ func (c *Client) ListObjects(ctx context.Context, prefix string) ([]string, erro
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// HeadObject checks if an object exists
|
||||
// HeadObject checks if an object exists in S3.
|
||||
// Returns true if the object exists, false otherwise.
|
||||
// The key is automatically prefixed with the configured prefix.
|
||||
// Note: This method returns false for any error, not just "not found".
|
||||
func (c *Client) HeadObject(ctx context.Context, key string) (bool, error) {
|
||||
fullKey := c.prefix + key
|
||||
_, err := c.s3Client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
@@ -138,3 +201,126 @@ func (c *Client) HeadObject(ctx context.Context, key string) (bool, error) {
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ObjectInfo contains information about an S3 object.
|
||||
// It is used by ListObjectsStream to return object metadata
|
||||
// along with any errors encountered during listing.
|
||||
type ObjectInfo struct {
|
||||
Key string
|
||||
Size int64
|
||||
Err error
|
||||
}
|
||||
|
||||
// ListObjectsStream lists objects with the given prefix and returns a channel.
|
||||
// This method is preferred for large result sets as it streams results
|
||||
// instead of loading everything into memory. The channel is closed when
|
||||
// listing is complete or an error occurs. If an error occurs, it will be
|
||||
// sent as the last item with the Err field set. The recursive parameter
|
||||
// is currently unused but reserved for future use.
|
||||
func (c *Client) ListObjectsStream(ctx context.Context, prefix string, recursive bool) <-chan ObjectInfo {
|
||||
ch := make(chan ObjectInfo)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fullPrefix := c.prefix + prefix
|
||||
|
||||
paginator := s3.NewListObjectsV2Paginator(c.s3Client, &s3.ListObjectsV2Input{
|
||||
Bucket: aws.String(c.bucket),
|
||||
Prefix: aws.String(fullPrefix),
|
||||
})
|
||||
|
||||
for paginator.HasMorePages() {
|
||||
page, err := paginator.NextPage(ctx)
|
||||
if err != nil {
|
||||
ch <- ObjectInfo{Err: err}
|
||||
return
|
||||
}
|
||||
|
||||
for _, obj := range page.Contents {
|
||||
if obj.Key != nil && obj.Size != nil {
|
||||
// Remove the base prefix from the key
|
||||
key := *obj.Key
|
||||
if len(key) > len(c.prefix) {
|
||||
key = key[len(c.prefix):]
|
||||
}
|
||||
ch <- ObjectInfo{
|
||||
Key: key,
|
||||
Size: *obj.Size,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// StatObject returns information about an object without downloading it.
|
||||
// The key is automatically prefixed with the configured prefix.
|
||||
// Returns an ObjectInfo struct with the object's metadata.
|
||||
// Returns an error if the object doesn't exist or if the operation fails.
|
||||
func (c *Client) StatObject(ctx context.Context, key string) (*ObjectInfo, error) {
|
||||
fullKey := c.prefix + key
|
||||
result, err := c.s3Client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(c.bucket),
|
||||
Key: aws.String(fullKey),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
size := int64(0)
|
||||
if result.ContentLength != nil {
|
||||
size = *result.ContentLength
|
||||
}
|
||||
|
||||
return &ObjectInfo{
|
||||
Key: key,
|
||||
Size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RemoveObject deletes an object from S3 (alias for DeleteObject).
|
||||
// This method exists for API compatibility and simply calls DeleteObject.
|
||||
func (c *Client) RemoveObject(ctx context.Context, key string) error {
|
||||
return c.DeleteObject(ctx, key)
|
||||
}
|
||||
|
||||
// BucketName returns the configured S3 bucket name.
|
||||
// This is useful for displaying configuration information.
|
||||
func (c *Client) BucketName() string {
|
||||
return c.bucket
|
||||
}
|
||||
|
||||
// Endpoint returns the S3 endpoint URL.
|
||||
// If no custom endpoint was configured, returns the default AWS S3 endpoint.
|
||||
// This is useful for displaying configuration information.
|
||||
func (c *Client) Endpoint() string {
|
||||
if c.endpoint == "" {
|
||||
return "s3.amazonaws.com"
|
||||
}
|
||||
return c.endpoint
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader to track reading progress
|
||||
type progressReader struct {
|
||||
reader io.Reader
|
||||
size int64
|
||||
read int64
|
||||
callback ProgressCallback
|
||||
}
|
||||
|
||||
// Read implements io.Reader
|
||||
func (pr *progressReader) Read(p []byte) (int, error) {
|
||||
n, err := pr.reader.Read(p)
|
||||
if n > 0 {
|
||||
atomic.AddInt64(&pr.read, int64(n))
|
||||
if pr.callback != nil {
|
||||
if callbackErr := pr.callback(atomic.LoadInt64(&pr.read)); callbackErr != nil {
|
||||
return n, callbackErr
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"go.uber.org/fx"
|
||||
)
|
||||
|
||||
// Module exports S3 functionality
|
||||
// Module exports S3 functionality as an fx module.
|
||||
// It provides automatic dependency injection for the S3 client,
|
||||
// configuring it based on the application's configuration settings.
|
||||
var Module = fx.Module("s3",
|
||||
fx.Provide(
|
||||
provideClient,
|
||||
|
||||
Reference in New Issue
Block a user