Fix foreign key constraints and improve snapshot tracking

- Add unified compression/encryption package in internal/blobgen
- Update DATAMODEL.md to reflect current schema implementation
- Refactor snapshot cleanup into well-named methods for clarity
- Add snapshot_id to uploads table to track new blobs per snapshot
- Fix blob count reporting for incremental backups
- Add DeleteOrphaned method to BlobChunkRepository
- Fix cleanup order to respect foreign key constraints
- Update tests to reflect schema changes
This commit is contained in:
Jeffrey Paul 2025-07-26 02:22:25 +02:00
parent 78af626759
commit d3afa65420
28 changed files with 994 additions and 534 deletions

View File

@ -15,14 +15,17 @@ Stores metadata about files in the filesystem being backed up.
**Columns:** **Columns:**
- `id` (TEXT PRIMARY KEY) - UUID for the file record - `id` (TEXT PRIMARY KEY) - UUID for the file record
- `path` (TEXT UNIQUE) - Absolute file path - `path` (TEXT NOT NULL UNIQUE) - Absolute file path
- `mtime` (INTEGER) - Modification time as Unix timestamp - `mtime` (INTEGER NOT NULL) - Modification time as Unix timestamp
- `ctime` (INTEGER) - Change time as Unix timestamp - `ctime` (INTEGER NOT NULL) - Change time as Unix timestamp
- `size` (INTEGER) - File size in bytes - `size` (INTEGER NOT NULL) - File size in bytes
- `mode` (INTEGER) - Unix file permissions and type - `mode` (INTEGER NOT NULL) - Unix file permissions and type
- `uid` (INTEGER) - User ID of file owner - `uid` (INTEGER NOT NULL) - User ID of file owner
- `gid` (INTEGER) - Group ID of file owner - `gid` (INTEGER NOT NULL) - Group ID of file owner
- `link_target` (TEXT) - Symlink target path (empty for regular files) - `link_target` (TEXT) - Symlink target path (NULL for regular files)
**Indexes:**
- `idx_files_path` on `path` for efficient lookups
**Purpose:** Tracks file metadata to detect changes between backup runs. Used for incremental backup decisions. The UUID primary key provides stable references that don't change if files are moved. **Purpose:** Tracks file metadata to detect changes between backup runs. Used for incremental backup decisions. The UUID primary key provides stable references that don't change if files are moved.
@ -31,8 +34,7 @@ Stores information about content-defined chunks created from files.
**Columns:** **Columns:**
- `chunk_hash` (TEXT PRIMARY KEY) - SHA256 hash of chunk content - `chunk_hash` (TEXT PRIMARY KEY) - SHA256 hash of chunk content
- `sha256` (TEXT) - SHA256 hash (currently same as chunk_hash) - `size` (INTEGER NOT NULL) - Chunk size in bytes
- `size` (INTEGER) - Chunk size in bytes
**Purpose:** Enables deduplication by tracking unique chunks across all files. **Purpose:** Enables deduplication by tracking unique chunks across all files.
@ -64,11 +66,11 @@ Stores information about packed, compressed, and encrypted blob files.
**Columns:** **Columns:**
- `id` (TEXT PRIMARY KEY) - UUID assigned when blob creation starts - `id` (TEXT PRIMARY KEY) - UUID assigned when blob creation starts
- `hash` (TEXT) - SHA256 hash of final blob (empty until finalized) - `blob_hash` (TEXT UNIQUE) - SHA256 hash of final blob (NULL until finalized)
- `created_ts` (INTEGER) - Creation timestamp - `created_ts` (INTEGER NOT NULL) - Creation timestamp
- `finished_ts` (INTEGER) - Finalization timestamp (NULL if in progress) - `finished_ts` (INTEGER) - Finalization timestamp (NULL if in progress)
- `uncompressed_size` (INTEGER) - Total size of chunks before compression - `uncompressed_size` (INTEGER NOT NULL DEFAULT 0) - Total size of chunks before compression
- `compressed_size` (INTEGER) - Size after compression and encryption - `compressed_size` (INTEGER NOT NULL DEFAULT 0) - Size after compression and encryption
- `uploaded_ts` (INTEGER) - Upload completion timestamp (NULL if not uploaded) - `uploaded_ts` (INTEGER) - Upload completion timestamp (NULL if not uploaded)
**Purpose:** Tracks blob lifecycle from creation through upload. The UUID primary key allows immediate association of chunks with blobs. **Purpose:** Tracks blob lifecycle from creation through upload. The UUID primary key allows immediate association of chunks with blobs.
@ -134,11 +136,12 @@ Tracks blob upload metrics.
**Columns:** **Columns:**
- `blob_hash` (TEXT PRIMARY KEY) - Hash of uploaded blob - `blob_hash` (TEXT PRIMARY KEY) - Hash of uploaded blob
- `snapshot_id` (TEXT NOT NULL) - The snapshot that triggered this upload (FK to snapshots.id)
- `uploaded_at` (INTEGER) - Upload timestamp - `uploaded_at` (INTEGER) - Upload timestamp
- `size` (INTEGER) - Size of uploaded blob - `size` (INTEGER) - Size of uploaded blob
- `duration_ms` (INTEGER) - Upload duration in milliseconds - `duration_ms` (INTEGER) - Upload duration in milliseconds
**Purpose:** Performance monitoring and upload tracking. **Purpose:** Performance monitoring and tracking which blobs were newly created (uploaded) during each snapshot.
## Data Flow and Operations ## Data Flow and Operations
@ -155,13 +158,13 @@ Tracks blob upload metrics.
- `INSERT INTO chunk_files` - Create reverse mapping - `INSERT INTO chunk_files` - Create reverse mapping
3. **Blob Packing** 3. **Blob Packing**
- `INSERT INTO blobs` - Create blob record with UUID (hash empty) - `INSERT INTO blobs` - Create blob record with UUID (blob_hash NULL)
- `INSERT INTO blob_chunks` - Associate chunks with blob immediately - `INSERT INTO blob_chunks` - Associate chunks with blob immediately
- `UPDATE blobs SET hash = ?, finished_ts = ?` - Finalize blob after packing - `UPDATE blobs SET blob_hash = ?, finished_ts = ?` - Finalize blob after packing
4. **Upload** 4. **Upload**
- `UPDATE blobs SET uploaded_ts = ?` - Mark blob as uploaded - `UPDATE blobs SET uploaded_ts = ?` - Mark blob as uploaded
- `INSERT INTO uploads` - Record upload metrics - `INSERT INTO uploads` - Record upload metrics with snapshot_id
- `INSERT INTO snapshot_blobs` - Associate blob with snapshot - `INSERT INTO snapshot_blobs` - Associate blob with snapshot
5. **Snapshot Completion** 5. **Snapshot Completion**
@ -179,37 +182,56 @@ Tracks blob upload metrics.
- `SELECT * FROM blob_chunks WHERE chunk_hash = ?` - Find existing chunks - `SELECT * FROM blob_chunks WHERE chunk_hash = ?` - Find existing chunks
- `INSERT INTO snapshot_blobs` - Reference existing blobs for unchanged files - `INSERT INTO snapshot_blobs` - Reference existing blobs for unchanged files
### 3. Restore Process ### 3. Snapshot Metadata Export
After a snapshot is completed:
1. Copy database to temporary file
2. Clean temporary database to contain only current snapshot data
3. Export to SQL dump using sqlite3
4. Compress with zstd and encrypt with age
5. Upload to S3 as `metadata/{snapshot-id}/db.zst.age`
6. Generate blob manifest and upload as `metadata/{snapshot-id}/manifest.json.zst.age`
### 4. Restore Process
The restore process doesn't use the local database. Instead: The restore process doesn't use the local database. Instead:
1. Downloads snapshot metadata from S3 1. Downloads snapshot metadata from S3
2. Downloads required blobs based on manifest 2. Downloads required blobs based on manifest
3. Reconstructs files from decrypted and decompressed chunks 3. Reconstructs files from decrypted and decompressed chunks
### 4. Pruning ### 5. Pruning
1. **Identify Unreferenced Blobs** 1. **Identify Unreferenced Blobs**
- Query blobs not referenced by any remaining snapshot - Query blobs not referenced by any remaining snapshot
- Delete from S3 and local database - Delete from S3 and local database
### 6. Incomplete Snapshot Cleanup
Before each backup:
1. Query incomplete snapshots (where `completed_at IS NULL`)
2. Check if metadata exists in S3
3. If no metadata, delete snapshot and all associations
4. Clean up orphaned files, chunks, and blobs
## Repository Pattern ## Repository Pattern
Vaultik uses a repository pattern for database access: Vaultik uses a repository pattern for database access:
- `FileRepository` - CRUD operations for files - `FileRepository` - CRUD operations for files and file metadata
- `ChunkRepository` - CRUD operations for chunks - `ChunkRepository` - CRUD operations for content chunks
- `FileChunkRepository` - Manage file-chunk mappings - `FileChunkRepository` - Manage file-to-chunk mappings
- `BlobRepository` - Manage blob lifecycle - `ChunkFileRepository` - Manage chunk-to-file reverse mappings
- `BlobChunkRepository` - Manage blob-chunk associations - `BlobRepository` - Manage blob lifecycle (creation, finalization, upload)
- `SnapshotRepository` - Manage snapshots - `BlobChunkRepository` - Manage blob-to-chunk associations
- `UploadRepository` - Track upload metrics - `SnapshotRepository` - Manage snapshots and their relationships
- `UploadRepository` - Track blob upload metrics
Each repository provides methods like: Each repository provides methods like:
- `Create()` - Insert new record - `Create()` - Insert new record
- `GetByID()` / `GetByPath()` / `GetByHash()` - Retrieve records - `GetByID()` / `GetByPath()` / `GetByHash()` - Retrieve records
- `Update()` - Update existing records - `Update()` - Update existing records
- `Delete()` - Remove records - `Delete()` - Remove records
- Specialized queries for each entity type - Specialized queries for each entity type (e.g., `DeleteOrphaned()`, `GetIncompleteByHostname()`)
## Transaction Management ## Transaction Management
@ -228,9 +250,9 @@ This ensures consistency, especially important for operations like:
## Performance Considerations ## Performance Considerations
1. **Indexes**: Primary keys are automatically indexed. Additional indexes may be needed for: 1. **Indexes**:
- `blobs.hash` for lookup performance - Primary keys are automatically indexed
- `blob_chunks.chunk_hash` for chunk location queries - `idx_files_path` on `files(path)` for efficient file lookups
2. **Prepared Statements**: All queries use prepared statements for performance and security 2. **Prepared Statements**: All queries use prepared statements for performance and security
@ -240,7 +262,7 @@ This ensures consistency, especially important for operations like:
## Data Integrity ## Data Integrity
1. **Foreign Keys**: Enforced at the application level through repository methods 1. **Foreign Keys**: Enforced through CASCADE DELETE and application-level repository methods
2. **Unique Constraints**: Chunk hashes and file paths are unique 2. **Unique Constraints**: Chunk hashes, file paths, and blob hashes are unique
3. **Null Handling**: Nullable fields clearly indicate in-progress operations 3. **Null Handling**: Nullable fields clearly indicate in-progress operations
4. **Timestamp Tracking**: All major operations record timestamps for auditing 4. **Timestamp Tracking**: All major operations record timestamps for auditing

View File

@ -393,7 +393,6 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
chunk := &database.Chunk{ chunk := &database.Chunk{
ChunkHash: chunkHash, ChunkHash: chunkHash,
SHA256: chunkHash,
Size: int64(n), Size: int64(n),
} }
return b.repos.Chunks.Create(ctx, tx, chunk) return b.repos.Chunks.Create(ctx, tx, chunk)

View File

