uses protected memory buffers now for all secrets in ram

This commit is contained in:
2025-07-15 08:32:33 +02:00
parent d3ca006886
commit 7596049828
22 changed files with 786 additions and 133 deletions

View File

@@ -82,13 +82,15 @@ func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error {
return fmt.Errorf("failed to generate age key: %w", err)
}
ageSecretKey = identity.String()
// Store the generated key as a secret using secure buffer
secureBuffer := memguard.NewBufferFromBytes([]byte(ageSecretKey))
// Store the generated key directly in a secure buffer
identityStr := identity.String()
secureBuffer := memguard.NewBufferFromBytes([]byte(identityStr))
defer secureBuffer.Destroy()
err = vlt.AddSecret(secretName, secureBuffer.Bytes(), false)
// Set ageSecretKey for later use (we need it for encryption)
ageSecretKey = identityStr
err = vlt.AddSecret(secretName, secureBuffer, false)
if err != nil {
return fmt.Errorf("failed to store age key: %w", err)
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"git.eeqj.de/sneak/secret/internal/vault"
"github.com/awnumar/memguard"
"github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39"
)
@@ -136,7 +137,11 @@ func (cli *Instance) GenerateSecret(
return err
}
if err := vlt.AddSecret(secretName, []byte(secretValue), force); err != nil {
// Protect the generated secret immediately
secretBuffer := memguard.NewBufferFromBytes([]byte(secretValue))
defer secretBuffer.Destroy()
if err := vlt.AddSecret(secretName, secretBuffer, force); err != nil {
return err
}

View File

@@ -166,8 +166,11 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
}
// Encrypt long-term private key to unlocker
ltPrivKeyData := []byte(ltIdentity.String())
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyData, unlockerRecipient)
// Use memguard to protect the private key in memory
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
defer ltPrivKeyBuffer.Destroy()
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerRecipient)
if err != nil {
return fmt.Errorf("failed to encrypt long-term private key: %w", err)
}

View File

