diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..0b09869 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.git +.gitea +*.md +LICENSE +vaultik +coverage.out +coverage.html +.DS_Store diff --git a/.gitea/workflows/check.yml b/.gitea/workflows/check.yml new file mode 100644 index 0000000..fb6ef70 --- /dev/null +++ b/.gitea/workflows/check.yml @@ -0,0 +1,14 @@ +name: check +on: + push: + branches: [main] + pull_request: + branches: [main] +jobs: + check: + runs-on: ubuntu-latest + steps: + # actions/checkout v4, 2024-09-16 + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 + - name: Build and check + run: docker build . diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 4cdb844..a28f75f 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -54,7 +54,7 @@ The database tracks five primary entities and their relationships: #### File (`database.File`) Represents a file or directory in the backup system. Stores metadata needed for restoration: -- Path, timestamps (mtime, ctime) +- Path, mtime - Size, mode, ownership (uid, gid) - Symlink target (if applicable) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..9fa0aaa --- /dev/null +++ b/Dockerfile @@ -0,0 +1,61 @@ +# Lint stage +# golangci/golangci-lint:v2.11.3-alpine, 2026-03-17 +FROM golangci/golangci-lint:v2.11.3-alpine@sha256:b1c3de5862ad0a95b4e45a993b0f00415835d687e4f12c845c7493b86c13414e AS lint + +RUN apk add --no-cache make build-base + +WORKDIR /src + +# Copy go mod files first for better layer caching +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source code +COPY . . + +# Run formatting check and linter +RUN make fmt-check +RUN make lint + +# Build stage +# golang:1.26.1-alpine, 2026-03-17 +FROM golang:1.26.1-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder + +# Depend on lint stage passing +COPY --from=lint /src/go.sum /dev/null + +ARG VERSION=dev + +# Install build dependencies for CGO (mattn/go-sqlite3) and sqlite3 CLI (tests) +RUN apk add --no-cache make build-base sqlite + +WORKDIR /src + +# Copy go mod files first for better layer caching +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source code +COPY . . + +# Run tests +RUN make test + +# Build with CGO enabled (required for mattn/go-sqlite3) +RUN CGO_ENABLED=1 go build -ldflags "-X 'git.eeqj.de/sneak/vaultik/internal/globals.Version=${VERSION}' -X 'git.eeqj.de/sneak/vaultik/internal/globals.Commit=$(git rev-parse HEAD 2>/dev/null || echo unknown)'" -o /vaultik ./cmd/vaultik + +# Runtime stage +# alpine:3.21, 2026-02-25 +FROM alpine:3.21@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709 + +RUN apk add --no-cache ca-certificates sqlite + +# Copy binary from builder +COPY --from=builder /vaultik /usr/local/bin/vaultik + +# Create non-root user +RUN adduser -D -H -s /sbin/nologin vaultik + +USER vaultik + +ENTRYPOINT ["/usr/local/bin/vaultik"] diff --git a/Makefile b/Makefile index b84c13c..24490cf 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all check test lint fmt fmt-check build clean deps test-coverage test-integration local install release release-snapshot +.PHONY: all check test lint fmt fmt-check build clean deps test-coverage test-integration local install release release-snapshot docker hooks # Version number VERSION := 1.0.0-rc.1 @@ -18,20 +18,11 @@ check: lint fmt-check test # Run tests only. test: - @echo "Running tests..." - @if ! go test -v -timeout 10s ./... 2>&1; then \ - echo ""; \ - echo "TEST FAILURES DETECTED"; \ - echo "Run 'go test -v ./internal/database' to see database test details"; \ - exit 1; \ - fi + go test -race -timeout 30s ./... -# Check if code is formatted. +# Check if code is formatted (read-only). fmt-check: - @if [ -n "$$(go fmt ./...)" ]; then \ - echo "Error: Code is not formatted. Run 'make fmt' to fix."; \ - exit 1; \ - fi + @test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1) # Format code. fmt: @@ -39,7 +30,7 @@ fmt: # Run linter only. lint: - golangci-lint run + golangci-lint run ./... # Build binary. vaultik: internal/*/*.go cmd/vaultik/*.go @@ -78,3 +69,14 @@ release: # Dry-run a release build without publishing or tagging. release-snapshot: goreleaser release --clean --snapshot + +# Build Docker image. +docker: + docker build -t vaultik . + +# Install pre-commit hook. +hooks: + @printf '#!/bin/sh\nset -e\n' > .git/hooks/pre-commit + @printf 'go mod tidy\ngo fmt ./...\ngit diff --exit-code -- go.mod go.sum || { echo "go mod tidy changed files; please stage and retry"; exit 1; }\n' >> .git/hooks/pre-commit + @printf 'make check\n' >> .git/hooks/pre-commit + @chmod +x .git/hooks/pre-commit diff --git a/docs/DATAMODEL.md b/docs/DATAMODEL.md index 37f9480..3ff1c3f 100644 --- a/docs/DATAMODEL.md +++ b/docs/DATAMODEL.md @@ -23,7 +23,6 @@ Stores metadata about files in the filesystem being backed up. - `id` (TEXT PRIMARY KEY) - UUID for the file record - `path` (TEXT NOT NULL UNIQUE) - Absolute file path - `mtime` (INTEGER NOT NULL) - Modification time as Unix timestamp -- `ctime` (INTEGER NOT NULL) - Change time as Unix timestamp - `size` (INTEGER NOT NULL) - File size in bytes - `mode` (INTEGER NOT NULL) - Unix file permissions and type - `uid` (INTEGER NOT NULL) - User ID of file owner diff --git a/go.mod b/go.mod index e6eb5ae..f558044 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module git.eeqj.de/sneak/vaultik -go 1.24.4 +go 1.26.1 require ( filippo.io/age v1.2.1 @@ -23,6 +23,7 @@ require ( github.com/spf13/cobra v1.10.1 github.com/stretchr/testify v1.11.1 go.uber.org/fx v1.24.0 + golang.org/x/sync v0.18.0 golang.org/x/term v0.37.0 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.38.0 @@ -265,7 +266,6 @@ require ( golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.33.0 // indirect - golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.14.0 // indirect diff --git a/internal/blob/packer.go b/internal/blob/packer.go index c5284ec..7edf15b 100644 --- a/internal/blob/packer.go +++ b/internal/blob/packer.go @@ -361,101 +361,23 @@ func (p *Packer) finalizeCurrentBlob() error { return nil } - // Close blobgen writer to flush all data - if err := p.currentBlob.writer.Close(); err != nil { - p.cleanupTempFile() - return fmt.Errorf("closing blobgen writer: %w", err) - } - - // Sync file to ensure all data is written - if err := p.currentBlob.tempFile.Sync(); err != nil { - p.cleanupTempFile() - return fmt.Errorf("syncing temp file: %w", err) - } - - // Get the final size (encrypted if applicable) - finalSize, err := p.currentBlob.tempFile.Seek(0, io.SeekCurrent) + blobHash, finalSize, err := p.closeBlobWriter() if err != nil { - p.cleanupTempFile() - return fmt.Errorf("getting file size: %w", err) + return err } - // Reset to beginning for reading - if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { - p.cleanupTempFile() - return fmt.Errorf("seeking to start: %w", err) - } + chunkRefs := p.buildChunkRefs() - // Get hash from blobgen writer (of final encrypted data) - finalHash := p.currentBlob.writer.Sum256() - blobHash := hex.EncodeToString(finalHash) - - // Create chunk references with offsets - chunkRefs := make([]*BlobChunkRef, 0, len(p.currentBlob.chunks)) - - for _, chunk := range p.currentBlob.chunks { - chunkRefs = append(chunkRefs, &BlobChunkRef{ - ChunkHash: chunk.Hash, - Offset: chunk.Offset, - Length: chunk.Size, - }) - } - - // Get pending chunks (will be inserted to DB and reported to handler) chunksToInsert := p.pendingChunks - p.pendingChunks = nil // Clear pending list + p.pendingChunks = nil - // Insert pending chunks, blob_chunks, and update blob in a single transaction - if p.repos != nil { - blobIDTyped, parseErr := types.ParseBlobID(p.currentBlob.id) - if parseErr != nil { - p.cleanupTempFile() - return fmt.Errorf("parsing blob ID: %w", parseErr) - } - err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { - // First insert all pending chunks (required for blob_chunks FK) - for _, chunk := range chunksToInsert { - dbChunk := &database.Chunk{ - ChunkHash: types.ChunkHash(chunk.Hash), - Size: chunk.Size, - } - if err := p.repos.Chunks.Create(ctx, tx, dbChunk); err != nil { - return fmt.Errorf("creating chunk: %w", err) - } - } - - // Insert all blob_chunk records in batch - for _, chunk := range p.currentBlob.chunks { - blobChunk := &database.BlobChunk{ - BlobID: blobIDTyped, - ChunkHash: types.ChunkHash(chunk.Hash), - Offset: chunk.Offset, - Length: chunk.Size, - } - if err := p.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil { - return fmt.Errorf("creating blob_chunk: %w", err) - } - } - - // Update blob record with final hash and sizes - return p.repos.Blobs.UpdateFinished(ctx, tx, p.currentBlob.id, blobHash, - p.currentBlob.size, finalSize) - }) - if err != nil { - p.cleanupTempFile() - return fmt.Errorf("finalizing blob transaction: %w", err) - } - - log.Debug("Committed blob transaction", - "chunks_inserted", len(chunksToInsert), - "blob_chunks_inserted", len(p.currentBlob.chunks)) + if err := p.commitBlobToDatabase(blobHash, finalSize, chunksToInsert); err != nil { + return err } - // Create finished blob finished := &FinishedBlob{ ID: p.currentBlob.id, Hash: blobHash, - Data: nil, // We don't load data into memory anymore Chunks: chunkRefs, CreatedTS: p.currentBlob.startTime, Uncompressed: p.currentBlob.size, @@ -464,28 +386,105 @@ func (p *Packer) finalizeCurrentBlob() error { compressionRatio := float64(finished.Compressed) / float64(finished.Uncompressed) log.Info("Finalized blob (compressed and encrypted)", - "hash", blobHash, - "chunks", len(chunkRefs), - "uncompressed", finished.Uncompressed, - "compressed", finished.Compressed, + "hash", blobHash, "chunks", len(chunkRefs), + "uncompressed", finished.Uncompressed, "compressed", finished.Compressed, "ratio", fmt.Sprintf("%.2f", compressionRatio), "duration", time.Since(p.currentBlob.startTime)) - // Collect inserted chunk hashes for the scanner to track var insertedChunkHashes []string for _, chunk := range chunksToInsert { insertedChunkHashes = append(insertedChunkHashes, chunk.Hash) } - // Call blob handler if set + return p.deliverFinishedBlob(finished, insertedChunkHashes) +} + +// closeBlobWriter closes the writer, syncs to disk, and returns the blob hash and final size +func (p *Packer) closeBlobWriter() (string, int64, error) { + if err := p.currentBlob.writer.Close(); err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("closing blobgen writer: %w", err) + } + if err := p.currentBlob.tempFile.Sync(); err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("syncing temp file: %w", err) + } + + finalSize, err := p.currentBlob.tempFile.Seek(0, io.SeekCurrent) + if err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("getting file size: %w", err) + } + if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { + p.cleanupTempFile() + return "", 0, fmt.Errorf("seeking to start: %w", err) + } + + finalHash := p.currentBlob.writer.Sum256() + return hex.EncodeToString(finalHash), finalSize, nil +} + +// buildChunkRefs creates BlobChunkRef entries from the current blob's chunks +func (p *Packer) buildChunkRefs() []*BlobChunkRef { + refs := make([]*BlobChunkRef, 0, len(p.currentBlob.chunks)) + for _, chunk := range p.currentBlob.chunks { + refs = append(refs, &BlobChunkRef{ + ChunkHash: chunk.Hash, Offset: chunk.Offset, Length: chunk.Size, + }) + } + return refs +} + +// commitBlobToDatabase inserts pending chunks, blob_chunks, and updates the blob record +func (p *Packer) commitBlobToDatabase(blobHash string, finalSize int64, chunksToInsert []PendingChunk) error { + if p.repos == nil { + return nil + } + + blobIDTyped, parseErr := types.ParseBlobID(p.currentBlob.id) + if parseErr != nil { + p.cleanupTempFile() + return fmt.Errorf("parsing blob ID: %w", parseErr) + } + + err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error { + for _, chunk := range chunksToInsert { + dbChunk := &database.Chunk{ChunkHash: types.ChunkHash(chunk.Hash), Size: chunk.Size} + if err := p.repos.Chunks.Create(ctx, tx, dbChunk); err != nil { + return fmt.Errorf("creating chunk: %w", err) + } + } + + for _, chunk := range p.currentBlob.chunks { + blobChunk := &database.BlobChunk{ + BlobID: blobIDTyped, ChunkHash: types.ChunkHash(chunk.Hash), + Offset: chunk.Offset, Length: chunk.Size, + } + if err := p.repos.BlobChunks.Create(ctx, tx, blobChunk); err != nil { + return fmt.Errorf("creating blob_chunk: %w", err) + } + } + + return p.repos.Blobs.UpdateFinished(ctx, tx, p.currentBlob.id, blobHash, p.currentBlob.size, finalSize) + }) + if err != nil { + p.cleanupTempFile() + return fmt.Errorf("finalizing blob transaction: %w", err) + } + + log.Debug("Committed blob transaction", + "chunks_inserted", len(chunksToInsert), "blob_chunks_inserted", len(p.currentBlob.chunks)) + return nil +} + +// deliverFinishedBlob passes the blob to the handler or stores it internally +func (p *Packer) deliverFinishedBlob(finished *FinishedBlob, insertedChunkHashes []string) error { if p.blobHandler != nil { - // Reset file position for handler if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { p.cleanupTempFile() return fmt.Errorf("seeking for handler: %w", err) } - // Create a blob reader that includes the data stream blobWithReader := &BlobWithReader{ FinishedBlob: finished, Reader: p.currentBlob.tempFile, @@ -497,30 +496,26 @@ func (p *Packer) finalizeCurrentBlob() error { p.cleanupTempFile() return fmt.Errorf("blob handler failed: %w", err) } - // Note: blob handler is responsible for closing/cleaning up temp file - p.currentBlob = nil - } else { - log.Debug("No blob handler callback configured", "blob_hash", blobHash[:8]+"...") - // No handler, need to read data for legacy behavior - if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { - p.cleanupTempFile() - return fmt.Errorf("seeking to read data: %w", err) - } - - data, err := io.ReadAll(p.currentBlob.tempFile) - if err != nil { - p.cleanupTempFile() - return fmt.Errorf("reading blob data: %w", err) - } - finished.Data = data - - p.finishedBlobs = append(p.finishedBlobs, finished) - - // Cleanup - p.cleanupTempFile() p.currentBlob = nil + return nil } + // No handler - read data for legacy behavior + log.Debug("No blob handler callback configured", "blob_hash", finished.Hash[:8]+"...") + if _, err := p.currentBlob.tempFile.Seek(0, io.SeekStart); err != nil { + p.cleanupTempFile() + return fmt.Errorf("seeking to read data: %w", err) + } + + data, err := io.ReadAll(p.currentBlob.tempFile) + if err != nil { + p.cleanupTempFile() + return fmt.Errorf("reading blob data: %w", err) + } + finished.Data = data + p.finishedBlobs = append(p.finishedBlobs, finished) + p.cleanupTempFile() + p.currentBlob = nil return nil } diff --git a/internal/blobgen/compress.go b/internal/blobgen/compress.go index 1292fae..e8a8799 100644 --- a/internal/blobgen/compress.go +++ b/internal/blobgen/compress.go @@ -51,7 +51,13 @@ func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipien if err != nil { return 0, "", fmt.Errorf("creating writer: %w", err) } - defer func() { _ = w.Close() }() + + closed := false + defer func() { + if !closed { + _ = w.Close() + } + }() // Copy data if _, err := io.Copy(w, src); err != nil { @@ -62,6 +68,7 @@ func CompressStream(dst io.Writer, src io.Reader, compressionLevel int, recipien if err := w.Close(); err != nil { return 0, "", fmt.Errorf("closing writer: %w", err) } + closed = true return w.BytesWritten(), hex.EncodeToString(w.Sum256()), nil } diff --git a/internal/blobgen/compress_test.go b/internal/blobgen/compress_test.go new file mode 100644 index 0000000..6d1240c --- /dev/null +++ b/internal/blobgen/compress_test.go @@ -0,0 +1,64 @@ +package blobgen + +import ( + "bytes" + "crypto/rand" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testRecipient is a static age recipient for tests. +const testRecipient = "age1cplgrwj77ta54dnmydvvmzn64ltk83ankxl5sww04mrtmu62kv3s89gmvv" + +// TestCompressStreamNoDoubleClose is a regression test for issue #28. +// It verifies that CompressStream does not panic or return an error due to +// double-closing the underlying blobgen.Writer. Before the fix in PR #33, +// the explicit Close() on the happy path combined with defer Close() would +// cause a double close. +func TestCompressStreamNoDoubleClose(t *testing.T) { + input := []byte("regression test data for issue #28 double-close fix") + var buf bytes.Buffer + + written, hash, err := CompressStream(&buf, bytes.NewReader(input), 3, []string{testRecipient}) + require.NoError(t, err, "CompressStream should not return an error") + assert.True(t, written > 0, "expected bytes written > 0") + assert.NotEmpty(t, hash, "expected non-empty hash") + assert.True(t, buf.Len() > 0, "expected non-empty output") +} + +// TestCompressStreamLargeInput exercises CompressStream with a larger payload +// to ensure no double-close issues surface under heavier I/O. +func TestCompressStreamLargeInput(t *testing.T) { + data := make([]byte, 512*1024) // 512 KB + _, err := rand.Read(data) + require.NoError(t, err) + + var buf bytes.Buffer + written, hash, err := CompressStream(&buf, bytes.NewReader(data), 3, []string{testRecipient}) + require.NoError(t, err) + assert.True(t, written > 0) + assert.NotEmpty(t, hash) +} + +// TestCompressStreamEmptyInput verifies CompressStream handles empty input +// without double-close issues. +func TestCompressStreamEmptyInput(t *testing.T) { + var buf bytes.Buffer + _, hash, err := CompressStream(&buf, strings.NewReader(""), 3, []string{testRecipient}) + require.NoError(t, err) + assert.NotEmpty(t, hash) +} + +// TestCompressDataNoDoubleClose mirrors the stream test for CompressData, +// ensuring the explicit Close + error-path Close pattern is also safe. +func TestCompressDataNoDoubleClose(t *testing.T) { + input := []byte("CompressData regression test for double-close") + result, err := CompressData(input, 3, []string{testRecipient}) + require.NoError(t, err) + assert.True(t, result.CompressedSize > 0) + assert.True(t, result.UncompressedSize == int64(len(input))) + assert.NotEmpty(t, result.SHA256) +} diff --git a/internal/cli/restore.go b/internal/cli/restore.go index 62a1363..9e0b3f1 100644 --- a/internal/cli/restore.go +++ b/internal/cli/restore.go @@ -58,77 +58,7 @@ Examples: vaultik restore --verify myhost_docs_2025-01-01T12:00:00Z /restore`, Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - snapshotID := args[0] - opts.TargetDir = args[1] - if len(args) > 2 { - opts.Paths = args[2:] - } - - // Use unified config resolution - configPath, err := ResolveConfigPath() - if err != nil { - return err - } - - // Use the app framework like other commands - rootFlags := GetRootFlags() - return RunWithApp(cmd.Context(), AppOptions{ - ConfigPath: configPath, - LogOptions: log.LogOptions{ - Verbose: rootFlags.Verbose, - Debug: rootFlags.Debug, - Quiet: rootFlags.Quiet, - }, - Modules: []fx.Option{ - fx.Provide(fx.Annotate( - func(g *globals.Globals, cfg *config.Config, - storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp { - return &RestoreApp{ - Globals: g, - Config: cfg, - Storage: storer, - Vaultik: v, - Shutdowner: shutdowner, - } - }, - )), - }, - Invokes: []fx.Option{ - fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - // Start the restore operation in a goroutine - go func() { - // Run the restore operation - restoreOpts := &vaultik.RestoreOptions{ - SnapshotID: snapshotID, - TargetDir: opts.TargetDir, - Paths: opts.Paths, - Verify: opts.Verify, - } - if err := app.Vaultik.Restore(restoreOpts); err != nil { - if err != context.Canceled { - log.Error("Restore operation failed", "error", err) - os.Exit(1) - } - } - - // Shutdown the app when restore completes - if err := app.Shutdowner.Shutdown(); err != nil { - log.Error("Failed to shutdown", "error", err) - } - }() - return nil - }, - OnStop: func(ctx context.Context) error { - log.Debug("Stopping restore operation") - app.Vaultik.Cancel() - return nil - }, - }) - }), - }, - }) + return runRestore(cmd, args, opts) }, } @@ -136,3 +66,88 @@ Examples: return cmd } + +// runRestore parses arguments and runs the restore operation through the app framework +func runRestore(cmd *cobra.Command, args []string, opts *RestoreOptions) error { + snapshotID := args[0] + opts.TargetDir = args[1] + if len(args) > 2 { + opts.Paths = args[2:] + } + + // Use unified config resolution + configPath, err := ResolveConfigPath() + if err != nil { + return err + } + + // Use the app framework like other commands + rootFlags := GetRootFlags() + return RunWithApp(cmd.Context(), AppOptions{ + ConfigPath: configPath, + LogOptions: log.LogOptions{ + Verbose: rootFlags.Verbose, + Debug: rootFlags.Debug, + Quiet: rootFlags.Quiet, + }, + Modules: buildRestoreModules(), + Invokes: buildRestoreInvokes(snapshotID, opts), + }) +} + +// buildRestoreModules returns the fx.Options for dependency injection in restore +func buildRestoreModules() []fx.Option { + return []fx.Option{ + fx.Provide(fx.Annotate( + func(g *globals.Globals, cfg *config.Config, + storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp { + return &RestoreApp{ + Globals: g, + Config: cfg, + Storage: storer, + Vaultik: v, + Shutdowner: shutdowner, + } + }, + )), + } +} + +// buildRestoreInvokes returns the fx.Options that wire up the restore lifecycle +func buildRestoreInvokes(snapshotID string, opts *RestoreOptions) []fx.Option { + return []fx.Option{ + fx.Invoke(func(app *RestoreApp, lc fx.Lifecycle) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + // Start the restore operation in a goroutine + go func() { + // Run the restore operation + restoreOpts := &vaultik.RestoreOptions{ + SnapshotID: snapshotID, + TargetDir: opts.TargetDir, + Paths: opts.Paths, + Verify: opts.Verify, + } + if err := app.Vaultik.Restore(restoreOpts); err != nil { + if err != context.Canceled { + log.Error("Restore operation failed", "error", err) + os.Exit(1) + } + } + + // Shutdown the app when restore completes + if err := app.Shutdowner.Shutdown(); err != nil { + log.Error("Failed to shutdown", "error", err) + } + }() + return nil + }, + OnStop: func(ctx context.Context) error { + log.Debug("Stopping restore operation") + app.Vaultik.Cancel() + return nil + }, + }) + }), + } +} diff --git a/internal/cli/snapshot.go b/internal/cli/snapshot.go index 6a801b8..ec7e30b 100644 --- a/internal/cli/snapshot.go +++ b/internal/cli/snapshot.go @@ -167,10 +167,7 @@ func newSnapshotListCommand() *cobra.Command { // newSnapshotPurgeCommand creates the 'snapshot purge' subcommand func newSnapshotPurgeCommand() *cobra.Command { - var keepLatest bool - var olderThan string - var force bool - var names []string + opts := &vaultik.SnapshotPurgeOptions{} cmd := &cobra.Command{ Use: "purge", @@ -183,10 +180,10 @@ restrict the operation to specific snapshot names.`, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { // Validate flags - if !keepLatest && olderThan == "" { + if !opts.KeepLatest && opts.OlderThan == "" { return fmt.Errorf("must specify either --keep-latest or --older-than") } - if keepLatest && olderThan != "" { + if opts.KeepLatest && opts.OlderThan != "" { return fmt.Errorf("cannot specify both --keep-latest and --older-than") } @@ -210,13 +207,7 @@ restrict the operation to specific snapshot names.`, lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { go func() { - purgeOpts := &vaultik.PurgeOptions{ - KeepLatest: keepLatest, - OlderThan: olderThan, - Force: force, - Names: names, - } - if err := v.PurgeSnapshots(purgeOpts); err != nil { + if err := v.PurgeSnapshotsWithOptions(opts); err != nil { if err != context.Canceled { log.Error("Failed to purge snapshots", "error", err) os.Exit(1) @@ -239,10 +230,10 @@ restrict the operation to specific snapshot names.`, }, } - cmd.Flags().BoolVar(&keepLatest, "keep-latest", false, "Keep only the latest snapshot of each name") - cmd.Flags().StringVar(&olderThan, "older-than", "", "Remove snapshots older than duration (e.g., 30d, 6m, 1y)") - cmd.Flags().BoolVar(&force, "force", false, "Skip confirmation prompt") - cmd.Flags().StringArrayVar(&names, "snapshot", nil, "Restrict to snapshots with these names (repeat for multiple)") + cmd.Flags().BoolVar(&opts.KeepLatest, "keep-latest", false, "Keep only the latest snapshot of each name") + cmd.Flags().StringVar(&opts.OlderThan, "older-than", "", "Remove snapshots older than duration (e.g., 30d, 6m, 1y)") + cmd.Flags().BoolVar(&opts.Force, "force", false, "Skip confirmation prompt") + cmd.Flags().StringArrayVar(&opts.Names, "snapshot", nil, "Restrict to snapshots with these names (repeat for multiple)") return cmd } diff --git a/internal/database/cascade_debug_test.go b/internal/database/cascade_debug_test.go index b7a29f6..ed24746 100644 --- a/internal/database/cascade_debug_test.go +++ b/internal/database/cascade_debug_test.go @@ -29,7 +29,6 @@ func TestCascadeDeleteDebug(t *testing.T) { file := &File{ Path: "/cascade-test.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, diff --git a/internal/database/chunk_files_test.go b/internal/database/chunk_files_test.go index d99a06d..9c772c4 100644 --- a/internal/database/chunk_files_test.go +++ b/internal/database/chunk_files_test.go @@ -22,7 +22,6 @@ func TestChunkFileRepository(t *testing.T) { file1 := &File{ Path: "/file1.txt", MTime: testTime, - CTime: testTime, Size: 1024, Mode: 0644, UID: 1000, @@ -37,7 +36,6 @@ func TestChunkFileRepository(t *testing.T) { file2 := &File{ Path: "/file2.txt", MTime: testTime, - CTime: testTime, Size: 1024, Mode: 0644, UID: 1000, @@ -138,9 +136,9 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) { // Create test files testTime := time.Now().Truncate(time.Second) - file1 := &File{Path: "/file1.txt", MTime: testTime, CTime: testTime, Size: 3072, Mode: 0644, UID: 1000, GID: 1000} - file2 := &File{Path: "/file2.txt", MTime: testTime, CTime: testTime, Size: 3072, Mode: 0644, UID: 1000, GID: 1000} - file3 := &File{Path: "/file3.txt", MTime: testTime, CTime: testTime, Size: 2048, Mode: 0644, UID: 1000, GID: 1000} + file1 := &File{Path: "/file1.txt", MTime: testTime, Size: 3072, Mode: 0644, UID: 1000, GID: 1000} + file2 := &File{Path: "/file2.txt", MTime: testTime, Size: 3072, Mode: 0644, UID: 1000, GID: 1000} + file3 := &File{Path: "/file3.txt", MTime: testTime, Size: 2048, Mode: 0644, UID: 1000, GID: 1000} if err := fileRepo.Create(ctx, nil, file1); err != nil { t.Fatalf("failed to create file1: %v", err) diff --git a/internal/database/database.go b/internal/database/database.go index 06d611d..0cc7c4e 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -6,24 +6,32 @@ // multiple source files. Blobs are content-addressed, meaning their filename // is derived from their SHA256 hash after compression and encryption. // -// The database does not support migrations. If the schema changes, delete -// the local database and perform a full backup to recreate it. +// Schema is managed via numbered SQL migrations embedded in the schema/ +// directory. Migration 000.sql bootstraps the schema_migrations tracking +// table; subsequent migrations (001, 002, …) are applied in order. package database import ( "context" "database/sql" - _ "embed" + "embed" "fmt" "os" + "path/filepath" + "sort" + "strconv" "strings" "git.eeqj.de/sneak/vaultik/internal/log" _ "modernc.org/sqlite" ) -//go:embed schema.sql -var schemaSQL string +//go:embed schema/*.sql +var schemaFS embed.FS + +// bootstrapVersion is the migration that creates the schema_migrations +// table itself. It is applied before the normal migration loop. +const bootstrapVersion = 0 // DB represents the Vaultik local index database connection. // It uses SQLite to track file metadata, content-defined chunks, and blob associations. @@ -35,6 +43,46 @@ type DB struct { path string } +// ParseMigrationVersion extracts the numeric version prefix from a migration +// filename. Filenames must follow the pattern ".sql" or +// "_.sql", where version is a zero-padded numeric +// string (e.g. "001", "002"). Returns the version as an integer and an +// error if the filename does not match the expected pattern. +func ParseMigrationVersion(filename string) (int, error) { + name := strings.TrimSuffix(filename, filepath.Ext(filename)) + if name == "" { + return 0, fmt.Errorf("invalid migration filename %q: empty name", filename) + } + + // Split on underscore to separate version from description. + // If there's no underscore, the entire stem is the version. + versionStr := name + if idx := strings.IndexByte(name, '_'); idx >= 0 { + versionStr = name[:idx] + } + + if versionStr == "" { + return 0, fmt.Errorf("invalid migration filename %q: empty version prefix", filename) + } + + // Validate the version is purely numeric. + for _, ch := range versionStr { + if ch < '0' || ch > '9' { + return 0, fmt.Errorf( + "invalid migration filename %q: version %q contains non-numeric character %q", + filename, versionStr, string(ch), + ) + } + } + + version, err := strconv.Atoi(versionStr) + if err != nil { + return 0, fmt.Errorf("invalid migration filename %q: %w", filename, err) + } + + return version, nil +} + // New creates a new database connection at the specified path. // It creates the schema if needed and configures SQLite with WAL mode for // better concurrency. SQLite handles crash recovery automatically when @@ -72,9 +120,9 @@ func New(ctx context.Context, path string) (*DB, error) { } db := &DB{conn: conn, path: path} - if err := db.createSchema(ctx); err != nil { + if err := applyMigrations(ctx, conn); err != nil { _ = conn.Close() - return nil, fmt.Errorf("creating schema: %w", err) + return nil, fmt.Errorf("applying migrations: %w", err) } return db, nil } @@ -125,9 +173,9 @@ func New(ctx context.Context, path string) (*DB, error) { } db := &DB{conn: conn, path: path} - if err := db.createSchema(ctx); err != nil { + if err := applyMigrations(ctx, conn); err != nil { _ = conn.Close() - return nil, fmt.Errorf("creating schema: %w", err) + return nil, fmt.Errorf("applying migrations: %w", err) } log.Debug("Database connection established successfully", "path", path) @@ -198,9 +246,120 @@ func (db *DB) QueryRowWithLog( return db.conn.QueryRowContext(ctx, query, args...) } -func (db *DB) createSchema(ctx context.Context) error { - _, err := db.conn.ExecContext(ctx, schemaSQL) - return err +// collectMigrations reads the embedded schema directory and returns +// migration filenames sorted lexicographically. +func collectMigrations() ([]string, error) { + entries, err := schemaFS.ReadDir("schema") + if err != nil { + return nil, fmt.Errorf("failed to read schema directory: %w", err) + } + + var migrations []string + + for _, entry := range entries { + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { + migrations = append(migrations, entry.Name()) + } + } + + sort.Strings(migrations) + + return migrations, nil +} + +// bootstrapMigrationsTable ensures the schema_migrations table exists +// by applying 000.sql if the table is missing. +func bootstrapMigrationsTable(ctx context.Context, db *sql.DB) error { + var tableExists int + + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableExists) + if err != nil { + return fmt.Errorf("failed to check for migrations table: %w", err) + } + + if tableExists > 0 { + return nil + } + + content, err := schemaFS.ReadFile("schema/000.sql") + if err != nil { + return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err) + } + + log.Info("applying bootstrap migration", "version", bootstrapVersion) + + _, err = db.ExecContext(ctx, string(content)) + if err != nil { + return fmt.Errorf("failed to apply bootstrap migration: %w", err) + } + + return nil +} + +// applyMigrations applies all pending migrations to db. It first bootstraps +// the schema_migrations table via 000.sql, then iterates through remaining +// migration files in order. +func applyMigrations(ctx context.Context, db *sql.DB) error { + if err := bootstrapMigrationsTable(ctx, db); err != nil { + return err + } + + migrations, err := collectMigrations() + if err != nil { + return err + } + + for _, migration := range migrations { + version, parseErr := ParseMigrationVersion(migration) + if parseErr != nil { + return parseErr + } + + // Check if already applied. + var count int + + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM schema_migrations WHERE version = ?", + version, + ).Scan(&count) + if err != nil { + return fmt.Errorf("failed to check migration status: %w", err) + } + + if count > 0 { + log.Debug("migration already applied", "version", version) + + continue + } + + // Read and apply migration. + content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration)) + if readErr != nil { + return fmt.Errorf("failed to read migration %s: %w", migration, readErr) + } + + log.Info("applying migration", "version", version) + + _, execErr := db.ExecContext(ctx, string(content)) + if execErr != nil { + return fmt.Errorf("failed to apply migration %s: %w", migration, execErr) + } + + // Record migration as applied. + _, recErr := db.ExecContext(ctx, + "INSERT INTO schema_migrations (version) VALUES (?)", + version, + ) + if recErr != nil { + return fmt.Errorf("failed to record migration %s: %w", migration, recErr) + } + + log.Info("migration applied successfully", "version", version) + } + + return nil } // NewTestDB creates an in-memory SQLite database for testing purposes. diff --git a/internal/database/database_test.go b/internal/database/database_test.go index 65457d1..6d763a3 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test.go @@ -2,6 +2,7 @@ package database import ( "context" + "database/sql" "fmt" "path/filepath" "testing" @@ -26,9 +27,10 @@ func TestDatabase(t *testing.T) { t.Fatal("database connection is nil") } - // Test schema creation (already done in New) + // Test schema creation (already done in New via migrations) // Verify tables exist tables := []string{ + "schema_migrations", "files", "file_chunks", "chunks", "blobs", "blob_chunks", "chunk_files", "snapshots", } @@ -99,3 +101,139 @@ func TestDatabaseConcurrentAccess(t *testing.T) { t.Errorf("expected 10 chunks, got %d", count) } } + +func TestParseMigrationVersion(t *testing.T) { + tests := []struct { + name string + filename string + wantVer int + wantError bool + }{ + {name: "valid 000.sql", filename: "000.sql", wantVer: 0, wantError: false}, + {name: "valid 001.sql", filename: "001.sql", wantVer: 1, wantError: false}, + {name: "valid 099.sql", filename: "099.sql", wantVer: 99, wantError: false}, + {name: "valid with description", filename: "001_initial_schema.sql", wantVer: 1, wantError: false}, + {name: "valid large version", filename: "123_big_migration.sql", wantVer: 123, wantError: false}, + {name: "invalid alpha version", filename: "abc.sql", wantVer: 0, wantError: true}, + {name: "invalid mixed chars", filename: "12a.sql", wantVer: 0, wantError: true}, + {name: "invalid no extension", filename: "schema.sql", wantVer: 0, wantError: true}, + {name: "empty string", filename: "", wantVer: 0, wantError: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseMigrationVersion(tc.filename) + if tc.wantError { + if err == nil { + t.Errorf("ParseMigrationVersion(%q) = %d, nil; want error", tc.filename, got) + } + return + } + if err != nil { + t.Errorf("ParseMigrationVersion(%q) unexpected error: %v", tc.filename, err) + return + } + if got != tc.wantVer { + t.Errorf("ParseMigrationVersion(%q) = %d; want %d", tc.filename, got, tc.wantVer) + } + }) + } +} + +func TestApplyMigrations_Idempotent(t *testing.T) { + ctx := context.Background() + + conn, err := sql.Open("sqlite", ":memory:?_foreign_keys=ON") + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Errorf("failed to close database: %v", err) + } + }() + + conn.SetMaxOpenConns(1) + conn.SetMaxIdleConns(1) + + // First run: apply all migrations. + if err := applyMigrations(ctx, conn); err != nil { + t.Fatalf("first applyMigrations failed: %v", err) + } + + // Count rows in schema_migrations after first run. + var countBefore int + if err := conn.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations").Scan(&countBefore); err != nil { + t.Fatalf("failed to count schema_migrations after first run: %v", err) + } + + // Second run: must be a no-op. + if err := applyMigrations(ctx, conn); err != nil { + t.Fatalf("second applyMigrations failed: %v", err) + } + + // Count rows in schema_migrations after second run — must be unchanged. + var countAfter int + if err := conn.QueryRowContext(ctx, "SELECT COUNT(*) FROM schema_migrations").Scan(&countAfter); err != nil { + t.Fatalf("failed to count schema_migrations after second run: %v", err) + } + + if countBefore != countAfter { + t.Errorf("schema_migrations row count changed: before=%d, after=%d", countBefore, countAfter) + } +} + +func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) { + ctx := context.Background() + + conn, err := sql.Open("sqlite", ":memory:?_foreign_keys=ON") + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Errorf("failed to close database: %v", err) + } + }() + + conn.SetMaxOpenConns(1) + conn.SetMaxIdleConns(1) + + // Verify schema_migrations does NOT exist yet. + var tableBefore int + if err := conn.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableBefore); err != nil { + t.Fatalf("failed to check for table before bootstrap: %v", err) + } + if tableBefore != 0 { + t.Fatal("schema_migrations table should not exist before bootstrap") + } + + // Run bootstrap. + if err := bootstrapMigrationsTable(ctx, conn); err != nil { + t.Fatalf("bootstrapMigrationsTable failed: %v", err) + } + + // Verify schema_migrations now exists. + var tableAfter int + if err := conn.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'", + ).Scan(&tableAfter); err != nil { + t.Fatalf("failed to check for table after bootstrap: %v", err) + } + if tableAfter != 1 { + t.Fatalf("schema_migrations table should exist after bootstrap, got count=%d", tableAfter) + } + + // Verify version 0 row exists. + var version int + if err := conn.QueryRowContext(ctx, + "SELECT version FROM schema_migrations WHERE version = 0", + ).Scan(&version); err != nil { + t.Fatalf("version 0 row not found in schema_migrations: %v", err) + } + if version != 0 { + t.Errorf("expected version 0, got %d", version) + } +} diff --git a/internal/database/file_chunks_test.go b/internal/database/file_chunks_test.go index aad891b..c009e97 100644 --- a/internal/database/file_chunks_test.go +++ b/internal/database/file_chunks_test.go @@ -22,7 +22,6 @@ func TestFileChunkRepository(t *testing.T) { file := &File{ Path: "/test/file.txt", MTime: testTime, - CTime: testTime, Size: 3072, Mode: 0644, UID: 1000, @@ -135,7 +134,6 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) { file := &File{ Path: types.FilePath(path), MTime: testTime, - CTime: testTime, Size: 2048, Mode: 0644, UID: 1000, diff --git a/internal/database/files.go b/internal/database/files.go index da68e2d..21c0a0b 100644 --- a/internal/database/files.go +++ b/internal/database/files.go @@ -25,12 +25,11 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err } query := ` - INSERT INTO files (id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO files (id, path, source_path, mtime, size, mode, uid, gid, link_target) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(path) DO UPDATE SET source_path = excluded.source_path, mtime = excluded.mtime, - ctime = excluded.ctime, size = excluded.size, mode = excluded.mode, uid = excluded.uid, @@ -42,10 +41,10 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err var idStr string var err error if tx != nil { - LogSQL("Execute", query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()) - err = tx.QueryRowContext(ctx, query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()).Scan(&idStr) + LogSQL("Execute", query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()) + err = tx.QueryRowContext(ctx, query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()).Scan(&idStr) } else { - err = r.db.QueryRowWithLog(ctx, query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()).Scan(&idStr) + err = r.db.QueryRowWithLog(ctx, query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()).Scan(&idStr) } if err != nil { @@ -63,7 +62,7 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, error) { query := ` - SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target + SELECT id, path, source_path, mtime, size, mode, uid, gid, link_target FROM files WHERE path = ? ` @@ -82,7 +81,7 @@ func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, err // GetByID retrieves a file by its UUID func (r *FileRepository) GetByID(ctx context.Context, id types.FileID) (*File, error) { query := ` - SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target + SELECT id, path, source_path, mtime, size, mode, uid, gid, link_target FROM files WHERE id = ? ` @@ -100,7 +99,7 @@ func (r *FileRepository) GetByID(ctx context.Context, id types.FileID) (*File, e func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path string) (*File, error) { query := ` - SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target + SELECT id, path, source_path, mtime, size, mode, uid, gid, link_target FROM files WHERE path = ? ` @@ -123,7 +122,7 @@ func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path strin func (r *FileRepository) scanFile(row *sql.Row) (*File, error) { var file File var idStr, pathStr, sourcePathStr string - var mtimeUnix, ctimeUnix int64 + var mtimeUnix int64 var linkTarget sql.NullString err := row.Scan( @@ -131,7 +130,6 @@ func (r *FileRepository) scanFile(row *sql.Row) (*File, error) { &pathStr, &sourcePathStr, &mtimeUnix, - &ctimeUnix, &file.Size, &file.Mode, &file.UID, @@ -149,7 +147,6 @@ func (r *FileRepository) scanFile(row *sql.Row) (*File, error) { file.Path = types.FilePath(pathStr) file.SourcePath = types.SourcePath(sourcePathStr) file.MTime = time.Unix(mtimeUnix, 0).UTC() - file.CTime = time.Unix(ctimeUnix, 0).UTC() if linkTarget.Valid { file.LinkTarget = types.FilePath(linkTarget.String) } @@ -161,7 +158,7 @@ func (r *FileRepository) scanFile(row *sql.Row) (*File, error) { func (r *FileRepository) scanFileRows(rows *sql.Rows) (*File, error) { var file File var idStr, pathStr, sourcePathStr string - var mtimeUnix, ctimeUnix int64 + var mtimeUnix int64 var linkTarget sql.NullString err := rows.Scan( @@ -169,7 +166,6 @@ func (r *FileRepository) scanFileRows(rows *sql.Rows) (*File, error) { &pathStr, &sourcePathStr, &mtimeUnix, - &ctimeUnix, &file.Size, &file.Mode, &file.UID, @@ -187,7 +183,6 @@ func (r *FileRepository) scanFileRows(rows *sql.Rows) (*File, error) { file.Path = types.FilePath(pathStr) file.SourcePath = types.SourcePath(sourcePathStr) file.MTime = time.Unix(mtimeUnix, 0).UTC() - file.CTime = time.Unix(ctimeUnix, 0).UTC() if linkTarget.Valid { file.LinkTarget = types.FilePath(linkTarget.String) } @@ -197,7 +192,7 @@ func (r *FileRepository) scanFileRows(rows *sql.Rows) (*File, error) { func (r *FileRepository) ListModifiedSince(ctx context.Context, since time.Time) ([]*File, error) { query := ` - SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target + SELECT id, path, source_path, mtime, size, mode, uid, gid, link_target FROM files WHERE mtime >= ? ORDER BY path @@ -258,7 +253,7 @@ func (r *FileRepository) DeleteByID(ctx context.Context, tx *sql.Tx, id types.Fi func (r *FileRepository) ListByPrefix(ctx context.Context, prefix string) ([]*File, error) { query := ` - SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target + SELECT id, path, source_path, mtime, size, mode, uid, gid, link_target FROM files WHERE path LIKE ? || '%' ORDER BY path @@ -285,7 +280,7 @@ func (r *FileRepository) ListByPrefix(ctx context.Context, prefix string) ([]*Fi // ListAll returns all files in the database func (r *FileRepository) ListAll(ctx context.Context) ([]*File, error) { query := ` - SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target + SELECT id, path, source_path, mtime, size, mode, uid, gid, link_target FROM files ORDER BY path ` @@ -315,7 +310,7 @@ func (r *FileRepository) CreateBatch(ctx context.Context, tx *sql.Tx, files []*F return nil } - // Each File has 10 values, so batch at 100 to be safe with SQLite's variable limit + // Each File has 9 values, so batch at 100 to be safe with SQLite's variable limit const batchSize = 100 for i := 0; i < len(files); i += batchSize { @@ -325,19 +320,18 @@ func (r *FileRepository) CreateBatch(ctx context.Context, tx *sql.Tx, files []*F } batch := files[i:end] - query := `INSERT INTO files (id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target) VALUES ` - args := make([]interface{}, 0, len(batch)*10) + query := `INSERT INTO files (id, path, source_path, mtime, size, mode, uid, gid, link_target) VALUES ` + args := make([]interface{}, 0, len(batch)*9) for j, f := range batch { if j > 0 { query += ", " } - query += "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" - args = append(args, f.ID.String(), f.Path.String(), f.SourcePath.String(), f.MTime.Unix(), f.CTime.Unix(), f.Size, f.Mode, f.UID, f.GID, f.LinkTarget.String()) + query += "(?, ?, ?, ?, ?, ?, ?, ?, ?)" + args = append(args, f.ID.String(), f.Path.String(), f.SourcePath.String(), f.MTime.Unix(), f.Size, f.Mode, f.UID, f.GID, f.LinkTarget.String()) } query += ` ON CONFLICT(path) DO UPDATE SET source_path = excluded.source_path, mtime = excluded.mtime, - ctime = excluded.ctime, size = excluded.size, mode = excluded.mode, uid = excluded.uid, diff --git a/internal/database/files_test.go b/internal/database/files_test.go index 4b94519..8f16421 100644 --- a/internal/database/files_test.go +++ b/internal/database/files_test.go @@ -39,7 +39,6 @@ func TestFileRepository(t *testing.T) { file := &File{ Path: "/test/file.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -124,7 +123,6 @@ func TestFileRepositorySymlink(t *testing.T) { symlink := &File{ Path: "/test/link", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 0, Mode: uint32(0777 | os.ModeSymlink), UID: 1000, @@ -161,7 +159,6 @@ func TestFileRepositoryTransaction(t *testing.T) { file := &File{ Path: "/test/tx_file.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, diff --git a/internal/database/models.go b/internal/database/models.go index 729b576..14bc580 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -17,7 +17,6 @@ type File struct { Path types.FilePath // Absolute path of the file SourcePath types.SourcePath // The source directory this file came from (for restore path stripping) MTime time.Time - CTime time.Time Size int64 Mode uint32 UID uint32 diff --git a/internal/database/repositories_test.go b/internal/database/repositories_test.go index 14c7117..439f325 100644 --- a/internal/database/repositories_test.go +++ b/internal/database/repositories_test.go @@ -23,7 +23,6 @@ func TestRepositoriesTransaction(t *testing.T) { file := &File{ Path: "/test/tx_file.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -146,7 +145,6 @@ func TestRepositoriesTransactionRollback(t *testing.T) { file := &File{ Path: "/test/rollback_file.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -202,7 +200,6 @@ func TestRepositoriesReadTransaction(t *testing.T) { file := &File{ Path: "/test/read_file.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -226,7 +223,6 @@ func TestRepositoriesReadTransaction(t *testing.T) { _ = repos.Files.Create(ctx, tx, &File{ Path: "/test/should_fail.txt", MTime: time.Now(), - CTime: time.Now(), Size: 0, Mode: 0644, UID: 1000, diff --git a/internal/database/repository_comprehensive_test.go b/internal/database/repository_comprehensive_test.go index 9bea6e7..3bc25a6 100644 --- a/internal/database/repository_comprehensive_test.go +++ b/internal/database/repository_comprehensive_test.go @@ -23,7 +23,6 @@ func TestFileRepositoryUUIDGeneration(t *testing.T) { { Path: "/file1.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -32,7 +31,6 @@ func TestFileRepositoryUUIDGeneration(t *testing.T) { { Path: "/file2.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 2048, Mode: 0644, UID: 1000, @@ -72,7 +70,6 @@ func TestFileRepositoryGetByID(t *testing.T) { file := &File{ Path: "/test.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -120,7 +117,6 @@ func TestOrphanedFileCleanup(t *testing.T) { file1 := &File{ Path: "/orphaned.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -129,7 +125,6 @@ func TestOrphanedFileCleanup(t *testing.T) { file2 := &File{ Path: "/referenced.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 2048, Mode: 0644, UID: 1000, @@ -218,7 +213,6 @@ func TestOrphanedChunkCleanup(t *testing.T) { file := &File{ Path: "/test.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -348,7 +342,6 @@ func TestFileChunkRepositoryWithUUIDs(t *testing.T) { file := &File{ Path: "/test.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 3072, Mode: 0644, UID: 1000, @@ -419,7 +412,6 @@ func TestChunkFileRepositoryWithUUIDs(t *testing.T) { file1 := &File{ Path: "/file1.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -428,7 +420,6 @@ func TestChunkFileRepositoryWithUUIDs(t *testing.T) { file2 := &File{ Path: "/file2.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -586,7 +577,6 @@ func TestComplexOrphanedDataScenario(t *testing.T) { files[i] = &File{ Path: types.FilePath(fmt.Sprintf("/file%d.txt", i)), MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -678,7 +668,6 @@ func TestCascadeDelete(t *testing.T) { file := &File{ Path: "/cascade-test.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -750,7 +739,6 @@ func TestTransactionIsolation(t *testing.T) { file := &File{ Path: "/tx-test.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -812,7 +800,6 @@ func TestConcurrentOrphanedCleanup(t *testing.T) { file := &File{ Path: types.FilePath(fmt.Sprintf("/concurrent-%d.txt", i)), MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, diff --git a/internal/database/repository_debug_test.go b/internal/database/repository_debug_test.go index 92433d5..2bd9493 100644 --- a/internal/database/repository_debug_test.go +++ b/internal/database/repository_debug_test.go @@ -18,7 +18,6 @@ func TestOrphanedFileCleanupDebug(t *testing.T) { file1 := &File{ Path: "/orphaned.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 1024, Mode: 0644, UID: 1000, @@ -27,7 +26,6 @@ func TestOrphanedFileCleanupDebug(t *testing.T) { file2 := &File{ Path: "/referenced.txt", MTime: time.Now().Truncate(time.Second), - CTime: time.Now().Truncate(time.Second), Size: 2048, Mode: 0644, UID: 1000, diff --git a/internal/database/repository_edge_cases_test.go b/internal/database/repository_edge_cases_test.go index d701d38..4f9bb2b 100644 --- a/internal/database/repository_edge_cases_test.go +++ b/internal/database/repository_edge_cases_test.go @@ -29,7 +29,6 @@ func TestFileRepositoryEdgeCases(t *testing.T) { file: &File{ Path: "", MTime: time.Now(), - CTime: time.Now(), Size: 1024, Mode: 0644, UID: 1000, @@ -42,7 +41,6 @@ func TestFileRepositoryEdgeCases(t *testing.T) { file: &File{ Path: types.FilePath("/" + strings.Repeat("a", 4096)), MTime: time.Now(), - CTime: time.Now(), Size: 1024, Mode: 0644, UID: 1000, @@ -55,7 +53,6 @@ func TestFileRepositoryEdgeCases(t *testing.T) { file: &File{ Path: "/test/file with spaces and 特殊文字.txt", MTime: time.Now(), - CTime: time.Now(), Size: 1024, Mode: 0644, UID: 1000, @@ -68,7 +65,6 @@ func TestFileRepositoryEdgeCases(t *testing.T) { file: &File{ Path: "/empty.txt", MTime: time.Now(), - CTime: time.Now(), Size: 0, Mode: 0644, UID: 1000, @@ -81,7 +77,6 @@ func TestFileRepositoryEdgeCases(t *testing.T) { file: &File{ Path: "/link", MTime: time.Now(), - CTime: time.Now(), Size: 0, Mode: 0777 | 0120000, // symlink mode UID: 1000, @@ -123,7 +118,6 @@ func TestDuplicateHandling(t *testing.T) { file1 := &File{ Path: "/duplicate.txt", MTime: time.Now(), - CTime: time.Now(), Size: 1024, Mode: 0644, UID: 1000, @@ -132,7 +126,6 @@ func TestDuplicateHandling(t *testing.T) { file2 := &File{ Path: "/duplicate.txt", // Same path MTime: time.Now().Add(time.Hour), - CTime: time.Now().Add(time.Hour), Size: 2048, Mode: 0644, UID: 1000, @@ -192,7 +185,6 @@ func TestDuplicateHandling(t *testing.T) { file := &File{ Path: "/test-dup-fc.txt", MTime: time.Now(), - CTime: time.Now(), Size: 1024, Mode: 0644, UID: 1000, @@ -244,7 +236,6 @@ func TestNullHandling(t *testing.T) { file := &File{ Path: "/regular.txt", MTime: time.Now(), - CTime: time.Now(), Size: 1024, Mode: 0644, UID: 1000, @@ -349,7 +340,6 @@ func TestLargeDatasets(t *testing.T) { file := &File{ Path: types.FilePath(fmt.Sprintf("/large/file%05d.txt", i)), MTime: time.Now(), - CTime: time.Now(), Size: int64(i * 1024), Mode: 0644, UID: uint32(1000 + (i % 10)), @@ -474,7 +464,6 @@ func TestQueryInjection(t *testing.T) { file := &File{ Path: types.FilePath(injection), MTime: time.Now(), - CTime: time.Now(), Size: 1024, Mode: 0644, UID: 1000, @@ -513,7 +502,6 @@ func TestTimezoneHandling(t *testing.T) { file := &File{ Path: "/timezone-test.txt", MTime: nyTime, - CTime: nyTime, Size: 1024, Mode: 0644, UID: 1000, diff --git a/internal/database/schema/000.sql b/internal/database/schema/000.sql new file mode 100644 index 0000000..e06a2da --- /dev/null +++ b/internal/database/schema/000.sql @@ -0,0 +1,9 @@ +-- Migration 000: Schema migrations tracking table +-- Applied as a bootstrap step before the normal migration loop. + +CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +INSERT OR IGNORE INTO schema_migrations (version) VALUES (0); diff --git a/internal/database/schema.sql b/internal/database/schema/001.sql similarity index 92% rename from internal/database/schema.sql rename to internal/database/schema/001.sql index 64b03a0..5f54565 100644 --- a/internal/database/schema.sql +++ b/internal/database/schema/001.sql @@ -1,6 +1,5 @@ --- Vaultik Database Schema --- Note: This database does not support migrations. If the schema changes, --- delete the local database and perform a full backup to recreate it. +-- Migration 001: Initial Vaultik schema +-- All core tables for tracking files, chunks, blobs, snapshots, and uploads. -- Files table: stores metadata about files in the filesystem CREATE TABLE IF NOT EXISTS files ( @@ -8,7 +7,6 @@ CREATE TABLE IF NOT EXISTS files ( path TEXT NOT NULL UNIQUE, source_path TEXT NOT NULL DEFAULT '', -- The source directory this file came from (for restore path stripping) mtime INTEGER NOT NULL, - ctime INTEGER NOT NULL, size INTEGER NOT NULL, mode INTEGER NOT NULL, uid INTEGER NOT NULL, @@ -103,7 +101,7 @@ CREATE TABLE IF NOT EXISTS snapshot_files ( file_id TEXT NOT NULL, PRIMARY KEY (snapshot_id, file_id), FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE, - FOREIGN KEY (file_id) REFERENCES files(id) + FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE ); -- Index for efficient file lookups (used in orphan detection) @@ -116,7 +114,7 @@ CREATE TABLE IF NOT EXISTS snapshot_blobs ( blob_hash TEXT NOT NULL, PRIMARY KEY (snapshot_id, blob_id), FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE, - FOREIGN KEY (blob_id) REFERENCES blobs(id) + FOREIGN KEY (blob_id) REFERENCES blobs(id) ON DELETE CASCADE ); -- Index for efficient blob lookups (used in orphan detection) @@ -130,8 +128,8 @@ CREATE TABLE IF NOT EXISTS uploads ( size INTEGER NOT NULL, duration_ms INTEGER NOT NULL, FOREIGN KEY (blob_hash) REFERENCES blobs(blob_hash), - FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) + FOREIGN KEY (snapshot_id) REFERENCES snapshots(id) ON DELETE CASCADE ); -- Index for efficient snapshot lookups -CREATE INDEX IF NOT EXISTS idx_uploads_snapshot_id ON uploads(snapshot_id); \ No newline at end of file +CREATE INDEX IF NOT EXISTS idx_uploads_snapshot_id ON uploads(snapshot_id); diff --git a/internal/database/schema/008_uploads.sql b/internal/database/schema/008_uploads.sql deleted file mode 100644 index 49b5add..0000000 --- a/internal/database/schema/008_uploads.sql +++ /dev/null @@ -1,11 +0,0 @@ --- Track blob upload metrics -CREATE TABLE IF NOT EXISTS uploads ( - blob_hash TEXT PRIMARY KEY, - uploaded_at TIMESTAMP NOT NULL, - size INTEGER NOT NULL, - duration_ms INTEGER NOT NULL, - FOREIGN KEY (blob_hash) REFERENCES blobs(blob_hash) -); - -CREATE INDEX idx_uploads_uploaded_at ON uploads(uploaded_at); -CREATE INDEX idx_uploads_duration ON uploads(duration_ms); \ No newline at end of file diff --git a/internal/snapshot/backup_test.go b/internal/snapshot/backup_test.go index 09ad29c..05bd30c 100644 --- a/internal/snapshot/backup_test.go +++ b/internal/snapshot/backup_test.go @@ -345,9 +345,8 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str Size: info.Size(), Mode: uint32(info.Mode()), MTime: info.ModTime(), - CTime: info.ModTime(), // Use mtime as ctime for test - UID: 1000, // Default UID for test - GID: 1000, // Default GID for test + UID: 1000, // Default UID for test + GID: 1000, // Default GID for test } err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { return b.repos.Files.Create(ctx, tx, file) diff --git a/internal/snapshot/scanner.go b/internal/snapshot/scanner.go index ca403b4..6043133 100644 --- a/internal/snapshot/scanner.go +++ b/internal/snapshot/scanner.go @@ -180,18 +180,10 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc } // Phase 0: Load known files and chunks from database into memory for fast lookup - fmt.Println("Loading known files from database...") - knownFiles, err := s.loadKnownFiles(ctx, path) + knownFiles, err := s.loadDatabaseState(ctx, path) if err != nil { - return nil, fmt.Errorf("loading known files: %w", err) + return nil, err } - fmt.Printf("Loaded %s known files from database\n", formatNumber(len(knownFiles))) - - fmt.Println("Loading known chunks from database...") - if err := s.loadKnownChunks(ctx); err != nil { - return nil, fmt.Errorf("loading known chunks: %w", err) - } - fmt.Printf("Loaded %s known chunks from database\n", formatNumber(len(s.knownChunks))) // Phase 1: Scan directory, collect files to process, and track existing files // (builds existingFiles map during walk to avoid double traversal) @@ -216,36 +208,8 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc } } - // Calculate total size to process - var totalSizeToProcess int64 - for _, file := range filesToProcess { - totalSizeToProcess += file.FileInfo.Size() - } - - // Update progress with total size and file count - if s.progress != nil { - s.progress.SetTotalSize(totalSizeToProcess) - s.progress.GetStats().TotalFiles.Store(int64(len(filesToProcess))) - } - - log.Info("Phase 1 complete", - "total_files", len(filesToProcess), - "total_size", humanize.Bytes(uint64(totalSizeToProcess)), - "files_skipped", result.FilesSkipped, - "bytes_skipped", humanize.Bytes(uint64(result.BytesSkipped))) - - // Print scan summary - fmt.Printf("Scan complete: %s examined (%s), %s to process (%s)", - formatNumber(result.FilesScanned), - humanize.Bytes(uint64(totalSizeToProcess+result.BytesSkipped)), - formatNumber(len(filesToProcess)), - humanize.Bytes(uint64(totalSizeToProcess))) - if result.FilesDeleted > 0 { - fmt.Printf(", %s deleted (%s)", - formatNumber(result.FilesDeleted), - humanize.Bytes(uint64(result.BytesDeleted))) - } - fmt.Println() + // Summarize scan phase results and update progress + s.summarizeScanPhase(result, filesToProcess) // Phase 2: Process files and create chunks if len(filesToProcess) > 0 { @@ -259,7 +223,66 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc log.Info("Phase 2/3: Skipping (no files need processing, metadata-only snapshot)") } - // Get final stats from packer + // Finalize result with blob statistics + s.finalizeScanResult(ctx, result) + + return result, nil +} + +// loadDatabaseState loads known files and chunks from the database into memory for fast lookup +// This avoids per-file and per-chunk database queries during the scan and process phases +func (s *Scanner) loadDatabaseState(ctx context.Context, path string) (map[string]*database.File, error) { + fmt.Println("Loading known files from database...") + knownFiles, err := s.loadKnownFiles(ctx, path) + if err != nil { + return nil, fmt.Errorf("loading known files: %w", err) + } + fmt.Printf("Loaded %s known files from database\n", formatNumber(len(knownFiles))) + + fmt.Println("Loading known chunks from database...") + if err := s.loadKnownChunks(ctx); err != nil { + return nil, fmt.Errorf("loading known chunks: %w", err) + } + fmt.Printf("Loaded %s known chunks from database\n", formatNumber(len(s.knownChunks))) + + return knownFiles, nil +} + +// summarizeScanPhase calculates total size to process, updates progress tracking, +// and prints the scan phase summary with file counts and sizes +func (s *Scanner) summarizeScanPhase(result *ScanResult, filesToProcess []*FileToProcess) { + var totalSizeToProcess int64 + for _, file := range filesToProcess { + totalSizeToProcess += file.FileInfo.Size() + } + + if s.progress != nil { + s.progress.SetTotalSize(totalSizeToProcess) + s.progress.GetStats().TotalFiles.Store(int64(len(filesToProcess))) + } + + log.Info("Phase 1 complete", + "total_files", len(filesToProcess), + "total_size", humanize.Bytes(uint64(totalSizeToProcess)), + "files_skipped", result.FilesSkipped, + "bytes_skipped", humanize.Bytes(uint64(result.BytesSkipped))) + + fmt.Printf("Scan complete: %s examined (%s), %s to process (%s)", + formatNumber(result.FilesScanned), + humanize.Bytes(uint64(totalSizeToProcess+result.BytesSkipped)), + formatNumber(len(filesToProcess)), + humanize.Bytes(uint64(totalSizeToProcess))) + if result.FilesDeleted > 0 { + fmt.Printf(", %s deleted (%s)", + formatNumber(result.FilesDeleted), + humanize.Bytes(uint64(result.BytesDeleted))) + } + fmt.Println() +} + +// finalizeScanResult populates final blob statistics in the scan result +// by querying the packer and database for blob/upload counts +func (s *Scanner) finalizeScanResult(ctx context.Context, result *ScanResult) { blobs := s.packer.GetFinishedBlobs() result.BlobsCreated += len(blobs) @@ -276,7 +299,6 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc } result.EndTime = time.Now().UTC() - return result, nil } // loadKnownFiles loads all known files from the database into a map for fast lookup @@ -424,12 +446,38 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { flushStart := time.Now() log.Debug("flushCompletedPendingFiles: starting") + // Partition pending files into those ready to flush and those still waiting + canFlush, stillPendingCount := s.partitionPendingByChunkStatus() + + if len(canFlush) == 0 { + log.Debug("flushCompletedPendingFiles: nothing to flush") + return nil + } + + log.Debug("Flushing completed files after blob finalize", + "files_to_flush", len(canFlush), + "files_still_pending", stillPendingCount) + + // Collect all data for batch operations + allFiles, allFileIDs, allFileChunks, allChunkFiles := s.collectBatchFlushData(canFlush) + + // Execute the batch flush in a single transaction + log.Debug("flushCompletedPendingFiles: starting transaction") + txStart := time.Now() + err := s.executeBatchFileFlush(ctx, allFiles, allFileIDs, allFileChunks, allChunkFiles) + log.Debug("flushCompletedPendingFiles: transaction done", "duration", time.Since(txStart)) + log.Debug("flushCompletedPendingFiles: total duration", "duration", time.Since(flushStart)) + return err +} + +// partitionPendingByChunkStatus separates pending files into those whose chunks +// are all committed to DB (ready to flush) and those still waiting on pending chunks. +// Updates s.pendingFiles to contain only the still-pending files. +func (s *Scanner) partitionPendingByChunkStatus() (canFlush []pendingFileData, stillPendingCount int) { log.Debug("flushCompletedPendingFiles: acquiring pendingFilesMu lock") s.pendingFilesMu.Lock() log.Debug("flushCompletedPendingFiles: acquired lock", "pending_files", len(s.pendingFiles)) - // Separate files into complete (can flush) and incomplete (keep pending) - var canFlush []pendingFileData var stillPending []pendingFileData log.Debug("flushCompletedPendingFiles: checking which files can flush") @@ -454,18 +502,15 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { s.pendingFilesMu.Unlock() log.Debug("flushCompletedPendingFiles: released lock") - if len(canFlush) == 0 { - log.Debug("flushCompletedPendingFiles: nothing to flush") - return nil - } + return canFlush, len(stillPending) +} - log.Debug("Flushing completed files after blob finalize", - "files_to_flush", len(canFlush), - "files_still_pending", len(stillPending)) - - // Collect all data for batch operations +// collectBatchFlushData aggregates file records, IDs, file-chunk mappings, and chunk-file +// mappings from the given pending file data for efficient batch database operations +func (s *Scanner) collectBatchFlushData(canFlush []pendingFileData) ([]*database.File, []types.FileID, []database.FileChunk, []database.ChunkFile) { log.Debug("flushCompletedPendingFiles: collecting data for batch ops") collectStart := time.Now() + var allFileChunks []database.FileChunk var allChunkFiles []database.ChunkFile var allFileIDs []types.FileID @@ -477,16 +522,20 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { allFileIDs = append(allFileIDs, data.file.ID) allFiles = append(allFiles, data.file) } + log.Debug("flushCompletedPendingFiles: collected data", "duration", time.Since(collectStart), "file_chunks", len(allFileChunks), "chunk_files", len(allChunkFiles), "files", len(allFiles)) - // Flush the complete files using batch operations - log.Debug("flushCompletedPendingFiles: starting transaction") - txStart := time.Now() - err := s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { + return allFiles, allFileIDs, allFileChunks, allChunkFiles +} + +// executeBatchFileFlush writes all collected file data to the database in a single transaction, +// including deleting old mappings, creating file records, and adding snapshot associations +func (s *Scanner) executeBatchFileFlush(ctx context.Context, allFiles []*database.File, allFileIDs []types.FileID, allFileChunks []database.FileChunk, allChunkFiles []database.ChunkFile) error { + return s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { log.Debug("flushCompletedPendingFiles: inside transaction") // Batch delete old file_chunks and chunk_files @@ -539,9 +588,6 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error { log.Debug("flushCompletedPendingFiles: transaction complete") return nil }) - log.Debug("flushCompletedPendingFiles: transaction done", "duration", time.Since(txStart)) - log.Debug("flushCompletedPendingFiles: total duration", "duration", time.Since(flushStart)) - return err } // ScanPhaseResult contains the results of the scan phase @@ -623,62 +669,11 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult mu.Unlock() // Update result stats - if needsProcessing { - result.BytesScanned += info.Size() - if s.progress != nil { - s.progress.GetStats().BytesScanned.Add(info.Size()) - } - } else { - result.FilesSkipped++ - result.BytesSkipped += info.Size() - if s.progress != nil { - s.progress.GetStats().FilesSkipped.Add(1) - s.progress.GetStats().BytesSkipped.Add(info.Size()) - } - } - result.FilesScanned++ - if s.progress != nil { - s.progress.GetStats().FilesScanned.Add(1) - } + s.updateScanEntryStats(result, needsProcessing, info) // Output periodic status if time.Since(lastStatusTime) >= statusInterval { - elapsed := time.Since(startTime) - rate := float64(filesScanned) / elapsed.Seconds() - - // Build status line - use estimate if available (not first backup) - if estimatedTotal > 0 { - // Show actual scanned vs estimate (may exceed estimate if files were added) - pct := float64(filesScanned) / float64(estimatedTotal) * 100 - if pct > 100 { - pct = 100 // Cap at 100% for display - } - remaining := estimatedTotal - filesScanned - if remaining < 0 { - remaining = 0 - } - var eta time.Duration - if rate > 0 && remaining > 0 { - eta = time.Duration(float64(remaining)/rate) * time.Second - } - fmt.Printf("Scan: %s files (~%.0f%%), %s changed/new, %.0f files/sec, %s elapsed", - formatNumber(int(filesScanned)), - pct, - formatNumber(changedCount), - rate, - elapsed.Round(time.Second)) - if eta > 0 { - fmt.Printf(", ETA %s", eta.Round(time.Second)) - } - fmt.Println() - } else { - // First backup - no estimate available - fmt.Printf("Scan: %s files, %s changed/new, %.0f files/sec, %s elapsed\n", - formatNumber(int(filesScanned)), - formatNumber(changedCount), - rate, - elapsed.Round(time.Second)) - } + printScanProgressLine(filesScanned, changedCount, estimatedTotal, startTime) lastStatusTime = time.Now() } @@ -695,6 +690,68 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult }, nil } +// updateScanEntryStats updates the scan result and progress reporter statistics +// for a single scanned file entry based on whether it needs processing +func (s *Scanner) updateScanEntryStats(result *ScanResult, needsProcessing bool, info os.FileInfo) { + if needsProcessing { + result.BytesScanned += info.Size() + if s.progress != nil { + s.progress.GetStats().BytesScanned.Add(info.Size()) + } + } else { + result.FilesSkipped++ + result.BytesSkipped += info.Size() + if s.progress != nil { + s.progress.GetStats().FilesSkipped.Add(1) + s.progress.GetStats().BytesSkipped.Add(info.Size()) + } + } + result.FilesScanned++ + if s.progress != nil { + s.progress.GetStats().FilesScanned.Add(1) + } +} + +// printScanProgressLine prints a periodic progress line during the scan phase, +// showing files scanned, percentage complete (if estimate available), and ETA +func printScanProgressLine(filesScanned int64, changedCount int, estimatedTotal int64, startTime time.Time) { + elapsed := time.Since(startTime) + rate := float64(filesScanned) / elapsed.Seconds() + + if estimatedTotal > 0 { + // Show actual scanned vs estimate (may exceed estimate if files were added) + pct := float64(filesScanned) / float64(estimatedTotal) * 100 + if pct > 100 { + pct = 100 // Cap at 100% for display + } + remaining := estimatedTotal - filesScanned + if remaining < 0 { + remaining = 0 + } + var eta time.Duration + if rate > 0 && remaining > 0 { + eta = time.Duration(float64(remaining)/rate) * time.Second + } + fmt.Printf("Scan: %s files (~%.0f%%), %s changed/new, %.0f files/sec, %s elapsed", + formatNumber(int(filesScanned)), + pct, + formatNumber(changedCount), + rate, + elapsed.Round(time.Second)) + if eta > 0 { + fmt.Printf(", ETA %s", eta.Round(time.Second)) + } + fmt.Println() + } else { + // First backup - no estimate available + fmt.Printf("Scan: %s files, %s changed/new, %.0f files/sec, %s elapsed\n", + formatNumber(int(filesScanned)), + formatNumber(changedCount), + rate, + elapsed.Round(time.Second)) + } +} + // checkFileInMemory checks if a file needs processing using the in-memory map // No database access is performed - this is purely CPU/memory work func (s *Scanner) checkFileInMemory(path string, info os.FileInfo, knownFiles map[string]*database.File) (*database.File, bool) { @@ -728,7 +785,6 @@ func (s *Scanner) checkFileInMemory(path string, info os.FileInfo, knownFiles ma Path: types.FilePath(path), SourcePath: types.SourcePath(s.currentSourcePath), // Store source directory for restore path stripping MTime: info.ModTime(), - CTime: info.ModTime(), // afero doesn't provide ctime Size: info.Size(), Mode: uint32(info.Mode()), UID: uid, @@ -830,22 +886,13 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc s.progress.GetStats().CurrentFile.Store(fileToProcess.Path) } - // Process file in streaming fashion - if err := s.processFileStreaming(ctx, fileToProcess, result); err != nil { - // Handle files that were deleted between scan and process phases - if errors.Is(err, os.ErrNotExist) { - log.Warn("File was deleted during backup, skipping", "path", fileToProcess.Path) - result.FilesSkipped++ - continue - } - // Skip file read errors if --skip-errors is enabled - if s.skipErrors { - log.Error("ERROR: Failed to process file (skipping due to --skip-errors)", "path", fileToProcess.Path, "error", err) - fmt.Printf("ERROR: Failed to process %s: %v (skipping)\n", fileToProcess.Path, err) - result.FilesSkipped++ - continue - } - return fmt.Errorf("processing file %s: %w", fileToProcess.Path, err) + // Process file with error handling for deleted files and skip-errors mode + skipped, err := s.processFileWithErrorHandling(ctx, fileToProcess, result) + if err != nil { + return err + } + if skipped { + continue } // Update files processed counter @@ -858,36 +905,71 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc // Output periodic status if time.Since(lastStatusTime) >= statusInterval { - elapsed := time.Since(startTime) - pct := float64(bytesProcessed) / float64(totalBytes) * 100 - byteRate := float64(bytesProcessed) / elapsed.Seconds() - fileRate := float64(filesProcessed) / elapsed.Seconds() - - // Calculate ETA based on bytes (more accurate than files) - remainingBytes := totalBytes - bytesProcessed - var eta time.Duration - if byteRate > 0 { - eta = time.Duration(float64(remainingBytes)/byteRate) * time.Second - } - - // Format: Progress [5.7k/610k] 6.7 GB/44 GB (15.4%), 106MB/sec, 500 files/sec, running for 1m30s, ETA: 5m49s - fmt.Printf("Progress [%s/%s] %s/%s (%.1f%%), %s/sec, %.0f files/sec, running for %s", - formatCompact(filesProcessed), - formatCompact(totalFiles), - humanize.Bytes(uint64(bytesProcessed)), - humanize.Bytes(uint64(totalBytes)), - pct, - humanize.Bytes(uint64(byteRate)), - fileRate, - elapsed.Round(time.Second)) - if eta > 0 { - fmt.Printf(", ETA: %s", eta.Round(time.Second)) - } - fmt.Println() + printProcessingProgress(filesProcessed, totalFiles, bytesProcessed, totalBytes, startTime) lastStatusTime = time.Now() } } + // Finalize: flush packer, pending files, and handle local blobs + return s.finalizeProcessPhase(ctx, result) +} + +// processFileWithErrorHandling wraps processFileStreaming with error recovery for +// deleted files and skip-errors mode. Returns (skipped, error). +func (s *Scanner) processFileWithErrorHandling(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) (bool, error) { + if err := s.processFileStreaming(ctx, fileToProcess, result); err != nil { + // Handle files that were deleted between scan and process phases + if errors.Is(err, os.ErrNotExist) { + log.Warn("File was deleted during backup, skipping", "path", fileToProcess.Path) + result.FilesSkipped++ + return true, nil + } + // Skip file read errors if --skip-errors is enabled + if s.skipErrors { + log.Error("ERROR: Failed to process file (skipping due to --skip-errors)", "path", fileToProcess.Path, "error", err) + fmt.Printf("ERROR: Failed to process %s: %v (skipping)\n", fileToProcess.Path, err) + result.FilesSkipped++ + return true, nil + } + return false, fmt.Errorf("processing file %s: %w", fileToProcess.Path, err) + } + return false, nil +} + +// printProcessingProgress prints a periodic progress line during the process phase, +// showing files processed, bytes transferred, throughput, and ETA +func printProcessingProgress(filesProcessed, totalFiles int, bytesProcessed, totalBytes int64, startTime time.Time) { + elapsed := time.Since(startTime) + pct := float64(bytesProcessed) / float64(totalBytes) * 100 + byteRate := float64(bytesProcessed) / elapsed.Seconds() + fileRate := float64(filesProcessed) / elapsed.Seconds() + + // Calculate ETA based on bytes (more accurate than files) + remainingBytes := totalBytes - bytesProcessed + var eta time.Duration + if byteRate > 0 { + eta = time.Duration(float64(remainingBytes)/byteRate) * time.Second + } + + // Format: Progress [5.7k/610k] 6.7 GB/44 GB (15.4%), 106MB/sec, 500 files/sec, running for 1m30s, ETA: 5m49s + fmt.Printf("Progress [%s/%s] %s/%s (%.1f%%), %s/sec, %.0f files/sec, running for %s", + formatCompact(filesProcessed), + formatCompact(totalFiles), + humanize.Bytes(uint64(bytesProcessed)), + humanize.Bytes(uint64(totalBytes)), + pct, + humanize.Bytes(uint64(byteRate)), + fileRate, + elapsed.Round(time.Second)) + if eta > 0 { + fmt.Printf(", ETA: %s", eta.Round(time.Second)) + } + fmt.Println() +} + +// finalizeProcessPhase flushes the packer, writes remaining pending files to the database, +// and handles local blob storage when no remote storage is configured +func (s *Scanner) finalizeProcessPhase(ctx context.Context, result *ScanResult) error { // Final packer flush first - this commits remaining chunks to DB // and handleBlobReady will flush files whose chunks are now committed s.packerMu.Lock() @@ -931,40 +1013,103 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { startTime := time.Now().UTC() finishedBlob := blobWithReader.FinishedBlob - // Report upload start and increment blobs created if s.progress != nil { s.progress.ReportUploadStart(finishedBlob.Hash, finishedBlob.Compressed) s.progress.GetStats().BlobsCreated.Add(1) } - // Upload to storage first (without holding any locks) - // Use scan context for cancellation support ctx := s.scanCtx if ctx == nil { ctx = context.Background() } - // Track bytes uploaded for accurate speed calculation + blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash) + blobExists, err := s.uploadBlobIfNeeded(ctx, blobPath, blobWithReader, startTime) + if err != nil { + s.cleanupBlobTempFile(blobWithReader) + return fmt.Errorf("uploading blob %s: %w", finishedBlob.Hash, err) + } + + if err := s.recordBlobMetadata(ctx, finishedBlob, blobExists, startTime); err != nil { + s.cleanupBlobTempFile(blobWithReader) + return err + } + + s.cleanupBlobTempFile(blobWithReader) + + // Chunks from this blob are now committed to DB - remove from pending set + s.removePendingChunkHashes(blobWithReader.InsertedChunkHashes) + + // Flush files whose chunks are now all committed + if err := s.flushCompletedPendingFiles(ctx); err != nil { + return fmt.Errorf("flushing completed files: %w", err) + } + + return nil +} + +// uploadBlobIfNeeded uploads the blob to storage if it doesn't already exist, returns whether it existed +func (s *Scanner) uploadBlobIfNeeded(ctx context.Context, blobPath string, blobWithReader *blob.BlobWithReader, startTime time.Time) (bool, error) { + finishedBlob := blobWithReader.FinishedBlob + + // Check if blob already exists (deduplication after restart) + if _, err := s.storage.Stat(ctx, blobPath); err == nil { + log.Info("Blob already exists in storage, skipping upload", + "hash", finishedBlob.Hash, "size", humanize.Bytes(uint64(finishedBlob.Compressed))) + fmt.Printf("Blob exists: %s (%s, skipped upload)\n", + finishedBlob.Hash[:12]+"...", humanize.Bytes(uint64(finishedBlob.Compressed))) + return true, nil + } + + progressCallback := s.makeUploadProgressCallback(ctx, finishedBlob) + + if err := s.storage.PutWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil { + log.Error("Failed to upload blob", "hash", finishedBlob.Hash, "error", err) + return false, fmt.Errorf("uploading blob to storage: %w", err) + } + + uploadDuration := time.Since(startTime) + uploadSpeedBps := float64(finishedBlob.Compressed) / uploadDuration.Seconds() + + fmt.Printf("Blob stored: %s (%s, %s/sec, %s)\n", + finishedBlob.Hash[:12]+"...", + humanize.Bytes(uint64(finishedBlob.Compressed)), + humanize.Bytes(uint64(uploadSpeedBps)), + uploadDuration.Round(time.Millisecond)) + + log.Info("Successfully uploaded blob to storage", + "path", blobPath, + "size", humanize.Bytes(uint64(finishedBlob.Compressed)), + "duration", uploadDuration, + "speed", humanize.SI(uploadSpeedBps*8, "bps")) + + if s.progress != nil { + s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration) + stats := s.progress.GetStats() + stats.BlobsUploaded.Add(1) + stats.BytesUploaded.Add(finishedBlob.Compressed) + } + + return false, nil +} + +// makeUploadProgressCallback creates a progress callback for blob uploads +func (s *Scanner) makeUploadProgressCallback(ctx context.Context, finishedBlob *blob.FinishedBlob) func(int64) error { lastProgressTime := time.Now() lastProgressBytes := int64(0) - progressCallback := func(uploaded int64) error { - // Calculate instantaneous speed + return func(uploaded int64) error { now := time.Now() elapsed := now.Sub(lastProgressTime).Seconds() - if elapsed > 0.5 { // Update speed every 0.5 seconds + if elapsed > 0.5 { bytesSinceLastUpdate := uploaded - lastProgressBytes speed := float64(bytesSinceLastUpdate) / elapsed - if s.progress != nil { s.progress.ReportUploadProgress(finishedBlob.Hash, uploaded, finishedBlob.Compressed, speed) } - lastProgressTime = now lastProgressBytes = uploaded } - - // Check for cancellation select { case <-ctx.Done(): return ctx.Err() @@ -972,87 +1117,26 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { return nil } } +} - // Create sharded path: blobs/ca/fe/cafebabe... - blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash) - - // Check if blob already exists in remote storage (deduplication after restart) - blobExists := false - if _, err := s.storage.Stat(ctx, blobPath); err == nil { - blobExists = true - log.Info("Blob already exists in storage, skipping upload", - "hash", finishedBlob.Hash, - "size", humanize.Bytes(uint64(finishedBlob.Compressed))) - fmt.Printf("Blob exists: %s (%s, skipped upload)\n", - finishedBlob.Hash[:12]+"...", - humanize.Bytes(uint64(finishedBlob.Compressed))) - } - - if !blobExists { - if err := s.storage.PutWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil { - return fmt.Errorf("uploading blob %s to storage: %w", finishedBlob.Hash, err) - } - - uploadDuration := time.Since(startTime) - - // Calculate upload speed - uploadSpeedBps := float64(finishedBlob.Compressed) / uploadDuration.Seconds() - - // Print blob stored message - fmt.Printf("Blob stored: %s (%s, %s/sec, %s)\n", - finishedBlob.Hash[:12]+"...", - humanize.Bytes(uint64(finishedBlob.Compressed)), - humanize.Bytes(uint64(uploadSpeedBps)), - uploadDuration.Round(time.Millisecond)) - - // Log upload stats - uploadSpeedBits := uploadSpeedBps * 8 // bits per second - log.Info("Successfully uploaded blob to storage", - "path", blobPath, - "size", humanize.Bytes(uint64(finishedBlob.Compressed)), - "duration", uploadDuration, - "speed", humanize.SI(uploadSpeedBits, "bps")) - - // Report upload complete - if s.progress != nil { - s.progress.ReportUploadComplete(finishedBlob.Hash, finishedBlob.Compressed, uploadDuration) - } - - // Update progress after upload completes - if s.progress != nil { - stats := s.progress.GetStats() - stats.BlobsUploaded.Add(1) - stats.BytesUploaded.Add(finishedBlob.Compressed) - } - } - - // Store metadata in database (after upload is complete) - dbCtx := s.scanCtx - if dbCtx == nil { - dbCtx = context.Background() - } - - // Parse blob ID for typed operations +// recordBlobMetadata stores blob upload metadata in the database +func (s *Scanner) recordBlobMetadata(ctx context.Context, finishedBlob *blob.FinishedBlob, blobExists bool, startTime time.Time) error { finishedBlobID, err := types.ParseBlobID(finishedBlob.ID) if err != nil { return fmt.Errorf("parsing finished blob ID: %w", err) } - // Track upload duration (0 if blob already existed) uploadDuration := time.Since(startTime) - err = s.repos.WithTx(dbCtx, func(ctx context.Context, tx *sql.Tx) error { - // Update blob upload timestamp - if err := s.repos.Blobs.UpdateUploaded(ctx, tx, finishedBlob.ID); err != nil { + return s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error { + if err := s.repos.Blobs.UpdateUploaded(txCtx, tx, finishedBlob.ID); err != nil { return fmt.Errorf("updating blob upload timestamp: %w", err) } - // Add the blob to the snapshot - if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, finishedBlobID, types.BlobHash(finishedBlob.Hash)); err != nil { + if err := s.repos.Snapshots.AddBlob(txCtx, tx, s.snapshotID, finishedBlobID, types.BlobHash(finishedBlob.Hash)); err != nil { return fmt.Errorf("adding blob to snapshot: %w", err) } - // Record upload metrics (only for actual uploads, not deduplicated blobs) if !blobExists { upload := &database.Upload{ BlobHash: finishedBlob.Hash, @@ -1061,15 +1145,17 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { Size: finishedBlob.Compressed, DurationMs: uploadDuration.Milliseconds(), } - if err := s.repos.Uploads.Create(ctx, tx, upload); err != nil { + if err := s.repos.Uploads.Create(txCtx, tx, upload); err != nil { return fmt.Errorf("recording upload metrics: %w", err) } } return nil }) +} - // Cleanup temp file if needed +// cleanupBlobTempFile closes and removes the blob's temporary file +func (s *Scanner) cleanupBlobTempFile(blobWithReader *blob.BlobWithReader) { if blobWithReader.TempFile != nil { tempName := blobWithReader.TempFile.Name() if err := blobWithReader.TempFile.Close(); err != nil { @@ -1079,77 +1165,41 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error { log.Fatal("Failed to remove temp file", "file", tempName, "error", err) } } +} - if err != nil { - return err - } - - // Chunks from this blob are now committed to DB - remove from pending set - log.Debug("handleBlobReady: removing pending chunk hashes") - s.removePendingChunkHashes(blobWithReader.InsertedChunkHashes) - log.Debug("handleBlobReady: removed pending chunk hashes") - - // Flush files whose chunks are now all committed - // This maintains database consistency after each blob - log.Debug("handleBlobReady: calling flushCompletedPendingFiles") - if err := s.flushCompletedPendingFiles(dbCtx); err != nil { - return fmt.Errorf("flushing completed files: %w", err) - } - log.Debug("handleBlobReady: flushCompletedPendingFiles returned") - - log.Debug("handleBlobReady: complete") - return nil +// streamingChunkInfo tracks chunk metadata collected during streaming +type streamingChunkInfo struct { + fileChunk database.FileChunk + offset int64 + size int64 } // processFileStreaming processes a file by streaming chunks directly to the packer func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileToProcess, result *ScanResult) error { - // Open the file file, err := s.fs.Open(fileToProcess.Path) if err != nil { return fmt.Errorf("opening file: %w", err) } defer func() { _ = file.Close() }() - // We'll collect file chunks for database storage - // but process them for packing as we go - type chunkInfo struct { - fileChunk database.FileChunk - offset int64 - size int64 - } - var chunks []chunkInfo + var chunks []streamingChunkInfo chunkIndex := 0 - // Process chunks in streaming fashion and get full file hash fileHash, err := s.chunker.ChunkReaderStreaming(file, func(chunk chunker.Chunk) error { - // Check for cancellation select { case <-ctx.Done(): return ctx.Err() default: } - log.Debug("Processing content-defined chunk from file", - "file", fileToProcess.Path, - "chunk_index", chunkIndex, - "hash", chunk.Hash, - "size", chunk.Size) - - // Check if chunk already exists (fast in-memory lookup) chunkExists := s.chunkExists(chunk.Hash) - - // Queue new chunks for batch insert when blob finalizes - // This dramatically reduces transaction overhead if !chunkExists { s.packer.AddPendingChunk(chunk.Hash, chunk.Size) - // Add to in-memory cache immediately for fast duplicate detection s.addKnownChunk(chunk.Hash) - // Track as pending until blob finalizes and commits to DB s.addPendingChunkHash(chunk.Hash) } - // Track file chunk association for later storage - chunks = append(chunks, chunkInfo{ + chunks = append(chunks, streamingChunkInfo{ fileChunk: database.FileChunk{ FileID: fileToProcess.File.ID, Idx: chunkIndex, @@ -1159,55 +1209,15 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT size: chunk.Size, }) - // Update stats - if chunkExists { - result.FilesSkipped++ // Track as skipped for now - result.BytesSkipped += chunk.Size - if s.progress != nil { - s.progress.GetStats().BytesSkipped.Add(chunk.Size) - } - } else { - result.ChunksCreated++ - result.BytesScanned += chunk.Size - if s.progress != nil { - s.progress.GetStats().ChunksCreated.Add(1) - s.progress.GetStats().BytesProcessed.Add(chunk.Size) - s.progress.UpdateChunkingActivity() - } - } + s.updateChunkStats(chunkExists, chunk.Size, result) - // Add chunk to packer immediately (streaming) - // This happens outside the database transaction if !chunkExists { - s.packerMu.Lock() - err := s.packer.AddChunk(&blob.ChunkRef{ - Hash: chunk.Hash, - Data: chunk.Data, - }) - if err == blob.ErrBlobSizeLimitExceeded { - // Finalize current blob and retry - if err := s.packer.FinalizeBlob(); err != nil { - s.packerMu.Unlock() - return fmt.Errorf("finalizing blob: %w", err) - } - // Retry adding the chunk - if err := s.packer.AddChunk(&blob.ChunkRef{ - Hash: chunk.Hash, - Data: chunk.Data, - }); err != nil { - s.packerMu.Unlock() - return fmt.Errorf("adding chunk after finalize: %w", err) - } - } else if err != nil { - s.packerMu.Unlock() - return fmt.Errorf("adding chunk to packer: %w", err) + if err := s.addChunkToPacker(chunk); err != nil { + return err } - s.packerMu.Unlock() } - // Clear chunk data from memory immediately after use chunk.Data = nil - chunkIndex++ return nil }) @@ -1217,12 +1227,54 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT } log.Debug("Completed snapshotting file", - "path", fileToProcess.Path, - "file_hash", fileHash, - "chunks", len(chunks)) + "path", fileToProcess.Path, "file_hash", fileHash, "chunks", len(chunks)) - // Build file data for batch insertion - // Update chunk associations with the file ID + s.queueFileForBatchInsert(ctx, fileToProcess, chunks) + return nil +} + +// updateChunkStats updates scan result and progress stats for a processed chunk +func (s *Scanner) updateChunkStats(chunkExists bool, chunkSize int64, result *ScanResult) { + if chunkExists { + result.FilesSkipped++ + result.BytesSkipped += chunkSize + if s.progress != nil { + s.progress.GetStats().BytesSkipped.Add(chunkSize) + } + } else { + result.ChunksCreated++ + result.BytesScanned += chunkSize + if s.progress != nil { + s.progress.GetStats().ChunksCreated.Add(1) + s.progress.GetStats().BytesProcessed.Add(chunkSize) + s.progress.UpdateChunkingActivity() + } + } +} + +// addChunkToPacker adds a chunk to the blob packer, finalizing the current blob if needed +func (s *Scanner) addChunkToPacker(chunk chunker.Chunk) error { + s.packerMu.Lock() + err := s.packer.AddChunk(&blob.ChunkRef{Hash: chunk.Hash, Data: chunk.Data}) + if err == blob.ErrBlobSizeLimitExceeded { + if err := s.packer.FinalizeBlob(); err != nil { + s.packerMu.Unlock() + return fmt.Errorf("finalizing blob: %w", err) + } + if err := s.packer.AddChunk(&blob.ChunkRef{Hash: chunk.Hash, Data: chunk.Data}); err != nil { + s.packerMu.Unlock() + return fmt.Errorf("adding chunk after finalize: %w", err) + } + } else if err != nil { + s.packerMu.Unlock() + return fmt.Errorf("adding chunk to packer: %w", err) + } + s.packerMu.Unlock() + return nil +} + +// queueFileForBatchInsert builds file/chunk associations and queues the file for batch DB insert +func (s *Scanner) queueFileForBatchInsert(ctx context.Context, fileToProcess *FileToProcess, chunks []streamingChunkInfo) { fileChunks := make([]database.FileChunk, len(chunks)) chunkFiles := make([]database.ChunkFile, len(chunks)) for i, ci := range chunks { @@ -1239,14 +1291,11 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT } } - // Queue file for batch insertion - // Files will be flushed when their chunks are committed (after blob finalize) s.addPendingFile(ctx, pendingFileData{ file: fileToProcess.File, fileChunks: fileChunks, chunkFiles: chunkFiles, }) - return nil } // GetProgress returns the progress reporter for this scanner diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index bb01ea1..883c572 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -227,12 +227,39 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st } }() + // Steps 1-5: Copy, clean, vacuum, compress, and read the database + finalData, tempDBPath, err := sm.prepareExportDB(ctx, dbPath, snapshotID, tempDir) + if err != nil { + return err + } + + // Step 6: Generate blob manifest (before closing temp DB) + blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) + if err != nil { + return fmt.Errorf("generating blob manifest: %w", err) + } + + // Step 7: Upload to S3 in snapshot subdirectory + if err := sm.uploadSnapshotArtifacts(ctx, snapshotID, finalData, blobManifest); err != nil { + return err + } + + log.Info("Uploaded snapshot metadata", + "snapshot_id", snapshotID, + "db_size", len(finalData), + "manifest_size", len(blobManifest)) + return nil +} + +// prepareExportDB copies, cleans, vacuums, and compresses the snapshot database for export. +// Returns the compressed data and the path to the temporary database (needed for manifest generation). +func (sm *SnapshotManager) prepareExportDB(ctx context.Context, dbPath, snapshotID, tempDir string) ([]byte, string, error) { // Step 1: Copy database to temp file // The main database should be closed at this point tempDBPath := filepath.Join(tempDir, "snapshot.db") log.Debug("Copying database to temporary location", "source", dbPath, "destination", tempDBPath) if err := sm.copyFile(dbPath, tempDBPath); err != nil { - return fmt.Errorf("copying database: %w", err) + return nil, "", fmt.Errorf("copying database: %w", err) } log.Debug("Database copy complete", "size", sm.getFileSize(tempDBPath)) @@ -240,7 +267,7 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st log.Debug("Cleaning temporary database", "snapshot_id", snapshotID) stats, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshotID) if err != nil { - return fmt.Errorf("cleaning snapshot database: %w", err) + return nil, "", fmt.Errorf("cleaning snapshot database: %w", err) } log.Info("Temporary database cleanup complete", "db_path", tempDBPath, @@ -255,14 +282,14 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st // Step 3: VACUUM the database to remove deleted data and compact // This is critical for security - ensures no stale/deleted data is uploaded if err := sm.vacuumDatabase(tempDBPath); err != nil { - return fmt.Errorf("vacuuming database: %w", err) + return nil, "", fmt.Errorf("vacuuming database: %w", err) } log.Debug("Database vacuumed", "size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath)))) // Step 4: Compress and encrypt the binary database file compressedPath := filepath.Join(tempDir, "db.zst.age") if err := sm.compressFile(tempDBPath, compressedPath); err != nil { - return fmt.Errorf("compressing database: %w", err) + return nil, "", fmt.Errorf("compressing database: %w", err) } log.Debug("Compression complete", "original_size", humanize.Bytes(uint64(sm.getFileSize(tempDBPath))), @@ -271,49 +298,43 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st // Step 5: Read compressed and encrypted data for upload finalData, err := afero.ReadFile(sm.fs, compressedPath) if err != nil { - return fmt.Errorf("reading compressed dump: %w", err) + return nil, "", fmt.Errorf("reading compressed dump: %w", err) } - // Step 6: Generate blob manifest (before closing temp DB) - blobManifest, err := sm.generateBlobManifest(ctx, tempDBPath, snapshotID) - if err != nil { - return fmt.Errorf("generating blob manifest: %w", err) - } + return finalData, tempDBPath, nil +} - // Step 7: Upload to S3 in snapshot subdirectory +// uploadSnapshotArtifacts uploads the database backup and blob manifest to S3 +func (sm *SnapshotManager) uploadSnapshotArtifacts(ctx context.Context, snapshotID string, dbData, manifestData []byte) error { // Upload database backup (compressed and encrypted) dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) dbUploadStart := time.Now() - if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(finalData)); err != nil { + if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(dbData)); err != nil { return fmt.Errorf("uploading snapshot database: %w", err) } dbUploadDuration := time.Since(dbUploadStart) - dbUploadSpeed := float64(len(finalData)) * 8 / dbUploadDuration.Seconds() // bits per second + dbUploadSpeed := float64(len(dbData)) * 8 / dbUploadDuration.Seconds() // bits per second log.Info("Uploaded snapshot database", "path", dbKey, - "size", humanize.Bytes(uint64(len(finalData))), + "size", humanize.Bytes(uint64(len(dbData))), "duration", dbUploadDuration, "speed", humanize.SI(dbUploadSpeed, "bps")) // Upload blob manifest (compressed only, not encrypted) manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) manifestUploadStart := time.Now() - if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil { + if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(manifestData)); err != nil { return fmt.Errorf("uploading blob manifest: %w", err) } manifestUploadDuration := time.Since(manifestUploadStart) - manifestUploadSpeed := float64(len(blobManifest)) * 8 / manifestUploadDuration.Seconds() // bits per second + manifestUploadSpeed := float64(len(manifestData)) * 8 / manifestUploadDuration.Seconds() // bits per second log.Info("Uploaded blob manifest", "path", manifestKey, - "size", humanize.Bytes(uint64(len(blobManifest))), + "size", humanize.Bytes(uint64(len(manifestData))), "duration", manifestUploadDuration, "speed", humanize.SI(manifestUploadSpeed, "bps")) - log.Info("Uploaded snapshot metadata", - "snapshot_id", snapshotID, - "db_size", len(finalData), - "manifest_size", len(blobManifest)) return nil } diff --git a/internal/vaultik/blob_fetch.go b/internal/vaultik/blob_fetch.go new file mode 100644 index 0000000..b440492 --- /dev/null +++ b/internal/vaultik/blob_fetch.go @@ -0,0 +1,93 @@ +package vaultik + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + + "filippo.io/age" + "git.eeqj.de/sneak/vaultik/internal/blobgen" +) + +// hashVerifyReader wraps a blobgen.Reader and verifies the double-SHA-256 hash +// of decrypted plaintext when Close is called. It reuses the hash that +// blobgen.Reader already computes internally via its TeeReader, avoiding +// redundant SHA-256 computation. +type hashVerifyReader struct { + reader *blobgen.Reader // underlying decrypted blob reader (has internal hasher) + fetcher io.ReadCloser // raw fetched stream (closed on Close) + blobHash string // expected double-SHA-256 hex + done bool // EOF reached +} + +func (h *hashVerifyReader) Read(p []byte) (int, error) { + n, err := h.reader.Read(p) + if err == io.EOF { + h.done = true + } + return n, err +} + +// Close verifies the hash (if the stream was fully read) and closes underlying readers. +func (h *hashVerifyReader) Close() error { + readerErr := h.reader.Close() + fetcherErr := h.fetcher.Close() + + if h.done { + firstHash := h.reader.Sum256() + secondHasher := sha256.New() + secondHasher.Write(firstHash) + actualHashHex := hex.EncodeToString(secondHasher.Sum(nil)) + if actualHashHex != h.blobHash { + return fmt.Errorf("blob hash mismatch: expected %s, got %s", h.blobHash[:16], actualHashHex[:16]) + } + } + + if readerErr != nil { + return readerErr + } + return fetcherErr +} + +// FetchAndDecryptBlob downloads a blob, decrypts and decompresses it, and +// returns a streaming reader that computes the double-SHA-256 hash on the fly. +// The hash is verified when the returned reader is closed (after fully reading). +// This avoids buffering the entire blob in memory. +func (v *Vaultik) FetchAndDecryptBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) (io.ReadCloser, error) { + rc, _, err := v.FetchBlob(ctx, blobHash, expectedSize) + if err != nil { + return nil, err + } + + reader, err := blobgen.NewReader(rc, identity) + if err != nil { + _ = rc.Close() + return nil, fmt.Errorf("creating blob reader: %w", err) + } + + return &hashVerifyReader{ + reader: reader, + fetcher: rc, + blobHash: blobHash, + }, nil +} + +// FetchBlob downloads a blob and returns a reader for the encrypted data. +func (v *Vaultik) FetchBlob(ctx context.Context, blobHash string, expectedSize int64) (io.ReadCloser, int64, error) { + blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash) + + rc, err := v.Storage.Get(ctx, blobPath) + if err != nil { + return nil, 0, fmt.Errorf("downloading blob %s: %w", blobHash[:16], err) + } + + info, err := v.Storage.Stat(ctx, blobPath) + if err != nil { + _ = rc.Close() + return nil, 0, fmt.Errorf("stat blob %s: %w", blobHash[:16], err) + } + + return rc, info.Size, nil +} diff --git a/internal/vaultik/blob_fetch_hash_test.go b/internal/vaultik/blob_fetch_hash_test.go new file mode 100644 index 0000000..192ec78 --- /dev/null +++ b/internal/vaultik/blob_fetch_hash_test.go @@ -0,0 +1,100 @@ +package vaultik_test + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "io" + "strings" + "testing" + + "filippo.io/age" + "git.eeqj.de/sneak/vaultik/internal/blobgen" + "git.eeqj.de/sneak/vaultik/internal/vaultik" +) + +// TestFetchAndDecryptBlobVerifiesHash verifies that FetchAndDecryptBlob checks +// the double-SHA-256 hash of the decrypted plaintext against the expected blob hash. +func TestFetchAndDecryptBlobVerifiesHash(t *testing.T) { + identity, err := age.GenerateX25519Identity() + if err != nil { + t.Fatalf("generating identity: %v", err) + } + + // Create test data and encrypt it using blobgen.Writer + plaintext := []byte("hello world test data for blob hash verification") + var encBuf bytes.Buffer + writer, err := blobgen.NewWriter(&encBuf, 1, []string{identity.Recipient().String()}) + if err != nil { + t.Fatalf("creating blobgen writer: %v", err) + } + if _, err := writer.Write(plaintext); err != nil { + t.Fatalf("writing plaintext: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("closing writer: %v", err) + } + encryptedData := encBuf.Bytes() + + // Compute correct double-SHA-256 hash of the plaintext (matches blobgen.Writer.Sum256) + firstHash := sha256.Sum256(plaintext) + secondHash := sha256.Sum256(firstHash[:]) + correctHash := hex.EncodeToString(secondHash[:]) + + // Verify our hash matches what blobgen.Writer produces + writerHash := hex.EncodeToString(writer.Sum256()) + if correctHash != writerHash { + t.Fatalf("hash computation mismatch: manual=%s, writer=%s", correctHash, writerHash) + } + + // Set up mock storage with the blob at the correct path + mockStorage := NewMockStorer() + blobPath := "blobs/" + correctHash[:2] + "/" + correctHash[2:4] + "/" + correctHash + mockStorage.mu.Lock() + mockStorage.data[blobPath] = encryptedData + mockStorage.mu.Unlock() + + tv := vaultik.NewForTesting(mockStorage) + ctx := context.Background() + + t.Run("correct hash succeeds", func(t *testing.T) { + rc, err := tv.FetchAndDecryptBlob(ctx, correctHash, int64(len(encryptedData)), identity) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + data, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("reading stream: %v", err) + } + if err := rc.Close(); err != nil { + t.Fatalf("close (hash verification) failed: %v", err) + } + if !bytes.Equal(data, plaintext) { + t.Fatalf("decrypted data mismatch: got %q, want %q", data, plaintext) + } + }) + + t.Run("wrong hash fails", func(t *testing.T) { + // Use a fake hash that doesn't match the actual plaintext + fakeHash := strings.Repeat("ab", 32) // 64 hex chars + fakePath := "blobs/" + fakeHash[:2] + "/" + fakeHash[2:4] + "/" + fakeHash + mockStorage.mu.Lock() + mockStorage.data[fakePath] = encryptedData + mockStorage.mu.Unlock() + + rc, err := tv.FetchAndDecryptBlob(ctx, fakeHash, int64(len(encryptedData)), identity) + if err != nil { + t.Fatalf("unexpected error opening stream: %v", err) + } + // Read all data — hash is verified on Close + _, _ = io.ReadAll(rc) + err = rc.Close() + if err == nil { + t.Fatal("expected error for mismatched hash, got nil") + } + if !strings.Contains(err.Error(), "hash mismatch") { + t.Fatalf("expected hash mismatch error, got: %v", err) + } + }) +} diff --git a/internal/vaultik/blobcache.go b/internal/vaultik/blobcache.go new file mode 100644 index 0000000..cdcee69 --- /dev/null +++ b/internal/vaultik/blobcache.go @@ -0,0 +1,207 @@ +package vaultik + +import ( + "fmt" + "os" + "path/filepath" + "sync" +) + +// blobDiskCacheEntry tracks a cached blob on disk. +type blobDiskCacheEntry struct { + key string + size int64 + prev *blobDiskCacheEntry + next *blobDiskCacheEntry +} + +// blobDiskCache is an LRU cache that stores blobs on disk instead of in memory. +// Blobs are written to a temp directory keyed by their hash. When total size +// exceeds maxBytes, the least-recently-used entries are evicted (deleted from disk). +type blobDiskCache struct { + mu sync.Mutex + dir string + maxBytes int64 + curBytes int64 + items map[string]*blobDiskCacheEntry + head *blobDiskCacheEntry // most recent + tail *blobDiskCacheEntry // least recent +} + +// newBlobDiskCache creates a new disk-based blob cache with the given max size. +func newBlobDiskCache(maxBytes int64) (*blobDiskCache, error) { + dir, err := os.MkdirTemp("", "vaultik-blobcache-*") + if err != nil { + return nil, fmt.Errorf("creating blob cache dir: %w", err) + } + return &blobDiskCache{ + dir: dir, + maxBytes: maxBytes, + items: make(map[string]*blobDiskCacheEntry), + }, nil +} + +func (c *blobDiskCache) path(key string) string { + return filepath.Join(c.dir, key) +} + +func (c *blobDiskCache) unlink(e *blobDiskCacheEntry) { + if e.prev != nil { + e.prev.next = e.next + } else { + c.head = e.next + } + if e.next != nil { + e.next.prev = e.prev + } else { + c.tail = e.prev + } + e.prev = nil + e.next = nil +} + +func (c *blobDiskCache) pushFront(e *blobDiskCacheEntry) { + e.prev = nil + e.next = c.head + if c.head != nil { + c.head.prev = e + } + c.head = e + if c.tail == nil { + c.tail = e + } +} + +func (c *blobDiskCache) evictLRU() { + if c.tail == nil { + return + } + victim := c.tail + c.unlink(victim) + delete(c.items, victim.key) + c.curBytes -= victim.size + _ = os.Remove(c.path(victim.key)) +} + +// Put writes blob data to disk cache. Entries larger than maxBytes are silently skipped. +func (c *blobDiskCache) Put(key string, data []byte) error { + entrySize := int64(len(data)) + + c.mu.Lock() + defer c.mu.Unlock() + + if entrySize > c.maxBytes { + return nil + } + + // Remove old entry if updating + if e, ok := c.items[key]; ok { + c.unlink(e) + c.curBytes -= e.size + _ = os.Remove(c.path(key)) + delete(c.items, key) + } + + if err := os.WriteFile(c.path(key), data, 0600); err != nil { + return fmt.Errorf("writing blob to cache: %w", err) + } + + e := &blobDiskCacheEntry{key: key, size: entrySize} + c.pushFront(e) + c.items[key] = e + c.curBytes += entrySize + + for c.curBytes > c.maxBytes && c.tail != nil { + c.evictLRU() + } + + return nil +} + +// Get reads a cached blob from disk. Returns data and true on hit. +func (c *blobDiskCache) Get(key string) ([]byte, bool) { + c.mu.Lock() + e, ok := c.items[key] + if !ok { + c.mu.Unlock() + return nil, false + } + c.unlink(e) + c.pushFront(e) + c.mu.Unlock() + + data, err := os.ReadFile(c.path(key)) + if err != nil { + c.mu.Lock() + if e2, ok2 := c.items[key]; ok2 && e2 == e { + c.unlink(e) + delete(c.items, key) + c.curBytes -= e.size + } + c.mu.Unlock() + return nil, false + } + return data, true +} + +// ReadAt reads a slice of a cached blob without loading the entire blob into memory. +func (c *blobDiskCache) ReadAt(key string, offset, length int64) ([]byte, error) { + c.mu.Lock() + e, ok := c.items[key] + if !ok { + c.mu.Unlock() + return nil, fmt.Errorf("key %q not in cache", key) + } + if offset+length > e.size { + c.mu.Unlock() + return nil, fmt.Errorf("read beyond blob size: offset=%d length=%d size=%d", offset, length, e.size) + } + c.unlink(e) + c.pushFront(e) + c.mu.Unlock() + + f, err := os.Open(c.path(key)) + if err != nil { + return nil, err + } + defer func() { _ = f.Close() }() + + buf := make([]byte, length) + if _, err := f.ReadAt(buf, offset); err != nil { + return nil, err + } + return buf, nil +} + +// Has returns whether a key exists in the cache. +func (c *blobDiskCache) Has(key string) bool { + c.mu.Lock() + defer c.mu.Unlock() + _, ok := c.items[key] + return ok +} + +// Size returns current total cached bytes. +func (c *blobDiskCache) Size() int64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.curBytes +} + +// Len returns number of cached entries. +func (c *blobDiskCache) Len() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.items) +} + +// Close removes the cache directory and all cached blobs. +func (c *blobDiskCache) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.items = nil + c.head = nil + c.tail = nil + c.curBytes = 0 + return os.RemoveAll(c.dir) +} diff --git a/internal/vaultik/blobcache_test.go b/internal/vaultik/blobcache_test.go new file mode 100644 index 0000000..778aadd --- /dev/null +++ b/internal/vaultik/blobcache_test.go @@ -0,0 +1,189 @@ +package vaultik + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" +) + +func TestBlobDiskCache_BasicGetPut(t *testing.T) { + cache, err := newBlobDiskCache(1 << 20) + if err != nil { + t.Fatal(err) + } + defer func() { _ = cache.Close() }() + + data := []byte("hello world") + if err := cache.Put("key1", data); err != nil { + t.Fatal(err) + } + + got, ok := cache.Get("key1") + if !ok { + t.Fatal("expected cache hit") + } + if !bytes.Equal(got, data) { + t.Fatalf("got %q, want %q", got, data) + } + + _, ok = cache.Get("nonexistent") + if ok { + t.Fatal("expected cache miss") + } +} + +func TestBlobDiskCache_EvictionUnderPressure(t *testing.T) { + maxBytes := int64(1000) + cache, err := newBlobDiskCache(maxBytes) + if err != nil { + t.Fatal(err) + } + defer func() { _ = cache.Close() }() + + for i := 0; i < 5; i++ { + data := make([]byte, 300) + if err := cache.Put(fmt.Sprintf("key%d", i), data); err != nil { + t.Fatal(err) + } + } + + if cache.Size() > maxBytes { + t.Fatalf("cache size %d exceeds max %d", cache.Size(), maxBytes) + } + + if !cache.Has("key4") { + t.Fatal("expected key4 to be cached") + } + if cache.Has("key0") { + t.Fatal("expected key0 to be evicted") + } +} + +func TestBlobDiskCache_OversizedEntryRejected(t *testing.T) { + cache, err := newBlobDiskCache(100) + if err != nil { + t.Fatal(err) + } + defer func() { _ = cache.Close() }() + + data := make([]byte, 200) + if err := cache.Put("big", data); err != nil { + t.Fatal(err) + } + + if cache.Has("big") { + t.Fatal("oversized entry should not be cached") + } +} + +func TestBlobDiskCache_UpdateInPlace(t *testing.T) { + cache, err := newBlobDiskCache(1 << 20) + if err != nil { + t.Fatal(err) + } + defer func() { _ = cache.Close() }() + + if err := cache.Put("key1", []byte("v1")); err != nil { + t.Fatal(err) + } + if err := cache.Put("key1", []byte("version2")); err != nil { + t.Fatal(err) + } + + got, ok := cache.Get("key1") + if !ok { + t.Fatal("expected hit") + } + if string(got) != "version2" { + t.Fatalf("got %q, want %q", got, "version2") + } + if cache.Len() != 1 { + t.Fatalf("expected 1 entry, got %d", cache.Len()) + } + if cache.Size() != int64(len("version2")) { + t.Fatalf("expected size %d, got %d", len("version2"), cache.Size()) + } +} + +func TestBlobDiskCache_ReadAt(t *testing.T) { + cache, err := newBlobDiskCache(1 << 20) + if err != nil { + t.Fatal(err) + } + defer func() { _ = cache.Close() }() + + data := make([]byte, 1024) + if _, err := rand.Read(data); err != nil { + t.Fatal(err) + } + if err := cache.Put("blob1", data); err != nil { + t.Fatal(err) + } + + chunk, err := cache.ReadAt("blob1", 100, 200) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(chunk, data[100:300]) { + t.Fatal("ReadAt returned wrong data") + } + + _, err = cache.ReadAt("blob1", 900, 200) + if err == nil { + t.Fatal("expected error for out-of-bounds read") + } + + _, err = cache.ReadAt("missing", 0, 10) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestBlobDiskCache_Close(t *testing.T) { + cache, err := newBlobDiskCache(1 << 20) + if err != nil { + t.Fatal(err) + } + + if err := cache.Put("key1", []byte("data")); err != nil { + t.Fatal(err) + } + if err := cache.Close(); err != nil { + t.Fatal(err) + } +} + +func TestBlobDiskCache_LRUOrder(t *testing.T) { + cache, err := newBlobDiskCache(200) + if err != nil { + t.Fatal(err) + } + defer func() { _ = cache.Close() }() + + d := make([]byte, 100) + if err := cache.Put("a", d); err != nil { + t.Fatal(err) + } + if err := cache.Put("b", d); err != nil { + t.Fatal(err) + } + + // Access "a" to make it most recently used + cache.Get("a") + + // Adding "c" should evict "b" (LRU), not "a" + if err := cache.Put("c", d); err != nil { + t.Fatal(err) + } + + if !cache.Has("a") { + t.Fatal("expected 'a' to survive") + } + if !cache.Has("c") { + t.Fatal("expected 'c' to be present") + } + if cache.Has("b") { + t.Fatal("expected 'b' to be evicted") + } +} diff --git a/internal/vaultik/helpers.go b/internal/vaultik/helpers.go index 9947caa..16c1ed4 100644 --- a/internal/vaultik/helpers.go +++ b/internal/vaultik/helpers.go @@ -79,31 +79,20 @@ func parseSnapshotTimestamp(snapshotID string) (time.Time, error) { return timestamp.UTC(), nil } -// snapshotNameFromID extracts the snapshot name from a snapshot ID. -// Snapshot IDs are formatted as `__` (or -// `_` if no name was given). The hostname argument -// is used to disambiguate cases where the hostname itself contains -// underscores. Returns "" if the ID has no name component. -func snapshotNameFromID(snapshotID, hostname string) string { - // Strip the trailing `_` suffix. - idx := strings.LastIndex(snapshotID, "_") - if idx <= 0 { +// parseSnapshotName extracts the snapshot name from a snapshot ID. +// Format: hostname_snapshotname_timestamp — the middle part(s) between hostname +// and the RFC3339 timestamp are the snapshot name (may contain underscores). +// Returns the snapshot name, or empty string if the ID is malformed. +func parseSnapshotName(snapshotID string) string { + parts := strings.Split(snapshotID, "_") + if len(parts) < 3 { + // Format: hostname_timestamp — no snapshot name return "" } - prefix := snapshotID[:idx] - - // Strip the leading hostname prefix. - if !strings.HasPrefix(prefix, hostname) { - return "" - } - rest := prefix[len(hostname):] - if rest == "" { - return "" // No name component - } - if rest[0] == '_' { - return rest[1:] - } - return "" + // Format: hostname_name_timestamp — middle parts are the name. + // The last part is the RFC3339 timestamp, the first part is the hostname, + // everything in between is the snapshot name (which may itself contain underscores). + return strings.Join(parts[1:len(parts)-1], "_") } // parseDuration parses a duration string with support for days diff --git a/internal/vaultik/helpers_test.go b/internal/vaultik/helpers_test.go new file mode 100644 index 0000000..ef7bf5b --- /dev/null +++ b/internal/vaultik/helpers_test.go @@ -0,0 +1,76 @@ +package vaultik + +import ( + "testing" +) + +func TestParseSnapshotName(t *testing.T) { + tests := []struct { + name string + snapshotID string + want string + }{ + { + name: "standard format with name", + snapshotID: "myhost_home_2026-01-12T14:41:15Z", + want: "home", + }, + { + name: "standard format with different name", + snapshotID: "server1_system_2026-02-15T09:30:00Z", + want: "system", + }, + { + name: "name with underscores", + snapshotID: "myhost_my_special_backup_2026-03-01T00:00:00Z", + want: "my_special_backup", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseSnapshotName(tt.snapshotID) + if got != tt.want { + t.Errorf("parseSnapshotName(%q) = %q, want %q", tt.snapshotID, got, tt.want) + } + }) + } +} + +func TestParseSnapshotTimestamp(t *testing.T) { + tests := []struct { + name string + snapshotID string + wantErr bool + }{ + { + name: "valid with name", + snapshotID: "myhost_home_2026-01-12T14:41:15Z", + wantErr: false, + }, + { + name: "valid without name", + snapshotID: "myhost_2026-01-12T14:41:15Z", + wantErr: false, + }, + { + name: "invalid - single part", + snapshotID: "nounderscore", + wantErr: true, + }, + { + name: "invalid - bad timestamp", + snapshotID: "myhost_home_notadate", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := parseSnapshotTimestamp(tt.snapshotID) + if (err != nil) != tt.wantErr { + t.Errorf("parseSnapshotTimestamp(%q) error = %v, wantErr %v", tt.snapshotID, err, tt.wantErr) + } + }) + } +} diff --git a/internal/vaultik/info.go b/internal/vaultik/info.go index 7640ad1..2d1517b 100644 --- a/internal/vaultik/info.go +++ b/internal/vaultik/info.go @@ -15,87 +15,87 @@ import ( // ShowInfo displays system and configuration information func (v *Vaultik) ShowInfo() error { // System Information - fmt.Printf("=== System Information ===\n") - fmt.Printf("OS/Architecture: %s/%s\n", runtime.GOOS, runtime.GOARCH) - fmt.Printf("Version: %s\n", v.Globals.Version) - fmt.Printf("Commit: %s\n", v.Globals.Commit) - fmt.Printf("Go Version: %s\n", runtime.Version()) - fmt.Println() + v.printfStdout("=== System Information ===\n") + v.printfStdout("OS/Architecture: %s/%s\n", runtime.GOOS, runtime.GOARCH) + v.printfStdout("Version: %s\n", v.Globals.Version) + v.printfStdout("Commit: %s\n", v.Globals.Commit) + v.printfStdout("Go Version: %s\n", runtime.Version()) + v.printlnStdout() // Storage Configuration - fmt.Printf("=== Storage Configuration ===\n") - fmt.Printf("S3 Bucket: %s\n", v.Config.S3.Bucket) + v.printfStdout("=== Storage Configuration ===\n") + v.printfStdout("S3 Bucket: %s\n", v.Config.S3.Bucket) if v.Config.S3.Prefix != "" { - fmt.Printf("S3 Prefix: %s\n", v.Config.S3.Prefix) + v.printfStdout("S3 Prefix: %s\n", v.Config.S3.Prefix) } - fmt.Printf("S3 Endpoint: %s\n", v.Config.S3.Endpoint) - fmt.Printf("S3 Region: %s\n", v.Config.S3.Region) - fmt.Println() + v.printfStdout("S3 Endpoint: %s\n", v.Config.S3.Endpoint) + v.printfStdout("S3 Region: %s\n", v.Config.S3.Region) + v.printlnStdout() // Backup Settings - fmt.Printf("=== Backup Settings ===\n") + v.printfStdout("=== Backup Settings ===\n") // Show configured snapshots - fmt.Printf("Snapshots:\n") + v.printfStdout("Snapshots:\n") for _, name := range v.Config.SnapshotNames() { snap := v.Config.Snapshots[name] - fmt.Printf(" %s:\n", name) + v.printfStdout(" %s:\n", name) for _, path := range snap.Paths { - fmt.Printf(" - %s\n", path) + v.printfStdout(" - %s\n", path) } if len(snap.Exclude) > 0 { - fmt.Printf(" exclude: %s\n", strings.Join(snap.Exclude, ", ")) + v.printfStdout(" exclude: %s\n", strings.Join(snap.Exclude, ", ")) } } // Global exclude patterns if len(v.Config.Exclude) > 0 { - fmt.Printf("Global Exclude: %s\n", strings.Join(v.Config.Exclude, ", ")) + v.printfStdout("Global Exclude: %s\n", strings.Join(v.Config.Exclude, ", ")) } - fmt.Printf("Compression: zstd level %d\n", v.Config.CompressionLevel) - fmt.Printf("Chunk Size: %s\n", humanize.Bytes(uint64(v.Config.ChunkSize))) - fmt.Printf("Blob Size Limit: %s\n", humanize.Bytes(uint64(v.Config.BlobSizeLimit))) - fmt.Println() + v.printfStdout("Compression: zstd level %d\n", v.Config.CompressionLevel) + v.printfStdout("Chunk Size: %s\n", humanize.Bytes(uint64(v.Config.ChunkSize))) + v.printfStdout("Blob Size Limit: %s\n", humanize.Bytes(uint64(v.Config.BlobSizeLimit))) + v.printlnStdout() // Encryption Configuration - fmt.Printf("=== Encryption Configuration ===\n") - fmt.Printf("Recipients:\n") + v.printfStdout("=== Encryption Configuration ===\n") + v.printfStdout("Recipients:\n") for _, recipient := range v.Config.AgeRecipients { - fmt.Printf(" - %s\n", recipient) + v.printfStdout(" - %s\n", recipient) } - fmt.Println() + v.printlnStdout() // Local Database - fmt.Printf("=== Local Database ===\n") - fmt.Printf("Index Path: %s\n", v.Config.IndexPath) + v.printfStdout("=== Local Database ===\n") + v.printfStdout("Index Path: %s\n", v.Config.IndexPath) // Check if index file exists and get its size if info, err := v.Fs.Stat(v.Config.IndexPath); err == nil { - fmt.Printf("Index Size: %s\n", humanize.Bytes(uint64(info.Size()))) + v.printfStdout("Index Size: %s\n", humanize.Bytes(uint64(info.Size()))) // Get snapshot count from database query := `SELECT COUNT(*) FROM snapshots WHERE completed_at IS NOT NULL` var snapshotCount int if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&snapshotCount); err == nil { - fmt.Printf("Snapshots: %d\n", snapshotCount) + v.printfStdout("Snapshots: %d\n", snapshotCount) } // Get blob count from database query = `SELECT COUNT(*) FROM blobs` var blobCount int if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&blobCount); err == nil { - fmt.Printf("Blobs: %d\n", blobCount) + v.printfStdout("Blobs: %d\n", blobCount) } // Get file count from database query = `SELECT COUNT(*) FROM files` var fileCount int if err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&fileCount); err == nil { - fmt.Printf("Files: %d\n", fileCount) + v.printfStdout("Files: %d\n", fileCount) } } else { - fmt.Printf("Index Size: (not created)\n") + v.printfStdout("Index Size: (not created)\n") } return nil @@ -137,35 +137,64 @@ type RemoteInfoResult struct { // RemoteInfo displays information about remote storage func (v *Vaultik) RemoteInfo(jsonOutput bool) error { + log.Info("Starting remote storage info gathering") result := &RemoteInfoResult{} - // Get storage info storageInfo := v.Storage.Info() result.StorageType = storageInfo.Type result.StorageLocation = storageInfo.Location if !jsonOutput { - fmt.Printf("=== Remote Storage ===\n") - fmt.Printf("Type: %s\n", storageInfo.Type) - fmt.Printf("Location: %s\n", storageInfo.Location) - fmt.Println() + v.printfStdout("=== Remote Storage ===\n") + v.printfStdout("Type: %s\n", storageInfo.Type) + v.printfStdout("Location: %s\n", storageInfo.Location) + v.printlnStdout() + v.printfStdout("Scanning snapshot metadata...\n") + } + + snapshotMetadata, snapshotIDs, err := v.collectSnapshotMetadata() + if err != nil { + return err } - // List all snapshot metadata if !jsonOutput { - fmt.Printf("Scanning snapshot metadata...\n") + v.printfStdout("Downloading %d manifest(s)...\n", len(snapshotIDs)) } + referencedBlobs := v.collectReferencedBlobsFromManifests(snapshotIDs, snapshotMetadata) + + v.populateRemoteInfoResult(result, snapshotMetadata, snapshotIDs, referencedBlobs) + + if err := v.scanRemoteBlobStorage(result, referencedBlobs, jsonOutput); err != nil { + return err + } + + log.Info("Remote info complete", + "snapshots", result.TotalMetadataCount, + "total_blobs", result.TotalBlobCount, + "referenced_blobs", result.ReferencedBlobCount, + "orphaned_blobs", result.OrphanedBlobCount) + + if jsonOutput { + enc := json.NewEncoder(v.Stdout) + enc.SetIndent("", " ") + return enc.Encode(result) + } + + v.printRemoteInfoTable(result) + return nil +} + +// collectSnapshotMetadata scans remote metadata and returns per-snapshot info and sorted IDs +func (v *Vaultik) collectSnapshotMetadata() (map[string]*SnapshotMetadataInfo, []string, error) { snapshotMetadata := make(map[string]*SnapshotMetadataInfo) - // Collect metadata files metadataCh := v.Storage.ListStream(v.ctx, "metadata/") for obj := range metadataCh { if obj.Err != nil { - return fmt.Errorf("listing metadata: %w", obj.Err) + return nil, nil, fmt.Errorf("listing metadata: %w", obj.Err) } - // Parse key: metadata// parts := strings.Split(obj.Key, "/") if len(parts) < 3 { continue @@ -173,14 +202,11 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { snapshotID := parts[1] if _, exists := snapshotMetadata[snapshotID]; !exists { - snapshotMetadata[snapshotID] = &SnapshotMetadataInfo{ - SnapshotID: snapshotID, - } + snapshotMetadata[snapshotID] = &SnapshotMetadataInfo{SnapshotID: snapshotID} } info := snapshotMetadata[snapshotID] filename := parts[2] - if strings.HasPrefix(filename, "manifest") { info.ManifestSize = obj.Size } else if strings.HasPrefix(filename, "db") { @@ -189,19 +215,18 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { info.TotalSize = info.ManifestSize + info.DatabaseSize } - // Sort snapshots by ID for consistent output var snapshotIDs []string for id := range snapshotMetadata { snapshotIDs = append(snapshotIDs, id) } sort.Strings(snapshotIDs) - // Download and parse all manifests to get referenced blobs - if !jsonOutput { - fmt.Printf("Downloading %d manifest(s)...\n", len(snapshotIDs)) - } + return snapshotMetadata, snapshotIDs, nil +} - referencedBlobs := make(map[string]int64) // hash -> compressed size +// collectReferencedBlobsFromManifests downloads manifests and returns referenced blob hashes with sizes +func (v *Vaultik) collectReferencedBlobsFromManifests(snapshotIDs []string, snapshotMetadata map[string]*SnapshotMetadataInfo) map[string]int64 { + referencedBlobs := make(map[string]int64) for _, snapshotID := range snapshotIDs { manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) @@ -218,10 +243,8 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { continue } - // Record blob info from manifest info := snapshotMetadata[snapshotID] info.BlobCount = manifest.BlobCount - var blobsSize int64 for _, blob := range manifest.Blobs { referencedBlobs[blob.Hash] = blob.CompressedSize @@ -230,7 +253,11 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { info.BlobsSize = blobsSize } - // Build result snapshots + return referencedBlobs +} + +// populateRemoteInfoResult fills in the result's snapshot and referenced blob stats +func (v *Vaultik) populateRemoteInfoResult(result *RemoteInfoResult, snapshotMetadata map[string]*SnapshotMetadataInfo, snapshotIDs []string, referencedBlobs map[string]int64) { var totalMetadataSize int64 for _, id := range snapshotIDs { info := snapshotMetadata[id] @@ -240,26 +267,25 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { result.TotalMetadataSize = totalMetadataSize result.TotalMetadataCount = len(snapshotIDs) - // Calculate referenced blob stats for _, size := range referencedBlobs { result.ReferencedBlobCount++ result.ReferencedBlobSize += size } +} - // List all blobs on remote +// scanRemoteBlobStorage lists all blobs on remote and computes orphan stats +func (v *Vaultik) scanRemoteBlobStorage(result *RemoteInfoResult, referencedBlobs map[string]int64, jsonOutput bool) error { if !jsonOutput { - fmt.Printf("Scanning blobs...\n") + v.printfStdout("Scanning blobs...\n") } - allBlobs := make(map[string]int64) // hash -> size from storage - blobCh := v.Storage.ListStream(v.ctx, "blobs/") + allBlobs := make(map[string]int64) + for obj := range blobCh { if obj.Err != nil { return fmt.Errorf("listing blobs: %w", obj.Err) } - - // Extract hash from key: blobs/xx/yy/hash parts := strings.Split(obj.Key, "/") if len(parts) < 4 { continue @@ -270,7 +296,6 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { result.TotalBlobSize += obj.Size } - // Calculate orphaned blobs for hash, size := range allBlobs { if _, referenced := referencedBlobs[hash]; !referenced { result.OrphanedBlobCount++ @@ -278,22 +303,19 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { } } - // Output results - if jsonOutput { - enc := json.NewEncoder(v.Stdout) - enc.SetIndent("", " ") - return enc.Encode(result) - } + return nil +} - // Human-readable output - fmt.Printf("\n=== Snapshot Metadata ===\n") +// printRemoteInfoTable renders the human-readable remote info output +func (v *Vaultik) printRemoteInfoTable(result *RemoteInfoResult) { + v.printfStdout("\n=== Snapshot Metadata ===\n") if len(result.Snapshots) == 0 { - fmt.Printf("No snapshots found\n") + v.printfStdout("No snapshots found\n") } else { - fmt.Printf("%-45s %12s %12s %12s %10s %12s\n", "SNAPSHOT", "MANIFEST", "DATABASE", "TOTAL", "BLOBS", "BLOB SIZE") - fmt.Printf("%-45s %12s %12s %12s %10s %12s\n", strings.Repeat("-", 45), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 10), strings.Repeat("-", 12)) + v.printfStdout("%-45s %12s %12s %12s %10s %12s\n", "SNAPSHOT", "MANIFEST", "DATABASE", "TOTAL", "BLOBS", "BLOB SIZE") + v.printfStdout("%-45s %12s %12s %12s %10s %12s\n", strings.Repeat("-", 45), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 10), strings.Repeat("-", 12)) for _, info := range result.Snapshots { - fmt.Printf("%-45s %12s %12s %12s %10s %12s\n", + v.printfStdout("%-45s %12s %12s %12s %10s %12s\n", truncateString(info.SnapshotID, 45), humanize.Bytes(uint64(info.ManifestSize)), humanize.Bytes(uint64(info.DatabaseSize)), @@ -302,26 +324,21 @@ func (v *Vaultik) RemoteInfo(jsonOutput bool) error { humanize.Bytes(uint64(info.BlobsSize)), ) } - fmt.Printf("%-45s %12s %12s %12s %10s %12s\n", strings.Repeat("-", 45), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 10), strings.Repeat("-", 12)) - fmt.Printf("%-45s %12s %12s %12s\n", fmt.Sprintf("Total (%d snapshots)", result.TotalMetadataCount), "", "", humanize.Bytes(uint64(result.TotalMetadataSize))) + v.printfStdout("%-45s %12s %12s %12s %10s %12s\n", strings.Repeat("-", 45), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 12), strings.Repeat("-", 10), strings.Repeat("-", 12)) + v.printfStdout("%-45s %12s %12s %12s\n", fmt.Sprintf("Total (%d snapshots)", result.TotalMetadataCount), "", "", humanize.Bytes(uint64(result.TotalMetadataSize))) } - fmt.Printf("\n=== Blob Storage ===\n") - fmt.Printf("Total blobs on remote: %s (%s)\n", - humanize.Comma(int64(result.TotalBlobCount)), - humanize.Bytes(uint64(result.TotalBlobSize))) - fmt.Printf("Referenced by snapshots: %s (%s)\n", - humanize.Comma(int64(result.ReferencedBlobCount)), - humanize.Bytes(uint64(result.ReferencedBlobSize))) - fmt.Printf("Orphaned (unreferenced): %s (%s)\n", - humanize.Comma(int64(result.OrphanedBlobCount)), - humanize.Bytes(uint64(result.OrphanedBlobSize))) + v.printfStdout("\n=== Blob Storage ===\n") + v.printfStdout("Total blobs on remote: %s (%s)\n", + humanize.Comma(int64(result.TotalBlobCount)), humanize.Bytes(uint64(result.TotalBlobSize))) + v.printfStdout("Referenced by snapshots: %s (%s)\n", + humanize.Comma(int64(result.ReferencedBlobCount)), humanize.Bytes(uint64(result.ReferencedBlobSize))) + v.printfStdout("Orphaned (unreferenced): %s (%s)\n", + humanize.Comma(int64(result.OrphanedBlobCount)), humanize.Bytes(uint64(result.OrphanedBlobSize))) if result.OrphanedBlobCount > 0 { - fmt.Printf("\nRun 'vaultik prune --remote' to remove orphaned blobs.\n") + v.printfStdout("\nRun 'vaultik prune --remote' to remove orphaned blobs.\n") } - - return nil } // truncateString truncates a string to maxLen, adding "..." if truncated diff --git a/internal/vaultik/prune.go b/internal/vaultik/prune.go index 946461e..2fb1a35 100644 --- a/internal/vaultik/prune.go +++ b/internal/vaultik/prune.go @@ -3,7 +3,6 @@ package vaultik import ( "encoding/json" "fmt" - "os" "strings" "git.eeqj.de/sneak/vaultik/internal/log" @@ -28,54 +27,80 @@ type PruneBlobsResult struct { func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { log.Info("Starting prune operation") - // Get all remote snapshots and their manifests - allBlobsReferenced := make(map[string]bool) - manifestCount := 0 + allBlobsReferenced, err := v.collectReferencedBlobs() + if err != nil { + return err + } - // List all snapshots in storage - log.Info("Listing remote snapshots") - objectCh := v.Storage.ListStream(v.ctx, "metadata/") + allBlobs, err := v.listAllRemoteBlobs() + if err != nil { + return err + } - var snapshotIDs []string - for object := range objectCh { - if object.Err != nil { - return fmt.Errorf("listing remote snapshots: %w", object.Err) + unreferencedBlobs, totalSize := v.findUnreferencedBlobs(allBlobs, allBlobsReferenced) + + result := &PruneBlobsResult{BlobsFound: len(unreferencedBlobs)} + + if len(unreferencedBlobs) == 0 { + log.Info("No unreferenced blobs found") + if opts.JSON { + return v.outputPruneBlobsJSON(result) } + v.printlnStdout("No unreferenced blobs to remove.") + return nil + } - // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/ - parts := strings.Split(object.Key, "/") - if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { - // Check if this is a directory by looking for trailing slash - if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") { - snapshotID := parts[1] - // Only add unique snapshot IDs - found := false - for _, id := range snapshotIDs { - if id == snapshotID { - found = true - break - } - } - if !found { - snapshotIDs = append(snapshotIDs, snapshotID) - } - } + log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize))) + if !opts.JSON { + v.printfStdout("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize))) + } + + if !opts.Force && !opts.JSON { + v.printfStdout("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs)) + var confirm string + if _, err := v.scanStdin(&confirm); err != nil { + v.printlnStdout("Cancelled") + return nil + } + if strings.ToLower(confirm) != "y" { + v.printlnStdout("Cancelled") + return nil } } + v.deleteUnreferencedBlobs(unreferencedBlobs, allBlobs, result) + + if opts.JSON { + return v.outputPruneBlobsJSON(result) + } + + v.printfStdout("\nDeleted %d blob(s) totaling %s\n", result.BlobsDeleted, humanize.Bytes(uint64(result.BytesFreed))) + if result.BlobsFailed > 0 { + v.printfStdout("Failed to delete %d blob(s)\n", result.BlobsFailed) + } + + return nil +} + +// collectReferencedBlobs downloads all manifests and returns the set of referenced blob hashes +func (v *Vaultik) collectReferencedBlobs() (map[string]bool, error) { + log.Info("Listing remote snapshots") + snapshotIDs, err := v.listUniqueSnapshotIDs() + if err != nil { + return nil, fmt.Errorf("listing snapshot IDs: %w", err) + } log.Info("Found manifests in remote storage", "count", len(snapshotIDs)) - // Download and parse each manifest to get referenced blobs + allBlobsReferenced := make(map[string]bool) + manifestCount := 0 + for _, snapshotID := range snapshotIDs { log.Debug("Processing manifest", "snapshot_id", snapshotID) - manifest, err := v.downloadManifest(snapshotID) if err != nil { log.Error("Failed to download manifest", "snapshot_id", snapshotID, "error", err) continue } - - // Add all blobs from this manifest to our referenced set for _, blob := range manifest.Blobs { allBlobsReferenced[blob.Hash] = true } @@ -83,75 +108,69 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { } log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced)) + return allBlobsReferenced, nil +} - // List all blobs in storage +// listUniqueSnapshotIDs returns deduplicated snapshot IDs from remote metadata +func (v *Vaultik) listUniqueSnapshotIDs() ([]string, error) { + objectCh := v.Storage.ListStream(v.ctx, "metadata/") + seen := make(map[string]bool) + var snapshotIDs []string + + for object := range objectCh { + if object.Err != nil { + return nil, fmt.Errorf("listing metadata objects: %w", object.Err) + } + parts := strings.Split(object.Key, "/") + if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { + if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") { + snapshotID := parts[1] + if !seen[snapshotID] { + seen[snapshotID] = true + snapshotIDs = append(snapshotIDs, snapshotID) + } + } + } + } + return snapshotIDs, nil +} + +// listAllRemoteBlobs returns a map of all blob hashes to their sizes in remote storage +func (v *Vaultik) listAllRemoteBlobs() (map[string]int64, error) { log.Info("Listing all blobs in storage") - allBlobs := make(map[string]int64) // hash -> size + allBlobs := make(map[string]int64) blobObjectCh := v.Storage.ListStream(v.ctx, "blobs/") for object := range blobObjectCh { if object.Err != nil { - return fmt.Errorf("listing blobs: %w", object.Err) + return nil, fmt.Errorf("listing blobs: %w", object.Err) } - - // Extract hash from path like blobs/ab/cd/abcdef123456... parts := strings.Split(object.Key, "/") if len(parts) == 4 && parts[0] == "blobs" { - hash := parts[3] - allBlobs[hash] = object.Size + allBlobs[parts[3]] = object.Size } } log.Info("Found blobs in storage", "count", len(allBlobs)) + return allBlobs, nil +} - // Find unreferenced blobs - var unreferencedBlobs []string +// findUnreferencedBlobs returns blob hashes not referenced by any manifest and their total size +func (v *Vaultik) findUnreferencedBlobs(allBlobs map[string]int64, referenced map[string]bool) ([]string, int64) { + var unreferenced []string var totalSize int64 for hash, size := range allBlobs { - if !allBlobsReferenced[hash] { - unreferencedBlobs = append(unreferencedBlobs, hash) + if !referenced[hash] { + unreferenced = append(unreferenced, hash) totalSize += size } } + return unreferenced, totalSize +} - result := &PruneBlobsResult{ - BlobsFound: len(unreferencedBlobs), - } - - if len(unreferencedBlobs) == 0 { - log.Info("No unreferenced blobs found") - if opts.JSON { - return outputPruneBlobsJSON(result) - } - fmt.Println("No unreferenced blobs to remove.") - return nil - } - - // Show what will be deleted - log.Info("Found unreferenced blobs", "count", len(unreferencedBlobs), "total_size", humanize.Bytes(uint64(totalSize))) - if !opts.JSON { - fmt.Printf("Found %d unreferenced blob(s) totaling %s\n", len(unreferencedBlobs), humanize.Bytes(uint64(totalSize))) - } - - // Confirm unless --force is used (skip in JSON mode - require --force) - if !opts.Force && !opts.JSON { - fmt.Printf("\nDelete %d unreferenced blob(s)? [y/N] ", len(unreferencedBlobs)) - var confirm string - if _, err := fmt.Scanln(&confirm); err != nil { - // Treat EOF or error as "no" - fmt.Println("Cancelled") - return nil - } - if strings.ToLower(confirm) != "y" { - fmt.Println("Cancelled") - return nil - } - } - - // Delete unreferenced blobs +// deleteUnreferencedBlobs deletes the given blobs from storage and populates the result +func (v *Vaultik) deleteUnreferencedBlobs(unreferencedBlobs []string, allBlobs map[string]int64, result *PruneBlobsResult) { log.Info("Deleting unreferenced blobs") - deletedCount := 0 - deletedSize := int64(0) for i, hash := range unreferencedBlobs { blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash) @@ -161,10 +180,9 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { continue } - deletedCount++ - deletedSize += allBlobs[hash] + result.BlobsDeleted++ + result.BytesFreed += allBlobs[hash] - // Progress update every 100 blobs if (i+1)%100 == 0 || i == len(unreferencedBlobs)-1 { log.Info("Deletion progress", "deleted", i+1, @@ -174,31 +192,18 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error { } } - result.BlobsDeleted = deletedCount - result.BlobsFailed = len(unreferencedBlobs) - deletedCount - result.BytesFreed = deletedSize + result.BlobsFailed = len(unreferencedBlobs) - result.BlobsDeleted log.Info("Prune complete", - "deleted_count", deletedCount, - "deleted_size", humanize.Bytes(uint64(deletedSize)), - "failed", len(unreferencedBlobs)-deletedCount, + "deleted_count", result.BlobsDeleted, + "deleted_size", humanize.Bytes(uint64(result.BytesFreed)), + "failed", result.BlobsFailed, ) - - if opts.JSON { - return outputPruneBlobsJSON(result) - } - - fmt.Printf("\nDeleted %d blob(s) totaling %s\n", deletedCount, humanize.Bytes(uint64(deletedSize))) - if deletedCount < len(unreferencedBlobs) { - fmt.Printf("Failed to delete %d blob(s)\n", len(unreferencedBlobs)-deletedCount) - } - - return nil } // outputPruneBlobsJSON outputs the prune result as JSON -func outputPruneBlobsJSON(result *PruneBlobsResult) error { - encoder := json.NewEncoder(os.Stdout) +func (v *Vaultik) outputPruneBlobsJSON(result *PruneBlobsResult) error { + encoder := json.NewEncoder(v.Stdout) encoder.SetIndent("", " ") return encoder.Encode(result) } diff --git a/internal/vaultik/purge_per_name_test.go b/internal/vaultik/purge_per_name_test.go new file mode 100644 index 0000000..e76c0bc --- /dev/null +++ b/internal/vaultik/purge_per_name_test.go @@ -0,0 +1,256 @@ +package vaultik_test + +import ( + "bytes" + "context" + "database/sql" + "strings" + "testing" + "time" + + "git.eeqj.de/sneak/vaultik/internal/database" + "git.eeqj.de/sneak/vaultik/internal/log" + "git.eeqj.de/sneak/vaultik/internal/types" + "git.eeqj.de/sneak/vaultik/internal/vaultik" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupPurgeTest creates a Vaultik instance with an in-memory database and mock +// storage pre-populated with the given snapshot IDs. Each snapshot is marked as +// completed. Remote metadata stubs are created so syncWithRemote keeps them. +func setupPurgeTest(t *testing.T, snapshotIDs []string) *vaultik.Vaultik { + t.Helper() + log.Initialize(log.Config{}) + + ctx := context.Background() + db, err := database.New(ctx, ":memory:") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + repos := database.NewRepositories(db) + mockStorage := NewMockStorer() + + // Insert each snapshot into the DB and create remote metadata stubs. + // Use timestamps parsed from snapshot IDs for realistic ordering. + for _, id := range snapshotIDs { + // Parse timestamp from the snapshot ID + parts := strings.Split(id, "_") + timestampStr := parts[len(parts)-1] + startedAt, err := time.Parse(time.RFC3339, timestampStr) + require.NoError(t, err, "parsing timestamp from snapshot ID %q", id) + + completedAt := startedAt.Add(5 * time.Minute) + snap := &database.Snapshot{ + ID: types.SnapshotID(id), + Hostname: "testhost", + VaultikVersion: "test", + StartedAt: startedAt, + CompletedAt: &completedAt, + } + err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error { + return repos.Snapshots.Create(ctx, tx, snap) + }) + require.NoError(t, err, "creating snapshot %s", id) + + // Create remote metadata stub so syncWithRemote keeps it + metadataKey := "metadata/" + id + "/manifest.json.zst" + err = mockStorage.Put(ctx, metadataKey, strings.NewReader("stub")) + require.NoError(t, err) + } + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + stdin := &bytes.Buffer{} + + v := &vaultik.Vaultik{ + Storage: mockStorage, + Repositories: repos, + DB: db, + Stdout: stdout, + Stderr: stderr, + Stdin: stdin, + } + v.SetContext(ctx) + + return v +} + +// listRemainingSnapshots returns IDs of all completed snapshots in the database. +func listRemainingSnapshots(t *testing.T, v *vaultik.Vaultik) []string { + t.Helper() + ctx := context.Background() + dbSnaps, err := v.Repositories.Snapshots.ListRecent(ctx, 10000) + require.NoError(t, err) + + var ids []string + for _, s := range dbSnaps { + if s.CompletedAt != nil { + ids = append(ids, s.ID.String()) + } + } + return ids +} + +func TestPurgeKeepLatest_PerName(t *testing.T) { + // Create snapshots for two different names: "home" and "system". + // With per-name --keep-latest, the latest of each should be kept. + snapshotIDs := []string{ + "testhost_system_2026-01-01T00:00:00Z", + "testhost_home_2026-01-01T01:00:00Z", + "testhost_system_2026-01-01T02:00:00Z", + "testhost_home_2026-01-01T03:00:00Z", + "testhost_system_2026-01-01T04:00:00Z", + } + + v := setupPurgeTest(t, snapshotIDs) + + err := v.PurgeSnapshotsWithOptions(&vaultik.SnapshotPurgeOptions{ + KeepLatest: true, + Force: true, + }) + require.NoError(t, err) + + remaining := listRemainingSnapshots(t, v) + + // Should keep the latest of each name + assert.Len(t, remaining, 2, "should keep exactly 2 snapshots (one per name)") + assert.Contains(t, remaining, "testhost_system_2026-01-01T04:00:00Z", "should keep latest system") + assert.Contains(t, remaining, "testhost_home_2026-01-01T03:00:00Z", "should keep latest home") +} + +func TestPurgeKeepLatest_SingleName(t *testing.T) { + // All snapshots have the same name — keep-latest should keep exactly one. + snapshotIDs := []string{ + "testhost_home_2026-01-01T00:00:00Z", + "testhost_home_2026-01-01T01:00:00Z", + "testhost_home_2026-01-01T02:00:00Z", + } + + v := setupPurgeTest(t, snapshotIDs) + + err := v.PurgeSnapshotsWithOptions(&vaultik.SnapshotPurgeOptions{ + KeepLatest: true, + Force: true, + }) + require.NoError(t, err) + + remaining := listRemainingSnapshots(t, v) + assert.Len(t, remaining, 1) + assert.Contains(t, remaining, "testhost_home_2026-01-01T02:00:00Z", "should keep the newest") +} + +func TestPurgeKeepLatest_WithNameFilter(t *testing.T) { + // Use --name to filter purge to only "home" snapshots. + // "system" snapshots should be untouched. + snapshotIDs := []string{ + "testhost_system_2026-01-01T00:00:00Z", + "testhost_home_2026-01-01T01:00:00Z", + "testhost_system_2026-01-01T02:00:00Z", + "testhost_home_2026-01-01T03:00:00Z", + "testhost_home_2026-01-01T04:00:00Z", + } + + v := setupPurgeTest(t, snapshotIDs) + + err := v.PurgeSnapshotsWithOptions(&vaultik.SnapshotPurgeOptions{ + KeepLatest: true, + Force: true, + Names: []string{"home"}, + }) + require.NoError(t, err) + + remaining := listRemainingSnapshots(t, v) + + // 2 system snapshots untouched + 1 latest home = 3 + assert.Len(t, remaining, 3) + assert.Contains(t, remaining, "testhost_system_2026-01-01T00:00:00Z") + assert.Contains(t, remaining, "testhost_system_2026-01-01T02:00:00Z") + assert.Contains(t, remaining, "testhost_home_2026-01-01T04:00:00Z") +} + +func TestPurgeKeepLatest_NoSnapshots(t *testing.T) { + v := setupPurgeTest(t, nil) + + err := v.PurgeSnapshotsWithOptions(&vaultik.SnapshotPurgeOptions{ + KeepLatest: true, + Force: true, + }) + require.NoError(t, err) +} + +func TestPurgeKeepLatest_NameFilterNoMatch(t *testing.T) { + snapshotIDs := []string{ + "testhost_system_2026-01-01T00:00:00Z", + "testhost_system_2026-01-01T01:00:00Z", + } + + v := setupPurgeTest(t, snapshotIDs) + + err := v.PurgeSnapshotsWithOptions(&vaultik.SnapshotPurgeOptions{ + KeepLatest: true, + Force: true, + Names: []string{"nonexistent"}, + }) + require.NoError(t, err) + + // All snapshots should remain — the name filter matched nothing + remaining := listRemainingSnapshots(t, v) + assert.Len(t, remaining, 2) +} + +func TestPurgeOlderThan_WithNameFilter(t *testing.T) { + // Snapshots with different names and timestamps. + // --older-than should apply only to the named subset when --name is used. + snapshotIDs := []string{ + "testhost_system_2020-01-01T00:00:00Z", + "testhost_home_2020-01-01T00:00:00Z", + "testhost_system_2026-01-01T00:00:00Z", + "testhost_home_2026-01-01T00:00:00Z", + } + + v := setupPurgeTest(t, snapshotIDs) + + // Purge only "home" snapshots older than 365 days + err := v.PurgeSnapshotsWithOptions(&vaultik.SnapshotPurgeOptions{ + OlderThan: "365d", + Force: true, + Names: []string{"home"}, + }) + require.NoError(t, err) + + remaining := listRemainingSnapshots(t, v) + + // Old system stays (not filtered by name), old home deleted, recent ones stay + assert.Len(t, remaining, 3) + assert.Contains(t, remaining, "testhost_system_2020-01-01T00:00:00Z") + assert.Contains(t, remaining, "testhost_system_2026-01-01T00:00:00Z") + assert.Contains(t, remaining, "testhost_home_2026-01-01T00:00:00Z") +} + +func TestPurgeKeepLatest_ThreeNames(t *testing.T) { + // Three different snapshot names with multiple snapshots each. + snapshotIDs := []string{ + "testhost_home_2026-01-01T00:00:00Z", + "testhost_system_2026-01-01T01:00:00Z", + "testhost_media_2026-01-01T02:00:00Z", + "testhost_home_2026-01-01T03:00:00Z", + "testhost_system_2026-01-01T04:00:00Z", + "testhost_media_2026-01-01T05:00:00Z", + "testhost_home_2026-01-01T06:00:00Z", + } + + v := setupPurgeTest(t, snapshotIDs) + + err := v.PurgeSnapshotsWithOptions(&vaultik.SnapshotPurgeOptions{ + KeepLatest: true, + Force: true, + }) + require.NoError(t, err) + + remaining := listRemainingSnapshots(t, v) + assert.Len(t, remaining, 3, "should keep one per name") + assert.Contains(t, remaining, "testhost_home_2026-01-01T06:00:00Z") + assert.Contains(t, remaining, "testhost_system_2026-01-01T04:00:00Z") + assert.Contains(t, remaining, "testhost_media_2026-01-01T05:00:00Z") +} diff --git a/internal/vaultik/restore.go b/internal/vaultik/restore.go index 015c533..5797fc8 100644 --- a/internal/vaultik/restore.go +++ b/internal/vaultik/restore.go @@ -22,6 +22,13 @@ import ( "golang.org/x/term" ) +const ( + // progressBarWidth is the character width of the progress bar display. + progressBarWidth = 40 + // progressBarThrottle is the minimum interval between progress bar redraws. + progressBarThrottle = 100 * time.Millisecond +) + // RestoreOptions contains options for the restore operation type RestoreOptions struct { SnapshotID string @@ -48,15 +55,9 @@ type RestoreResult struct { func (v *Vaultik) Restore(opts *RestoreOptions) error { startTime := time.Now() - // Check for age_secret_key - if v.Config.AgeSecretKey == "" { - return fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'") - } - - // Parse the age identity - identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey) + identity, err := v.prepareRestoreIdentity() if err != nil { - return fmt.Errorf("parsing age secret key: %w", err) + return err } log.Info("Starting restore operation", @@ -108,27 +109,9 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { } // Step 5: Restore files - result := &RestoreResult{} - blobCache := make(map[string][]byte) // Cache downloaded and decrypted blobs - - for i, file := range files { - if v.ctx.Err() != nil { - return v.ctx.Err() - } - - if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil { - log.Error("Failed to restore file", "path", file.Path, "error", err) - // Continue with other files - continue - } - - // Progress logging - if (i+1)%100 == 0 || i+1 == len(files) { - log.Info("Restore progress", - "files", fmt.Sprintf("%d/%d", i+1, len(files)), - "bytes", humanize.Bytes(uint64(result.BytesRestored)), - ) - } + result, err := v.restoreAllFiles(files, repos, opts, identity, chunkToBlobMap) + if err != nil { + return err } result.Duration = time.Since(startTime) @@ -141,32 +124,130 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error { "duration", result.Duration, ) - _, _ = fmt.Fprintf(v.Stdout, "Restored %d files (%s) in %s\n", + v.printfStdout("Restored %d files (%s) in %s\n", result.FilesRestored, humanize.Bytes(uint64(result.BytesRestored)), result.Duration.Round(time.Second), ) - // Run verification if requested - if opts.Verify { - if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil { - return fmt.Errorf("verification failed: %w", err) + if result.FilesFailed > 0 { + _, _ = fmt.Fprintf(v.Stdout, "\nWARNING: %d file(s) failed to restore:\n", result.FilesFailed) + for _, path := range result.FailedFiles { + _, _ = fmt.Fprintf(v.Stdout, " - %s\n", path) } - - if result.FilesFailed > 0 { - _, _ = fmt.Fprintf(v.Stdout, "\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed) - for _, path := range result.FailedFiles { - _, _ = fmt.Fprintf(v.Stdout, " - %s\n", path) - } - return fmt.Errorf("%d files failed verification", result.FilesFailed) - } - - _, _ = fmt.Fprintf(v.Stdout, "Verified %d files (%s)\n", - result.FilesVerified, - humanize.Bytes(uint64(result.BytesVerified)), - ) } + // Run verification if requested + if opts.Verify { + if err := v.handleRestoreVerification(repos, files, opts, result); err != nil { + return err + } + } + + if result.FilesFailed > 0 { + return fmt.Errorf("%d file(s) failed to restore", result.FilesFailed) + } + + return nil +} + +// prepareRestoreIdentity validates that an age secret key is configured and parses it +func (v *Vaultik) prepareRestoreIdentity() (age.Identity, error) { + if v.Config.AgeSecretKey == "" { + return nil, fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'") + } + + identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey) + if err != nil { + return nil, fmt.Errorf("parsing age secret key: %w", err) + } + return identity, nil +} + +// restoreAllFiles iterates over files and restores each one, tracking progress and failures +func (v *Vaultik) restoreAllFiles( + files []*database.File, + repos *database.Repositories, + opts *RestoreOptions, + identity age.Identity, + chunkToBlobMap map[string]*database.BlobChunk, +) (*RestoreResult, error) { + result := &RestoreResult{} + blobCache, err := newBlobDiskCache(4 * v.Config.BlobSizeLimit.Int64()) + if err != nil { + return nil, fmt.Errorf("creating blob cache: %w", err) + } + defer func() { _ = blobCache.Close() }() + + // Calculate total bytes for progress bar + var totalBytesExpected int64 + for _, file := range files { + totalBytesExpected += file.Size + } + + // Create progress bar if output is a terminal + bar := v.newProgressBar("Restoring", totalBytesExpected) + + for i, file := range files { + if v.ctx.Err() != nil { + return nil, v.ctx.Err() + } + + if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil { + log.Error("Failed to restore file", "path", file.Path, "error", err) + result.FilesFailed++ + result.FailedFiles = append(result.FailedFiles, file.Path.String()) + // Update progress bar even on failure + if bar != nil { + _ = bar.Add64(file.Size) + } + continue + } + + // Update progress bar + if bar != nil { + _ = bar.Add64(file.Size) + } + + // Progress logging (for non-terminal or structured logs) + if (i+1)%100 == 0 || i+1 == len(files) { + log.Info("Restore progress", + "files", fmt.Sprintf("%d/%d", i+1, len(files)), + "bytes", humanize.Bytes(uint64(result.BytesRestored)), + ) + } + } + + if bar != nil { + _ = bar.Finish() + } + + return result, nil +} + +// handleRestoreVerification runs post-restore verification if requested +func (v *Vaultik) handleRestoreVerification( + repos *database.Repositories, + files []*database.File, + opts *RestoreOptions, + result *RestoreResult, +) error { + if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil { + return fmt.Errorf("verification failed: %w", err) + } + + if result.FilesFailed > 0 { + v.printfStdout("\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed) + for _, path := range result.FailedFiles { + v.printfStdout(" - %s\n", path) + } + return fmt.Errorf("%d files failed verification", result.FilesFailed) + } + + v.printfStdout("Verified %d files (%s)\n", + result.FilesVerified, + humanize.Bytes(uint64(result.BytesVerified)), + ) return nil } @@ -299,7 +380,7 @@ func (v *Vaultik) restoreFile( targetDir string, identity age.Identity, chunkToBlobMap map[string]*database.BlobChunk, - blobCache map[string][]byte, + blobCache *blobDiskCache, result *RestoreResult, ) error { // Calculate target path - use full original path under target directory @@ -383,7 +464,7 @@ func (v *Vaultik) restoreRegularFile( targetPath string, identity age.Identity, chunkToBlobMap map[string]*database.BlobChunk, - blobCache map[string][]byte, + blobCache *blobDiskCache, result *RestoreResult, ) error { // Get file chunks in order @@ -417,13 +498,15 @@ func (v *Vaultik) restoreRegularFile( // Download and decrypt blob if not cached blobHashStr := blob.Hash.String() - blobData, ok := blobCache[blobHashStr] + blobData, ok := blobCache.Get(blobHashStr) if !ok { blobData, err = v.downloadBlob(ctx, blobHashStr, blob.CompressedSize, identity) if err != nil { return fmt.Errorf("downloading blob %s: %w", blobHashStr[:16], err) } - blobCache[blobHashStr] = blobData + if putErr := blobCache.Put(blobHashStr, blobData); putErr != nil { + log.Debug("Failed to cache blob on disk", "hash", blobHashStr[:16], "error", putErr) + } result.BlobsDownloaded++ result.BytesDownloaded += blob.CompressedSize } @@ -475,11 +558,23 @@ func (v *Vaultik) restoreRegularFile( // downloadBlob downloads and decrypts a blob func (v *Vaultik) downloadBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) ([]byte, error) { - result, err := v.FetchAndDecryptBlob(ctx, blobHash, expectedSize, identity) + rc, err := v.FetchAndDecryptBlob(ctx, blobHash, expectedSize, identity) if err != nil { return nil, err } - return result.Data, nil + + data, err := io.ReadAll(rc) + if err != nil { + _ = rc.Close() + return nil, fmt.Errorf("reading blob data: %w", err) + } + + // Close triggers hash verification + if err := rc.Close(); err != nil { + return nil, err + } + + return data, nil } // verifyRestoredFiles verifies that all restored files match their expected chunk hashes @@ -511,28 +606,13 @@ func (v *Vaultik) verifyRestoredFiles( "files", len(regularFiles), "bytes", humanize.Bytes(uint64(totalBytes)), ) - _, _ = fmt.Fprintf(v.Stdout, "\nVerifying %d files (%s)...\n", + v.printfStdout("\nVerifying %d files (%s)...\n", len(regularFiles), humanize.Bytes(uint64(totalBytes)), ) // Create progress bar if output is a terminal - var bar *progressbar.ProgressBar - if isTerminal() { - bar = progressbar.NewOptions64( - totalBytes, - progressbar.OptionSetDescription("Verifying"), - progressbar.OptionSetWriter(os.Stderr), - progressbar.OptionShowBytes(true), - progressbar.OptionShowCount(), - progressbar.OptionSetWidth(40), - progressbar.OptionThrottle(100*time.Millisecond), - progressbar.OptionOnCompletion(func() { - fmt.Fprint(os.Stderr, "\n") - }), - progressbar.OptionSetRenderBlankState(true), - ) - } + bar := v.newProgressBar("Verifying", totalBytes) // Verify each file for _, file := range regularFiles { @@ -626,7 +706,37 @@ func (v *Vaultik) verifyFile( return bytesVerified, nil } -// isTerminal returns true if stdout is a terminal -func isTerminal() bool { - return term.IsTerminal(int(os.Stdout.Fd())) +// newProgressBar creates a terminal-aware progress bar with standard options. +// It returns nil if stdout is not a terminal. +func (v *Vaultik) newProgressBar(description string, total int64) *progressbar.ProgressBar { + if !v.isTerminal() { + return nil + } + return progressbar.NewOptions64( + total, + progressbar.OptionSetDescription(description), + progressbar.OptionSetWriter(v.Stderr), + progressbar.OptionShowBytes(true), + progressbar.OptionShowCount(), + progressbar.OptionSetWidth(progressBarWidth), + progressbar.OptionThrottle(progressBarThrottle), + progressbar.OptionOnCompletion(func() { + v.printfStderr("\n") + }), + progressbar.OptionSetRenderBlankState(true), + ) +} + +// isTerminal returns true if stdout is a terminal. +// It checks whether v.Stdout implements Fd() (i.e. is an *os.File), +// and falls back to false for non-file writers (e.g. in tests). +func (v *Vaultik) isTerminal() bool { + type fder interface { + Fd() uintptr + } + f, ok := v.Stdout.(fder) + if !ok { + return false + } + return term.IsTerminal(int(f.Fd())) } diff --git a/internal/vaultik/snapshot.go b/internal/vaultik/snapshot.go index b9d1142..8fd6a46 100644 --- a/internal/vaultik/snapshot.go +++ b/internal/vaultik/snapshot.go @@ -5,8 +5,10 @@ import ( "fmt" "os" "path/filepath" + "regexp" "sort" "strings" + "sync" "text/tabwriter" "time" @@ -15,6 +17,7 @@ import ( "git.eeqj.de/sneak/vaultik/internal/snapshot" "git.eeqj.de/sneak/vaultik/internal/types" "github.com/dustin/go-humanize" + "golang.org/x/sync/errgroup" ) // SnapshotCreateOptions contains options for the snapshot create command @@ -79,7 +82,7 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error { // Print overall summary if multiple snapshots if len(snapshotNames) > 1 { - _, _ = fmt.Fprintf(v.Stdout, "\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second)) + v.printfStdout("\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second)) } if opts.Prune { @@ -96,60 +99,53 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error { // when `snapshot create --prune` is used. func (v *Vaultik) runPostBackupPrune(snapshotNames []string) error { log.Info("Running post-backup prune", "snapshots", snapshotNames) - _, _ = fmt.Fprintln(v.Stdout, "\n=== Post-backup prune ===") + v.printlnStdout("\n=== Post-backup prune ===") - purgeOpts := &PurgeOptions{ + purgeOpts := &SnapshotPurgeOptions{ KeepLatest: true, Force: true, Names: snapshotNames, Quiet: true, } - if err := v.PurgeSnapshots(purgeOpts); err != nil { + if err := v.PurgeSnapshotsWithOptions(purgeOpts); err != nil { return fmt.Errorf("purging old snapshots: %w", err) } - pruneOpts := &PruneOptions{Force: true} - if err := v.PruneBlobs(pruneOpts); err != nil { + if err := v.PruneBlobs(&PruneOptions{Force: true}); err != nil { return fmt.Errorf("pruning orphaned blobs: %w", err) } return nil } +// snapshotStats tracks aggregate statistics across directory scans +type snapshotStats struct { + totalFiles int + totalBytes int64 + totalChunks int + totalBlobs int + totalBytesSkipped int64 + totalFilesSkipped int + totalFilesDeleted int + totalBytesDeleted int64 + totalBytesUploaded int64 + totalBlobsUploaded int + uploadDuration time.Duration +} + // createNamedSnapshot creates a single named snapshot func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, snapName string, idx, total int) error { snapshotStartTime := time.Now() - snapConfig := v.Config.Snapshots[snapName] - if total > 1 { - _, _ = fmt.Fprintf(v.Stdout, "\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName) + v.printfStdout("\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName) } - // Resolve source directories to absolute paths - resolvedDirs := make([]string, 0, len(snapConfig.Paths)) - for _, dir := range snapConfig.Paths { - absPath, err := filepath.Abs(dir) - if err != nil { - return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err) - } - - // Resolve symlinks - resolvedPath, err := filepath.EvalSymlinks(absPath) - if err != nil { - // If the path doesn't exist yet, use the absolute path - if os.IsNotExist(err) { - resolvedPath = absPath - } else { - return fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err) - } - } - - resolvedDirs = append(resolvedDirs, resolvedPath) + resolvedDirs, err := v.resolveSnapshotPaths(snapName) + if err != nil { + return err } - // Create scanner with progress enabled (unless in cron mode) - // Pass the combined excludes for this snapshot scanner := v.ScannerFactory(snapshot.ScannerParams{ EnableProgress: !opts.Cron, Fs: v.Fs, @@ -157,51 +153,89 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna SkipErrors: opts.SkipErrors, }) - // Statistics tracking - totalFiles := 0 - totalBytes := int64(0) - totalChunks := 0 - totalBlobs := 0 - totalBytesSkipped := int64(0) - totalFilesSkipped := 0 - totalFilesDeleted := 0 - totalBytesDeleted := int64(0) - totalBytesUploaded := int64(0) - totalBlobsUploaded := 0 - uploadDuration := time.Duration(0) - - // Create a new snapshot at the beginning (with snapshot name in ID) snapshotID, err := v.SnapshotManager.CreateSnapshotWithName(v.ctx, hostname, snapName, v.Globals.Version, v.Globals.Commit) if err != nil { return fmt.Errorf("creating snapshot: %w", err) } log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName) - _, _ = fmt.Fprintf(v.Stdout, "Beginning snapshot: %s\n", snapshotID) + v.printfStdout("Beginning snapshot: %s\n", snapshotID) + + stats, err := v.scanAllDirectories(scanner, resolvedDirs, snapshotID) + if err != nil { + return err + } + + v.collectUploadStats(scanner, stats) + + if err := v.finalizeSnapshotMetadata(snapshotID, stats); err != nil { + return err + } + + log.Info("Snapshot complete", + "snapshot_id", snapshotID, + "name", snapName, + "files", stats.totalFiles, + "blobs_uploaded", stats.totalBlobsUploaded, + "bytes_uploaded", stats.totalBytesUploaded, + "duration", time.Since(snapshotStartTime)) + + v.printSnapshotSummary(snapshotID, snapshotStartTime, stats) + return nil +} + +// resolveSnapshotPaths resolves source directories to absolute paths with symlink resolution +func (v *Vaultik) resolveSnapshotPaths(snapName string) ([]string, error) { + snapConfig := v.Config.Snapshots[snapName] + resolvedDirs := make([]string, 0, len(snapConfig.Paths)) + + for _, dir := range snapConfig.Paths { + absPath, err := filepath.Abs(dir) + if err != nil { + return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err) + } + + resolvedPath, err := filepath.EvalSymlinks(absPath) + if err != nil { + if os.IsNotExist(err) { + resolvedPath = absPath + } else { + return nil, fmt.Errorf("failed to resolve symlinks for %s: %w", absPath, err) + } + } + + resolvedDirs = append(resolvedDirs, resolvedPath) + } + + return resolvedDirs, nil +} + +// scanAllDirectories runs the scanner on each resolved directory and accumulates stats +func (v *Vaultik) scanAllDirectories(scanner *snapshot.Scanner, resolvedDirs []string, snapshotID string) (*snapshotStats, error) { + stats := &snapshotStats{} for i, dir := range resolvedDirs { - // Check if context is cancelled select { case <-v.ctx.Done(): log.Info("Snapshot creation cancelled") - return v.ctx.Err() + return nil, v.ctx.Err() default: } log.Info("Scanning directory", "path", dir) - _, _ = fmt.Fprintf(v.Stdout, "Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir) + v.printfStdout("Beginning directory scan (%d/%d): %s\n", i+1, len(resolvedDirs), dir) result, err := scanner.Scan(v.ctx, dir, snapshotID) if err != nil { - return fmt.Errorf("failed to scan %s: %w", dir, err) + return nil, fmt.Errorf("failed to scan %s: %w", dir, err) } - totalFiles += result.FilesScanned - totalBytes += result.BytesScanned - totalChunks += result.ChunksCreated - totalBlobs += result.BlobsCreated - totalFilesSkipped += result.FilesSkipped - totalBytesSkipped += result.BytesSkipped - totalFilesDeleted += result.FilesDeleted - totalBytesDeleted += result.BytesDeleted + stats.totalFiles += result.FilesScanned + stats.totalBytes += result.BytesScanned + stats.totalChunks += result.ChunksCreated + stats.totalBlobs += result.BlobsCreated + stats.totalFilesSkipped += result.FilesSkipped + stats.totalBytesSkipped += result.BytesSkipped + stats.totalFilesDeleted += result.FilesDeleted + stats.totalBytesDeleted += result.BytesDeleted log.Info("Directory scan complete", "path", dir, @@ -212,85 +246,79 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna "chunks", result.ChunksCreated, "blobs", result.BlobsCreated, "duration", result.EndTime.Sub(result.StartTime)) - - // Remove per-directory summary - the scanner already prints its own summary } - // Get upload statistics from scanner progress if available + return stats, nil +} + +// collectUploadStats gathers upload statistics from the scanner's progress reporter +func (v *Vaultik) collectUploadStats(scanner *snapshot.Scanner, stats *snapshotStats) { if s := scanner.GetProgress(); s != nil { - stats := s.GetStats() - totalBytesUploaded = stats.BytesUploaded.Load() - totalBlobsUploaded = int(stats.BlobsUploaded.Load()) - uploadDuration = time.Duration(stats.UploadDurationMs.Load()) * time.Millisecond + progressStats := s.GetStats() + stats.totalBytesUploaded = progressStats.BytesUploaded.Load() + stats.totalBlobsUploaded = int(progressStats.BlobsUploaded.Load()) + stats.uploadDuration = time.Duration(progressStats.UploadDurationMs.Load()) * time.Millisecond } +} - // Update snapshot statistics with extended fields +// finalizeSnapshotMetadata updates stats, marks complete, and exports metadata +func (v *Vaultik) finalizeSnapshotMetadata(snapshotID string, stats *snapshotStats) error { extStats := snapshot.ExtendedBackupStats{ BackupStats: snapshot.BackupStats{ - FilesScanned: totalFiles, - BytesScanned: totalBytes, - ChunksCreated: totalChunks, - BlobsCreated: totalBlobs, - BytesUploaded: totalBytesUploaded, + FilesScanned: stats.totalFiles, + BytesScanned: stats.totalBytes, + ChunksCreated: stats.totalChunks, + BlobsCreated: stats.totalBlobs, + BytesUploaded: stats.totalBytesUploaded, }, - BlobUncompressedSize: 0, // Will be set from database query below + BlobUncompressedSize: 0, CompressionLevel: v.Config.CompressionLevel, - UploadDurationMs: uploadDuration.Milliseconds(), + UploadDurationMs: stats.uploadDuration.Milliseconds(), } if err := v.SnapshotManager.UpdateSnapshotStatsExtended(v.ctx, snapshotID, extStats); err != nil { return fmt.Errorf("updating snapshot stats: %w", err) } - // Mark snapshot as complete if err := v.SnapshotManager.CompleteSnapshot(v.ctx, snapshotID); err != nil { return fmt.Errorf("completing snapshot: %w", err) } - // Export snapshot metadata - // Export snapshot metadata without closing the database - // The export function should handle its own database connection if err := v.SnapshotManager.ExportSnapshotMetadata(v.ctx, v.Config.IndexPath, snapshotID); err != nil { return fmt.Errorf("exporting snapshot metadata: %w", err) } - // Calculate final statistics - snapshotDuration := time.Since(snapshotStartTime) - totalFilesChanged := totalFiles - totalFilesSkipped - totalBytesChanged := totalBytes - totalBytesAll := totalBytes + totalBytesSkipped + return nil +} - // Calculate upload speed - var avgUploadSpeed string - if totalBytesUploaded > 0 && uploadDuration > 0 { - bytesPerSec := float64(totalBytesUploaded) / uploadDuration.Seconds() - bitsPerSec := bytesPerSec * 8 - if bitsPerSec >= 1e9 { - avgUploadSpeed = fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9) - } else if bitsPerSec >= 1e6 { - avgUploadSpeed = fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6) - } else if bitsPerSec >= 1e3 { - avgUploadSpeed = fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3) - } else { - avgUploadSpeed = fmt.Sprintf("%.0f bit/s", bitsPerSec) - } - } else { - avgUploadSpeed = "N/A" +// formatUploadSpeed formats bytes uploaded and duration into a human-readable speed string +func formatUploadSpeed(bytesUploaded int64, duration time.Duration) string { + if bytesUploaded <= 0 || duration <= 0 { + return "N/A" } + bytesPerSec := float64(bytesUploaded) / duration.Seconds() + bitsPerSec := bytesPerSec * 8 + switch { + case bitsPerSec >= 1e9: + return fmt.Sprintf("%.1f Gbit/s", bitsPerSec/1e9) + case bitsPerSec >= 1e6: + return fmt.Sprintf("%.0f Mbit/s", bitsPerSec/1e6) + case bitsPerSec >= 1e3: + return fmt.Sprintf("%.0f Kbit/s", bitsPerSec/1e3) + default: + return fmt.Sprintf("%.0f bit/s", bitsPerSec) + } +} + +// printSnapshotSummary prints the comprehensive snapshot completion summary +func (v *Vaultik) printSnapshotSummary(snapshotID string, startTime time.Time, stats *snapshotStats) { + snapshotDuration := time.Since(startTime) + totalFilesChanged := stats.totalFiles - stats.totalFilesSkipped + totalBytesAll := stats.totalBytes + stats.totalBytesSkipped // Get total blob sizes from database - totalBlobSizeCompressed := int64(0) - totalBlobSizeUncompressed := int64(0) - if blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID); err == nil { - for _, hash := range blobHashes { - if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil { - totalBlobSizeCompressed += blob.CompressedSize - totalBlobSizeUncompressed += blob.UncompressedSize - } - } - } + totalBlobSizeCompressed, totalBlobSizeUncompressed := v.getSnapshotBlobSizes(snapshotID) - // Calculate compression ratio var compressionRatio float64 if totalBlobSizeUncompressed > 0 { compressionRatio = float64(totalBlobSizeCompressed) / float64(totalBlobSizeUncompressed) @@ -298,55 +326,96 @@ func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, sna compressionRatio = 1.0 } - // Print comprehensive summary - _, _ = fmt.Fprintf(v.Stdout, "=== Snapshot Complete ===\n") - _, _ = fmt.Fprintf(v.Stdout, "ID: %s\n", snapshotID) - _, _ = fmt.Fprintf(v.Stdout, "Files: %s examined, %s to process, %s unchanged", - formatNumber(totalFiles), + v.printfStdout("=== Snapshot Complete ===\n") + v.printfStdout("ID: %s\n", snapshotID) + v.printfStdout("Files: %s examined, %s to process, %s unchanged", + formatNumber(stats.totalFiles), formatNumber(totalFilesChanged), - formatNumber(totalFilesSkipped)) - if totalFilesDeleted > 0 { - _, _ = fmt.Fprintf(v.Stdout, ", %s deleted", formatNumber(totalFilesDeleted)) + formatNumber(stats.totalFilesSkipped)) + if stats.totalFilesDeleted > 0 { + v.printfStdout(", %s deleted", formatNumber(stats.totalFilesDeleted)) } - _, _ = fmt.Fprintln(v.Stdout) - _, _ = fmt.Fprintf(v.Stdout, "Data: %s total (%s to process)", + v.printlnStdout() + v.printfStdout("Data: %s total (%s to process)", humanize.Bytes(uint64(totalBytesAll)), - humanize.Bytes(uint64(totalBytesChanged))) - if totalBytesDeleted > 0 { - _, _ = fmt.Fprintf(v.Stdout, ", %s deleted", humanize.Bytes(uint64(totalBytesDeleted))) + humanize.Bytes(uint64(stats.totalBytes))) + if stats.totalBytesDeleted > 0 { + v.printfStdout(", %s deleted", humanize.Bytes(uint64(stats.totalBytesDeleted))) } - _, _ = fmt.Fprintln(v.Stdout) - if totalBlobsUploaded > 0 { - _, _ = fmt.Fprintf(v.Stdout, "Storage: %s compressed from %s (%.2fx)\n", + v.printlnStdout() + if stats.totalBlobsUploaded > 0 { + v.printfStdout("Storage: %s compressed from %s (%.2fx)\n", humanize.Bytes(uint64(totalBlobSizeCompressed)), humanize.Bytes(uint64(totalBlobSizeUncompressed)), compressionRatio) - _, _ = fmt.Fprintf(v.Stdout, "Upload: %d blobs, %s in %s (%s)\n", - totalBlobsUploaded, - humanize.Bytes(uint64(totalBytesUploaded)), - formatDuration(uploadDuration), - avgUploadSpeed) + v.printfStdout("Upload: %d blobs, %s in %s (%s)\n", + stats.totalBlobsUploaded, + humanize.Bytes(uint64(stats.totalBytesUploaded)), + formatDuration(stats.uploadDuration), + formatUploadSpeed(stats.totalBytesUploaded, stats.uploadDuration)) } - _, _ = fmt.Fprintf(v.Stdout, "Duration: %s\n", formatDuration(snapshotDuration)) + v.printfStdout("Duration: %s\n", formatDuration(snapshotDuration)) +} - return nil +// getSnapshotBlobSizes returns total compressed and uncompressed blob sizes for a snapshot +func (v *Vaultik) getSnapshotBlobSizes(snapshotID string) (compressed int64, uncompressed int64) { + blobHashes, err := v.Repositories.Snapshots.GetBlobHashes(v.ctx, snapshotID) + if err != nil { + return 0, 0 + } + for _, hash := range blobHashes { + if blob, err := v.Repositories.Blobs.GetByHash(v.ctx, hash); err == nil && blob != nil { + compressed += blob.CompressedSize + uncompressed += blob.UncompressedSize + } + } + return compressed, uncompressed } // ListSnapshots lists all snapshots func (v *Vaultik) ListSnapshots(jsonOutput bool) error { - // Get all remote snapshots + log.Info("Listing snapshots") + remoteSnapshots, err := v.listRemoteSnapshotIDs() + if err != nil { + return err + } + + localSnapshotMap, err := v.reconcileLocalWithRemote(remoteSnapshots) + if err != nil { + return err + } + + snapshots, err := v.buildSnapshotInfoList(remoteSnapshots, localSnapshotMap) + if err != nil { + return err + } + + // Sort by timestamp (newest first) + sort.Slice(snapshots, func(i, j int) bool { + return snapshots[i].Timestamp.After(snapshots[j].Timestamp) + }) + + if jsonOutput { + encoder := json.NewEncoder(v.Stdout) + encoder.SetIndent("", " ") + return encoder.Encode(snapshots) + } + + return v.printSnapshotTable(snapshots) +} + +// listRemoteSnapshotIDs returns a set of snapshot IDs found in remote storage +func (v *Vaultik) listRemoteSnapshotIDs() (map[string]bool, error) { remoteSnapshots := make(map[string]bool) objectCh := v.Storage.ListStream(v.ctx, "metadata/") for object := range objectCh { if object.Err != nil { - return fmt.Errorf("listing remote snapshots: %w", object.Err) + return nil, fmt.Errorf("listing remote snapshots: %w", object.Err) } - // Extract snapshot ID from paths like metadata/hostname-20240115-143052Z/ parts := strings.Split(object.Key, "/") if len(parts) >= 2 && parts[0] == "metadata" && parts[1] != "" { - // Skip macOS resource fork files (._*) and other hidden files if strings.HasPrefix(parts[1], ".") { continue } @@ -354,56 +423,36 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { } } - // Get all local snapshots + return remoteSnapshots, nil +} + +// reconcileLocalWithRemote builds a map of local snapshots keyed by ID for cross-referencing with remote +func (v *Vaultik) reconcileLocalWithRemote(remoteSnapshots map[string]bool) (map[string]*database.Snapshot, error) { localSnapshots, err := v.Repositories.Snapshots.ListRecent(v.ctx, 10000) if err != nil { - return fmt.Errorf("listing local snapshots: %w", err) + return nil, fmt.Errorf("listing local snapshots: %w", err) } - // Build a map of local snapshots for quick lookup localSnapshotMap := make(map[string]*database.Snapshot) for _, s := range localSnapshots { localSnapshotMap[s.ID.String()] = s } - // Remove local snapshots that don't exist remotely - for _, snapshot := range localSnapshots { - snapshotIDStr := snapshot.ID.String() - if !remoteSnapshots[snapshotIDStr] { - log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID) + return localSnapshotMap, nil +} - // Delete related records first to avoid foreign key constraints - if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err) - } - if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err) - } - if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err) - } - - // Now delete the snapshot itself - if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotIDStr); err != nil { - log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err) - } else { - log.Info("Deleted local snapshot not found in remote", "snapshot_id", snapshot.ID) - delete(localSnapshotMap, snapshotIDStr) - } - } - } - - // Build final snapshot list +// buildSnapshotInfoList constructs SnapshotInfo entries from remote IDs and local data +func (v *Vaultik) buildSnapshotInfoList(remoteSnapshots map[string]bool, localSnapshotMap map[string]*database.Snapshot) ([]SnapshotInfo, error) { snapshots := make([]SnapshotInfo, 0, len(remoteSnapshots)) + // remoteOnly collects snapshot IDs that need a manifest download. + var remoteOnly []string + for snapshotID := range remoteSnapshots { - // Check if we have this snapshot locally if localSnap, exists := localSnapshotMap[snapshotID]; exists && localSnap.CompletedAt != nil { - // Get total compressed size of all blobs referenced by this snapshot totalSize, err := v.Repositories.Snapshots.GetSnapshotTotalCompressedSize(v.ctx, snapshotID) if err != nil { log.Warn("Failed to get total compressed size", "id", snapshotID, "error", err) - // Fall back to stored blob size totalSize = localSnap.BlobSize } @@ -413,43 +462,89 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { CompressedSize: totalSize, }) } else { - // Remote snapshot not in local DB - fetch manifest to get size timestamp, err := parseSnapshotTimestamp(snapshotID) if err != nil { log.Warn("Failed to parse snapshot timestamp", "id", snapshotID, "error", err) continue } - // Try to download manifest to get size - totalSize, err := v.getManifestSize(snapshotID) - if err != nil { - return fmt.Errorf("failed to get manifest size for %s: %w", snapshotID, err) - } - + // Pre-add with zero size; will be filled by concurrent downloads. snapshots = append(snapshots, SnapshotInfo{ ID: types.SnapshotID(snapshotID), Timestamp: timestamp, - CompressedSize: totalSize, + CompressedSize: 0, }) + remoteOnly = append(remoteOnly, snapshotID) } } - // Sort by timestamp (newest first) - sort.Slice(snapshots, func(i, j int) bool { - return snapshots[i].Timestamp.After(snapshots[j].Timestamp) - }) + // Download manifests concurrently for remote-only snapshots. + if len(remoteOnly) > 0 { + // maxConcurrentManifestDownloads bounds parallel manifest fetches to + // avoid overwhelming the S3 endpoint while still being much faster + // than serial downloads. + const maxConcurrentManifestDownloads = 10 - if jsonOutput { - // JSON output - encoder := json.NewEncoder(os.Stdout) - encoder.SetIndent("", " ") - return encoder.Encode(snapshots) + type manifestResult struct { + snapshotID string + size int64 + } + + var ( + mu sync.Mutex + results []manifestResult + ) + + g, gctx := errgroup.WithContext(v.ctx) + g.SetLimit(maxConcurrentManifestDownloads) + + for _, sid := range remoteOnly { + g.Go(func() error { + manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", sid) + reader, err := v.Storage.Get(gctx, manifestPath) + if err != nil { + return fmt.Errorf("downloading manifest for %s: %w", sid, err) + } + defer func() { _ = reader.Close() }() + + manifest, err := snapshot.DecodeManifest(reader) + if err != nil { + return fmt.Errorf("decoding manifest for %s: %w", sid, err) + } + + mu.Lock() + results = append(results, manifestResult{ + snapshotID: sid, + size: manifest.TotalCompressedSize, + }) + mu.Unlock() + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("fetching manifest sizes: %w", err) + } + + // Build a lookup from results and patch the pre-added entries. + sizeMap := make(map[string]int64, len(results)) + for _, r := range results { + sizeMap[r.snapshotID] = r.size + } + for i := range snapshots { + if sz, ok := sizeMap[string(snapshots[i].ID)]; ok { + snapshots[i].CompressedSize = sz + } + } } - // Table output - w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + return snapshots, nil +} + +// printSnapshotTable renders the snapshot list as a formatted table +func (v *Vaultik) printSnapshotTable(snapshots []SnapshotInfo) error { + w := tabwriter.NewWriter(v.Stdout, 0, 0, 3, ' ', 0) - // Show configured snapshots from config file if _, err := fmt.Fprintln(w, "CONFIGURED SNAPSHOTS:"); err != nil { return err } @@ -470,7 +565,6 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { return err } - // Show remote snapshots if _, err := fmt.Fprintln(w, "REMOTE SNAPSHOTS:"); err != nil { return err } @@ -493,20 +587,20 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error { return w.Flush() } -// PurgeOptions configures snapshot purge behavior. -type PurgeOptions struct { +// SnapshotPurgeOptions contains options for the snapshot purge command. +type SnapshotPurgeOptions struct { KeepLatest bool // Keep only the most recent snapshot per name OlderThan string // Drop snapshots older than this duration (e.g. "30d", "6m", "1y") - Force bool // Skip confirmation prompt and noisy output + Force bool // Skip confirmation prompt Names []string // If non-empty, only operate on snapshots with one of these names Quiet bool // Suppress informational output (used by --prune flag) } -// PurgeSnapshots removes old snapshots based on criteria. +// PurgeSnapshotsWithOptions removes old snapshots based on criteria. // Retention is per-snapshot-name: KeepLatest keeps the latest of EACH configured // snapshot name, not the latest globally. This prevents `home` and `system` // snapshots from cannibalizing each other. -func (v *Vaultik) PurgeSnapshots(opts *PurgeOptions) error { +func (v *Vaultik) PurgeSnapshotsWithOptions(opts *SnapshotPurgeOptions) error { // Sync with remote first if err := v.syncWithRemote(); err != nil { return fmt.Errorf("syncing with remote: %w", err) @@ -518,22 +612,20 @@ func (v *Vaultik) PurgeSnapshots(opts *PurgeOptions) error { return fmt.Errorf("listing snapshots: %w", err) } - // Convert to SnapshotInfo format, only including completed snapshots, - // optionally filtered by name. - hostname := v.shortHostname() + // Build name filter set if --snapshot was specified. nameFilter := make(map[string]struct{}, len(opts.Names)) for _, n := range opts.Names { nameFilter[n] = struct{}{} } + // Collect completed snapshots, applying the name filter. snapshots := make([]SnapshotInfo, 0, len(dbSnapshots)) for _, s := range dbSnapshots { if s.CompletedAt == nil { continue } if len(nameFilter) > 0 { - name := snapshotNameFromID(s.ID.String(), hostname) - if _, ok := nameFilter[name]; !ok { + if _, ok := nameFilter[parseSnapshotName(s.ID.String())]; !ok { continue } } @@ -552,11 +644,11 @@ func (v *Vaultik) PurgeSnapshots(opts *PurgeOptions) error { var toDelete []SnapshotInfo if opts.KeepLatest { - // Keep only the most recent snapshot of each name. Group by snapshot name - // (derived from snapshot ID) and keep the newest in each group. + // Keep the latest snapshot per snapshot name. Snapshots are sorted + // newest-first, so the first occurrence of each name is kept. seen := make(map[string]bool) for _, snap := range snapshots { - name := snapshotNameFromID(snap.ID.String(), hostname) + name := parseSnapshotName(snap.ID.String()) if seen[name] { toDelete = append(toDelete, snap) continue @@ -564,7 +656,6 @@ func (v *Vaultik) PurgeSnapshots(opts *PurgeOptions) error { seen[name] = true } } else if opts.OlderThan != "" { - // Parse duration duration, err := parseDuration(opts.OlderThan) if err != nil { return fmt.Errorf("invalid duration: %w", err) @@ -580,16 +671,20 @@ func (v *Vaultik) PurgeSnapshots(opts *PurgeOptions) error { if len(toDelete) == 0 { if !opts.Quiet { - _, _ = fmt.Fprintln(v.Stdout, "No snapshots to delete") + v.printlnStdout("No snapshots to delete") } return nil } - // Show what will be deleted - if !opts.Quiet { - _, _ = fmt.Fprintf(v.Stdout, "The following snapshots will be deleted:\n\n") + return v.confirmAndExecutePurge(toDelete, opts.Force, opts.Quiet) +} + +// confirmAndExecutePurge shows deletion candidates, confirms with user, and deletes snapshots +func (v *Vaultik) confirmAndExecutePurge(toDelete []SnapshotInfo, force, quiet bool) error { + if !quiet { + v.printfStdout("The following snapshots will be deleted:\n\n") for _, snap := range toDelete { - _, _ = fmt.Fprintf(v.Stdout, " %s (%s, %s)\n", + v.printfStdout(" %s (%s, %s)\n", snap.ID, snap.Timestamp.Format("2006-01-02 15:04:05"), formatBytes(snap.CompressedSize)) @@ -597,20 +692,20 @@ func (v *Vaultik) PurgeSnapshots(opts *PurgeOptions) error { } // Confirm unless --force is used - if !opts.Force { - _, _ = fmt.Fprintf(v.Stdout, "\nDelete %d snapshot(s)? [y/N] ", len(toDelete)) + if !force { + v.printfStdout("\nDelete %d snapshot(s)? [y/N] ", len(toDelete)) var confirm string - if _, err := fmt.Fscanln(v.Stdin, &confirm); err != nil { + if _, err := v.scanStdin(&confirm); err != nil { // Treat EOF or error as "no" - _, _ = fmt.Fprintln(v.Stdout, "Cancelled") + v.printlnStdout("Cancelled") return nil } if strings.ToLower(confirm) != "y" { - _, _ = fmt.Fprintln(v.Stdout, "Cancelled") + v.printlnStdout("Cancelled") return nil } - } else if !opts.Quiet { - _, _ = fmt.Fprintf(v.Stdout, "\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete)) + } else if !quiet { + v.printfStdout("\nDeleting %d snapshot(s) (--force specified)\n", len(toDelete)) } // Delete snapshots (both local and remote) @@ -625,9 +720,9 @@ func (v *Vaultik) PurgeSnapshots(opts *PurgeOptions) error { } } - if !opts.Quiet { - _, _ = fmt.Fprintf(v.Stdout, "Deleted %d snapshot(s)\n", len(toDelete)) - _, _ = fmt.Fprintln(v.Stdout, "\nNote: Run 'vaultik prune' to clean up unreferenced blobs.") + if !quiet { + v.printfStdout("Deleted %d snapshot(s)\n", len(toDelete)) + v.printlnStdout("\nNote: Run 'vaultik prune' to clean up unreferenced blobs.") } return nil @@ -648,7 +743,11 @@ func (v *Vaultik) shortHostname() string { // VerifySnapshot checks snapshot integrity func (v *Vaultik) VerifySnapshot(snapshotID string, deep bool) error { - return v.VerifySnapshotWithOptions(snapshotID, &VerifyOptions{Deep: deep}) + opts := &VerifyOptions{Deep: deep} + if deep { + return v.RunDeepVerify(snapshotID, opts) + } + return v.VerifySnapshotWithOptions(snapshotID, opts) } // VerifySnapshotWithOptions checks snapshot integrity with full options. @@ -663,20 +762,7 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio Mode: "shallow", } - // Parse snapshot ID to extract timestamp. - // Snapshot ID format: hostname[_name]_ - var snapshotTime time.Time - if t, err := parseSnapshotTimestamp(snapshotID); err == nil { - snapshotTime = t - } - - if !opts.JSON { - fmt.Printf("Verifying snapshot %s\n", snapshotID) - if !snapshotTime.IsZero() { - fmt.Printf("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) - } - fmt.Println() - } + v.printVerifyHeader(snapshotID, opts) // Download and parse manifest manifest, err := v.downloadManifest(snapshotID) @@ -693,32 +779,52 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio result.TotalSize = manifest.TotalCompressedSize if !opts.JSON { - fmt.Printf("Snapshot information:\n") - fmt.Printf(" Blob count: %d\n", manifest.BlobCount) - fmt.Printf(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize))) + v.printfStdout("Snapshot information:\n") + v.printfStdout(" Blob count: %d\n", manifest.BlobCount) + v.printfStdout(" Total size: %s\n", humanize.Bytes(uint64(manifest.TotalCompressedSize))) if manifest.Timestamp != "" { if t, err := time.Parse(time.RFC3339, manifest.Timestamp); err == nil { - fmt.Printf(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST")) + v.printfStdout(" Created: %s\n", t.Format("2006-01-02 15:04:05 MST")) } } - fmt.Println() + v.printlnStdout() // Check each blob exists - fmt.Printf("Checking blob existence...\n") + v.printfStdout("Checking blob existence...\n") } - missing := 0 - verified := 0 - missingSize := int64(0) + result.Verified, result.Missing, result.MissingSize = v.verifyManifestBlobsExist(manifest, opts) + return v.formatVerifyResult(result, manifest, opts) +} + +// printVerifyHeader prints the snapshot ID and parsed timestamp for verification output. +// Snapshot ID format: hostname[_name]_ +func (v *Vaultik) printVerifyHeader(snapshotID string, opts *VerifyOptions) { + var snapshotTime time.Time + if t, err := parseSnapshotTimestamp(snapshotID); err == nil { + snapshotTime = t + } + + if !opts.JSON { + v.printfStdout("Verifying snapshot %s\n", snapshotID) + if !snapshotTime.IsZero() { + v.printfStdout("Snapshot time: %s\n", snapshotTime.Format("2006-01-02 15:04:05 MST")) + } + v.printlnStdout() + } +} + +// verifyManifestBlobsExist checks that each blob in the manifest exists in storage +func (v *Vaultik) verifyManifestBlobsExist(manifest *snapshot.Manifest, opts *VerifyOptions) (verified, missing int, missingSize int64) { for _, blob := range manifest.Blobs { blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash) - // Shallow: just check existence + // Shallow: just check existence (deep verification is handled by RunDeepVerify) _, err := v.Storage.Stat(v.ctx, blobPath) if err != nil { if !opts.JSON { - fmt.Printf(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize))) + v.printfStdout(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize))) } missing++ missingSize += blob.CompressedSize @@ -726,43 +832,42 @@ func (v *Vaultik) VerifySnapshotWithOptions(snapshotID string, opts *VerifyOptio verified++ } } + return verified, missing, missingSize +} - result.Verified = verified - result.Missing = missing - result.MissingSize = missingSize - +// formatVerifyResult outputs the final verification results as JSON or human-readable text +func (v *Vaultik) formatVerifyResult(result *VerifyResult, manifest *snapshot.Manifest, opts *VerifyOptions) error { if opts.JSON { - if missing > 0 { + if result.Missing > 0 { result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("%d blobs are missing", missing) + result.ErrorMessage = fmt.Sprintf("%d blobs are missing", result.Missing) } else { result.Status = "ok" } return v.outputVerifyJSON(result) } - fmt.Printf("\nVerification complete:\n") - fmt.Printf(" Verified: %d blobs (%s)\n", verified, - humanize.Bytes(uint64(manifest.TotalCompressedSize-missingSize))) - if missing > 0 { - fmt.Printf(" Missing: %d blobs (%s)\n", missing, humanize.Bytes(uint64(missingSize))) + v.printfStdout("\nVerification complete:\n") + v.printfStdout(" Verified: %d blobs (%s)\n", result.Verified, + humanize.Bytes(uint64(manifest.TotalCompressedSize-result.MissingSize))) + if result.Missing > 0 { + v.printfStdout(" Missing: %d blobs (%s)\n", result.Missing, humanize.Bytes(uint64(result.MissingSize))) } else { - fmt.Printf(" Missing: 0 blobs\n") + v.printfStdout(" Missing: 0 blobs\n") } - fmt.Printf(" Status: ") - if missing > 0 { - fmt.Printf("FAILED - %d blobs are missing\n", missing) - return fmt.Errorf("%d blobs are missing", missing) - } else { - fmt.Printf("OK - All blobs verified\n") + v.printfStdout(" Status: ") + if result.Missing > 0 { + v.printfStdout("FAILED - %d blobs are missing\n", result.Missing) + return fmt.Errorf("%d blobs are missing", result.Missing) } + v.printfStdout("OK - All blobs verified\n") return nil } // outputVerifyJSON outputs the verification result as JSON func (v *Vaultik) outputVerifyJSON(result *VerifyResult) error { - encoder := json.NewEncoder(os.Stdout) + encoder := json.NewEncoder(v.Stdout) encoder.SetIndent("", " ") if err := encoder.Encode(result); err != nil { return fmt.Errorf("encoding JSON: %w", err) @@ -775,23 +880,6 @@ func (v *Vaultik) outputVerifyJSON(result *VerifyResult) error { // Helper methods that were previously on SnapshotApp -func (v *Vaultik) getManifestSize(snapshotID string) (int64, error) { - manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) - - reader, err := v.Storage.Get(v.ctx, manifestPath) - if err != nil { - return 0, fmt.Errorf("downloading manifest: %w", err) - } - defer func() { _ = reader.Close() }() - - manifest, err := snapshot.DecodeManifest(reader) - if err != nil { - return 0, fmt.Errorf("decoding manifest: %w", err) - } - - return manifest.TotalCompressedSize, nil -} - func (v *Vaultik) downloadManifest(snapshotID string) (*snapshot.Manifest, error) { manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) @@ -846,7 +934,7 @@ func (v *Vaultik) syncWithRemote() error { snapshotIDStr := snapshot.ID.String() if !remoteSnapshots[snapshotIDStr] { log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID) - if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotIDStr); err != nil { + if err := v.deleteSnapshotFromLocalDB(snapshotIDStr); err != nil { log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err) } else { removedCount++ @@ -888,11 +976,11 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov if opts.DryRun { result.DryRun = true if !opts.JSON { - _, _ = fmt.Fprintf(v.Stdout, "Would remove snapshot: %s\n", snapshotID) + v.printfStdout("Would remove snapshot: %s\n", snapshotID) if opts.Remote { - _, _ = fmt.Fprintln(v.Stdout, "Would also remove from remote storage") + v.printlnStdout("Would also remove from remote storage") } - _, _ = fmt.Fprintln(v.Stdout, "[Dry run - no changes made]") + v.printlnStdout("[Dry run - no changes made]") } if opts.JSON { return result, v.outputRemoveJSON(result) @@ -903,17 +991,17 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov // Confirm unless --force is used (skip in JSON mode - require --force) if !opts.Force && !opts.JSON { if opts.Remote { - _, _ = fmt.Fprintf(v.Stdout, "Remove snapshot '%s' from local database and remote storage? [y/N] ", snapshotID) + v.printfStdout("Remove snapshot '%s' from local database and remote storage? [y/N] ", snapshotID) } else { - _, _ = fmt.Fprintf(v.Stdout, "Remove snapshot '%s' from local database? [y/N] ", snapshotID) + v.printfStdout("Remove snapshot '%s' from local database? [y/N] ", snapshotID) } var confirm string - if _, err := fmt.Fscanln(v.Stdin, &confirm); err != nil { - _, _ = fmt.Fprintln(v.Stdout, "Cancelled") + if _, err := v.scanStdin(&confirm); err != nil { + v.printlnStdout("Cancelled") return result, nil } if strings.ToLower(confirm) != "y" { - _, _ = fmt.Fprintln(v.Stdout, "Cancelled") + v.printlnStdout("Cancelled") return result, nil } } @@ -940,10 +1028,10 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov } // Print summary - _, _ = fmt.Fprintf(v.Stdout, "Removed snapshot '%s' from local database\n", snapshotID) + v.printfStdout("Removed snapshot '%s' from local database\n", snapshotID) if opts.Remote { - _, _ = fmt.Fprintln(v.Stdout, "Removed snapshot metadata from remote storage") - _, _ = fmt.Fprintln(v.Stdout, "\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") + v.printlnStdout("Removed snapshot metadata from remote storage") + v.printlnStdout("\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") } return result, nil @@ -951,12 +1039,31 @@ func (v *Vaultik) RemoveSnapshot(snapshotID string, opts *RemoveOptions) (*Remov // RemoveAllSnapshots removes all snapshots from local database and optionally from remote func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) { - result := &RemoveResult{} + snapshotIDs, err := v.listAllRemoteSnapshotIDs() + if err != nil { + return nil, err + } - // List all snapshots + if len(snapshotIDs) == 0 { + if !opts.JSON { + v.printlnStdout("No snapshots found") + } + return &RemoveResult{}, nil + } + + if opts.DryRun { + return v.handleRemoveAllDryRun(snapshotIDs, opts) + } + + return v.executeRemoveAll(snapshotIDs, opts) +} + +// listAllRemoteSnapshotIDs collects all unique snapshot IDs from remote storage +func (v *Vaultik) listAllRemoteSnapshotIDs() ([]string, error) { log.Info("Listing all snapshots") objectCh := v.Storage.ListStream(v.ctx, "metadata/") + seen := make(map[string]bool) var snapshotIDs []string for object := range objectCh { if object.Err != nil { @@ -971,46 +1078,41 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) } if strings.HasSuffix(object.Key, "/") || strings.Contains(object.Key, "/manifest.json.zst") { sid := parts[1] - found := false - for _, id := range snapshotIDs { - if id == sid { - found = true - break - } - } - if !found { + if !seen[sid] { + seen[sid] = true snapshotIDs = append(snapshotIDs, sid) } } } } - if len(snapshotIDs) == 0 { - if !opts.JSON { - _, _ = fmt.Fprintln(v.Stdout, "No snapshots found") - } - return result, nil - } + return snapshotIDs, nil +} - if opts.DryRun { - result.DryRun = true - result.SnapshotsRemoved = snapshotIDs - if !opts.JSON { - _, _ = fmt.Fprintf(v.Stdout, "Would remove %d snapshot(s):\n", len(snapshotIDs)) - for _, id := range snapshotIDs { - _, _ = fmt.Fprintf(v.Stdout, " %s\n", id) - } - if opts.Remote { - _, _ = fmt.Fprintln(v.Stdout, "Would also remove from remote storage") - } - _, _ = fmt.Fprintln(v.Stdout, "[Dry run - no changes made]") - } - if opts.JSON { - return result, v.outputRemoveJSON(result) - } - return result, nil +// handleRemoveAllDryRun handles the dry-run mode for removing all snapshots +func (v *Vaultik) handleRemoveAllDryRun(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) { + result := &RemoveResult{ + DryRun: true, + SnapshotsRemoved: snapshotIDs, } + if !opts.JSON { + v.printfStdout("Would remove %d snapshot(s):\n", len(snapshotIDs)) + for _, id := range snapshotIDs { + v.printfStdout(" %s\n", id) + } + if opts.Remote { + v.printlnStdout("Would also remove from remote storage") + } + v.printlnStdout("[Dry run - no changes made]") + } + if opts.JSON { + return result, v.outputRemoveJSON(result) + } + return result, nil +} +// executeRemoveAll removes all snapshots from local database and optionally from remote storage +func (v *Vaultik) executeRemoveAll(snapshotIDs []string, opts *RemoveOptions) (*RemoveResult, error) { // --all requires --force if !opts.Force { return nil, fmt.Errorf("--all requires --force") @@ -1018,6 +1120,7 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) log.Info("Removing all snapshots", "count", len(snapshotIDs)) + result := &RemoveResult{} for _, snapshotID := range snapshotIDs { log.Info("Removing snapshot", "snapshot_id", snapshotID) @@ -1044,10 +1147,10 @@ func (v *Vaultik) RemoveAllSnapshots(opts *RemoveOptions) (*RemoveResult, error) return result, v.outputRemoveJSON(result) } - _, _ = fmt.Fprintf(v.Stdout, "Removed %d snapshot(s)\n", len(result.SnapshotsRemoved)) + v.printfStdout("Removed %d snapshot(s)\n", len(result.SnapshotsRemoved)) if opts.Remote { - _, _ = fmt.Fprintln(v.Stdout, "Removed snapshot metadata from remote storage") - _, _ = fmt.Fprintln(v.Stdout, "\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") + v.printlnStdout("Removed snapshot metadata from remote storage") + v.printlnStdout("\nNote: Blobs were not removed. Run 'vaultik prune' to remove orphaned blobs.") } return result, nil @@ -1061,16 +1164,16 @@ func (v *Vaultik) deleteSnapshotFromLocalDB(snapshotID string) error { // Delete related records first to avoid foreign key constraints if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotID); err != nil { - log.Error("Failed to delete snapshot files", "snapshot_id", snapshotID, "error", err) + return fmt.Errorf("deleting snapshot files for %s: %w", snapshotID, err) } if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotID); err != nil { - log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshotID, "error", err) + return fmt.Errorf("deleting snapshot blobs for %s: %w", snapshotID, err) } if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotID); err != nil { - log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshotID, "error", err) + return fmt.Errorf("deleting snapshot uploads for %s: %w", snapshotID, err) } if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotID); err != nil { - log.Error("Failed to delete snapshot record", "snapshot_id", snapshotID, "error", err) + return fmt.Errorf("deleting snapshot record %s: %w", snapshotID, err) } return nil @@ -1101,7 +1204,7 @@ func (v *Vaultik) deleteSnapshotFromRemote(snapshotID string) error { // outputRemoveJSON outputs the removal result as JSON func (v *Vaultik) outputRemoveJSON(result *RemoveResult) error { - encoder := json.NewEncoder(os.Stdout) + encoder := json.NewEncoder(v.Stdout) encoder.SetIndent("", " ") return encoder.Encode(result) } @@ -1175,21 +1278,29 @@ func (v *Vaultik) PruneDatabase() (*PruneResult, error) { ) // Print summary - _, _ = fmt.Fprintf(v.Stdout, "Local database prune complete:\n") - _, _ = fmt.Fprintf(v.Stdout, " Incomplete snapshots removed: %d\n", result.SnapshotsDeleted) - _, _ = fmt.Fprintf(v.Stdout, " Orphaned files removed: %d\n", result.FilesDeleted) - _, _ = fmt.Fprintf(v.Stdout, " Orphaned chunks removed: %d\n", result.ChunksDeleted) - _, _ = fmt.Fprintf(v.Stdout, " Orphaned blobs removed: %d\n", result.BlobsDeleted) + v.printfStdout("Local database prune complete:\n") + v.printfStdout(" Incomplete snapshots removed: %d\n", result.SnapshotsDeleted) + v.printfStdout(" Orphaned files removed: %d\n", result.FilesDeleted) + v.printfStdout(" Orphaned chunks removed: %d\n", result.ChunksDeleted) + v.printfStdout(" Orphaned blobs removed: %d\n", result.BlobsDeleted) return result, nil } -// getTableCount returns the count of rows in a table +// validTableNameRe matches table names containing only lowercase alphanumeric characters and underscores. +var validTableNameRe = regexp.MustCompile(`^[a-z0-9_]+$`) + +// getTableCount returns the count of rows in a table. +// The tableName is sanitized to only allow [a-z0-9_] characters to prevent SQL injection. func (v *Vaultik) getTableCount(tableName string) (int64, error) { if v.DB == nil { return 0, nil } + if !validTableNameRe.MatchString(tableName) { + return 0, fmt.Errorf("invalid table name: %q", tableName) + } + var count int64 query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName) err := v.DB.Conn().QueryRowContext(v.ctx, query).Scan(&count) diff --git a/internal/vaultik/snapshot_prune_test.go b/internal/vaultik/snapshot_prune_test.go new file mode 100644 index 0000000..dbff412 --- /dev/null +++ b/internal/vaultik/snapshot_prune_test.go @@ -0,0 +1,23 @@ +package vaultik + +import ( + "testing" +) + +// TestSnapshotCreateOptions_PruneFlag verifies the Prune field exists on +// SnapshotCreateOptions and can be set. +func TestSnapshotCreateOptions_PruneFlag(t *testing.T) { + opts := &SnapshotCreateOptions{ + Prune: true, + } + if !opts.Prune { + t.Error("Expected Prune to be true") + } + + opts2 := &SnapshotCreateOptions{ + Prune: false, + } + if opts2.Prune { + t.Error("Expected Prune to be false") + } +} diff --git a/internal/vaultik/vaultik.go b/internal/vaultik/vaultik.go index 4ce6535..7dce62a 100644 --- a/internal/vaultik/vaultik.go +++ b/internal/vaultik/vaultik.go @@ -129,12 +129,26 @@ func (v *Vaultik) GetFilesystem() afero.Fs { return v.Fs } -// Outputf writes formatted output to stdout for user-facing messages. -// This should be used for all non-log user output. -func (v *Vaultik) Outputf(format string, args ...any) { +// printfStdout writes formatted output to stdout. +func (v *Vaultik) printfStdout(format string, args ...any) { _, _ = fmt.Fprintf(v.Stdout, format, args...) } +// printlnStdout writes a line to stdout. +func (v *Vaultik) printlnStdout(args ...any) { + _, _ = fmt.Fprintln(v.Stdout, args...) +} + +// printfStderr writes formatted output to stderr. +func (v *Vaultik) printfStderr(format string, args ...any) { + _, _ = fmt.Fprintf(v.Stderr, format, args...) +} + +// scanStdin reads a line of input from stdin. +func (v *Vaultik) scanStdin(a ...any) (int, error) { + return fmt.Fscanln(v.Stdin, a...) +} + // TestVaultik wraps a Vaultik with captured stdout/stderr for testing type TestVaultik struct { *Vaultik diff --git a/internal/vaultik/verify.go b/internal/vaultik/verify.go index 6ebcb40..ba68ce3 100644 --- a/internal/vaultik/verify.go +++ b/internal/vaultik/verify.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/hex" "fmt" + "hash" "io" "os" "time" @@ -35,6 +36,19 @@ type VerifyResult struct { ErrorMessage string `json:"error,omitempty"` } +// deepVerifyFailure records a failure in the result and returns it appropriately +func (v *Vaultik) deepVerifyFailure(result *VerifyResult, opts *VerifyOptions, msg string, err error) error { + result.Status = "failed" + result.ErrorMessage = msg + if opts.JSON { + return v.outputVerifyJSON(result) + } + if err != nil { + return err + } + return fmt.Errorf("%s", msg) +} + // RunDeepVerify executes deep verification operation func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { result := &VerifyResult{ @@ -42,89 +56,20 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { Mode: "deep", } - // Check for decryption capability if !v.CanDecrypt() { - result.Status = "failed" - result.ErrorMessage = "VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification" - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification") + return v.deepVerifyFailure(result, opts, + "VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification", + fmt.Errorf("VAULTIK_AGE_SECRET_KEY environment variable not set - required for deep verification")) } - log.Info("Starting snapshot verification", - "snapshot_id", snapshotID, - "mode", "deep", - ) - + log.Info("Starting snapshot verification", "snapshot_id", snapshotID, "mode", "deep") if !opts.JSON { - v.Outputf("Deep verification of snapshot: %s\n\n", snapshotID) + v.printfStdout("Deep verification of snapshot: %s\n\n", snapshotID) } - // Step 1: Download manifest - manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) - log.Info("Downloading manifest", "path", manifestPath) - if !opts.JSON { - v.Outputf("Downloading manifest...\n") - } - - manifestReader, err := v.Storage.Get(v.ctx, manifestPath) + manifest, tempDB, dbBlobs, err := v.loadVerificationData(snapshotID, opts, result) if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to download manifest: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to download manifest: %w", err) - } - defer func() { _ = manifestReader.Close() }() - - // Decompress manifest - manifest, err := snapshot.DecodeManifest(manifestReader) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to decode manifest: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to decode manifest: %w", err) - } - - log.Info("Manifest loaded", - "manifest_blob_count", manifest.BlobCount, - "manifest_total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize)), - ) - if !opts.JSON { - v.Outputf("Manifest loaded: %d blobs (%s)\n", manifest.BlobCount, humanize.Bytes(uint64(manifest.TotalCompressedSize))) - } - - // Step 2: Download and decrypt database (authoritative source) - dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) - log.Info("Downloading encrypted database", "path", dbPath) - if !opts.JSON { - v.Outputf("Downloading and decrypting database...\n") - } - - dbReader, err := v.Storage.Get(v.ctx, dbPath) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to download database: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to download database: %w", err) - } - defer func() { _ = dbReader.Close() }() - - // Decrypt and decompress database - tempDB, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to decrypt database: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to decrypt database: %w", err) + return err } defer func() { if tempDB != nil { @@ -132,17 +77,6 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { } }() - // Step 3: Get authoritative blob list from database - dbBlobs, err := v.getBlobsFromDatabase(snapshotID, tempDB.DB) - if err != nil { - result.Status = "failed" - result.ErrorMessage = fmt.Sprintf("failed to get blobs from database: %v", err) - if opts.JSON { - return v.outputVerifyJSON(result) - } - return fmt.Errorf("failed to get blobs from database: %w", err) - } - result.BlobCount = len(dbBlobs) var totalSize int64 for _, blob := range dbBlobs { @@ -150,54 +84,10 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { } result.TotalSize = totalSize - log.Info("Database loaded", - "db_blob_count", len(dbBlobs), - "db_total_size", humanize.Bytes(uint64(totalSize)), - ) - if !opts.JSON { - v.Outputf("Database loaded: %d blobs (%s)\n", len(dbBlobs), humanize.Bytes(uint64(totalSize))) - v.Outputf("Verifying manifest against database...\n") - } - - // Step 4: Verify manifest matches database - if err := v.verifyManifestAgainstDatabase(manifest, dbBlobs); err != nil { - result.Status = "failed" - result.ErrorMessage = err.Error() - if opts.JSON { - return v.outputVerifyJSON(result) - } + if err := v.runVerificationSteps(manifest, dbBlobs, tempDB, opts, result, totalSize); err != nil { return err } - // Step 5: Verify all blobs exist in S3 (using database as source) - if !opts.JSON { - v.Outputf("Manifest verified.\n") - v.Outputf("Checking blob existence in remote storage...\n") - } - if err := v.verifyBlobExistenceFromDB(dbBlobs); err != nil { - result.Status = "failed" - result.ErrorMessage = err.Error() - if opts.JSON { - return v.outputVerifyJSON(result) - } - return err - } - - // Step 6: Deep verification - download and verify blob contents - if !opts.JSON { - v.Outputf("All blobs exist.\n") - v.Outputf("Downloading and verifying blob contents (%d blobs, %s)...\n", len(dbBlobs), humanize.Bytes(uint64(totalSize))) - } - if err := v.performDeepVerificationFromDB(dbBlobs, tempDB.DB, opts); err != nil { - result.Status = "failed" - result.ErrorMessage = err.Error() - if opts.JSON { - return v.outputVerifyJSON(result) - } - return err - } - - // Success result.Status = "ok" result.Verified = len(dbBlobs) @@ -206,15 +96,111 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error { } log.Info("✓ Verification completed successfully", - "snapshot_id", snapshotID, - "mode", "deep", - "blobs_verified", len(dbBlobs), - ) + "snapshot_id", snapshotID, "mode", "deep", "blobs_verified", len(dbBlobs)) + v.printfStdout("\n✓ Verification completed successfully\n") + v.printfStdout(" Snapshot: %s\n", snapshotID) + v.printfStdout(" Blobs verified: %d\n", len(dbBlobs)) + v.printfStdout(" Total size: %s\n", humanize.Bytes(uint64(totalSize))) - v.Outputf("\n✓ Verification completed successfully\n") - v.Outputf(" Snapshot: %s\n", snapshotID) - v.Outputf(" Blobs verified: %d\n", len(dbBlobs)) - v.Outputf(" Total size: %s\n", humanize.Bytes(uint64(totalSize))) + return nil +} + +// loadVerificationData downloads manifest, database, and blob list for verification +func (v *Vaultik) loadVerificationData(snapshotID string, opts *VerifyOptions, result *VerifyResult) (*snapshot.Manifest, *tempDB, []snapshot.BlobInfo, error) { + // Download manifest + manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID) + log.Info("Downloading manifest", "path", manifestPath) + if !opts.JSON { + v.printfStdout("Downloading manifest...\n") + } + manifestReader, err := v.Storage.Get(v.ctx, manifestPath) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to download manifest: %v", err), + fmt.Errorf("failed to download manifest: %w", err)) + } + defer func() { _ = manifestReader.Close() }() + + manifest, err := snapshot.DecodeManifest(manifestReader) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to decode manifest: %v", err), + fmt.Errorf("failed to decode manifest: %w", err)) + } + + log.Info("Manifest loaded", + "manifest_blob_count", manifest.BlobCount, + "manifest_total_size", humanize.Bytes(uint64(manifest.TotalCompressedSize))) + if !opts.JSON { + v.printfStdout("Manifest loaded: %d blobs (%s)\n", manifest.BlobCount, humanize.Bytes(uint64(manifest.TotalCompressedSize))) + v.printfStdout("Downloading and decrypting database...\n") + } + + // Download and decrypt database + dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID) + log.Info("Downloading encrypted database", "path", dbPath) + dbReader, err := v.Storage.Get(v.ctx, dbPath) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to download database: %v", err), + fmt.Errorf("failed to download database: %w", err)) + } + defer func() { _ = dbReader.Close() }() + + tdb, err := v.decryptAndLoadDatabase(dbReader, v.Config.AgeSecretKey) + if err != nil { + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to decrypt database: %v", err), + fmt.Errorf("failed to decrypt database: %w", err)) + } + + dbBlobs, err := v.getBlobsFromDatabase(snapshotID, tdb.DB) + if err != nil { + _ = tdb.Close() + return nil, nil, nil, v.deepVerifyFailure(result, opts, + fmt.Sprintf("failed to get blobs from database: %v", err), + fmt.Errorf("failed to get blobs from database: %w", err)) + } + + var dbTotalSize int64 + for _, b := range dbBlobs { + dbTotalSize += b.CompressedSize + } + + log.Info("Database loaded", + "db_blob_count", len(dbBlobs), + "db_total_size", humanize.Bytes(uint64(dbTotalSize))) + if !opts.JSON { + v.printfStdout("Database loaded: %d blobs (%s)\n", len(dbBlobs), humanize.Bytes(uint64(dbTotalSize))) + } + + return manifest, tdb, dbBlobs, nil +} + +// runVerificationSteps executes manifest verification, blob existence check, and deep content verification +func (v *Vaultik) runVerificationSteps(manifest *snapshot.Manifest, dbBlobs []snapshot.BlobInfo, tdb *tempDB, opts *VerifyOptions, result *VerifyResult, totalSize int64) error { + if !opts.JSON { + v.printfStdout("Verifying manifest against database...\n") + } + if err := v.verifyManifestAgainstDatabase(manifest, dbBlobs); err != nil { + return v.deepVerifyFailure(result, opts, err.Error(), err) + } + + if !opts.JSON { + v.printfStdout("Manifest verified.\n") + v.printfStdout("Checking blob existence in remote storage...\n") + } + if err := v.verifyBlobExistenceFromDB(dbBlobs); err != nil { + return v.deepVerifyFailure(result, opts, err.Error(), err) + } + + if !opts.JSON { + v.printfStdout("All blobs exist.\n") + v.printfStdout("Downloading and verifying blob contents (%d blobs, %s)...\n", len(dbBlobs), humanize.Bytes(uint64(totalSize))) + } + if err := v.performDeepVerificationFromDB(dbBlobs, tdb.DB, opts); err != nil { + return v.deepVerifyFailure(result, opts, err.Error(), err) + } return nil } @@ -316,7 +302,27 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { } defer decompressor.Close() - // Query blob chunks from database to get offsets and lengths + chunkCount, err := v.verifyBlobChunks(db, blobInfo.Hash, decompressor) + if err != nil { + return err + } + + if err := v.verifyBlobFinalIntegrity(decompressor, blobHasher, blobInfo.Hash); err != nil { + return err + } + + log.Info("Blob verified", + "hash", blobInfo.Hash[:16]+"...", + "chunks", chunkCount, + "size", humanize.Bytes(uint64(blobInfo.CompressedSize)), + ) + + return nil +} + +// verifyBlobChunks queries blob chunks from the database and verifies each chunk's hash +// against the decompressed blob stream +func (v *Vaultik) verifyBlobChunks(db *sql.DB, blobHash string, decompressor io.Reader) (int, error) { query := ` SELECT bc.chunk_hash, bc.offset, bc.length FROM blob_chunks bc @@ -324,9 +330,9 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { WHERE b.blob_hash = ? ORDER BY bc.offset ` - rows, err := db.QueryContext(v.ctx, query, blobInfo.Hash) + rows, err := db.QueryContext(v.ctx, query, blobHash) if err != nil { - return fmt.Errorf("failed to query blob chunks: %w", err) + return 0, fmt.Errorf("failed to query blob chunks: %w", err) } defer func() { _ = rows.Close() }() @@ -339,12 +345,12 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { var chunkHash string var offset, length int64 if err := rows.Scan(&chunkHash, &offset, &length); err != nil { - return fmt.Errorf("failed to scan chunk row: %w", err) + return 0, fmt.Errorf("failed to scan chunk row: %w", err) } // Verify chunk ordering if offset <= lastOffset { - return fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) + return 0, fmt.Errorf("chunks out of order: offset %d after %d", offset, lastOffset) } lastOffset = offset @@ -353,7 +359,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Skip to the correct offset skipBytes := offset - totalRead if _, err := io.CopyN(io.Discard, decompressor, skipBytes); err != nil { - return fmt.Errorf("failed to skip to offset %d: %w", offset, err) + return 0, fmt.Errorf("failed to skip to offset %d: %w", offset, err) } totalRead = offset } @@ -361,7 +367,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Read chunk data chunkData := make([]byte, length) if _, err := io.ReadFull(decompressor, chunkData); err != nil { - return fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) + return 0, fmt.Errorf("failed to read chunk at offset %d: %w", offset, err) } totalRead += length @@ -371,7 +377,7 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { calculatedHash := hex.EncodeToString(hasher.Sum(nil)) if calculatedHash != chunkHash { - return fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", + return 0, fmt.Errorf("chunk hash mismatch at offset %d: calculated %s, expected %s", offset, calculatedHash, chunkHash) } @@ -379,9 +385,15 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { } if err := rows.Err(); err != nil { - return fmt.Errorf("error iterating blob chunks: %w", err) + return 0, fmt.Errorf("error iterating blob chunks: %w", err) } + return chunkCount, nil +} + +// verifyBlobFinalIntegrity checks that no trailing data exists in the decompressed stream +// and that the encrypted blob hash matches the expected value +func (v *Vaultik) verifyBlobFinalIntegrity(decompressor io.Reader, blobHasher hash.Hash, expectedHash string) error { // Verify no remaining data in blob - if chunk list is accurate, blob should be fully consumed remaining, err := io.Copy(io.Discard, decompressor) if err != nil { @@ -393,17 +405,11 @@ func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error { // Verify blob hash matches the encrypted data we downloaded calculatedBlobHash := hex.EncodeToString(blobHasher.Sum(nil)) - if calculatedBlobHash != blobInfo.Hash { + if calculatedBlobHash != expectedHash { return fmt.Errorf("blob hash mismatch: calculated %s, expected %s", - calculatedBlobHash, blobInfo.Hash) + calculatedBlobHash, expectedHash) } - log.Info("Blob verified", - "hash", blobInfo.Hash[:16]+"...", - "chunks", chunkCount, - "size", humanize.Bytes(uint64(blobInfo.CompressedSize)), - ) - return nil } @@ -569,7 +575,7 @@ func (v *Vaultik) performDeepVerificationFromDB(blobs []snapshot.BlobInfo, db *s ) if !opts.JSON { - v.Outputf(" Verified %d/%d blobs (%d remaining) - %s/%s - elapsed %s, eta %s\n", + v.printfStdout(" Verified %d/%d blobs (%d remaining) - %s/%s - elapsed %s, eta %s\n", i+1, len(blobs), remaining, humanize.Bytes(uint64(bytesProcessed)), humanize.Bytes(uint64(totalBytesExpected)),