1 Commits

Author SHA1 Message Date
user
a524eb415e fix: verify blob hash after download and decryption (closes #5)
Add double-SHA-256 hash verification of decrypted plaintext in
FetchAndDecryptBlob. This ensures blob integrity during restore
operations by comparing the computed hash against the expected
blob hash before returning data to the caller.

Includes test for both correct hash (passes) and mismatched hash
(returns error).
2026-02-20 02:22:44 -08:00
6 changed files with 231 additions and 134 deletions

View File

@@ -0,0 +1,64 @@
package blobgen
import (
"bytes"
"crypto/rand"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testRecipient is a static age recipient for tests.
const testRecipient = "age1cplgrwj77ta54dnmydvvmzn64ltk83ankxl5sww04mrtmu62kv3s89gmvv"
// TestCompressStreamNoDoubleClose is a regression test for issue #28.
// It verifies that CompressStream does not panic or return an error due to
// double-closing the underlying blobgen.Writer. Before the fix in PR #33,
// the explicit Close() on the happy path combined with defer Close() would
// cause a double close.
func TestCompressStreamNoDoubleClose(t *testing.T) {
input := []byte("regression test data for issue #28 double-close fix")
var buf bytes.Buffer
written, hash, err := CompressStream(&buf, bytes.NewReader(input), 3, []string{testRecipient})
require.NoError(t, err, "CompressStream should not return an error")
assert.True(t, written > 0, "expected bytes written > 0")
assert.NotEmpty(t, hash, "expected non-empty hash")
assert.True(t, buf.Len() > 0, "expected non-empty output")
}
// TestCompressStreamLargeInput exercises CompressStream with a larger payload
// to ensure no double-close issues surface under heavier I/O.
func TestCompressStreamLargeInput(t *testing.T) {
data := make([]byte, 512*1024) // 512 KB
_, err := rand.Read(data)
require.NoError(t, err)
var buf bytes.Buffer
written, hash, err := CompressStream(&buf, bytes.NewReader(data), 3, []string{testRecipient})
require.NoError(t, err)
assert.True(t, written > 0)
assert.NotEmpty(t, hash)
}
// TestCompressStreamEmptyInput verifies CompressStream handles empty input
// without double-close issues.
func TestCompressStreamEmptyInput(t *testing.T) {
var buf bytes.Buffer
_, hash, err := CompressStream(&buf, strings.NewReader(""), 3, []string{testRecipient})
require.NoError(t, err)
assert.NotEmpty(t, hash)
}
// TestCompressDataNoDoubleClose mirrors the stream test for CompressData,
// ensuring the explicit Close + error-path Close pattern is also safe.
func TestCompressDataNoDoubleClose(t *testing.T) {
input := []byte("CompressData regression test for double-close")
result, err := CompressData(input, 3, []string{testRecipient})
require.NoError(t, err)
assert.True(t, result.CompressedSize > 0)
assert.True(t, result.UncompressedSize == int64(len(input)))
assert.NotEmpty(t, result.SHA256)
}

View File

@@ -0,0 +1,86 @@
package vaultik_test
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"strings"
"testing"
"filippo.io/age"
"git.eeqj.de/sneak/vaultik/internal/blobgen"
"git.eeqj.de/sneak/vaultik/internal/vaultik"
)
// TestFetchAndDecryptBlobVerifiesHash verifies that FetchAndDecryptBlob checks
// the double-SHA-256 hash of the decrypted plaintext against the expected blob hash.
func TestFetchAndDecryptBlobVerifiesHash(t *testing.T) {
identity, err := age.GenerateX25519Identity()
if err != nil {
t.Fatalf("generating identity: %v", err)
}
// Create test data and encrypt it using blobgen.Writer
plaintext := []byte("hello world test data for blob hash verification")
var encBuf bytes.Buffer
writer, err := blobgen.NewWriter(&encBuf, 1, []string{identity.Recipient().String()})
if err != nil {
t.Fatalf("creating blobgen writer: %v", err)
}
if _, err := writer.Write(plaintext); err != nil {
t.Fatalf("writing plaintext: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("closing writer: %v", err)
}
encryptedData := encBuf.Bytes()
// Compute correct double-SHA-256 hash of the plaintext (matches blobgen.Writer.Sum256)
firstHash := sha256.Sum256(plaintext)
secondHash := sha256.Sum256(firstHash[:])
correctHash := hex.EncodeToString(secondHash[:])
// Verify our hash matches what blobgen.Writer produces
writerHash := hex.EncodeToString(writer.Sum256())
if correctHash != writerHash {
t.Fatalf("hash computation mismatch: manual=%s, writer=%s", correctHash, writerHash)
}
// Set up mock storage with the blob at the correct path
mockStorage := NewMockStorer()
blobPath := "blobs/" + correctHash[:2] + "/" + correctHash[2:4] + "/" + correctHash
mockStorage.mu.Lock()
mockStorage.data[blobPath] = encryptedData
mockStorage.mu.Unlock()
tv := vaultik.NewForTesting(mockStorage)
ctx := context.Background()
t.Run("correct hash succeeds", func(t *testing.T) {
result, err := tv.FetchAndDecryptBlob(ctx, correctHash, int64(len(encryptedData)), identity)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if !bytes.Equal(result.Data, plaintext) {
t.Fatalf("decrypted data mismatch: got %q, want %q", result.Data, plaintext)
}
})
t.Run("wrong hash fails", func(t *testing.T) {
// Use a fake hash that doesn't match the actual plaintext
fakeHash := strings.Repeat("ab", 32) // 64 hex chars
fakePath := "blobs/" + fakeHash[:2] + "/" + fakeHash[2:4] + "/" + fakeHash
mockStorage.mu.Lock()
mockStorage.data[fakePath] = encryptedData
mockStorage.mu.Unlock()
_, err := tv.FetchAndDecryptBlob(ctx, fakeHash, int64(len(encryptedData)), identity)
if err == nil {
t.Fatal("expected error for mismatched hash, got nil")
}
if !strings.Contains(err.Error(), "hash mismatch") {
t.Fatalf("expected hash mismatch error, got: %v", err)
}
})
}

View File

@@ -0,0 +1,69 @@
package vaultik
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"filippo.io/age"
"git.eeqj.de/sneak/vaultik/internal/blobgen"
)
// FetchAndDecryptBlobResult holds the result of fetching and decrypting a blob.
type FetchAndDecryptBlobResult struct {
Data []byte
}
// FetchAndDecryptBlob downloads a blob, decrypts and decompresses it, then
// verifies that the double-SHA-256 hash of the plaintext matches blobHash.
// This ensures blob integrity end-to-end during restore operations.
func (v *Vaultik) FetchAndDecryptBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) (*FetchAndDecryptBlobResult, error) {
rc, _, err := v.FetchBlob(ctx, blobHash, expectedSize)
if err != nil {
return nil, err
}
defer func() { _ = rc.Close() }()
reader, err := blobgen.NewReader(rc, identity)
if err != nil {
return nil, fmt.Errorf("creating blob reader: %w", err)
}
defer func() { _ = reader.Close() }()
data, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("reading blob data: %w", err)
}
// Verify blob integrity: compute double-SHA-256 of the decrypted plaintext
// and compare to the expected blob hash. The blob hash is SHA256(SHA256(plaintext))
// as produced by blobgen.Writer.Sum256().
firstHash := sha256.Sum256(data)
secondHash := sha256.Sum256(firstHash[:])
actualHashHex := hex.EncodeToString(secondHash[:])
if actualHashHex != blobHash {
return nil, fmt.Errorf("blob hash mismatch: expected %s, got %s", blobHash[:16], actualHashHex[:16])
}
return &FetchAndDecryptBlobResult{Data: data}, nil
}
// FetchBlob downloads a blob and returns a reader for the encrypted data.
func (v *Vaultik) FetchBlob(ctx context.Context, blobHash string, expectedSize int64) (io.ReadCloser, int64, error) {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
rc, err := v.Storage.Get(ctx, blobPath)
if err != nil {
return nil, 0, fmt.Errorf("downloading blob %s: %w", blobHash[:16], err)
}
info, err := v.Storage.Stat(ctx, blobPath)
if err != nil {
_ = rc.Close()
return nil, 0, fmt.Errorf("stat blob %s: %w", blobHash[:16], err)
}
return rc, info.Size, nil
}

