2 Commits

Author SHA1 Message Date
3e282af516 Merge branch 'main' into fix/sql-injection-whitelist 2026-02-20 11:16:27 +01:00
user
bb4b9b5bc9 fix: use whitelist for SQL table names in getTableCount (closes #7)
Replace regex-based validation with a strict whitelist of allowed table
names (files, chunks, blobs). The whitelist check now runs before the
nil-DB early return so invalid names are always rejected.

Removes unused regexp import.
2026-02-20 02:09:40 -08:00
16 changed files with 1247 additions and 1601 deletions

View File

@@ -1,8 +0,0 @@
.git
.gitea
*.md
LICENSE
vaultik
coverage.out
coverage.html
.DS_Store

View File

@@ -1,14 +0,0 @@
name: check
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4, 2024-09-16
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5
- name: Build and check
run: docker build .

View File

@@ -1,61 +0,0 @@
# Lint stage
# golangci/golangci-lint:v2.11.3-alpine, 2026-03-17
FROM golangci/golangci-lint:v2.11.3-alpine@sha256:b1c3de5862ad0a95b4e45a993b0f00415835d687e4f12c845c7493b86c13414e AS lint
RUN apk add --no-cache make build-base
WORKDIR /src
# Copy go mod files first for better layer caching
COPY go.mod go.sum ./
RUN go mod download
# Copy source code
COPY . .
# Run formatting check and linter
RUN make fmt-check
RUN make lint
# Build stage
# golang:1.26.1-alpine, 2026-03-17
FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder
# Depend on lint stage passing
COPY --from=lint /src/go.sum /dev/null
ARG VERSION=dev
# Install build dependencies for CGO (mattn/go-sqlite3) and sqlite3 CLI (tests)
RUN apk add --no-cache make build-base sqlite
WORKDIR /src
# Copy go mod files first for better layer caching
COPY go.mod go.sum ./
RUN go mod download
# Copy source code
COPY . .
# Run tests
RUN make test
# Build with CGO enabled (required for mattn/go-sqlite3)
RUN CGO_ENABLED=1 go build -ldflags "-X 'git.eeqj.de/sneak/vaultik/internal/globals.Version=${VERSION}' -X 'git.eeqj.de/sneak/vaultik/internal/globals.Commit=$(git rev-parse HEAD 2>/dev/null || echo unknown)'" -o /vaultik ./cmd/vaultik
# Runtime stage
# alpine:3.21, 2026-02-25
FROM alpine:3.21@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk add --no-cache ca-certificates sqlite
# Copy binary from builder
COPY --from=builder /vaultik /usr/local/bin/vaultik
# Create non-root user
RUN adduser -D -H -s /sbin/nologin vaultik
USER vaultik
ENTRYPOINT ["/usr/local/bin/vaultik"]

View File

@@ -1,4 +1,4 @@
.PHONY: test fmt lint fmt-check check build clean all docker hooks .PHONY: test fmt lint build clean all
# Version number # Version number
VERSION := 0.0.1 VERSION := 0.0.1
@@ -14,12 +14,21 @@ LDFLAGS := -X 'git.eeqj.de/sneak/vaultik/internal/globals.Version=$(VERSION)' \
all: vaultik all: vaultik
# Run tests # Run tests
test: test: lint fmt-check
go test -race -timeout 30s ./... @echo "Running tests..."
@if ! go test -v -timeout 10s ./... 2>&1; then \
echo ""; \
echo "TEST FAILURES DETECTED"; \
echo "Run 'go test -v ./internal/database' to see database test details"; \
exit 1; \
fi
# Check if code is formatted (read-only) # Check if code is formatted
fmt-check: fmt-check:
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1) @if [ -n "$$(go fmt ./...)" ]; then \
echo "Error: Code is not formatted. Run 'make fmt' to fix."; \
exit 1; \
fi
# Format code # Format code
fmt: fmt:
@@ -27,7 +36,7 @@ fmt:
# Run linter # Run linter
lint: lint:
golangci-lint run ./... golangci-lint run
# Build binary # Build binary
vaultik: internal/*/*.go cmd/vaultik/*.go vaultik: internal/*/*.go cmd/vaultik/*.go
@@ -38,6 +47,11 @@ clean:
rm -f vaultik rm -f vaultik
go clean go clean
# Install dependencies
deps:
go mod download
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Run tests with coverage # Run tests with coverage
test-coverage: test-coverage:
go test -v -coverprofile=coverage.out ./... go test -v -coverprofile=coverage.out ./...
@@ -53,17 +67,3 @@ local:
install: vaultik install: vaultik
cp ./vaultik $(HOME)/bin/ cp ./vaultik $(HOME)/bin/
# Run all checks (formatting, linting, tests) without modifying files
check: fmt-check lint test
# Build Docker image
docker:
docker build -t vaultik .
# Install pre-commit hook
hooks:
@printf '#!/bin/sh\nset -e\n' > .git/hooks/pre-commit
@printf 'go mod tidy\ngo fmt ./...\ngit diff --exit-code -- go.mod go.sum || { echo "go mod tidy changed files; please stage and retry"; exit 1; }\n' >> .git/hooks/pre-commit
@printf 'make check\n' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit

2
go.mod
View File

@@ -1,6 +1,6 @@
module git.eeqj.de/sneak/vaultik module git.eeqj.de/sneak/vaultik
go 1.26.1 go 1.24.4
require ( require (
filippo.io/age v1.2.1 filippo.io/age v1.2.1

View File

@@ -361,23 +361,101 @@ func (p *Packer) finalizeCurrentBlob() error {
return nil return nil
} }
blobHash, finalSize, err := p.closeBlobWriter() // Close blobgen writer to flush all data
if err := p.currentBlob.writer.Close(); err != nil {
p.cleanupTempFile()
return fmt.Errorf("closing blobgen writer: %w", err)
}
// Sync file to ensure all data is written
if err := p.currentBlob.tempFile.Sync(); err != nil {
p.cleanupTempFile()
return fmt.Errorf("syncing temp file: %w", err)
}
// Get the final size (encrypted if applicable)
finalSize, err := p.currentBlob.tempFile.Seek(0, io.SeekCurrent)
if err != nil { if err != nil {
return err p.cleanupTempFile()
return fmt.Errorf("getting file size: %w", err)
} }
chunkRefs := p.buildChunkRefs() // Reset to beginning for reading
if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil {
p.cleanupTempFile()
return fmt.Errorf("seeking to start: %w", err)
}
// Get hash from blobgen writer (of final encrypted data)
finalHash := p.currentBlob.writer.Sum256()
blobHash := hex.EncodeToString(finalHash)
// Create chunk references with offsets
chunkRefs := make([]*BlobChunkRef, 0, len(p.currentBlob.chunks))
for _, chunk := range p.currentBlob.chunks {
chunkRefs = append(chunkRefs, &BlobChunkRef{
ChunkHash: chunk.Hash,
Offset: chunk.Offset,
Length: chunk.Size,
})
}
// Get pending chunks (will be inserted to DB and reported to handler)
chunksToInsert := p.pendingChunks chunksToInsert := p.pendingChunks
p.pendingChunks = nil p.pendingChunks = nil // Clear pending list
if err := p.commitBlobToDatabase(blobHash, finalSize, chunksToInsert); err != nil { // Insert pending chunks, blob_chunks, and update blob in a single transaction
return err if p.repos != nil {
blobIDTyped, parseErr := types.ParseBlobID(p.currentBlob.id)
if parseErr != nil {
p.cleanupTempFile()
return fmt.Errorf("parsing blob ID: %w", parseErr)
}
err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
// First insert all pending chunks (required for blob_chunks FK)
for _, chunk := range chunksToInsert {
dbChunk := &database.Chunk{
ChunkHash: types.ChunkHash(chunk.Hash),
Size: chunk.Size,
}
if err := p.repos.Chunks.Create(ctx, tx, dbChunk); err != nil {
return fmt.Errorf("creating chunk: %w", err)
}
} }
// Insert all blob_chunk records in batch
for _, chunk := range p.currentBlob.chunks {
blobChunk := &database.BlobChunk{
BlobID: blobIDTyped,
ChunkHash: types.ChunkHash(chunk.Hash),
Offset: chunk.Offset,
Length: chunk.Size,
}
if err := p.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil {
return fmt.Errorf("creating blob_chunk: %w", err)
}
}
// Update blob record with final hash and sizes
return p.repos.Blobs.UpdateFinished(ctx, tx, p.currentBlob.id, blobHash,
p.currentBlob.size, finalSize)
})
if err != nil {
p.cleanupTempFile()
return fmt.Errorf("finalizing blob transaction: %w", err)
}
log.Debug("Committed blob transaction",
"chunks_inserted", len(chunksToInsert),
"blob_chunks_inserted", len(p.currentBlob.chunks))
}
// Create finished blob
finished := &FinishedBlob{ finished := &FinishedBlob{
ID: p.currentBlob.id, ID: p.currentBlob.id,
Hash: blobHash, Hash: blobHash,
Data: nil, // We don't load data into memory anymore
Chunks: chunkRefs, Chunks: chunkRefs,
CreatedTS: p.currentBlob.startTime, CreatedTS: p.currentBlob.startTime,
Uncompressed: p.currentBlob.size, Uncompressed: p.currentBlob.size,
@@ -386,105 +464,28 @@ func (p *Packer) finalizeCurrentBlob() error {
compressionRatio := float64(finished.Compressed) / float64(finished.Uncompressed) compressionRatio := float64(finished.Compressed) / float64(finished.Uncompressed)
log.Info("Finalized blob (compressed and encrypted)", log.Info("Finalized blob (compressed and encrypted)",
"hash", blobHash, "chunks", len(chunkRefs), "hash", blobHash,
"uncompressed", finished.Uncompressed, "compressed", finished.Compressed, "chunks", len(chunkRefs),
"uncompressed", finished.Uncompressed,
"compressed", finished.Compressed,
"ratio", fmt.Sprintf("%.2f", compressionRatio), "ratio", fmt.Sprintf("%.2f", compressionRatio),
"duration", time.Since(p.currentBlob.startTime)) "duration", time.Since(p.currentBlob.startTime))
// Collect inserted chunk hashes for the scanner to track
var insertedChunkHashes []string var insertedChunkHashes []string
for _, chunk := range chunksToInsert { for _, chunk := range chunksToInsert {
insertedChunkHashes = append(insertedChunkHashes, chunk.Hash) insertedChunkHashes = append(insertedChunkHashes, chunk.Hash)
} }
return p.deliverFinishedBlob(finished, insertedChunkHashes) // Call blob handler if set
}
// closeBlobWriter closes the writer, syncs to disk, and returns the blob hash and final size
func (p *Packer) closeBlobWriter() (string, int64, error) {
if err := p.currentBlob.writer.Close(); err != nil {
p.cleanupTempFile()
return "", 0, fmt.Errorf("closing blobgen writer: %w", err)
}
if err := p.currentBlob.tempFile.Sync(); err != nil {
p.cleanupTempFile()
return "", 0, fmt.Errorf("syncing temp file: %w", err)
}
finalSize, err := p.currentBlob.tempFile.Seek(0, io.SeekCurrent)
if err != nil {
p.cleanupTempFile()
return "", 0, fmt.Errorf("getting file size: %w", err)
}
if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil {
p.cleanupTempFile()
return "", 0, fmt.Errorf("seeking to start: %w", err)
}
finalHash := p.currentBlob.writer.Sum256()
return hex.EncodeToString(finalHash), finalSize, nil
}
// buildChunkRefs creates BlobChunkRef entries from the current blob's chunks
func (p *Packer) buildChunkRefs() []*BlobChunkRef {
refs := make([]*BlobChunkRef, 0, len(p.currentBlob.chunks))
for _, chunk := range p.currentBlob.chunks {
refs = append(refs, &BlobChunkRef{
ChunkHash: chunk.Hash, Offset: chunk.Offset, Length: chunk.Size,
})
}
return refs
}
// commitBlobToDatabase inserts pending chunks, blob_chunks, and updates the blob record
func (p *Packer) commitBlobToDatabase(blobHash string, finalSize int64, chunksToInsert []PendingChunk) error {
if p.repos == nil {
return nil
}
blobIDTyped, parseErr := types.ParseBlobID(p.currentBlob.id)
if parseErr != nil {
p.cleanupTempFile()
return fmt.Errorf("parsing blob ID: %w", parseErr)
}
err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
for _, chunk := range chunksToInsert {
dbChunk := &database.Chunk{ChunkHash: types.ChunkHash(chunk.Hash), Size: chunk.Size}
if err := p.repos.Chunks.Create(ctx, tx, dbChunk); err != nil {
return fmt.Errorf("creating chunk: %w", err)
}
}
for _, chunk := range p.currentBlob.chunks {
blobChunk := &database.BlobChunk{
BlobID: blobIDTyped, ChunkHash: types.ChunkHash(chunk.Hash),
Offset: chunk.Offset, Length: chunk.Size,
}
if err := p.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil {
return fmt.Errorf("creating blob_chunk: %w", err)
}
}
return p.repos.Blobs.UpdateFinished(ctx, tx, p.currentBlob.id, blobHash, p.currentBlob.size, finalSize)
})
if err != nil {
p.cleanupTempFile()
return fmt.Errorf("finalizing blob transaction: %w", err)
}
log.Debug("Committed blob transaction",
"chunks_inserted", len(chunksToInsert), "blob_chunks_inserted", len(p.currentBlob.chunks))
return nil
}
// deliverFinishedBlob passes the blob to the handler or stores it internally
func (p *Packer) deliverFinishedBlob(finished *FinishedBlob, insertedChunkHashes []string) error {
if p.blobHandler != nil { if p.blobHandler != nil {
// Reset file position for handler
if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil {
p.cleanupTempFile() p.cleanupTempFile()
return fmt.Errorf("seeking for handler: %w", err) return fmt.Errorf("seeking for handler: %w", err)
} }
// Create a blob reader that includes the data stream
blobWithReader := &BlobWithReader{ blobWithReader := &BlobWithReader{
FinishedBlob: finished, FinishedBlob: finished,
Reader: p.currentBlob.tempFile, Reader: p.currentBlob.tempFile,
@@ -496,12 +497,11 @@ func (p *Packer) deliverFinishedBlob(finished *FinishedBlob, insertedChunkHashes
p.cleanupTempFile() p.cleanupTempFile()
return fmt.Errorf("blob handler failed: %w", err) return fmt.Errorf("blob handler failed: %w", err)
} }
// Note: blob handler is responsible for closing/cleaning up temp file
p.currentBlob = nil p.currentBlob = nil
return nil } else {
} log.Debug("No blob handler callback configured", "blob_hash", blobHash[:8]+"...")
// No handler, need to read data for legacy behavior
// No handler - read data for legacy behavior
log.Debug("No blob handler callback configured", "blob_hash", finished.Hash[:8]+"...")
if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil {
p.cleanupTempFile() p.cleanupTempFile()
return fmt.Errorf("seeking to read data: %w", err) return fmt.Errorf("seeking to read data: %w", err)
@@ -513,9 +513,14 @@ func (p *Packer) deliverFinishedBlob(finished *FinishedBlob, insertedChunkHashes
return fmt.Errorf("reading blob data: %w", err) return fmt.Errorf("reading blob data: %w", err)
} }
finished.Data = data finished.Data = data
p.finishedBlobs = append(p.finishedBlobs, finished) p.finishedBlobs = append(p.finishedBlobs, finished)
// Cleanup
p.cleanupTempFile() p.cleanupTempFile()
p.currentBlob = nil p.currentBlob = nil
}
return nil return nil
} }

View File

@@ -57,17 +57,6 @@ Examples:
vaultik restore --verify myhost_docs_2025-01-01T12:00:00Z /restore`, vaultik restore --verify myhost_docs_2025-01-01T12:00:00Z /restore`,
Args: cobra.MinimumNArgs(2), Args: cobra.MinimumNArgs(2),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return runRestore(cmd, args, opts)
},
}
cmd.Flags().BoolVar(&opts.Verify, "verify", false, "Verify restored files by checking chunk hashes")
return cmd
}
// runRestore parses arguments and runs the restore operation through the app framework
func runRestore(cmd *cobra.Command, args []string, opts *RestoreOptions) error {
snapshotID := args[0] snapshotID := args[0]
opts.TargetDir = args[1] opts.TargetDir = args[1]
if len(args) > 2 { if len(args) > 2 {
@@ -89,14 +78,7 @@ func runRestore(cmd *cobra.Command, args []string, opts *RestoreOptions) error {
Debug: rootFlags.Debug, Debug: rootFlags.Debug,
Quiet: rootFlags.Quiet, Quiet: rootFlags.Quiet,
}, },
Modules: buildRestoreModules(), Modules: []fx.Option{
Invokes: buildRestoreInvokes(snapshotID, opts),
})
}
// buildRestoreModules returns the fx.Options for dependency injection in restore
func buildRestoreModules() []fx.Option {
return []fx.Option{
fx.Provide(fx.Annotate( fx.Provide(fx.Annotate(
func(g *globals.Globals, cfg *config.Config, func(g *globals.Globals, cfg *config.Config,
storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp { storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp {
@@ -109,12 +91,8 @@ func buildRestoreModules() []fx.Option {
} }
}, },
)), )),
} },
} Invokes: []fx.Option{
// buildRestoreInvokes returns the fx.Options that wire up the restore lifecycle
func buildRestoreInvokes(snapshotID string, opts *RestoreOptions) []fx.Option {
return []fx.Option{
fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) { fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) {
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error { OnStart: func(ctx context.Context) error {
@@ -147,5 +125,12 @@ func buildRestoreInvokes(snapshotID string, opts *RestoreOptions) []fx.Option {
}, },
}) })
}), }),
},
})
},
} }
cmd.Flags().BoolVar(&opts.Verify, "verify", false, "Verify restored files by checking chunk hashes")
return cmd
} }

View File

@@ -180,10 +180,18 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
} }
// Phase 0: Load known files and chunks from database into memory for fast lookup // Phase 0: Load known files and chunks from database into memory for fast lookup
knownFiles, err := s.loadDatabaseState(ctx, path) fmt.Println("Loading known files from database...")
knownFiles, err := s.loadKnownFiles(ctx, path)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("loading known files: %w", err)
} }
fmt.Printf("Loaded %s known files from database\n", formatNumber(len(knownFiles)))
fmt.Println("Loading known chunks from database...")
if err := s.loadKnownChunks(ctx); err != nil {
return nil, fmt.Errorf("loading known chunks: %w", err)
}
fmt.Printf("Loaded %s known chunks from database\n", formatNumber(len(s.knownChunks)))
// Phase 1: Scan directory, collect files to process, and track existing files // Phase 1: Scan directory, collect files to process, and track existing files
// (builds existingFiles map during walk to avoid double traversal) // (builds existingFiles map during walk to avoid double traversal)
@@ -208,8 +216,36 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
} }
} }
// Summarize scan phase results and update progress // Calculate total size to process
s.summarizeScanPhase(result, filesToProcess) var totalSizeToProcess int64
for _, file := range filesToProcess {
totalSizeToProcess += file.FileInfo.Size()
}
// Update progress with total size and file count
if s.progress != nil {
s.progress.SetTotalSize(totalSizeToProcess)
s.progress.GetStats().TotalFiles.Store(int64(len(filesToProcess)))
}
log.Info("Phase 1 complete",
"total_files", len(filesToProcess),
"total_size", humanize.Bytes(uint64(totalSizeToProcess)),
"files_skipped", result.FilesSkipped,
"bytes_skipped", humanize.Bytes(uint64(result.BytesSkipped)))
// Print scan summary
fmt.Printf("Scan complete: %s examined (%s), %s to process (%s)",
formatNumber(result.FilesScanned),
humanize.Bytes(uint64(totalSizeToProcess+result.BytesSkipped)),
formatNumber(len(filesToProcess)),
humanize.Bytes(uint64(totalSizeToProcess)))
if result.FilesDeleted > 0 {
fmt.Printf(", %s deleted (%s)",
formatNumber(result.FilesDeleted),
humanize.Bytes(uint64(result.BytesDeleted)))
}
fmt.Println()
// Phase 2: Process files and create chunks // Phase 2: Process files and create chunks
if len(filesToProcess) > 0 { if len(filesToProcess) > 0 {
@@ -223,66 +259,7 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
log.Info("Phase 2/3: Skipping (no files need processing, metadata-only snapshot)") log.Info("Phase 2/3: Skipping (no files need processing, metadata-only snapshot)")
} }
// Finalize result with blob statistics // Get final stats from packer
s.finalizeScanResult(ctx, result)
return result, nil
}
// loadDatabaseState loads known files and chunks from the database into memory for fast lookup
// This avoids per-file and per-chunk database queries during the scan and process phases
func (s *Scanner) loadDatabaseState(ctx context.Context, path string) (map[string]*database.File, error) {
fmt.Println("Loading known files from database...")
knownFiles, err := s.loadKnownFiles(ctx, path)
if err != nil {
return nil, fmt.Errorf("loading known files: %w", err)
}
fmt.Printf("Loaded %s known files from database\n", formatNumber(len(knownFiles)))
fmt.Println("Loading known chunks from database...")
if err := s.loadKnownChunks(ctx); err != nil {
return nil, fmt.Errorf("loading known chunks: %w", err)
}
fmt.Printf("Loaded %s known chunks from database\n", formatNumber(len(s.knownChunks)))
return knownFiles, nil
}
// summarizeScanPhase calculates total size to process, updates progress tracking,
// and prints the scan phase summary with file counts and sizes
func (s *Scanner) summarizeScanPhase(result *ScanResult, filesToProcess []*FileToProcess) {
var totalSizeToProcess int64
for _, file := range filesToProcess {
totalSizeToProcess += file.FileInfo.Size()
}
if s.progress != nil {
s.progress.SetTotalSize(totalSizeToProcess)
s.progress.GetStats().TotalFiles.Store(int64(len(filesToProcess)))
}
log.Info("Phase 1 complete",
"total_files", len(filesToProcess),
"total_size", humanize.Bytes(uint64(totalSizeToProcess)),
"files_skipped", result.FilesSkipped,
"bytes_skipped", humanize.Bytes(uint64(result.BytesSkipped)))
fmt.Printf("Scan complete: %s examined (%s), %s to process (%s)",
formatNumber(result.FilesScanned),
humanize.Bytes(uint64(totalSizeToProcess+result.BytesSkipped)),
formatNumber(len(filesToProcess)),
humanize.Bytes(uint64(totalSizeToProcess)))
if result.FilesDeleted > 0 {
fmt.Printf(", %s deleted (%s)",
formatNumber(result.FilesDeleted),
humanize.Bytes(uint64(result.BytesDeleted)))
}
fmt.Println()
}
// finalizeScanResult populates final blob statistics in the scan result
// by querying the packer and database for blob/upload counts
func (s *Scanner) finalizeScanResult(ctx context.Context, result *ScanResult) {
blobs := s.packer.GetFinishedBlobs() blobs := s.packer.GetFinishedBlobs()
result.BlobsCreated += len(blobs) result.BlobsCreated += len(blobs)
@@ -299,6 +276,7 @@ func (s *Scanner) finalizeScanResult(ctx context.Context, result *ScanResult) {
} }
result.EndTime = time.Now().UTC() result.EndTime = time.Now().UTC()
return result, nil
} }
// loadKnownFiles loads all known files from the database into a map for fast lookup // loadKnownFiles loads all known files from the database into a map for fast lookup
@@ -446,38 +424,12 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error {
flushStart := time.Now() flushStart := time.Now()
log.Debug("flushCompletedPendingFiles: starting") log.Debug("flushCompletedPendingFiles: starting")
// Partition pending files into those ready to flush and those still waiting
canFlush, stillPendingCount := s.partitionPendingByChunkStatus()
if len(canFlush) == 0 {
log.Debug("flushCompletedPendingFiles: nothing to flush")
return nil
}
log.Debug("Flushing completed files after blob finalize",
"files_to_flush", len(canFlush),
"files_still_pending", stillPendingCount)
// Collect all data for batch operations
allFiles, allFileIDs, allFileChunks, allChunkFiles := s.collectBatchFlushData(canFlush)
// Execute the batch flush in a single transaction
log.Debug("flushCompletedPendingFiles: starting transaction")
txStart := time.Now()
err := s.executeBatchFileFlush(ctx, allFiles, allFileIDs, allFileChunks, allChunkFiles)
log.Debug("flushCompletedPendingFiles: transaction done", "duration", time.Since(txStart))
log.Debug("flushCompletedPendingFiles: total duration", "duration", time.Since(flushStart))
return err
}
// partitionPendingByChunkStatus separates pending files into those whose chunks
// are all committed to DB (ready to flush) and those still waiting on pending chunks.
// Updates s.pendingFiles to contain only the still-pending files.
func (s *Scanner) partitionPendingByChunkStatus() (canFlush []pendingFileData, stillPendingCount int) {
log.Debug("flushCompletedPendingFiles: acquiring pendingFilesMu lock") log.Debug("flushCompletedPendingFiles: acquiring pendingFilesMu lock")
s.pendingFilesMu.Lock() s.pendingFilesMu.Lock()
log.Debug("flushCompletedPendingFiles: acquired lock", "pending_files", len(s.pendingFiles)) log.Debug("flushCompletedPendingFiles: acquired lock", "pending_files", len(s.pendingFiles))
// Separate files into complete (can flush) and incomplete (keep pending)
var canFlush []pendingFileData
var stillPending []pendingFileData var stillPending []pendingFileData
log.Debug("flushCompletedPendingFiles: checking which files can flush") log.Debug("flushCompletedPendingFiles: checking which files can flush")
@@ -502,15 +454,18 @@ func (s *Scanner) partitionPendingByChunkStatus() (canFlush []pendingFileData, s
s.pendingFilesMu.Unlock() s.pendingFilesMu.Unlock()
log.Debug("flushCompletedPendingFiles: released lock") log.Debug("flushCompletedPendingFiles: released lock")
return canFlush, len(stillPending) if len(canFlush) == 0 {
log.Debug("flushCompletedPendingFiles: nothing to flush")
return nil
} }
// collectBatchFlushData aggregates file records, IDs, file-chunk mappings, and chunk-file log.Debug("Flushing completed files after blob finalize",
// mappings from the given pending file data for efficient batch database operations "files_to_flush", len(canFlush),
func (s *Scanner) collectBatchFlushData(canFlush []pendingFileData) ([]*database.File, []types.FileID, []database.FileChunk, []database.ChunkFile) { "files_still_pending", len(stillPending))
// Collect all data for batch operations
log.Debug("flushCompletedPendingFiles: collecting data for batch ops") log.Debug("flushCompletedPendingFiles: collecting data for batch ops")
collectStart := time.Now() collectStart := time.Now()
var allFileChunks []database.FileChunk var allFileChunks []database.FileChunk
var allChunkFiles []database.ChunkFile var allChunkFiles []database.ChunkFile
var allFileIDs []types.FileID var allFileIDs []types.FileID
@@ -522,20 +477,16 @@ func (s *Scanner) collectBatchFlushData(canFlush []pendingFileData) ([]*database
allFileIDs = append(allFileIDs, data.file.ID) allFileIDs = append(allFileIDs, data.file.ID)
allFiles = append(allFiles, data.file) allFiles = append(allFiles, data.file)
} }
log.Debug("flushCompletedPendingFiles: collected data", log.Debug("flushCompletedPendingFiles: collected data",
"duration", time.Since(collectStart), "duration", time.Since(collectStart),
"file_chunks", len(allFileChunks), "file_chunks", len(allFileChunks),
"chunk_files", len(allChunkFiles), "chunk_files", len(allChunkFiles),
"files", len(allFiles)) "files", len(allFiles))
return allFiles, allFileIDs, allFileChunks, allChunkFiles // Flush the complete files using batch operations
} log.Debug("flushCompletedPendingFiles: starting transaction")
txStart := time.Now()
// executeBatchFileFlush writes all collected file data to the database in a single transaction, err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
// including deleting old mappings, creating file records, and adding snapshot associations
func (s *Scanner) executeBatchFileFlush(ctx context.Context, allFiles []*database.File, allFileIDs []types.FileID, allFileChunks []database.FileChunk, allChunkFiles []database.ChunkFile) error {
return s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
log.Debug("flushCompletedPendingFiles: inside transaction") log.Debug("flushCompletedPendingFiles: inside transaction")
// Batch delete old file_chunks and chunk_files // Batch delete old file_chunks and chunk_files
@@ -588,6 +539,9 @@ func (s *Scanner) executeBatchFileFlush(ctx context.Context, allFiles []*databas
log.Debug("flushCompletedPendingFiles: transaction complete") log.Debug("flushCompletedPendingFiles: transaction complete")
return nil return nil
}) })
log.Debug("flushCompletedPendingFiles: transaction done", "duration", time.Since(txStart))
log.Debug("flushCompletedPendingFiles: total duration", "duration", time.Since(flushStart))
return err
} }
// ScanPhaseResult contains the results of the scan phase // ScanPhaseResult contains the results of the scan phase
@@ -669,30 +623,6 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult
mu.Unlock() mu.Unlock()
// Update result stats // Update result stats
s.updateScanEntryStats(result, needsProcessing, info)
// Output periodic status
if time.Since(lastStatusTime) >= statusInterval {
printScanProgressLine(filesScanned, changedCount, estimatedTotal, startTime)
lastStatusTime = time.Now()
}
return nil
})
if err != nil {
return nil, err
}
return &ScanPhaseResult{
FilesToProcess: filesToProcess,
UnchangedFileIDs: unchangedFileIDs,
}, nil
}
// updateScanEntryStats updates the scan result and progress reporter statistics
// for a single scanned file entry based on whether it needs processing
func (s *Scanner) updateScanEntryStats(result *ScanResult, needsProcessing bool, info os.FileInfo) {
if needsProcessing { if needsProcessing {
result.BytesScanned += info.Size() result.BytesScanned += info.Size()
if s.progress != nil { if s.progress != nil {
@@ -710,14 +640,13 @@ func (s *Scanner) updateScanEntryStats(result *ScanResult, needsProcessing bool,
if s.progress != nil { if s.progress != nil {
s.progress.GetStats().FilesScanned.Add(1) s.progress.GetStats().FilesScanned.Add(1)
} }
}
// printScanProgressLine prints a periodic progress line during the scan phase, // Output periodic status
// showing files scanned, percentage complete (if estimate available), and ETA if time.Since(lastStatusTime) >= statusInterval {
func printScanProgressLine(filesScanned int64, changedCount int, estimatedTotal int64, startTime time.Time) {
elapsed := time.Since(startTime) elapsed := time.Since(startTime)
rate := float64(filesScanned) / elapsed.Seconds() rate := float64(filesScanned) / elapsed.Seconds()
// Build status line - use estimate if available (not first backup)
if estimatedTotal > 0 { if estimatedTotal > 0 {
// Show actual scanned vs estimate (may exceed estimate if files were added) // Show actual scanned vs estimate (may exceed estimate if files were added)
pct := float64(filesScanned) / float64(estimatedTotal) * 100 pct := float64(filesScanned) / float64(estimatedTotal) * 100
@@ -750,6 +679,20 @@ func printScanProgressLine(filesScanned int64, changedCount int, estimatedTotal
rate, rate,
elapsed.Round(time.Second)) elapsed.Round(time.Second))
} }
lastStatusTime = time.Now()
}
return nil
})
if err != nil {
return nil, err
}
return &ScanPhaseResult{
FilesToProcess: filesToProcess,
UnchangedFileIDs: unchangedFileIDs,
}, nil
} }
// checkFileInMemory checks if a file needs processing using the in-memory map // checkFileInMemory checks if a file needs processing using the in-memory map
@@ -887,14 +830,23 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc
s.progress.GetStats().CurrentFile.Store(fileToProcess.Path) s.progress.GetStats().CurrentFile.Store(fileToProcess.Path)
} }
// Process file with error handling for deleted files and skip-errors mode // Process file in streaming fashion
skipped, err := s.processFileWithErrorHandling(ctx, fileToProcess, result) if err := s.processFileStreaming(ctx, fileToProcess, result); err != nil {
if err != nil { // Handle files that were deleted between scan and process phases
return err if errors.Is(err, os.ErrNotExist) {
} log.Warn("File was deleted during backup, skipping", "path", fileToProcess.Path)
if skipped { result.FilesSkipped++
continue continue
} }
// Skip file read errors if --skip-errors is enabled
if s.skipErrors {
log.Error("ERROR: Failed to process file (skipping due to --skip-errors)", "path", fileToProcess.Path, "error", err)
fmt.Printf("ERROR: Failed to process %s: %v (skipping)\n", fileToProcess.Path, err)
result.FilesSkipped++
continue
}
return fmt.Errorf("processing file %s: %w", fileToProcess.Path, err)
}
// Update files processed counter // Update files processed counter
if s.progress != nil { if s.progress != nil {
@@ -906,40 +858,6 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc
// Output periodic status // Output periodic status
if time.Since(lastStatusTime) >= statusInterval { if time.Since(lastStatusTime) >= statusInterval {
printProcessingProgress(filesProcessed, totalFiles, bytesProcessed, totalBytes, startTime)
lastStatusTime = time.Now()
}
}
// Finalize: flush packer, pending files, and handle local blobs
return s.finalizeProcessPhase(ctx, result)
}
// processFileWithErrorHandling wraps processFileStreaming with error recovery for
// deleted files and skip-errors mode. Returns (skipped, error).
func (s *Scanner) processFileWithErrorHandling(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) (bool, error) {
if err := s.processFileStreaming(ctx, fileToProcess, result); err != nil {
// Handle files that were deleted between scan and process phases
if errors.Is(err, os.ErrNotExist) {
log.Warn("File was deleted during backup, skipping", "path", fileToProcess.Path)
result.FilesSkipped++
return true, nil
}
// Skip file read errors if --skip-errors is enabled
if s.skipErrors {
log.Error("ERROR: Failed to process file (skipping due to --skip-errors)", "path", fileToProcess.Path, "error", err)
fmt.Printf("ERROR: Failed to process %s: %v (skipping)\n", fileToProcess.Path, err)
result.FilesSkipped++
return true, nil
}
return false, fmt.Errorf("processing file %s: %w", fileToProcess.Path, err)
}
return false, nil
}
// printProcessingProgress prints a periodic progress line during the process phase,
// showing files processed, bytes transferred, throughput, and ETA
func printProcessingProgress(filesProcessed, totalFiles int, bytesProcessed, totalBytes int64, startTime time.Time) {
elapsed := time.Since(startTime) elapsed := time.Since(startTime)
pct := float64(bytesProcessed) / float64(totalBytes) * 100 pct := float64(bytesProcessed) / float64(totalBytes) * 100
byteRate := float64(bytesProcessed) / elapsed.Seconds() byteRate := float64(bytesProcessed) / elapsed.Seconds()
@@ -966,11 +884,10 @@ func printProcessingProgress(filesProcessed, totalFiles int, bytesProcessed, tot
fmt.Printf(", ETA: %s", eta.Round(time.Second)) fmt.Printf(", ETA: %s", eta.Round(time.Second))
} }
fmt.Println() fmt.Println()
lastStatusTime = time.Now()
}
} }
// finalizeProcessPhase flushes the packer, writes remaining pending files to the database,
// and handles local blob storage when no remote storage is configured
func (s *Scanner) finalizeProcessPhase(ctx context.Context, result *ScanResult) error {
// Final packer flush first - this commits remaining chunks to DB // Final packer flush first - this commits remaining chunks to DB
// and handleBlobReady will flush files whose chunks are now committed // and handleBlobReady will flush files whose chunks are now committed
s.packerMu.Lock() s.packerMu.Lock()
@@ -1014,103 +931,40 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
startTime := time.Now().UTC() startTime := time.Now().UTC()
finishedBlob := blobWithReader.FinishedBlob finishedBlob := blobWithReader.FinishedBlob
// Report upload start and increment blobs created
if s.progress != nil { if s.progress != nil {
s.progress.ReportUploadStart(finishedBlob.Hash, finishedBlob.Compressed) s.progress.ReportUploadStart(finishedBlob.Hash, finishedBlob.Compressed)
s.progress.GetStats().BlobsCreated.Add(1) s.progress.GetStats().BlobsCreated.Add(1)
} }
// Upload to storage first (without holding any locks)
// Use scan context for cancellation support
ctx := s.scanCtx ctx := s.scanCtx
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash) // Track bytes uploaded for accurate speed calculation
blobExists, err := s.uploadBlobIfNeeded(ctx, blobPath, blobWithReader, startTime)
if err != nil {
s.cleanupBlobTempFile(blobWithReader)
return fmt.Errorf("uploading blob %s: %w", finishedBlob.Hash, err)
}
if err := s.recordBlobMetadata(ctx, finishedBlob, blobExists, startTime); err != nil {
s.cleanupBlobTempFile(blobWithReader)
return err
}
s.cleanupBlobTempFile(blobWithReader)
// Chunks from this blob are now committed to DB - remove from pending set
s.removePendingChunkHashes(blobWithReader.InsertedChunkHashes)
// Flush files whose chunks are now all committed
if err := s.flushCompletedPendingFiles(ctx); err != nil {
return fmt.Errorf("flushing completed files: %w", err)
}
return nil
}
// uploadBlobIfNeeded uploads the blob to storage if it doesn't already exist, returns whether it existed
func (s *Scanner) uploadBlobIfNeeded(ctx context.Context, blobPath string, blobWithReader *blob.BlobWithReader, startTime time.Time) (bool, error) {
finishedBlob := blobWithReader.FinishedBlob
// Check if blob already exists (deduplication after restart)
if _, err := s.storage.Stat(ctx, blobPath); err == nil {
log.Info("Blob already exists in storage, skipping upload",
"hash", finishedBlob.Hash, "size", humanize.Bytes(uint64(finishedBlob.Compressed)))
fmt.Printf("Blob exists: %s (%s, skipped upload)\n",
finishedBlob.Hash[:12]+"...", humanize.Bytes(uint64(finishedBlob.Compressed)))
return true, nil
}
progressCallback := s.makeUploadProgressCallback(ctx, finishedBlob)
if err := s.storage.PutWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil {
log.Error("Failed to upload blob", "hash", finishedBlob.Hash, "error", err)
return false, fmt.Errorf("uploading blob to storage: %w", err)
}
uploadDuration := time.Since(startTime)
uploadSpeedBps := float64(finishedBlob.Compressed) / uploadDuration.Seconds()
fmt.Printf("Blob stored: %s (%s, %s/sec, %s)\n",
finishedBlob.Hash[:12]+"...",
humanize.Bytes(uint64(finishedBlob.Compressed)),
humanize.Bytes(uint64(uploadSpeedBps)),
uploadDuration.Round(time.Millisecond))
log.Info("Successfully uploaded blob to storage",
"path", blobPath,
"size", humanize.Bytes(uint64(finishedBlob.Compressed)),
"duration", uploadDuration,
"speed", humanize.SI(uploadSpeedBps*8, "bps"))
if s.progress != nil {
s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration)
stats := s.progress.GetStats()
stats.BlobsUploaded.Add(1)
stats.BytesUploaded.Add(finishedBlob.Compressed)
}
return false, nil
}
// makeUploadProgressCallback creates a progress callback for blob uploads
func (s *Scanner) makeUploadProgressCallback(ctx context.Context, finishedBlob *blob.FinishedBlob) func(int64) error {
lastProgressTime := time.Now() lastProgressTime := time.Now()
lastProgressBytes := int64(0) lastProgressBytes := int64(0)
return func(uploaded int64) error { progressCallback := func(uploaded int64) error {
// Calculate instantaneous speed
now := time.Now() now := time.Now()
elapsed := now.Sub(lastProgressTime).Seconds() elapsed := now.Sub(lastProgressTime).Seconds()
if elapsed > 0.5 { if elapsed > 0.5 { // Update speed every 0.5 seconds
bytesSinceLastUpdate := uploaded - lastProgressBytes bytesSinceLastUpdate := uploaded - lastProgressBytes
speed := float64(bytesSinceLastUpdate) / elapsed speed := float64(bytesSinceLastUpdate) / elapsed
if s.progress != nil { if s.progress != nil {
s.progress.ReportUploadProgress(finishedBlob.Hash, uploaded, finishedBlob.Compressed, speed) s.progress.ReportUploadProgress(finishedBlob.Hash, uploaded, finishedBlob.Compressed, speed)
} }
lastProgressTime = now lastProgressTime = now
lastProgressBytes = uploaded lastProgressBytes = uploaded
} }
// Check for cancellation
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@@ -1118,26 +972,87 @@ func (s *Scanner) makeUploadProgressCallback(ctx context.Context, finishedBlob *
return nil return nil
} }
} }
// Create sharded path: blobs/ca/fe/cafebabe...
blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash)
// Check if blob already exists in remote storage (deduplication after restart)
blobExists := false
if _, err := s.storage.Stat(ctx, blobPath); err == nil {
blobExists = true
log.Info("Blob already exists in storage, skipping upload",
"hash", finishedBlob.Hash,
"size", humanize.Bytes(uint64(finishedBlob.Compressed)))
fmt.Printf("Blob exists: %s (%s, skipped upload)\n",
finishedBlob.Hash[:12]+"...",
humanize.Bytes(uint64(finishedBlob.Compressed)))
} }
// recordBlobMetadata stores blob upload metadata in the database if !blobExists {
func (s *Scanner) recordBlobMetadata(ctx context.Context, finishedBlob *blob.FinishedBlob, blobExists bool, startTime time.Time) error { if err := s.storage.PutWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil {
return fmt.Errorf("uploading blob %s to storage: %w", finishedBlob.Hash, err)
}
uploadDuration := time.Since(startTime)
// Calculate upload speed
uploadSpeedBps := float64(finishedBlob.Compressed) / uploadDuration.Seconds()
// Print blob stored message
fmt.Printf("Blob stored: %s (%s, %s/sec, %s)\n",
finishedBlob.Hash[:12]+"...",
humanize.Bytes(uint64(finishedBlob.Compressed)),
humanize.Bytes(uint64(uploadSpeedBps)),
uploadDuration.Round(time.Millisecond))
// Log upload stats
uploadSpeedBits := uploadSpeedBps * 8 // bits per second
log.Info("Successfully uploaded blob to storage",
"path", blobPath,
"size", humanize.Bytes(uint64(finishedBlob.Compressed)),
"duration", uploadDuration,
"speed", humanize.SI(uploadSpeedBits, "bps"))
// Report upload complete
if s.progress != nil {
s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration)
}
// Update progress after upload completes
if s.progress != nil {
stats := s.progress.GetStats()
stats.BlobsUploaded.Add(1)
stats.BytesUploaded.Add(finishedBlob.Compressed)
}
}
// Store metadata in database (after upload is complete)
dbCtx := s.scanCtx
if dbCtx == nil {
dbCtx = context.Background()
}
// Parse blob ID for typed operations
finishedBlobID, err := types.ParseBlobID(finishedBlob.ID) finishedBlobID, err := types.ParseBlobID(finishedBlob.ID)
if err != nil { if err != nil {
return fmt.Errorf("parsing finished blob ID: %w", err) return fmt.Errorf("parsing finished blob ID: %w", err)
} }
// Track upload duration (0 if blob already existed)
uploadDuration := time.Since(startTime) uploadDuration := time.Since(startTime)
return s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { err = s.repos.WithTx(dbCtx, func(ctx context.Context, tx *sql.Tx) error {
if err := s.repos.Blobs.UpdateUploaded(txCtx, tx, finishedBlob.ID); err != nil { // Update blob upload timestamp
if err := s.repos.Blobs.UpdateUploaded(ctx, tx, finishedBlob.ID); err != nil {
return fmt.Errorf("updating blob upload timestamp: %w", err) return fmt.Errorf("updating blob upload timestamp: %w", err)
} }
if err := s.repos.Snapshots.AddBlob(txCtx, tx, s.snapshotID, finishedBlobID, types.BlobHash(finishedBlob.Hash)); err != nil { // Add the blob to the snapshot
if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, finishedBlobID, types.BlobHash(finishedBlob.Hash)); err != nil {
return fmt.Errorf("adding blob to snapshot: %w", err) return fmt.Errorf("adding blob to snapshot: %w", err)
} }
// Record upload metrics (only for actual uploads, not deduplicated blobs)
if !blobExists { if !blobExists {
upload := &database.Upload{ upload := &database.Upload{
BlobHash: finishedBlob.Hash, BlobHash: finishedBlob.Hash,
@@ -1146,17 +1061,15 @@ func (s *Scanner) recordBlobMetadata(ctx context.Context, finishedBlob *blob.Fin
Size: finishedBlob.Compressed, Size: finishedBlob.Compressed,
DurationMs: uploadDuration.Milliseconds(), DurationMs: uploadDuration.Milliseconds(),
} }
if err := s.repos.Uploads.Create(txCtx, tx, upload); err != nil { if err := s.repos.Uploads.Create(ctx, tx, upload); err != nil {
return fmt.Errorf("recording upload metrics: %w", err) return fmt.Errorf("recording upload metrics: %w", err)
} }
} }
return nil return nil
}) })
}
// cleanupBlobTempFile closes and removes the blob's temporary file // Cleanup temp file if needed
func (s *Scanner) cleanupBlobTempFile(blobWithReader *blob.BlobWithReader) {
if blobWithReader.TempFile != nil { if blobWithReader.TempFile != nil {
tempName := blobWithReader.TempFile.Name() tempName := blobWithReader.TempFile.Name()
if err := blobWithReader.TempFile.Close(); err != nil { if err := blobWithReader.TempFile.Close(); err != nil {
@@ -1166,41 +1079,77 @@ func (s *Scanner) cleanupBlobTempFile(blobWithReader *blob.BlobWithReader) {
log.Fatal("Failed to remove temp file", "file", tempName, "error", err) log.Fatal("Failed to remove temp file", "file", tempName, "error", err)
} }
} }
if err != nil {
return err
} }
// streamingChunkInfo tracks chunk metadata collected during streaming // Chunks from this blob are now committed to DB - remove from pending set
type streamingChunkInfo struct { log.Debug("handleBlobReady: removing pending chunk hashes")
fileChunk database.FileChunk s.removePendingChunkHashes(blobWithReader.InsertedChunkHashes)
offset int64 log.Debug("handleBlobReady: removed pending chunk hashes")
size int64
// Flush files whose chunks are now all committed
// This maintains database consistency after each blob
log.Debug("handleBlobReady: calling flushCompletedPendingFiles")
if err := s.flushCompletedPendingFiles(dbCtx); err != nil {
return fmt.Errorf("flushing completed files: %w", err)
}
log.Debug("handleBlobReady: flushCompletedPendingFiles returned")
log.Debug("handleBlobReady: complete")
return nil
} }
// processFileStreaming processes a file by streaming chunks directly to the packer // processFileStreaming processes a file by streaming chunks directly to the packer
func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) error { func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) error {
// Open the file
file, err := s.fs.Open(fileToProcess.Path) file, err := s.fs.Open(fileToProcess.Path)
if err != nil { if err != nil {
return fmt.Errorf("opening file: %w", err) return fmt.Errorf("opening file: %w", err)
} }
defer func() { _ = file.Close() }() defer func() { _ = file.Close() }()
var chunks []streamingChunkInfo // We'll collect file chunks for database storage
// but process them for packing as we go
type chunkInfo struct {
fileChunk database.FileChunk
offset int64
size int64
}
var chunks []chunkInfo
chunkIndex := 0 chunkIndex := 0
// Process chunks in streaming fashion and get full file hash
fileHash, err := s.chunker.ChunkReaderStreaming(file, func(chunk chunker.Chunk) error { fileHash, err := s.chunker.ChunkReaderStreaming(file, func(chunk chunker.Chunk) error {
// Check for cancellation
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
default: default:
} }
log.Debug("Processing content-defined chunk from file",
"file", fileToProcess.Path,
"chunk_index", chunkIndex,
"hash", chunk.Hash,
"size", chunk.Size)
// Check if chunk already exists (fast in-memory lookup)
chunkExists := s.chunkExists(chunk.Hash) chunkExists := s.chunkExists(chunk.Hash)
// Queue new chunks for batch insert when blob finalizes
// This dramatically reduces transaction overhead
if !chunkExists { if !chunkExists {
s.packer.AddPendingChunk(chunk.Hash, chunk.Size) s.packer.AddPendingChunk(chunk.Hash, chunk.Size)
// Add to in-memory cache immediately for fast duplicate detection
s.addKnownChunk(chunk.Hash) s.addKnownChunk(chunk.Hash)
// Track as pending until blob finalizes and commits to DB
s.addPendingChunkHash(chunk.Hash) s.addPendingChunkHash(chunk.Hash)
} }
chunks = append(chunks, streamingChunkInfo{ // Track file chunk association for later storage
chunks = append(chunks, chunkInfo{
fileChunk: database.FileChunk{ fileChunk: database.FileChunk{
FileID: fileToProcess.File.ID, FileID: fileToProcess.File.ID,
Idx: chunkIndex, Idx: chunkIndex,
@@ -1210,15 +1159,55 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
size: chunk.Size, size: chunk.Size,
}) })
s.updateChunkStats(chunkExists, chunk.Size, result) // Update stats
if chunkExists {
result.FilesSkipped++ // Track as skipped for now
result.BytesSkipped += chunk.Size
if s.progress != nil {
s.progress.GetStats().BytesSkipped.Add(chunk.Size)
}
} else {
result.ChunksCreated++
result.BytesScanned += chunk.Size
if s.progress != nil {
s.progress.GetStats().ChunksCreated.Add(1)
s.progress.GetStats().BytesProcessed.Add(chunk.Size)
s.progress.UpdateChunkingActivity()
}
}
// Add chunk to packer immediately (streaming)
// This happens outside the database transaction
if !chunkExists { if !chunkExists {
if err := s.addChunkToPacker(chunk); err != nil { s.packerMu.Lock()
return err err := s.packer.AddChunk(&blob.ChunkRef{
Hash: chunk.Hash,
Data: chunk.Data,
})
if err == blob.ErrBlobSizeLimitExceeded {
// Finalize current blob and retry
if err := s.packer.FinalizeBlob(); err != nil {
s.packerMu.Unlock()
return fmt.Errorf("finalizing blob: %w", err)
} }
// Retry adding the chunk
if err := s.packer.AddChunk(&blob.ChunkRef{
Hash: chunk.Hash,
Data: chunk.Data,
}); err != nil {
s.packerMu.Unlock()
return fmt.Errorf("adding chunk after finalize: %w", err)
}
} else if err != nil {
s.packerMu.Unlock()
return fmt.Errorf("adding chunk to packer: %w", err)
}
s.packerMu.Unlock()
} }
// Clear chunk data from memory immediately after use
chunk.Data = nil chunk.Data = nil
chunkIndex++ chunkIndex++
return nil return nil
}) })
@@ -1228,54 +1217,12 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
} }
log.Debug("Completed snapshotting file", log.Debug("Completed snapshotting file",
"path", fileToProcess.Path, "file_hash", fileHash, "chunks", len(chunks)) "path", fileToProcess.Path,
"file_hash", fileHash,
"chunks", len(chunks))
s.queueFileForBatchInsert(ctx, fileToProcess, chunks) // Build file data for batch insertion
return nil // Update chunk associations with the file ID
}
// updateChunkStats updates scan result and progress stats for a processed chunk
func (s *Scanner) updateChunkStats(chunkExists bool, chunkSize int64, result *ScanResult) {
if chunkExists {
result.FilesSkipped++
result.BytesSkipped += chunkSize
if s.progress != nil {
s.progress.GetStats().BytesSkipped.Add(chunkSize)
}
} else {
result.ChunksCreated++
result.BytesScanned += chunkSize
if s.progress != nil {
s.progress.GetStats().ChunksCreated.Add(1)
s.progress.GetStats().BytesProcessed.Add(chunkSize)
s.progress.UpdateChunkingActivity()
}
}
}
// addChunkToPacker adds a chunk to the blob packer, finalizing the current blob if needed
func (s *Scanner) addChunkToPacker(chunk chunker.Chunk) error {
s.packerMu.Lock()
err := s.packer.AddChunk(&blob.ChunkRef{Hash: chunk.Hash, Data: chunk.Data})
if err == blob.ErrBlobSizeLimitExceeded {
if err := s.packer.FinalizeBlob(); err != nil {
s.packerMu.Unlock()
return fmt.Errorf("finalizing blob: %w", err)
}
if err := s.packer.AddChunk(&blob.ChunkRef{Hash: chunk.Hash, Data: chunk.Data}); err != nil {
s.packerMu.Unlock()
return fmt.Errorf("adding chunk after finalize: %w", err)
}
} else if err != nil {
s.packerMu.Unlock()
return fmt.Errorf("adding chunk to packer: %w", err)
}
s.packerMu.Unlock()
return nil
}
// queueFileForBatchInsert builds file/chunk associations and queues the file for batch DB insert
func (s *Scanner) queueFileForBatchInsert(ctx context.Context, fileToProcess *FileToProcess, chunks []streamingChunkInfo) {
fileChunks := make([]database.FileChunk, len(chunks)) fileChunks := make([]database.FileChunk, len(chunks))
chunkFiles := make([]database.ChunkFile, len(chunks)) chunkFiles := make([]database.ChunkFile, len(chunks))
for i, ci := range chunks { for i, ci := range chunks {
@@ -1292,11 +1239,14 @@ func (s *Scanner) queueFileForBatchInsert(ctx context.Context, fileToProcess *Fi
} }
} }
// Queue file for batch insertion
// Files will be flushed when their chunks are committed (after blob finalize)
s.addPendingFile(ctx, pendingFileData{ s.addPendingFile(ctx, pendingFileData{
file: fileToProcess.File, file: fileToProcess.File,
fileChunks: fileChunks, fileChunks: fileChunks,
chunkFiles: chunkFiles, chunkFiles: chunkFiles,
}) })
return nil
} }
// GetProgress returns the progress reporter for this scanner // GetProgress returns the progress reporter for this scanner

View File

@@ -227,39 +227,12 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
} }
}() }()
// Steps 1-5: Copy, clean, vacuum, compress, and read the database
finalData, tempDBPath, err := sm.prepareExportDB(ctx, dbPath, snapshotID, tempDir)
if err != nil {
return err
}
// Step 6: Generate blob manifest (before closing temp DB)
blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID)
if err != nil {
return fmt.Errorf("generating blob manifest: %w", err)
}
// Step 7: Upload to S3 in snapshot subdirectory
if err := sm.uploadSnapshotArtifacts(ctx, snapshotID, finalData, blobManifest); err != nil {
return err
}
log.Info("Uploaded snapshot metadata",
"snapshot_id", snapshotID,
"db_size", len(finalData),
"manifest_size", len(blobManifest))
return nil
}
// prepareExportDB copies, cleans, vacuums, and compresses the snapshot database for export.
// Returns the compressed data and the path to the temporary database (needed for manifest generation).
func (sm *SnapshotManager) prepareExportDB(ctx context.Context, dbPath, snapshotID, tempDir string) ([]byte, string, error) {
// Step 1: Copy database to temp file // Step 1: Copy database to temp file
// The main database should be closed at this point // The main database should be closed at this point
tempDBPath := filepath.Join(tempDir, "snapshot.db") tempDBPath := filepath.Join(tempDir, "snapshot.db")
log.Debug("Copying database to temporary location", "source", dbPath, "destination", tempDBPath) log.Debug("Copying database to temporary location", "source", dbPath, "destination", tempDBPath)
if err := sm.copyFile(dbPath, tempDBPath); err != nil { if err := sm.copyFile(dbPath, tempDBPath); err != nil {
return nil, "", fmt.Errorf("copying database: %w", err) return fmt.Errorf("copying database: %w", err)
} }
log.Debug("Database copy complete", "size", sm.getFileSize(tempDBPath)) log.Debug("Database copy complete", "size", sm.getFileSize(tempDBPath))
@@ -267,7 +240,7 @@ func (sm *SnapshotManager) prepareExportDB(ctx context.Context, dbPath, snapshot
log.Debug("Cleaning temporary database", "snapshot_id", snapshotID) log.Debug("Cleaning temporary database", "snapshot_id", snapshotID)
stats, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID) stats, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("cleaning snapshot database: %w", err) return fmt.Errorf("cleaning snapshot database: %w", err)
} }
log.Info("Temporary database cleanup complete", log.Info("Temporary database cleanup complete",
"db_path", tempDBPath, "db_path", tempDBPath,
@@ -282,14 +255,14 @@ func (sm *SnapshotManager) prepareExportDB(ctx context.Context, dbPath, snapshot
// Step 3: VACUUM the database to remove deleted data and compact // Step 3: VACUUM the database to remove deleted data and compact
// This is critical for security - ensures no stale/deleted data is uploaded // This is critical for security - ensures no stale/deleted data is uploaded
if err := sm.vacuumDatabase(tempDBPath); err != nil { if err := sm.vacuumDatabase(tempDBPath); err != nil {
return nil, "", fmt.Errorf("vacuuming database: %w", err) return fmt.Errorf("vacuuming database: %w", err)
} }
log.Debug("Database vacuumed", "size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath)))) log.Debug("Database vacuumed", "size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath))))
// Step 4: Compress and encrypt the binary database file // Step 4: Compress and encrypt the binary database file
compressedPath := filepath.Join(tempDir, "db.zst.age") compressedPath := filepath.Join(tempDir, "db.zst.age")
if err := sm.compressFile(tempDBPath, compressedPath); err != nil { if err := sm.compressFile(tempDBPath, compressedPath); err != nil {
return nil, "", fmt.Errorf("compressing database: %w", err) return fmt.Errorf("compressing database: %w", err)
} }
log.Debug("Compression complete", log.Debug("Compression complete",
"original_size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath))), "original_size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath))),
@@ -298,43 +271,49 @@ func (sm *SnapshotManager) prepareExportDB(ctx context.Context, dbPath, snapshot
// Step 5: Read compressed and encrypted data for upload // Step 5: Read compressed and encrypted data for upload
finalData, err := afero.ReadFile(sm.fs, compressedPath) finalData, err := afero.ReadFile(sm.fs, compressedPath)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("reading compressed dump: %w", err) return fmt.Errorf("reading compressed dump: %w", err)
} }
return finalData, tempDBPath, nil // Step 6: Generate blob manifest (before closing temp DB)
blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID)
if err != nil {
return fmt.Errorf("generating blob manifest: %w", err)
} }
// uploadSnapshotArtifacts uploads the database backup and blob manifest to S3 // Step 7: Upload to S3 in snapshot subdirectory
func (sm *SnapshotManager) uploadSnapshotArtifacts(ctx context.Context, snapshotID string, dbData, manifestData []byte) error {
// Upload database backup (compressed and encrypted) // Upload database backup (compressed and encrypted)
dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
dbUploadStart := time.Now() dbUploadStart := time.Now()
if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(dbData)); err != nil { if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(finalData)); err != nil {
return fmt.Errorf("uploading snapshot database: %w", err) return fmt.Errorf("uploading snapshot database: %w", err)
} }
dbUploadDuration := time.Since(dbUploadStart) dbUploadDuration := time.Since(dbUploadStart)
dbUploadSpeed := float64(len(dbData)) * 8 / dbUploadDuration.Seconds() // bits per second dbUploadSpeed := float64(len(finalData)) * 8 / dbUploadDuration.Seconds() // bits per second
log.Info("Uploaded snapshot database", log.Info("Uploaded snapshot database",
"path", dbKey, "path", dbKey,
"size", humanize.Bytes(uint64(len(dbData))), "size", humanize.Bytes(uint64(len(finalData))),
"duration", dbUploadDuration, "duration", dbUploadDuration,
"speed", humanize.SI(dbUploadSpeed, "bps")) "speed", humanize.SI(dbUploadSpeed, "bps"))
// Upload blob manifest (compressed only, not encrypted) // Upload blob manifest (compressed only, not encrypted)
manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
manifestUploadStart := time.Now() manifestUploadStart := time.Now()
if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(manifestData)); err != nil { if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil {
return fmt.Errorf("uploading blob manifest: %w", err) return fmt.Errorf("uploading blob manifest: %w", err)
} }
manifestUploadDuration := time.Since(manifestUploadStart) manifestUploadDuration := time.Since(manifestUploadStart)
manifestUploadSpeed := float64(len(manifestData)) * 8 / manifestUploadDuration.Seconds() // bits per second manifestUploadSpeed := float64(len(blobManifest)) * 8 / manifestUploadDuration.Seconds() // bits per second
log.Info("Uploaded blob manifest", log.Info("Uploaded blob manifest",
"path", manifestKey, "path", manifestKey,
"size", humanize.Bytes(uint64(len(manifestData))), "size", humanize.Bytes(uint64(len(blobManifest))),
"duration", manifestUploadDuration, "duration", manifestUploadDuration,
"speed", humanize.SI(manifestUploadSpeed, "bps")) "speed", humanize.SI(manifestUploadSpeed, "bps"))
log.Info("Uploaded snapshot metadata",
"snapshot_id", snapshotID,
"db_size", len(finalData),
"manifest_size", len(blobManifest))
return nil return nil
} }

View File

@@ -149,9 +149,9 @@ type RemoteInfoResult struct {
// RemoteInfo displays information about remote storage // RemoteInfo displays information about remote storage
func (v *Vaultik) RemoteInfo(jsonOutput bool) error { func (v *Vaultik) RemoteInfo(jsonOutput bool) error {
log.Info("Starting remote storage info gathering")
result := &RemoteInfoResult{} result := &RemoteInfoResult{}
// Get storage info
storageInfo := v.Storage.Info() storageInfo := v.Storage.Info()
result.StorageType = storageInfo.Type result.StorageType = storageInfo.Type
result.StorageLocation = storageInfo.Location result.StorageLocation = storageInfo.Location
@@ -161,52 +161,23 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error {
v.printfStdout("Type: %s\n", storageInfo.Type) v.printfStdout("Type: %s\n", storageInfo.Type)
v.printfStdout("Location: %s\n", storageInfo.Location) v.printfStdout("Location: %s\n", storageInfo.Location)
v.printlnStdout() v.printlnStdout()
}
// List all snapshot metadata
if !jsonOutput {
v.printfStdout("Scanning snapshot metadata...\n") v.printfStdout("Scanning snapshot metadata...\n")
} }
snapshotMetadata, snapshotIDs, err := v.collectSnapshotMetadata()
if err != nil {
return err
}
if !jsonOutput {
v.printfStdout("Downloading %d manifest(s)...\n", len(snapshotIDs))
}
referencedBlobs := v.collectReferencedBlobsFromManifests(snapshotIDs, snapshotMetadata)
v.populateRemoteInfoResult(result, snapshotMetadata, snapshotIDs, referencedBlobs)
if err := v.scanRemoteBlobStorage(result, referencedBlobs, jsonOutput); err != nil {
return err
}
log.Info("Remote info complete",
"snapshots", result.TotalMetadataCount,
"total_blobs", result.TotalBlobCount,
"referenced_blobs", result.ReferencedBlobCount,
"orphaned_blobs", result.OrphanedBlobCount)
if jsonOutput {
enc := json.NewEncoder(v.Stdout)
enc.SetIndent("", " ")
return enc.Encode(result)
}
v.printRemoteInfoTable(result)
return nil
}
// collectSnapshotMetadata scans remote metadata and returns per-snapshot info and sorted IDs
func (v *Vaultik) collectSnapshotMetadata() (map[string]*SnapshotMetadataInfo, []string, error) {
snapshotMetadata := make(map[string]*SnapshotMetadataInfo) snapshotMetadata := make(map[string]*SnapshotMetadataInfo)
// Collect metadata files
metadataCh := v.Storage.ListStream(v.ctx, "metadata/") metadataCh := v.Storage.ListStream(v.ctx, "metadata/")
for obj := range metadataCh { for obj := range metadataCh {
if obj.Err != nil { if obj.Err != nil {
return nil, nil, fmt.Errorf("listing metadata: %w", obj.Err) return fmt.Errorf("listing metadata: %w", obj.Err)
} }
// Parse key: metadata/<snapshot-id>/<filename>
parts := strings.Split(obj.Key, "/") parts := strings.Split(obj.Key, "/")
if len(parts) < 3 { if len(parts) < 3 {
continue continue
@@ -214,11 +185,14 @@ func (v *Vaultik) collectSnapshotMetadata() (map[string]*SnapshotMetadataInfo, [
snapshotID := parts[1] snapshotID := parts[1]
if _, exists := snapshotMetadata[snapshotID]; !exists { if _, exists := snapshotMetadata[snapshotID]; !exists {
snapshotMetadata[snapshotID] = &SnapshotMetadataInfo{SnapshotID: snapshotID} snapshotMetadata[snapshotID] = &SnapshotMetadataInfo{
SnapshotID: snapshotID,
}
} }
info := snapshotMetadata[snapshotID] info := snapshotMetadata[snapshotID]
filename := parts[2] filename := parts[2]
if strings.HasPrefix(filename, "manifest") { if strings.HasPrefix(filename, "manifest") {
info.ManifestSize = obj.Size info.ManifestSize = obj.Size
} else if strings.HasPrefix(filename, "db") { } else if strings.HasPrefix(filename, "db") {
@@ -227,18 +201,19 @@ func (v *Vaultik) collectSnapshotMetadata() (map[string]*SnapshotMetadataInfo, [
info.TotalSize = info.ManifestSize + info.DatabaseSize info.TotalSize = info.ManifestSize + info.DatabaseSize
} }
// Sort snapshots by ID for consistent output
var snapshotIDs []string var snapshotIDs []string
for id := range snapshotMetadata { for id := range snapshotMetadata {
snapshotIDs = append(snapshotIDs, id) snapshotIDs = append(snapshotIDs, id)
} }
sort.Strings(snapshotIDs) sort.Strings(snapshotIDs)
return snapshotMetadata, snapshotIDs, nil // Download and parse all manifests to get referenced blobs
if !jsonOutput {
v.printfStdout("Downloading %d manifest(s)...\n", len(snapshotIDs))
} }
// collectReferencedBlobsFromManifests downloads manifests and returns referenced blob hashes with sizes referencedBlobs := make(map[string]int64) // hash -> compressed size
func (v *Vaultik) collectReferencedBlobsFromManifests(snapshotIDs []string, snapshotMetadata map[string]*SnapshotMetadataInfo) map[string]int64 {
referencedBlobs := make(map[string]int64)
for _, snapshotID := range snapshotIDs { for _, snapshotID := range snapshotIDs {
manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
@@ -255,8 +230,10 @@ func (v *Vaultik) collectReferencedBlobsFromManifests(snapshotIDs []string, snap
continue continue
} }
// Record blob info from manifest
info := snapshotMetadata[snapshotID] info := snapshotMetadata[snapshotID]
info.BlobCount = manifest.BlobCount info.BlobCount = manifest.BlobCount
var blobsSize int64 var blobsSize int64
for _, blob := range manifest.Blobs { for _, blob := range manifest.Blobs {
referencedBlobs[blob.Hash] = blob.CompressedSize referencedBlobs[blob.Hash] = blob.CompressedSize
@@ -265,11 +242,7 @@ func (v *Vaultik) collectReferencedBlobsFromManifests(snapshotIDs []string, snap
info.BlobsSize = blobsSize info.BlobsSize = blobsSize
} }
return referencedBlobs // Build result snapshots
}
// populateRemoteInfoResult fills in the result's snapshot and referenced blob stats
func (v *Vaultik) populateRemoteInfoResult(result *RemoteInfoResult, snapshotMetadata map[string]*SnapshotMetadataInfo, snapshotIDs []string, referencedBlobs map[string]int64) {
var totalMetadataSize int64 var totalMetadataSize int64
for _, id := range snapshotIDs { for _, id := range snapshotIDs {
info := snapshotMetadata[id] info := snapshotMetadata[id]
@@ -279,25 +252,26 @@ func (v *Vaultik) populateRemoteInfoResult(result *RemoteInfoResult, snapshotMet
result.TotalMetadataSize = totalMetadataSize result.TotalMetadataSize = totalMetadataSize
result.TotalMetadataCount = len(snapshotIDs) result.TotalMetadataCount = len(snapshotIDs)
// Calculate referenced blob stats
for _, size := range referencedBlobs { for _, size := range referencedBlobs {
result.ReferencedBlobCount++ result.ReferencedBlobCount++
result.ReferencedBlobSize += size result.ReferencedBlobSize += size
} }
}
// scanRemoteBlobStorage lists all blobs on remote and computes orphan stats // List all blobs on remote
func (v *Vaultik) scanRemoteBlobStorage(result *RemoteInfoResult, referencedBlobs map[string]int64, jsonOutput bool) error {
if !jsonOutput { if !jsonOutput {
v.printfStdout("Scanning blobs...\n") v.printfStdout("Scanning blobs...\n")
} }
blobCh := v.Storage.ListStream(v.ctx, "blobs/") allBlobs := make(map[string]int64) // hash -> size from storage
allBlobs := make(map[string]int64)
blobCh := v.Storage.ListStream(v.ctx, "blobs/")
for obj := range blobCh { for obj := range blobCh {
if obj.Err != nil { if obj.Err != nil {
return fmt.Errorf("listing blobs: %w", obj.Err) return fmt.Errorf("listing blobs: %w", obj.Err)
} }
// Extract hash from key: blobs/xx/yy/hash
parts := strings.Split(obj.Key, "/") parts := strings.Split(obj.Key, "/")
if len(parts) < 4 { if len(parts) < 4 {
continue continue
@@ -308,6 +282,7 @@ func (v *Vaultik) scanRemoteBlobStorage(result *RemoteInfoResult, referencedBlob
result.TotalBlobSize += obj.Size result.TotalBlobSize += obj.Size
} }
// Calculate orphaned blobs
for hash, size := range allBlobs { for hash, size := range allBlobs {
if _, referenced := referencedBlobs[hash]; !referenced { if _, referenced := referencedBlobs[hash]; !referenced {
result.OrphanedBlobCount++ result.OrphanedBlobCount++
@@ -315,11 +290,14 @@ func (v *Vaultik) scanRemoteBlobStorage(result *RemoteInfoResult, referencedBlob
} }
} }
return nil // Output results
if jsonOutput {
enc := json.NewEncoder(v.Stdout)
enc.SetIndent("", " ")
return enc.Encode(result)
} }
// printRemoteInfoTable renders the human-readable remote info output // Human-readable output
func (v *Vaultik) printRemoteInfoTable(result *RemoteInfoResult) {
v.printfStdout("\n=== Snapshot Metadata ===\n") v.printfStdout("\n=== Snapshot Metadata ===\n")
if len(result.Snapshots) == 0 { if len(result.Snapshots) == 0 {
v.printfStdout("No snapshots found\n") v.printfStdout("No snapshots found\n")
@@ -342,15 +320,20 @@ func (v *Vaultik) printRemoteInfoTable(result *RemoteInfoResult) {
v.printfStdout("\n=== Blob Storage ===\n") v.printfStdout("\n=== Blob Storage ===\n")
v.printfStdout("Total blobs on remote: %s (%s)\n", v.printfStdout("Total blobs on remote: %s (%s)\n",
humanize.Comma(int64(result.TotalBlobCount)), humanize.Bytes(uint64(result.TotalBlobSize))) humanize.Comma(int64(result.TotalBlobCount)),
humanize.Bytes(uint64(result.TotalBlobSize)))
v.printfStdout("Referenced by snapshots: %s (%s)\n", v.printfStdout("Referenced by snapshots: %s (%s)\n",
humanize.Comma(int64(result.ReferencedBlobCount)), humanize.Bytes(uint64(result.ReferencedBlobSize))) humanize.Comma(int64(result.ReferencedBlobCount)),
humanize.Bytes(uint64(result.ReferencedBlobSize)))
v.printfStdout("Orphaned (unreferenced): %s (%s)\n", v.printfStdout("Orphaned (unreferenced): %s (%s)\n",
humanize.Comma(int64(result.OrphanedBlobCount)), humanize.Bytes(uint64(result.OrphanedBlobSize))) humanize.Comma(int64(result.OrphanedBlobCount)),
humanize.Bytes(uint64(result.OrphanedBlobSize)))
if result.OrphanedBlobCount > 0 { if result.OrphanedBlobCount > 0 {
v.printfStdout("\nRun 'vaultik prune --remote' to remove orphaned blobs.\n") v.printfStdout("\nRun 'vaultik prune --remote' to remove orphaned blobs.\n")
} }
return nil
} }
// truncateString truncates a string to maxLen, adding "..." if truncated // truncateString truncates a string to maxLen, adding "..." if truncated

View File

@@ -27,19 +27,95 @@ type PruneBlobsResult struct {
func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
log.Info("Starting prune operation") log.Info("Starting prune operation")
allBlobsReferenced, err := v.collectReferencedBlobs() // Get all remote snapshots and their manifests
if err != nil { allBlobsReferenced := make(map[string]bool)
return err manifestCount := 0
// List all snapshots in storage
log.Info("Listing remote snapshots")
objectCh := v.Storage.ListStream(v.ctx, "metadata/")
var snapshotIDs []string
for object := range objectCh {
if object.Err != nil {
return fmt.Errorf("listing remote snapshots: %w", object.Err)
} }
allBlobs, err := v.listAllRemoteBlobs() // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
if err != nil { parts := strings.Split(object.Key, "/")
return err if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
// Check if this is a directory by looking for trailing slash
if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") {
snapshotID := parts[1]
// Only add unique snapshot IDs
found := false
for _, id := range snapshotIDs {
if id == snapshotID {
found = true
break
}
}
if !found {
snapshotIDs = append(snapshotIDs, snapshotID)
}
}
}
} }
unreferencedBlobs, totalSize := v.findUnreferencedBlobs(allBlobs, allBlobsReferenced) log.Info("Found manifests in remote storage", "count", len(snapshotIDs))
result := &PruneBlobsResult{BlobsFound: len(unreferencedBlobs)} // Download and parse each manifest to get referenced blobs
for _, snapshotID := range snapshotIDs {
log.Debug("Processing manifest", "snapshot_id", snapshotID)
manifest, err := v.downloadManifest(snapshotID)
if err != nil {
log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err)
continue
}
// Add all blobs from this manifest to our referenced set
for _, blob := range manifest.Blobs {
allBlobsReferenced[blob.Hash] = true
}
manifestCount++
}
log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced))
// List all blobs in storage
log.Info("Listing all blobs in storage")
allBlobs := make(map[string]int64) // hash -> size
blobObjectCh := v.Storage.ListStream(v.ctx, "blobs/")
for object := range blobObjectCh {
if object.Err != nil {
return fmt.Errorf("listing blobs: %w", object.Err)
}
// Extract hash from path like blobs/ab/cd/abcdef123456...
parts := strings.Split(object.Key, "/")
if len(parts) == 4 && parts[0] == "blobs" {
hash := parts[3]
allBlobs[hash] = object.Size
}
}
log.Info("Found blobs in storage", "count", len(allBlobs))
// Find unreferenced blobs
var unreferencedBlobs []string
var totalSize int64
for hash, size := range allBlobs {
if !allBlobsReferenced[hash] {
unreferencedBlobs = append(unreferencedBlobs, hash)
totalSize += size
}
}
result := &PruneBlobsResult{
BlobsFound: len(unreferencedBlobs),
}
if len(unreferencedBlobs) == 0 { if len(unreferencedBlobs) == 0 {
log.Info("No unreferenced blobs found") log.Info("No unreferenced blobs found")
@@ -50,15 +126,18 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
return nil return nil
} }
// Show what will be deleted
log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize))) log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize)))
if !opts.JSON { if !opts.JSON {
v.printfStdout("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize))) v.printfStdout("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize)))
} }
// Confirm unless --force is used (skip in JSON mode - require --force)
if !opts.Force && !opts.JSON { if !opts.Force && !opts.JSON {
v.printfStdout("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs)) v.printfStdout("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs))
var confirm string var confirm string
if _, err := v.scanStdin(&confirm); err != nil { if _, err := v.scanStdin(&confirm); err != nil {
// Treat EOF or error as "no"
v.printlnStdout("Cancelled") v.printlnStdout("Cancelled")
return nil return nil
} }
@@ -68,109 +147,10 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
} }
} }
v.deleteUnreferencedBlobs(unreferencedBlobs, allBlobs, result) // Delete unreferenced blobs
if opts.JSON {
return v.outputPruneBlobsJSON(result)
}
v.printfStdout("\nDeleted %d blob(s) totaling %s\n", result.BlobsDeleted, humanize.Bytes(uint64(result.BytesFreed)))
if result.BlobsFailed > 0 {
v.printfStdout("Failed to delete %d blob(s)\n", result.BlobsFailed)
}
return nil
}
// collectReferencedBlobs downloads all manifests and returns the set of referenced blob hashes
func (v *Vaultik) collectReferencedBlobs() (map[string]bool, error) {
log.Info("Listing remote snapshots")
snapshotIDs, err := v.listUniqueSnapshotIDs()
if err != nil {
return nil, fmt.Errorf("listing snapshot IDs: %w", err)
}
log.Info("Found manifests in remote storage", "count", len(snapshotIDs))
allBlobsReferenced := make(map[string]bool)
manifestCount := 0
for _, snapshotID := range snapshotIDs {
log.Debug("Processing manifest", "snapshot_id", snapshotID)
manifest, err := v.downloadManifest(snapshotID)
if err != nil {
log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err)
continue
}
for _, blob := range manifest.Blobs {
allBlobsReferenced[blob.Hash] = true
}
manifestCount++
}
log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced))
return allBlobsReferenced, nil
}
// listUniqueSnapshotIDs returns deduplicated snapshot IDs from remote metadata
func (v *Vaultik) listUniqueSnapshotIDs() ([]string, error) {
objectCh := v.Storage.ListStream(v.ctx, "metadata/")
seen := make(map[string]bool)
var snapshotIDs []string
for object := range objectCh {
if object.Err != nil {
return nil, fmt.Errorf("listing metadata objects: %w", object.Err)
}
parts := strings.Split(object.Key, "/")
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") {
snapshotID := parts[1]
if !seen[snapshotID] {
seen[snapshotID] = true
snapshotIDs = append(snapshotIDs, snapshotID)
}
}
}
}
return snapshotIDs, nil
}
// listAllRemoteBlobs returns a map of all blob hashes to their sizes in remote storage
func (v *Vaultik) listAllRemoteBlobs() (map[string]int64, error) {
log.Info("Listing all blobs in storage")
allBlobs := make(map[string]int64)
blobObjectCh := v.Storage.ListStream(v.ctx, "blobs/")
for object := range blobObjectCh {
if object.Err != nil {
return nil, fmt.Errorf("listing blobs: %w", object.Err)
}
parts := strings.Split(object.Key, "/")
if len(parts) == 4 && parts[0] == "blobs" {
allBlobs[parts[3]] = object.Size
}
}
log.Info("Found blobs in storage", "count", len(allBlobs))
return allBlobs, nil
}
// findUnreferencedBlobs returns blob hashes not referenced by any manifest and their total size
func (v *Vaultik) findUnreferencedBlobs(allBlobs map[string]int64, referenced map[string]bool) ([]string, int64) {
var unreferenced []string
var totalSize int64
for hash, size := range allBlobs {
if !referenced[hash] {
unreferenced = append(unreferenced, hash)
totalSize += size
}
}
return unreferenced, totalSize
}
// deleteUnreferencedBlobs deletes the given blobs from storage and populates the result
func (v *Vaultik) deleteUnreferencedBlobs(unreferencedBlobs []string, allBlobs map[string]int64, result *PruneBlobsResult) {
log.Info("Deleting unreferenced blobs") log.Info("Deleting unreferenced blobs")
deletedCount := 0
deletedSize := int64(0)
for i, hash := range unreferencedBlobs { for i, hash := range unreferencedBlobs {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash) blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash)
@@ -180,9 +160,10 @@ func (v *Vaultik) deleteUnreferencedBlobs(unreferencedBlobs []string, allBlobs m
continue continue
} }
result.BlobsDeleted++ deletedCount++
result.BytesFreed += allBlobs[hash] deletedSize += allBlobs[hash]
// Progress update every 100 blobs
if (i+1)%100 == 0 || i == len(unreferencedBlobs)-1 { if (i+1)%100 == 0 || i == len(unreferencedBlobs)-1 {
log.Info("Deletion progress", log.Info("Deletion progress",
"deleted", i+1, "deleted", i+1,
@@ -192,13 +173,26 @@ func (v *Vaultik) deleteUnreferencedBlobs(unreferencedBlobs []string, allBlobs m
} }
} }
result.BlobsFailed = len(unreferencedBlobs) - result.BlobsDeleted result.BlobsDeleted = deletedCount
result.BlobsFailed = len(unreferencedBlobs) - deletedCount
result.BytesFreed = deletedSize
log.Info("Prune complete", log.Info("Prune complete",
"deleted_count", result.BlobsDeleted, "deleted_count", deletedCount,
"deleted_size", humanize.Bytes(uint64(result.BytesFreed)), "deleted_size", humanize.Bytes(uint64(deletedSize)),
"failed", result.BlobsFailed, "failed", len(unreferencedBlobs)-deletedCount,
) )
if opts.JSON {
return v.outputPruneBlobsJSON(result)
}
v.printfStdout("\nDeleted %d blob(s) totaling %s\n", deletedCount, humanize.Bytes(uint64(deletedSize)))
if deletedCount < len(unreferencedBlobs) {
v.printfStdout("Failed to delete %d blob(s)\n", len(unreferencedBlobs)-deletedCount)
}
return nil
} }
// outputPruneBlobsJSON outputs the prune result as JSON // outputPruneBlobsJSON outputs the prune result as JSON

View File

@@ -22,13 +22,6 @@ import (
"golang.org/x/term" "golang.org/x/term"
) )
const (
// progressBarWidth is the character width of the progress bar display.
progressBarWidth = 40
// progressBarThrottle is the minimum interval between progress bar redraws.
progressBarThrottle = 100 * time.Millisecond
)
// RestoreOptions contains options for the restore operation // RestoreOptions contains options for the restore operation
type RestoreOptions struct { type RestoreOptions struct {
SnapshotID string SnapshotID string
@@ -55,9 +48,15 @@ type RestoreResult struct {
func (v *Vaultik) Restore(opts *RestoreOptions) error { func (v *Vaultik) Restore(opts *RestoreOptions) error {
startTime := time.Now() startTime := time.Now()
identity, err := v.prepareRestoreIdentity() // Check for age_secret_key
if v.Config.AgeSecretKey == "" {
return fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'")
}
// Parse the age identity
identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey)
if err != nil { if err != nil {
return err return fmt.Errorf("parsing age secret key: %w", err)
} }
log.Info("Starting restore operation", log.Info("Starting restore operation",
@@ -109,9 +108,31 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error {
} }
// Step 5: Restore files // Step 5: Restore files
result, err := v.restoreAllFiles(files, repos, opts, identity, chunkToBlobMap) result := &RestoreResult{}
blobCache, err := newBlobDiskCache(4 * v.Config.BlobSizeLimit.Int64())
if err != nil { if err != nil {
return err return fmt.Errorf("creating blob cache: %w", err)
}
defer func() { _ = blobCache.Close() }()
for i, file := range files {
if v.ctx.Err() != nil {
return v.ctx.Err()
}
if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil {
log.Error("Failed to restore file", "path", file.Path, "error", err)
// Continue with other files
continue
}
// Progress logging
if (i+1)%100 == 0 || i+1 == len(files) {
log.Info("Restore progress",
"files", fmt.Sprintf("%d/%d", i+1, len(files)),
"bytes", humanize.Bytes(uint64(result.BytesRestored)),
)
}
} }
result.Duration = time.Since(startTime) result.Duration = time.Since(startTime)
@@ -130,108 +151,8 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error {
result.Duration.Round(time.Second), result.Duration.Round(time.Second),
) )
if result.FilesFailed > 0 {
_, _ = fmt.Fprintf(v.Stdout, "\nWARNING: %d file(s) failed to restore:\n", result.FilesFailed)
for _, path := range result.FailedFiles {
_, _ = fmt.Fprintf(v.Stdout, " - %s\n", path)
}
}
// Run verification if requested // Run verification if requested
if opts.Verify { if opts.Verify {
if err := v.handleRestoreVerification(repos, files, opts, result); err != nil {
return err
}
}
if result.FilesFailed > 0 {
return fmt.Errorf("%d file(s) failed to restore", result.FilesFailed)
}
return nil
}
// prepareRestoreIdentity validates that an age secret key is configured and parses it
func (v *Vaultik) prepareRestoreIdentity() (age.Identity, error) {
if v.Config.AgeSecretKey == "" {
return nil, fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'")
}
identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey)
if err != nil {
return nil, fmt.Errorf("parsing age secret key: %w", err)
}
return identity, nil
}
// restoreAllFiles iterates over files and restores each one, tracking progress and failures
func (v *Vaultik) restoreAllFiles(
files []*database.File,
repos *database.Repositories,
opts *RestoreOptions,
identity age.Identity,
chunkToBlobMap map[string]*database.BlobChunk,
) (*RestoreResult, error) {
result := &RestoreResult{}
blobCache, err := newBlobDiskCache(4 * v.Config.BlobSizeLimit.Int64())
if err != nil {
return nil, fmt.Errorf("creating blob cache: %w", err)
}
defer func() { _ = blobCache.Close() }()
// Calculate total bytes for progress bar
var totalBytesExpected int64
for _, file := range files {
totalBytesExpected += file.Size
}
// Create progress bar if output is a terminal
bar := v.newProgressBar("Restoring", totalBytesExpected)
for i, file := range files {
if v.ctx.Err() != nil {
return nil, v.ctx.Err()
}
if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil {
log.Error("Failed to restore file", "path", file.Path, "error", err)
result.FilesFailed++
result.FailedFiles = append(result.FailedFiles, file.Path.String())
// Update progress bar even on failure
if bar != nil {
_ = bar.Add64(file.Size)
}
continue
}
// Update progress bar
if bar != nil {
_ = bar.Add64(file.Size)
}
// Progress logging (for non-terminal or structured logs)
if (i+1)%100 == 0 || i+1 == len(files) {
log.Info("Restore progress",
"files", fmt.Sprintf("%d/%d", i+1, len(files)),
"bytes", humanize.Bytes(uint64(result.BytesRestored)),
)
}
}
if bar != nil {
_ = bar.Finish()
}
return result, nil
}
// handleRestoreVerification runs post-restore verification if requested
func (v *Vaultik) handleRestoreVerification(
repos *database.Repositories,
files []*database.File,
opts *RestoreOptions,
result *RestoreResult,
) error {
if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil { if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil {
return fmt.Errorf("verification failed: %w", err) return fmt.Errorf("verification failed: %w", err)
} }
@@ -248,6 +169,8 @@ func (v *Vaultik) handleRestoreVerification(
result.FilesVerified, result.FilesVerified,
humanize.Bytes(uint64(result.BytesVerified)), humanize.Bytes(uint64(result.BytesVerified)),
) )
}
return nil return nil
} }
@@ -600,7 +523,22 @@ func (v *Vaultik) verifyRestoredFiles(
) )
// Create progress bar if output is a terminal // Create progress bar if output is a terminal
bar := v.newProgressBar("Verifying", totalBytes) var bar *progressbar.ProgressBar
if isTerminal() {
bar = progressbar.NewOptions64(
totalBytes,
progressbar.OptionSetDescription("Verifying"),
progressbar.OptionSetWriter(v.Stderr),
progressbar.OptionShowBytes(true),
progressbar.OptionShowCount(),
progressbar.OptionSetWidth(40),
progressbar.OptionThrottle(100*time.Millisecond),
progressbar.OptionOnCompletion(func() {
v.printfStderr("\n")
}),
progressbar.OptionSetRenderBlankState(true),
)
}
// Verify each file // Verify each file
for _, file := range regularFiles { for _, file := range regularFiles {
@@ -694,37 +632,7 @@ func (v *Vaultik) verifyFile(
return bytesVerified, nil return bytesVerified, nil
} }
// newProgressBar creates a terminal-aware progress bar with standard options. // isTerminal returns true if stdout is a terminal
// It returns nil if stdout is not a terminal. func isTerminal() bool {
func (v *Vaultik) newProgressBar(description string, total int64) *progressbar.ProgressBar { return term.IsTerminal(int(os.Stdout.Fd()))
if !v.isTerminal() {
return nil
}
return progressbar.NewOptions64(
total,
progressbar.OptionSetDescription(description),
progressbar.OptionSetWriter(v.Stderr),
progressbar.OptionShowBytes(true),
progressbar.OptionShowCount(),
progressbar.OptionSetWidth(progressBarWidth),
progressbar.OptionThrottle(progressBarThrottle),
progressbar.OptionOnCompletion(func() {
v.printfStderr("\n")
}),
progressbar.OptionSetRenderBlankState(true),
)
}
// isTerminal returns true if stdout is a terminal.
// It checks whether v.Stdout implements Fd() (i.e. is an *os.File),
// and falls back to false for non-file writers (e.g. in tests).
func (v *Vaultik) isTerminal() bool {
type fder interface {
Fd() uintptr
}
f, ok := v.Stdout.(fder)
if !ok {
return false
}
return term.IsTerminal(int(f.Fd()))
} }

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strings" "strings"
"text/tabwriter" "text/tabwriter"
@@ -90,55 +89,43 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
v.printfStdout("\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second)) v.printfStdout("\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second))
} }
// Prune old snapshots and unreferenced blobs if --prune was specified
if opts.Prune {
log.Info("Pruning enabled - deleting old snapshots and unreferenced blobs")
v.printlnStdout("\nPruning old snapshots (keeping latest)...")
if err := v.PurgeSnapshots(true, "", true); err != nil {
return fmt.Errorf("prune: purging old snapshots: %w", err)
}
v.printlnStdout("Pruning unreferenced blobs...")
if err := v.PruneBlobs(&PruneOptions{Force: true}); err != nil {
return fmt.Errorf("prune: removing unreferenced blobs: %w", err)
}
log.Info("Pruning complete")
}
return nil return nil
} }
// snapshotStats tracks aggregate statistics across directory scans
type snapshotStats struct {
totalFiles int
totalBytes int64
totalChunks int
totalBlobs int
totalBytesSkipped int64
totalFilesSkipped int
totalFilesDeleted int
totalBytesDeleted int64
totalBytesUploaded int64
totalBlobsUploaded int
uploadDuration time.Duration
}
// createNamedSnapshot creates a single named snapshot // createNamedSnapshot creates a single named snapshot
func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, snapName string, idx, total int) error { func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, snapName string, idx, total int) error {
snapshotStartTime := time.Now() snapshotStartTime := time.Now()
snapConfig := v.Config.Snapshots[snapName]
if total > 1 { if total > 1 {
v.printfStdout("\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName) v.printfStdout("\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName)
} }
resolvedDirs, err := v.resolveSnapshotPaths(snapName) // Resolve source directories to absolute paths
resolvedDirs := make([]string, 0, len(snapConfig.Paths))
for _, dir := range snapConfig.Paths {
absPath, err := filepath.Abs(dir)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
} }
// Resolve symlinks
resolvedPath, err := filepath.EvalSymlinks(absPath)
if err != nil {
// If the path doesn't exist yet, use the absolute path
if os.IsNotExist(err) {
resolvedPath = absPath
} else {
return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err)
}
}
resolvedDirs = append(resolvedDirs, resolvedPath)
}
// Create scanner with progress enabled (unless in cron mode)
// Pass the combined excludes for this snapshot
scanner := v.ScannerFactory(snapshot.ScannerParams{ scanner := v.ScannerFactory(snapshot.ScannerParams{
EnableProgress: !opts.Cron, EnableProgress: !opts.Cron,
Fs: v.Fs, Fs: v.Fs,
@@ -146,6 +133,20 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
SkipErrors: opts.SkipErrors, SkipErrors: opts.SkipErrors,
}) })
// Statistics tracking
totalFiles := 0
totalBytes := int64(0)
totalChunks := 0
totalBlobs := 0
totalBytesSkipped := int64(0)
totalFilesSkipped := 0
totalFilesDeleted := 0
totalBytesDeleted := int64(0)
totalBytesUploaded := int64(0)
totalBlobsUploaded := 0
uploadDuration := time.Duration(0)
// Create a new snapshot at the beginning (with snapshot name in ID)
snapshotID, err := v.SnapshotManager.CreateSnapshotWithName(v.ctx, hostname, snapName, v.Globals.Version, v.Globals.Commit) snapshotID, err := v.SnapshotManager.CreateSnapshotWithName(v.ctx, hostname, snapName, v.Globals.Version, v.Globals.Commit)
if err != nil { if err != nil {
return fmt.Errorf("creating snapshot: %w", err) return fmt.Errorf("creating snapshot: %w", err)
@@ -153,64 +154,12 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna
log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName) log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName)
v.printfStdout("Beginning snapshot: %s\n", snapshotID) v.printfStdout("Beginning snapshot: %s\n", snapshotID)
stats, err := v.scanAllDirectories(scanner, resolvedDirs, snapshotID)
if err != nil {
return err
}
v.collectUploadStats(scanner, stats)
if err := v.finalizeSnapshotMetadata(snapshotID, stats); err != nil {
return err
}
log.Info("Snapshot complete",
"snapshot_id", snapshotID,
"name", snapName,
"files", stats.totalFiles,
"blobs_uploaded", stats.totalBlobsUploaded,
"bytes_uploaded", stats.totalBytesUploaded,
"duration", time.Since(snapshotStartTime))
v.printSnapshotSummary(snapshotID, snapshotStartTime, stats)
return nil
}
// resolveSnapshotPaths resolves source directories to absolute paths with symlink resolution
func (v *Vaultik) resolveSnapshotPaths(snapName string) ([]string, error) {
snapConfig := v.Config.Snapshots[snapName]
resolvedDirs := make([]string, 0, len(snapConfig.Paths))
for _, dir := range snapConfig.Paths {
absPath, err := filepath.Abs(dir)
if err != nil {
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
}
resolvedPath, err := filepath.EvalSymlinks(absPath)
if err != nil {
if os.IsNotExist(err) {
resolvedPath = absPath
} else {
return nil, fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err)
}
}
resolvedDirs = append(resolvedDirs, resolvedPath)
}
return resolvedDirs, nil
}
// scanAllDirectories runs the scanner on each resolved directory and accumulates stats
func (v *Vaultik) scanAllDirectories(scanner *snapshot.Scanner, resolvedDirs []string, snapshotID string) (*snapshotStats, error) {
stats := &snapshotStats{}
for i, dir := range resolvedDirs { for i, dir := range resolvedDirs {
// Check if context is cancelled
select { select {
case <-v.ctx.Done(): case <-v.ctx.Done():
log.Info("Snapshot creation cancelled") log.Info("Snapshot creation cancelled")
return nil, v.ctx.Err() return v.ctx.Err()
default: default:
} }
@@ -218,17 +167,17 @@ func (v *Vaultik) scanAllDirectories(scanner *snapshot.Scanner, resolvedDirs []s
v.printfStdout("Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir) v.printfStdout("Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir)
result, err := scanner.Scan(v.ctx, dir, snapshotID) result, err := scanner.Scan(v.ctx, dir, snapshotID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to scan %s: %w", dir, err) return fmt.Errorf("failed to scan %s: %w", dir, err)
} }
stats.totalFiles += result.FilesScanned totalFiles += result.FilesScanned
stats.totalBytes += result.BytesScanned totalBytes += result.BytesScanned
stats.totalChunks += result.ChunksCreated totalChunks += result.ChunksCreated
stats.totalBlobs += result.BlobsCreated totalBlobs += result.BlobsCreated
stats.totalFilesSkipped += result.FilesSkipped totalFilesSkipped += result.FilesSkipped
stats.totalBytesSkipped += result.BytesSkipped totalBytesSkipped += result.BytesSkipped
stats.totalFilesDeleted += result.FilesDeleted totalFilesDeleted += result.FilesDeleted
stats.totalBytesDeleted += result.BytesDeleted totalBytesDeleted += result.BytesDeleted
log.Info("Directory scan complete", log.Info("Directory scan complete",
"path", dir, "path", dir,
@@ -239,79 +188,85 @@ func (v *Vaultik) scanAllDirectories(scanner *snapshot.Scanner, resolvedDirs []s
"chunks", result.ChunksCreated, "chunks", result.ChunksCreated,
"blobs", result.BlobsCreated, "blobs", result.BlobsCreated,
"duration", result.EndTime.Sub(result.StartTime)) "duration", result.EndTime.Sub(result.StartTime))
// Remove per-directory summary - the scanner already prints its own summary
} }
return stats, nil // Get upload statistics from scanner progress if available
}
// collectUploadStats gathers upload statistics from the scanner's progress reporter
func (v *Vaultik) collectUploadStats(scanner *snapshot.Scanner, stats *snapshotStats) {
if s := scanner.GetProgress(); s != nil { if s := scanner.GetProgress(); s != nil {
progressStats := s.GetStats() stats := s.GetStats()
stats.totalBytesUploaded = progressStats.BytesUploaded.Load() totalBytesUploaded = stats.BytesUploaded.Load()
stats.totalBlobsUploaded = int(progressStats.BlobsUploaded.Load()) totalBlobsUploaded = int(stats.BlobsUploaded.Load())
stats.uploadDuration = time.Duration(progressStats.UploadDurationMs.Load()) * time.Millisecond uploadDuration = time.Duration(stats.UploadDurationMs.Load()) * time.Millisecond
}
} }
// finalizeSnapshotMetadata updates stats, marks complete, and exports metadata // Update snapshot statistics with extended fields
func (v *Vaultik) finalizeSnapshotMetadata(snapshotID string, stats *snapshotStats) error {
extStats := snapshot.ExtendedBackupStats{ extStats := snapshot.ExtendedBackupStats{
BackupStats: snapshot.BackupStats{ BackupStats: snapshot.BackupStats{
FilesScanned: stats.totalFiles, FilesScanned: totalFiles,
BytesScanned: stats.totalBytes, BytesScanned: totalBytes,
ChunksCreated: stats.totalChunks, ChunksCreated: totalChunks,
BlobsCreated: stats.totalBlobs, BlobsCreated: totalBlobs,
BytesUploaded: stats.totalBytesUploaded, BytesUploaded: totalBytesUploaded,
}, },
BlobUncompressedSize: 0, BlobUncompressedSize: 0, // Will be set from database query below
CompressionLevel: v.Config.CompressionLevel, CompressionLevel: v.Config.CompressionLevel,
UploadDurationMs: stats.uploadDuration.Milliseconds(), UploadDurationMs: uploadDuration.Milliseconds(),
} }
if err := v.SnapshotManager.UpdateSnapshotStatsExtended(v.ctx, snapshotID, extStats); err != nil { if err := v.SnapshotManager.UpdateSnapshotStatsExtended(v.ctx, snapshotID, extStats); err != nil {
return fmt.Errorf("updating snapshot stats: %w", err) return fmt.Errorf("updating snapshot stats: %w", err)
} }
// Mark snapshot as complete
if err := v.SnapshotManager.CompleteSnapshot(v.ctx, snapshotID); err != nil { if err := v.SnapshotManager.CompleteSnapshot(v.ctx, snapshotID); err != nil {
return fmt.Errorf("completing snapshot: %w", err) return fmt.Errorf("completing snapshot: %w", err)
} }
// Export snapshot metadata
// Export snapshot metadata without closing the database
// The export function should handle its own database connection
if err := v.SnapshotManager.ExportSnapshotMetadata(v.ctx, v.Config.IndexPath, snapshotID); err != nil { if err := v.SnapshotManager.ExportSnapshotMetadata(v.ctx, v.Config.IndexPath, snapshotID); err != nil {
return fmt.Errorf("exporting snapshot metadata: %w", err) return fmt.Errorf("exporting snapshot metadata: %w", err)
} }
return nil // Calculate final statistics
} snapshotDuration := time.Since(snapshotStartTime)
totalFilesChanged := totalFiles - totalFilesSkipped
totalBytesChanged := totalBytes
totalBytesAll := totalBytes + totalBytesSkipped
// formatUploadSpeed formats bytes uploaded and duration into a human-readable speed string // Calculate upload speed
func formatUploadSpeed(bytesUploaded int64, duration time.Duration) string { var avgUploadSpeed string
if bytesUploaded <= 0 || duration <= 0 { if totalBytesUploaded > 0 && uploadDuration > 0 {
return "N/A" bytesPerSec := float64(totalBytesUploaded) / uploadDuration.Seconds()
}
bytesPerSec := float64(bytesUploaded) / duration.Seconds()
bitsPerSec := bytesPerSec * 8 bitsPerSec := bytesPerSec * 8
switch { if bitsPerSec >= 1e9 {
case bitsPerSec >= 1e9: avgUploadSpeed = fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9)
return fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9) } else if bitsPerSec >= 1e6 {
case bitsPerSec >= 1e6: avgUploadSpeed = fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6)
return fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6) } else if bitsPerSec >= 1e3 {
case bitsPerSec >= 1e3: avgUploadSpeed = fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3)
return fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3) } else {
default: avgUploadSpeed = fmt.Sprintf("%.0f bit/s", bitsPerSec)
return fmt.Sprintf("%.0f bit/s", bitsPerSec)
} }
} else {
avgUploadSpeed = "N/A"
} }
// printSnapshotSummary prints the comprehensive snapshot completion summary
func (v *Vaultik) printSnapshotSummary(snapshotID string, startTime time.Time, stats *snapshotStats) {
snapshotDuration := time.Since(startTime)
totalFilesChanged := stats.totalFiles - stats.totalFilesSkipped
totalBytesAll := stats.totalBytes + stats.totalBytesSkipped
// Get total blob sizes from database // Get total blob sizes from database
totalBlobSizeCompressed, totalBlobSizeUncompressed := v.getSnapshotBlobSizes(snapshotID) totalBlobSizeCompressed := int64(0)
totalBlobSizeUncompressed := int64(0)
if blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID); err == nil {
for _, hash := range blobHashes {
if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil {
totalBlobSizeCompressed += blob.CompressedSize
totalBlobSizeUncompressed += blob.UncompressedSize
}
}
}
// Calculate compression ratio
var compressionRatio float64 var compressionRatio float64
if totalBlobSizeUncompressed > 0 { if totalBlobSizeUncompressed > 0 {
compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed) compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed)
@@ -319,96 +274,60 @@ func (v *Vaultik) printSnapshotSummary(snapshotID string, startTime time.Time, s
compressionRatio = 1.0 compressionRatio = 1.0
} }
// Print comprehensive summary
v.printfStdout("=== Snapshot Complete ===\n") v.printfStdout("=== Snapshot Complete ===\n")
v.printfStdout("ID: %s\n", snapshotID) v.printfStdout("ID: %s\n", snapshotID)
v.printfStdout("Files: %s examined, %s to process, %s unchanged", v.printfStdout("Files: %s examined, %s to process, %s unchanged",
formatNumber(stats.totalFiles), formatNumber(totalFiles),
formatNumber(totalFilesChanged), formatNumber(totalFilesChanged),
formatNumber(stats.totalFilesSkipped)) formatNumber(totalFilesSkipped))
if stats.totalFilesDeleted > 0 { if totalFilesDeleted > 0 {
v.printfStdout(", %s deleted", formatNumber(stats.totalFilesDeleted)) v.printfStdout(", %s deleted", formatNumber(totalFilesDeleted))
} }
v.printlnStdout() v.printlnStdout()
v.printfStdout("Data: %s total (%s to process)", v.printfStdout("Data: %s total (%s to process)",
humanize.Bytes(uint64(totalBytesAll)), humanize.Bytes(uint64(totalBytesAll)),
humanize.Bytes(uint64(stats.totalBytes))) humanize.Bytes(uint64(totalBytesChanged)))
if stats.totalBytesDeleted > 0 { if totalBytesDeleted > 0 {
v.printfStdout(", %s deleted", humanize.Bytes(uint64(stats.totalBytesDeleted))) v.printfStdout(", %s deleted", humanize.Bytes(uint64(totalBytesDeleted)))
} }
v.printlnStdout() v.printlnStdout()
if stats.totalBlobsUploaded > 0 { if totalBlobsUploaded > 0 {
v.printfStdout("Storage: %s compressed from %s (%.2fx)\n", v.printfStdout("Storage: %s compressed from %s (%.2fx)\n",
humanize.Bytes(uint64(totalBlobSizeCompressed)), humanize.Bytes(uint64(totalBlobSizeCompressed)),
humanize.Bytes(uint64(totalBlobSizeUncompressed)), humanize.Bytes(uint64(totalBlobSizeUncompressed)),
compressionRatio) compressionRatio)
v.printfStdout("Upload: %d blobs, %s in %s (%s)\n", v.printfStdout("Upload: %d blobs, %s in %s (%s)\n",
stats.totalBlobsUploaded, totalBlobsUploaded,
humanize.Bytes(uint64(stats.totalBytesUploaded)), humanize.Bytes(uint64(totalBytesUploaded)),
formatDuration(stats.uploadDuration), formatDuration(uploadDuration),
formatUploadSpeed(stats.totalBytesUploaded, stats.uploadDuration)) avgUploadSpeed)
} }
v.printfStdout("Duration: %s\n", formatDuration(snapshotDuration)) v.printfStdout("Duration: %s\n", formatDuration(snapshotDuration))
if opts.Prune {
log.Info("Pruning enabled - will delete old snapshots after snapshot")
// TODO: Implement pruning
} }
// getSnapshotBlobSizes returns total compressed and uncompressed blob sizes for a snapshot return nil
func (v *Vaultik) getSnapshotBlobSizes(snapshotID string) (compressed int64, uncompressed int64) {
blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID)
if err != nil {
return 0, 0
}
for _, hash := range blobHashes {
if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil {
compressed += blob.CompressedSize
uncompressed += blob.UncompressedSize
}
}
return compressed, uncompressed
} }
// ListSnapshots lists all snapshots // ListSnapshots lists all snapshots
func (v *Vaultik) ListSnapshots(jsonOutput bool) error { func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
log.Info("Listing snapshots") // Get all remote snapshots
remoteSnapshots, err := v.listRemoteSnapshotIDs()
if err != nil {
return err
}
localSnapshotMap, err := v.reconcileLocalWithRemote(remoteSnapshots)
if err != nil {
return err
}
snapshots, err := v.buildSnapshotInfoList(remoteSnapshots, localSnapshotMap)
if err != nil {
return err
}
// Sort by timestamp (newest first)
sort.Slice(snapshots, func(i, j int) bool {
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
})
if jsonOutput {
encoder := json.NewEncoder(v.Stdout)
encoder.SetIndent("", " ")
return encoder.Encode(snapshots)
}
return v.printSnapshotTable(snapshots)
}
// listRemoteSnapshotIDs returns a set of snapshot IDs found in remote storage
func (v *Vaultik) listRemoteSnapshotIDs() (map[string]bool, error) {
remoteSnapshots := make(map[string]bool) remoteSnapshots := make(map[string]bool)
objectCh := v.Storage.ListStream(v.ctx, "metadata/") objectCh := v.Storage.ListStream(v.ctx, "metadata/")
for object := range objectCh { for object := range objectCh {
if object.Err != nil { if object.Err != nil {
return nil, fmt.Errorf("listing remote snapshots: %w", object.Err) return fmt.Errorf("listing remote snapshots: %w", object.Err)
} }
// Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/
parts := strings.Split(object.Key, "/") parts := strings.Split(object.Key, "/")
if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" {
// Skip macOS resource fork files (._*) and other hidden files
if strings.HasPrefix(parts[1], ".") { if strings.HasPrefix(parts[1], ".") {
continue continue
} }
@@ -416,46 +335,56 @@ func (v *Vaultik) listRemoteSnapshotIDs() (map[string]bool, error) {
} }
} }
return remoteSnapshots, nil // Get all local snapshots
}
// reconcileLocalWithRemote removes local snapshots not in remote and returns the surviving local map
func (v *Vaultik) reconcileLocalWithRemote(remoteSnapshots map[string]bool) (map[string]*database.Snapshot, error) {
localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000) localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000)
if err != nil { if err != nil {
return nil, fmt.Errorf("listing local snapshots: %w", err) return fmt.Errorf("listing local snapshots: %w", err)
} }
// Build a map of local snapshots for quick lookup
localSnapshotMap := make(map[string]*database.Snapshot) localSnapshotMap := make(map[string]*database.Snapshot)
for _, s := range localSnapshots { for _, s := range localSnapshots {
localSnapshotMap[s.ID.String()] = s localSnapshotMap[s.ID.String()] = s
} }
for _, snap := range localSnapshots { // Remove local snapshots that don't exist remotely
snapshotIDStr := snap.ID.String() for _, snapshot := range localSnapshots {
snapshotIDStr := snapshot.ID.String()
if !remoteSnapshots[snapshotIDStr] { if !remoteSnapshots[snapshotIDStr] {
log.Info("Removing local snapshot not found in remote", "snapshot_id", snap.ID) log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID)
if err := v.deleteSnapshotFromLocalDB(snapshotIDStr); err != nil {
log.Error("Failed to delete local snapshot", "snapshot_id", snap.ID, "error", err) // Delete related records first to avoid foreign key constraints
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err)
}
// Now delete the snapshot itself
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err)
} else { } else {
log.Info("Deleted local snapshot not found in remote", "snapshot_id", snap.ID) log.Info("Deleted local snapshot not found in remote", "snapshot_id", snapshot.ID)
delete(localSnapshotMap, snapshotIDStr) delete(localSnapshotMap, snapshotIDStr)
} }
} }
} }
return localSnapshotMap, nil // Build final snapshot list
}
// buildSnapshotInfoList constructs SnapshotInfo entries from remote IDs and local data
func (v *Vaultik) buildSnapshotInfoList(remoteSnapshots map[string]bool, localSnapshotMap map[string]*database.Snapshot) ([]SnapshotInfo, error) {
snapshots := make([]SnapshotInfo, 0, len(remoteSnapshots)) snapshots := make([]SnapshotInfo, 0, len(remoteSnapshots))
for snapshotID := range remoteSnapshots { for snapshotID := range remoteSnapshots {
// Check if we have this snapshot locally
if localSnap, exists := localSnapshotMap[snapshotID]; exists && localSnap.CompletedAt != nil { if localSnap, exists := localSnapshotMap[snapshotID]; exists && localSnap.CompletedAt != nil {
// Get total compressed size of all blobs referenced by this snapshot
totalSize, err := v.Repositories.Snapshots.GetSnapshotTotalCompressedSize(v.ctx, snapshotID) totalSize, err := v.Repositories.Snapshots.GetSnapshotTotalCompressedSize(v.ctx, snapshotID)
if err != nil { if err != nil {
log.Warn("Failed to get total compressed size", "id", snapshotID, "error", err) log.Warn("Failed to get total compressed size", "id", snapshotID, "error", err)
// Fall back to stored blob size
totalSize = localSnap.BlobSize totalSize = localSnap.BlobSize
} }
@@ -465,15 +394,17 @@ func (v *Vaultik) buildSnapshotInfoList(remoteSnapshots map[string]bool, localSn
CompressedSize: totalSize, CompressedSize: totalSize,
}) })
} else { } else {
// Remote snapshot not in local DB - fetch manifest to get size
timestamp, err := parseSnapshotTimestamp(snapshotID) timestamp, err := parseSnapshotTimestamp(snapshotID)
if err != nil { if err != nil {
log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err) log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err)
continue continue
} }
// Try to download manifest to get size
totalSize, err := v.getManifestSize(snapshotID) totalSize, err := v.getManifestSize(snapshotID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err) return fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err)
} }
snapshots = append(snapshots, SnapshotInfo{ snapshots = append(snapshots, SnapshotInfo{
@@ -484,13 +415,22 @@ func (v *Vaultik) buildSnapshotInfoList(remoteSnapshots map[string]bool, localSn
} }
} }
return snapshots, nil // Sort by timestamp (newest first)
sort.Slice(snapshots, func(i, j int) bool {
return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
})
if jsonOutput {
// JSON output
encoder := json.NewEncoder(v.Stdout)
encoder.SetIndent("", " ")
return encoder.Encode(snapshots)
} }
// printSnapshotTable renders the snapshot list as a formatted table // Table output
func (v *Vaultik) printSnapshotTable(snapshots []SnapshotInfo) error {
w := tabwriter.NewWriter(v.Stdout, 0, 0, 3, ' ', 0) w := tabwriter.NewWriter(v.Stdout, 0, 0, 3, ' ', 0)
// Show configured snapshots from config file
if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil { if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil {
return err return err
} }
@@ -511,6 +451,7 @@ func (v *Vaultik) printSnapshotTable(snapshots []SnapshotInfo) error {
return err return err
} }
// Show remote snapshots
if _, err := fmt.Fprintln(w, "REMOTE SNAPSHOTS:"); err != nil { if _, err := fmt.Fprintln(w, "REMOTE SNAPSHOTS:"); err != nil {
return err return err
} }
@@ -563,9 +504,26 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
return snapshots[i].Timestamp.After(snapshots[j].Timestamp) return snapshots[i].Timestamp.After(snapshots[j].Timestamp)
}) })
toDelete, err := v.collectSnapshotsToPurge(snapshots, keepLatest, olderThan) var toDelete []SnapshotInfo
if keepLatest {
// Keep only the most recent snapshot
if len(snapshots) > 1 {
toDelete = snapshots[1:]
}
} else if olderThan != "" {
// Parse duration
duration, err := parseDuration(olderThan)
if err != nil { if err != nil {
return err return fmt.Errorf("invalid duration: %w", err)
}
cutoff := time.Now().UTC().Add(-duration)
for _, snap := range snapshots {
if snap.Timestamp.Before(cutoff) {
toDelete = append(toDelete, snap)
}
}
} }
if len(toDelete) == 0 { if len(toDelete) == 0 {
@@ -573,41 +531,6 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
return nil return nil
} }
return v.confirmAndExecutePurge(toDelete, force)
}
// collectSnapshotsToPurge determines which snapshots to delete based on retention criteria
func (v *Vaultik) collectSnapshotsToPurge(snapshots []SnapshotInfo, keepLatest bool, olderThan string) ([]SnapshotInfo, error) {
if keepLatest {
// Keep only the most recent snapshot
if len(snapshots) > 1 {
return snapshots[1:], nil
}
return nil, nil
}
if olderThan != "" {
// Parse duration
duration, err := parseDuration(olderThan)
if err != nil {
return nil, fmt.Errorf("invalid duration: %w", err)
}
cutoff := time.Now().UTC().Add(-duration)
var toDelete []SnapshotInfo
for _, snap := range snapshots {
if snap.Timestamp.Before(cutoff) {
toDelete = append(toDelete, snap)
}
}
return toDelete, nil
}
return nil, nil
}
// confirmAndExecutePurge shows deletion candidates, confirms with user, and deletes snapshots
func (v *Vaultik) confirmAndExecutePurge(toDelete []SnapshotInfo, force bool) error {
// Show what will be deleted // Show what will be deleted
v.printfStdout("The following snapshots will be deleted:\n\n") v.printfStdout("The following snapshots will be deleted:\n\n")
for _, snap := range toDelete { for _, snap := range toDelete {
@@ -673,7 +596,29 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
result.Mode = "deep" result.Mode = "deep"
} }
v.printVerifyHeader(snapshotID, opts) // Parse snapshot ID to extract timestamp
parts := strings.Split(snapshotID, "-")
var snapshotTime time.Time
if len(parts) >= 3 {
// Format: hostname-YYYYMMDD-HHMMSSZ
dateStr := parts[len(parts)-2]
timeStr := parts[len(parts)-1]
if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") {
timeStr = timeStr[:6] // Remove Z
timestamp, err := time.Parse("20060102150405", dateStr+timeStr)
if err == nil {
snapshotTime = timestamp
}
}
}
if !opts.JSON {
v.printfStdout("Verifying snapshot %s\n", snapshotID)
if !snapshotTime.IsZero() {
v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST"))
}
v.printlnStdout()
}
// Download and parse manifest // Download and parse manifest
manifest, err := v.downloadManifest(snapshotID) manifest, err := v.downloadManifest(snapshotID)
@@ -704,40 +649,10 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio
v.printfStdout("Checking blob existence...\n") v.printfStdout("Checking blob existence...\n")
} }
result.Verified, result.Missing, result.MissingSize = v.verifyManifestBlobsExist(manifest, opts) missing := 0
verified := 0
missingSize := int64(0)
return v.formatVerifyResult(result, manifest, opts)
}
// printVerifyHeader prints the snapshot ID and parsed timestamp for verification output
func (v *Vaultik) printVerifyHeader(snapshotID string, opts *VerifyOptions) {
// Parse snapshot ID to extract timestamp
parts := strings.Split(snapshotID, "-")
var snapshotTime time.Time
if len(parts) >= 3 {
// Format: hostname-YYYYMMDD-HHMMSSZ
dateStr := parts[len(parts)-2]
timeStr := parts[len(parts)-1]
if len(dateStr) == 8 && len(timeStr) == 7 && strings.HasSuffix(timeStr, "Z") {
timeStr = timeStr[:6] // Remove Z
timestamp, err := time.Parse("20060102150405", dateStr+timeStr)
if err == nil {
snapshotTime = timestamp
}
}
}
if !opts.JSON {
v.printfStdout("Verifying snapshot %s\n", snapshotID)
if !snapshotTime.IsZero() {
v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST"))
}
v.printlnStdout()
}
}
// verifyManifestBlobsExist checks that each blob in the manifest exists in storage
func (v *Vaultik) verifyManifestBlobsExist(manifest *snapshot.Manifest, opts *VerifyOptions) (verified, missing int, missingSize int64) {
for _, blob := range manifest.Blobs { for _, blob := range manifest.Blobs {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash) blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash)
@@ -753,15 +668,15 @@ func (v *Vaultik) verifyManifestBlobsExist(manifest *snapshot.Manifest, opts *Ve
verified++ verified++
} }
} }
return verified, missing, missingSize
}
// formatVerifyResult outputs the final verification results as JSON or human-readable text result.Verified = verified
func (v *Vaultik) formatVerifyResult(result *VerifyResult, manifest *snapshot.Manifest, opts *VerifyOptions) error { result.Missing = missing
result.MissingSize = missingSize
if opts.JSON { if opts.JSON {
if result.Missing > 0 { if missing > 0 {
result.Status = "failed" result.Status = "failed"
result.ErrorMessage = fmt.Sprintf("%d blobs are missing", result.Missing) result.ErrorMessage = fmt.Sprintf("%d blobs are missing", missing)
} else { } else {
result.Status = "ok" result.Status = "ok"
} }
@@ -769,19 +684,20 @@ func (v *Vaultik) formatVerifyResult(result *VerifyResult, manifest *snapshot.Ma
} }
v.printfStdout("\nVerification complete:\n") v.printfStdout("\nVerification complete:\n")
v.printfStdout(" Verified: %d blobs (%s)\n", result.Verified, v.printfStdout(" Verified: %d blobs (%s)\n", verified,
humanize.Bytes(uint64(manifest.TotalCompressedSize-result.MissingSize))) humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize)))
if result.Missing > 0 { if missing > 0 {
v.printfStdout(" Missing: %d blobs (%s)\n", result.Missing, humanize.Bytes(uint64(result.MissingSize))) v.printfStdout(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize)))
} else { } else {
v.printfStdout(" Missing: 0 blobs\n") v.printfStdout(" Missing: 0 blobs\n")
} }
v.printfStdout(" Status: ") v.printfStdout(" Status: ")
if result.Missing > 0 { if missing > 0 {
v.printfStdout("FAILED - %d blobs are missing\n", result.Missing) v.printfStdout("FAILED - %d blobs are missing\n", missing)
return fmt.Errorf("%d blobs are missing", result.Missing) return fmt.Errorf("%d blobs are missing", missing)
} } else {
v.printfStdout("OK - All blobs verified\n") v.printfStdout("OK - All blobs verified\n")
}
return nil return nil
} }
@@ -977,27 +893,9 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov
// RemoveAllSnapshots removes all snapshots from local database and optionally from remote // RemoveAllSnapshots removes all snapshots from local database and optionally from remote
func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) { func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) {
snapshotIDs, err := v.listAllRemoteSnapshotIDs() result := &RemoveResult{}
if err != nil {
return nil, err
}
if len(snapshotIDs) == 0 { // List all snapshots
if !opts.JSON {
v.printlnStdout("No snapshots found")
}
return &RemoveResult{}, nil
}
if opts.DryRun {
return v.handleRemoveAllDryRun(snapshotIDs, opts)
}
return v.executeRemoveAll(snapshotIDs, opts)
}
// listAllRemoteSnapshotIDs collects all unique snapshot IDs from remote storage
func (v *Vaultik) listAllRemoteSnapshotIDs() ([]string, error) {
log.Info("Listing all snapshots") log.Info("Listing all snapshots")
objectCh := v.Storage.ListStream(v.ctx, "metadata/") objectCh := v.Storage.ListStream(v.ctx, "metadata/")
@@ -1029,15 +927,16 @@ func (v *Vaultik) listAllRemoteSnapshotIDs() ([]string, error) {
} }
} }
return snapshotIDs, nil if len(snapshotIDs) == 0 {
if !opts.JSON {
v.printlnStdout("No snapshots found")
}
return result, nil
} }
// handleRemoveAllDryRun handles the dry-run mode for removing all snapshots if opts.DryRun {
func (v *Vaultik) handleRemoveAllDryRun(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) { result.DryRun = true
result := &RemoveResult{ result.SnapshotsRemoved = snapshotIDs
DryRun: true,
SnapshotsRemoved: snapshotIDs,
}
if !opts.JSON { if !opts.JSON {
v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs)) v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs))
for _, id := range snapshotIDs { for _, id := range snapshotIDs {
@@ -1054,8 +953,6 @@ func (v *Vaultik) handleRemoveAllDryRun(snapshotIDs []string, opts *RemoveOption
return result, nil return result, nil
} }
// executeRemoveAll removes all snapshots from local database and optionally from remote storage
func (v *Vaultik) executeRemoveAll(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) {
// --all requires --force // --all requires --force
if !opts.Force { if !opts.Force {
return nil, fmt.Errorf("--all requires --force") return nil, fmt.Errorf("--all requires --force")
@@ -1063,7 +960,6 @@ func (v *Vaultik) executeRemoveAll(snapshotIDs []string, opts *RemoveOptions) (*
log.Info("Removing all snapshots", "count", len(snapshotIDs)) log.Info("Removing all snapshots", "count", len(snapshotIDs))
result := &RemoveResult{}
for _, snapshotID := range snapshotIDs { for _, snapshotID := range snapshotIDs {
log.Info("Removing snapshot", "snapshot_id", snapshotID) log.Info("Removing snapshot", "snapshot_id", snapshotID)
@@ -1107,16 +1003,16 @@ func (v *Vaultik) deleteSnapshotFromLocalDB(snapshotID string) error {
// Delete related records first to avoid foreign key constraints // Delete related records first to avoid foreign key constraints
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotID); err != nil { if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotID); err != nil {
return fmt.Errorf("deleting snapshot files for %s: %w", snapshotID, err) log.Error("Failed to delete snapshot files", "snapshot_id", snapshotID, "error", err)
} }
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotID); err != nil { if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotID); err != nil {
return fmt.Errorf("deleting snapshot blobs for %s: %w", snapshotID, err) log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshotID, "error", err)
} }
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotID); err != nil { if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotID); err != nil {
return fmt.Errorf("deleting snapshot uploads for %s: %w", snapshotID, err) log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshotID, "error", err)
} }
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotID); err != nil { if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotID); err != nil {
return fmt.Errorf("deleting snapshot record %s: %w", snapshotID, err) log.Error("Failed to delete snapshot record", "snapshot_id", snapshotID, "error", err)
} }
return nil return nil
@@ -1230,18 +1126,25 @@ func (v *Vaultik) PruneDatabase() (*PruneResult, error) {
return result, nil return result, nil
} }
// validTableNameRe matches table names containing only lowercase alphanumeric characters and underscores. // allowedTableNames is the exhaustive whitelist of table names that may be
var validTableNameRe = regexp.MustCompile(`^[a-z0-9_]+$`) // passed to getTableCount. Any name not in this set is rejected, preventing
// SQL injection even if caller-controlled input is accidentally supplied.
// getTableCount returns the count of rows in a table. var allowedTableNames = map[string]struct{}{
// The tableName is sanitized to only allow [a-z0-9_] characters to prevent SQL injection. "files": {},
func (v *Vaultik) getTableCount(tableName string) (int64, error) { "chunks": {},
if v.DB == nil { "blobs": {},
return 0, nil
} }
if !validTableNameRe.MatchString(tableName) { // getTableCount returns the number of rows in the given table.
return 0, fmt.Errorf("invalid table name: %q", tableName) // tableName must appear in the allowedTableNames whitelist; all other values
// are rejected with an error, preventing SQL injection.
func (v *Vaultik) getTableCount(tableName string) (int64, error) {
if _, ok := allowedTableNames[tableName]; !ok {
return 0, fmt.Errorf("table name not allowed: %q", tableName)
}
if v.DB == nil {
return 0, nil
} }
var count int64 var count int64

View File

@@ -1,23 +0,0 @@
package vaultik
import (
"testing"
)
// TestSnapshotCreateOptions_PruneFlag verifies the Prune field exists on
// SnapshotCreateOptions and can be set.
func TestSnapshotCreateOptions_PruneFlag(t *testing.T) {
opts := &SnapshotCreateOptions{
Prune: true,
}
if !opts.Prune {
t.Error("Expected Prune to be true")
}
opts2 := &SnapshotCreateOptions{
Prune: false,
}
if opts2.Prune {
t.Error("Expected Prune to be false")
}
}

View File

@@ -0,0 +1,51 @@
package vaultik
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAllowedTableNames(t *testing.T) {
// Verify the whitelist contains exactly the expected tables
expected := []string{"files", "chunks", "blobs"}
assert.Len(t, allowedTableNames, len(expected))
for _, name := range expected {
_, ok := allowedTableNames[name]
assert.True(t, ok, "expected %q in allowedTableNames", name)
}
}
func TestGetTableCount_RejectsInvalidNames(t *testing.T) {
v := &Vaultik{} // DB is nil, but rejection happens before DB access
v.DB = nil // explicit
tests := []struct {
name string
tableName string
wantErr bool
}{
{"allowed files", "files", false},
{"allowed chunks", "chunks", false},
{"allowed blobs", "blobs", false},
{"sql injection attempt", "files; DROP TABLE files--", true},
{"unknown table", "users", true},
{"empty string", "", true},
{"uppercase", "FILES", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
count, err := v.getTableCount(tt.tableName)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), "not allowed")
assert.Equal(t, int64(0), count)
} else {
// DB is nil so returns 0, nil for allowed names
assert.NoError(t, err)
assert.Equal(t, int64(0), count)
}
})
}
}

View File

@@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"hash"
"io" "io"
"os" "os"
"time" "time"
@@ -36,19 +35,6 @@ type VerifyResult struct {
ErrorMessage string `json:"error,omitempty"` ErrorMessage string `json:"error,omitempty"`
} }
// deepVerifyFailure records a failure in the result and returns it appropriately
func (v *Vaultik) deepVerifyFailure(result *VerifyResult, opts *VerifyOptions, msg string, err error) error {
result.Status = "failed"
result.ErrorMessage = msg
if opts.JSON {
return v.outputVerifyJSON(result)
}
if err != nil {
return err
}
return fmt.Errorf("%s", msg)
}
// RunDeepVerify executes deep verification operation // RunDeepVerify executes deep verification operation
func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
result := &VerifyResult{ result := &VerifyResult{
@@ -56,20 +42,89 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
Mode: "deep", Mode: "deep",
} }
// Check for decryption capability
if !v.CanDecrypt() { if !v.CanDecrypt() {
return v.deepVerifyFailure(result, opts, result.Status = "failed"
"VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification", result.ErrorMessage = "VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification"
fmt.Errorf("VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification")) if opts.JSON {
return v.outputVerifyJSON(result)
}
return fmt.Errorf("VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification")
} }
log.Info("Starting snapshot verification", "snapshot_id", snapshotID, "mode", "deep") log.Info("Starting snapshot verification",
"snapshot_id", snapshotID,
"mode", "deep",
)
if !opts.JSON { if !opts.JSON {
v.printfStdout("Deep verification of snapshot: %s\n\n", snapshotID) v.printfStdout("Deep verification of snapshot: %s\n\n", snapshotID)
} }
manifest, tempDB, dbBlobs, err := v.loadVerificationData(snapshotID, opts, result) // Step 1: Download manifest
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
log.Info("Downloading manifest", "path", manifestPath)
if !opts.JSON {
v.printfStdout("Downloading manifest...\n")
}
manifestReader, err := v.Storage.Get(v.ctx, manifestPath)
if err != nil { if err != nil {
return err result.Status = "failed"
result.ErrorMessage = fmt.Sprintf("failed to download manifest: %v", err)
if opts.JSON {
return v.outputVerifyJSON(result)
}
return fmt.Errorf("failed to download manifest: %w", err)
}
defer func() { _ = manifestReader.Close() }()
// Decompress manifest
manifest, err := snapshot.DecodeManifest(manifestReader)
if err != nil {
result.Status = "failed"
result.ErrorMessage = fmt.Sprintf("failed to decode manifest: %v", err)
if opts.JSON {
return v.outputVerifyJSON(result)
}
return fmt.Errorf("failed to decode manifest: %w", err)
}
log.Info("Manifest loaded",
"manifest_blob_count", manifest.BlobCount,
"manifest_total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize)),
)
if !opts.JSON {
v.printfStdout("Manifest loaded: %d blobs (%s)\n", manifest.BlobCount, humanize.Bytes(uint64(manifest.TotalCompressedSize)))
}
// Step 2: Download and decrypt database (authoritative source)
dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
log.Info("Downloading encrypted database", "path", dbPath)
if !opts.JSON {
v.printfStdout("Downloading and decrypting database...\n")
}
dbReader, err := v.Storage.Get(v.ctx, dbPath)
if err != nil {
result.Status = "failed"
result.ErrorMessage = fmt.Sprintf("failed to download database: %v", err)
if opts.JSON {
return v.outputVerifyJSON(result)
}
return fmt.Errorf("failed to download database: %w", err)
}
defer func() { _ = dbReader.Close() }()
// Decrypt and decompress database
tempDB, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey)
if err != nil {
result.Status = "failed"
result.ErrorMessage = fmt.Sprintf("failed to decrypt database: %v", err)
if opts.JSON {
return v.outputVerifyJSON(result)
}
return fmt.Errorf("failed to decrypt database: %w", err)
} }
defer func() { defer func() {
if tempDB != nil { if tempDB != nil {
@@ -77,6 +132,17 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
} }
}() }()
// Step 3: Get authoritative blob list from database
dbBlobs, err := v.getBlobsFromDatabase(snapshotID, tempDB.DB)
if err != nil {
result.Status = "failed"
result.ErrorMessage = fmt.Sprintf("failed to get blobs from database: %v", err)
if opts.JSON {
return v.outputVerifyJSON(result)
}
return fmt.Errorf("failed to get blobs from database: %w", err)
}
result.BlobCount = len(dbBlobs) result.BlobCount = len(dbBlobs)
var totalSize int64 var totalSize int64
for _, blob := range dbBlobs { for _, blob := range dbBlobs {
@@ -84,10 +150,54 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
} }
result.TotalSize = totalSize result.TotalSize = totalSize
if err := v.runVerificationSteps(manifest, dbBlobs, tempDB, opts, result, totalSize); err != nil { log.Info("Database loaded",
"db_blob_count", len(dbBlobs),
"db_total_size", humanize.Bytes(uint64(totalSize)),
)
if !opts.JSON {
v.printfStdout("Database loaded: %d blobs (%s)\n", len(dbBlobs), humanize.Bytes(uint64(totalSize)))
v.printfStdout("Verifying manifest against database...\n")
}
// Step 4: Verify manifest matches database
if err := v.verifyManifestAgainstDatabase(manifest, dbBlobs); err != nil {
result.Status = "failed"
result.ErrorMessage = err.Error()
if opts.JSON {
return v.outputVerifyJSON(result)
}
return err return err
} }
// Step 5: Verify all blobs exist in S3 (using database as source)
if !opts.JSON {
v.printfStdout("Manifest verified.\n")
v.printfStdout("Checking blob existence in remote storage...\n")
}
if err := v.verifyBlobExistenceFromDB(dbBlobs); err != nil {
result.Status = "failed"
result.ErrorMessage = err.Error()
if opts.JSON {
return v.outputVerifyJSON(result)
}
return err
}
// Step 6: Deep verification - download and verify blob contents
if !opts.JSON {
v.printfStdout("All blobs exist.\n")
v.printfStdout("Downloading and verifying blob contents (%d blobs, %s)...\n", len(dbBlobs), humanize.Bytes(uint64(totalSize)))
}
if err := v.performDeepVerificationFromDB(dbBlobs, tempDB.DB, opts); err != nil {
result.Status = "failed"
result.ErrorMessage = err.Error()
if opts.JSON {
return v.outputVerifyJSON(result)
}
return err
}
// Success
result.Status = "ok" result.Status = "ok"
result.Verified = len(dbBlobs) result.Verified = len(dbBlobs)
@@ -96,7 +206,11 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
} }
log.Info("✓ Verification completed successfully", log.Info("✓ Verification completed successfully",
"snapshot_id", snapshotID, "mode", "deep", "blobs_verified", len(dbBlobs)) "snapshot_id", snapshotID,
"mode", "deep",
"blobs_verified", len(dbBlobs),
)
v.printfStdout("\n✓ Verification completed successfully\n") v.printfStdout("\n✓ Verification completed successfully\n")
v.printfStdout(" Snapshot: %s\n", snapshotID) v.printfStdout(" Snapshot: %s\n", snapshotID)
v.printfStdout(" Blobs verified: %d\n", len(dbBlobs)) v.printfStdout(" Blobs verified: %d\n", len(dbBlobs))
@@ -105,106 +219,6 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
return nil return nil
} }
// loadVerificationData downloads manifest, database, and blob list for verification
func (v *Vaultik) loadVerificationData(snapshotID string, opts *VerifyOptions, result *VerifyResult) (*snapshot.Manifest, *tempDB, []snapshot.BlobInfo, error) {
// Download manifest
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
log.Info("Downloading manifest", "path", manifestPath)
if !opts.JSON {
v.printfStdout("Downloading manifest...\n")
}
manifestReader, err := v.Storage.Get(v.ctx, manifestPath)
if err != nil {
return nil, nil, nil, v.deepVerifyFailure(result, opts,
fmt.Sprintf("failed to download manifest: %v", err),
fmt.Errorf("failed to download manifest: %w", err))
}
defer func() { _ = manifestReader.Close() }()
manifest, err := snapshot.DecodeManifest(manifestReader)
if err != nil {
return nil, nil, nil, v.deepVerifyFailure(result, opts,
fmt.Sprintf("failed to decode manifest: %v", err),
fmt.Errorf("failed to decode manifest: %w", err))
}
log.Info("Manifest loaded",
"manifest_blob_count", manifest.BlobCount,
"manifest_total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize)))
if !opts.JSON {
v.printfStdout("Manifest loaded: %d blobs (%s)\n", manifest.BlobCount, humanize.Bytes(uint64(manifest.TotalCompressedSize)))
v.printfStdout("Downloading and decrypting database...\n")
}
// Download and decrypt database
dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
log.Info("Downloading encrypted database", "path", dbPath)
dbReader, err := v.Storage.Get(v.ctx, dbPath)
if err != nil {
return nil, nil, nil, v.deepVerifyFailure(result, opts,
fmt.Sprintf("failed to download database: %v", err),
fmt.Errorf("failed to download database: %w", err))
}
defer func() { _ = dbReader.Close() }()
tdb, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey)
if err != nil {
return nil, nil, nil, v.deepVerifyFailure(result, opts,
fmt.Sprintf("failed to decrypt database: %v", err),
fmt.Errorf("failed to decrypt database: %w", err))
}
dbBlobs, err := v.getBlobsFromDatabase(snapshotID, tdb.DB)
if err != nil {
_ = tdb.Close()
return nil, nil, nil, v.deepVerifyFailure(result, opts,
fmt.Sprintf("failed to get blobs from database: %v", err),
fmt.Errorf("failed to get blobs from database: %w", err))
}
var dbTotalSize int64
for _, b := range dbBlobs {
dbTotalSize += b.CompressedSize
}
log.Info("Database loaded",
"db_blob_count", len(dbBlobs),
"db_total_size", humanize.Bytes(uint64(dbTotalSize)))
if !opts.JSON {
v.printfStdout("Database loaded: %d blobs (%s)\n", len(dbBlobs), humanize.Bytes(uint64(dbTotalSize)))
}
return manifest, tdb, dbBlobs, nil
}
// runVerificationSteps executes manifest verification, blob existence check, and deep content verification
func (v *Vaultik) runVerificationSteps(manifest *snapshot.Manifest, dbBlobs []snapshot.BlobInfo, tdb *tempDB, opts *VerifyOptions, result *VerifyResult, totalSize int64) error {
if !opts.JSON {
v.printfStdout("Verifying manifest against database...\n")
}
if err := v.verifyManifestAgainstDatabase(manifest, dbBlobs); err != nil {
return v.deepVerifyFailure(result, opts, err.Error(), err)
}
if !opts.JSON {
v.printfStdout("Manifest verified.\n")
v.printfStdout("Checking blob existence in remote storage...\n")
}
if err := v.verifyBlobExistenceFromDB(dbBlobs); err != nil {
return v.deepVerifyFailure(result, opts, err.Error(), err)
}
if !opts.JSON {
v.printfStdout("All blobs exist.\n")
v.printfStdout("Downloading and verifying blob contents (%d blobs, %s)...\n", len(dbBlobs), humanize.Bytes(uint64(totalSize)))
}
if err := v.performDeepVerificationFromDB(dbBlobs, tdb.DB, opts); err != nil {
return v.deepVerifyFailure(result, opts, err.Error(), err)
}
return nil
}
// tempDB wraps sql.DB with cleanup // tempDB wraps sql.DB with cleanup
type tempDB struct { type tempDB struct {
*sql.DB *sql.DB
@@ -302,27 +316,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error {
} }
defer decompressor.Close() defer decompressor.Close()
chunkCount, err := v.verifyBlobChunks(db, blobInfo.Hash, decompressor) // Query blob chunks from database to get offsets and lengths
if err != nil {
return err
}
if err := v.verifyBlobFinalIntegrity(decompressor, blobHasher, blobInfo.Hash); err != nil {
return err
}
log.Info("Blob verified",
"hash", blobInfo.Hash[:16]+"...",
"chunks", chunkCount,
"size", humanize.Bytes(uint64(blobInfo.CompressedSize)),
)
return nil
}
// verifyBlobChunks queries blob chunks from the database and verifies each chunk's hash
// against the decompressed blob stream
func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.Reader) (int, error) {
query := ` query := `
SELECT bc.chunk_hash, bc.offset, bc.length SELECT bc.chunk_hash, bc.offset, bc.length
FROM blob_chunks bc FROM blob_chunks bc
@@ -330,9 +324,9 @@ func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.
WHERE b.blob_hash = ? WHERE b.blob_hash = ?
ORDER BY bc.offset ORDER BY bc.offset
` `
rows, err := db.QueryContext(v.ctx, query, blobHash) rows, err := db.QueryContext(v.ctx, query, blobInfo.Hash)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to query blob chunks: %w", err) return fmt.Errorf("failed to query blob chunks: %w", err)
} }
defer func() { _ = rows.Close() }() defer func() { _ = rows.Close() }()
@@ -345,12 +339,12 @@ func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.
var chunkHash string var chunkHash string
var offset, length int64 var offset, length int64
if err := rows.Scan(&chunkHash, &offset, &length); err != nil { if err := rows.Scan(&chunkHash, &offset, &length); err != nil {
return 0, fmt.Errorf("failed to scan chunk row: %w", err) return fmt.Errorf("failed to scan chunk row: %w", err)
} }
// Verify chunk ordering // Verify chunk ordering
if offset <= lastOffset { if offset <= lastOffset {
return 0, fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) return fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset)
} }
lastOffset = offset lastOffset = offset
@@ -359,7 +353,7 @@ func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.
// Skip to the correct offset // Skip to the correct offset
skipBytes := offset - totalRead skipBytes := offset - totalRead
if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil { if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil {
return 0, fmt.Errorf("failed to skip to offset %d: %w", offset, err) return fmt.Errorf("failed to skip to offset %d: %w", offset, err)
} }
totalRead = offset totalRead = offset
} }
@@ -367,7 +361,7 @@ func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.
// Read chunk data // Read chunk data
chunkData := make([]byte, length) chunkData := make([]byte, length)
if _, err := io.ReadFull(decompressor, chunkData); err != nil { if _, err := io.ReadFull(decompressor, chunkData); err != nil {
return 0, fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) return fmt.Errorf("failed to read chunk at offset %d: %w", offset, err)
} }
totalRead += length totalRead += length
@@ -377,7 +371,7 @@ func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.
calculatedHash := hex.EncodeToString(hasher.Sum(nil)) calculatedHash := hex.EncodeToString(hasher.Sum(nil))
if calculatedHash != chunkHash { if calculatedHash != chunkHash {
return 0, fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", return fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s",
offset, calculatedHash, chunkHash) offset, calculatedHash, chunkHash)
} }
@@ -385,15 +379,9 @@ func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return 0, fmt.Errorf("error iterating blob chunks: %w", err) return fmt.Errorf("error iterating blob chunks: %w", err)
} }
return chunkCount, nil
}
// verifyBlobFinalIntegrity checks that no trailing data exists in the decompressed stream
// and that the encrypted blob hash matches the expected value
func (v *Vaultik) verifyBlobFinalIntegrity(decompressor io.Reader, blobHasher hash.Hash, expectedHash string) error {
// Verify no remaining data in blob - if chunk list is accurate, blob should be fully consumed // Verify no remaining data in blob - if chunk list is accurate, blob should be fully consumed
remaining, err := io.Copy(io.Discard, decompressor) remaining, err := io.Copy(io.Discard, decompressor)
if err != nil { if err != nil {
@@ -405,11 +393,17 @@ func (v *Vaultik) verifyBlobFinalIntegrity(decompressor io.Reader, blobHasher ha
// Verify blob hash matches the encrypted data we downloaded // Verify blob hash matches the encrypted data we downloaded
calculatedBlobHash := hex.EncodeToString(blobHasher.Sum(nil)) calculatedBlobHash := hex.EncodeToString(blobHasher.Sum(nil))
if calculatedBlobHash != expectedHash { if calculatedBlobHash != blobInfo.Hash {
return fmt.Errorf("blob hash mismatch: calculated %s, expected %s", return fmt.Errorf("blob hash mismatch: calculated %s, expected %s",
calculatedBlobHash, expectedHash) calculatedBlobHash, blobInfo.Hash)
} }
log.Info("Blob verified",
"hash", blobInfo.Hash[:16]+"...",
"chunks", chunkCount,
"size", humanize.Bytes(uint64(blobInfo.CompressedSize)),
)
return nil return nil
} }