@ -19,6 +19,7 @@ type ScannerParams struct {
var Module = fx.Module("backup", var Module = fx.Module("backup",
fx.Provide( fx.Provide(
provideScannerFactory, provideScannerFactory,
NewSnapshotManager,
), ),
) )

View File

@ -12,7 +12,6 @@ import (
"git.eeqj.de/sneak/vaultik/internal/blob" "git.eeqj.de/sneak/vaultik/internal/blob"
"git.eeqj.de/sneak/vaultik/internal/chunker" "git.eeqj.de/sneak/vaultik/internal/chunker"
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log" "git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/s3" "git.eeqj.de/sneak/vaultik/internal/s3"
@ -86,17 +85,11 @@ func NewScanner(cfg ScannerConfig) *Scanner {
return nil return nil
} }
enc, err := crypto.NewEncryptor(cfg.AgeRecipients)
if err != nil {
log.Error("Failed to create encryptor", "error", err)
return nil
}
// Create blob packer with encryption // Create blob packer with encryption
packerCfg := blob.PackerConfig{ packerCfg := blob.PackerConfig{
MaxBlobSize: cfg.MaxBlobSize, MaxBlobSize: cfg.MaxBlobSize,
CompressionLevel: cfg.CompressionLevel, CompressionLevel: cfg.CompressionLevel,
Encryptor: enc, Recipients: cfg.AgeRecipients,
Repositories: cfg.Repositories, Repositories: cfg.Repositories,
} }
packer, err := blob.NewPacker(packerCfg) packer, err := blob.NewPacker(packerCfg)
@ -182,6 +175,18 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
blobs := s.packer.GetFinishedBlobs() blobs := s.packer.GetFinishedBlobs()
result.BlobsCreated += len(blobs) result.BlobsCreated += len(blobs)
// Query database for actual blob count created during this snapshot
// The database is authoritative, especially for concurrent blob uploads
// We count uploads rather than all snapshot_blobs to get only NEW blobs
if s.snapshotID != "" {
uploadCount, err := s.repos.Uploads.GetCountBySnapshot(ctx, s.snapshotID)
if err != nil {
log.Warn("Failed to get upload count from database", "error", err)
} else {
result.BlobsCreated = int(uploadCount)
}
}
result.EndTime = time.Now().UTC() result.EndTime = time.Now().UTC()
return result, nil return result, nil
} }
@ -341,24 +346,22 @@ func (s *Scanner) checkFile(ctx context.Context, path string, info os.FileInfo,
fileChanged := existingFile == nil || s.hasFileChanged(existingFile, file) fileChanged := existingFile == nil || s.hasFileChanged(existingFile, file)
// Update file metadata in a short transaction // Update file metadata and add to snapshot in a single transaction
log.Debug("Updating file metadata", "path", path, "changed", fileChanged) log.Debug("Updating file metadata and adding to snapshot", "path", path, "changed", fileChanged, "snapshot", s.snapshotID)
err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return s.repos.Files.Create(ctx, tx, file) // First create/update the file
if err := s.repos.Files.Create(ctx, tx, file); err != nil {
return fmt.Errorf("creating file: %w", err)
}
// Then add it to the snapshot using the file ID
if err := s.repos.Snapshots.AddFileByID(ctx, tx, s.snapshotID, file.ID); err != nil {
return fmt.Errorf("adding file to snapshot: %w", err)
}
return nil
}) })
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
log.Debug("File metadata updated", "path", path)
// Add file to snapshot in a short transaction
log.Debug("Adding file to snapshot", "path", path, "snapshot", s.snapshotID)
err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return s.repos.Snapshots.AddFile(ctx, tx, s.snapshotID, path)
})
if err != nil {
return nil, false, fmt.Errorf("adding file to snapshot: %w", err)
}
log.Debug("File added to snapshot", "path", path) log.Debug("File added to snapshot", "path", path)
result.FilesScanned++ result.FilesScanned++
@ -542,6 +545,14 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
uploadDuration := time.Since(startTime) uploadDuration := time.Since(startTime)
// Log upload stats
uploadSpeed := float64(finishedBlob.Compressed) * 8 / uploadDuration.Seconds() // bits per second
log.Info("Uploaded blob to S3",
"path", blobPath,
"size", humanize.Bytes(uint64(finishedBlob.Compressed)),
"duration", uploadDuration,
"speed", humanize.SI(uploadSpeed, "bps"))
// Report upload complete // Report upload complete
if s.progress != nil { if s.progress != nil {
s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration) s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration)
@ -574,6 +585,7 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
// Record upload metrics // Record upload metrics
upload := &database.Upload{ upload := &database.Upload{
BlobHash: finishedBlob.Hash, BlobHash: finishedBlob.Hash,
SnapshotID: s.snapshotID,
UploadedAt: startTime, UploadedAt: startTime,
Size: finishedBlob.Compressed, Size: finishedBlob.Compressed,
DurationMs: uploadDuration.Milliseconds(), DurationMs: uploadDuration.Milliseconds(),
@ -645,7 +657,6 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
dbChunk := &database.Chunk{ dbChunk := &database.Chunk{
ChunkHash: chunk.Hash, ChunkHash: chunk.Hash,
SHA256: chunk.Hash,
Size: chunk.Size, Size: chunk.Size,
} }
if err := s.repos.Chunks.Create(txCtx, tx, dbChunk); err != nil { if err := s.repos.Chunks.Create(txCtx, tx, dbChunk); err != nil {

View File

@ -48,32 +48,39 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime"
"time" "time"
"git.eeqj.de/sneak/vaultik/internal/blobgen"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log" "git.eeqj.de/sneak/vaultik/internal/log"
"github.com/klauspost/compress/zstd" "git.eeqj.de/sneak/vaultik/internal/s3"
"github.com/dustin/go-humanize"
"go.uber.org/fx"
) )
// SnapshotManager handles snapshot creation and metadata export // SnapshotManager handles snapshot creation and metadata export
type SnapshotManager struct { type SnapshotManager struct {
repos *database.Repositories repos *database.Repositories
s3Client S3Client s3Client S3Client
encryptor Encryptor config *config.Config
} }
// Encryptor interface for snapshot encryption // SnapshotManagerParams holds dependencies for NewSnapshotManager
type Encryptor interface { type SnapshotManagerParams struct {
Encrypt(data []byte) ([]byte, error) fx.In
Repos *database.Repositories
S3Client *s3.Client
Config *config.Config
} }
// NewSnapshotManager creates a new snapshot manager // NewSnapshotManager creates a new snapshot manager for dependency injection
func NewSnapshotManager(repos *database.Repositories, s3Client S3Client, encryptor Encryptor) *SnapshotManager { func NewSnapshotManager(params SnapshotManagerParams) *SnapshotManager {
return &SnapshotManager{ return &SnapshotManager{
repos: repos, repos: params.Repos,
s3Client: s3Client, s3Client: params.S3Client,
encryptor: encryptor, config: params.Config,
} }
} }
@ -208,11 +215,20 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
log.Debug("Database copy complete", "size", getFileSize(tempDBPath)) log.Debug("Database copy complete", "size", getFileSize(tempDBPath))
// Step 2: Clean the temp database to only contain current snapshot data // Step 2: Clean the temp database to only contain current snapshot data
log.Debug("Cleaning snapshot database to contain only current snapshot", "snapshot_id", snapshotID) log.Debug("Cleaning temporary snapshot database to contain only current snapshot", "snapshot_id", snapshotID, "db_path", tempDBPath)
if err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID); err != nil { stats, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID)
if err != nil {
return fmt.Errorf("cleaning snapshot database: %w", err) return fmt.Errorf("cleaning snapshot database: %w", err)
} }
log.Debug("Database cleaning complete", "size_after_clean", getFileSize(tempDBPath)) log.Info("Snapshot database cleanup complete",
"db_path", tempDBPath,
"size_after_clean", humanize.Bytes(uint64(getFileSize(tempDBPath))),
"files", stats.FileCount,
"chunks", stats.ChunkCount,
"blobs", stats.BlobCount,
"total_compressed_size", humanize.Bytes(uint64(stats.CompressedSize)),
"total_uncompressed_size", humanize.Bytes(uint64(stats.UncompressedSize)),
"compression_ratio", fmt.Sprintf("%.2fx", float64(stats.UncompressedSize)/float64(stats.CompressedSize)))
// Step 3: Dump the cleaned database to SQL // Step 3: Dump the cleaned database to SQL
dumpPath := filepath.Join(tempDir, "snapshot.sql") dumpPath := filepath.Join(tempDir, "snapshot.sql")
@ -222,62 +238,59 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
} }
log.Debug("SQL dump complete", "size", getFileSize(dumpPath)) log.Debug("SQL dump complete", "size", getFileSize(dumpPath))
// Step 4: Compress the SQL dump // Step 4: Compress and encrypt the SQL dump
compressedPath := filepath.Join(tempDir, "snapshot.sql.zst") compressedPath := filepath.Join(tempDir, "snapshot.sql.zst.age")
log.Debug("Compressing SQL dump", "source", dumpPath, "destination", compressedPath) log.Debug("Compressing and encrypting SQL dump", "source", dumpPath, "destination", compressedPath)
if err := sm.compressDump(dumpPath, compressedPath); err != nil { if err := sm.compressDump(dumpPath, compressedPath); err != nil {
return fmt.Errorf("compressing dump: %w", err) return fmt.Errorf("compressing dump: %w", err)
} }
log.Debug("Compression complete", "original_size", getFileSize(dumpPath), "compressed_size", getFileSize(compressedPath)) log.Debug("Compression complete", "original_size", getFileSize(dumpPath), "compressed_size", getFileSize(compressedPath))
// Step 5: Read compressed data for encryption/upload // Step 5: Read compressed and encrypted data for upload
log.Debug("Reading compressed data for upload", "path", compressedPath) log.Debug("Reading compressed and encrypted data for upload", "path", compressedPath)
compressedData, err := os.ReadFile(compressedPath) finalData, err := os.ReadFile(compressedPath)
if err != nil { if err != nil {
return fmt.Errorf("reading compressed dump: %w", err) return fmt.Errorf("reading compressed dump: %w", err)
} }
// Step 6: Encrypt if encryptor is available // Step 6: Generate blob manifest (before closing temp DB)
finalData := compressedData
if sm.encryptor != nil {
log.Debug("Encrypting snapshot data", "size_before", len(compressedData))
encrypted, err := sm.encryptor.Encrypt(compressedData)
if err != nil {
return fmt.Errorf("encrypting snapshot: %w", err)
}
finalData = encrypted
log.Debug("Encryption complete", "size_after", len(encrypted))
} else {
log.Debug("No encryption configured, using compressed data as-is")
}
// Step 7: Generate blob manifest (before closing temp DB)
log.Debug("Generating blob manifest from temporary database", "db_path", tempDBPath) log.Debug("Generating blob manifest from temporary database", "db_path", tempDBPath)
blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID)
if err != nil { if err != nil {
return fmt.Errorf("generating blob manifest: %w", err) return fmt.Errorf("generating blob manifest: %w", err)
} }
// Step 8: Upload to S3 in snapshot subdirectory // Step 7: Upload to S3 in snapshot subdirectory
// Upload database backup (encrypted) // Upload database backup (compressed and encrypted)
dbKey := fmt.Sprintf("metadata/%s/db.zst", snapshotID) dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
if sm.encryptor != nil {
dbKey += ".age"
}
log.Debug("Uploading snapshot database to S3", "key", dbKey, "size", len(finalData)) log.Debug("Uploading snapshot database to S3", "key", dbKey, "size", len(finalData))
dbUploadStart := time.Now()
if err := sm.s3Client.PutObject(ctx, dbKey, bytes.NewReader(finalData)); err != nil { if err := sm.s3Client.PutObject(ctx, dbKey, bytes.NewReader(finalData)); err != nil {
return fmt.Errorf("uploading snapshot database: %w", err) return fmt.Errorf("uploading snapshot database: %w", err)
} }
log.Debug("Database upload complete", "key", dbKey) dbUploadDuration := time.Since(dbUploadStart)
dbUploadSpeed := float64(len(finalData)) * 8 / dbUploadDuration.Seconds() // bits per second
log.Info("Uploaded snapshot database to S3",
"path", dbKey,
"size", humanize.Bytes(uint64(len(finalData))),
"duration", dbUploadDuration,
"speed", humanize.SI(dbUploadSpeed, "bps"))
// Upload blob manifest (unencrypted, compressed) // Upload blob manifest (compressed and encrypted)
manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst.age", snapshotID)
log.Debug("Uploading blob manifest to S3", "key", manifestKey, "size", len(blobManifest)) log.Debug("Uploading blob manifest to S3", "key", manifestKey, "size", len(blobManifest))
manifestUploadStart := time.Now()
if err := sm.s3Client.PutObject(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil { if err := sm.s3Client.PutObject(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil {
return fmt.Errorf("uploading blob manifest: %w", err) return fmt.Errorf("uploading blob manifest: %w", err)
} }
log.Debug("Manifest upload complete", "key", manifestKey) manifestUploadDuration := time.Since(manifestUploadStart)
manifestUploadSpeed := float64(len(blobManifest)) * 8 / manifestUploadDuration.Seconds() // bits per second
log.Info("Uploaded blob manifest to S3",
"path", manifestKey,
"size", humanize.Bytes(uint64(len(blobManifest))),
"duration", manifestUploadDuration,
"speed", humanize.SI(manifestUploadSpeed, "bps"))
log.Info("Uploaded snapshot metadata", log.Info("Uploaded snapshot metadata",
"snapshot_id", snapshotID, "snapshot_id", snapshotID,
@ -286,43 +299,32 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
return nil return nil
} }
// CleanupStats contains statistics about cleaned snapshot database
type CleanupStats struct {
FileCount int
ChunkCount int
BlobCount int
CompressedSize int64
UncompressedSize int64
}
// cleanSnapshotDB removes all data except for the specified snapshot // cleanSnapshotDB removes all data except for the specified snapshot
// //
// Current implementation: // The cleanup is performed in a specific order to maintain referential integrity:
// Since we don't yet have snapshot-file relationships, this currently only // 1. Delete other snapshots
// removes other snapshots. In a complete implementation, it would: // 2. Delete orphaned snapshot associations (snapshot_files, snapshot_blobs) for deleted snapshots
// 3. Delete orphaned files (not in the current snapshot)
// 4. Delete orphaned chunk-to-file mappings (references to deleted files)
// 5. Delete orphaned blobs (not in the current snapshot)
// 6. Delete orphaned blob-to-chunk mappings (references to deleted chunks)
// 7. Delete orphaned chunks (not referenced by any file)
// //
// 1. Delete all snapshots except the current one // Each step is implemented as a separate method for clarity and maintainability.
// 2. Delete files not belonging to the current snapshot func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, snapshotID string) (*CleanupStats, error) {
// 3. Delete file_chunks for deleted files (CASCADE)
// 4. Delete chunk_files for deleted files
// 5. Delete chunks with no remaining file references
// 6. Delete blob_chunks for deleted chunks
// 7. Delete blobs with no remaining chunks
//
// The order is important to maintain referential integrity.
//
// Future implementation when we have snapshot_files table:
//
// DELETE FROM snapshots WHERE id != ?;
// DELETE FROM files WHERE NOT EXISTS (
// SELECT 1 FROM snapshot_files
// WHERE snapshot_files.file_id = files.id
// AND snapshot_files.snapshot_id = ?
// );
// DELETE FROM chunks WHERE NOT EXISTS (
// SELECT 1 FROM file_chunks
// WHERE file_chunks.chunk_hash = chunks.chunk_hash
// );
// DELETE FROM blobs WHERE NOT EXISTS (
// SELECT 1 FROM blob_chunks
// WHERE blob_chunks.blob_hash = blobs.blob_hash
// );
func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, snapshotID string) error {
// Open the temp database // Open the temp database
db, err := database.New(ctx, dbPath) db, err := database.New(ctx, dbPath)
if err != nil { if err != nil {
return fmt.Errorf("opening temp database: %w", err) return nil, fmt.Errorf("opening temp database: %w", err)
} }
defer func() { defer func() {
if err := db.Close(); err != nil { if err := db.Close(); err != nil {
@ -333,7 +335,7 @@ func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, s
// Start a transaction // Start a transaction
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return fmt.Errorf("beginning transaction: %w", err) return nil, fmt.Errorf("beginning transaction: %w", err)
} }
defer func() { defer func() {
if rbErr := tx.Rollback(); rbErr != nil && rbErr != sql.ErrTxDone { if rbErr := tx.Rollback(); rbErr != nil && rbErr != sql.ErrTxDone {
@ -341,123 +343,77 @@ func (sm *SnapshotManager) cleanSnapshotDB(ctx context.Context, dbPath string, s
} }
}() }()
// Step 1: Delete all other snapshots // Execute cleanup steps in order
log.Debug("Deleting other snapshots", "keeping", snapshotID) if err := sm.deleteOtherSnapshots(ctx, tx, snapshotID); err != nil {
database.LogSQL("Execute", "DELETE FROM snapshots WHERE id != ?", snapshotID) return nil, fmt.Errorf("step 1 - delete other snapshots: %w", err)
result, err := tx.ExecContext(ctx, "DELETE FROM snapshots WHERE id != ?", snapshotID)
if err != nil {
return fmt.Errorf("deleting other snapshots: %w", err)
} }
rowsAffected, _ := result.RowsAffected()
log.Debug("Deleted snapshots", "count", rowsAffected)
// Step 2: Delete files not in this snapshot if err := sm.deleteOrphanedSnapshotAssociations(ctx, tx, snapshotID); err != nil {
log.Debug("Deleting files not in current snapshot") return nil, fmt.Errorf("step 2 - delete orphaned snapshot associations: %w", err)
database.LogSQL("Execute", `DELETE FROM files WHERE NOT EXISTS (SELECT 1 FROM snapshot_files WHERE snapshot_files.file_id = files.id AND snapshot_files.snapshot_id = ?)`, snapshotID)
result, err = tx.ExecContext(ctx, `
DELETE FROM files
WHERE NOT EXISTS (
SELECT 1 FROM snapshot_files
WHERE snapshot_files.file_id = files.id
AND snapshot_files.snapshot_id = ?
)`, snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned files: %w", err)
} }
rowsAffected, _ = result.RowsAffected()
log.Debug("Deleted files", "count", rowsAffected)
// Step 3: file_chunks will be deleted via CASCADE from files if err := sm.deleteOrphanedFiles(ctx, tx, snapshotID); err != nil {
log.Debug("file_chunks will be deleted via CASCADE") return nil, fmt.Errorf("step 3 - delete orphaned files: %w", err)
// Step 4: Delete chunk_files for deleted files
log.Debug("Deleting orphaned chunk_files")
database.LogSQL("Execute", `DELETE FROM chunk_files WHERE NOT EXISTS (SELECT 1 FROM files WHERE files.id = chunk_files.file_id)`)
result, err = tx.ExecContext(ctx, `
DELETE FROM chunk_files
WHERE NOT EXISTS (
SELECT 1 FROM files
WHERE files.id = chunk_files.file_id
)`)
if err != nil {
return fmt.Errorf("deleting orphaned chunk_files: %w", err)
} }
rowsAffected, _ = result.RowsAffected()
log.Debug("Deleted chunk_files", "count", rowsAffected)
// Step 5: Delete chunks with no remaining file references if err := sm.deleteOrphanedChunkToFileMappings(ctx, tx); err != nil {
log.Debug("Deleting orphaned chunks") return nil, fmt.Errorf("step 4 - delete orphaned chunk-to-file mappings: %w", err)
database.LogSQL("Execute", `DELETE FROM chunks WHERE NOT EXISTS (SELECT 1 FROM file_chunks WHERE file_chunks.chunk_hash = chunks.chunk_hash)`)
result, err = tx.ExecContext(ctx, `
DELETE FROM chunks
WHERE NOT EXISTS (
SELECT 1 FROM file_chunks
WHERE file_chunks.chunk_hash = chunks.chunk_hash
)`)
if err != nil {
return fmt.Errorf("deleting orphaned chunks: %w", err)
} }
rowsAffected, _ = result.RowsAffected()
log.Debug("Deleted chunks", "count", rowsAffected)
// Step 6: Delete blob_chunks for deleted chunks if err := sm.deleteOrphanedBlobs(ctx, tx, snapshotID); err != nil {
log.Debug("Deleting orphaned blob_chunks") return nil, fmt.Errorf("step 5 - delete orphaned blobs: %w", err)
database.LogSQL("Execute", `DELETE FROM blob_chunks WHERE NOT EXISTS (SELECT 1 FROM chunks WHERE chunks.chunk_hash = blob_chunks.chunk_hash)`)
result, err = tx.ExecContext(ctx, `
DELETE FROM blob_chunks
WHERE NOT EXISTS (
SELECT 1 FROM chunks
WHERE chunks.chunk_hash = blob_chunks.chunk_hash
)`)
if err != nil {
return fmt.Errorf("deleting orphaned blob_chunks: %w", err)
} }
rowsAffected, _ = result.RowsAffected()
log.Debug("Deleted blob_chunks", "count", rowsAffected)
// Step 7: Delete blobs not in this snapshot if err := sm.deleteOrphanedBlobToChunkMappings(ctx, tx); err != nil {
log.Debug("Deleting blobs not in current snapshot") return nil, fmt.Errorf("step 6 - delete orphaned blob-to-chunk mappings: %w", err)
database.LogSQL("Execute", `DELETE FROM blobs WHERE NOT EXISTS (SELECT 1 FROM snapshot_blobs WHERE snapshot_blobs.blob_hash = blobs.blob_hash AND snapshot_blobs.snapshot_id = ?)`, snapshotID)
result, err = tx.ExecContext(ctx, `
DELETE FROM blobs
WHERE NOT EXISTS (
SELECT 1 FROM snapshot_blobs
WHERE snapshot_blobs.blob_hash = blobs.blob_hash
AND snapshot_blobs.snapshot_id = ?
)`, snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned blobs: %w", err)
} }
rowsAffected, _ = result.RowsAffected()
log.Debug("Deleted blobs not in snapshot", "count", rowsAffected)
// Step 8: Delete orphaned snapshot_files and snapshot_blobs if err := sm.deleteOrphanedChunks(ctx, tx); err != nil {
log.Debug("Deleting orphaned snapshot_files") return nil, fmt.Errorf("step 7 - delete orphaned chunks: %w", err)
database.LogSQL("Execute", "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID)
result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_files WHERE snapshot_id != ?", snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned snapshot_files: %w", err)
} }
rowsAffected, _ = result.RowsAffected()
log.Debug("Deleted snapshot_files", "count", rowsAffected)
log.Debug("Deleting orphaned snapshot_blobs")
database.LogSQL("Execute", "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID)
result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", snapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned snapshot_blobs: %w", err)
}
rowsAffected, _ = result.RowsAffected()
log.Debug("Deleted snapshot_blobs", "count", rowsAffected)
// Commit transaction // Commit transaction
log.Debug("Committing cleanup transaction") log.Debug("[Temp DB Cleanup] Committing cleanup transaction")
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err) return nil, fmt.Errorf("committing transaction: %w", err)
} }
log.Debug("Database cleanup complete") // Collect statistics about the cleaned database
return nil stats := &CleanupStats{}
// Count files
var fileCount int
err = db.QueryRowWithLog(ctx, "SELECT COUNT(*) FROM files").Scan(&fileCount)
if err != nil {
return nil, fmt.Errorf("counting files: %w", err)
}
stats.FileCount = fileCount
// Count chunks
var chunkCount int
err = db.QueryRowWithLog(ctx, "SELECT COUNT(*) FROM chunks").Scan(&chunkCount)
if err != nil {
return nil, fmt.Errorf("counting chunks: %w", err)
}
stats.ChunkCount = chunkCount
// Count blobs and get sizes
var blobCount int
var compressedSize, uncompressedSize sql.NullInt64
err = db.QueryRowWithLog(ctx, `
SELECT COUNT(*), COALESCE(SUM(compressed_size), 0), COALESCE(SUM(uncompressed_size), 0)
FROM blobs
WHERE blob_hash IN (SELECT blob_hash FROM snapshot_blobs WHERE snapshot_id = ?)
`, snapshotID).Scan(&blobCount, &compressedSize, &uncompressedSize)
if err != nil {
return nil, fmt.Errorf("counting blobs and sizes: %w", err)
}
stats.BlobCount = blobCount
stats.CompressedSize = compressedSize.Int64
stats.UncompressedSize = uncompressedSize.Int64
log.Debug("[Temp DB Cleanup] Database cleanup complete", "stats", stats)
return stats, nil
} }
// dumpDatabase creates a SQL dump of the database // dumpDatabase creates a SQL dump of the database
@ -492,7 +448,7 @@ func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error {
} }
}() }()
log.Debug("Creating output file for compressed data", "path", outputPath) log.Debug("Creating output file for compressed and encrypted data", "path", outputPath)
output, err := os.Create(outputPath) output, err := os.Create(outputPath)
if err != nil { if err != nil {
return fmt.Errorf("creating output file: %w", err) return fmt.Errorf("creating output file: %w", err)
@ -504,27 +460,30 @@ func (sm *SnapshotManager) compressDump(inputPath, outputPath string) error {
} }
}() }()
// Create zstd encoder with good compression and multithreading // Use blobgen for compression and encryption
log.Debug("Creating zstd compressor", "level", "SpeedBetterCompression", "concurrency", runtime.NumCPU()) log.Debug("Creating compressor/encryptor", "level", sm.config.CompressionLevel)
zstdWriter, err := zstd.NewWriter(output, writer, err := blobgen.NewWriter(output, sm.config.CompressionLevel, sm.config.AgeRecipients)
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
zstd.WithEncoderConcurrency(runtime.NumCPU()),
zstd.WithWindowSize(4<<20), // 4MB window for metadata files
)
if err != nil { if err != nil {
return fmt.Errorf("creating zstd writer: %w", err) return fmt.Errorf("creating blobgen writer: %w", err)
} }
defer func() { defer func() {
if err := zstdWriter.Close(); err != nil { if err := writer.Close(); err != nil {
log.Debug("Failed to close zstd writer", "error", err) log.Debug("Failed to close writer", "error", err)
} }
}() }()
log.Debug("Compressing data") log.Debug("Compressing and encrypting data")
if _, err := io.Copy(zstdWriter, input); err != nil { if _, err := io.Copy(writer, input); err != nil {
return fmt.Errorf("compressing data: %w", err) return fmt.Errorf("compressing data: %w", err)
} }
// Close writer to flush all data
if err := writer.Close(); err != nil {
return fmt.Errorf("closing writer: %w", err)
}
log.Debug("Compression complete", "hash", fmt.Sprintf("%x", writer.Sum256()))
return nil return nil
} }
@ -607,44 +566,28 @@ func (sm *SnapshotManager) generateBlobManifest(ctx context.Context, dbPath stri
} }
log.Debug("JSON manifest created", "size", len(jsonData)) log.Debug("JSON manifest created", "size", len(jsonData))
// Compress with zstd // Compress and encrypt with blobgen
log.Debug("Compressing manifest with zstd") log.Debug("Compressing and encrypting manifest")
compressed, err := compressData(jsonData)
result, err := blobgen.CompressData(jsonData, sm.config.CompressionLevel, sm.config.AgeRecipients)
if err != nil { if err != nil {
return nil, fmt.Errorf("compressing manifest: %w", err) return nil, fmt.Errorf("compressing manifest: %w", err)
} }
log.Debug("Manifest compressed", "original_size", len(jsonData), "compressed_size", len(compressed)) log.Debug("Manifest compressed and encrypted",
"original_size", len(jsonData),
"compressed_size", result.CompressedSize,
"hash", result.SHA256)
log.Info("Generated blob manifest", log.Info("Generated blob manifest",
"snapshot_id", snapshotID, "snapshot_id", snapshotID,
"blob_count", len(blobs), "blob_count", len(blobs),
"json_size", len(jsonData), "json_size", len(jsonData),
"compressed_size", len(compressed)) "compressed_size", result.CompressedSize)
return compressed, nil return result.Data, nil
} }
// compressData compresses data using zstd // compressData compresses data using zstd
func compressData(data []byte) ([]byte, error) {
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf,
zstd.WithEncoderLevel(zstd.SpeedBetterCompression),
)
if err != nil {
return nil, err
}
if _, err := w.Write(data); err != nil {
_ = w.Close()
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// getFileSize returns the size of a file in bytes, or -1 if error // getFileSize returns the size of a file in bytes, or -1 if error
func getFileSize(path string) int64 { func getFileSize(path string) int64 {
@ -738,7 +681,7 @@ func (sm *SnapshotManager) deleteSnapshot(ctx context.Context, snapshotID string
} }
// Clean up orphaned data // Clean up orphaned data
log.Debug("Cleaning up orphaned data") log.Debug("Cleaning up orphaned data in main database")
if err := sm.cleanupOrphanedData(ctx); err != nil { if err := sm.cleanupOrphanedData(ctx); err != nil {
return fmt.Errorf("cleaning up orphaned data: %w", err) return fmt.Errorf("cleaning up orphaned data: %w", err)
} }
@ -748,23 +691,170 @@ func (sm *SnapshotManager) deleteSnapshot(ctx context.Context, snapshotID string
// cleanupOrphanedData removes files, chunks, and blobs that are no longer referenced by any snapshot // cleanupOrphanedData removes files, chunks, and blobs that are no longer referenced by any snapshot
func (sm *SnapshotManager) cleanupOrphanedData(ctx context.Context) error { func (sm *SnapshotManager) cleanupOrphanedData(ctx context.Context) error {
// Order is important to respect foreign key constraints:
// 1. Delete orphaned files (will cascade delete file_chunks)
// 2. Delete orphaned blobs (will cascade delete blob_chunks for deleted blobs)
// 3. Delete orphaned blob_chunks (where blob exists but chunk doesn't)
// 4. Delete orphaned chunks (now safe after all blob_chunks are gone)
// Delete orphaned files (files not in any snapshot) // Delete orphaned files (files not in any snapshot)
log.Debug("Deleting orphaned files") log.Debug("Deleting orphaned files")
if err := sm.repos.Files.DeleteOrphaned(ctx); err != nil { if err := sm.repos.Files.DeleteOrphaned(ctx); err != nil {
return fmt.Errorf("deleting orphaned files: %w", err) return fmt.Errorf("deleting orphaned files: %w", err)
} }
// Delete orphaned chunks (chunks not referenced by any file)
log.Debug("Deleting orphaned chunks")
if err := sm.repos.Chunks.DeleteOrphaned(ctx); err != nil {
return fmt.Errorf("deleting orphaned chunks: %w", err)
}
// Delete orphaned blobs (blobs not in any snapshot) // Delete orphaned blobs (blobs not in any snapshot)
// This will cascade delete blob_chunks for deleted blobs
log.Debug("Deleting orphaned blobs") log.Debug("Deleting orphaned blobs")
if err := sm.repos.Blobs.DeleteOrphaned(ctx); err != nil { if err := sm.repos.Blobs.DeleteOrphaned(ctx); err != nil {
return fmt.Errorf("deleting orphaned blobs: %w", err) return fmt.Errorf("deleting orphaned blobs: %w", err)
} }
// Delete orphaned blob_chunks entries
// This handles cases where the blob still exists but chunks were deleted
log.Debug("Deleting orphaned blob_chunks")
if err := sm.repos.BlobChunks.DeleteOrphaned(ctx); err != nil {
return fmt.Errorf("deleting orphaned blob_chunks: %w", err)
}
// Delete orphaned chunks (chunks not referenced by any file)
// This must come after cleaning up blob_chunks to avoid foreign key violations
log.Debug("Deleting orphaned chunks")
if err := sm.repos.Chunks.DeleteOrphaned(ctx); err != nil {
return fmt.Errorf("deleting orphaned chunks: %w", err)
}
return nil
}
// deleteOtherSnapshots deletes all snapshots except the current one
func (sm *SnapshotManager) deleteOtherSnapshots(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error {
log.Debug("[Temp DB Cleanup] Deleting other snapshots", "keeping", currentSnapshotID)
database.LogSQL("Execute", "DELETE FROM snapshots WHERE id != ?", currentSnapshotID)
result, err := tx.ExecContext(ctx, "DELETE FROM snapshots WHERE id != ?", currentSnapshotID)
if err != nil {
return fmt.Errorf("deleting other snapshots: %w", err)
}
rowsAffected, _ := result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted snapshots", "count", rowsAffected)
return nil
}
// deleteOrphanedSnapshotAssociations deletes snapshot_files and snapshot_blobs for deleted snapshots
func (sm *SnapshotManager) deleteOrphanedSnapshotAssociations(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error {
// Delete orphaned snapshot_files
log.Debug("[Temp DB Cleanup] Deleting orphaned snapshot_files")
database.LogSQL("Execute", "DELETE FROM snapshot_files WHERE snapshot_id != ?", currentSnapshotID)
result, err := tx.ExecContext(ctx, "DELETE FROM snapshot_files WHERE snapshot_id != ?", currentSnapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned snapshot_files: %w", err)
}
rowsAffected, _ := result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted snapshot_files", "count", rowsAffected)
// Delete orphaned snapshot_blobs
log.Debug("[Temp DB Cleanup] Deleting orphaned snapshot_blobs")
database.LogSQL("Execute", "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", currentSnapshotID)
result, err = tx.ExecContext(ctx, "DELETE FROM snapshot_blobs WHERE snapshot_id != ?", currentSnapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned snapshot_blobs: %w", err)
}
rowsAffected, _ = result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted snapshot_blobs", "count", rowsAffected)
return nil
}
// deleteOrphanedFiles deletes files not in the current snapshot
func (sm *SnapshotManager) deleteOrphanedFiles(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error {
log.Debug("[Temp DB Cleanup] Deleting files not in current snapshot")
database.LogSQL("Execute", `DELETE FROM files WHERE NOT EXISTS (SELECT 1 FROM snapshot_files WHERE snapshot_files.file_id = files.id AND snapshot_files.snapshot_id = ?)`, currentSnapshotID)
result, err := tx.ExecContext(ctx, `
DELETE FROM files
WHERE NOT EXISTS (
SELECT 1 FROM snapshot_files
WHERE snapshot_files.file_id = files.id
AND snapshot_files.snapshot_id = ?
)`, currentSnapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned files: %w", err)
}
rowsAffected, _ := result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted files", "count", rowsAffected)
// Note: file_chunks will be deleted via CASCADE
log.Debug("[Temp DB Cleanup] file_chunks will be deleted via CASCADE")
return nil
}
// deleteOrphanedChunkToFileMappings deletes chunk_files entries for deleted files
func (sm *SnapshotManager) deleteOrphanedChunkToFileMappings(ctx context.Context, tx *sql.Tx) error {
log.Debug("[Temp DB Cleanup] Deleting orphaned chunk_files")
database.LogSQL("Execute", `DELETE FROM chunk_files WHERE NOT EXISTS (SELECT 1 FROM files WHERE files.id = chunk_files.file_id)`)
result, err := tx.ExecContext(ctx, `
DELETE FROM chunk_files
WHERE NOT EXISTS (
SELECT 1 FROM files
WHERE files.id = chunk_files.file_id
)`)
if err != nil {
return fmt.Errorf("deleting orphaned chunk_files: %w", err)
}
rowsAffected, _ := result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted chunk_files", "count", rowsAffected)
return nil
}
// deleteOrphanedBlobs deletes blobs not in the current snapshot
func (sm *SnapshotManager) deleteOrphanedBlobs(ctx context.Context, tx *sql.Tx, currentSnapshotID string) error {
log.Debug("[Temp DB Cleanup] Deleting blobs not in current snapshot")
database.LogSQL("Execute", `DELETE FROM blobs WHERE NOT EXISTS (SELECT 1 FROM snapshot_blobs WHERE snapshot_blobs.blob_hash = blobs.blob_hash AND snapshot_blobs.snapshot_id = ?)`, currentSnapshotID)
result, err := tx.ExecContext(ctx, `
DELETE FROM blobs
WHERE NOT EXISTS (
SELECT 1 FROM snapshot_blobs
WHERE snapshot_blobs.blob_hash = blobs.blob_hash
AND snapshot_blobs.snapshot_id = ?
)`, currentSnapshotID)
if err != nil {
return fmt.Errorf("deleting orphaned blobs: %w", err)
}
rowsAffected, _ := result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted blobs not in snapshot", "count", rowsAffected)
return nil
}
// deleteOrphanedBlobToChunkMappings deletes blob_chunks entries for deleted blobs
func (sm *SnapshotManager) deleteOrphanedBlobToChunkMappings(ctx context.Context, tx *sql.Tx) error {
log.Debug("[Temp DB Cleanup] Deleting orphaned blob_chunks")
database.LogSQL("Execute", `DELETE FROM blob_chunks WHERE NOT EXISTS (SELECT 1 FROM blobs WHERE blobs.id = blob_chunks.blob_id)`)
result, err := tx.ExecContext(ctx, `
DELETE FROM blob_chunks
WHERE NOT EXISTS (
SELECT 1 FROM blobs
WHERE blobs.id = blob_chunks.blob_id
)`)
if err != nil {
return fmt.Errorf("deleting orphaned blob_chunks: %w", err)
}
rowsAffected, _ := result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted blob_chunks", "count", rowsAffected)
return nil
}
// deleteOrphanedChunks deletes chunks not referenced by any file
func (sm *SnapshotManager) deleteOrphanedChunks(ctx context.Context, tx *sql.Tx) error {
log.Debug("[Temp DB Cleanup] Deleting orphaned chunks")
database.LogSQL("Execute", `DELETE FROM chunks WHERE NOT EXISTS (SELECT 1 FROM file_chunks WHERE file_chunks.chunk_hash = chunks.chunk_hash)`)
result, err := tx.ExecContext(ctx, `
DELETE FROM chunks
WHERE NOT EXISTS (
SELECT 1 FROM file_chunks
WHERE file_chunks.chunk_hash = chunks.chunk_hash
)`)
if err != nil {
return fmt.Errorf("deleting orphaned chunks: %w", err)
}
rowsAffected, _ := result.RowsAffected()
log.Debug("[Temp DB Cleanup] Deleted chunks", "count", rowsAffected)
return nil return nil
} }

View File

@ -6,10 +6,16 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log" "git.eeqj.de/sneak/vaultik/internal/log"
) )
const (
// Test age public key for encryption
testAgeRecipient = "age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"
)
func TestCleanSnapshotDBEmptySnapshot(t *testing.T) { func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
// Initialize logger // Initialize logger
log.Initialize(log.Config{}) log.Initialize(log.Config{})
@ -41,7 +47,7 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
// Create some files and chunks not associated with any snapshot // Create some files and chunks not associated with any snapshot
file := &database.File{Path: "/orphan/file.txt", Size: 1000} file := &database.File{Path: "/orphan/file.txt", Size: 1000}
chunk := &database.Chunk{ChunkHash: "orphan-chunk", SHA256: "orphan-chunk", Size: 500} chunk := &database.Chunk{ChunkHash: "orphan-chunk", Size: 500}
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
if err := repos.Files.Create(ctx, tx, file); err != nil { if err := repos.Files.Create(ctx, tx, file); err != nil {
@ -64,9 +70,14 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
t.Fatalf("failed to copy database: %v", err) t.Fatalf("failed to copy database: %v", err)
} }
// Create a mock config for testing
cfg := &config.Config{
CompressionLevel: 3,
AgeRecipients: []string{testAgeRecipient},
}
// Clean the database // Clean the database
sm := &SnapshotManager{} sm := &SnapshotManager{config: cfg}
if err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshot.ID); err != nil { if _, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshot.ID); err != nil {
t.Fatalf("failed to clean snapshot database: %v", err) t.Fatalf("failed to clean snapshot database: %v", err)
} }
@ -136,9 +147,14 @@ func TestCleanSnapshotDBNonExistentSnapshot(t *testing.T) {
t.Fatalf("failed to copy database: %v", err) t.Fatalf("failed to copy database: %v", err)
} }
// Create a mock config for testing
cfg := &config.Config{
CompressionLevel: 3,
AgeRecipients: []string{testAgeRecipient},
}
// Try to clean with non-existent snapshot // Try to clean with non-existent snapshot
sm := &SnapshotManager{} sm := &SnapshotManager{config: cfg}
err = sm.cleanSnapshotDB(ctx, tempDBPath, "non-existent-snapshot") _, err = sm.cleanSnapshotDB(ctx, tempDBPath, "non-existent-snapshot")
// Should not error - it will just delete everything // Should not error - it will just delete everything
if err != nil { if err != nil {

View File

@ -16,22 +16,18 @@ package blob
import ( import (
"context" "context"
"crypto/sha256"
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"hash"
"io" "io"
"math/bits"
"os" "os"
"runtime"
"sync" "sync"
"time" "time"
"git.eeqj.de/sneak/vaultik/internal/blobgen"
"git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log" "git.eeqj.de/sneak/vaultik/internal/log"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/klauspost/compress/zstd"
) )
// BlobHandler is a callback function invoked when a blob is finalized and ready for upload. // BlobHandler is a callback function invoked when a blob is finalized and ready for upload.
@ -45,7 +41,7 @@ type BlobHandler func(blob *BlobWithReader) error
type PackerConfig struct { type PackerConfig struct {
MaxBlobSize int64 // Maximum size of a blob before forcing finalization MaxBlobSize int64 // Maximum size of a blob before forcing finalization
CompressionLevel int // Zstd compression level (1-19, higher = better compression) CompressionLevel int // Zstd compression level (1-19, higher = better compression)
Encryptor Encryptor // Age encryptor for blob encryption (required) Recipients []string // Age recipients for encryption
Repositories *database.Repositories // Database repositories for tracking blob metadata Repositories *database.Repositories // Database repositories for tracking blob metadata
BlobHandler BlobHandler // Optional callback when blob is ready for upload BlobHandler BlobHandler // Optional callback when blob is ready for upload
} }
@ -56,7 +52,7 @@ type PackerConfig struct {
type Packer struct { type Packer struct {
maxBlobSize int64 maxBlobSize int64
compressionLevel int compressionLevel int
encryptor Encryptor // Required - blobs are always encrypted recipients []string // Age recipients for encryption
blobHandler BlobHandler // Called when blob is ready blobHandler BlobHandler // Called when blob is ready
repos *database.Repositories // For creating blob records repos *database.Repositories // For creating blob records
@ -68,25 +64,15 @@ type Packer struct {
finishedBlobs []*FinishedBlob // Only used if no handler provided finishedBlobs []*FinishedBlob // Only used if no handler provided
} }
// Encryptor interface for encryption support
type Encryptor interface {
Encrypt(data []byte) ([]byte, error)
EncryptWriter(dst io.Writer) (io.WriteCloser, error)
}
// blobInProgress represents a blob being assembled // blobInProgress represents a blob being assembled
type blobInProgress struct { type blobInProgress struct {
id string // UUID of the blob id string // UUID of the blob
chunks []*chunkInfo // Track chunk metadata chunks []*chunkInfo // Track chunk metadata
chunkSet map[string]bool // Track unique chunks in this blob chunkSet map[string]bool // Track unique chunks in this blob
tempFile *os.File // Temporary file for encrypted compressed data tempFile *os.File // Temporary file for encrypted compressed data
hasher hash.Hash // For computing hash of final encrypted data writer *blobgen.Writer // Unified compression/encryption/hashing writer
compressor io.WriteCloser // Compression writer startTime time.Time
encryptor io.WriteCloser // Encryption writer (if encryption enabled) size int64 // Current uncompressed size
finalWriter io.Writer // The final writer in the chain
startTime time.Time
size int64 // Current uncompressed size
compressedSize int64 // Current compressed size (estimated)
} }
// ChunkRef represents a chunk to be added to a blob. // ChunkRef represents a chunk to be added to a blob.
@ -134,8 +120,8 @@ type BlobWithReader struct {
// The packer will automatically finalize blobs when they reach MaxBlobSize. // The packer will automatically finalize blobs when they reach MaxBlobSize.
// Returns an error if required configuration fields are missing or invalid. // Returns an error if required configuration fields are missing or invalid.
func NewPacker(cfg PackerConfig) (*Packer, error) { func NewPacker(cfg PackerConfig) (*Packer, error) {
if cfg.Encryptor == nil { if len(cfg.Recipients) == 0 {
return nil, fmt.Errorf("encryptor is required - blobs must be encrypted") return nil, fmt.Errorf("recipients are required - blobs must be encrypted")
} }
if cfg.MaxBlobSize <= 0 { if cfg.MaxBlobSize <= 0 {
return nil, fmt.Errorf("max blob size must be positive") return nil, fmt.Errorf("max blob size must be positive")
@ -143,7 +129,7 @@ func NewPacker(cfg PackerConfig) (*Packer, error) {
return &Packer{ return &Packer{
maxBlobSize: cfg.MaxBlobSize, maxBlobSize: cfg.MaxBlobSize,
compressionLevel: cfg.CompressionLevel, compressionLevel: cfg.CompressionLevel,
encryptor: cfg.Encryptor, recipients: cfg.Recipients,
blobHandler: cfg.BlobHandler, blobHandler: cfg.BlobHandler,
repos: cfg.Repositories, repos: cfg.Repositories,
finishedBlobs: make([]*FinishedBlob, 0), finishedBlobs: make([]*FinishedBlob, 0),
@ -274,66 +260,24 @@ func (p *Packer) startNewBlob() error {
return fmt.Errorf("creating temp file: %w", err) return fmt.Errorf("creating temp file: %w", err)
} }
// Create blobgen writer for unified compression/encryption/hashing
writer, err := blobgen.NewWriter(tempFile, p.compressionLevel, p.recipients)
if err != nil {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
return fmt.Errorf("creating blobgen writer: %w", err)
}
p.currentBlob = &blobInProgress{ p.currentBlob = &blobInProgress{
id: blobID, id: blobID,
chunks: make([]*chunkInfo, 0), chunks: make([]*chunkInfo, 0),
chunkSet: make(map[string]bool), chunkSet: make(map[string]bool),
startTime: time.Now().UTC(), startTime: time.Now().UTC(),
tempFile: tempFile, tempFile: tempFile,
hasher: sha256.New(), writer: writer,
size: 0, size: 0,
compressedSize: 0,
} }
// Build writer chain: compressor -> [encryptor ->] hasher+file
// This ensures only encrypted data touches disk
// Final destination: write to both file and hasher
finalWriter := io.MultiWriter(tempFile, p.currentBlob.hasher)
// Set up encryption (required - closest to disk)
encWriter, err := p.encryptor.EncryptWriter(finalWriter)
if err != nil {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
return fmt.Errorf("creating encryption writer: %w", err)
}
p.currentBlob.encryptor = encWriter
currentWriter := encWriter
// Set up compression (processes data before encryption)
encoderLevel := zstd.EncoderLevel(p.compressionLevel)
if p.compressionLevel < 1 {
encoderLevel = zstd.SpeedDefault
} else if p.compressionLevel > 9 {
encoderLevel = zstd.SpeedBestCompression
}
// Calculate window size based on blob size
windowSize := p.maxBlobSize / 100
if windowSize < (1 << 20) { // Min 1MB
windowSize = 1 << 20
} else if windowSize > (128 << 20) { // Max 128MB
windowSize = 128 << 20
}
windowSize = 1 << uint(63-bits.LeadingZeros64(uint64(windowSize)))
compWriter, err := zstd.NewWriter(currentWriter,
zstd.WithEncoderLevel(encoderLevel),
zstd.WithEncoderConcurrency(runtime.NumCPU()),
zstd.WithWindowSize(int(windowSize)),
)
if err != nil {
if p.currentBlob.encryptor != nil {
_ = p.currentBlob.encryptor.Close()
}
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
return fmt.Errorf("creating compression writer: %w", err)
}
p.currentBlob.compressor = compWriter
p.currentBlob.finalWriter = compWriter
log.Debug("Started new blob", "blob_id", blobID, "temp_file", tempFile.Name()) log.Debug("Started new blob", "blob_id", blobID, "temp_file", tempFile.Name())
return nil return nil
} }
@ -349,8 +293,8 @@ func (p *Packer) addChunkToCurrentBlob(chunk *ChunkRef) error {
// Track offset before writing // Track offset before writing
offset := p.currentBlob.size offset := p.currentBlob.size
// Write to the final writer (compression -> encryption -> disk) // Write to the blobgen writer (compression -> encryption -> disk)
if _, err := p.currentBlob.finalWriter.Write(chunk.Data); err != nil { if _, err := p.currentBlob.writer.Write(chunk.Data); err != nil {
return fmt.Errorf("writing to blob stream: %w", err) return fmt.Errorf("writing to blob stream: %w", err)
} }
@ -402,16 +346,10 @@ func (p *Packer) finalizeCurrentBlob() error {
return nil return nil
} }
// Close compression writer to flush all data // Close blobgen writer to flush all data
if err := p.currentBlob.compressor.Close(); err != nil { if err := p.currentBlob.writer.Close(); err != nil {
p.cleanupTempFile() p.cleanupTempFile()
return fmt.Errorf("closing compression writer: %w", err) return fmt.Errorf("closing blobgen writer: %w", err)
}
// Close encryption writer
if err := p.currentBlob.encryptor.Close(); err != nil {
p.cleanupTempFile()
return fmt.Errorf("closing encryption writer: %w", err)
} }
// Sync file to ensure all data is written // Sync file to ensure all data is written
@ -433,8 +371,8 @@ func (p *Packer) finalizeCurrentBlob() error {
return fmt.Errorf("seeking to start: %w", err) return fmt.Errorf("seeking to start: %w", err)
} }
// Get hash from hasher (of final encrypted data) // Get hash from blobgen writer (of final encrypted data)
finalHash := p.currentBlob.hasher.Sum(nil) finalHash := p.currentBlob.writer.Sum256()
blobHash := hex.EncodeToString(finalHash) blobHash := hex.EncodeToString(finalHash)
// Create chunk references with offsets // Create chunk references with offsets

View File

@ -2,13 +2,14 @@ package blob
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"database/sql"
"encoding/hex" "encoding/hex"
"io" "io"
"testing" "testing"
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log" "git.eeqj.de/sneak/vaultik/internal/log"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
@ -30,12 +31,6 @@ func TestPacker(t *testing.T) {
t.Fatalf("failed to parse test identity: %v", err) t.Fatalf("failed to parse test identity: %v", err)
} }
// Create test encryptor using the public key
enc, err := crypto.NewEncryptor([]string{testPublicKey})
if err != nil {
t.Fatalf("failed to create encryptor: %v", err)
}
t.Run("single chunk creates single blob", func(t *testing.T) { t.Run("single chunk creates single blob", func(t *testing.T) {
// Create test database // Create test database
db, err := database.NewTestDB() db, err := database.NewTestDB()
@ -48,7 +43,7 @@ func TestPacker(t *testing.T) {
cfg := PackerConfig{ cfg := PackerConfig{
MaxBlobSize: 10 * 1024 * 1024, // 10MB MaxBlobSize: 10 * 1024 * 1024, // 10MB
CompressionLevel: 3, CompressionLevel: 3,
Encryptor: enc, Recipients: []string{testPublicKey},
Repositories: repos, Repositories: repos,
} }
packer, err := NewPacker(cfg) packer, err := NewPacker(cfg)
@ -59,8 +54,22 @@ func TestPacker(t *testing.T) {
// Create a chunk // Create a chunk
data := []byte("Hello, World!") data := []byte("Hello, World!")
hash := sha256.Sum256(data) hash := sha256.Sum256(data)
hashStr := hex.EncodeToString(hash[:])
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return repos.Chunks.Create(ctx, tx, dbChunk)
})
if err != nil {
t.Fatalf("failed to create chunk in db: %v", err)
}
chunk := &ChunkRef{ chunk := &ChunkRef{
Hash: hex.EncodeToString(hash[:]), Hash: hashStr,
Data: data, Data: data,
} }
@ -123,7 +132,7 @@ func TestPacker(t *testing.T) {
cfg := PackerConfig{ cfg := PackerConfig{
MaxBlobSize: 10 * 1024 * 1024, // 10MB MaxBlobSize: 10 * 1024 * 1024, // 10MB
CompressionLevel: 3, CompressionLevel: 3,
Encryptor: enc, Recipients: []string{testPublicKey},
Repositories: repos, Repositories: repos,
} }
packer, err := NewPacker(cfg) packer, err := NewPacker(cfg)
@ -136,8 +145,22 @@ func TestPacker(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
data := bytes.Repeat([]byte{byte(i)}, 1000) data := bytes.Repeat([]byte{byte(i)}, 1000)
hash := sha256.Sum256(data) hash := sha256.Sum256(data)
hashStr := hex.EncodeToString(hash[:])
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return repos.Chunks.Create(ctx, tx, dbChunk)
})
if err != nil {
t.Fatalf("failed to create chunk in db: %v", err)
}
chunks[i] = &ChunkRef{ chunks[i] = &ChunkRef{
Hash: hex.EncodeToString(hash[:]), Hash: hashStr,
Data: data, Data: data,
} }
} }
@ -191,7 +214,7 @@ func TestPacker(t *testing.T) {
cfg := PackerConfig{ cfg := PackerConfig{
MaxBlobSize: 5000, // 5KB max MaxBlobSize: 5000, // 5KB max
CompressionLevel: 3, CompressionLevel: 3,
Encryptor: enc, Recipients: []string{testPublicKey},
Repositories: repos, Repositories: repos,
} }
packer, err := NewPacker(cfg) packer, err := NewPacker(cfg)
@ -204,8 +227,22 @@ func TestPacker(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
data := bytes.Repeat([]byte{byte(i)}, 1000) // 1KB each data := bytes.Repeat([]byte{byte(i)}, 1000) // 1KB each
hash := sha256.Sum256(data) hash := sha256.Sum256(data)
hashStr := hex.EncodeToString(hash[:])
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return repos.Chunks.Create(ctx, tx, dbChunk)
})
if err != nil {
t.Fatalf("failed to create chunk in db: %v", err)
}
chunks[i] = &ChunkRef{ chunks[i] = &ChunkRef{
Hash: hex.EncodeToString(hash[:]), Hash: hashStr,
Data: data, Data: data,
} }
} }
@ -265,7 +302,7 @@ func TestPacker(t *testing.T) {
cfg := PackerConfig{ cfg := PackerConfig{
MaxBlobSize: 10 * 1024 * 1024, // 10MB MaxBlobSize: 10 * 1024 * 1024, // 10MB
CompressionLevel: 3, CompressionLevel: 3,
Encryptor: enc, Recipients: []string{testPublicKey},
Repositories: repos, Repositories: repos,
} }
packer, err := NewPacker(cfg) packer, err := NewPacker(cfg)
@ -276,8 +313,22 @@ func TestPacker(t *testing.T) {
// Create test data // Create test data
data := bytes.Repeat([]byte("Test data for encryption!"), 100) data := bytes.Repeat([]byte("Test data for encryption!"), 100)
hash := sha256.Sum256(data) hash := sha256.Sum256(data)
hashStr := hex.EncodeToString(hash[:])
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return repos.Chunks.Create(ctx, tx, dbChunk)
})
if err != nil {
t.Fatalf("failed to create chunk in db: %v", err)
}
chunk := &ChunkRef{ chunk := &ChunkRef{
Hash: hex.EncodeToString(hash[:]), Hash: hashStr,
Data: data, Data: data,
} }

View File

@ -0,0 +1,67 @@
package blobgen
import (
"bytes"
"encoding/hex"
"fmt"
"io"
)
// CompressResult contains the results of compression
type CompressResult struct {
Data []byte
UncompressedSize int64
CompressedSize int64
SHA256 string
}
// CompressData compresses and encrypts data, returning the result with hash
func CompressData(data []byte, compressionLevel int, recipients []string) (*CompressResult, error) {
var buf bytes.Buffer
// Create writer
w, err := NewWriter(&buf, compressionLevel, recipients)
if err != nil {
return nil, fmt.Errorf("creating writer: %w", err)
}
// Write data
if _, err := w.Write(data); err != nil {
_ = w.Close()
return nil, fmt.Errorf("writing data: %w", err)
}
// Close to flush
if err := w.Close(); err != nil {
return nil, fmt.Errorf("closing writer: %w", err)
}
return &CompressResult{
Data: buf.Bytes(),
UncompressedSize: int64(len(data)),
CompressedSize: int64(buf.Len()),
SHA256: hex.EncodeToString(w.Sum256()),
}, nil
}
// CompressStream compresses and encrypts from reader to writer, returning hash
func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipients []string) (written int64, hash string, err error) {
// Create writer
w, err := NewWriter(dst, compressionLevel, recipients)
if err != nil {
return 0, "", fmt.Errorf("creating writer: %w", err)
}
defer func() { _ = w.Close() }()
// Copy data
if _, err := io.Copy(w, src); err != nil {
return 0, "", fmt.Errorf("copying data: %w", err)
}
// Close to flush
if err := w.Close(); err != nil {
return 0, "", fmt.Errorf("closing writer: %w", err)
}
return w.BytesWritten(), hex.EncodeToString(w.Sum256()), nil
}

View File

@ -0,0 +1,73 @@
package blobgen
import (
"crypto/sha256"
"fmt"
"hash"
"io"
"filippo.io/age"
"github.com/klauspost/compress/zstd"
)
// Reader wraps decompression and decryption with SHA256 verification
type Reader struct {
reader io.Reader
decompressor *zstd.Decoder
decryptor io.Reader
hasher hash.Hash
teeReader io.Reader
bytesRead int64
}
// NewReader creates a new Reader that decrypts, decompresses, and verifies data
func NewReader(r io.Reader, identity age.Identity) (*Reader, error) {
// Create decryption reader
decReader, err := age.Decrypt(r, identity)
if err != nil {
return nil, fmt.Errorf("creating decryption reader: %w", err)
}
// Create decompression reader
decompressor, err := zstd.NewReader(decReader)
if err != nil {
return nil, fmt.Errorf("creating decompression reader: %w", err)
}
// Create SHA256 hasher
hasher := sha256.New()
// Create tee reader that reads from decompressor and writes to hasher
teeReader := io.TeeReader(decompressor, hasher)
return &Reader{
reader: r,
decompressor: decompressor,
decryptor: decReader,
hasher: hasher,
teeReader: teeReader,
}, nil
}
// Read implements io.Reader
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.teeReader.Read(p)
r.bytesRead += int64(n)
return n, err
}
// Close closes the decompressor
func (r *Reader) Close() error {
r.decompressor.Close()
return nil
}
// Sum256 returns the SHA256 hash of all data read
func (r *Reader) Sum256() []byte {
return r.hasher.Sum(nil)
}
// BytesRead returns the number of uncompressed bytes read
func (r *Reader) BytesRead() int64 {
return r.bytesRead
}

112
internal/blobgen/writer.go Normal file
View File

@ -0,0 +1,112 @@
package blobgen
import (
"crypto/sha256"
"fmt"
"hash"
"io"
"filippo.io/age"
"github.com/klauspost/compress/zstd"
)
// Writer wraps compression and encryption with SHA256 hashing
type Writer struct {
writer io.Writer // Final destination
compressor *zstd.Encoder // Compression layer
encryptor io.WriteCloser // Encryption layer
hasher hash.Hash // SHA256 hasher
teeWriter io.Writer // Tees data to hasher
compressionLevel int
bytesWritten int64
}
// NewWriter creates a new Writer that compresses, encrypts, and hashes data
func NewWriter(w io.Writer, compressionLevel int, recipients []string) (*Writer, error) {
// Validate compression level
if err := validateCompressionLevel(compressionLevel); err != nil {
return nil, err
}
// Create SHA256 hasher
hasher := sha256.New()
// Parse recipients
var ageRecipients []age.Recipient
for _, recipient := range recipients {
r, err := age.ParseX25519Recipient(recipient)
if err != nil {
return nil, fmt.Errorf("parsing recipient %s: %w", recipient, err)
}
ageRecipients = append(ageRecipients, r)
}
// Create encryption writer
encWriter, err := age.Encrypt(w, ageRecipients...)
if err != nil {
return nil, fmt.Errorf("creating encryption writer: %w", err)
}
// Create compression writer with encryption as destination
compressor, err := zstd.NewWriter(encWriter,
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(compressionLevel)),
zstd.WithEncoderConcurrency(1), // Use single thread for streaming
)
if err != nil {
_ = encWriter.Close()
return nil, fmt.Errorf("creating compression writer: %w", err)
}
// Create tee writer that writes to both compressor and hasher
teeWriter := io.MultiWriter(compressor, hasher)
return &Writer{
writer: w,
compressor: compressor,
encryptor: encWriter,
hasher: hasher,
teeWriter: teeWriter,
compressionLevel: compressionLevel,
}, nil
}
// Write implements io.Writer
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.teeWriter.Write(p)
w.bytesWritten += int64(n)
return n, err
}
// Close closes all layers and returns any errors
func (w *Writer) Close() error {
// Close compressor first
if err := w.compressor.Close(); err != nil {
return fmt.Errorf("closing compressor: %w", err)
}
// Then close encryptor
if err := w.encryptor.Close(); err != nil {
return fmt.Errorf("closing encryptor: %w", err)
}
return nil
}
// Sum256 returns the SHA256 hash of all data written
func (w *Writer) Sum256() []byte {
return w.hasher.Sum(nil)
}
// BytesWritten returns the number of uncompressed bytes written
func (w *Writer) BytesWritten() int64 {
return w.bytesWritten
}
func validateCompressionLevel(level int) error {
// Zstd compression levels: 1-19 (default is 3)
// SpeedFastest = 1, SpeedDefault = 3, SpeedBetterCompression = 7, SpeedBestCompression = 11
if level < 1 || level > 19 {
return fmt.Errorf("invalid compression level %d: must be between 1 and 19", level)
}
return nil
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
@ -13,7 +14,6 @@ import (
"git.eeqj.de/sneak/vaultik/internal/backup" "git.eeqj.de/sneak/vaultik/internal/backup"
"git.eeqj.de/sneak/vaultik/internal/config" "git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database" "git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals" "git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log" "git.eeqj.de/sneak/vaultik/internal/log"
@ -33,14 +33,18 @@ type SnapshotCreateOptions struct {
// SnapshotCreateApp contains all dependencies needed for creating snapshots // SnapshotCreateApp contains all dependencies needed for creating snapshots
type SnapshotCreateApp struct { type SnapshotCreateApp struct {
Globals *globals.Globals Globals *globals.Globals
Config *config.Config Config *config.Config
Repositories *database.Repositories Repositories *database.Repositories
ScannerFactory backup.ScannerFactory ScannerFactory backup.ScannerFactory
S3Client *s3.Client SnapshotManager *backup.SnapshotManager
DB *database.DB S3Client *s3.Client
Lifecycle fx.Lifecycle DB *database.DB
Shutdowner fx.Shutdowner Lifecycle fx.Lifecycle
Shutdowner fx.Shutdowner
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
} }
// SnapshotApp contains dependencies for snapshot commands // SnapshotApp contains dependencies for snapshot commands
@ -106,17 +110,22 @@ specifying a path using --config or by setting VAULTIK_CONFIG to a path.`,
s3.Module, s3.Module,
fx.Provide(fx.Annotate( fx.Provide(fx.Annotate(
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories, func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
scannerFactory backup.ScannerFactory, s3Client *s3.Client, db *database.DB, scannerFactory backup.ScannerFactory, snapshotManager *backup.SnapshotManager,
s3Client *s3.Client, db *database.DB,
lc fx.Lifecycle, shutdowner fx.Shutdowner) *SnapshotCreateApp { lc fx.Lifecycle, shutdowner fx.Shutdowner) *SnapshotCreateApp {
return &SnapshotCreateApp{ return &SnapshotCreateApp{
Globals: g, Globals: g,
Config: cfg, Config: cfg,
Repositories: repos, Repositories: repos,
ScannerFactory: scannerFactory, ScannerFactory: scannerFactory,
S3Client: s3Client, SnapshotManager: snapshotManager,
DB: db, S3Client: s3Client,
Lifecycle: lc, DB: db,
Shutdowner: shutdowner, Lifecycle: lc,
Shutdowner: shutdowner,
Stdout: os.Stdout,
Stderr: os.Stderr,
Stdin: os.Stdin,
} }
}, },
)), )),
@ -181,21 +190,10 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre
hostname, _ = os.Hostname() hostname, _ = os.Hostname()
} }
// Create encryptor if needed for snapshot manager
var encryptor backup.Encryptor
if len(app.Config.AgeRecipients) > 0 {
cryptoEncryptor, err := crypto.NewEncryptor(app.Config.AgeRecipients)
if err != nil {
return fmt.Errorf("creating encryptor: %w", err)
}
encryptor = cryptoEncryptor
}
snapshotManager := backup.NewSnapshotManager(app.Repositories, app.S3Client, encryptor)
// CRITICAL: This MUST succeed. If we fail to clean up incomplete snapshots, // CRITICAL: This MUST succeed. If we fail to clean up incomplete snapshots,
// the deduplication logic will think files from the incomplete snapshot were // the deduplication logic will think files from the incomplete snapshot were
// already backed up and skip them, resulting in data loss. // already backed up and skip them, resulting in data loss.
if err := snapshotManager.CleanupIncompleteSnapshots(ctx, hostname); err != nil { if err := app.SnapshotManager.CleanupIncompleteSnapshots(ctx, hostname); err != nil {
return fmt.Errorf("cleanup incomplete snapshots: %w", err) return fmt.Errorf("cleanup incomplete snapshots: %w", err)
} }
@ -234,8 +232,10 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre
// Perform a single snapshot run // Perform a single snapshot run
log.Notice("Starting snapshot", "source_dirs", len(resolvedDirs)) log.Notice("Starting snapshot", "source_dirs", len(resolvedDirs))
_, _ = fmt.Fprintf(app.Stdout, "Starting snapshot with %d source directories\n", len(resolvedDirs))
for i, dir := range resolvedDirs { for i, dir := range resolvedDirs {
log.Info("Source directory", "index", i+1, "path", dir) log.Info("Source directory", "index", i+1, "path", dir)
_, _ = fmt.Fprintf(app.Stdout, "Source directory %d: %s\n", i+1, dir)
} }
// Statistics tracking // Statistics tracking
@ -250,12 +250,12 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre
uploadDuration := time.Duration(0) uploadDuration := time.Duration(0)
// Create a new snapshot at the beginning // Create a new snapshot at the beginning
// (hostname, encryptor, and snapshotManager already created above for cleanup) snapshotID, err := app.SnapshotManager.CreateSnapshot(ctx, hostname, app.Globals.Version, app.Globals.Commit)
snapshotID, err := snapshotManager.CreateSnapshot(ctx, hostname, app.Globals.Version, app.Globals.Commit)
if err != nil { if err != nil {
return fmt.Errorf("creating snapshot: %w", err) return fmt.Errorf("creating snapshot: %w", err)
} }
log.Info("Created snapshot", "snapshot_id", snapshotID) log.Info("Created snapshot", "snapshot_id", snapshotID)
_, _ = fmt.Fprintf(app.Stdout, "\nCreated snapshot: %s\n", snapshotID)
for _, dir := range resolvedDirs { for _, dir := range resolvedDirs {
// Check if context is cancelled // Check if context is cancelled
@ -288,6 +288,13 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre
"chunks", result.ChunksCreated, "chunks", result.ChunksCreated,
"blobs", result.BlobsCreated, "blobs", result.BlobsCreated,
"duration", result.EndTime.Sub(result.StartTime)) "duration", result.EndTime.Sub(result.StartTime))
// Human-friendly output
_, _ = fmt.Fprintf(app.Stdout, "\nDirectory: %s\n", dir)
_, _ = fmt.Fprintf(app.Stdout, " Scanned: %d files (%s)\n", result.FilesScanned, humanize.Bytes(uint64(result.BytesScanned)))
_, _ = fmt.Fprintf(app.Stdout, " Skipped: %d files (%s) - already backed up\n", result.FilesSkipped, humanize.Bytes(uint64(result.BytesSkipped)))
_, _ = fmt.Fprintf(app.Stdout, " Created: %d chunks, %d blobs\n", result.ChunksCreated, result.BlobsCreated)
_, _ = fmt.Fprintf(app.Stdout, " Duration: %s\n", result.EndTime.Sub(result.StartTime).Round(time.Millisecond))
} }
// Get upload statistics from scanner progress if available // Get upload statistics from scanner progress if available
@ -312,19 +319,19 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre
UploadDurationMs: uploadDuration.Milliseconds(), UploadDurationMs: uploadDuration.Milliseconds(),
} }
if err := snapshotManager.UpdateSnapshotStatsExtended(ctx, snapshotID, extStats); err != nil { if err := app.SnapshotManager.UpdateSnapshotStatsExtended(ctx, snapshotID, extStats); err != nil {
return fmt.Errorf("updating snapshot stats: %w", err) return fmt.Errorf("updating snapshot stats: %w", err)
} }
// Mark snapshot as complete // Mark snapshot as complete
if err := snapshotManager.CompleteSnapshot(ctx, snapshotID); err != nil { if err := app.SnapshotManager.CompleteSnapshot(ctx, snapshotID); err != nil {
return fmt.Errorf("completing snapshot: %w", err) return fmt.Errorf("completing snapshot: %w", err)
} }
// Export snapshot metadata // Export snapshot metadata
// Export snapshot metadata without closing the database // Export snapshot metadata without closing the database
// The export function should handle its own database connection // The export function should handle its own database connection
if err := snapshotManager.ExportSnapshotMetadata(ctx, app.Config.IndexPath, snapshotID); err != nil { if err := app.SnapshotManager.ExportSnapshotMetadata(ctx, app.Config.IndexPath, snapshotID); err != nil {
return fmt.Errorf("exporting snapshot metadata: %w", err) return fmt.Errorf("exporting snapshot metadata: %w", err)
} }
@ -373,29 +380,29 @@ func (app *SnapshotCreateApp) runSnapshot(ctx context.Context, opts *SnapshotCre
} }
// Print comprehensive summary // Print comprehensive summary
log.Notice("=== Snapshot Summary ===") _, _ = fmt.Fprintln(app.Stdout, "\n=== Snapshot Summary ===")
log.Info("Snapshot ID", "id", snapshotID) _, _ = fmt.Fprintf(app.Stdout, "Snapshot ID: %s\n", snapshotID)
log.Info("Source files", _, _ = fmt.Fprintf(app.Stdout, "Source files: %s files, %s total\n",
"total_count", formatNumber(totalFiles), formatNumber(totalFiles),
"total_size", humanize.Bytes(uint64(totalBytesAll))) humanize.Bytes(uint64(totalBytesAll)))
log.Info("Changed files", _, _ = fmt.Fprintf(app.Stdout, "Changed files: %s files, %s\n",
"count", formatNumber(totalFilesChanged), formatNumber(totalFilesChanged),
"size", humanize.Bytes(uint64(totalBytesChanged))) humanize.Bytes(uint64(totalBytesChanged)))
log.Info("Unchanged files", _, _ = fmt.Fprintf(app.Stdout, "Unchanged files: %s files, %s\n",
"count", formatNumber(totalFilesSkipped), formatNumber(totalFilesSkipped),
"size", humanize.Bytes(uint64(totalBytesSkipped))) humanize.Bytes(uint64(totalBytesSkipped)))
log.Info("Blob storage", _, _ = fmt.Fprintf(app.Stdout, "Blob storage: %s uncompressed, %s compressed (%.2fx ratio, level %d)\n",
"total_uncompressed", humanize.Bytes(uint64(totalBlobSizeUncompressed)), humanize.Bytes(uint64(totalBlobSizeUncompressed)),
"total_compressed", humanize.Bytes(uint64(totalBlobSizeCompressed)), humanize.Bytes(uint64(totalBlobSizeCompressed)),
"compression_ratio", fmt.Sprintf("%.2fx", compressionRatio), compressionRatio,
"compression_level", app.Config.CompressionLevel) app.Config.CompressionLevel)
log.Info("Upload activity", _, _ = fmt.Fprintf(app.Stdout, "Upload activity: %s uploaded, %d blobs, %s duration, %s avg speed\n",
"bytes_uploaded", humanize.Bytes(uint64(totalBytesUploaded)), humanize.Bytes(uint64(totalBytesUploaded)),
"blobs_uploaded", totalBlobsUploaded, totalBlobsUploaded,
"upload_time", formatDuration(uploadDuration), formatDuration(uploadDuration),
"avg_speed", avgUploadSpeed) avgUploadSpeed)
log.Info("Total time", "duration", formatDuration(snapshotDuration)) _, _ = fmt.Fprintf(app.Stdout, "Total time: %s\n", formatDuration(snapshotDuration))
log.Notice("==========================") _, _ = fmt.Fprintln(app.Stdout, "==========================")
if opts.Prune { if opts.Prune {
log.Info("Pruning enabled - will delete old snapshots after snapshot") log.Info("Pruning enabled - will delete old snapshots after snapshot")
@ -729,13 +736,18 @@ func (app *SnapshotApp) downloadManifest(ctx context.Context, snapshotID string)
} }
defer zr.Close() defer zr.Close()
// Decode JSON // Decode JSON - manifest is an object with a "blobs" field
var manifest []string var manifest struct {
SnapshotID string `json:"snapshot_id"`
Timestamp string `json:"timestamp"`
BlobCount int `json:"blob_count"`
Blobs []string `json:"blobs"`
}
if err := json.NewDecoder(zr).Decode(&manifest); err != nil { if err := json.NewDecoder(zr).Decode(&manifest); err != nil {
return nil, fmt.Errorf("decoding manifest: %w", err) return nil, fmt.Errorf("decoding manifest: %w", err)
} }
return manifest, nil return manifest.Blobs, nil
} }
// deleteSnapshot removes a snapshot and its metadata // deleteSnapshot removes a snapshot and its metadata
@ -765,29 +777,21 @@ func (app *SnapshotApp) deleteSnapshot(ctx context.Context, snapshotID string) e
// parseSnapshotTimestamp extracts timestamp from snapshot ID // parseSnapshotTimestamp extracts timestamp from snapshot ID
// Format: hostname-20240115-143052Z // Format: hostname-20240115-143052Z
func parseSnapshotTimestamp(snapshotID string) (time.Time, error) { func parseSnapshotTimestamp(snapshotID string) (time.Time, error) {
// Find the last hyphen to separate hostname from timestamp // The snapshot ID format is: hostname-YYYYMMDD-HHMMSSZ
lastHyphen := strings.LastIndex(snapshotID, "-") // We need to find the timestamp part which starts after the hostname
if lastHyphen == -1 {
return time.Time{}, fmt.Errorf("invalid snapshot ID format") // Split by hyphen
parts := strings.Split(snapshotID, "-")
if len(parts) < 3 {
return time.Time{}, fmt.Errorf("invalid snapshot ID format: expected hostname-YYYYMMDD-HHMMSSZ")
} }
// Extract timestamp part (everything after hostname) // The last two parts should be the date and time with Z suffix
timestampPart := snapshotID[lastHyphen+1:] dateStr := parts[len(parts)-2]
timeStr := parts[len(parts)-1]
// The timestamp format is YYYYMMDD-HHMMSSZ // Reconstruct the full timestamp
// We need to find where the date ends and time begins fullTimestamp := dateStr + "-" + timeStr
if len(timestampPart) < 8 {
return time.Time{}, fmt.Errorf("invalid snapshot ID format: timestamp too short")
}
// Find where the hostname ends by looking for pattern YYYYMMDD
hostnameEnd := strings.LastIndex(snapshotID[:lastHyphen], "-")
if hostnameEnd == -1 {
return time.Time{}, fmt.Errorf("invalid snapshot ID format: missing date separator")
}
// Get the full timestamp including date from before the last hyphen
fullTimestamp := snapshotID[hostnameEnd+1:]
// Parse the timestamp with Z suffix // Parse the timestamp with Z suffix
return time.Parse("20060102-150405Z", fullTimestamp) return time.Parse("20060102-150405Z", fullTimestamp)

View File

@ -121,3 +121,32 @@ func (r *BlobChunkRepository) GetByChunkHashTx(ctx context.Context, tx *sql.Tx,
LogSQL("GetByChunkHashTx", "Found blob", chunkHash, "blob", bc.BlobID) LogSQL("GetByChunkHashTx", "Found blob", chunkHash, "blob", bc.BlobID)
return &bc, nil return &bc, nil
} }
// DeleteOrphaned deletes blob_chunks entries where either the blob or chunk no longer exists
func (r *BlobChunkRepository) DeleteOrphaned(ctx context.Context) error {
// Delete blob_chunks where the blob doesn't exist
query1 := `
DELETE FROM blob_chunks
WHERE NOT EXISTS (
SELECT 1 FROM blobs
WHERE blobs.id = blob_chunks.blob_id
)
`
if _, err := r.db.ExecWithLog(ctx, query1); err != nil {
return fmt.Errorf("deleting blob_chunks with missing blobs: %w", err)
}
// Delete blob_chunks where the chunk doesn't exist
query2 := `
DELETE FROM blob_chunks
WHERE NOT EXISTS (
SELECT 1 FROM chunks
WHERE chunks.chunk_hash = blob_chunks.chunk_hash
)
`
if _, err := r.db.ExecWithLog(ctx, query2); err != nil {
return fmt.Errorf("deleting blob_chunks with missing chunks: %w", err)
}
return nil
}

View File

@ -30,7 +30,6 @@ func TestBlobChunkRepository(t *testing.T) {
for _, chunkHash := range chunks { for _, chunkHash := range chunks {
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: chunkHash, ChunkHash: chunkHash,
SHA256: chunkHash + "-sha",
Size: 1024, Size: 1024,
} }
err = repos.Chunks.Create(ctx, nil, chunk) err = repos.Chunks.Create(ctx, nil, chunk)
@ -159,7 +158,6 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
for _, chunkHash := range chunkHashes { for _, chunkHash := range chunkHashes {
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: chunkHash, ChunkHash: chunkHash,
SHA256: chunkHash + "-sha",
Size: 1024, Size: 1024,
} }
err = repos.Chunks.Create(ctx, nil, chunk) err = repos.Chunks.Create(ctx, nil, chunk)

View File

@ -43,7 +43,6 @@ func TestCascadeDeleteDebug(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: fmt.Sprintf("cascade-chunk-%d", i), ChunkHash: fmt.Sprintf("cascade-chunk-%d", i),
SHA256: fmt.Sprintf("cascade-sha-%d", i),
Size: 1024, Size: 1024,
} }
err = repos.Chunks.Create(ctx, nil, chunk) err = repos.Chunks.Create(ctx, nil, chunk)

View File

@ -13,6 +13,7 @@ func TestChunkFileRepository(t *testing.T) {
ctx := context.Background() ctx := context.Background()
repo := NewChunkFileRepository(db) repo := NewChunkFileRepository(db)
fileRepo := NewFileRepository(db) fileRepo := NewFileRepository(db)
chunksRepo := NewChunkRepository(db)
// Create test files first // Create test files first
testTime := time.Now().Truncate(time.Second) testTime := time.Now().Truncate(time.Second)
@ -46,6 +47,16 @@ func TestChunkFileRepository(t *testing.T) {
t.Fatalf("failed to create file2: %v", err) t.Fatalf("failed to create file2: %v", err)
} }
// Create chunk first
chunk := &Chunk{
ChunkHash: "chunk1",
Size: 1024,
}
err = chunksRepo.Create(ctx, nil, chunk)
if err != nil {
t.Fatalf("failed to create chunk: %v", err)
}
// Test Create // Test Create
cf1 := &ChunkFile{ cf1 := &ChunkFile{
ChunkHash: "chunk1", ChunkHash: "chunk1",
@ -121,6 +132,7 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
ctx := context.Background() ctx := context.Background()
repo := NewChunkFileRepository(db) repo := NewChunkFileRepository(db)
fileRepo := NewFileRepository(db) fileRepo := NewFileRepository(db)
chunksRepo := NewChunkRepository(db)
// Create test files // Create test files
testTime := time.Now().Truncate(time.Second) testTime := time.Now().Truncate(time.Second)
@ -138,6 +150,19 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
t.Fatalf("failed to create file3: %v", err) t.Fatalf("failed to create file3: %v", err)
} }
// Create chunks first
chunks := []string{"chunk1", "chunk2", "chunk3", "chunk4"}
for _, chunkHash := range chunks {
chunk := &Chunk{
ChunkHash: chunkHash,
Size: 1024,
}
err := chunksRepo.Create(ctx, nil, chunk)
if err != nil {
t.Fatalf("failed to create chunk %s: %v", chunkHash, err)
}
}
// Simulate a scenario where multiple files share chunks // Simulate a scenario where multiple files share chunks
// File1: chunk1, chunk2, chunk3 // File1: chunk1, chunk2, chunk3
// File2: chunk2, chunk3, chunk4 // File2: chunk2, chunk3, chunk4
@ -183,11 +208,11 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
} }
// Test file2 chunks // Test file2 chunks
chunks, err := repo.GetByFileID(ctx, file2.ID) file2Chunks, err := repo.GetByFileID(ctx, file2.ID)
if err != nil { if err != nil {
t.Fatalf("failed to get chunks for file2: %v", err) t.Fatalf("failed to get chunks for file2: %v", err)
} }
if len(chunks) != 3 { if len(file2Chunks) != 3 {
t.Errorf("expected 3 chunks for file2, got %d", len(chunks)) t.Errorf("expected 3 chunks for file2, got %d", len(file2Chunks))
} }
} }

View File

@ -18,16 +18,16 @@ func NewChunkRepository(db *DB) *ChunkRepository {
func (r *ChunkRepository) Create(ctx context.Context, tx *sql.Tx, chunk *Chunk) error { func (r *ChunkRepository) Create(ctx context.Context, tx *sql.Tx, chunk *Chunk) error {
query := ` query := `
INSERT INTO chunks (chunk_hash, sha256, size) INSERT INTO chunks (chunk_hash, size)
VALUES (?, ?, ?) VALUES (?, ?)
ON CONFLICT(chunk_hash) DO NOTHING ON CONFLICT(chunk_hash) DO NOTHING
` `
var err error var err error
if tx != nil { if tx != nil {
_, err = tx.ExecContext(ctx, query, chunk.ChunkHash, chunk.SHA256, chunk.Size) _, err = tx.ExecContext(ctx, query, chunk.ChunkHash, chunk.Size)
} else { } else {
_, err = r.db.ExecWithLog(ctx, query, chunk.ChunkHash, chunk.SHA256, chunk.Size) _, err = r.db.ExecWithLog(ctx, query, chunk.ChunkHash, chunk.Size)
} }
if err != nil { if err != nil {
@ -39,7 +39,7 @@ func (r *ChunkRepository) Create(ctx context.Context, tx *sql.Tx, chunk *Chunk)
func (r *ChunkRepository) GetByHash(ctx context.Context, hash string) (*Chunk, error) { func (r *ChunkRepository) GetByHash(ctx context.Context, hash string) (*Chunk, error) {
query := ` query := `
SELECT chunk_hash, sha256, size SELECT chunk_hash, size
FROM chunks FROM chunks
WHERE chunk_hash = ? WHERE chunk_hash = ?
` `
@ -48,7 +48,6 @@ func (r *ChunkRepository) GetByHash(ctx context.Context, hash string) (*Chunk, e
err := r.db.conn.QueryRowContext(ctx, query, hash).Scan( err := r.db.conn.QueryRowContext(ctx, query, hash).Scan(
&chunk.ChunkHash, &chunk.ChunkHash,
&chunk.SHA256,
&chunk.Size, &chunk.Size,
) )
@ -68,7 +67,7 @@ func (r *ChunkRepository) GetByHashes(ctx context.Context, hashes []string) ([]*
} }
query := ` query := `
SELECT chunk_hash, sha256, size SELECT chunk_hash, size
FROM chunks FROM chunks
WHERE chunk_hash IN (` WHERE chunk_hash IN (`
@ -94,7 +93,6 @@ func (r *ChunkRepository) GetByHashes(ctx context.Context, hashes []string) ([]*
err := rows.Scan( err := rows.Scan(
&chunk.ChunkHash, &chunk.ChunkHash,
&chunk.SHA256,
&chunk.Size, &chunk.Size,
) )
if err != nil { if err != nil {
@ -109,7 +107,7 @@ func (r *ChunkRepository) GetByHashes(ctx context.Context, hashes []string) ([]*
func (r *ChunkRepository) ListUnpacked(ctx context.Context, limit int) ([]*Chunk, error) { func (r *ChunkRepository) ListUnpacked(ctx context.Context, limit int) ([]*Chunk, error) {
query := ` query := `
SELECT c.chunk_hash, c.sha256, c.size SELECT c.chunk_hash, c.size
FROM chunks c FROM chunks c
LEFT JOIN blob_chunks bc ON c.chunk_hash = bc.chunk_hash LEFT JOIN blob_chunks bc ON c.chunk_hash = bc.chunk_hash
WHERE bc.chunk_hash IS NULL WHERE bc.chunk_hash IS NULL
@ -129,7 +127,6 @@ func (r *ChunkRepository) ListUnpacked(ctx context.Context, limit int) ([]*Chunk
err := rows.Scan( err := rows.Scan(
&chunk.ChunkHash, &chunk.ChunkHash,
&chunk.SHA256,
&chunk.Size, &chunk.Size,
) )
if err != nil { if err != nil {

View File

@ -7,7 +7,7 @@ import (
func (r *ChunkRepository) List(ctx context.Context) ([]*Chunk, error) { func (r *ChunkRepository) List(ctx context.Context) ([]*Chunk, error) {
query := ` query := `
SELECT chunk_hash, sha256, size SELECT chunk_hash, size
FROM chunks FROM chunks
ORDER BY chunk_hash ORDER BY chunk_hash
` `
@ -24,7 +24,6 @@ func (r *ChunkRepository) List(ctx context.Context) ([]*Chunk, error) {
err := rows.Scan( err := rows.Scan(
&chunk.ChunkHash, &chunk.ChunkHash,
&chunk.SHA256,
&chunk.Size, &chunk.Size,
) )
if err != nil { if err != nil {

View File

@ -15,7 +15,6 @@ func TestChunkRepository(t *testing.T) {
// Test Create // Test Create
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: "chunkhash123", ChunkHash: "chunkhash123",
SHA256: "sha256hash123",
Size: 4096, Size: 4096,
} }
@ -35,9 +34,6 @@ func TestChunkRepository(t *testing.T) {
if retrieved.ChunkHash != chunk.ChunkHash { if retrieved.ChunkHash != chunk.ChunkHash {
t.Errorf("chunk hash mismatch: got %s, want %s", retrieved.ChunkHash, chunk.ChunkHash) t.Errorf("chunk hash mismatch: got %s, want %s", retrieved.ChunkHash, chunk.ChunkHash)
} }
if retrieved.SHA256 != chunk.SHA256 {
t.Errorf("sha256 mismatch: got %s, want %s", retrieved.SHA256, chunk.SHA256)
}
if retrieved.Size != chunk.Size { if retrieved.Size != chunk.Size {
t.Errorf("size mismatch: got %d, want %d", retrieved.Size, chunk.Size) t.Errorf("size mismatch: got %d, want %d", retrieved.Size, chunk.Size)
} }
@ -51,7 +47,6 @@ func TestChunkRepository(t *testing.T) {
// Test GetByHashes // Test GetByHashes
chunk2 := &Chunk{ chunk2 := &Chunk{
ChunkHash: "chunkhash456", ChunkHash: "chunkhash456",
SHA256: "sha256hash456",
Size: 8192, Size: 8192,
} }
err = repo.Create(ctx, nil, chunk2) err = repo.Create(ctx, nil, chunk2)

View File

@ -75,8 +75,8 @@ func TestDatabaseConcurrentAccess(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
go func(i int) { go func(i int) {
_, err := db.ExecWithLog(ctx, "INSERT INTO chunks (chunk_hash, sha256, size) VALUES (?, ?, ?)", _, err := db.ExecWithLog(ctx, "INSERT INTO chunks (chunk_hash, size) VALUES (?, ?)",
fmt.Sprintf("hash%d", i), fmt.Sprintf("sha%d", i), i*1024) fmt.Sprintf("hash%d", i), i*1024)
results <- result{index: i, err: err} results <- result{index: i, err: err}
}(i) }(i)
} }

View File

@ -32,6 +32,20 @@ func TestFileChunkRepository(t *testing.T) {
t.Fatalf("failed to create file: %v", err) t.Fatalf("failed to create file: %v", err)
} }
// Create chunks first
chunks := []string{"chunk1", "chunk2", "chunk3"}
chunkRepo := NewChunkRepository(db)
for _, chunkHash := range chunks {
chunk := &Chunk{
ChunkHash: chunkHash,
Size: 1024,
}
err = chunkRepo.Create(ctx, nil, chunk)
if err != nil {
t.Fatalf("failed to create chunk %s: %v", chunkHash, err)
}
}
// Test Create // Test Create
fc1 := &FileChunk{ fc1 := &FileChunk{
FileID: file.ID, FileID: file.ID,
@ -66,16 +80,16 @@ func TestFileChunkRepository(t *testing.T) {
} }
// Test GetByFile // Test GetByFile
chunks, err := repo.GetByFile(ctx, "/test/file.txt") fileChunks, err := repo.GetByFile(ctx, "/test/file.txt")
if err != nil { if err != nil {
t.Fatalf("failed to get file chunks: %v", err) t.Fatalf("failed to get file chunks: %v", err)
} }
if len(chunks) != 3 { if len(fileChunks) != 3 {
t.Errorf("expected 3 chunks, got %d", len(chunks)) t.Errorf("expected 3 chunks, got %d", len(fileChunks))
} }
// Verify order // Verify order
for i, chunk := range chunks { for i, chunk := range fileChunks {
if chunk.Idx != i { if chunk.Idx != i {
t.Errorf("wrong chunk order: expected idx %d, got %d", i, chunk.Idx) t.Errorf("wrong chunk order: expected idx %d, got %d", i, chunk.Idx)
} }
@ -93,12 +107,12 @@ func TestFileChunkRepository(t *testing.T) {
t.Fatalf("failed to delete file chunks: %v", err) t.Fatalf("failed to delete file chunks: %v", err)
} }
chunks, err = repo.GetByFileID(ctx, file.ID) fileChunks, err = repo.GetByFileID(ctx, file.ID)
if err != nil { if err != nil {
t.Fatalf("failed to get deleted file chunks: %v", err) t.Fatalf("failed to get deleted file chunks: %v", err)
} }
if len(chunks) != 0 { if len(fileChunks) != 0 {
t.Errorf("expected 0 chunks after delete, got %d", len(chunks)) t.Errorf("expected 0 chunks after delete, got %d", len(fileChunks))
} }
} }
@ -133,6 +147,22 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) {
files[i] = file files[i] = file
} }
// Create all chunks first
chunkRepo := NewChunkRepository(db)
for i := range files {
for j := 0; j < 2; j++ {
chunkHash := fmt.Sprintf("file%d_chunk%d", i, j)
chunk := &Chunk{
ChunkHash: chunkHash,
Size: 1024,
}
err := chunkRepo.Create(ctx, nil, chunk)
if err != nil {
t.Fatalf("failed to create chunk %s: %v", chunkHash, err)
}
}
}
// Create chunks for multiple files // Create chunks for multiple files
for i, file := range files { for i, file := range files {
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {

View File

@ -28,7 +28,6 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err
INSERT INTO files (id, path, mtime, ctime, size, mode, uid, gid, link_target) INSERT INTO files (id, path, mtime, ctime, size, mode, uid, gid, link_target)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(path) DO UPDATE SET ON CONFLICT(path) DO UPDATE SET
id = excluded.id,
mtime = excluded.mtime, mtime = excluded.mtime,
ctime = excluded.ctime, ctime = excluded.ctime,
size = excluded.size, size = excluded.size,

View File

@ -37,11 +37,9 @@ type FileChunk struct {
// Chunk represents a data chunk in the deduplication system. // Chunk represents a data chunk in the deduplication system.
// Files are split into chunks which are content-addressed by their hash. // Files are split into chunks which are content-addressed by their hash.
// The ChunkHash is used for deduplication, while SHA256 provides // The ChunkHash is the SHA256 hash of the chunk content, used for deduplication.
// an additional verification hash.
type Chunk struct { type Chunk struct {
ChunkHash string ChunkHash string
SHA256 string
Size int64 Size int64
} }

View File

@ -34,7 +34,6 @@ func TestRepositoriesTransaction(t *testing.T) {
// Create chunks // Create chunks
chunk1 := &Chunk{ chunk1 := &Chunk{
ChunkHash: "tx_chunk1", ChunkHash: "tx_chunk1",
SHA256: "tx_sha1",
Size: 512, Size: 512,
} }
if err := repos.Chunks.Create(ctx, tx, chunk1); err != nil { if err := repos.Chunks.Create(ctx, tx, chunk1); err != nil {
@ -43,7 +42,6 @@ func TestRepositoriesTransaction(t *testing.T) {
chunk2 := &Chunk{ chunk2 := &Chunk{
ChunkHash: "tx_chunk2", ChunkHash: "tx_chunk2",
SHA256: "tx_sha2",
Size: 512, Size: 512,
} }
if err := repos.Chunks.Create(ctx, tx, chunk2); err != nil { if err := repos.Chunks.Create(ctx, tx, chunk2); err != nil {
@ -159,7 +157,6 @@ func TestRepositoriesTransactionRollback(t *testing.T) {
// Create a chunk // Create a chunk
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: "rollback_chunk", ChunkHash: "rollback_chunk",
SHA256: "rollback_sha",
Size: 1024, Size: 1024,
} }
if err := repos.Chunks.Create(ctx, tx, chunk); err != nil { if err := repos.Chunks.Create(ctx, tx, chunk); err != nil {

View File

@ -195,12 +195,10 @@ func TestOrphanedChunkCleanup(t *testing.T) {
// Create chunks // Create chunks
chunk1 := &Chunk{ chunk1 := &Chunk{
ChunkHash: "orphaned-chunk", ChunkHash: "orphaned-chunk",
SHA256: "orphaned-chunk-sha",
Size: 1024, Size: 1024,
} }
chunk2 := &Chunk{ chunk2 := &Chunk{
ChunkHash: "referenced-chunk", ChunkHash: "referenced-chunk",
SHA256: "referenced-chunk-sha",
Size: 1024, Size: 1024,
} }
@ -363,7 +361,6 @@ func TestFileChunkRepositoryWithUUIDs(t *testing.T) {
for i, chunkHash := range chunks { for i, chunkHash := range chunks {
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: chunkHash, ChunkHash: chunkHash,
SHA256: fmt.Sprintf("sha-%s", chunkHash),
Size: 1024, Size: 1024,
} }
err = repos.Chunks.Create(ctx, nil, chunk) err = repos.Chunks.Create(ctx, nil, chunk)
@ -447,7 +444,6 @@ func TestChunkFileRepositoryWithUUIDs(t *testing.T) {
// Create a chunk that appears in both files (deduplication) // Create a chunk that appears in both files (deduplication)
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: "shared-chunk", ChunkHash: "shared-chunk",
SHA256: "shared-chunk-sha",
Size: 1024, Size: 1024,
} }
err = repos.Chunks.Create(ctx, nil, chunk) err = repos.Chunks.Create(ctx, nil, chunk)
@ -694,7 +690,6 @@ func TestCascadeDelete(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: fmt.Sprintf("cascade-chunk-%d", i), ChunkHash: fmt.Sprintf("cascade-chunk-%d", i),
SHA256: fmt.Sprintf("cascade-sha-%d", i),
Size: 1024, Size: 1024,
} }
err = repos.Chunks.Create(ctx, nil, chunk) err = repos.Chunks.Create(ctx, nil, chunk)

View File

@ -170,7 +170,6 @@ func TestDuplicateHandling(t *testing.T) {
t.Run("duplicate chunk hashes", func(t *testing.T) { t.Run("duplicate chunk hashes", func(t *testing.T) {
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: "duplicate-chunk", ChunkHash: "duplicate-chunk",
SHA256: "duplicate-sha",
Size: 1024, Size: 1024,
} }
@ -204,7 +203,6 @@ func TestDuplicateHandling(t *testing.T) {
chunk := &Chunk{ chunk := &Chunk{
ChunkHash: "test-chunk-dup", ChunkHash: "test-chunk-dup",
SHA256: "test-sha-dup",
Size: 1024, Size: 1024,
} }
err = repos.Chunks.Create(ctx, nil, chunk) err = repos.Chunks.Create(ctx, nil, chunk)

View File

@ -24,13 +24,13 @@ CREATE TABLE IF NOT EXISTS file_chunks (
idx INTEGER NOT NULL, idx INTEGER NOT NULL,
chunk_hash TEXT NOT NULL, chunk_hash TEXT NOT NULL,
PRIMARY KEY (file_id, idx), PRIMARY KEY (file_id, idx),
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE,
FOREIGN KEY (chunk_hash) REFERENCES chunks(chunk_hash)
); );
-- Chunks table: stores unique content-defined chunks -- Chunks table: stores unique content-defined chunks
CREATE TABLE IF NOT EXISTS chunks ( CREATE TABLE IF NOT EXISTS chunks (
chunk_hash TEXT PRIMARY KEY, chunk_hash TEXT PRIMARY KEY,
sha256 TEXT NOT NULL,
size INTEGER NOT NULL size INTEGER NOT NULL
); );
@ -52,7 +52,8 @@ CREATE TABLE IF NOT EXISTS blob_chunks (
offset INTEGER NOT NULL, offset INTEGER NOT NULL,
length INTEGER NOT NULL, length INTEGER NOT NULL,
PRIMARY KEY (blob_id, chunk_hash), PRIMARY KEY (blob_id, chunk_hash),
FOREIGN KEY (blob_id) REFERENCES blobs(id) FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE,
FOREIGN KEY (chunk_hash) REFERENCES chunks(chunk_hash)
); );
-- Chunk files table: reverse mapping of chunks to files -- Chunk files table: reverse mapping of chunks to files
@ -62,6 +63,7 @@ CREATE TABLE IF NOT EXISTS chunk_files (
file_offset INTEGER NOT NULL, file_offset INTEGER NOT NULL,
length INTEGER NOT NULL, length INTEGER NOT NULL,
PRIMARY KEY (chunk_hash, file_id), PRIMARY KEY (chunk_hash, file_id),
FOREIGN KEY (chunk_hash) REFERENCES chunks(chunk_hash),
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
); );
@ -91,7 +93,7 @@ CREATE TABLE IF NOT EXISTS snapshot_files (
file_id TEXT NOT NULL, file_id TEXT NOT NULL,
PRIMARY KEY (snapshot_id, file_id), PRIMARY KEY (snapshot_id, file_id),
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE, FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE FOREIGN KEY (file_id) REFERENCES files(id)
); );
-- Snapshot blobs table: maps snapshots to blobs -- Snapshot blobs table: maps snapshots to blobs
@ -101,13 +103,16 @@ CREATE TABLE IF NOT EXISTS snapshot_blobs (
blob_hash TEXT NOT NULL, blob_hash TEXT NOT NULL,
PRIMARY KEY (snapshot_id, blob_id), PRIMARY KEY (snapshot_id, blob_id),
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE, FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE,
FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE FOREIGN KEY (blob_id) REFERENCES blobs(id)
); );
-- Uploads table: tracks blob upload metrics -- Uploads table: tracks blob upload metrics
CREATE TABLE IF NOT EXISTS uploads ( CREATE TABLE IF NOT EXISTS uploads (
blob_hash TEXT PRIMARY KEY, blob_hash TEXT PRIMARY KEY,
snapshot_id TEXT NOT NULL,
uploaded_at INTEGER NOT NULL, uploaded_at INTEGER NOT NULL,
size INTEGER NOT NULL, size INTEGER NOT NULL,
duration_ms INTEGER NOT NULL duration_ms INTEGER NOT NULL,
FOREIGN KEY (blob_hash) REFERENCES blobs(blob_hash),
FOREIGN KEY (snapshot_id) REFERENCES snapshots(id)
); );

View File

@ -11,6 +11,7 @@ import (
// Upload represents a blob upload record // Upload represents a blob upload record
type Upload struct { type Upload struct {
BlobHash string BlobHash string
SnapshotID string
UploadedAt time.Time UploadedAt time.Time
Size int64 Size int64
DurationMs int64 DurationMs int64
@ -29,15 +30,15 @@ func NewUploadRepository(conn *sql.DB) *UploadRepository {
// Create inserts a new upload record // Create inserts a new upload record
func (r *UploadRepository) Create(ctx context.Context, tx *sql.Tx, upload *Upload) error { func (r *UploadRepository) Create(ctx context.Context, tx *sql.Tx, upload *Upload) error {
query := ` query := `
INSERT INTO uploads (blob_hash, uploaded_at, size, duration_ms) INSERT INTO uploads (blob_hash, snapshot_id, uploaded_at, size, duration_ms)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
` `
var err error var err error
if tx != nil { if tx != nil {
_, err = tx.ExecContext(ctx, query, upload.BlobHash, upload.UploadedAt, upload.Size, upload.DurationMs) _, err = tx.ExecContext(ctx, query, upload.BlobHash, upload.SnapshotID, upload.UploadedAt, upload.Size, upload.DurationMs)
} else { } else {
_, err = r.conn.ExecContext(ctx, query, upload.BlobHash, upload.UploadedAt, upload.Size, upload.DurationMs) _, err = r.conn.ExecContext(ctx, query, upload.BlobHash, upload.SnapshotID, upload.UploadedAt, upload.Size, upload.DurationMs)
} }
return err return err
@ -133,3 +134,14 @@ type UploadStats struct {
MinDurationMs int64 MinDurationMs int64
MaxDurationMs int64 MaxDurationMs int64
} }
// GetCountBySnapshot returns the count of uploads for a specific snapshot
func (r *UploadRepository) GetCountBySnapshot(ctx context.Context, snapshotID string) (int64, error) {
query := `SELECT COUNT(*) FROM uploads WHERE snapshot_id = ?`
var count int64
err := r.conn.QueryRowContext(ctx, query, snapshotID).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}