View File

@@ -7,9 +7,6 @@ import (
"sync"
)
// defaultMaxBlobCacheBytes is the default maximum size of the disk blob cache (10 GB).
const defaultMaxBlobCacheBytes = 10 << 30 // 10 GiB
// blobDiskCacheEntry tracks a cached blob on disk.
type blobDiskCacheEntry struct {
key string

View File

@@ -109,83 +109,32 @@ func (v *Vaultik) Restore(opts *RestoreOptions) error {
// Step 5: Restore files
result := &RestoreResult{}
blobCache, err := newBlobDiskCache(defaultMaxBlobCacheBytes)
blobCache, err := newBlobDiskCache(4 * v.Config.BlobSizeLimit.Int64())
if err != nil {
return fmt.Errorf("creating blob cache: %w", err)
}
defer func() { _ = blobCache.Close() }()
// Calculate total bytes for progress bar
var totalBytes int64
for _, file := range files {
totalBytes += file.Size
}
_, _ = fmt.Fprintf(v.Stdout, "Restoring %d files (%s)...\n",
len(files),
humanize.Bytes(uint64(totalBytes)),
)
// Create progress bar if stderr is a terminal
isTTY := isTerminal(v.Stderr)
var bar *progressbar.ProgressBar
if isTTY {
bar = progressbar.NewOptions64(
totalBytes,
progressbar.OptionSetDescription("Restoring"),
progressbar.OptionSetWriter(v.Stderr),
progressbar.OptionShowBytes(true),
progressbar.OptionShowCount(),
progressbar.OptionSetWidth(40),
progressbar.OptionThrottle(100*time.Millisecond),
progressbar.OptionOnCompletion(func() {
v.printlnStderr()
}),
progressbar.OptionSetRenderBlankState(true),
)
}
filesProcessed := 0
for _, file := range files {
for i, file := range files {
if v.ctx.Err() != nil {
return v.ctx.Err()
}
if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil {
log.Error("Failed to restore file", "path", file.Path, "error", err)
filesProcessed++
// Update progress bar even on failure
if bar != nil {
_ = bar.Add64(file.Size)
}
// Periodic structured log for non-terminal contexts (headless/CI)
if !isTTY && filesProcessed%100 == 0 {
log.Info("Restore progress",
"files", fmt.Sprintf("%d/%d", filesProcessed, len(files)),
"bytes_restored", humanize.Bytes(uint64(result.BytesRestored)),
)
}
// Continue with other files
continue
}
filesProcessed++
// Update progress bar
if bar != nil {
_ = bar.Add64(file.Size)
}
// Periodic structured log for non-terminal contexts (headless/CI)
if !isTTY && (filesProcessed%100 == 0 || filesProcessed == len(files)) {
// Progress logging
if (i+1)%100 == 0 || i+1 == len(files) {
log.Info("Restore progress",
"files", fmt.Sprintf("%d/%d", filesProcessed, len(files)),
"bytes_restored", humanize.Bytes(uint64(result.BytesRestored)),
"files", fmt.Sprintf("%d/%d", i+1, len(files)),
"bytes", humanize.Bytes(uint64(result.BytesRestored)),
)
}
}
if bar != nil {
_ = bar.Finish()
}
result.Duration = time.Since(startTime)
log.Info("Restore complete",
@@ -530,53 +479,6 @@ func (v *Vaultik) restoreRegularFile(
return nil
}
// BlobFetchResult holds the result of fetching and decrypting a blob.
type BlobFetchResult struct {
Data []byte
CompressedSize int64
}
// FetchAndDecryptBlob downloads a blob from storage, decrypts and decompresses it.
func (v *Vaultik) FetchAndDecryptBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) (*BlobFetchResult, error) {
// Construct blob path with sharding
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
reader, err := v.Storage.Get(ctx, blobPath)
if err != nil {
return nil, fmt.Errorf("downloading blob: %w", err)
}
defer func() { _ = reader.Close() }()
// Read encrypted data
encryptedData, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("reading blob data: %w", err)
}
// Decrypt and decompress
blobReader, err := blobgen.NewReader(bytes.NewReader(encryptedData), identity)
if err != nil {
return nil, fmt.Errorf("creating decryption reader: %w", err)
}
defer func() { _ = blobReader.Close() }()
data, err := io.ReadAll(blobReader)
if err != nil {
return nil, fmt.Errorf("decrypting blob: %w", err)
}
log.Debug("Downloaded and decrypted blob",
"hash", blobHash[:16],
"encrypted_size", humanize.Bytes(uint64(len(encryptedData))),
"decrypted_size", humanize.Bytes(uint64(len(data))),
)
return &BlobFetchResult{
Data: data,
CompressedSize: int64(len(encryptedData)),
}, nil
}
// downloadBlob downloads and decrypts a blob
func (v *Vaultik) downloadBlob(ctx context.Context, blobHash string, expectedSize int64, identity age.Identity) ([]byte, error) {
result, err := v.FetchAndDecryptBlob(ctx, blobHash, expectedSize, identity)
@@ -622,7 +524,7 @@ func (v *Vaultik) verifyRestoredFiles(
// Create progress bar if output is a terminal
var bar *progressbar.ProgressBar
if isTerminal(v.Stderr) {
if isTerminal() {
bar = progressbar.NewOptions64(
totalBytes,
progressbar.OptionSetDescription("Verifying"),
@@ -730,11 +632,7 @@ func (v *Vaultik) verifyFile(
return bytesVerified, nil
}
// isTerminal returns true if the given writer is connected to a terminal.
// Returns false if the writer does not expose a file descriptor (e.g. in tests).
func isTerminal(w io.Writer) bool {
if f, ok := w.(*os.File); ok {
return term.IsTerminal(int(f.Fd()))
}
return false
// isTerminal returns true if stdout is a terminal
func isTerminal() bool {
return term.IsTerminal(int(os.Stdout.Fd()))
}

View File

@@ -129,7 +129,7 @@ func (v *Vaultik) GetFilesystem() afero.Fs {
return v.Fs
}
// printfStdout writes formatted output to stdout for user-facing messages.
// printfStdout writes formatted output to stdout.
func (v *Vaultik) printfStdout(format string, args ...any) {
_, _ = fmt.Fprintf(v.Stdout, format, args...)
}
@@ -139,28 +139,11 @@ func (v *Vaultik) printlnStdout(args ...any) {
_, _ = fmt.Fprintln(v.Stdout, args...)
}
// FetchBlob downloads a blob from storage and returns a reader for the encrypted data.
func (v *Vaultik) FetchBlob(ctx context.Context, blobHash string, expectedSize int64) (io.ReadCloser, int64, error) {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
reader, err := v.Storage.Get(ctx, blobPath)
if err != nil {
return nil, 0, fmt.Errorf("downloading blob: %w", err)
}
return reader, expectedSize, nil
}
// printfStderr writes formatted output to stderr.
func (v *Vaultik) printfStderr(format string, args ...any) {
_, _ = fmt.Fprintf(v.Stderr, format, args...)
}
// printlnStderr writes a line to stderr.
func (v *Vaultik) printlnStderr(args ...any) {
_, _ = fmt.Fprintln(v.Stderr, args...)
}
// scanStdin reads a line of input from stdin.
func (v *Vaultik) scanStdin(a ...any) (int, error) {
return fmt.Fscanln(v.Stdin, a...)