From d3afa654207bb1de52e8cf05921ace8353903a78 Mon Sep 17 00:00:00 2001 From: sneak Date: Sat, 26 Jul 2025 02:22:25 +0200 Subject: [PATCH] Fix foreign key constraints and improve snapshot tracking - Add unified compression/encryption package in internal/blobgen - Update DATAMODEL.md to reflect current schema implementation - Refactor snapshot cleanup into well-named methods for clarity - Add snapshot_id to uploads table to track new blobs per snapshot - Fix blob count reporting for incremental backups - Add DeleteOrphaned method to BlobChunkRepository - Fix cleanup order to respect foreign key constraints - Update tests to reflect schema changes --- DATAMODEL.md | 88 +-- internal/backup/backup_test.go | 1 - internal/backup/module.go | 1 + internal/backup/scanner.go | 55 +- internal/backup/snapshot.go | 538 ++++++++++-------- internal/backup/snapshot_test.go | 26 +- internal/blob/packer.go | 132 ++--- internal/blob/packer_test.go | 81 ++- internal/blobgen/compress.go | 67 +++ internal/blobgen/reader.go | 73 +++ internal/blobgen/writer.go | 112 ++++ internal/cli/snapshot.go | 166 +++--- internal/database/blob_chunks.go | 29 + internal/database/blob_chunks_test.go | 2 - internal/database/cascade_debug_test.go | 1 - internal/database/chunk_files_test.go | 31 +- internal/database/chunks.go | 17 +- internal/database/chunks_ext.go | 3 +- internal/database/chunks_test.go | 5 - internal/database/database_test.go | 4 +- internal/database/file_chunks_test.go | 44 +- internal/database/files.go | 1 - internal/database/models.go | 4 +- internal/database/repositories_test.go | 3 - .../database/repository_comprehensive_test.go | 5 - .../database/repository_edge_cases_test.go | 2 - internal/database/schema.sql | 17 +- internal/database/uploads.go | 20 +- 28 files changed, 994 insertions(+), 534 deletions(-) create mode 100644 internal/blobgen/compress.go create mode 100644 internal/blobgen/reader.go create mode 100644 internal/blobgen/writer.go diff --git a/DATAMODEL.md b/DATAMODEL.md index c825507..2111570 100644 --- a/DATAMODEL.md +++ b/DATAMODEL.md @@ -15,14 +15,17 @@ Stores metadata about files in the filesystem being backed up. **Columns:** - `id` (TEXT PRIMARY KEY) - UUID for the file record -- `path` (TEXT UNIQUE) - Absolute file path -- `mtime` (INTEGER) - Modification time as Unix timestamp -- `ctime` (INTEGER) - Change time as Unix timestamp -- `size` (INTEGER) - File size in bytes -- `mode` (INTEGER) - Unix file permissions and type -- `uid` (INTEGER) - User ID of file owner -- `gid` (INTEGER) - Group ID of file owner -- `link_target` (TEXT) - Symlink target path (empty for regular files) +- `path` (TEXT NOT NULL UNIQUE) - Absolute file path +- `mtime` (INTEGER NOT NULL) - Modification time as Unix timestamp +- `ctime` (INTEGER NOT NULL) - Change time as Unix timestamp +- `size` (INTEGER NOT NULL) - File size in bytes +- `mode` (INTEGER NOT NULL) - Unix file permissions and type +- `uid` (INTEGER NOT NULL) - User ID of file owner +- `gid` (INTEGER NOT NULL) - Group ID of file owner +- `link_target` (TEXT) - Symlink target path (NULL for regular files) + +**Indexes:** +- `idx_files_path` on `path` for efficient lookups **Purpose:** Tracks file metadata to detect changes between backup runs. Used for incremental backup decisions. The UUID primary key provides stable references that don't change if files are moved. @@ -31,8 +34,7 @@ Stores information about content-defined chunks created from files. **Columns:** - `chunk_hash` (TEXT PRIMARY KEY) - SHA256 hash of chunk content -- `sha256` (TEXT) - SHA256 hash (currently same as chunk_hash) -- `size` (INTEGER) - Chunk size in bytes +- `size` (INTEGER NOT NULL) - Chunk size in bytes **Purpose:** Enables deduplication by tracking unique chunks across all files. @@ -64,11 +66,11 @@ Stores information about packed, compressed, and encrypted blob files. **Columns:** - `id` (TEXT PRIMARY KEY) - UUID assigned when blob creation starts -- `hash` (TEXT) - SHA256 hash of final blob (empty until finalized) -- `created_ts` (INTEGER) - Creation timestamp +- `blob_hash` (TEXT UNIQUE) - SHA256 hash of final blob (NULL until finalized) +- `created_ts` (INTEGER NOT NULL) - Creation timestamp - `finished_ts` (INTEGER) - Finalization timestamp (NULL if in progress) -- `uncompressed_size` (INTEGER) - Total size of chunks before compression -- `compressed_size` (INTEGER) - Size after compression and encryption +- `uncompressed_size` (INTEGER NOT NULL DEFAULT 0) - Total size of chunks before compression +- `compressed_size` (INTEGER NOT NULL DEFAULT 0) - Size after compression and encryption - `uploaded_ts` (INTEGER) - Upload completion timestamp (NULL if not uploaded) **Purpose:** Tracks blob lifecycle from creation through upload. The UUID primary key allows immediate association of chunks with blobs. @@ -134,11 +136,12 @@ Tracks blob upload metrics. **Columns:** - `blob_hash` (TEXT PRIMARY KEY) - Hash of uploaded blob +- `snapshot_id` (TEXT NOT NULL) - The snapshot that triggered this upload (FK to snapshots.id) - `uploaded_at` (INTEGER) - Upload timestamp - `size` (INTEGER) - Size of uploaded blob - `duration_ms` (INTEGER) - Upload duration in milliseconds -**Purpose:** Performance monitoring and upload tracking. +**Purpose:** Performance monitoring and tracking which blobs were newly created (uploaded) during each snapshot. ## Data Flow and Operations @@ -155,13 +158,13 @@ Tracks blob upload metrics. - `INSERT INTO chunk_files` - Create reverse mapping 3. **Blob Packing** - - `INSERT INTO blobs` - Create blob record with UUID (hash empty) + - `INSERT INTO blobs` - Create blob record with UUID (blob_hash NULL) - `INSERT INTO blob_chunks` - Associate chunks with blob immediately - - `UPDATE blobs SET hash = ?, finished_ts = ?` - Finalize blob after packing + - `UPDATE blobs SET blob_hash = ?, finished_ts = ?` - Finalize blob after packing 4. **Upload** - `UPDATE blobs SET uploaded_ts = ?` - Mark blob as uploaded - - `INSERT INTO uploads` - Record upload metrics + - `INSERT INTO uploads` - Record upload metrics with snapshot_id - `INSERT INTO snapshot_blobs` - Associate blob with snapshot 5. **Snapshot Completion** @@ -179,37 +182,56 @@ Tracks blob upload metrics. - `SELECT * FROM blob_chunks WHERE chunk_hash = ?` - Find existing chunks - `INSERT INTO snapshot_blobs` - Reference existing blobs for unchanged files -### 3. Restore Process +### 3. Snapshot Metadata Export + +After a snapshot is completed: +1. Copy database to temporary file +2. Clean temporary database to contain only current snapshot data +3. Export to SQL dump using sqlite3 +4. Compress with zstd and encrypt with age +5. Upload to S3 as `metadata/{snapshot-id}/db.zst.age` +6. Generate blob manifest and upload as `metadata/{snapshot-id}/manifest.json.zst.age` + +### 4. Restore Process The restore process doesn't use the local database. Instead: 1. Downloads snapshot metadata from S3 2. Downloads required blobs based on manifest 3. Reconstructs files from decrypted and decompressed chunks -### 4. Pruning +### 5. Pruning 1. **Identify Unreferenced Blobs** - Query blobs not referenced by any remaining snapshot - Delete from S3 and local database +### 6. Incomplete Snapshot Cleanup + +Before each backup: +1. Query incomplete snapshots (where `completed_at IS NULL`) +2. Check if metadata exists in S3 +3. If no metadata, delete snapshot and all associations +4. Clean up orphaned files, chunks, and blobs + ## Repository Pattern Vaultik uses a repository pattern for database access: -- `FileRepository` - CRUD operations for files -- `ChunkRepository` - CRUD operations for chunks -- `FileChunkRepository` - Manage file-chunk mappings -- `BlobRepository` - Manage blob lifecycle -- `BlobChunkRepository` - Manage blob-chunk associations -- `SnapshotRepository` - Manage snapshots -- `UploadRepository` - Track upload metrics +- `FileRepository` - CRUD operations for files and file metadata +- `ChunkRepository` - CRUD operations for content chunks +- `FileChunkRepository` - Manage file-to-chunk mappings +- `ChunkFileRepository` - Manage chunk-to-file reverse mappings +- `BlobRepository` - Manage blob lifecycle (creation, finalization, upload) +- `BlobChunkRepository` - Manage blob-to-chunk associations +- `SnapshotRepository` - Manage snapshots and their relationships +- `UploadRepository` - Track blob upload metrics Each repository provides methods like: - `Create()` - Insert new record - `GetByID()` / `GetByPath()` / `GetByHash()` - Retrieve records - `Update()` - Update existing records - `Delete()` - Remove records -- Specialized queries for each entity type +- Specialized queries for each entity type (e.g., `DeleteOrphaned()`, `GetIncompleteByHostname()`) ## Transaction Management @@ -228,9 +250,9 @@ This ensures consistency, especially important for operations like: ## Performance Considerations -1. **Indexes**: Primary keys are automatically indexed. Additional indexes may be needed for: - - `blobs.hash` for lookup performance - - `blob_chunks.chunk_hash` for chunk location queries +1. **Indexes**: + - Primary keys are automatically indexed + - `idx_files_path` on `files(path)` for efficient file lookups 2. **Prepared Statements**: All queries use prepared statements for performance and security @@ -240,7 +262,7 @@ This ensures consistency, especially important for operations like: ## Data Integrity -1. **Foreign Keys**: Enforced at the application level through repository methods -2. **Unique Constraints**: Chunk hashes and file paths are unique +1. **Foreign Keys**: Enforced through CASCADE DELETE and application-level repository methods +2. **Unique Constraints**: Chunk hashes, file paths, and blob hashes are unique 3. **Null Handling**: Nullable fields clearly indicate in-progress operations 4. **Timestamp Tracking**: All major operations record timestamps for auditing \ No newline at end of file diff --git a/internal/backup/backup_test.go b/internal/backup/backup_test.go index f51ee04..77c46ed 100644 --- a/internal/backup/backup_test.go +++ b/internal/backup/backup_test.go @@ -393,7 +393,6 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { chunk := &database.Chunk{ ChunkHash: chunkHash, - SHA256: chunkHash, Size: int64(n), } return b.repos.Chunks.Create(ctx, tx, chunk) diff --git a/internal/backup/module.go b/internal/backup/module.go index 6109031..6e6161a 100644 --- a/internal/backup/module.go +++ b/internal/backup/module.go @@ -19,6 +19,7 @@ type ScannerParams struct { var Module = fx.Module("backup", fx.Provide( provideScannerFactory, + NewSnapshotManager, ), ) diff --git a/internal/backup/scanner.go b/internal/backup/scanner.go index 80c21a4..1fc8e03 100644 --- a/internal/backup/scanner.go +++ b/internal/backup/scanner.go @@ -12,7 +12,6 @@ import ( "git.eeqj.de/sneak/vaultik/internal/blob" "git.eeqj.de/sneak/vaultik/internal/chunker" - "git.eeqj.de/sneak/vaultik/internal/crypto" "git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/log" "git.eeqj.de/sneak/vaultik/internal/s3" @@ -86,17 +85,11 @@ func NewScanner(cfg ScannerConfig) *Scanner { return nil } - enc, err := crypto.NewEncryptor(cfg.AgeRecipients) - if err != nil { - log.Error("Failed to create encryptor", "error", err) - return nil - } - // Create blob packer with encryption packerCfg := blob.PackerConfig{ MaxBlobSize: cfg.MaxBlobSize, CompressionLevel: cfg.CompressionLevel, - Encryptor: enc, + Recipients: cfg.AgeRecipients, Repositories: cfg.Repositories, } packer, err := blob.NewPacker(packerCfg) @@ -182,6 +175,18 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc blobs := s.packer.GetFinishedBlobs() result.BlobsCreated += len(blobs) + // Query database for actual blob count created during this snapshot + // The database is authoritative, especially for concurrent blob uploads + // We count uploads rather than all snapshot_blobs to get only NEW blobs + if s.snapshotID != "" { + uploadCount, err := s.repos.Uploads.GetCountBySnapshot(ctx, s.snapshotID) + if err != nil { + log.Warn("Failed to get upload count from database", "error", err) + } else { + result.BlobsCreated = int(uploadCount) + } + } + result.EndTime = time.Now().UTC() return result, nil } @@ -341,24 +346,22 @@ func (s *Scanner) checkFile(ctx context.Context, path string, info os.FileInfo, fileChanged := existingFile == nil || s.hasFileChanged(existingFile, file) - // Update file metadata in a short transaction - log.Debug("Updating file metadata", "path", path, "changed", fileChanged) + // Update file metadata and add to snapshot in a single transaction + log.Debug("Updating file metadata and adding to snapshot", "path", path, "changed", fileChanged, "snapshot", s.snapshotID) err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { - return s.repos.Files.Create(ctx, tx, file) + // First create/update the file + if err := s.repos.Files.Create(ctx, tx, file); err != nil { + return fmt.Errorf("creating file: %w", err) + } + // Then add it to the snapshot using the file ID + if err := s.repos.Snapshots.AddFileByID(ctx, tx, s.snapshotID, file.ID); err != nil { + return fmt.Errorf("adding file to snapshot: %w", err) + } + return nil }) if err != nil { return nil, false, err } - log.Debug("File metadata updated", "path", path) - - // Add file to snapshot in a short transaction - log.Debug("Adding file to snapshot", "path", path, "snapshot", s.snapshotID) - err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { - return s.repos.Snapshots.AddFile(ctx, tx, s.snapshotID, path) - }) - if err != nil { - return nil, false, fmt.Errorf("adding file to snapshot: %w", err) - } log.Debug("File added to snapshot", "path", path) result.FilesScanned++ @@ -542,6 +545,14 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { uploadDuration := time.Since(startTime) + // Log upload stats + uploadSpeed := float64(finishedBlob.Compressed) * 8 / uploadDuration.Seconds() // bits per second + log.Info("Uploaded blob to S3", + "path", blobPath, + "size", humanize.Bytes(uint64(finishedBlob.Compressed)), + "duration", uploadDuration, + "speed", humanize.SI(uploadSpeed, "bps")) + // Report upload complete if s.progress != nil { s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration) @@ -574,6 +585,7 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { // Record upload metrics upload := &database.Upload{ BlobHash: finishedBlob.Hash, + SnapshotID: s.snapshotID, UploadedAt: startTime, Size: finishedBlob.Compressed, DurationMs: uploadDuration.Milliseconds(), @@ -645,7 +657,6 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { dbChunk := &database.Chunk{ ChunkHash: chunk.Hash, - SHA256: chunk.Hash, Size: chunk.Size, } if err := s.repos.Chunks.Create(txCtx, tx, dbChunk); err != nil { diff --git a/internal/backup/snapshot.go b/internal/backup/snapshot.go index 4f95d1b..5a15189 100644 --- a/internal/backup/snapshot.go +++ b/internal/backup/snapshot.go @@ -48,32 +48,39 @@ import ( "os" "os/exec" "path/filepath" - "runtime" "time" + "git.eeqj.de/sneak/vaultik/internal/blobgen" + "git.eeqj.de/sneak/vaultik/internal/config" "git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/log" - "github.com/klauspost/compress/zstd" + "git.eeqj.de/sneak/vaultik/internal/s3" + "github.com/dustin/go-humanize" + "go.uber.org/fx" ) // SnapshotManager handles snapshot creation and metadata export type SnapshotManager struct { - repos *database.Repositories - s3Client S3Client - encryptor Encryptor + repos *database.Repositories + s3Client S3Client + config *config.Config } -// Encryptor interface for snapshot encryption -type Encryptor interface { - Encrypt(data []byte) ([]byte, error) +// SnapshotManagerParams holds dependencies for NewSnapshotManager +type SnapshotManagerParams struct { + fx.In + + Repos *database.Repositories + S3Client *s3.Client + Config *config.Config } -// NewSnapshotManager creates a new snapshot manager -func NewSnapshotManager(repos *database.Repositories, s3Client S3Client, encryptor Encryptor) *SnapshotManager { +// NewSnapshotManager creates a new snapshot manager for dependency injection +func NewSnapshotManager(params SnapshotManagerParams) *SnapshotManager { return &SnapshotManager{ - repos: repos, - s3Client: s3Client, - encryptor: encryptor, + repos: params.Repos, + s3Client: params.S3Client, + config: params.Config, } } @@ -208,11 +215,20 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st log.Debug("Database copy complete", "size", getFileSize(tempDBPath)) // Step 2: Clean the temp database to only contain current snapshot data - log.Debug("Cleaning snapshot database to contain only current snapshot", "snapshot_id", snapshotID) - if err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID); err != nil { + log.Debug("Cleaning temporary snapshot database to contain only current snapshot", "snapshot_id", snapshotID, "db_path", tempDBPath) + stats, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID) + if err != nil { return fmt.Errorf("cleaning snapshot database: %w", err) } - log.Debug("Database cleaning complete", "size_after_clean", getFileSize(tempDBPath)) + log.Info("Snapshot database cleanup complete", + "db_path", tempDBPath, + "size_after_clean", humanize.Bytes(uint64(getFileSize(tempDBPath))), + "files", stats.FileCount, + "chunks", stats.ChunkCount, + "blobs", stats.BlobCount, + "total_compressed_size", humanize.Bytes(uint64(stats.CompressedSize)), + "total_uncompressed_size", humanize.Bytes(uint64(stats.UncompressedSize)), + "compression_ratio", fmt.Sprintf("%.2fx", float64(stats.UncompressedSize)/float64(stats.CompressedSize))) // Step 3: Dump the cleaned database to SQL dumpPath := filepath.Join(tempDir, "snapshot.sql") @@ -222,62 +238,59 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st } log.Debug("SQL dump complete", "size", getFileSize(dumpPath)) - // Step 4: Compress the SQL dump - compressedPath := filepath.Join(tempDir, "snapshot.sql.zst") - log.Debug("Compressing SQL dump", "source", dumpPath, "destination", compressedPath) + // Step 4: Compress and encrypt the SQL dump + compressedPath := filepath.Join(tempDir, "snapshot.sql.zst.age") + log.Debug("Compressing and encrypting SQL dump", "source", dumpPath, "destination", compressedPath) if err := sm.compressDump(dumpPath, compressedPath); err != nil { return fmt.Errorf("compressing dump: %w", err) } log.Debug("Compression complete", "original_size", getFileSize(dumpPath), "compressed_size", getFileSize(compressedPath)) - // Step 5: Read compressed data for encryption/upload - log.Debug("Reading compressed data for upload", "path", compressedPath) - compressedData, err := os.ReadFile(compressedPath) + // Step 5: Read compressed and encrypted data for upload + log.Debug("Reading compressed and encrypted data for upload", "path", compressedPath) + finalData, err := os.ReadFile(compressedPath) if err != nil { return fmt.Errorf("reading compressed dump: %w", err) } - // Step 6: Encrypt if encryptor is available - finalData := compressedData - if sm.encryptor != nil { - log.Debug("Encrypting snapshot data", "size_before", len(compressedData)) - encrypted, err := sm.encryptor.Encrypt(compressedData) - if err != nil { - return fmt.Errorf("encrypting snapshot: %w", err) - } - finalData = encrypted - log.Debug("Encryption complete", "size_after", len(encrypted)) - } else { - log.Debug("No encryption configured, using compressed data as-is") - } - - // Step 7: Generate blob manifest (before closing temp DB) + // Step 6: Generate blob manifest (before closing temp DB) log.Debug("Generating blob manifest from temporary database", "db_path", tempDBPath) blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) if err != nil { return fmt.Errorf("generating blob manifest: %w", err) } - // Step 8: Upload to S3 in snapshot subdirectory - // Upload database backup (encrypted) - dbKey := fmt.Sprintf("metadata/%s/db.zst", snapshotID) - if sm.encryptor != nil { - dbKey += ".age" - } + // Step 7: Upload to S3 in snapshot subdirectory + // Upload database backup (compressed and encrypted) + dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) log.Debug("Uploading snapshot database to S3", "key", dbKey, "size", len(finalData)) + dbUploadStart := time.Now() if err := sm.s3Client.PutObject(ctx, dbKey, bytes.NewReader(finalData)); err != nil { return fmt.Errorf("uploading snapshot database: %w", err) } - log.Debug("Database upload complete", "key", dbKey) + dbUploadDuration := time.Since(dbUploadStart) + dbUploadSpeed := float64(len(finalData)) * 8 / dbUploadDuration.Seconds() // bits per second + log.Info("Uploaded snapshot database to S3", + "path", dbKey, + "size", humanize.Bytes(uint64(len(finalData))), + "duration", dbUploadDuration, + "speed", humanize.SI(dbUploadSpeed, "bps")) - // Upload blob manifest (unencrypted, compressed) - manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) + // Upload blob manifest (compressed and encrypted) + manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst.age", snapshotID) log.Debug("Uploading blob manifest to S3", "key", manifestKey, "size", len(blobManifest)) + manifestUploadStart := time.Now() if err := sm.s3Client.PutObject(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil { return fmt.Errorf("uploading blob manifest: %w", err) } - log.Debug("Manifest upload complete", "key", manifestKey) + manifestUploadDuration := time.Since(manifestUploadStart) + manifestUploadSpeed := float64(len(blobManifest)) * 8 / manifestUploadDuration.Seconds() // bits per second + log.Info("Uploaded blob manifest to S3", + "path", manifestKey, + "size", humanize.Bytes(uint64(len(blobManifest))), + "duration", manifestUploadDuration, + "speed", humanize.SI(manifestUploadSpeed, "bps")) log.Info("Uploaded snapshot metadata", "snapshot_id", snapshotID, @@ -286,43 +299,32 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st return nil } +// CleanupStats contains statistics about cleaned snapshot database +type CleanupStats struct { + FileCount int + ChunkCount int + BlobCount int + CompressedSize int64 + UncompressedSize int64 +} + // cleanSnapshotDB removes all data except for the specified snapshot // -// Current implementation: -// Since we don't yet have snapshot-file relationships, this currently only -// removes other snapshots. In a complete implementation, it would: +// The cleanup is performed in a specific order to maintain referential integrity: +// 1. Delete other snapshots +// 2. Delete orphaned snapshot associations (snapshot_files, snapshot_blobs) for deleted snapshots +// 3. Delete orphaned files (not in the current snapshot) +// 4. Delete orphaned chunk-to-file mappings (references to deleted files) +// 5. Delete orphaned blobs (not in the current snapshot) +// 6. Delete orphaned blob-to-chunk mappings (references to deleted chunks) +// 7. Delete orphaned chunks (not referenced by any file) // -// 1. Delete all snapshots except the current one -// 2. Delete files not belonging to the current snapshot -// 3. Delete file_chunks for deleted files (CASCADE) -// 4. Delete chunk_files for deleted files -// 5. Delete chunks with no remaining file references -// 6. Delete blob_chunks for deleted chunks -// 7. Delete blobs with no remaining chunks -// -// The order is important to maintain referential integrity. -// -// Future implementation when we have snapshot_files table: -// -// DELETE FROM snapshots WHERE id != ?; -// DELETE FROM files WHERE NOT EXISTS ( -// SELECT 1 FROM snapshot_files -// WHERE snapshot_files.file_id = files.id -// AND snapshot_files.snapshot_id = ? -// ); -// DELETE FROM chunks WHERE NOT EXISTS ( -// SELECT 1 FROM file_chunks -// WHERE file_chunks.chunk_hash = chunks.chunk_hash -// ); -// DELETE FROM blobs WHERE NOT EXISTS ( -// SELECT 1 FROM blob_chunks -// WHERE blob_chunks.blob_hash = blobs.blob_hash -// ); -func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, snapshotID string) error { +// Each step is implemented as a separate method for clarity and maintainability. +func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, snapshotID string) (*CleanupStats, error) { // Open the temp database db, err := database.New(ctx, dbPath) if err != nil { - return fmt.Errorf("opening temp database: %w", err) + return nil, fmt.Errorf("opening temp database: %w", err) } defer func() { if err := db.Close(); err != nil { @@ -333,7 +335,7 @@ func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, s // Start a transaction tx, err := db.BeginTx(ctx, nil) if err != nil { - return fmt.Errorf("beginning transaction: %w", err) + return nil, fmt.Errorf("beginning transaction: %w", err) } defer func() { if rbErr := tx.Rollback(); rbErr != nil && rbErr != sql.ErrTxDone { @@ -341,123 +343,77 @@ func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, s } }() - // Step 1: Delete all other snapshots - log.Debug("Deleting other snapshots", "keeping", snapshotID) - database.LogSQL("Execute", "DELETE FROM snapshots WHERE id != ?", snapshotID) - result, err := tx.ExecContext(ctx, "DELETE FROM snapshots WHERE id != ?", snapshotID) - if err != nil { - return fmt.Errorf("deleting other snapshots: %w", err) + // Execute cleanup steps in order + if err := sm.deleteOtherSnapshots(ctx, tx, snapshotID); err != nil { + return nil, fmt.Errorf("step 1 - delete other snapshots: %w", err) } - rowsAffected, _ := result.RowsAffected() - log.Debug("Deleted snapshots", "count", rowsAffected) - // Step 2: Delete files not in this snapshot - log.Debug("Deleting files not in current snapshot") - database.LogSQL("Execute", `DELETE FROM files WHERE NOT EXISTS (SELECT 1 FROM snapshot_files WHERE snapshot_files.file_id = files.id AND snapshot_files.snapshot_id = ?)`, snapshotID) - result, err = tx.ExecContext(ctx, ` - DELETE FROM files - WHERE NOT EXISTS ( - SELECT 1 FROM snapshot_files - WHERE snapshot_files.file_id = files.id - AND snapshot_files.snapshot_id = ? - )`, snapshotID) - if err != nil { - return fmt.Errorf("deleting orphaned files: %w", err) + if err := sm.deleteOrphanedSnapshotAssociations(ctx, tx, snapshotID); err != nil { + return nil, fmt.Errorf("step 2 - delete orphaned snapshot associations: %w", err) } - rowsAffected, _ = result.RowsAffected() - log.Debug("Deleted files", "count", rowsAffected) - // Step 3: file_chunks will be deleted via CASCADE from files - log.Debug("file_chunks will be deleted via CASCADE") - - // Step 4: Delete chunk_files for deleted files - log.Debug("Deleting orphaned chunk_files") - database.LogSQL("Execute", `DELETE FROM chunk_files WHERE NOT EXISTS (SELECT 1 FROM files WHERE files.id = chunk_files.file_id)`) - result, err = tx.ExecContext(ctx, ` - DELETE FROM chunk_files - WHERE NOT EXISTS ( - SELECT 1 FROM files - WHERE files.id = chunk_files.file_id - )`) - if err != nil { - return fmt.Errorf("deleting orphaned chunk_files: %w", err) + if err := sm.deleteOrphanedFiles(ctx, tx, snapshotID); err != nil { + return nil, fmt.Errorf("step 3 - delete orphaned files: %w", err) } - rowsAffected, _ = result.RowsAffected() - log.Debug("Deleted chunk_files", "count", rowsAffected) - // Step 5: Delete chunks with no remaining file references - log.Debug("Deleting orphaned chunks") - database.LogSQL("Execute", `DELETE FROM chunks WHERE NOT EXISTS (SELECT 1 FROM file_chunks WHERE file_chunks.chunk_hash = chunks.chunk_hash)`) - result, err = tx.ExecContext(ctx, ` - DELETE FROM chunks - WHERE NOT EXISTS ( - SELECT 1 FROM file_chunks - WHERE file_chunks.chunk_hash = chunks.chunk_hash - )`) - if err != nil { - return fmt.Errorf("deleting orphaned chunks: %w", err) + if err := sm.deleteOrphanedChunkToFileMappings(ctx, tx); err != nil { + return nil, fmt.Errorf("step 4 - delete orphaned chunk-to-file mappings: %w", err) } - rowsAffected, _ = result.RowsAffected() - log.Debug("Deleted chunks", "count", rowsAffected) - // Step 6: Delete blob_chunks for deleted chunks - log.Debug("Deleting orphaned blob_chunks") - database.LogSQL("Execute", `DELETE FROM blob_chunks WHERE NOT EXISTS (SELECT 1 FROM chunks WHERE chunks.chunk_hash = blob_chunks.chunk_hash)`) - result, err = tx.ExecContext(ctx, ` - DELETE FROM blob_chunks - WHERE NOT EXISTS ( - SELECT 1 FROM chunks - WHERE chunks.chunk_hash = blob_chunks.chunk_hash - )`) - if err != nil { - return fmt.Errorf("deleting orphaned blob_chunks: %w", err) + if err := sm.deleteOrphanedBlobs(ctx, tx, snapshotID); err != nil { + return nil, fmt.Errorf("step 5 - delete orphaned blobs: %w", err) } - rowsAffected, _ = result.RowsAffected() - log.Debug("Deleted blob_chunks", "count", rowsAffected) - // Step 7: Delete blobs not in this snapshot - log.Debug("Deleting blobs not in current snapshot") - database.LogSQL("Execute", `DELETE FROM blobs WHERE NOT EXISTS (SELECT 1 FROM snapshot_blobs WHERE snapshot_blobs.blob_hash = blobs.blob_hash AND snapshot_blobs.snapshot_id = ?)`, snapshotID) - result, err = tx.ExecContext(ctx, ` - DELETE FROM blobs - WHERE NOT EXISTS ( - SELECT 1 FROM snapshot_blobs - WHERE snapshot_blobs.blob_hash = blobs.blob_hash - AND snapshot_blobs.snapshot_id = ? - )`, snapshotID) - if err != nil { - return fmt.Errorf("deleting orphaned blobs: %w", err) + if err := sm.deleteOrphanedBlobToChunkMappings(ctx, tx); err != nil { + return nil, fmt.Errorf("step 6 - delete orphaned blob-to-chunk mappings: %w", err) } - rowsAffected, _ = result.RowsAffected() - log.Debug("Deleted blobs not in snapshot", "count", rowsAffected) - // Step 8: Delete orphaned snapshot_files and snapshot_blobs - log.Debug("Deleting orphaned snapshot_files") - database.LogSQL("Execute", "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID) - result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID) - if err != nil { - return fmt.Errorf("deleting orphaned snapshot_files: %w", err) + if err := sm.deleteOrphanedChunks(ctx, tx); err != nil { + return nil, fmt.Errorf("step 7 - delete orphaned chunks: %w", err) } - rowsAffected, _ = result.RowsAffected() - log.Debug("Deleted snapshot_files", "count", rowsAffected) - - log.Debug("Deleting orphaned snapshot_blobs") - database.LogSQL("Execute", "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID) - result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID) - if err != nil { - return fmt.Errorf("deleting orphaned snapshot_blobs: %w", err) - } - rowsAffected, _ = result.RowsAffected() - log.Debug("Deleted snapshot_blobs", "count", rowsAffected) // Commit transaction - log.Debug("Committing cleanup transaction") + log.Debug("[Temp DB Cleanup] Committing cleanup transaction") if err := tx.Commit(); err != nil { - return fmt.Errorf("committing transaction: %w", err) + return nil, fmt.Errorf("committing transaction: %w", err) } - log.Debug("Database cleanup complete") - return nil + // Collect statistics about the cleaned database + stats := &CleanupStats{} + + // Count files + var fileCount int + err = db.QueryRowWithLog(ctx, "SELECT COUNT(*) FROM files").Scan(&fileCount) + if err != nil { + return nil, fmt.Errorf("counting files: %w", err) + } + stats.FileCount = fileCount + + // Count chunks + var chunkCount int + err = db.QueryRowWithLog(ctx, "SELECT COUNT(*) FROM chunks").Scan(&chunkCount) + if err != nil { + return nil, fmt.Errorf("counting chunks: %w", err) + } + stats.ChunkCount = chunkCount + + // Count blobs and get sizes + var blobCount int + var compressedSize, uncompressedSize sql.NullInt64 + err = db.QueryRowWithLog(ctx, ` + SELECT COUNT(*), COALESCE(SUM(compressed_size), 0), COALESCE(SUM(uncompressed_size), 0) + FROM blobs + WHERE blob_hash IN (SELECT blob_hash FROM snapshot_blobs WHERE snapshot_id = ?) + `, snapshotID).Scan(&blobCount, &compressedSize, &uncompressedSize) + if err != nil { + return nil, fmt.Errorf("counting blobs and sizes: %w", err) + } + stats.BlobCount = blobCount + stats.CompressedSize = compressedSize.Int64 + stats.UncompressedSize = uncompressedSize.Int64 + + log.Debug("[Temp DB Cleanup] Database cleanup complete", "stats", stats) + return stats, nil } // dumpDatabase creates a SQL dump of the database @@ -492,7 +448,7 @@ func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error { } }() - log.Debug("Creating output file for compressed data", "path", outputPath) + log.Debug("Creating output file for compressed and encrypted data", "path", outputPath) output, err := os.Create(outputPath) if err != nil { return fmt.Errorf("creating output file: %w", err) @@ -504,27 +460,30 @@ func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error { } }() - // Create zstd encoder with good compression and multithreading - log.Debug("Creating zstd compressor", "level", "SpeedBetterCompression", "concurrency", runtime.NumCPU()) - zstdWriter, err := zstd.NewWriter(output, - zstd.WithEncoderLevel(zstd.SpeedBetterCompression), - zstd.WithEncoderConcurrency(runtime.NumCPU()), - zstd.WithWindowSize(4<<20), // 4MB window for metadata files - ) + // Use blobgen for compression and encryption + log.Debug("Creating compressor/encryptor", "level", sm.config.CompressionLevel) + writer, err := blobgen.NewWriter(output, sm.config.CompressionLevel, sm.config.AgeRecipients) if err != nil { - return fmt.Errorf("creating zstd writer: %w", err) + return fmt.Errorf("creating blobgen writer: %w", err) } defer func() { - if err := zstdWriter.Close(); err != nil { - log.Debug("Failed to close zstd writer", "error", err) + if err := writer.Close(); err != nil { + log.Debug("Failed to close writer", "error", err) } }() - log.Debug("Compressing data") - if _, err := io.Copy(zstdWriter, input); err != nil { + log.Debug("Compressing and encrypting data") + if _, err := io.Copy(writer, input); err != nil { return fmt.Errorf("compressing data: %w", err) } + // Close writer to flush all data + if err := writer.Close(); err != nil { + return fmt.Errorf("closing writer: %w", err) + } + + log.Debug("Compression complete", "hash", fmt.Sprintf("%x", writer.Sum256())) + return nil } @@ -607,44 +566,28 @@ func (sm *SnapshotManager) generateBlobManifest(ctx context.Context, dbPath stri } log.Debug("JSON manifest created", "size", len(jsonData)) - // Compress with zstd - log.Debug("Compressing manifest with zstd") - compressed, err := compressData(jsonData) + // Compress and encrypt with blobgen + log.Debug("Compressing and encrypting manifest") + + result, err := blobgen.CompressData(jsonData, sm.config.CompressionLevel, sm.config.AgeRecipients) if err != nil { return nil, fmt.Errorf("compressing manifest: %w", err) } - log.Debug("Manifest compressed", "original_size", len(jsonData), "compressed_size", len(compressed)) + log.Debug("Manifest compressed and encrypted", + "original_size", len(jsonData), + "compressed_size", result.CompressedSize, + "hash", result.SHA256) log.Info("Generated blob manifest", "snapshot_id", snapshotID, "blob_count", len(blobs), "json_size", len(jsonData), - "compressed_size", len(compressed)) + "compressed_size", result.CompressedSize) - return compressed, nil + return result.Data, nil } // compressData compresses data using zstd -func compressData(data []byte) ([]byte, error) { - var buf bytes.Buffer - w, err := zstd.NewWriter(&buf, - zstd.WithEncoderLevel(zstd.SpeedBetterCompression), - ) - if err != nil { - return nil, err - } - - if _, err := w.Write(data); err != nil { - _ = w.Close() - return nil, err - } - - if err := w.Close(); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} // getFileSize returns the size of a file in bytes, or -1 if error func getFileSize(path string) int64 { @@ -738,7 +681,7 @@ func (sm *SnapshotManager) deleteSnapshot(ctx context.Context, snapshotID string } // Clean up orphaned data - log.Debug("Cleaning up orphaned data") + log.Debug("Cleaning up orphaned data in main database") if err := sm.cleanupOrphanedData(ctx); err != nil { return fmt.Errorf("cleaning up orphaned data: %w", err) } @@ -748,23 +691,170 @@ func (sm *SnapshotManager) deleteSnapshot(ctx context.Context, snapshotID string // cleanupOrphanedData removes files, chunks, and blobs that are no longer referenced by any snapshot func (sm *SnapshotManager) cleanupOrphanedData(ctx context.Context) error { + // Order is important to respect foreign key constraints: + // 1. Delete orphaned files (will cascade delete file_chunks) + // 2. Delete orphaned blobs (will cascade delete blob_chunks for deleted blobs) + // 3. Delete orphaned blob_chunks (where blob exists but chunk doesn't) + // 4. Delete orphaned chunks (now safe after all blob_chunks are gone) + // Delete orphaned files (files not in any snapshot) log.Debug("Deleting orphaned files") if err := sm.repos.Files.DeleteOrphaned(ctx); err != nil { return fmt.Errorf("deleting orphaned files: %w", err) } - // Delete orphaned chunks (chunks not referenced by any file) - log.Debug("Deleting orphaned chunks") - if err := sm.repos.Chunks.DeleteOrphaned(ctx); err != nil { - return fmt.Errorf("deleting orphaned chunks: %w", err) - } - // Delete orphaned blobs (blobs not in any snapshot) + // This will cascade delete blob_chunks for deleted blobs log.Debug("Deleting orphaned blobs") if err := sm.repos.Blobs.DeleteOrphaned(ctx); err != nil { return fmt.Errorf("deleting orphaned blobs: %w", err) } + // Delete orphaned blob_chunks entries + // This handles cases where the blob still exists but chunks were deleted + log.Debug("Deleting orphaned blob_chunks") + if err := sm.repos.BlobChunks.DeleteOrphaned(ctx); err != nil { + return fmt.Errorf("deleting orphaned blob_chunks: %w", err) + } + + // Delete orphaned chunks (chunks not referenced by any file) + // This must come after cleaning up blob_chunks to avoid foreign key violations + log.Debug("Deleting orphaned chunks") + if err := sm.repos.Chunks.DeleteOrphaned(ctx); err != nil { + return fmt.Errorf("deleting orphaned chunks: %w", err) + } + + return nil +} + +// deleteOtherSnapshots deletes all snapshots except the current one +func (sm *SnapshotManager) deleteOtherSnapshots(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error { + log.Debug("[Temp DB Cleanup] Deleting other snapshots", "keeping", currentSnapshotID) + database.LogSQL("Execute", "DELETE FROM snapshots WHERE id != ?", currentSnapshotID) + result, err := tx.ExecContext(ctx, "DELETE FROM snapshots WHERE id != ?", currentSnapshotID) + if err != nil { + return fmt.Errorf("deleting other snapshots: %w", err) + } + rowsAffected, _ := result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted snapshots", "count", rowsAffected) + return nil +} + +// deleteOrphanedSnapshotAssociations deletes snapshot_files and snapshot_blobs for deleted snapshots +func (sm *SnapshotManager) deleteOrphanedSnapshotAssociations(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error { + // Delete orphaned snapshot_files + log.Debug("[Temp DB Cleanup] Deleting orphaned snapshot_files") + database.LogSQL("Execute", "DELETE FROM snapshot_files WHERE snapshot_id != ?", currentSnapshotID) + result, err := tx.ExecContext(ctx, "DELETE FROM snapshot_files WHERE snapshot_id != ?", currentSnapshotID) + if err != nil { + return fmt.Errorf("deleting orphaned snapshot_files: %w", err) + } + rowsAffected, _ := result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted snapshot_files", "count", rowsAffected) + + // Delete orphaned snapshot_blobs + log.Debug("[Temp DB Cleanup] Deleting orphaned snapshot_blobs") + database.LogSQL("Execute", "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", currentSnapshotID) + result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", currentSnapshotID) + if err != nil { + return fmt.Errorf("deleting orphaned snapshot_blobs: %w", err) + } + rowsAffected, _ = result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted snapshot_blobs", "count", rowsAffected) + return nil +} + +// deleteOrphanedFiles deletes files not in the current snapshot +func (sm *SnapshotManager) deleteOrphanedFiles(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error { + log.Debug("[Temp DB Cleanup] Deleting files not in current snapshot") + database.LogSQL("Execute", `DELETE FROM files WHERE NOT EXISTS (SELECT 1 FROM snapshot_files WHERE snapshot_files.file_id = files.id AND snapshot_files.snapshot_id = ?)`, currentSnapshotID) + result, err := tx.ExecContext(ctx, ` + DELETE FROM files + WHERE NOT EXISTS ( + SELECT 1 FROM snapshot_files + WHERE snapshot_files.file_id = files.id + AND snapshot_files.snapshot_id = ? + )`, currentSnapshotID) + if err != nil { + return fmt.Errorf("deleting orphaned files: %w", err) + } + rowsAffected, _ := result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted files", "count", rowsAffected) + + // Note: file_chunks will be deleted via CASCADE + log.Debug("[Temp DB Cleanup] file_chunks will be deleted via CASCADE") + return nil +} + +// deleteOrphanedChunkToFileMappings deletes chunk_files entries for deleted files +func (sm *SnapshotManager) deleteOrphanedChunkToFileMappings(ctx context.Context, tx *sql.Tx) error { + log.Debug("[Temp DB Cleanup] Deleting orphaned chunk_files") + database.LogSQL("Execute", `DELETE FROM chunk_files WHERE NOT EXISTS (SELECT 1 FROM files WHERE files.id = chunk_files.file_id)`) + result, err := tx.ExecContext(ctx, ` + DELETE FROM chunk_files + WHERE NOT EXISTS ( + SELECT 1 FROM files + WHERE files.id = chunk_files.file_id + )`) + if err != nil { + return fmt.Errorf("deleting orphaned chunk_files: %w", err) + } + rowsAffected, _ := result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted chunk_files", "count", rowsAffected) + return nil +} + +// deleteOrphanedBlobs deletes blobs not in the current snapshot +func (sm *SnapshotManager) deleteOrphanedBlobs(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error { + log.Debug("[Temp DB Cleanup] Deleting blobs not in current snapshot") + database.LogSQL("Execute", `DELETE FROM blobs WHERE NOT EXISTS (SELECT 1 FROM snapshot_blobs WHERE snapshot_blobs.blob_hash = blobs.blob_hash AND snapshot_blobs.snapshot_id = ?)`, currentSnapshotID) + result, err := tx.ExecContext(ctx, ` + DELETE FROM blobs + WHERE NOT EXISTS ( + SELECT 1 FROM snapshot_blobs + WHERE snapshot_blobs.blob_hash = blobs.blob_hash + AND snapshot_blobs.snapshot_id = ? + )`, currentSnapshotID) + if err != nil { + return fmt.Errorf("deleting orphaned blobs: %w", err) + } + rowsAffected, _ := result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted blobs not in snapshot", "count", rowsAffected) + return nil +} + +// deleteOrphanedBlobToChunkMappings deletes blob_chunks entries for deleted blobs +func (sm *SnapshotManager) deleteOrphanedBlobToChunkMappings(ctx context.Context, tx *sql.Tx) error { + log.Debug("[Temp DB Cleanup] Deleting orphaned blob_chunks") + database.LogSQL("Execute", `DELETE FROM blob_chunks WHERE NOT EXISTS (SELECT 1 FROM blobs WHERE blobs.id = blob_chunks.blob_id)`) + result, err := tx.ExecContext(ctx, ` + DELETE FROM blob_chunks + WHERE NOT EXISTS ( + SELECT 1 FROM blobs + WHERE blobs.id = blob_chunks.blob_id + )`) + if err != nil { + return fmt.Errorf("deleting orphaned blob_chunks: %w", err) + } + rowsAffected, _ := result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted blob_chunks", "count", rowsAffected) + return nil +} + +// deleteOrphanedChunks deletes chunks not referenced by any file +func (sm *SnapshotManager) deleteOrphanedChunks(ctx context.Context, tx *sql.Tx) error { + log.Debug("[Temp DB Cleanup] Deleting orphaned chunks") + database.LogSQL("Execute", `DELETE FROM chunks WHERE NOT EXISTS (SELECT 1 FROM file_chunks WHERE file_chunks.chunk_hash = chunks.chunk_hash)`) + result, err := tx.ExecContext(ctx, ` + DELETE FROM chunks + WHERE NOT EXISTS ( + SELECT 1 FROM file_chunks + WHERE file_chunks.chunk_hash = chunks.chunk_hash + )`) + if err != nil { + return fmt.Errorf("deleting orphaned chunks: %w", err) + } + rowsAffected, _ := result.RowsAffected() + log.Debug("[Temp DB Cleanup] Deleted chunks", "count", rowsAffected) return nil } diff --git a/internal/backup/snapshot_test.go b/internal/backup/snapshot_test.go index 6e9a413..6859946 100644 --- a/internal/backup/snapshot_test.go +++ b/internal/backup/snapshot_test.go @@ -6,10 +6,16 @@ import ( "path/filepath" "testing" + "git.eeqj.de/sneak/vaultik/internal/config" "git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/log" ) +const ( + // Test age public key for encryption + testAgeRecipient = "age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg" +) + func TestCleanSnapshotDBEmptySnapshot(t *testing.T) { // Initialize logger log.Initialize(log.Config{}) @@ -41,7 +47,7 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) { // Create some files and chunks not associated with any snapshot file := &database.File{Path: "/orphan/file.txt", Size: 1000} - chunk := &database.Chunk{ChunkHash: "orphan-chunk", SHA256: "orphan-chunk", Size: 500} + chunk := &database.Chunk{ChunkHash: "orphan-chunk", Size: 500} err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { if err := repos.Files.Create(ctx, tx, file); err != nil { @@ -64,9 +70,14 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) { t.Fatalf("failed to copy database: %v", err) } + // Create a mock config for testing + cfg := &config.Config{ + CompressionLevel: 3, + AgeRecipients: []string{testAgeRecipient}, + } // Clean the database - sm := &SnapshotManager{} - if err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshot.ID); err != nil { + sm := &SnapshotManager{config: cfg} + if _, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshot.ID); err != nil { t.Fatalf("failed to clean snapshot database: %v", err) } @@ -136,9 +147,14 @@ func TestCleanSnapshotDBNonExistentSnapshot(t *testing.T) { t.Fatalf("failed to copy database: %v", err) } + // Create a mock config for testing + cfg := &config.Config{ + CompressionLevel: 3, + AgeRecipients: []string{testAgeRecipient}, + } // Try to clean with non-existent snapshot - sm := &SnapshotManager{} - err = sm.cleanSnapshotDB(ctx, tempDBPath, "non-existent-snapshot") + sm := &SnapshotManager{config: cfg} + _, err = sm.cleanSnapshotDB(ctx, tempDBPath, "non-existent-snapshot") // Should not error - it will just delete everything if err != nil { diff --git a/internal/blob/packer.go b/internal/blob/packer.go index 11b58a2..75b6ab4 100644 --- a/internal/blob/packer.go +++ b/internal/blob/packer.go @@ -16,22 +16,18 @@ package blob import ( "context" - "crypto/sha256" "database/sql" "encoding/hex" "fmt" - "hash" "io" - "math/bits" "os" - "runtime" "sync" "time" + "git.eeqj.de/sneak/vaultik/internal/blobgen" "git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/log" "github.com/google/uuid" - "github.com/klauspost/compress/zstd" ) // BlobHandler is a callback function invoked when a blob is finalized and ready for upload. @@ -45,7 +41,7 @@ type BlobHandler func(blob *BlobWithReader) error type PackerConfig struct { MaxBlobSize int64 // Maximum size of a blob before forcing finalization CompressionLevel int // Zstd compression level (1-19, higher = better compression) - Encryptor Encryptor // Age encryptor for blob encryption (required) + Recipients []string // Age recipients for encryption Repositories *database.Repositories // Database repositories for tracking blob metadata BlobHandler BlobHandler // Optional callback when blob is ready for upload } @@ -56,7 +52,7 @@ type PackerConfig struct { type Packer struct { maxBlobSize int64 compressionLevel int - encryptor Encryptor // Required - blobs are always encrypted + recipients []string // Age recipients for encryption blobHandler BlobHandler // Called when blob is ready repos *database.Repositories // For creating blob records @@ -68,25 +64,15 @@ type Packer struct { finishedBlobs []*FinishedBlob // Only used if no handler provided } -// Encryptor interface for encryption support -type Encryptor interface { - Encrypt(data []byte) ([]byte, error) - EncryptWriter(dst io.Writer) (io.WriteCloser, error) -} - // blobInProgress represents a blob being assembled type blobInProgress struct { - id string // UUID of the blob - chunks []*chunkInfo // Track chunk metadata - chunkSet map[string]bool // Track unique chunks in this blob - tempFile *os.File // Temporary file for encrypted compressed data - hasher hash.Hash // For computing hash of final encrypted data - compressor io.WriteCloser // Compression writer - encryptor io.WriteCloser // Encryption writer (if encryption enabled) - finalWriter io.Writer // The final writer in the chain - startTime time.Time - size int64 // Current uncompressed size - compressedSize int64 // Current compressed size (estimated) + id string // UUID of the blob + chunks []*chunkInfo // Track chunk metadata + chunkSet map[string]bool // Track unique chunks in this blob + tempFile *os.File // Temporary file for encrypted compressed data + writer *blobgen.Writer // Unified compression/encryption/hashing writer + startTime time.Time + size int64 // Current uncompressed size } // ChunkRef represents a chunk to be added to a blob. @@ -134,8 +120,8 @@ type BlobWithReader struct { // The packer will automatically finalize blobs when they reach MaxBlobSize. // Returns an error if required configuration fields are missing or invalid. func NewPacker(cfg PackerConfig) (*Packer, error) { - if cfg.Encryptor == nil { - return nil, fmt.Errorf("encryptor is required - blobs must be encrypted") + if len(cfg.Recipients) == 0 { + return nil, fmt.Errorf("recipients are required - blobs must be encrypted") } if cfg.MaxBlobSize <= 0 { return nil, fmt.Errorf("max blob size must be positive") @@ -143,7 +129,7 @@ func NewPacker(cfg PackerConfig) (*Packer, error) { return &Packer{ maxBlobSize: cfg.MaxBlobSize, compressionLevel: cfg.CompressionLevel, - encryptor: cfg.Encryptor, + recipients: cfg.Recipients, blobHandler: cfg.BlobHandler, repos: cfg.Repositories, finishedBlobs: make([]*FinishedBlob, 0), @@ -274,66 +260,24 @@ func (p *Packer) startNewBlob() error { return fmt.Errorf("creating temp file: %w", err) } + // Create blobgen writer for unified compression/encryption/hashing + writer, err := blobgen.NewWriter(tempFile, p.compressionLevel, p.recipients) + if err != nil { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + return fmt.Errorf("creating blobgen writer: %w", err) + } + p.currentBlob = &blobInProgress{ - id: blobID, - chunks: make([]*chunkInfo, 0), - chunkSet: make(map[string]bool), - startTime: time.Now().UTC(), - tempFile: tempFile, - hasher: sha256.New(), - size: 0, - compressedSize: 0, + id: blobID, + chunks: make([]*chunkInfo, 0), + chunkSet: make(map[string]bool), + startTime: time.Now().UTC(), + tempFile: tempFile, + writer: writer, + size: 0, } - // Build writer chain: compressor -> [encryptor ->] hasher+file - // This ensures only encrypted data touches disk - - // Final destination: write to both file and hasher - finalWriter := io.MultiWriter(tempFile, p.currentBlob.hasher) - - // Set up encryption (required - closest to disk) - encWriter, err := p.encryptor.EncryptWriter(finalWriter) - if err != nil { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) - return fmt.Errorf("creating encryption writer: %w", err) - } - p.currentBlob.encryptor = encWriter - currentWriter := encWriter - - // Set up compression (processes data before encryption) - encoderLevel := zstd.EncoderLevel(p.compressionLevel) - if p.compressionLevel < 1 { - encoderLevel = zstd.SpeedDefault - } else if p.compressionLevel > 9 { - encoderLevel = zstd.SpeedBestCompression - } - - // Calculate window size based on blob size - windowSize := p.maxBlobSize / 100 - if windowSize < (1 << 20) { // Min 1MB - windowSize = 1 << 20 - } else if windowSize > (128 << 20) { // Max 128MB - windowSize = 128 << 20 - } - windowSize = 1 << uint(63-bits.LeadingZeros64(uint64(windowSize))) - - compWriter, err := zstd.NewWriter(currentWriter, - zstd.WithEncoderLevel(encoderLevel), - zstd.WithEncoderConcurrency(runtime.NumCPU()), - zstd.WithWindowSize(int(windowSize)), - ) - if err != nil { - if p.currentBlob.encryptor != nil { - _ = p.currentBlob.encryptor.Close() - } - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) - return fmt.Errorf("creating compression writer: %w", err) - } - p.currentBlob.compressor = compWriter - p.currentBlob.finalWriter = compWriter - log.Debug("Started new blob", "blob_id", blobID, "temp_file", tempFile.Name()) return nil } @@ -349,8 +293,8 @@ func (p *Packer) addChunkToCurrentBlob(chunk *ChunkRef) error { // Track offset before writing offset := p.currentBlob.size - // Write to the final writer (compression -> encryption -> disk) - if _, err := p.currentBlob.finalWriter.Write(chunk.Data); err != nil { + // Write to the blobgen writer (compression -> encryption -> disk) + if _, err := p.currentBlob.writer.Write(chunk.Data); err != nil { return fmt.Errorf("writing to blob stream: %w", err) } @@ -402,16 +346,10 @@ func (p *Packer) finalizeCurrentBlob() error { return nil } - // Close compression writer to flush all data - if err := p.currentBlob.compressor.Close(); err != nil { + // Close blobgen writer to flush all data + if err := p.currentBlob.writer.Close(); err != nil { p.cleanupTempFile() - return fmt.Errorf("closing compression writer: %w", err) - } - - // Close encryption writer - if err := p.currentBlob.encryptor.Close(); err != nil { - p.cleanupTempFile() - return fmt.Errorf("closing encryption writer: %w", err) + return fmt.Errorf("closing blobgen writer: %w", err) } // Sync file to ensure all data is written @@ -433,8 +371,8 @@ func (p *Packer) finalizeCurrentBlob() error { return fmt.Errorf("seeking to start: %w", err) } - // Get hash from hasher (of final encrypted data) - finalHash := p.currentBlob.hasher.Sum(nil) + // Get hash from blobgen writer (of final encrypted data) + finalHash := p.currentBlob.writer.Sum256() blobHash := hex.EncodeToString(finalHash) // Create chunk references with offsets diff --git a/internal/blob/packer_test.go b/internal/blob/packer_test.go index 40ea3c6..3518901 100644 --- a/internal/blob/packer_test.go +++ b/internal/blob/packer_test.go @@ -2,13 +2,14 @@ package blob import ( "bytes" + "context" "crypto/sha256" + "database/sql" "encoding/hex" "io" "testing" "filippo.io/age" - "git.eeqj.de/sneak/vaultik/internal/crypto" "git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/log" "github.com/klauspost/compress/zstd" @@ -30,12 +31,6 @@ func TestPacker(t *testing.T) { t.Fatalf("failed to parse test identity: %v", err) } - // Create test encryptor using the public key - enc, err := crypto.NewEncryptor([]string{testPublicKey}) - if err != nil { - t.Fatalf("failed to create encryptor: %v", err) - } - t.Run("single chunk creates single blob", func(t *testing.T) { // Create test database db, err := database.NewTestDB() @@ -48,7 +43,7 @@ func TestPacker(t *testing.T) { cfg := PackerConfig{ MaxBlobSize: 10 * 1024 * 1024, // 10MB CompressionLevel: 3, - Encryptor: enc, + Recipients: []string{testPublicKey}, Repositories: repos, } packer, err := NewPacker(cfg) @@ -59,8 +54,22 @@ func TestPacker(t *testing.T) { // Create a chunk data := []byte("Hello, World!") hash := sha256.Sum256(data) + hashStr := hex.EncodeToString(hash[:]) + + // Create chunk in database first + dbChunk := &database.Chunk{ + ChunkHash: hashStr, + Size: int64(len(data)), + } + err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return repos.Chunks.Create(ctx, tx, dbChunk) + }) + if err != nil { + t.Fatalf("failed to create chunk in db: %v", err) + } + chunk := &ChunkRef{ - Hash: hex.EncodeToString(hash[:]), + Hash: hashStr, Data: data, } @@ -123,7 +132,7 @@ func TestPacker(t *testing.T) { cfg := PackerConfig{ MaxBlobSize: 10 * 1024 * 1024, // 10MB CompressionLevel: 3, - Encryptor: enc, + Recipients: []string{testPublicKey}, Repositories: repos, } packer, err := NewPacker(cfg) @@ -136,8 +145,22 @@ func TestPacker(t *testing.T) { for i := 0; i < 10; i++ { data := bytes.Repeat([]byte{byte(i)}, 1000) hash := sha256.Sum256(data) + hashStr := hex.EncodeToString(hash[:]) + + // Create chunk in database first + dbChunk := &database.Chunk{ + ChunkHash: hashStr, + Size: int64(len(data)), + } + err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return repos.Chunks.Create(ctx, tx, dbChunk) + }) + if err != nil { + t.Fatalf("failed to create chunk in db: %v", err) + } + chunks[i] = &ChunkRef{ - Hash: hex.EncodeToString(hash[:]), + Hash: hashStr, Data: data, } } @@ -191,7 +214,7 @@ func TestPacker(t *testing.T) { cfg := PackerConfig{ MaxBlobSize: 5000, // 5KB max CompressionLevel: 3, - Encryptor: enc, + Recipients: []string{testPublicKey}, Repositories: repos, } packer, err := NewPacker(cfg) @@ -204,8 +227,22 @@ func TestPacker(t *testing.T) { for i := 0; i < 10; i++ { data := bytes.Repeat([]byte{byte(i)}, 1000) // 1KB each hash := sha256.Sum256(data) + hashStr := hex.EncodeToString(hash[:]) + + // Create chunk in database first + dbChunk := &database.Chunk{ + ChunkHash: hashStr, + Size: int64(len(data)), + } + err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return repos.Chunks.Create(ctx, tx, dbChunk) + }) + if err != nil { + t.Fatalf("failed to create chunk in db: %v", err) + } + chunks[i] = &ChunkRef{ - Hash: hex.EncodeToString(hash[:]), + Hash: hashStr, Data: data, } } @@ -265,7 +302,7 @@ func TestPacker(t *testing.T) { cfg := PackerConfig{ MaxBlobSize: 10 * 1024 * 1024, // 10MB CompressionLevel: 3, - Encryptor: enc, + Recipients: []string{testPublicKey}, Repositories: repos, } packer, err := NewPacker(cfg) @@ -276,8 +313,22 @@ func TestPacker(t *testing.T) { // Create test data data := bytes.Repeat([]byte("Test data for encryption!"), 100) hash := sha256.Sum256(data) + hashStr := hex.EncodeToString(hash[:]) + + // Create chunk in database first + dbChunk := &database.Chunk{ + ChunkHash: hashStr, + Size: int64(len(data)), + } + err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + return repos.Chunks.Create(ctx, tx, dbChunk) + }) + if err != nil { + t.Fatalf("failed to create chunk in db: %v", err) + } + chunk := &ChunkRef{ - Hash: hex.EncodeToString(hash[:]), + Hash: hashStr, Data: data, } diff --git a/internal/blobgen/compress.go b/internal/blobgen/compress.go new file mode 100644 index 0000000..1292fae --- /dev/null +++ b/internal/blobgen/compress.go @@ -0,0 +1,67 @@ +package blobgen + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" +) + +// CompressResult contains the results of compression +type CompressResult struct { + Data []byte + UncompressedSize int64 + CompressedSize int64 + SHA256 string +} + +// CompressData compresses and encrypts data, returning the result with hash +func CompressData(data []byte, compressionLevel int, recipients []string) (*CompressResult, error) { + var buf bytes.Buffer + + // Create writer + w, err := NewWriter(&buf, compressionLevel, recipients) + if err != nil { + return nil, fmt.Errorf("creating writer: %w", err) + } + + // Write data + if _, err := w.Write(data); err != nil { + _ = w.Close() + return nil, fmt.Errorf("writing data: %w", err) + } + + // Close to flush + if err := w.Close(); err != nil { + return nil, fmt.Errorf("closing writer: %w", err) + } + + return &CompressResult{ + Data: buf.Bytes(), + UncompressedSize: int64(len(data)), + CompressedSize: int64(buf.Len()), + SHA256: hex.EncodeToString(w.Sum256()), + }, nil +} + +// CompressStream compresses and encrypts from reader to writer, returning hash +func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipients []string) (written int64, hash string, err error) { + // Create writer + w, err := NewWriter(dst, compressionLevel, recipients) + if err != nil { + return 0, "", fmt.Errorf("creating writer: %w", err) + } + defer func() { _ = w.Close() }() + + // Copy data + if _, err := io.Copy(w, src); err != nil { + return 0, "", fmt.Errorf("copying data: %w", err) + } + + // Close to flush + if err := w.Close(); err != nil { + return 0, "", fmt.Errorf("closing writer: %w", err) + } + + return w.BytesWritten(), hex.EncodeToString(w.Sum256()), nil +} diff --git a/internal/blobgen/reader.go b/internal/blobgen/reader.go new file mode 100644 index 0000000..cc3edd4 --- /dev/null +++ b/internal/blobgen/reader.go @@ -0,0 +1,73 @@ +package blobgen + +import ( + "crypto/sha256" + "fmt" + "hash" + "io" + + "filippo.io/age" + "github.com/klauspost/compress/zstd" +) + +// Reader wraps decompression and decryption with SHA256 verification +type Reader struct { + reader io.Reader + decompressor *zstd.Decoder + decryptor io.Reader + hasher hash.Hash + teeReader io.Reader + bytesRead int64 +} + +// NewReader creates a new Reader that decrypts, decompresses, and verifies data +func NewReader(r io.Reader, identity age.Identity) (*Reader, error) { + // Create decryption reader + decReader, err := age.Decrypt(r, identity) + if err != nil { + return nil, fmt.Errorf("creating decryption reader: %w", err) + } + + // Create decompression reader + decompressor, err := zstd.NewReader(decReader) + if err != nil { + return nil, fmt.Errorf("creating decompression reader: %w", err) + } + + // Create SHA256 hasher + hasher := sha256.New() + + // Create tee reader that reads from decompressor and writes to hasher + teeReader := io.TeeReader(decompressor, hasher) + + return &Reader{ + reader: r, + decompressor: decompressor, + decryptor: decReader, + hasher: hasher, + teeReader: teeReader, + }, nil +} + +// Read implements io.Reader +func (r *Reader) Read(p []byte) (n int, err error) { + n, err = r.teeReader.Read(p) + r.bytesRead += int64(n) + return n, err +} + +// Close closes the decompressor +func (r *Reader) Close() error { + r.decompressor.Close() + return nil +} + +// Sum256 returns the SHA256 hash of all data read +func (r *Reader) Sum256() []byte { + return r.hasher.Sum(nil) +} + +// BytesRead returns the number of uncompressed bytes read +func (r *Reader) BytesRead() int64 { + return r.bytesRead +} diff --git a/internal/blobgen/writer.go b/internal/blobgen/writer.go new file mode 100644 index 0000000..3d64847 --- /dev/null +++ b/internal/blobgen/writer.go @@ -0,0 +1,112 @@ +package blobgen + +import ( + "crypto/sha256" + "fmt" + "hash" + "io" + + "filippo.io/age" + "github.com/klauspost/compress/zstd" +) + +// Writer wraps compression and encryption with SHA256 hashing +type Writer struct { + writer io.Writer // Final destination + compressor *zstd.Encoder // Compression layer + encryptor io.WriteCloser // Encryption layer + hasher hash.Hash // SHA256 hasher + teeWriter io.Writer // Tees data to hasher + compressionLevel int + bytesWritten int64 +} + +// NewWriter creates a new Writer that compresses, encrypts, and hashes data +func NewWriter(w io.Writer, compressionLevel int, recipients []string) (*Writer, error) { + // Validate compression level + if err := validateCompressionLevel(compressionLevel); err != nil { + return nil, err + } + + // Create SHA256 hasher + hasher := sha256.New() + + // Parse recipients + var ageRecipients []age.Recipient + for _, recipient := range recipients { + r, err := age.ParseX25519Recipient(recipient) + if err != nil { + return nil, fmt.Errorf("parsing recipient %s: %w", recipient, err) + } + ageRecipients = append(ageRecipients, r) + } + + // Create encryption writer + encWriter, err := age.Encrypt(w, ageRecipients...) + if err != nil { + return nil, fmt.Errorf("creating encryption writer: %w", err) + } + + // Create compression writer with encryption as destination + compressor, err := zstd.NewWriter(encWriter, + zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(compressionLevel)), + zstd.WithEncoderConcurrency(1), // Use single thread for streaming + ) + if err != nil { + _ = encWriter.Close() + return nil, fmt.Errorf("creating compression writer: %w", err) + } + + // Create tee writer that writes to both compressor and hasher + teeWriter := io.MultiWriter(compressor, hasher) + + return &Writer{ + writer: w, + compressor: compressor, + encryptor: encWriter, + hasher: hasher, + teeWriter: teeWriter, + compressionLevel: compressionLevel, + }, nil +} + +// Write implements io.Writer +func (w *Writer) Write(p []byte) (n int, err error) { + n, err = w.teeWriter.Write(p) + w.bytesWritten += int64(n) + return n, err +} + +// Close closes all layers and returns any errors +func (w *Writer) Close() error { + // Close compressor first + if err := w.compressor.Close(); err != nil { + return fmt.Errorf("closing compressor: %w", err) + } + + // Then close encryptor + if err := w.encryptor.Close(); err != nil { + return fmt.Errorf("closing encryptor: %w", err) + } + + return nil +} + +// Sum256 returns the SHA256 hash of all data written +func (w *Writer) Sum256() []byte { + return w.hasher.Sum(nil) +} + +// BytesWritten returns the number of uncompressed bytes written +func (w *Writer) BytesWritten() int64 { + return w.bytesWritten +} + +func validateCompressionLevel(level int) error { + // Zstd compression levels: 1-19 (default is 3) + // SpeedFastest = 1, SpeedDefault = 3, SpeedBetterCompression = 7, SpeedBestCompression = 11 + if level < 1 || level > 19 { + return fmt.Errorf("invalid compression level %d: must be between 1 and 19", level) + } + return nil +} diff --git a/internal/cli/snapshot.go b/internal/cli/snapshot.go index 267d655..013a039 100644 --- a/internal/cli/snapshot.go +++ b/internal/cli/snapshot.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "os" "path/filepath" "sort" @@ -13,7 +14,6 @@ import ( "git.eeqj.de/sneak/vaultik/internal/backup" "git.eeqj.de/sneak/vaultik/internal/config" - "git.eeqj.de/sneak/vaultik/internal/crypto" "git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/globals" "git.eeqj.de/sneak/vaultik/internal/log" @@ -33,14 +33,18 @@ type SnapshotCreateOptions struct { // SnapshotCreateApp contains all dependencies needed for creating snapshots type SnapshotCreateApp struct { - Globals *globals.Globals - Config *config.Config - Repositories *database.Repositories - ScannerFactory backup.ScannerFactory - S3Client *s3.Client - DB *database.DB - Lifecycle fx.Lifecycle - Shutdowner fx.Shutdowner + Globals *globals.Globals + Config *config.Config + Repositories *database.Repositories + ScannerFactory backup.ScannerFactory + SnapshotManager *backup.SnapshotManager + S3Client *s3.Client + DB *database.DB + Lifecycle fx.Lifecycle + Shutdowner fx.Shutdowner + Stdout io.Writer + Stderr io.Writer + Stdin io.Reader } // SnapshotApp contains dependencies for snapshot commands @@ -106,17 +110,22 @@ specifying a path using --config or by setting VAULTIK_CONFIG to a path.`, s3.Module, fx.Provide(fx.Annotate( func(g *globals.Globals, cfg *config.Config, repos *database.Repositories, - scannerFactory backup.ScannerFactory, s3Client *s3.Client, db *database.DB, + scannerFactory backup.ScannerFactory, snapshotManager *backup.SnapshotManager, + s3Client *s3.Client, db *database.DB, lc fx.Lifecycle, shutdowner fx.Shutdowner) *SnapshotCreateApp { return &SnapshotCreateApp{ - Globals: g, - Config: cfg, - Repositories: repos, - ScannerFactory: scannerFactory, - S3Client: s3Client, - DB: db, - Lifecycle: lc, - Shutdowner: shutdowner, + Globals: g, + Config: cfg, + Repositories: repos, + ScannerFactory: scannerFactory, + SnapshotManager: snapshotManager, + S3Client: s3Client, + DB: db, + Lifecycle: lc, + Shutdowner: shutdowner, + Stdout: os.Stdout, + Stderr: os.Stderr, + Stdin: os.Stdin, } }, )), @@ -181,21 +190,10 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre hostname, _ = os.Hostname() } - // Create encryptor if needed for snapshot manager - var encryptor backup.Encryptor - if len(app.Config.AgeRecipients) > 0 { - cryptoEncryptor, err := crypto.NewEncryptor(app.Config.AgeRecipients) - if err != nil { - return fmt.Errorf("creating encryptor: %w", err) - } - encryptor = cryptoEncryptor - } - - snapshotManager := backup.NewSnapshotManager(app.Repositories, app.S3Client, encryptor) // CRITICAL: This MUST succeed. If we fail to clean up incomplete snapshots, // the deduplication logic will think files from the incomplete snapshot were // already backed up and skip them, resulting in data loss. - if err := snapshotManager.CleanupIncompleteSnapshots(ctx, hostname); err != nil { + if err := app.SnapshotManager.CleanupIncompleteSnapshots(ctx, hostname); err != nil { return fmt.Errorf("cleanup incomplete snapshots: %w", err) } @@ -234,8 +232,10 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre // Perform a single snapshot run log.Notice("Starting snapshot", "source_dirs", len(resolvedDirs)) + _, _ = fmt.Fprintf(app.Stdout, "Starting snapshot with %d source directories\n", len(resolvedDirs)) for i, dir := range resolvedDirs { log.Info("Source directory", "index", i+1, "path", dir) + _, _ = fmt.Fprintf(app.Stdout, "Source directory %d: %s\n", i+1, dir) } // Statistics tracking @@ -250,12 +250,12 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre uploadDuration := time.Duration(0) // Create a new snapshot at the beginning - // (hostname, encryptor, and snapshotManager already created above for cleanup) - snapshotID, err := snapshotManager.CreateSnapshot(ctx, hostname, app.Globals.Version, app.Globals.Commit) + snapshotID, err := app.SnapshotManager.CreateSnapshot(ctx, hostname, app.Globals.Version, app.Globals.Commit) if err != nil { return fmt.Errorf("creating snapshot: %w", err) } log.Info("Created snapshot", "snapshot_id", snapshotID) + _, _ = fmt.Fprintf(app.Stdout, "\nCreated snapshot: %s\n", snapshotID) for _, dir := range resolvedDirs { // Check if context is cancelled @@ -288,6 +288,13 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre "chunks", result.ChunksCreated, "blobs", result.BlobsCreated, "duration", result.EndTime.Sub(result.StartTime)) + + // Human-friendly output + _, _ = fmt.Fprintf(app.Stdout, "\nDirectory: %s\n", dir) + _, _ = fmt.Fprintf(app.Stdout, " Scanned: %d files (%s)\n", result.FilesScanned, humanize.Bytes(uint64(result.BytesScanned))) + _, _ = fmt.Fprintf(app.Stdout, " Skipped: %d files (%s) - already backed up\n", result.FilesSkipped, humanize.Bytes(uint64(result.BytesSkipped))) + _, _ = fmt.Fprintf(app.Stdout, " Created: %d chunks, %d blobs\n", result.ChunksCreated, result.BlobsCreated) + _, _ = fmt.Fprintf(app.Stdout, " Duration: %s\n", result.EndTime.Sub(result.StartTime).Round(time.Millisecond)) } // Get upload statistics from scanner progress if available @@ -312,19 +319,19 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre UploadDurationMs: uploadDuration.Milliseconds(), } - if err := snapshotManager.UpdateSnapshotStatsExtended(ctx, snapshotID, extStats); err != nil { + if err := app.SnapshotManager.UpdateSnapshotStatsExtended(ctx, snapshotID, extStats); err != nil { return fmt.Errorf("updating snapshot stats: %w", err) } // Mark snapshot as complete - if err := snapshotManager.CompleteSnapshot(ctx, snapshotID); err != nil { + if err := app.SnapshotManager.CompleteSnapshot(ctx, snapshotID); err != nil { return fmt.Errorf("completing snapshot: %w", err) } // Export snapshot metadata // Export snapshot metadata without closing the database // The export function should handle its own database connection - if err := snapshotManager.ExportSnapshotMetadata(ctx, app.Config.IndexPath, snapshotID); err != nil { + if err := app.SnapshotManager.ExportSnapshotMetadata(ctx, app.Config.IndexPath, snapshotID); err != nil { return fmt.Errorf("exporting snapshot metadata: %w", err) } @@ -373,29 +380,29 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre } // Print comprehensive summary - log.Notice("=== Snapshot Summary ===") - log.Info("Snapshot ID", "id", snapshotID) - log.Info("Source files", - "total_count", formatNumber(totalFiles), - "total_size", humanize.Bytes(uint64(totalBytesAll))) - log.Info("Changed files", - "count", formatNumber(totalFilesChanged), - "size", humanize.Bytes(uint64(totalBytesChanged))) - log.Info("Unchanged files", - "count", formatNumber(totalFilesSkipped), - "size", humanize.Bytes(uint64(totalBytesSkipped))) - log.Info("Blob storage", - "total_uncompressed", humanize.Bytes(uint64(totalBlobSizeUncompressed)), - "total_compressed", humanize.Bytes(uint64(totalBlobSizeCompressed)), - "compression_ratio", fmt.Sprintf("%.2fx", compressionRatio), - "compression_level", app.Config.CompressionLevel) - log.Info("Upload activity", - "bytes_uploaded", humanize.Bytes(uint64(totalBytesUploaded)), - "blobs_uploaded", totalBlobsUploaded, - "upload_time", formatDuration(uploadDuration), - "avg_speed", avgUploadSpeed) - log.Info("Total time", "duration", formatDuration(snapshotDuration)) - log.Notice("==========================") + _, _ = fmt.Fprintln(app.Stdout, "\n=== Snapshot Summary ===") + _, _ = fmt.Fprintf(app.Stdout, "Snapshot ID: %s\n", snapshotID) + _, _ = fmt.Fprintf(app.Stdout, "Source files: %s files, %s total\n", + formatNumber(totalFiles), + humanize.Bytes(uint64(totalBytesAll))) + _, _ = fmt.Fprintf(app.Stdout, "Changed files: %s files, %s\n", + formatNumber(totalFilesChanged), + humanize.Bytes(uint64(totalBytesChanged))) + _, _ = fmt.Fprintf(app.Stdout, "Unchanged files: %s files, %s\n", + formatNumber(totalFilesSkipped), + humanize.Bytes(uint64(totalBytesSkipped))) + _, _ = fmt.Fprintf(app.Stdout, "Blob storage: %s uncompressed, %s compressed (%.2fx ratio, level %d)\n", + humanize.Bytes(uint64(totalBlobSizeUncompressed)), + humanize.Bytes(uint64(totalBlobSizeCompressed)), + compressionRatio, + app.Config.CompressionLevel) + _, _ = fmt.Fprintf(app.Stdout, "Upload activity: %s uploaded, %d blobs, %s duration, %s avg speed\n", + humanize.Bytes(uint64(totalBytesUploaded)), + totalBlobsUploaded, + formatDuration(uploadDuration), + avgUploadSpeed) + _, _ = fmt.Fprintf(app.Stdout, "Total time: %s\n", formatDuration(snapshotDuration)) + _, _ = fmt.Fprintln(app.Stdout, "==========================") if opts.Prune { log.Info("Pruning enabled - will delete old snapshots after snapshot") @@ -729,13 +736,18 @@ func (app *SnapshotApp) downloadManifest(ctx context.Context, snapshotID string) } defer zr.Close() - // Decode JSON - var manifest []string + // Decode JSON - manifest is an object with a "blobs" field + var manifest struct { + SnapshotID string `json:"snapshot_id"` + Timestamp string `json:"timestamp"` + BlobCount int `json:"blob_count"` + Blobs []string `json:"blobs"` + } if err := json.NewDecoder(zr).Decode(&manifest); err != nil { return nil, fmt.Errorf("decoding manifest: %w", err) } - return manifest, nil + return manifest.Blobs, nil } // deleteSnapshot removes a snapshot and its metadata @@ -765,29 +777,21 @@ func (app *SnapshotApp) deleteSnapshot(ctx context.Context, snapshotID string) e // parseSnapshotTimestamp extracts timestamp from snapshot ID // Format: hostname-20240115-143052Z func parseSnapshotTimestamp(snapshotID string) (time.Time, error) { - // Find the last hyphen to separate hostname from timestamp - lastHyphen := strings.LastIndex(snapshotID, "-") - if lastHyphen == -1 { - return time.Time{}, fmt.Errorf("invalid snapshot ID format") + // The snapshot ID format is: hostname-YYYYMMDD-HHMMSSZ + // We need to find the timestamp part which starts after the hostname + + // Split by hyphen + parts := strings.Split(snapshotID, "-") + if len(parts) < 3 { + return time.Time{}, fmt.Errorf("invalid snapshot ID format: expected hostname-YYYYMMDD-HHMMSSZ") } - // Extract timestamp part (everything after hostname) - timestampPart := snapshotID[lastHyphen+1:] + // The last two parts should be the date and time with Z suffix + dateStr := parts[len(parts)-2] + timeStr := parts[len(parts)-1] - // The timestamp format is YYYYMMDD-HHMMSSZ - // We need to find where the date ends and time begins - if len(timestampPart) < 8 { - return time.Time{}, fmt.Errorf("invalid snapshot ID format: timestamp too short") - } - - // Find where the hostname ends by looking for pattern YYYYMMDD - hostnameEnd := strings.LastIndex(snapshotID[:lastHyphen], "-") - if hostnameEnd == -1 { - return time.Time{}, fmt.Errorf("invalid snapshot ID format: missing date separator") - } - - // Get the full timestamp including date from before the last hyphen - fullTimestamp := snapshotID[hostnameEnd+1:] + // Reconstruct the full timestamp + fullTimestamp := dateStr + "-" + timeStr // Parse the timestamp with Z suffix return time.Parse("20060102-150405Z", fullTimestamp) diff --git a/internal/database/blob_chunks.go b/internal/database/blob_chunks.go index 1fc8f1a..13e6b27 100644 --- a/internal/database/blob_chunks.go +++ b/internal/database/blob_chunks.go @@ -121,3 +121,32 @@ func (r *BlobChunkRepository) GetByChunkHashTx(ctx context.Context, tx *sql.Tx, LogSQL("GetByChunkHashTx", "Found blob", chunkHash, "blob", bc.BlobID) return &bc, nil } + +// DeleteOrphaned deletes blob_chunks entries where either the blob or chunk no longer exists +func (r *BlobChunkRepository) DeleteOrphaned(ctx context.Context) error { + // Delete blob_chunks where the blob doesn't exist + query1 := ` + DELETE FROM blob_chunks + WHERE NOT EXISTS ( + SELECT 1 FROM blobs + WHERE blobs.id = blob_chunks.blob_id + ) + ` + if _, err := r.db.ExecWithLog(ctx, query1); err != nil { + return fmt.Errorf("deleting blob_chunks with missing blobs: %w", err) + } + + // Delete blob_chunks where the chunk doesn't exist + query2 := ` + DELETE FROM blob_chunks + WHERE NOT EXISTS ( + SELECT 1 FROM chunks + WHERE chunks.chunk_hash = blob_chunks.chunk_hash + ) + ` + if _, err := r.db.ExecWithLog(ctx, query2); err != nil { + return fmt.Errorf("deleting blob_chunks with missing chunks: %w", err) + } + + return nil +} diff --git a/internal/database/blob_chunks_test.go b/internal/database/blob_chunks_test.go index e371f2e..a3321db 100644 --- a/internal/database/blob_chunks_test.go +++ b/internal/database/blob_chunks_test.go @@ -30,7 +30,6 @@ func TestBlobChunkRepository(t *testing.T) { for _, chunkHash := range chunks { chunk := &Chunk{ ChunkHash: chunkHash, - SHA256: chunkHash + "-sha", Size: 1024, } err = repos.Chunks.Create(ctx, nil, chunk) @@ -159,7 +158,6 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) { for _, chunkHash := range chunkHashes { chunk := &Chunk{ ChunkHash: chunkHash, - SHA256: chunkHash + "-sha", Size: 1024, } err = repos.Chunks.Create(ctx, nil, chunk) diff --git a/internal/database/cascade_debug_test.go b/internal/database/cascade_debug_test.go index c0591c8..a01aa48 100644 --- a/internal/database/cascade_debug_test.go +++ b/internal/database/cascade_debug_test.go @@ -43,7 +43,6 @@ func TestCascadeDeleteDebug(t *testing.T) { for i := 0; i < 3; i++ { chunk := &Chunk{ ChunkHash: fmt.Sprintf("cascade-chunk-%d", i), - SHA256: fmt.Sprintf("cascade-sha-%d", i), Size: 1024, } err = repos.Chunks.Create(ctx, nil, chunk) diff --git a/internal/database/chunk_files_test.go b/internal/database/chunk_files_test.go index 5bf9299..a0fae5e 100644 --- a/internal/database/chunk_files_test.go +++ b/internal/database/chunk_files_test.go @@ -13,6 +13,7 @@ func TestChunkFileRepository(t *testing.T) { ctx := context.Background() repo := NewChunkFileRepository(db) fileRepo := NewFileRepository(db) + chunksRepo := NewChunkRepository(db) // Create test files first testTime := time.Now().Truncate(time.Second) @@ -46,6 +47,16 @@ func TestChunkFileRepository(t *testing.T) { t.Fatalf("failed to create file2: %v", err) } + // Create chunk first + chunk := &Chunk{ + ChunkHash: "chunk1", + Size: 1024, + } + err = chunksRepo.Create(ctx, nil, chunk) + if err != nil { + t.Fatalf("failed to create chunk: %v", err) + } + // Test Create cf1 := &ChunkFile{ ChunkHash: "chunk1", @@ -121,6 +132,7 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) { ctx := context.Background() repo := NewChunkFileRepository(db) fileRepo := NewFileRepository(db) + chunksRepo := NewChunkRepository(db) // Create test files testTime := time.Now().Truncate(time.Second) @@ -138,6 +150,19 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) { t.Fatalf("failed to create file3: %v", err) } + // Create chunks first + chunks := []string{"chunk1", "chunk2", "chunk3", "chunk4"} + for _, chunkHash := range chunks { + chunk := &Chunk{ + ChunkHash: chunkHash, + Size: 1024, + } + err := chunksRepo.Create(ctx, nil, chunk) + if err != nil { + t.Fatalf("failed to create chunk %s: %v", chunkHash, err) + } + } + // Simulate a scenario where multiple files share chunks // File1: chunk1, chunk2, chunk3 // File2: chunk2, chunk3, chunk4 @@ -183,11 +208,11 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) { } // Test file2 chunks - chunks, err := repo.GetByFileID(ctx, file2.ID) + file2Chunks, err := repo.GetByFileID(ctx, file2.ID) if err != nil { t.Fatalf("failed to get chunks for file2: %v", err) } - if len(chunks) != 3 { - t.Errorf("expected 3 chunks for file2, got %d", len(chunks)) + if len(file2Chunks) != 3 { + t.Errorf("expected 3 chunks for file2, got %d", len(file2Chunks)) } } diff --git a/internal/database/chunks.go b/internal/database/chunks.go index f70ecd8..ed9e25e 100644 --- a/internal/database/chunks.go +++ b/internal/database/chunks.go @@ -18,16 +18,16 @@ func NewChunkRepository(db *DB) *ChunkRepository { func (r *ChunkRepository) Create(ctx context.Context, tx *sql.Tx, chunk *Chunk) error { query := ` - INSERT INTO chunks (chunk_hash, sha256, size) - VALUES (?, ?, ?) + INSERT INTO chunks (chunk_hash, size) + VALUES (?, ?) ON CONFLICT(chunk_hash) DO NOTHING ` var err error if tx != nil { - _, err = tx.ExecContext(ctx, query, chunk.ChunkHash, chunk.SHA256, chunk.Size) + _, err = tx.ExecContext(ctx, query, chunk.ChunkHash, chunk.Size) } else { - _, err = r.db.ExecWithLog(ctx, query, chunk.ChunkHash, chunk.SHA256, chunk.Size) + _, err = r.db.ExecWithLog(ctx, query, chunk.ChunkHash, chunk.Size) } if err != nil { @@ -39,7 +39,7 @@ func (r *ChunkRepository) Create(ctx context.Context, tx *sql.Tx, chunk *Chunk) func (r *ChunkRepository) GetByHash(ctx context.Context, hash string) (*Chunk, error) { query := ` - SELECT chunk_hash, sha256, size + SELECT chunk_hash, size FROM chunks WHERE chunk_hash = ? ` @@ -48,7 +48,6 @@ func (r *ChunkRepository) GetByHash(ctx context.Context, hash string) (*Chunk, e err := r.db.conn.QueryRowContext(ctx, query, hash).Scan( &chunk.ChunkHash, - &chunk.SHA256, &chunk.Size, ) @@ -68,7 +67,7 @@ func (r *ChunkRepository) GetByHashes(ctx context.Context, hashes []string) ([]* } query := ` - SELECT chunk_hash, sha256, size + SELECT chunk_hash, size FROM chunks WHERE chunk_hash IN (` @@ -94,7 +93,6 @@ func (r *ChunkRepository) GetByHashes(ctx context.Context, hashes []string) ([]* err := rows.Scan( &chunk.ChunkHash, - &chunk.SHA256, &chunk.Size, ) if err != nil { @@ -109,7 +107,7 @@ func (r *ChunkRepository) GetByHashes(ctx context.Context, hashes []string) ([]* func (r *ChunkRepository) ListUnpacked(ctx context.Context, limit int) ([]*Chunk, error) { query := ` - SELECT c.chunk_hash, c.sha256, c.size + SELECT c.chunk_hash, c.size FROM chunks c LEFT JOIN blob_chunks bc ON c.chunk_hash = bc.chunk_hash WHERE bc.chunk_hash IS NULL @@ -129,7 +127,6 @@ func (r *ChunkRepository) ListUnpacked(ctx context.Context, limit int) ([]*Chunk err := rows.Scan( &chunk.ChunkHash, - &chunk.SHA256, &chunk.Size, ) if err != nil { diff --git a/internal/database/chunks_ext.go b/internal/database/chunks_ext.go index 57fcaf4..b38c170 100644 --- a/internal/database/chunks_ext.go +++ b/internal/database/chunks_ext.go @@ -7,7 +7,7 @@ import ( func (r *ChunkRepository) List(ctx context.Context) ([]*Chunk, error) { query := ` - SELECT chunk_hash, sha256, size + SELECT chunk_hash, size FROM chunks ORDER BY chunk_hash ` @@ -24,7 +24,6 @@ func (r *ChunkRepository) List(ctx context.Context) ([]*Chunk, error) { err := rows.Scan( &chunk.ChunkHash, - &chunk.SHA256, &chunk.Size, ) if err != nil { diff --git a/internal/database/chunks_test.go b/internal/database/chunks_test.go index 2991d8f..230965c 100644 --- a/internal/database/chunks_test.go +++ b/internal/database/chunks_test.go @@ -15,7 +15,6 @@ func TestChunkRepository(t *testing.T) { // Test Create chunk := &Chunk{ ChunkHash: "chunkhash123", - SHA256: "sha256hash123", Size: 4096, } @@ -35,9 +34,6 @@ func TestChunkRepository(t *testing.T) { if retrieved.ChunkHash != chunk.ChunkHash { t.Errorf("chunk hash mismatch: got %s, want %s", retrieved.ChunkHash, chunk.ChunkHash) } - if retrieved.SHA256 != chunk.SHA256 { - t.Errorf("sha256 mismatch: got %s, want %s", retrieved.SHA256, chunk.SHA256) - } if retrieved.Size != chunk.Size { t.Errorf("size mismatch: got %d, want %d", retrieved.Size, chunk.Size) } @@ -51,7 +47,6 @@ func TestChunkRepository(t *testing.T) { // Test GetByHashes chunk2 := &Chunk{ ChunkHash: "chunkhash456", - SHA256: "sha256hash456", Size: 8192, } err = repo.Create(ctx, nil, chunk2) diff --git a/internal/database/database_test.go b/internal/database/database_test.go index 8c0baa9..65457d1 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test.go @@ -75,8 +75,8 @@ func TestDatabaseConcurrentAccess(t *testing.T) { for i := 0; i < 10; i++ { go func(i int) { - _, err := db.ExecWithLog(ctx, "INSERT INTO chunks (chunk_hash, sha256, size) VALUES (?, ?, ?)", - fmt.Sprintf("hash%d", i), fmt.Sprintf("sha%d", i), i*1024) + _, err := db.ExecWithLog(ctx, "INSERT INTO chunks (chunk_hash, size) VALUES (?, ?)", + fmt.Sprintf("hash%d", i), i*1024) results <- result{index: i, err: err} }(i) } diff --git a/internal/database/file_chunks_test.go b/internal/database/file_chunks_test.go index 6efc1fa..514471f 100644 --- a/internal/database/file_chunks_test.go +++ b/internal/database/file_chunks_test.go @@ -32,6 +32,20 @@ func TestFileChunkRepository(t *testing.T) { t.Fatalf("failed to create file: %v", err) } + // Create chunks first + chunks := []string{"chunk1", "chunk2", "chunk3"} + chunkRepo := NewChunkRepository(db) + for _, chunkHash := range chunks { + chunk := &Chunk{ + ChunkHash: chunkHash, + Size: 1024, + } + err = chunkRepo.Create(ctx, nil, chunk) + if err != nil { + t.Fatalf("failed to create chunk %s: %v", chunkHash, err) + } + } + // Test Create fc1 := &FileChunk{ FileID: file.ID, @@ -66,16 +80,16 @@ func TestFileChunkRepository(t *testing.T) { } // Test GetByFile - chunks, err := repo.GetByFile(ctx, "/test/file.txt") + fileChunks, err := repo.GetByFile(ctx, "/test/file.txt") if err != nil { t.Fatalf("failed to get file chunks: %v", err) } - if len(chunks) != 3 { - t.Errorf("expected 3 chunks, got %d", len(chunks)) + if len(fileChunks) != 3 { + t.Errorf("expected 3 chunks, got %d", len(fileChunks)) } // Verify order - for i, chunk := range chunks { + for i, chunk := range fileChunks { if chunk.Idx != i { t.Errorf("wrong chunk order: expected idx %d, got %d", i, chunk.Idx) } @@ -93,12 +107,12 @@ func TestFileChunkRepository(t *testing.T) { t.Fatalf("failed to delete file chunks: %v", err) } - chunks, err = repo.GetByFileID(ctx, file.ID) + fileChunks, err = repo.GetByFileID(ctx, file.ID) if err != nil { t.Fatalf("failed to get deleted file chunks: %v", err) } - if len(chunks) != 0 { - t.Errorf("expected 0 chunks after delete, got %d", len(chunks)) + if len(fileChunks) != 0 { + t.Errorf("expected 0 chunks after delete, got %d", len(fileChunks)) } } @@ -133,6 +147,22 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) { files[i] = file } + // Create all chunks first + chunkRepo := NewChunkRepository(db) + for i := range files { + for j := 0; j < 2; j++ { + chunkHash := fmt.Sprintf("file%d_chunk%d", i, j) + chunk := &Chunk{ + ChunkHash: chunkHash, + Size: 1024, + } + err := chunkRepo.Create(ctx, nil, chunk) + if err != nil { + t.Fatalf("failed to create chunk %s: %v", chunkHash, err) + } + } + } + // Create chunks for multiple files for i, file := range files { for j := 0; j < 2; j++ { diff --git a/internal/database/files.go b/internal/database/files.go index 95a5de9..c0b1584 100644 --- a/internal/database/files.go +++ b/internal/database/files.go @@ -28,7 +28,6 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err INSERT INTO files (id, path, mtime, ctime, size, mode, uid, gid, link_target) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(path) DO UPDATE SET - id = excluded.id, mtime = excluded.mtime, ctime = excluded.ctime, size = excluded.size, diff --git a/internal/database/models.go b/internal/database/models.go index c7eaac6..d7e0c5e 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -37,11 +37,9 @@ type FileChunk struct { // Chunk represents a data chunk in the deduplication system. // Files are split into chunks which are content-addressed by their hash. -// The ChunkHash is used for deduplication, while SHA256 provides -// an additional verification hash. +// The ChunkHash is the SHA256 hash of the chunk content, used for deduplication. type Chunk struct { ChunkHash string - SHA256 string Size int64 } diff --git a/internal/database/repositories_test.go b/internal/database/repositories_test.go index bf677b8..bbb76b8 100644 --- a/internal/database/repositories_test.go +++ b/internal/database/repositories_test.go @@ -34,7 +34,6 @@ func TestRepositoriesTransaction(t *testing.T) { // Create chunks chunk1 := &Chunk{ ChunkHash: "tx_chunk1", - SHA256: "tx_sha1", Size: 512, } if err := repos.Chunks.Create(ctx, tx, chunk1); err != nil { @@ -43,7 +42,6 @@ func TestRepositoriesTransaction(t *testing.T) { chunk2 := &Chunk{ ChunkHash: "tx_chunk2", - SHA256: "tx_sha2", Size: 512, } if err := repos.Chunks.Create(ctx, tx, chunk2); err != nil { @@ -159,7 +157,6 @@ func TestRepositoriesTransactionRollback(t *testing.T) { // Create a chunk chunk := &Chunk{ ChunkHash: "rollback_chunk", - SHA256: "rollback_sha", Size: 1024, } if err := repos.Chunks.Create(ctx, tx, chunk); err != nil { diff --git a/internal/database/repository_comprehensive_test.go b/internal/database/repository_comprehensive_test.go index 61e1169..ab0328a 100644 --- a/internal/database/repository_comprehensive_test.go +++ b/internal/database/repository_comprehensive_test.go @@ -195,12 +195,10 @@ func TestOrphanedChunkCleanup(t *testing.T) { // Create chunks chunk1 := &Chunk{ ChunkHash: "orphaned-chunk", - SHA256: "orphaned-chunk-sha", Size: 1024, } chunk2 := &Chunk{ ChunkHash: "referenced-chunk", - SHA256: "referenced-chunk-sha", Size: 1024, } @@ -363,7 +361,6 @@ func TestFileChunkRepositoryWithUUIDs(t *testing.T) { for i, chunkHash := range chunks { chunk := &Chunk{ ChunkHash: chunkHash, - SHA256: fmt.Sprintf("sha-%s", chunkHash), Size: 1024, } err = repos.Chunks.Create(ctx, nil, chunk) @@ -447,7 +444,6 @@ func TestChunkFileRepositoryWithUUIDs(t *testing.T) { // Create a chunk that appears in both files (deduplication) chunk := &Chunk{ ChunkHash: "shared-chunk", - SHA256: "shared-chunk-sha", Size: 1024, } err = repos.Chunks.Create(ctx, nil, chunk) @@ -694,7 +690,6 @@ func TestCascadeDelete(t *testing.T) { for i := 0; i < 3; i++ { chunk := &Chunk{ ChunkHash: fmt.Sprintf("cascade-chunk-%d", i), - SHA256: fmt.Sprintf("cascade-sha-%d", i), Size: 1024, } err = repos.Chunks.Create(ctx, nil, chunk) diff --git a/internal/database/repository_edge_cases_test.go b/internal/database/repository_edge_cases_test.go index 8fcbd23..d2999be 100644 --- a/internal/database/repository_edge_cases_test.go +++ b/internal/database/repository_edge_cases_test.go @@ -170,7 +170,6 @@ func TestDuplicateHandling(t *testing.T) { t.Run("duplicate chunk hashes", func(t *testing.T) { chunk := &Chunk{ ChunkHash: "duplicate-chunk", - SHA256: "duplicate-sha", Size: 1024, } @@ -204,7 +203,6 @@ func TestDuplicateHandling(t *testing.T) { chunk := &Chunk{ ChunkHash: "test-chunk-dup", - SHA256: "test-sha-dup", Size: 1024, } err = repos.Chunks.Create(ctx, nil, chunk) diff --git a/internal/database/schema.sql b/internal/database/schema.sql index 2b0a6de..f28f791 100644 --- a/internal/database/schema.sql +++ b/internal/database/schema.sql @@ -24,13 +24,13 @@ CREATE TABLE IF NOT EXISTS file_chunks ( idx INTEGER NOT NULL, chunk_hash TEXT NOT NULL, PRIMARY KEY (file_id, idx), - FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE + FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE, + FOREIGN KEY (chunk_hash) REFERENCES chunks(chunk_hash) ); -- Chunks table: stores unique content-defined chunks CREATE TABLE IF NOT EXISTS chunks ( chunk_hash TEXT PRIMARY KEY, - sha256 TEXT NOT NULL, size INTEGER NOT NULL ); @@ -52,7 +52,8 @@ CREATE TABLE IF NOT EXISTS blob_chunks ( offset INTEGER NOT NULL, length INTEGER NOT NULL, PRIMARY KEY (blob_id, chunk_hash), - FOREIGN KEY (blob_id) REFERENCES blobs(id) + FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE, + FOREIGN KEY (chunk_hash) REFERENCES chunks(chunk_hash) ); -- Chunk files table: reverse mapping of chunks to files @@ -62,6 +63,7 @@ CREATE TABLE IF NOT EXISTS chunk_files ( file_offset INTEGER NOT NULL, length INTEGER NOT NULL, PRIMARY KEY (chunk_hash, file_id), + FOREIGN KEY (chunk_hash) REFERENCES chunks(chunk_hash), FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE ); @@ -91,7 +93,7 @@ CREATE TABLE IF NOT EXISTS snapshot_files ( file_id TEXT NOT NULL, PRIMARY KEY (snapshot_id, file_id), FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE, - FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE + FOREIGN KEY (file_id) REFERENCES files(id) ); -- Snapshot blobs table: maps snapshots to blobs @@ -101,13 +103,16 @@ CREATE TABLE IF NOT EXISTS snapshot_blobs ( blob_hash TEXT NOT NULL, PRIMARY KEY (snapshot_id, blob_id), FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE, - FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE + FOREIGN KEY (blob_id) REFERENCES blobs(id) ); -- Uploads table: tracks blob upload metrics CREATE TABLE IF NOT EXISTS uploads ( blob_hash TEXT PRIMARY KEY, + snapshot_id TEXT NOT NULL, uploaded_at INTEGER NOT NULL, size INTEGER NOT NULL, - duration_ms INTEGER NOT NULL + duration_ms INTEGER NOT NULL, + FOREIGN KEY (blob_hash) REFERENCES blobs(blob_hash), + FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ); \ No newline at end of file diff --git a/internal/database/uploads.go b/internal/database/uploads.go index e0dcb58..b3db507 100644 --- a/internal/database/uploads.go +++ b/internal/database/uploads.go @@ -11,6 +11,7 @@ import ( // Upload represents a blob upload record type Upload struct { BlobHash string + SnapshotID string UploadedAt time.Time Size int64 DurationMs int64 @@ -29,15 +30,15 @@ func NewUploadRepository(conn *sql.DB) *UploadRepository { // Create inserts a new upload record func (r *UploadRepository) Create(ctx context.Context, tx *sql.Tx, upload *Upload) error { query := ` - INSERT INTO uploads (blob_hash, uploaded_at, size, duration_ms) - VALUES (?, ?, ?, ?) + INSERT INTO uploads (blob_hash, snapshot_id, uploaded_at, size, duration_ms) + VALUES (?, ?, ?, ?, ?) ` var err error if tx != nil { - _, err = tx.ExecContext(ctx, query, upload.BlobHash, upload.UploadedAt, upload.Size, upload.DurationMs) + _, err = tx.ExecContext(ctx, query, upload.BlobHash, upload.SnapshotID, upload.UploadedAt, upload.Size, upload.DurationMs) } else { - _, err = r.conn.ExecContext(ctx, query, upload.BlobHash, upload.UploadedAt, upload.Size, upload.DurationMs) + _, err = r.conn.ExecContext(ctx, query, upload.BlobHash, upload.SnapshotID, upload.UploadedAt, upload.Size, upload.DurationMs) } return err @@ -133,3 +134,14 @@ type UploadStats struct { MinDurationMs int64 MaxDurationMs int64 } + +// GetCountBySnapshot returns the count of uploads for a specific snapshot +func (r *UploadRepository) GetCountBySnapshot(ctx context.Context, snapshotID string) (int64, error) { + query := `SELECT COUNT(*) FROM uploads WHERE snapshot_id = ?` + var count int64 + err := r.conn.QueryRowContext(ctx, query, snapshotID).Scan(&count) + if err != nil { + return 0, err + } + return count, nil +}