diff --git a/TODO.md b/TODO.md index c53d963..2c5b561 100644 --- a/TODO.md +++ b/TODO.md @@ -6,22 +6,9 @@ prioritized from most critical (top) to least critical (bottom). ## Code Cleanups -* none of the integration tests should be searching for a binary or trying - to execute another process. the integration tests cannot make another - process or depend on a compiled file, they must do all of their testing in - the current (test) process. - * we shouldn't be passing around a statedir, it should be read from the environment or default. -## CRITICAL SECURITY ISSUES - Must Fix Before 1.0 - -- [ ] **1. Memory security vulnerabilities**: Sensitive data (passwords, - private keys, passphrases) stored as strings are not properly zeroed from - memory after use. Memory dumps or swap files could expose secrets. Found - in crypto.go:107, passphraseunlocker.go:29-48, cli/crypto.go:89,193, - pgpunlocker.go:278, keychainunlocker.go:252,346. - ## HIGH PRIORITY SECURITY ISSUES - [ ] **4. Application crashes on corrupted metadata**: Code panics instead diff --git a/internal/cli/crypto.go b/internal/cli/crypto.go index 560e165..db337a9 100644 --- a/internal/cli/crypto.go +++ b/internal/cli/crypto.go @@ -82,13 +82,15 @@ func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error { return fmt.Errorf("failed to generate age key: %w", err) } - ageSecretKey = identity.String() - - // Store the generated key as a secret using secure buffer - secureBuffer := memguard.NewBufferFromBytes([]byte(ageSecretKey)) + // Store the generated key directly in a secure buffer + identityStr := identity.String() + secureBuffer := memguard.NewBufferFromBytes([]byte(identityStr)) defer secureBuffer.Destroy() - err = vlt.AddSecret(secretName, secureBuffer.Bytes(), false) + // Set ageSecretKey for later use (we need it for encryption) + ageSecretKey = identityStr + + err = vlt.AddSecret(secretName, secureBuffer, false) if err != nil { return fmt.Errorf("failed to store age key: %w", err) } diff --git a/internal/cli/generate.go b/internal/cli/generate.go index 06d1fb8..d2bc79b 100644 --- a/internal/cli/generate.go +++ b/internal/cli/generate.go @@ -7,6 +7,7 @@ import ( "os" "git.eeqj.de/sneak/secret/internal/vault" + "github.com/awnumar/memguard" "github.com/spf13/cobra" "github.com/tyler-smith/go-bip39" ) @@ -136,7 +137,11 @@ func (cli *Instance) GenerateSecret( return err } - if err := vlt.AddSecret(secretName, []byte(secretValue), force); err != nil { + // Protect the generated secret immediately + secretBuffer := memguard.NewBufferFromBytes([]byte(secretValue)) + defer secretBuffer.Destroy() + + if err := vlt.AddSecret(secretName, secretBuffer, force); err != nil { return err } diff --git a/internal/cli/init.go b/internal/cli/init.go index 91e04ee..4181d50 100644 --- a/internal/cli/init.go +++ b/internal/cli/init.go @@ -166,8 +166,11 @@ func (cli *Instance) Init(cmd *cobra.Command) error { } // Encrypt long-term private key to unlocker - ltPrivKeyData := []byte(ltIdentity.String()) - encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyData, unlockerRecipient) + // Use memguard to protect the private key in memory + ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String())) + defer ltPrivKeyBuffer.Destroy() + + encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerRecipient) if err != nil { return fmt.Errorf("failed to encrypt long-term private key: %w", err) } diff --git a/internal/cli/secrets.go b/internal/cli/secrets.go index 00900d4..efcabde 100644 --- a/internal/cli/secrets.go +++ b/internal/cli/secrets.go @@ -8,7 +8,7 @@ import ( "git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/vault" - "github.com/spf13/afero" + "github.com/awnumar/memguard" "github.com/spf13/cobra" ) @@ -103,6 +103,20 @@ func newImportCmd() *cobra.Command { return cmd } +// updateBufferSize updates the buffer size based on usage pattern +func updateBufferSize(currentSize int, sameSize *int) int { + *sameSize++ + const doubleAfterBuffers = 2 + const growthFactor = 2 + if *sameSize >= doubleAfterBuffers { + *sameSize = 0 + + return currentSize * growthFactor + } + + return currentSize +} + // AddSecret adds a secret to the current vault func (cli *Instance) AddSecret(secretName string, force bool) error { secret.Debug("CLI AddSecret starting", "secret_name", secretName, "force", force) @@ -116,24 +130,82 @@ func (cli *Instance) AddSecret(secretName string, force bool) error { secret.Debug("Got current vault", "vault_name", vlt.GetName()) - // Read secret value from stdin - secret.Debug("Reading secret value from stdin") - value, err := io.ReadAll(cli.cmd.InOrStdin()) - if err != nil { - return fmt.Errorf("failed to read secret value: %w", err) + // Read secret value directly into protected buffers + secret.Debug("Reading secret value from stdin into protected buffers") + + const initialSize = 4 * 1024 // 4KB initial buffer + const maxSize = 100 * 1024 * 1024 // 100MB max + + type bufferInfo struct { + buffer *memguard.LockedBuffer + used int } - secret.Debug("Read secret value from stdin", "value_length", len(value)) + var buffers []bufferInfo + defer func() { + for _, b := range buffers { + b.buffer.Destroy() + } + }() - // Remove trailing newline if present - if len(value) > 0 && value[len(value)-1] == '\n' { - value = value[:len(value)-1] - secret.Debug("Removed trailing newline", "new_length", len(value)) + reader := cli.cmd.InOrStdin() + totalSize := 0 + currentBufferSize := initialSize + sameSize := 0 + + for { + // Create a new buffer + buffer := memguard.NewBuffer(currentBufferSize) + n, err := io.ReadFull(reader, buffer.Bytes()) + + if n == 0 { + // No data read, destroy the unused buffer + buffer.Destroy() + } else { + buffers = append(buffers, bufferInfo{buffer: buffer, used: n}) + totalSize += n + + if totalSize > maxSize { + return fmt.Errorf("secret too large: exceeds 100MB limit") + } + + // If we filled the buffer, consider growing for next iteration + if n == currentBufferSize { + currentBufferSize = updateBufferSize(currentBufferSize, &sameSize) + } + } + + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } else if err != nil { + return fmt.Errorf("failed to read secret value: %w", err) + } + } + + // Check for trailing newline in the last buffer + if len(buffers) > 0 && totalSize > 0 { + lastBuffer := &buffers[len(buffers)-1] + if lastBuffer.buffer.Bytes()[lastBuffer.used-1] == '\n' { + lastBuffer.used-- + totalSize-- + } + } + + secret.Debug("Read secret value from stdin", "value_length", totalSize, "buffers", len(buffers)) + + // Combine all buffers into a single protected buffer + valueBuffer := memguard.NewBuffer(totalSize) + defer valueBuffer.Destroy() + + offset := 0 + for _, b := range buffers { + copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used]) + offset += b.used } // Add the secret to the vault - secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", len(value), "force", force) - if err := vlt.AddSecret(secretName, value, force); err != nil { + secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", valueBuffer.Size(), "force", force) + if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil { secret.Debug("vault.AddSecret failed", "error", err) return err @@ -298,14 +370,77 @@ func (cli *Instance) ImportSecret(cmd *cobra.Command, secretName, sourceFile str return err } - // Read secret value from the source file - value, err := afero.ReadFile(cli.fs, sourceFile) + // Read secret value from the source file into protected buffers + file, err := cli.fs.Open(sourceFile) if err != nil { - return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err) + return fmt.Errorf("failed to open file %s: %w", sourceFile, err) + } + defer func() { + if err := file.Close(); err != nil { + secret.Debug("Failed to close file", "error", err) + } + }() + + const initialSize = 4 * 1024 // 4KB initial buffer + const maxSize = 100 * 1024 * 1024 // 100MB max + + type bufferInfo struct { + buffer *memguard.LockedBuffer + used int + } + + var buffers []bufferInfo + defer func() { + for _, b := range buffers { + b.buffer.Destroy() + } + }() + + totalSize := 0 + currentBufferSize := initialSize + sameSize := 0 + + for { + // Create a new buffer + buffer := memguard.NewBuffer(currentBufferSize) + n, err := io.ReadFull(file, buffer.Bytes()) + + if n == 0 { + // No data read, destroy the unused buffer + buffer.Destroy() + } else { + buffers = append(buffers, bufferInfo{buffer: buffer, used: n}) + totalSize += n + + if totalSize > maxSize { + return fmt.Errorf("secret file too large: exceeds 100MB limit") + } + + // If we filled the buffer, consider growing for next iteration + if n == currentBufferSize { + currentBufferSize = updateBufferSize(currentBufferSize, &sameSize) + } + } + + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } else if err != nil { + return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err) + } + } + + // Combine all buffers into a single protected buffer + valueBuffer := memguard.NewBuffer(totalSize) + defer valueBuffer.Destroy() + + offset := 0 + for _, b := range buffers { + copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used]) + offset += b.used } // Store the secret in the vault - if err := vlt.AddSecret(secretName, value, force); err != nil { + if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil { return err } diff --git a/internal/cli/secrets_size_test.go b/internal/cli/secrets_size_test.go new file mode 100644 index 0000000..e996c24 --- /dev/null +++ b/internal/cli/secrets_size_test.go @@ -0,0 +1,425 @@ +package cli + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "path/filepath" + "strings" + "testing" + + "git.eeqj.de/sneak/secret/internal/secret" + "git.eeqj.de/sneak/secret/internal/vault" + "git.eeqj.de/sneak/secret/pkg/agehd" + "github.com/spf13/afero" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestAddSecretVariousSizes tests adding secrets of various sizes through stdin +func TestAddSecretVariousSizes(t *testing.T) { + tests := []struct { + name string + size int + shouldError bool + errorMsg string + }{ + { + name: "1KB secret", + size: 1024, + shouldError: false, + }, + { + name: "10KB secret", + size: 10 * 1024, + shouldError: false, + }, + { + name: "100KB secret", + size: 100 * 1024, + shouldError: false, + }, + { + name: "1MB secret", + size: 1024 * 1024, + shouldError: false, + }, + { + name: "10MB secret", + size: 10 * 1024 * 1024, + shouldError: false, + }, + { + name: "99MB secret", + size: 99 * 1024 * 1024, + shouldError: false, + }, + { + name: "100MB secret minus 1 byte", + size: 100*1024*1024 - 1, + shouldError: false, + }, + { + name: "101MB secret - should fail", + size: 101 * 1024 * 1024, + shouldError: true, + errorMsg: "secret too large: exceeds 100MB limit", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up test environment + fs := afero.NewMemMapFs() + stateDir := "/test/state" + + // Set test mnemonic + t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") + + // Create vault + vaultName := "test-vault" + _, err := vault.CreateVault(fs, stateDir, vaultName) + require.NoError(t, err) + + // Set current vault + currentVaultPath := filepath.Join(stateDir, "currentvault") + vaultPath := filepath.Join(stateDir, "vaults.d", vaultName) + err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600) + require.NoError(t, err) + + // Get vault and set up long-term key + vlt, err := vault.GetCurrentVault(fs, stateDir) + require.NoError(t, err) + + ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0) + require.NoError(t, err) + vlt.Unlock(ltIdentity) + + // Generate test data of specified size + testData := make([]byte, tt.size) + _, err = rand.Read(testData) + require.NoError(t, err) + + // Add newline that will be stripped + testDataWithNewline := append(testData, '\n') + + // Create fake stdin + stdin := bytes.NewReader(testDataWithNewline) + + // Create command with fake stdin + cmd := &cobra.Command{} + cmd.SetIn(stdin) + + // Create CLI instance + cli := NewCLIInstance() + cli.fs = fs + cli.stateDir = stateDir + cli.cmd = cmd + + // Test adding the secret + secretName := fmt.Sprintf("test-secret-%d", tt.size) + err = cli.AddSecret(secretName, false) + + if tt.shouldError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + + // Verify the secret was stored correctly + retrievedValue, err := vlt.GetSecret(secretName) + require.NoError(t, err) + assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original (without newline)") + } + }) + } +} + +// TestImportSecretVariousSizes tests importing secrets of various sizes from files +func TestImportSecretVariousSizes(t *testing.T) { + tests := []struct { + name string + size int + shouldError bool + errorMsg string + }{ + { + name: "1KB file", + size: 1024, + shouldError: false, + }, + { + name: "10KB file", + size: 10 * 1024, + shouldError: false, + }, + { + name: "100KB file", + size: 100 * 1024, + shouldError: false, + }, + { + name: "1MB file", + size: 1024 * 1024, + shouldError: false, + }, + { + name: "10MB file", + size: 10 * 1024 * 1024, + shouldError: false, + }, + { + name: "99MB file", + size: 99 * 1024 * 1024, + shouldError: false, + }, + { + name: "100MB file", + size: 100 * 1024 * 1024, + shouldError: false, + }, + { + name: "101MB file - should fail", + size: 101 * 1024 * 1024, + shouldError: true, + errorMsg: "secret file too large: exceeds 100MB limit", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up test environment + fs := afero.NewMemMapFs() + stateDir := "/test/state" + + // Set test mnemonic + t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") + + // Create vault + vaultName := "test-vault" + _, err := vault.CreateVault(fs, stateDir, vaultName) + require.NoError(t, err) + + // Set current vault + currentVaultPath := filepath.Join(stateDir, "currentvault") + vaultPath := filepath.Join(stateDir, "vaults.d", vaultName) + err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600) + require.NoError(t, err) + + // Get vault and set up long-term key + vlt, err := vault.GetCurrentVault(fs, stateDir) + require.NoError(t, err) + + ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0) + require.NoError(t, err) + vlt.Unlock(ltIdentity) + + // Generate test data of specified size + testData := make([]byte, tt.size) + _, err = rand.Read(testData) + require.NoError(t, err) + + // Write test data to file + testFile := fmt.Sprintf("/test/secret-%d.bin", tt.size) + err = afero.WriteFile(fs, testFile, testData, 0o600) + require.NoError(t, err) + + // Create command + cmd := &cobra.Command{} + + // Create CLI instance + cli := NewCLIInstance() + cli.fs = fs + cli.stateDir = stateDir + + // Test importing the secret + secretName := fmt.Sprintf("imported-secret-%d", tt.size) + err = cli.ImportSecret(cmd, secretName, testFile, false) + + if tt.shouldError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + + // Verify the secret was stored correctly + retrievedValue, err := vlt.GetSecret(secretName) + require.NoError(t, err) + assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original") + } + }) + } +} + +// TestAddSecretBufferGrowth tests that our buffer growth strategy works correctly +func TestAddSecretBufferGrowth(t *testing.T) { + // Test various sizes that should trigger buffer growth + sizes := []int{ + 1, // Single byte + 100, // Small + 4095, // Just under initial 4KB + 4096, // Exactly 4KB + 4097, // Just over 4KB + 8191, // Just under 8KB (first double) + 8192, // Exactly 8KB + 8193, // Just over 8KB + 12288, // 12KB (should trigger second double) + 16384, // 16KB + 32768, // 32KB (after more doublings) + 65536, // 64KB + 131072, // 128KB + 524288, // 512KB + 1048576, // 1MB + 2097152, // 2MB + } + + for _, size := range sizes { + t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { + // Set up test environment + fs := afero.NewMemMapFs() + stateDir := "/test/state" + + // Set test mnemonic + t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") + + // Create vault + vaultName := "test-vault" + _, err := vault.CreateVault(fs, stateDir, vaultName) + require.NoError(t, err) + + // Set current vault + currentVaultPath := filepath.Join(stateDir, "currentvault") + vaultPath := filepath.Join(stateDir, "vaults.d", vaultName) + err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600) + require.NoError(t, err) + + // Get vault and set up long-term key + vlt, err := vault.GetCurrentVault(fs, stateDir) + require.NoError(t, err) + + ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0) + require.NoError(t, err) + vlt.Unlock(ltIdentity) + + // Create test data of exactly the specified size + // Use a pattern that's easy to verify + testData := make([]byte, size) + for i := range testData { + testData[i] = byte(i % 256) + } + + // Create fake stdin without newline + stdin := bytes.NewReader(testData) + + // Create command with fake stdin + cmd := &cobra.Command{} + cmd.SetIn(stdin) + + // Create CLI instance + cli := NewCLIInstance() + cli.fs = fs + cli.stateDir = stateDir + cli.cmd = cmd + + // Test adding the secret + secretName := fmt.Sprintf("buffer-test-%d", size) + err = cli.AddSecret(secretName, false) + require.NoError(t, err) + + // Verify the secret was stored correctly + retrievedValue, err := vlt.GetSecret(secretName) + require.NoError(t, err) + assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original exactly") + }) + } +} + +// TestAddSecretStreamingBehavior tests that we handle streaming input correctly +func TestAddSecretStreamingBehavior(t *testing.T) { + // Set up test environment + fs := afero.NewMemMapFs() + stateDir := "/test/state" + + // Set test mnemonic + t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") + + // Create vault + vaultName := "test-vault" + _, err := vault.CreateVault(fs, stateDir, vaultName) + require.NoError(t, err) + + // Set current vault + currentVaultPath := filepath.Join(stateDir, "currentvault") + vaultPath := filepath.Join(stateDir, "vaults.d", vaultName) + err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600) + require.NoError(t, err) + + // Get vault and set up long-term key + vlt, err := vault.GetCurrentVault(fs, stateDir) + require.NoError(t, err) + + ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0) + require.NoError(t, err) + vlt.Unlock(ltIdentity) + + // Create a custom reader that simulates slow streaming input + // This will help verify our buffer handling works correctly with partial reads + testData := []byte(strings.Repeat("Hello, World! ", 1000)) // ~14KB + slowReader := &slowReader{ + data: testData, + chunkSize: 1000, // Read 1KB at a time + } + + // Create command with slow reader as stdin + cmd := &cobra.Command{} + cmd.SetIn(slowReader) + + // Create CLI instance + cli := NewCLIInstance() + cli.fs = fs + cli.stateDir = stateDir + cli.cmd = cmd + + // Test adding the secret + err = cli.AddSecret("streaming-test", false) + require.NoError(t, err) + + // Verify the secret was stored correctly + retrievedValue, err := vlt.GetSecret("streaming-test") + require.NoError(t, err) + assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original") +} + +// slowReader simulates a reader that returns data in small chunks +type slowReader struct { + data []byte + offset int + chunkSize int +} + +func (r *slowReader) Read(p []byte) (n int, err error) { + if r.offset >= len(r.data) { + return 0, io.EOF + } + + // Read at most chunkSize bytes + remaining := len(r.data) - r.offset + toRead := r.chunkSize + if toRead > remaining { + toRead = remaining + } + if toRead > len(p) { + toRead = len(p) + } + + n = copy(p, r.data[r.offset:r.offset+toRead]) + r.offset += n + + if r.offset >= len(r.data) { + err = io.EOF + } + + return n, err +} \ No newline at end of file diff --git a/internal/cli/version_test.go b/internal/cli/version_test.go index 2b937a4..2bf687d 100644 --- a/internal/cli/version_test.go +++ b/internal/cli/version_test.go @@ -26,11 +26,21 @@ import ( "git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/pkg/agehd" + "github.com/awnumar/memguard" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// Helper function to add a secret to vault with proper buffer protection +func addTestSecret(t *testing.T, vlt *vault.Vault, name string, value []byte, force bool) { + t.Helper() + buffer := memguard.NewBufferFromBytes(value) + defer buffer.Destroy() + err := vlt.AddSecret(name, buffer, force) + require.NoError(t, err) +} + // Helper function to set up a vault with long-term key func setupTestVault(t *testing.T, fs afero.Fs, stateDir string) { // Set mnemonic for testing @@ -70,13 +80,11 @@ func TestListVersionsCommand(t *testing.T) { vlt, err := vault.GetCurrentVault(fs, stateDir) require.NoError(t, err) - err = vlt.AddSecret("test/secret", []byte("version-1"), false) - require.NoError(t, err) + addTestSecret(t, vlt, "test/secret", []byte("version-1"), false) time.Sleep(10 * time.Millisecond) - err = vlt.AddSecret("test/secret", []byte("version-2"), true) - require.NoError(t, err) + addTestSecret(t, vlt, "test/secret", []byte("version-2"), true) // Create a command for output capture cmd := newRootCmd() @@ -144,13 +152,11 @@ func TestPromoteVersionCommand(t *testing.T) { vlt, err := vault.GetCurrentVault(fs, stateDir) require.NoError(t, err) - err = vlt.AddSecret("test/secret", []byte("version-1"), false) - require.NoError(t, err) + addTestSecret(t, vlt, "test/secret", []byte("version-1"), false) time.Sleep(10 * time.Millisecond) - err = vlt.AddSecret("test/secret", []byte("version-2"), true) - require.NoError(t, err) + addTestSecret(t, vlt, "test/secret", []byte("version-2"), true) // Get versions vaultDir, _ := vlt.GetDirectory() @@ -201,8 +207,7 @@ func TestPromoteNonExistentVersion(t *testing.T) { vlt, err := vault.GetCurrentVault(fs, stateDir) require.NoError(t, err) - err = vlt.AddSecret("test/secret", []byte("value"), false) - require.NoError(t, err) + addTestSecret(t, vlt, "test/secret", []byte("value"), false) // Create a command for output capture cmd := newRootCmd() @@ -228,13 +233,11 @@ func TestGetSecretWithVersion(t *testing.T) { vlt, err := vault.GetCurrentVault(fs, stateDir) require.NoError(t, err) - err = vlt.AddSecret("test/secret", []byte("version-1"), false) - require.NoError(t, err) + addTestSecret(t, vlt, "test/secret", []byte("version-1"), false) time.Sleep(10 * time.Millisecond) - err = vlt.AddSecret("test/secret", []byte("version-2"), true) - require.NoError(t, err) + addTestSecret(t, vlt, "test/secret", []byte("version-2"), true) // Get versions vaultDir, _ := vlt.GetDirectory() diff --git a/internal/secret/crypto.go b/internal/secret/crypto.go index 5fcf3e2..2233a51 100644 --- a/internal/secret/crypto.go +++ b/internal/secret/crypto.go @@ -13,8 +13,13 @@ import ( ) // EncryptToRecipient encrypts data to a recipient using age -func EncryptToRecipient(data []byte, recipient age.Recipient) ([]byte, error) { - Debug("EncryptToRecipient starting", "data_length", len(data)) +// The data parameter should be a LockedBuffer for secure memory handling +func EncryptToRecipient(data *memguard.LockedBuffer, recipient age.Recipient) ([]byte, error) { + if data == nil { + return nil, fmt.Errorf("data buffer is nil") + } + + Debug("EncryptToRecipient starting", "data_length", data.Size()) var buf bytes.Buffer Debug("Creating age encryptor") @@ -27,7 +32,7 @@ func EncryptToRecipient(data []byte, recipient age.Recipient) ([]byte, error) { Debug("Created age encryptor successfully") Debug("Writing data to encryptor") - if _, err := w.Write(data); err != nil { + if _, err := w.Write(data.Bytes()); err != nil { Debug("Failed to write data to encryptor", "error", err) return nil, fmt.Errorf("failed to write data: %w", err) @@ -77,7 +82,11 @@ func EncryptWithPassphrase(data []byte, passphrase *memguard.LockedBuffer) ([]by return nil, fmt.Errorf("failed to create scrypt recipient: %w", err) } - return EncryptToRecipient(data, recipient) + // Create a secure buffer for the data + dataBuffer := memguard.NewBufferFromBytes(data) + defer dataBuffer.Destroy() + + return EncryptToRecipient(dataBuffer, recipient) } // DecryptWithPassphrase decrypts data using a passphrase with age's scrypt-based decryption @@ -138,4 +147,3 @@ func ReadPassphrase(prompt string) (*memguard.LockedBuffer, error) { return secureBuffer, nil } - diff --git a/internal/secret/keychainunlocker.go b/internal/secret/keychainunlocker.go index 1f48bf6..d6c9eb9 100644 --- a/internal/secret/keychainunlocker.go +++ b/internal/secret/keychainunlocker.go @@ -106,7 +106,7 @@ func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) { // Create secure buffer for the keychain passphrase passphraseBuffer := memguard.NewBufferFromBytes([]byte(keychainData.AgePrivKeyPassphrase)) defer passphraseBuffer.Destroy() - + agePrivKeyData, err := DecryptWithPassphrase(encryptedAgePrivKeyData, passphraseBuffer) if err != nil { Debug("Failed to decrypt age private key with keychain passphrase", "error", err, "unlocker_id", k.GetID()) @@ -245,7 +245,8 @@ func generateKeychainUnlockerName(vaultName string) (string, error) { } // getLongTermPrivateKey retrieves the long-term private key either from environment or current unlocker -func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) ([]byte, error) { +// Returns a LockedBuffer to ensure the private key is protected in memory +func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) (*memguard.LockedBuffer, error) { // Check if mnemonic is available in environment variable envMnemonic := os.Getenv(EnvMnemonic) if envMnemonic != "" { @@ -255,7 +256,8 @@ func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) ([]byte, error) { return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err) } - return []byte(ltIdentity.String()), nil + // Return the private key in a secure buffer + return memguard.NewBufferFromBytes([]byte(ltIdentity.String())), nil } // Get the vault to access current unlocker @@ -304,7 +306,8 @@ func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) ([]byte, error) { return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err) } - return ltPrivKeyData, nil + // Return the decrypted key in a secure buffer + return memguard.NewBufferFromBytes(ltPrivKeyData), nil } // CreateKeychainUnlocker creates a new keychain unlocker and stores it in the vault @@ -361,7 +364,7 @@ func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, er agePrivKeyStr := ageIdentity.String() agePrivKeyBuffer := memguard.NewBufferFromBytes([]byte(agePrivKeyStr)) defer agePrivKeyBuffer.Destroy() - + passphraseBuffer := memguard.NewBufferFromBytes([]byte(agePrivKeyPassphrase)) defer passphraseBuffer.Destroy() @@ -380,6 +383,7 @@ func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, er if err != nil { return nil, err } + defer ltPrivKeyData.Destroy() // Step 6: Encrypt long-term private key to the new age unlocker encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient()) diff --git a/internal/secret/passphrase_test.go b/internal/secret/passphrase_test.go index d58257b..36e56d3 100644 --- a/internal/secret/passphrase_test.go +++ b/internal/secret/passphrase_test.go @@ -113,8 +113,9 @@ func TestPassphraseUnlockerWithRealFS(t *testing.T) { t.Fatalf("Failed to parse recipient: %v", err) } - ltPrivKeyData := []byte(ltIdentity.String()) - encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyData, recipient) + ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String())) + defer ltPrivKeyBuffer.Destroy() + encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, recipient) if err != nil { t.Fatalf("Failed to encrypt long-term private key: %v", err) } diff --git a/internal/secret/passphraseunlocker.go b/internal/secret/passphraseunlocker.go index 37a269b..de1f882 100644 --- a/internal/secret/passphraseunlocker.go +++ b/internal/secret/passphraseunlocker.go @@ -36,7 +36,7 @@ func (p *PassphraseUnlocker) getPassphrase() (*memguard.LockedBuffer, error) { Debug("Using passphrase from environment", "unlocker_id", p.GetID()) // Convert to secure buffer secureBuffer := memguard.NewBufferFromBytes([]byte(passphraseStr)) - + return secureBuffer, nil } @@ -45,7 +45,7 @@ func (p *PassphraseUnlocker) getPassphrase() (*memguard.LockedBuffer, error) { secureBuffer, err := ReadPassphrase("Enter unlock passphrase: ") if err != nil { Debug("Failed to read passphrase", "error", err, "unlocker_id", p.GetID()) - + return nil, fmt.Errorf("failed to read passphrase: %w", err) } @@ -173,7 +173,11 @@ func NewPassphraseUnlocker(fs afero.Fs, directory string, metadata UnlockerMetad // CreatePassphraseUnlocker creates a new passphrase-protected unlocker // The passphrase must be provided as a LockedBuffer for security -func CreatePassphraseUnlocker(fs afero.Fs, stateDir string, passphrase *memguard.LockedBuffer) (*PassphraseUnlocker, error) { +func CreatePassphraseUnlocker( + fs afero.Fs, + stateDir string, + passphrase *memguard.LockedBuffer, +) (*PassphraseUnlocker, error) { // Get current vault currentVault, err := GetCurrentVault(fs, stateDir) if err != nil { diff --git a/internal/secret/pgpunlocker.go b/internal/secret/pgpunlocker.go index a938b12..d252477 100644 --- a/internal/secret/pgpunlocker.go +++ b/internal/secret/pgpunlocker.go @@ -12,6 +12,7 @@ import ( "time" "filippo.io/age" + "github.com/awnumar/memguard" "github.com/spf13/afero" ) @@ -233,6 +234,7 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc if err != nil { return nil, err } + defer ltPrivKeyData.Destroy() // Step 7: Encrypt long-term private key to the new age unlocker encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient()) @@ -247,8 +249,11 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc } // Step 8: Encrypt age private key to the GPG key ID - agePrivateKeyBytes := []byte(ageIdentity.String()) - encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBytes, gpgKeyID) + // Use memguard to protect the private key in memory + agePrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(ageIdentity.String())) + defer agePrivateKeyBuffer.Destroy() + + encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBuffer.Bytes(), gpgKeyID) if err != nil { return nil, fmt.Errorf("failed to encrypt age private key with GPG: %w", err) } diff --git a/internal/secret/secret.go b/internal/secret/secret.go index 66ce9ff..a6ea870 100644 --- a/internal/secret/secret.go +++ b/internal/secret/secret.go @@ -18,7 +18,7 @@ import ( // VaultInterface defines the interface that vault implementations must satisfy type VaultInterface interface { GetDirectory() (string, error) - AddSecret(name string, value []byte, force bool) error + AddSecret(name string, value *memguard.LockedBuffer, force bool) error GetName() string GetFilesystem() afero.Fs GetCurrentUnlocker() (Unlocker, error) @@ -72,7 +72,12 @@ func (s *Secret) Save(value []byte, force bool) error { slog.Bool("force", force), ) - err := s.vault.AddSecret(s.Name, value, force) + // Create a secure buffer for the value - note that the caller + // should ideally pass a LockedBuffer directly to vault.AddSecret + valueBuffer := memguard.NewBufferFromBytes(value) + defer valueBuffer.Destroy() + + err := s.vault.AddSecret(s.Name, valueBuffer, force) if err != nil { Debug("Failed to save secret", "error", err, "secret_name", s.Name) diff --git a/internal/secret/secret_test.go b/internal/secret/secret_test.go index 742e3bb..390f58a 100644 --- a/internal/secret/secret_test.go +++ b/internal/secret/secret_test.go @@ -26,7 +26,7 @@ func (m *MockVault) GetDirectory() (string, error) { return m.directory, nil } -func (m *MockVault) AddSecret(name string, value []byte, _ bool) error { +func (m *MockVault) AddSecret(name string, value *memguard.LockedBuffer, _ bool) error { // Create secret directory with proper storage name conversion storageName := strings.ReplaceAll(name, "/", "%") secretDir := filepath.Join(m.directory, "secrets.d", storageName) @@ -75,7 +75,7 @@ func (m *MockVault) AddSecret(name string, value []byte, _ bool) error { return err } - // Encrypt value to version's public key + // Encrypt value to version's public key (value is already a LockedBuffer) encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient()) if err != nil { return err @@ -88,7 +88,9 @@ func (m *MockVault) AddSecret(name string, value []byte, _ bool) error { } // Encrypt version private key to long-term public key - encryptedPrivKey, err := EncryptToRecipient([]byte(versionIdentity.String()), ltIdentity.Recipient()) + versionPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String())) + defer versionPrivKeyBuffer.Destroy() + encryptedPrivKey, err := EncryptToRecipient(versionPrivKeyBuffer, ltIdentity.Recipient()) if err != nil { return err } @@ -180,9 +182,13 @@ func TestPerSecretKeyFunctionality(t *testing.T) { secretName := "test-secret" secretValue := []byte("this is a test secret value") + // Create a secure buffer for the test value + valueBuffer := memguard.NewBufferFromBytes(secretValue) + defer valueBuffer.Destroy() + // Test AddSecret t.Run("AddSecret", func(t *testing.T) { - err := vault.AddSecret(secretName, secretValue, false) + err := vault.AddSecret(secretName, valueBuffer, false) if err != nil { t.Fatalf("AddSecret failed: %v", err) } diff --git a/internal/secret/version.go b/internal/secret/version.go index f3dfd2e..9036b2d 100644 --- a/internal/secret/version.go +++ b/internal/secret/version.go @@ -11,6 +11,7 @@ import ( "time" "filippo.io/age" + "github.com/awnumar/memguard" "github.com/oklog/ulid/v2" "github.com/spf13/afero" ) @@ -120,11 +121,15 @@ func GenerateVersionName(fs afero.Fs, secretDir string) (string, error) { } // Save saves the version metadata and value -func (sv *Version) Save(value []byte) error { +func (sv *Version) Save(value *memguard.LockedBuffer) error { + if value == nil { + return fmt.Errorf("value buffer is nil") + } + DebugWith("Saving secret version", slog.String("secret_name", sv.SecretName), slog.String("version", sv.Version), - slog.Int("value_length", len(value)), + slog.Int("value_length", value.Size()), ) fs := sv.vault.GetFilesystem() @@ -146,7 +151,9 @@ func (sv *Version) Save(value []byte) error { } versionPublicKey := versionIdentity.Recipient().String() - versionPrivateKey := versionIdentity.String() + // Store private key in memguard buffer immediately + versionPrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String())) + defer versionPrivateKeyBuffer.Destroy() DebugWith("Generated version keypair", slog.String("version", sv.Version), @@ -202,7 +209,7 @@ func (sv *Version) Save(value []byte) error { // Step 6: Encrypt the version's private key to the long-term public key Debug("Encrypting version private key to long-term public key", "version", sv.Version) - encryptedPrivKey, err := EncryptToRecipient([]byte(versionPrivateKey), ltRecipient) + encryptedPrivKey, err := EncryptToRecipient(versionPrivateKeyBuffer, ltRecipient) if err != nil { Debug("Failed to encrypt version private key", "error", err, "version", sv.Version) @@ -228,7 +235,10 @@ func (sv *Version) Save(value []byte) error { } // Encrypt metadata to the version's public key - encryptedMetadata, err := EncryptToRecipient(metadataBytes, versionIdentity.Recipient()) + metadataBuffer := memguard.NewBufferFromBytes(metadataBytes) + defer metadataBuffer.Destroy() + + encryptedMetadata, err := EncryptToRecipient(metadataBuffer, versionIdentity.Recipient()) if err != nil { Debug("Failed to encrypt version metadata", "error", err, "version", sv.Version) diff --git a/internal/secret/version_test.go b/internal/secret/version_test.go index 182e703..3ca95d3 100644 --- a/internal/secret/version_test.go +++ b/internal/secret/version_test.go @@ -59,7 +59,7 @@ func (m *MockVersionVault) GetDirectory() (string, error) { return filepath.Join(m.stateDir, "vaults.d", m.Name), nil } -func (m *MockVersionVault) AddSecret(_ string, _ []byte, _ bool) error { +func (m *MockVersionVault) AddSecret(_ string, _ *memguard.LockedBuffer, _ bool) error { return fmt.Errorf("not implemented in mock") } @@ -165,7 +165,9 @@ func TestSecretVersionSave(t *testing.T) { sv := NewVersion(vault, "test/secret", "20231215.001") testValue := []byte("test-secret-value") - err = sv.Save(testValue) + testBuffer := memguard.NewBufferFromBytes(testValue) + defer testBuffer.Destroy() + err = sv.Save(testBuffer) require.NoError(t, err) // Verify files were created @@ -203,7 +205,9 @@ func TestSecretVersionLoadMetadata(t *testing.T) { sv.Metadata.NotBefore = &epochPlusOne sv.Metadata.NotAfter = &now - err = sv.Save([]byte("test-value")) + testBuffer := memguard.NewBufferFromBytes([]byte("test-value")) + defer testBuffer.Destroy() + err = sv.Save(testBuffer) require.NoError(t, err) // Create new version object and load metadata @@ -242,15 +246,19 @@ func TestSecretVersionGetValue(t *testing.T) { // Create and save a version sv := NewVersion(vault, "test/secret", "20231215.001") originalValue := []byte("test-secret-value-12345") + expectedValue := make([]byte, len(originalValue)) + copy(expectedValue, originalValue) - err = sv.Save(originalValue) + originalBuffer := memguard.NewBufferFromBytes(originalValue) + defer originalBuffer.Destroy() + err = sv.Save(originalBuffer) require.NoError(t, err) // Retrieve the value retrievedValue, err := sv.GetValue(ltIdentity) require.NoError(t, err) - assert.Equal(t, originalValue, retrievedValue) + assert.Equal(t, expectedValue, retrievedValue) } func TestListVersions(t *testing.T) { diff --git a/internal/vault/integration_test.go b/internal/vault/integration_test.go index 393a88b..1ea8122 100644 --- a/internal/vault/integration_test.go +++ b/internal/vault/integration_test.go @@ -8,6 +8,7 @@ import ( "git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/pkg/agehd" + "github.com/awnumar/memguard" "github.com/spf13/afero" ) @@ -107,8 +108,13 @@ func TestVaultWithRealFilesystem(t *testing.T) { // Create a secret with a deeply nested path deepPath := "api/credentials/production/database/primary" secretValue := []byte("supersecretdbpassword") + expectedValue := make([]byte, len(secretValue)) + copy(expectedValue, secretValue) - err = vlt.AddSecret(deepPath, secretValue, false) + secretBuffer := memguard.NewBufferFromBytes(secretValue) + defer secretBuffer.Destroy() + + err = vlt.AddSecret(deepPath, secretBuffer, false) if err != nil { t.Fatalf("Failed to add secret with deep path: %v", err) } @@ -137,9 +143,9 @@ func TestVaultWithRealFilesystem(t *testing.T) { t.Fatalf("Failed to retrieve deep path secret: %v", err) } - if string(retrievedValue) != string(secretValue) { + if string(retrievedValue) != string(expectedValue) { t.Errorf("Retrieved value doesn't match. Expected %q, got %q", - string(secretValue), string(retrievedValue)) + string(expectedValue), string(retrievedValue)) } }) @@ -368,7 +374,11 @@ func TestVaultWithRealFilesystem(t *testing.T) { // Add a secret to vault1 secretName := "test-secret" secretValue := []byte("secret in vault1") - if err := vault1.AddSecret(secretName, secretValue, false); err != nil { + + secretBuffer := memguard.NewBufferFromBytes(secretValue) + defer secretBuffer.Destroy() + + if err := vault1.AddSecret(secretName, secretBuffer, false); err != nil { t.Fatalf("Failed to add secret to vault1: %v", err) } diff --git a/internal/vault/integration_version_test.go b/internal/vault/integration_version_test.go index 3043f2d..ec83858 100644 --- a/internal/vault/integration_version_test.go +++ b/internal/vault/integration_version_test.go @@ -29,11 +29,21 @@ import ( "git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/pkg/agehd" + "github.com/awnumar/memguard" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// Helper function to add a secret to vault with proper buffer protection +func addTestSecret(t *testing.T, vault *Vault, name string, value []byte, force bool) { + t.Helper() + buffer := memguard.NewBufferFromBytes(value) + defer buffer.Destroy() + err := vault.AddSecret(name, buffer, force) + require.NoError(t, err) +} + // TestVersionIntegrationWorkflow tests the complete version workflow func TestVersionIntegrationWorkflow(t *testing.T) { fs := afero.NewMemMapFs() @@ -66,8 +76,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) { // Step 1: Create initial version t.Run("create_initial_version", func(t *testing.T) { - err := vault.AddSecret(secretName, []byte("version-1-data"), false) - require.NoError(t, err) + addTestSecret(t, vault, secretName, []byte("version-1-data"), false) // Verify secret can be retrieved value, err := vault.GetSecret(secretName) @@ -108,8 +117,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) { firstVersionName = versions[0] // Create second version - err = vault.AddSecret(secretName, []byte("version-2-data"), true) - require.NoError(t, err) + addTestSecret(t, vault, secretName, []byte("version-2-data"), true) // Verify new value is current value, err := vault.GetSecret(secretName) @@ -142,8 +150,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) { t.Run("create_third_version", func(t *testing.T) { time.Sleep(10 * time.Millisecond) - err := vault.AddSecret(secretName, []byte("version-3-data"), true) - require.NoError(t, err) + addTestSecret(t, vault, secretName, []byte("version-3-data"), true) // Verify we now have three versions secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test") @@ -214,8 +221,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) { secretDir := filepath.Join(vaultDir, "secrets.d", "limit%test", "versions") // Create 998 versions (we already have one from the first AddSecret) - err := vault.AddSecret(limitSecretName, []byte("initial"), false) - require.NoError(t, err) + addTestSecret(t, vault, limitSecretName, []byte("initial"), false) // Get today's date for consistent version names today := time.Now().Format("20060102") @@ -255,7 +261,9 @@ func TestVersionIntegrationWorkflow(t *testing.T) { assert.Error(t, err) // Try to add secret without force when it exists - err = vault.AddSecret(secretName, []byte("should-fail"), false) + failBuffer := memguard.NewBufferFromBytes([]byte("should-fail")) + defer failBuffer.Destroy() + err = vault.AddSecret(secretName, failBuffer, false) assert.Error(t, err) assert.Contains(t, err.Error(), "already exists") }) @@ -272,8 +280,7 @@ func TestVersionConcurrency(t *testing.T) { secretName := "concurrent/test" // Create initial version - err := vault.AddSecret(secretName, []byte("initial"), false) - require.NoError(t, err) + addTestSecret(t, vault, secretName, []byte("initial"), false) // Test concurrent reads t.Run("concurrent_reads", func(t *testing.T) { @@ -326,8 +333,10 @@ func TestVersionCompatibility(t *testing.T) { // Create old-style encrypted value directly in secret directory testValue := []byte("legacy-value") + testValueBuffer := memguard.NewBufferFromBytes(testValue) + defer testValueBuffer.Destroy() ltRecipient := ltIdentity.Recipient() - encrypted, err := secret.EncryptToRecipient(testValue, ltRecipient) + encrypted, err := secret.EncryptToRecipient(testValueBuffer, ltRecipient) require.NoError(t, err) valuePath := filepath.Join(secretDir, "value.age") diff --git a/internal/vault/secrets.go b/internal/vault/secrets.go index 4300094..8bc2805 100644 --- a/internal/vault/secrets.go +++ b/internal/vault/secrets.go @@ -11,6 +11,7 @@ import ( "filippo.io/age" "git.eeqj.de/sneak/secret/internal/secret" + "github.com/awnumar/memguard" "github.com/spf13/afero" ) @@ -98,11 +99,15 @@ func isValidSecretName(name string) bool { } // AddSecret adds a secret to this vault -func (v *Vault) AddSecret(name string, value []byte, force bool) error { +func (v *Vault) AddSecret(name string, value *memguard.LockedBuffer, force bool) error { + if value == nil { + return fmt.Errorf("value buffer is nil") + } + secret.DebugWith("Adding secret to vault", slog.String("vault_name", v.Name), slog.String("secret_name", name), - slog.Int("value_length", len(value)), + slog.Int("value_length", value.Size()), slog.Bool("force", force), ) @@ -195,7 +200,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error { // We'll update the previous version's notAfter after we save the new version } - // Save the new version + // Save the new version - pass the LockedBuffer directly if err := newVersion.Save(value); err != nil { secret.Debug("Failed to save new version", "error", err, "version", versionName) @@ -272,7 +277,10 @@ func updateVersionMetadata(fs afero.Fs, version *secret.Version, ltIdentity *age } // Encrypt metadata to the version's public key - encryptedMetadata, err := secret.EncryptToRecipient(metadataBytes, versionIdentity.Recipient()) + metadataBuffer := memguard.NewBufferFromBytes(metadataBytes) + defer metadataBuffer.Destroy() + + encryptedMetadata, err := secret.EncryptToRecipient(metadataBuffer, versionIdentity.Recipient()) if err != nil { return fmt.Errorf("failed to encrypt version metadata: %w", err) } diff --git a/internal/vault/secrets_version_test.go b/internal/vault/secrets_version_test.go index 0bf20b4..c846d33 100644 --- a/internal/vault/secrets_version_test.go +++ b/internal/vault/secrets_version_test.go @@ -24,11 +24,21 @@ import ( "git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/pkg/agehd" + "github.com/awnumar/memguard" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// Helper function to add a secret to vault with proper buffer protection +func addTestSecretToVault(t *testing.T, vault *Vault, name string, value []byte, force bool) { + t.Helper() + buffer := memguard.NewBufferFromBytes(value) + defer buffer.Destroy() + err := vault.AddSecret(name, buffer, force) + require.NoError(t, err) +} + // Helper function to create a vault with long-term key set up func createTestVaultWithKey(t *testing.T, fs afero.Fs, stateDir, vaultName string) *Vault { // Set mnemonic for testing @@ -65,9 +75,10 @@ func TestVaultAddSecretCreatesVersion(t *testing.T) { // Add a secret secretName := "test/secret" secretValue := []byte("initial-value") + expectedValue := make([]byte, len(secretValue)) + copy(expectedValue, secretValue) - err := vault.AddSecret(secretName, secretValue, false) - require.NoError(t, err) + addTestSecretToVault(t, vault, secretName, secretValue, false) // Check that version directory was created vaultDir, _ := vault.GetDirectory() @@ -88,7 +99,7 @@ func TestVaultAddSecretCreatesVersion(t *testing.T) { // Get the secret value retrievedValue, err := vault.GetSecret(secretName) require.NoError(t, err) - assert.Equal(t, secretValue, retrievedValue) + assert.Equal(t, expectedValue, retrievedValue) } func TestVaultAddSecretMultipleVersions(t *testing.T) { @@ -101,17 +112,17 @@ func TestVaultAddSecretMultipleVersions(t *testing.T) { secretName := "test/secret" // Add first version - err := vault.AddSecret(secretName, []byte("version-1"), false) - require.NoError(t, err) + addTestSecretToVault(t, vault, secretName, []byte("version-1"), false) // Try to add again without force - should fail - err = vault.AddSecret(secretName, []byte("version-2"), false) + failBuffer := memguard.NewBufferFromBytes([]byte("version-2")) + defer failBuffer.Destroy() + err := vault.AddSecret(secretName, failBuffer, false) assert.Error(t, err) assert.Contains(t, err.Error(), "already exists") // Add with force - should create new version - err = vault.AddSecret(secretName, []byte("version-2"), true) - require.NoError(t, err) + addTestSecretToVault(t, vault, secretName, []byte("version-2"), true) // Check that we have two versions vaultDir, _ := vault.GetDirectory() @@ -136,14 +147,12 @@ func TestVaultGetSecretVersion(t *testing.T) { secretName := "test/secret" // Add multiple versions - err := vault.AddSecret(secretName, []byte("version-1"), false) - require.NoError(t, err) + addTestSecretToVault(t, vault, secretName, []byte("version-1"), false) // Small delay to ensure different version names time.Sleep(10 * time.Millisecond) - err = vault.AddSecret(secretName, []byte("version-2"), true) - require.NoError(t, err) + addTestSecretToVault(t, vault, secretName, []byte("version-2"), true) // Get versions list vaultDir, _ := vault.GetDirectory() @@ -185,7 +194,9 @@ func TestVaultVersionTimestamps(t *testing.T) { // Add first version beforeFirst := time.Now() - err = vault.AddSecret(secretName, []byte("version-1"), false) + v1Buffer := memguard.NewBufferFromBytes([]byte("version-1")) + defer v1Buffer.Destroy() + err = vault.AddSecret(secretName, v1Buffer, false) require.NoError(t, err) afterFirst := time.Now() @@ -212,8 +223,7 @@ func TestVaultVersionTimestamps(t *testing.T) { // Add second version time.Sleep(10 * time.Millisecond) beforeSecond := time.Now() - err = vault.AddSecret(secretName, []byte("version-2"), true) - require.NoError(t, err) + addTestSecretToVault(t, vault, secretName, []byte("version-2"), true) afterSecond := time.Now() // Get updated versions @@ -249,11 +259,10 @@ func TestVaultGetNonExistentVersion(t *testing.T) { vault := createTestVaultWithKey(t, fs, stateDir, "test") // Add a secret - err := vault.AddSecret("test/secret", []byte("value"), false) - require.NoError(t, err) + addTestSecretToVault(t, vault, "test/secret", []byte("value"), false) // Try to get non-existent version - _, err = vault.GetSecretVersion("test/secret", "20991231.999") + _, err := vault.GetSecretVersion("test/secret", "20991231.999") assert.Error(t, err) assert.Contains(t, err.Error(), "not found") } @@ -281,7 +290,9 @@ func TestUpdateVersionMetadata(t *testing.T) { version.Metadata.NotAfter = nil // Save version - err = version.Save([]byte("test-value")) + testBuffer := memguard.NewBufferFromBytes([]byte("test-value")) + defer testBuffer.Destroy() + err = version.Save(testBuffer) require.NoError(t, err) // Update metadata diff --git a/internal/vault/unlockers.go b/internal/vault/unlockers.go index 2deab9e..11e6e4e 100644 --- a/internal/vault/unlockers.go +++ b/internal/vault/unlockers.go @@ -346,10 +346,7 @@ func (v *Vault) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*se // Encrypt private key with passphrase privKeyStr := unlockerIdentity.String() - privKeyBuffer := memguard.NewBufferFromBytes([]byte(privKeyStr)) - defer privKeyBuffer.Destroy() - - encryptedPrivKey, err := secret.EncryptWithPassphrase(privKeyBuffer.Bytes(), passphrase) + encryptedPrivKey, err := secret.EncryptWithPassphrase([]byte(privKeyStr), passphrase) if err != nil { return nil, fmt.Errorf("failed to encrypt unlocker private key: %w", err) } @@ -385,8 +382,10 @@ func (v *Vault) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*se return nil, fmt.Errorf("failed to get long-term key: %w", err) } - ltPrivKey := []byte(ltIdentity.String()) - encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKey, unlockerIdentity.Recipient()) + ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String())) + defer ltPrivKeyBuffer.Destroy() + + encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerIdentity.Recipient()) if err != nil { return nil, fmt.Errorf("failed to encrypt long-term private key: %w", err) } diff --git a/internal/vault/vault_test.go b/internal/vault/vault_test.go index be6ed9e..bed6752 100644 --- a/internal/vault/vault_test.go +++ b/internal/vault/vault_test.go @@ -122,8 +122,13 @@ func TestVaultOperations(t *testing.T) { // Now add a secret secretName := "test/secret" secretValue := []byte("test-secret-value") + expectedValue := make([]byte, len(secretValue)) + copy(expectedValue, secretValue) - err = vlt.AddSecret(secretName, secretValue, false) + secretBuffer := memguard.NewBufferFromBytes(secretValue) + defer secretBuffer.Destroy() + + err = vlt.AddSecret(secretName, secretBuffer, false) if err != nil { t.Fatalf("Failed to add secret: %v", err) } @@ -152,8 +157,8 @@ func TestVaultOperations(t *testing.T) { t.Fatalf("Failed to get secret: %v", err) } - if string(retrievedValue) != string(secretValue) { - t.Errorf("Expected secret value '%s', got '%s'", string(secretValue), string(retrievedValue)) + if string(retrievedValue) != string(expectedValue) { + t.Errorf("Expected secret value '%s', got '%s'", string(expectedValue), string(retrievedValue)) } })