diff --git a/internal/vaultik/remove_snapshot_test.go b/internal/vaultik/remove_snapshot_test.go new file mode 100644 index 0000000..33c4d07 --- /dev/null +++ b/internal/vaultik/remove_snapshot_test.go @@ -0,0 +1,351 @@ +package vaultik_test + +import ( + "bytes" + "context" + "io" + "strings" + "sync" + "testing" + + "git.eeqj.de/sneak/vaultik/internal/log" + "git.eeqj.de/sneak/vaultik/internal/snapshot" + "git.eeqj.de/sneak/vaultik/internal/storage" + "git.eeqj.de/sneak/vaultik/internal/vaultik" + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testStorer implements storage.Storer for testing +type testStorer struct { + mu sync.Mutex + data map[string][]byte +} + +func newTestStorer() *testStorer { + return &testStorer{ + data: make(map[string][]byte), + } +} + +func (s *testStorer) Put(ctx context.Context, key string, reader io.Reader) error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := io.ReadAll(reader) + if err != nil { + return err + } + s.data[key] = data + return nil +} + +func (s *testStorer) PutWithProgress(ctx context.Context, key string, reader io.Reader, size int64, progress storage.ProgressCallback) error { + return s.Put(ctx, key, reader) +} + +func (s *testStorer) Get(ctx context.Context, key string) (io.ReadCloser, error) { + s.mu.Lock() + defer s.mu.Unlock() + + data, exists := s.data[key] + if !exists { + return nil, storage.ErrNotFound + } + return io.NopCloser(bytes.NewReader(data)), nil +} + +func (s *testStorer) Stat(ctx context.Context, key string) (*storage.ObjectInfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + + data, exists := s.data[key] + if !exists { + return nil, storage.ErrNotFound + } + return &storage.ObjectInfo{ + Key: key, + Size: int64(len(data)), + }, nil +} + +func (s *testStorer) Delete(ctx context.Context, key string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.data, key) + return nil +} + +func (s *testStorer) List(ctx context.Context, prefix string) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var keys []string + for key := range s.data { + if prefix == "" || strings.HasPrefix(key, prefix) { + keys = append(keys, key) + } + } + return keys, nil +} + +func (s *testStorer) ListStream(ctx context.Context, prefix string) <-chan storage.ObjectInfo { + ch := make(chan storage.ObjectInfo) + + go func() { + defer close(ch) + s.mu.Lock() + defer s.mu.Unlock() + + for key, data := range s.data { + if prefix == "" || strings.HasPrefix(key, prefix) { + ch <- storage.ObjectInfo{ + Key: key, + Size: int64(len(data)), + } + } + } + }() + + return ch +} + +func (s *testStorer) hasKey(key string) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, exists := s.data[key] + return exists +} + +func (s *testStorer) keyCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.data) +} + +func (s *testStorer) Info() storage.StorageInfo { + return storage.StorageInfo{ + Type: "test", + Location: "memory", + } +} + +// addManifest creates a compressed manifest in storage +func addManifest(t *testing.T, store *testStorer, snapshotID string, blobHashes []string) { + t.Helper() + + blobs := make([]snapshot.BlobInfo, len(blobHashes)) + for i, hash := range blobHashes { + blobs[i] = snapshot.BlobInfo{ + Hash: hash, + CompressedSize: 1000, + } + } + + manifest := &snapshot.Manifest{ + SnapshotID: snapshotID, + BlobCount: len(blobs), + Blobs: blobs, + } + + data, err := snapshot.EncodeManifest(manifest, 3) + require.NoError(t, err) + + key := "metadata/" + snapshotID + "/manifest.json.zst" + err = store.Put(context.Background(), key, bytes.NewReader(data)) + require.NoError(t, err) +} + +// addBlob adds a fake blob to storage +func addBlob(t *testing.T, store *testStorer, hash string) { + t.Helper() + + // Create zstd compressed data + var buf bytes.Buffer + writer, _ := zstd.NewWriter(&buf) + _, _ = writer.Write([]byte("blob data")) + _ = writer.Close() + + key := "blobs/" + hash[:2] + "/" + hash[2:4] + "/" + hash + err := store.Put(context.Background(), key, bytes.NewReader(buf.Bytes())) + require.NoError(t, err) +} + +// ============================================================================ +// Unit Tests for RemoveSnapshot +// ============================================================================ + +func TestRemoveSnapshot_LocalOnly(t *testing.T) { + log.Initialize(log.Config{}) + + store := newTestStorer() + + blobA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + addManifest(t, store, "snapshot-001", []string{blobA}) + addBlob(t, store, blobA) + + tv := vaultik.NewForTesting(store) + + opts := &vaultik.RemoveOptions{Force: true} + result, err := tv.RemoveSnapshot("snapshot-001", opts) + + require.NoError(t, err) + assert.Equal(t, "snapshot-001", result.SnapshotID) + assert.False(t, result.RemoteRemoved) + + // Blobs should NOT be deleted (that's what prune is for) + assert.True(t, store.hasKey("blobs/aa/aa/"+blobA)) + // Remote metadata should NOT be deleted (no --remote flag) + assert.True(t, store.hasKey("metadata/snapshot-001/manifest.json.zst")) + + // Verify output + assert.Contains(t, tv.Stdout.String(), "Removed snapshot 'snapshot-001' from local database") +} + +func TestRemoveSnapshot_WithRemote(t *testing.T) { + log.Initialize(log.Config{}) + + store := newTestStorer() + + blobA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + addManifest(t, store, "snapshot-001", []string{blobA}) + addBlob(t, store, blobA) + + tv := vaultik.NewForTesting(store) + + opts := &vaultik.RemoveOptions{Force: true, Remote: true} + result, err := tv.RemoveSnapshot("snapshot-001", opts) + + require.NoError(t, err) + assert.Equal(t, "snapshot-001", result.SnapshotID) + assert.True(t, result.RemoteRemoved) + + // Blobs should NOT be deleted + assert.True(t, store.hasKey("blobs/aa/aa/"+blobA)) + // Remote metadata SHOULD be deleted + assert.False(t, store.hasKey("metadata/snapshot-001/manifest.json.zst")) + + // Verify output mentions prune + assert.Contains(t, tv.Stdout.String(), "Removed snapshot 'snapshot-001' from local database") + assert.Contains(t, tv.Stdout.String(), "Removed snapshot metadata from remote storage") + assert.Contains(t, tv.Stdout.String(), "Run 'vaultik prune' to remove orphaned blobs") +} + +func TestRemoveSnapshot_DryRun(t *testing.T) { + log.Initialize(log.Config{}) + + store := newTestStorer() + + blobA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + addManifest(t, store, "snapshot-001", []string{blobA}) + addBlob(t, store, blobA) + + initialCount := store.keyCount() + + tv := vaultik.NewForTesting(store) + + opts := &vaultik.RemoveOptions{Force: true, DryRun: true, Remote: true} + result, err := tv.RemoveSnapshot("snapshot-001", opts) + + require.NoError(t, err) + assert.True(t, result.DryRun) + + // Nothing should be deleted + assert.Equal(t, initialCount, store.keyCount()) + assert.True(t, store.hasKey("blobs/aa/aa/"+blobA)) + assert.True(t, store.hasKey("metadata/snapshot-001/manifest.json.zst")) + + // Verify dry run message + assert.Contains(t, tv.Stdout.String(), "[Dry run - no changes made]") +} + +func TestRemoveAllSnapshots_RequiresForce(t *testing.T) { + log.Initialize(log.Config{}) + + store := newTestStorer() + addManifest(t, store, "snapshot-001", []string{}) + addManifest(t, store, "snapshot-002", []string{}) + + tv := vaultik.NewForTesting(store) + + opts := &vaultik.RemoveOptions{All: true} // No Force + _, err := tv.RemoveAllSnapshots(opts) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "--all requires --force") +} + +func TestRemoveAllSnapshots_WithForce(t *testing.T) { + log.Initialize(log.Config{}) + + store := newTestStorer() + + blobA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + addManifest(t, store, "snapshot-001", []string{blobA}) + addManifest(t, store, "snapshot-002", []string{blobA}) + addBlob(t, store, blobA) + + tv := vaultik.NewForTesting(store) + + opts := &vaultik.RemoveOptions{All: true, Force: true, Remote: true} + result, err := tv.RemoveAllSnapshots(opts) + + require.NoError(t, err) + assert.Len(t, result.SnapshotsRemoved, 2) + assert.True(t, result.RemoteRemoved) + + // Blobs should NOT be deleted + assert.True(t, store.hasKey("blobs/aa/aa/"+blobA)) + // Remote metadata SHOULD be deleted + assert.False(t, store.hasKey("metadata/snapshot-001/manifest.json.zst")) + assert.False(t, store.hasKey("metadata/snapshot-002/manifest.json.zst")) + + // Verify output + assert.Contains(t, tv.Stdout.String(), "Removed 2 snapshot(s)") + assert.Contains(t, tv.Stdout.String(), "Run 'vaultik prune' to remove orphaned blobs") +} + +func TestRemoveAllSnapshots_DryRun(t *testing.T) { + log.Initialize(log.Config{}) + + store := newTestStorer() + addManifest(t, store, "snapshot-001", []string{}) + addManifest(t, store, "snapshot-002", []string{}) + + initialCount := store.keyCount() + + tv := vaultik.NewForTesting(store) + + opts := &vaultik.RemoveOptions{All: true, Force: true, DryRun: true} + result, err := tv.RemoveAllSnapshots(opts) + + require.NoError(t, err) + assert.True(t, result.DryRun) + assert.Len(t, result.SnapshotsRemoved, 2) + + // Nothing should be deleted + assert.Equal(t, initialCount, store.keyCount()) + + // Verify dry run message + assert.Contains(t, tv.Stdout.String(), "[Dry run - no changes made]") +} + +func TestRemoveAllSnapshots_NoSnapshots(t *testing.T) { + log.Initialize(log.Config{}) + + store := newTestStorer() + // No snapshots added + + tv := vaultik.NewForTesting(store) + + opts := &vaultik.RemoveOptions{All: true, Force: true} + result, err := tv.RemoveAllSnapshots(opts) + + require.NoError(t, err) + assert.Len(t, result.SnapshotsRemoved, 0) + + // Verify output + assert.Contains(t, tv.Stdout.String(), "No snapshots found") +} diff --git a/internal/vaultik/verify_test.go b/internal/vaultik/verify_test.go new file mode 100644 index 0000000..7e03056 --- /dev/null +++ b/internal/vaultik/verify_test.go @@ -0,0 +1,92 @@ +package vaultik_test + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "io" + "testing" + + "git.eeqj.de/sneak/vaultik/internal/crypto" + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestTeeReaderWithDecryption tests that TeeReader correctly hashes all encrypted +// bytes when streaming through age decryption and zstd decompression. +// This validates the verification path: hash encrypted blob -> decrypt -> decompress. +func TestTeeReaderWithDecryption(t *testing.T) { + // Test data - use random data that doesn't compress well (5MB) + testData := make([]byte, 5*1024*1024) + _, err := rand.Read(testData) + require.NoError(t, err) + + // Compress the data + var compressedBuf bytes.Buffer + compressor, err := zstd.NewWriter(&compressedBuf, zstd.WithEncoderLevel(zstd.SpeedDefault)) + require.NoError(t, err) + _, err = compressor.Write(testData) + require.NoError(t, err) + err = compressor.Close() + require.NoError(t, err) + + // Encrypt the compressed data + testRecipient := "age1cplgrwj77ta54dnmydvvmzn64ltk83ankxl5sww04mrtmu62kv3s89gmvv" + testSecretKey := "AGE-SECRET-KEY-1C77PYNTHXSHNNC6EYR2W52UWYXACXA5JT00J9CCW9986M3XY87PSGP89AQ" + + encryptor, err := crypto.NewEncryptor([]string{testRecipient}) + require.NoError(t, err) + + var encryptedBuf bytes.Buffer + err = encryptor.EncryptStream(&encryptedBuf, bytes.NewReader(compressedBuf.Bytes())) + require.NoError(t, err) + + encryptedData := encryptedBuf.Bytes() + + // Calculate the expected hash of the encrypted data directly + expectedHash := sha256.Sum256(encryptedData) + expectedHashStr := hex.EncodeToString(expectedHash[:]) + + t.Logf("Encrypted data size: %d bytes", len(encryptedData)) + t.Logf("Expected hash: %s", expectedHashStr) + + // Now simulate what verifyBlob does: use TeeReader to hash while decrypting + decryptor, err := crypto.NewDecryptor(testSecretKey) + require.NoError(t, err) + + // Create hasher and tee reader + hasher := sha256.New() + reader := bytes.NewReader(encryptedData) + teeReader := io.TeeReader(reader, hasher) + + // Decrypt through the tee reader + decryptedReader, err := decryptor.DecryptStream(teeReader) + require.NoError(t, err) + + // Decompress + decompressor, err := zstd.NewReader(decryptedReader) + require.NoError(t, err) + defer decompressor.Close() + + // Read all decompressed data (simulating chunk verification) + decompressedData, err := io.ReadAll(decompressor) + require.NoError(t, err) + + // Verify we got the original data back + assert.Equal(t, testData, decompressedData, "Decompressed data should match original") + + // Drain remaining decompressed data (should be 0) + remaining, err := io.Copy(io.Discard, decompressor) + require.NoError(t, err) + assert.Equal(t, int64(0), remaining, "No remaining decompressed data") + + // Calculate hash from tee reader + calculatedHashStr := hex.EncodeToString(hasher.Sum(nil)) + t.Logf("Calculated hash (before drain): %s", calculatedHashStr) + + // Verify the hash matches the direct hash of encrypted data + assert.Equal(t, expectedHashStr, calculatedHashStr, + "Hash calculated via TeeReader should match direct hash of encrypted data") +}