@@ -8,7 +8,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault"
"github.com/spf13/afero"
"github.com/awnumar/memguard"
"github.com/spf13/cobra"
)
@@ -103,6 +103,20 @@ func newImportCmd() *cobra.Command {
return cmd
}
// updateBufferSize updates the buffer size based on usage pattern
func updateBufferSize(currentSize int, sameSize *int) int {
*sameSize++
const doubleAfterBuffers = 2
const growthFactor = 2
if *sameSize >= doubleAfterBuffers {
*sameSize = 0
return currentSize * growthFactor
}
return currentSize
}
// AddSecret adds a secret to the current vault
func (cli *Instance) AddSecret(secretName string, force bool) error {
secret.Debug("CLI AddSecret starting", "secret_name", secretName, "force", force)
@@ -116,24 +130,82 @@ func (cli *Instance) AddSecret(secretName string, force bool) error {
secret.Debug("Got current vault", "vault_name", vlt.GetName())
// Read secret value from stdin
secret.Debug("Reading secret value from stdin")
value, err := io.ReadAll(cli.cmd.InOrStdin())
if err != nil {
return fmt.Errorf("failed to read secret value: %w", err)
// Read secret value directly into protected buffers
secret.Debug("Reading secret value from stdin into protected buffers")
const initialSize = 4 * 1024 // 4KB initial buffer
const maxSize = 100 * 1024 * 1024 // 100MB max
type bufferInfo struct {
buffer *memguard.LockedBuffer
used int
}
secret.Debug("Read secret value from stdin", "value_length", len(value))
var buffers []bufferInfo
defer func() {
for _, b := range buffers {
b.buffer.Destroy()
}
}()
// Remove trailing newline if present
if len(value) > 0 && value[len(value)-1] == '\n' {
value = value[:len(value)-1]
secret.Debug("Removed trailing newline", "new_length", len(value))
reader := cli.cmd.InOrStdin()
totalSize := 0
currentBufferSize := initialSize
sameSize := 0
for {
// Create a new buffer
buffer := memguard.NewBuffer(currentBufferSize)
n, err := io.ReadFull(reader, buffer.Bytes())
if n == 0 {
// No data read, destroy the unused buffer
buffer.Destroy()
} else {
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
totalSize += n
if totalSize > maxSize {
return fmt.Errorf("secret too large: exceeds 100MB limit")
}
// If we filled the buffer, consider growing for next iteration
if n == currentBufferSize {
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
}
}
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
} else if err != nil {
return fmt.Errorf("failed to read secret value: %w", err)
}
}
// Check for trailing newline in the last buffer
if len(buffers) > 0 && totalSize > 0 {
lastBuffer := &buffers[len(buffers)-1]
if lastBuffer.buffer.Bytes()[lastBuffer.used-1] == '\n' {
lastBuffer.used--
totalSize--
}
}
secret.Debug("Read secret value from stdin", "value_length", totalSize, "buffers", len(buffers))
// Combine all buffers into a single protected buffer
valueBuffer := memguard.NewBuffer(totalSize)
defer valueBuffer.Destroy()
offset := 0
for _, b := range buffers {
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
offset += b.used
}
// Add the secret to the vault
secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", len(value), "force", force)
if err := vlt.AddSecret(secretName, value, force); err != nil {
secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", valueBuffer.Size(), "force", force)
if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
secret.Debug("vault.AddSecret failed", "error", err)
return err
@@ -298,14 +370,77 @@ func (cli *Instance) ImportSecret(cmd *cobra.Command, secretName, sourceFile str
return err
}
// Read secret value from the source file
value, err := afero.ReadFile(cli.fs, sourceFile)
// Read secret value from the source file into protected buffers
file, err := cli.fs.Open(sourceFile)
if err != nil {
return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err)
return fmt.Errorf("failed to open file %s: %w", sourceFile, err)
}
defer func() {
if err := file.Close(); err != nil {
secret.Debug("Failed to close file", "error", err)
}
}()
const initialSize = 4 * 1024 // 4KB initial buffer
const maxSize = 100 * 1024 * 1024 // 100MB max
type bufferInfo struct {
buffer *memguard.LockedBuffer
used int
}
var buffers []bufferInfo
defer func() {
for _, b := range buffers {
b.buffer.Destroy()
}
}()
totalSize := 0
currentBufferSize := initialSize
sameSize := 0
for {
// Create a new buffer
buffer := memguard.NewBuffer(currentBufferSize)
n, err := io.ReadFull(file, buffer.Bytes())
if n == 0 {
// No data read, destroy the unused buffer
buffer.Destroy()
} else {
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
totalSize += n
if totalSize > maxSize {
return fmt.Errorf("secret file too large: exceeds 100MB limit")
}
// If we filled the buffer, consider growing for next iteration
if n == currentBufferSize {
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
}
}
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
} else if err != nil {
return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err)
}
}
// Combine all buffers into a single protected buffer
valueBuffer := memguard.NewBuffer(totalSize)
defer valueBuffer.Destroy()
offset := 0
for _, b := range buffers {
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
offset += b.used
}
// Store the secret in the vault
if err := vlt.AddSecret(secretName, value, force); err != nil {
if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
return err
}

View File

@@ -0,0 +1,425 @@
package cli
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"path/filepath"
"strings"
"testing"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestAddSecretVariousSizes tests adding secrets of various sizes through stdin
func TestAddSecretVariousSizes(t *testing.T) {
tests := []struct {
name string
size int
shouldError bool
errorMsg string
}{
{
name: "1KB secret",
size: 1024,
shouldError: false,
},
{
name: "10KB secret",
size: 10 * 1024,
shouldError: false,
},
{
name: "100KB secret",
size: 100 * 1024,
shouldError: false,
},
{
name: "1MB secret",
size: 1024 * 1024,
shouldError: false,
},
{
name: "10MB secret",
size: 10 * 1024 * 1024,
shouldError: false,
},
{
name: "99MB secret",
size: 99 * 1024 * 1024,
shouldError: false,
},
{
name: "100MB secret minus 1 byte",
size: 100*1024*1024 - 1,
shouldError: false,
},
{
name: "101MB secret - should fail",
size: 101 * 1024 * 1024,
shouldError: true,
errorMsg: "secret too large: exceeds 100MB limit",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Generate test data of specified size
testData := make([]byte, tt.size)
_, err = rand.Read(testData)
require.NoError(t, err)
// Add newline that will be stripped
testDataWithNewline := append(testData, '\n')
// Create fake stdin
stdin := bytes.NewReader(testDataWithNewline)
// Create command with fake stdin
cmd := &cobra.Command{}
cmd.SetIn(stdin)
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
cli.cmd = cmd
// Test adding the secret
secretName := fmt.Sprintf("test-secret-%d", tt.size)
err = cli.AddSecret(secretName, false)
if tt.shouldError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret(secretName)
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original (without newline)")
}
})
}
}
// TestImportSecretVariousSizes tests importing secrets of various sizes from files
func TestImportSecretVariousSizes(t *testing.T) {
tests := []struct {
name string
size int
shouldError bool
errorMsg string
}{
{
name: "1KB file",
size: 1024,
shouldError: false,
},
{
name: "10KB file",
size: 10 * 1024,
shouldError: false,
},
{
name: "100KB file",
size: 100 * 1024,
shouldError: false,
},
{
name: "1MB file",
size: 1024 * 1024,
shouldError: false,
},
{
name: "10MB file",
size: 10 * 1024 * 1024,
shouldError: false,
},
{
name: "99MB file",
size: 99 * 1024 * 1024,
shouldError: false,
},
{
name: "100MB file",
size: 100 * 1024 * 1024,
shouldError: false,
},
{
name: "101MB file - should fail",
size: 101 * 1024 * 1024,
shouldError: true,
errorMsg: "secret file too large: exceeds 100MB limit",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Generate test data of specified size
testData := make([]byte, tt.size)
_, err = rand.Read(testData)
require.NoError(t, err)
// Write test data to file
testFile := fmt.Sprintf("/test/secret-%d.bin", tt.size)
err = afero.WriteFile(fs, testFile, testData, 0o600)
require.NoError(t, err)
// Create command
cmd := &cobra.Command{}
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
// Test importing the secret
secretName := fmt.Sprintf("imported-secret-%d", tt.size)
err = cli.ImportSecret(cmd, secretName, testFile, false)
if tt.shouldError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret(secretName)
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
}
})
}
}
// TestAddSecretBufferGrowth tests that our buffer growth strategy works correctly
func TestAddSecretBufferGrowth(t *testing.T) {
// Test various sizes that should trigger buffer growth
sizes := []int{
1, // Single byte
100, // Small
4095, // Just under initial 4KB
4096, // Exactly 4KB
4097, // Just over 4KB
8191, // Just under 8KB (first double)
8192, // Exactly 8KB
8193, // Just over 8KB
12288, // 12KB (should trigger second double)
16384, // 16KB
32768, // 32KB (after more doublings)
65536, // 64KB
131072, // 128KB
524288, // 512KB
1048576, // 1MB
2097152, // 2MB
}
for _, size := range sizes {
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Create test data of exactly the specified size
// Use a pattern that's easy to verify
testData := make([]byte, size)
for i := range testData {
testData[i] = byte(i % 256)
}
// Create fake stdin without newline
stdin := bytes.NewReader(testData)
// Create command with fake stdin
cmd := &cobra.Command{}
cmd.SetIn(stdin)
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
cli.cmd = cmd
// Test adding the secret
secretName := fmt.Sprintf("buffer-test-%d", size)
err = cli.AddSecret(secretName, false)
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret(secretName)
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original exactly")
})
}
}
// TestAddSecretStreamingBehavior tests that we handle streaming input correctly
func TestAddSecretStreamingBehavior(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Create a custom reader that simulates slow streaming input
// This will help verify our buffer handling works correctly with partial reads
testData := []byte(strings.Repeat("Hello, World! ", 1000)) // ~14KB
slowReader := &slowReader{
data: testData,
chunkSize: 1000, // Read 1KB at a time
}
// Create command with slow reader as stdin
cmd := &cobra.Command{}
cmd.SetIn(slowReader)
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
cli.cmd = cmd
// Test adding the secret
err = cli.AddSecret("streaming-test", false)
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret("streaming-test")
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
}
// slowReader simulates a reader that returns data in small chunks
type slowReader struct {
data []byte
offset int
chunkSize int
}
func (r *slowReader) Read(p []byte) (n int, err error) {
if r.offset >= len(r.data) {
return 0, io.EOF
}
// Read at most chunkSize bytes
remaining := len(r.data) - r.offset
toRead := r.chunkSize
if toRead > remaining {
toRead = remaining
}
if toRead > len(p) {
toRead = len(p)
}
n = copy(p, r.data[r.offset:r.offset+toRead])
r.offset += n
if r.offset >= len(r.data) {
err = io.EOF
}
return n, err
}

View File

@@ -26,11 +26,21 @@ import (
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Helper function to add a secret to vault with proper buffer protection
func addTestSecret(t *testing.T, vlt *vault.Vault, name string, value []byte, force bool) {
t.Helper()
buffer := memguard.NewBufferFromBytes(value)
defer buffer.Destroy()
err := vlt.AddSecret(name, buffer, force)
require.NoError(t, err)
}
// Helper function to set up a vault with long-term key
func setupTestVault(t *testing.T, fs afero.Fs, stateDir string) {
// Set mnemonic for testing
@@ -70,13 +80,11 @@ func TestListVersionsCommand(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("version-1"), false)
require.NoError(t, err)
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
time.Sleep(10 * time.Millisecond)
err = vlt.AddSecret("test/secret", []byte("version-2"), true)
require.NoError(t, err)
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
// Create a command for output capture
cmd := newRootCmd()
@@ -144,13 +152,11 @@ func TestPromoteVersionCommand(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("version-1"), false)
require.NoError(t, err)
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
time.Sleep(10 * time.Millisecond)
err = vlt.AddSecret("test/secret", []byte("version-2"), true)
require.NoError(t, err)
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
// Get versions
vaultDir, _ := vlt.GetDirectory()
@@ -201,8 +207,7 @@ func TestPromoteNonExistentVersion(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("value"), false)
require.NoError(t, err)
addTestSecret(t, vlt, "test/secret", []byte("value"), false)
// Create a command for output capture
cmd := newRootCmd()
@@ -228,13 +233,11 @@ func TestGetSecretWithVersion(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("version-1"), false)
require.NoError(t, err)
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
time.Sleep(10 * time.Millisecond)
err = vlt.AddSecret("test/secret", []byte("version-2"), true)
require.NoError(t, err)
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
// Get versions
vaultDir, _ := vlt.GetDirectory()