uses protected memory buffers now for all secrets in ram
This commit is contained in:
parent
d3ca006886
commit
7596049828
13
TODO.md
13
TODO.md
@ -6,22 +6,9 @@ prioritized from most critical (top) to least critical (bottom).
|
|||||||
|
|
||||||
## Code Cleanups
|
## Code Cleanups
|
||||||
|
|
||||||
* none of the integration tests should be searching for a binary or trying
|
|
||||||
to execute another process. the integration tests cannot make another
|
|
||||||
process or depend on a compiled file, they must do all of their testing in
|
|
||||||
the current (test) process.
|
|
||||||
|
|
||||||
* we shouldn't be passing around a statedir, it should be read from the
|
* we shouldn't be passing around a statedir, it should be read from the
|
||||||
environment or default.
|
environment or default.
|
||||||
|
|
||||||
## CRITICAL SECURITY ISSUES - Must Fix Before 1.0
|
|
||||||
|
|
||||||
- [ ] **1. Memory security vulnerabilities**: Sensitive data (passwords,
|
|
||||||
private keys, passphrases) stored as strings are not properly zeroed from
|
|
||||||
memory after use. Memory dumps or swap files could expose secrets. Found
|
|
||||||
in crypto.go:107, passphraseunlocker.go:29-48, cli/crypto.go:89,193,
|
|
||||||
pgpunlocker.go:278, keychainunlocker.go:252,346.
|
|
||||||
|
|
||||||
## HIGH PRIORITY SECURITY ISSUES
|
## HIGH PRIORITY SECURITY ISSUES
|
||||||
|
|
||||||
- [ ] **4. Application crashes on corrupted metadata**: Code panics instead
|
- [ ] **4. Application crashes on corrupted metadata**: Code panics instead
|
||||||
|
@ -82,13 +82,15 @@ func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error {
|
|||||||
return fmt.Errorf("failed to generate age key: %w", err)
|
return fmt.Errorf("failed to generate age key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ageSecretKey = identity.String()
|
// Store the generated key directly in a secure buffer
|
||||||
|
identityStr := identity.String()
|
||||||
// Store the generated key as a secret using secure buffer
|
secureBuffer := memguard.NewBufferFromBytes([]byte(identityStr))
|
||||||
secureBuffer := memguard.NewBufferFromBytes([]byte(ageSecretKey))
|
|
||||||
defer secureBuffer.Destroy()
|
defer secureBuffer.Destroy()
|
||||||
|
|
||||||
err = vlt.AddSecret(secretName, secureBuffer.Bytes(), false)
|
// Set ageSecretKey for later use (we need it for encryption)
|
||||||
|
ageSecretKey = identityStr
|
||||||
|
|
||||||
|
err = vlt.AddSecret(secretName, secureBuffer, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to store age key: %w", err)
|
return fmt.Errorf("failed to store age key: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.eeqj.de/sneak/secret/internal/vault"
|
"git.eeqj.de/sneak/secret/internal/vault"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/tyler-smith/go-bip39"
|
"github.com/tyler-smith/go-bip39"
|
||||||
)
|
)
|
||||||
@ -136,7 +137,11 @@ func (cli *Instance) GenerateSecret(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := vlt.AddSecret(secretName, []byte(secretValue), force); err != nil {
|
// Protect the generated secret immediately
|
||||||
|
secretBuffer := memguard.NewBufferFromBytes([]byte(secretValue))
|
||||||
|
defer secretBuffer.Destroy()
|
||||||
|
|
||||||
|
if err := vlt.AddSecret(secretName, secretBuffer, force); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,8 +166,11 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt long-term private key to unlocker
|
// Encrypt long-term private key to unlocker
|
||||||
ltPrivKeyData := []byte(ltIdentity.String())
|
// Use memguard to protect the private key in memory
|
||||||
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyData, unlockerRecipient)
|
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
|
||||||
|
defer ltPrivKeyBuffer.Destroy()
|
||||||
|
|
||||||
|
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerRecipient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to encrypt long-term private key: %w", err)
|
return fmt.Errorf("failed to encrypt long-term private key: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
"git.eeqj.de/sneak/secret/internal/secret"
|
"git.eeqj.de/sneak/secret/internal/secret"
|
||||||
"git.eeqj.de/sneak/secret/internal/vault"
|
"git.eeqj.de/sneak/secret/internal/vault"
|
||||||
"github.com/spf13/afero"
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,6 +103,20 @@ func newImportCmd() *cobra.Command {
|
|||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateBufferSize updates the buffer size based on usage pattern
|
||||||
|
func updateBufferSize(currentSize int, sameSize *int) int {
|
||||||
|
*sameSize++
|
||||||
|
const doubleAfterBuffers = 2
|
||||||
|
const growthFactor = 2
|
||||||
|
if *sameSize >= doubleAfterBuffers {
|
||||||
|
*sameSize = 0
|
||||||
|
|
||||||
|
return currentSize * growthFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentSize
|
||||||
|
}
|
||||||
|
|
||||||
// AddSecret adds a secret to the current vault
|
// AddSecret adds a secret to the current vault
|
||||||
func (cli *Instance) AddSecret(secretName string, force bool) error {
|
func (cli *Instance) AddSecret(secretName string, force bool) error {
|
||||||
secret.Debug("CLI AddSecret starting", "secret_name", secretName, "force", force)
|
secret.Debug("CLI AddSecret starting", "secret_name", secretName, "force", force)
|
||||||
@ -116,24 +130,82 @@ func (cli *Instance) AddSecret(secretName string, force bool) error {
|
|||||||
|
|
||||||
secret.Debug("Got current vault", "vault_name", vlt.GetName())
|
secret.Debug("Got current vault", "vault_name", vlt.GetName())
|
||||||
|
|
||||||
// Read secret value from stdin
|
// Read secret value directly into protected buffers
|
||||||
secret.Debug("Reading secret value from stdin")
|
secret.Debug("Reading secret value from stdin into protected buffers")
|
||||||
value, err := io.ReadAll(cli.cmd.InOrStdin())
|
|
||||||
if err != nil {
|
const initialSize = 4 * 1024 // 4KB initial buffer
|
||||||
return fmt.Errorf("failed to read secret value: %w", err)
|
const maxSize = 100 * 1024 * 1024 // 100MB max
|
||||||
|
|
||||||
|
type bufferInfo struct {
|
||||||
|
buffer *memguard.LockedBuffer
|
||||||
|
used int
|
||||||
}
|
}
|
||||||
|
|
||||||
secret.Debug("Read secret value from stdin", "value_length", len(value))
|
var buffers []bufferInfo
|
||||||
|
defer func() {
|
||||||
|
for _, b := range buffers {
|
||||||
|
b.buffer.Destroy()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Remove trailing newline if present
|
reader := cli.cmd.InOrStdin()
|
||||||
if len(value) > 0 && value[len(value)-1] == '\n' {
|
totalSize := 0
|
||||||
value = value[:len(value)-1]
|
currentBufferSize := initialSize
|
||||||
secret.Debug("Removed trailing newline", "new_length", len(value))
|
sameSize := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Create a new buffer
|
||||||
|
buffer := memguard.NewBuffer(currentBufferSize)
|
||||||
|
n, err := io.ReadFull(reader, buffer.Bytes())
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
// No data read, destroy the unused buffer
|
||||||
|
buffer.Destroy()
|
||||||
|
} else {
|
||||||
|
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
|
||||||
|
totalSize += n
|
||||||
|
|
||||||
|
if totalSize > maxSize {
|
||||||
|
return fmt.Errorf("secret too large: exceeds 100MB limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we filled the buffer, consider growing for next iteration
|
||||||
|
if n == currentBufferSize {
|
||||||
|
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to read secret value: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for trailing newline in the last buffer
|
||||||
|
if len(buffers) > 0 && totalSize > 0 {
|
||||||
|
lastBuffer := &buffers[len(buffers)-1]
|
||||||
|
if lastBuffer.buffer.Bytes()[lastBuffer.used-1] == '\n' {
|
||||||
|
lastBuffer.used--
|
||||||
|
totalSize--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
secret.Debug("Read secret value from stdin", "value_length", totalSize, "buffers", len(buffers))
|
||||||
|
|
||||||
|
// Combine all buffers into a single protected buffer
|
||||||
|
valueBuffer := memguard.NewBuffer(totalSize)
|
||||||
|
defer valueBuffer.Destroy()
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
for _, b := range buffers {
|
||||||
|
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
|
||||||
|
offset += b.used
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the secret to the vault
|
// Add the secret to the vault
|
||||||
secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", len(value), "force", force)
|
secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", valueBuffer.Size(), "force", force)
|
||||||
if err := vlt.AddSecret(secretName, value, force); err != nil {
|
if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
|
||||||
secret.Debug("vault.AddSecret failed", "error", err)
|
secret.Debug("vault.AddSecret failed", "error", err)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@ -298,14 +370,77 @@ func (cli *Instance) ImportSecret(cmd *cobra.Command, secretName, sourceFile str
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read secret value from the source file
|
// Read secret value from the source file into protected buffers
|
||||||
value, err := afero.ReadFile(cli.fs, sourceFile)
|
file, err := cli.fs.Open(sourceFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err)
|
return fmt.Errorf("failed to open file %s: %w", sourceFile, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := file.Close(); err != nil {
|
||||||
|
secret.Debug("Failed to close file", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
const initialSize = 4 * 1024 // 4KB initial buffer
|
||||||
|
const maxSize = 100 * 1024 * 1024 // 100MB max
|
||||||
|
|
||||||
|
type bufferInfo struct {
|
||||||
|
buffer *memguard.LockedBuffer
|
||||||
|
used int
|
||||||
|
}
|
||||||
|
|
||||||
|
var buffers []bufferInfo
|
||||||
|
defer func() {
|
||||||
|
for _, b := range buffers {
|
||||||
|
b.buffer.Destroy()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
totalSize := 0
|
||||||
|
currentBufferSize := initialSize
|
||||||
|
sameSize := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Create a new buffer
|
||||||
|
buffer := memguard.NewBuffer(currentBufferSize)
|
||||||
|
n, err := io.ReadFull(file, buffer.Bytes())
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
// No data read, destroy the unused buffer
|
||||||
|
buffer.Destroy()
|
||||||
|
} else {
|
||||||
|
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
|
||||||
|
totalSize += n
|
||||||
|
|
||||||
|
if totalSize > maxSize {
|
||||||
|
return fmt.Errorf("secret file too large: exceeds 100MB limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we filled the buffer, consider growing for next iteration
|
||||||
|
if n == currentBufferSize {
|
||||||
|
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine all buffers into a single protected buffer
|
||||||
|
valueBuffer := memguard.NewBuffer(totalSize)
|
||||||
|
defer valueBuffer.Destroy()
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
for _, b := range buffers {
|
||||||
|
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
|
||||||
|
offset += b.used
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the secret in the vault
|
// Store the secret in the vault
|
||||||
if err := vlt.AddSecret(secretName, value, force); err != nil {
|
if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
425
internal/cli/secrets_size_test.go
Normal file
425
internal/cli/secrets_size_test.go
Normal file
@ -0,0 +1,425 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.eeqj.de/sneak/secret/internal/secret"
|
||||||
|
"git.eeqj.de/sneak/secret/internal/vault"
|
||||||
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||||
|
"github.com/spf13/afero"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddSecretVariousSizes tests adding secrets of various sizes through stdin
|
||||||
|
func TestAddSecretVariousSizes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
size int
|
||||||
|
shouldError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1KB secret",
|
||||||
|
size: 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "10KB secret",
|
||||||
|
size: 10 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100KB secret",
|
||||||
|
size: 100 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1MB secret",
|
||||||
|
size: 1024 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "10MB secret",
|
||||||
|
size: 10 * 1024 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "99MB secret",
|
||||||
|
size: 99 * 1024 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100MB secret minus 1 byte",
|
||||||
|
size: 100*1024*1024 - 1,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "101MB secret - should fail",
|
||||||
|
size: 101 * 1024 * 1024,
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "secret too large: exceeds 100MB limit",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Set up test environment
|
||||||
|
fs := afero.NewMemMapFs()
|
||||||
|
stateDir := "/test/state"
|
||||||
|
|
||||||
|
// Set test mnemonic
|
||||||
|
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||||
|
|
||||||
|
// Create vault
|
||||||
|
vaultName := "test-vault"
|
||||||
|
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set current vault
|
||||||
|
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||||
|
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||||
|
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get vault and set up long-term key
|
||||||
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
vlt.Unlock(ltIdentity)
|
||||||
|
|
||||||
|
// Generate test data of specified size
|
||||||
|
testData := make([]byte, tt.size)
|
||||||
|
_, err = rand.Read(testData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add newline that will be stripped
|
||||||
|
testDataWithNewline := append(testData, '\n')
|
||||||
|
|
||||||
|
// Create fake stdin
|
||||||
|
stdin := bytes.NewReader(testDataWithNewline)
|
||||||
|
|
||||||
|
// Create command with fake stdin
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetIn(stdin)
|
||||||
|
|
||||||
|
// Create CLI instance
|
||||||
|
cli := NewCLIInstance()
|
||||||
|
cli.fs = fs
|
||||||
|
cli.stateDir = stateDir
|
||||||
|
cli.cmd = cmd
|
||||||
|
|
||||||
|
// Test adding the secret
|
||||||
|
secretName := fmt.Sprintf("test-secret-%d", tt.size)
|
||||||
|
err = cli.AddSecret(secretName, false)
|
||||||
|
|
||||||
|
if tt.shouldError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the secret was stored correctly
|
||||||
|
retrievedValue, err := vlt.GetSecret(secretName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original (without newline)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestImportSecretVariousSizes tests importing secrets of various sizes from files
|
||||||
|
func TestImportSecretVariousSizes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
size int
|
||||||
|
shouldError bool
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1KB file",
|
||||||
|
size: 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "10KB file",
|
||||||
|
size: 10 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100KB file",
|
||||||
|
size: 100 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1MB file",
|
||||||
|
size: 1024 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "10MB file",
|
||||||
|
size: 10 * 1024 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "99MB file",
|
||||||
|
size: 99 * 1024 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100MB file",
|
||||||
|
size: 100 * 1024 * 1024,
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "101MB file - should fail",
|
||||||
|
size: 101 * 1024 * 1024,
|
||||||
|
shouldError: true,
|
||||||
|
errorMsg: "secret file too large: exceeds 100MB limit",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Set up test environment
|
||||||
|
fs := afero.NewMemMapFs()
|
||||||
|
stateDir := "/test/state"
|
||||||
|
|
||||||
|
// Set test mnemonic
|
||||||
|
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||||
|
|
||||||
|
// Create vault
|
||||||
|
vaultName := "test-vault"
|
||||||
|
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set current vault
|
||||||
|
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||||
|
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||||
|
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get vault and set up long-term key
|
||||||
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
vlt.Unlock(ltIdentity)
|
||||||
|
|
||||||
|
// Generate test data of specified size
|
||||||
|
testData := make([]byte, tt.size)
|
||||||
|
_, err = rand.Read(testData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Write test data to file
|
||||||
|
testFile := fmt.Sprintf("/test/secret-%d.bin", tt.size)
|
||||||
|
err = afero.WriteFile(fs, testFile, testData, 0o600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create command
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
|
||||||
|
// Create CLI instance
|
||||||
|
cli := NewCLIInstance()
|
||||||
|
cli.fs = fs
|
||||||
|
cli.stateDir = stateDir
|
||||||
|
|
||||||
|
// Test importing the secret
|
||||||
|
secretName := fmt.Sprintf("imported-secret-%d", tt.size)
|
||||||
|
err = cli.ImportSecret(cmd, secretName, testFile, false)
|
||||||
|
|
||||||
|
if tt.shouldError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the secret was stored correctly
|
||||||
|
retrievedValue, err := vlt.GetSecret(secretName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddSecretBufferGrowth tests that our buffer growth strategy works correctly
|
||||||
|
func TestAddSecretBufferGrowth(t *testing.T) {
|
||||||
|
// Test various sizes that should trigger buffer growth
|
||||||
|
sizes := []int{
|
||||||
|
1, // Single byte
|
||||||
|
100, // Small
|
||||||
|
4095, // Just under initial 4KB
|
||||||
|
4096, // Exactly 4KB
|
||||||
|
4097, // Just over 4KB
|
||||||
|
8191, // Just under 8KB (first double)
|
||||||
|
8192, // Exactly 8KB
|
||||||
|
8193, // Just over 8KB
|
||||||
|
12288, // 12KB (should trigger second double)
|
||||||
|
16384, // 16KB
|
||||||
|
32768, // 32KB (after more doublings)
|
||||||
|
65536, // 64KB
|
||||||
|
131072, // 128KB
|
||||||
|
524288, // 512KB
|
||||||
|
1048576, // 1MB
|
||||||
|
2097152, // 2MB
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, size := range sizes {
|
||||||
|
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
|
||||||
|
// Set up test environment
|
||||||
|
fs := afero.NewMemMapFs()
|
||||||
|
stateDir := "/test/state"
|
||||||
|
|
||||||
|
// Set test mnemonic
|
||||||
|
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||||
|
|
||||||
|
// Create vault
|
||||||
|
vaultName := "test-vault"
|
||||||
|
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set current vault
|
||||||
|
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||||
|
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||||
|
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get vault and set up long-term key
|
||||||
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
vlt.Unlock(ltIdentity)
|
||||||
|
|
||||||
|
// Create test data of exactly the specified size
|
||||||
|
// Use a pattern that's easy to verify
|
||||||
|
testData := make([]byte, size)
|
||||||
|
for i := range testData {
|
||||||
|
testData[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create fake stdin without newline
|
||||||
|
stdin := bytes.NewReader(testData)
|
||||||
|
|
||||||
|
// Create command with fake stdin
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetIn(stdin)
|
||||||
|
|
||||||
|
// Create CLI instance
|
||||||
|
cli := NewCLIInstance()
|
||||||
|
cli.fs = fs
|
||||||
|
cli.stateDir = stateDir
|
||||||
|
cli.cmd = cmd
|
||||||
|
|
||||||
|
// Test adding the secret
|
||||||
|
secretName := fmt.Sprintf("buffer-test-%d", size)
|
||||||
|
err = cli.AddSecret(secretName, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the secret was stored correctly
|
||||||
|
retrievedValue, err := vlt.GetSecret(secretName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original exactly")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddSecretStreamingBehavior tests that we handle streaming input correctly
|
||||||
|
func TestAddSecretStreamingBehavior(t *testing.T) {
|
||||||
|
// Set up test environment
|
||||||
|
fs := afero.NewMemMapFs()
|
||||||
|
stateDir := "/test/state"
|
||||||
|
|
||||||
|
// Set test mnemonic
|
||||||
|
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||||
|
|
||||||
|
// Create vault
|
||||||
|
vaultName := "test-vault"
|
||||||
|
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set current vault
|
||||||
|
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||||
|
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||||
|
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Get vault and set up long-term key
|
||||||
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
vlt.Unlock(ltIdentity)
|
||||||
|
|
||||||
|
// Create a custom reader that simulates slow streaming input
|
||||||
|
// This will help verify our buffer handling works correctly with partial reads
|
||||||
|
testData := []byte(strings.Repeat("Hello, World! ", 1000)) // ~14KB
|
||||||
|
slowReader := &slowReader{
|
||||||
|
data: testData,
|
||||||
|
chunkSize: 1000, // Read 1KB at a time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create command with slow reader as stdin
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetIn(slowReader)
|
||||||
|
|
||||||
|
// Create CLI instance
|
||||||
|
cli := NewCLIInstance()
|
||||||
|
cli.fs = fs
|
||||||
|
cli.stateDir = stateDir
|
||||||
|
cli.cmd = cmd
|
||||||
|
|
||||||
|
// Test adding the secret
|
||||||
|
err = cli.AddSecret("streaming-test", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the secret was stored correctly
|
||||||
|
retrievedValue, err := vlt.GetSecret("streaming-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
|
||||||
|
}
|
||||||
|
|
||||||
|
// slowReader simulates a reader that returns data in small chunks
|
||||||
|
type slowReader struct {
|
||||||
|
data []byte
|
||||||
|
offset int
|
||||||
|
chunkSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *slowReader) Read(p []byte) (n int, err error) {
|
||||||
|
if r.offset >= len(r.data) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read at most chunkSize bytes
|
||||||
|
remaining := len(r.data) - r.offset
|
||||||
|
toRead := r.chunkSize
|
||||||
|
if toRead > remaining {
|
||||||
|
toRead = remaining
|
||||||
|
}
|
||||||
|
if toRead > len(p) {
|
||||||
|
toRead = len(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
n = copy(p, r.data[r.offset:r.offset+toRead])
|
||||||
|
r.offset += n
|
||||||
|
|
||||||
|
if r.offset >= len(r.data) {
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
@ -26,11 +26,21 @@ import (
|
|||||||
"git.eeqj.de/sneak/secret/internal/secret"
|
"git.eeqj.de/sneak/secret/internal/secret"
|
||||||
"git.eeqj.de/sneak/secret/internal/vault"
|
"git.eeqj.de/sneak/secret/internal/vault"
|
||||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Helper function to add a secret to vault with proper buffer protection
|
||||||
|
func addTestSecret(t *testing.T, vlt *vault.Vault, name string, value []byte, force bool) {
|
||||||
|
t.Helper()
|
||||||
|
buffer := memguard.NewBufferFromBytes(value)
|
||||||
|
defer buffer.Destroy()
|
||||||
|
err := vlt.AddSecret(name, buffer, force)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to set up a vault with long-term key
|
// Helper function to set up a vault with long-term key
|
||||||
func setupTestVault(t *testing.T, fs afero.Fs, stateDir string) {
|
func setupTestVault(t *testing.T, fs afero.Fs, stateDir string) {
|
||||||
// Set mnemonic for testing
|
// Set mnemonic for testing
|
||||||
@ -70,13 +80,11 @@ func TestListVersionsCommand(t *testing.T) {
|
|||||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = vlt.AddSecret("test/secret", []byte("version-1"), false)
|
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
err = vlt.AddSecret("test/secret", []byte("version-2"), true)
|
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create a command for output capture
|
// Create a command for output capture
|
||||||
cmd := newRootCmd()
|
cmd := newRootCmd()
|
||||||
@ -144,13 +152,11 @@ func TestPromoteVersionCommand(t *testing.T) {
|
|||||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = vlt.AddSecret("test/secret", []byte("version-1"), false)
|
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
err = vlt.AddSecret("test/secret", []byte("version-2"), true)
|
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Get versions
|
// Get versions
|
||||||
vaultDir, _ := vlt.GetDirectory()
|
vaultDir, _ := vlt.GetDirectory()
|
||||||
@ -201,8 +207,7 @@ func TestPromoteNonExistentVersion(t *testing.T) {
|
|||||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = vlt.AddSecret("test/secret", []byte("value"), false)
|
addTestSecret(t, vlt, "test/secret", []byte("value"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create a command for output capture
|
// Create a command for output capture
|
||||||
cmd := newRootCmd()
|
cmd := newRootCmd()
|
||||||
@ -228,13 +233,11 @@ func TestGetSecretWithVersion(t *testing.T) {
|
|||||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = vlt.AddSecret("test/secret", []byte("version-1"), false)
|
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
err = vlt.AddSecret("test/secret", []byte("version-2"), true)
|
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Get versions
|
// Get versions
|
||||||
vaultDir, _ := vlt.GetDirectory()
|
vaultDir, _ := vlt.GetDirectory()
|
||||||
|
@ -13,8 +13,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// EncryptToRecipient encrypts data to a recipient using age
|
// EncryptToRecipient encrypts data to a recipient using age
|
||||||
func EncryptToRecipient(data []byte, recipient age.Recipient) ([]byte, error) {
|
// The data parameter should be a LockedBuffer for secure memory handling
|
||||||
Debug("EncryptToRecipient starting", "data_length", len(data))
|
func EncryptToRecipient(data *memguard.LockedBuffer, recipient age.Recipient) ([]byte, error) {
|
||||||
|
if data == nil {
|
||||||
|
return nil, fmt.Errorf("data buffer is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
Debug("EncryptToRecipient starting", "data_length", data.Size())
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
Debug("Creating age encryptor")
|
Debug("Creating age encryptor")
|
||||||
@ -27,7 +32,7 @@ func EncryptToRecipient(data []byte, recipient age.Recipient) ([]byte, error) {
|
|||||||
Debug("Created age encryptor successfully")
|
Debug("Created age encryptor successfully")
|
||||||
|
|
||||||
Debug("Writing data to encryptor")
|
Debug("Writing data to encryptor")
|
||||||
if _, err := w.Write(data); err != nil {
|
if _, err := w.Write(data.Bytes()); err != nil {
|
||||||
Debug("Failed to write data to encryptor", "error", err)
|
Debug("Failed to write data to encryptor", "error", err)
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed to write data: %w", err)
|
return nil, fmt.Errorf("failed to write data: %w", err)
|
||||||
@ -77,7 +82,11 @@ func EncryptWithPassphrase(data []byte, passphrase *memguard.LockedBuffer) ([]by
|
|||||||
return nil, fmt.Errorf("failed to create scrypt recipient: %w", err)
|
return nil, fmt.Errorf("failed to create scrypt recipient: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return EncryptToRecipient(data, recipient)
|
// Create a secure buffer for the data
|
||||||
|
dataBuffer := memguard.NewBufferFromBytes(data)
|
||||||
|
defer dataBuffer.Destroy()
|
||||||
|
|
||||||
|
return EncryptToRecipient(dataBuffer, recipient)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecryptWithPassphrase decrypts data using a passphrase with age's scrypt-based decryption
|
// DecryptWithPassphrase decrypts data using a passphrase with age's scrypt-based decryption
|
||||||
@ -138,4 +147,3 @@ func ReadPassphrase(prompt string) (*memguard.LockedBuffer, error) {
|
|||||||
|
|
||||||
return secureBuffer, nil
|
return secureBuffer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,7 +245,8 @@ func generateKeychainUnlockerName(vaultName string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getLongTermPrivateKey retrieves the long-term private key either from environment or current unlocker
|
// getLongTermPrivateKey retrieves the long-term private key either from environment or current unlocker
|
||||||
func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) ([]byte, error) {
|
// Returns a LockedBuffer to ensure the private key is protected in memory
|
||||||
|
func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) (*memguard.LockedBuffer, error) {
|
||||||
// Check if mnemonic is available in environment variable
|
// Check if mnemonic is available in environment variable
|
||||||
envMnemonic := os.Getenv(EnvMnemonic)
|
envMnemonic := os.Getenv(EnvMnemonic)
|
||||||
if envMnemonic != "" {
|
if envMnemonic != "" {
|
||||||
@ -255,7 +256,8 @@ func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
|
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return []byte(ltIdentity.String()), nil
|
// Return the private key in a secure buffer
|
||||||
|
return memguard.NewBufferFromBytes([]byte(ltIdentity.String())), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the vault to access current unlocker
|
// Get the vault to access current unlocker
|
||||||
@ -304,7 +306,8 @@ func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
|
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ltPrivKeyData, nil
|
// Return the decrypted key in a secure buffer
|
||||||
|
return memguard.NewBufferFromBytes(ltPrivKeyData), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateKeychainUnlocker creates a new keychain unlocker and stores it in the vault
|
// CreateKeychainUnlocker creates a new keychain unlocker and stores it in the vault
|
||||||
@ -380,6 +383,7 @@ func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, er
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer ltPrivKeyData.Destroy()
|
||||||
|
|
||||||
// Step 6: Encrypt long-term private key to the new age unlocker
|
// Step 6: Encrypt long-term private key to the new age unlocker
|
||||||
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())
|
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())
|
||||||
|
@ -113,8 +113,9 @@ func TestPassphraseUnlockerWithRealFS(t *testing.T) {
|
|||||||
t.Fatalf("Failed to parse recipient: %v", err)
|
t.Fatalf("Failed to parse recipient: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ltPrivKeyData := []byte(ltIdentity.String())
|
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
|
||||||
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyData, recipient)
|
defer ltPrivKeyBuffer.Destroy()
|
||||||
|
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, recipient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to encrypt long-term private key: %v", err)
|
t.Fatalf("Failed to encrypt long-term private key: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,11 @@ func NewPassphraseUnlocker(fs afero.Fs, directory string, metadata UnlockerMetad
|
|||||||
|
|
||||||
// CreatePassphraseUnlocker creates a new passphrase-protected unlocker
|
// CreatePassphraseUnlocker creates a new passphrase-protected unlocker
|
||||||
// The passphrase must be provided as a LockedBuffer for security
|
// The passphrase must be provided as a LockedBuffer for security
|
||||||
func CreatePassphraseUnlocker(fs afero.Fs, stateDir string, passphrase *memguard.LockedBuffer) (*PassphraseUnlocker, error) {
|
func CreatePassphraseUnlocker(
|
||||||
|
fs afero.Fs,
|
||||||
|
stateDir string,
|
||||||
|
passphrase *memguard.LockedBuffer,
|
||||||
|
) (*PassphraseUnlocker, error) {
|
||||||
// Get current vault
|
// Get current vault
|
||||||
currentVault, err := GetCurrentVault(fs, stateDir)
|
currentVault, err := GetCurrentVault(fs, stateDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"filippo.io/age"
|
"filippo.io/age"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -233,6 +234,7 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer ltPrivKeyData.Destroy()
|
||||||
|
|
||||||
// Step 7: Encrypt long-term private key to the new age unlocker
|
// Step 7: Encrypt long-term private key to the new age unlocker
|
||||||
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())
|
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())
|
||||||
@ -247,8 +249,11 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Step 8: Encrypt age private key to the GPG key ID
|
// Step 8: Encrypt age private key to the GPG key ID
|
||||||
agePrivateKeyBytes := []byte(ageIdentity.String())
|
// Use memguard to protect the private key in memory
|
||||||
encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBytes, gpgKeyID)
|
agePrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(ageIdentity.String()))
|
||||||
|
defer agePrivateKeyBuffer.Destroy()
|
||||||
|
|
||||||
|
encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBuffer.Bytes(), gpgKeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to encrypt age private key with GPG: %w", err)
|
return nil, fmt.Errorf("failed to encrypt age private key with GPG: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -18,7 +18,7 @@ import (
|
|||||||
// VaultInterface defines the interface that vault implementations must satisfy
|
// VaultInterface defines the interface that vault implementations must satisfy
|
||||||
type VaultInterface interface {
|
type VaultInterface interface {
|
||||||
GetDirectory() (string, error)
|
GetDirectory() (string, error)
|
||||||
AddSecret(name string, value []byte, force bool) error
|
AddSecret(name string, value *memguard.LockedBuffer, force bool) error
|
||||||
GetName() string
|
GetName() string
|
||||||
GetFilesystem() afero.Fs
|
GetFilesystem() afero.Fs
|
||||||
GetCurrentUnlocker() (Unlocker, error)
|
GetCurrentUnlocker() (Unlocker, error)
|
||||||
@ -72,7 +72,12 @@ func (s *Secret) Save(value []byte, force bool) error {
|
|||||||
slog.Bool("force", force),
|
slog.Bool("force", force),
|
||||||
)
|
)
|
||||||
|
|
||||||
err := s.vault.AddSecret(s.Name, value, force)
|
// Create a secure buffer for the value - note that the caller
|
||||||
|
// should ideally pass a LockedBuffer directly to vault.AddSecret
|
||||||
|
valueBuffer := memguard.NewBufferFromBytes(value)
|
||||||
|
defer valueBuffer.Destroy()
|
||||||
|
|
||||||
|
err := s.vault.AddSecret(s.Name, valueBuffer, force)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Debug("Failed to save secret", "error", err, "secret_name", s.Name)
|
Debug("Failed to save secret", "error", err, "secret_name", s.Name)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ func (m *MockVault) GetDirectory() (string, error) {
|
|||||||
return m.directory, nil
|
return m.directory, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockVault) AddSecret(name string, value []byte, _ bool) error {
|
func (m *MockVault) AddSecret(name string, value *memguard.LockedBuffer, _ bool) error {
|
||||||
// Create secret directory with proper storage name conversion
|
// Create secret directory with proper storage name conversion
|
||||||
storageName := strings.ReplaceAll(name, "/", "%")
|
storageName := strings.ReplaceAll(name, "/", "%")
|
||||||
secretDir := filepath.Join(m.directory, "secrets.d", storageName)
|
secretDir := filepath.Join(m.directory, "secrets.d", storageName)
|
||||||
@ -75,7 +75,7 @@ func (m *MockVault) AddSecret(name string, value []byte, _ bool) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt value to version's public key
|
// Encrypt value to version's public key (value is already a LockedBuffer)
|
||||||
encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient())
|
encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -88,7 +88,9 @@ func (m *MockVault) AddSecret(name string, value []byte, _ bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt version private key to long-term public key
|
// Encrypt version private key to long-term public key
|
||||||
encryptedPrivKey, err := EncryptToRecipient([]byte(versionIdentity.String()), ltIdentity.Recipient())
|
versionPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String()))
|
||||||
|
defer versionPrivKeyBuffer.Destroy()
|
||||||
|
encryptedPrivKey, err := EncryptToRecipient(versionPrivKeyBuffer, ltIdentity.Recipient())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -180,9 +182,13 @@ func TestPerSecretKeyFunctionality(t *testing.T) {
|
|||||||
secretName := "test-secret"
|
secretName := "test-secret"
|
||||||
secretValue := []byte("this is a test secret value")
|
secretValue := []byte("this is a test secret value")
|
||||||
|
|
||||||
|
// Create a secure buffer for the test value
|
||||||
|
valueBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||||
|
defer valueBuffer.Destroy()
|
||||||
|
|
||||||
// Test AddSecret
|
// Test AddSecret
|
||||||
t.Run("AddSecret", func(t *testing.T) {
|
t.Run("AddSecret", func(t *testing.T) {
|
||||||
err := vault.AddSecret(secretName, secretValue, false)
|
err := vault.AddSecret(secretName, valueBuffer, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("AddSecret failed: %v", err)
|
t.Fatalf("AddSecret failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"filippo.io/age"
|
"filippo.io/age"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/oklog/ulid/v2"
|
"github.com/oklog/ulid/v2"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
)
|
)
|
||||||
@ -120,11 +121,15 @@ func GenerateVersionName(fs afero.Fs, secretDir string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Save saves the version metadata and value
|
// Save saves the version metadata and value
|
||||||
func (sv *Version) Save(value []byte) error {
|
func (sv *Version) Save(value *memguard.LockedBuffer) error {
|
||||||
|
if value == nil {
|
||||||
|
return fmt.Errorf("value buffer is nil")
|
||||||
|
}
|
||||||
|
|
||||||
DebugWith("Saving secret version",
|
DebugWith("Saving secret version",
|
||||||
slog.String("secret_name", sv.SecretName),
|
slog.String("secret_name", sv.SecretName),
|
||||||
slog.String("version", sv.Version),
|
slog.String("version", sv.Version),
|
||||||
slog.Int("value_length", len(value)),
|
slog.Int("value_length", value.Size()),
|
||||||
)
|
)
|
||||||
|
|
||||||
fs := sv.vault.GetFilesystem()
|
fs := sv.vault.GetFilesystem()
|
||||||
@ -146,7 +151,9 @@ func (sv *Version) Save(value []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
versionPublicKey := versionIdentity.Recipient().String()
|
versionPublicKey := versionIdentity.Recipient().String()
|
||||||
versionPrivateKey := versionIdentity.String()
|
// Store private key in memguard buffer immediately
|
||||||
|
versionPrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String()))
|
||||||
|
defer versionPrivateKeyBuffer.Destroy()
|
||||||
|
|
||||||
DebugWith("Generated version keypair",
|
DebugWith("Generated version keypair",
|
||||||
slog.String("version", sv.Version),
|
slog.String("version", sv.Version),
|
||||||
@ -202,7 +209,7 @@ func (sv *Version) Save(value []byte) error {
|
|||||||
|
|
||||||
// Step 6: Encrypt the version's private key to the long-term public key
|
// Step 6: Encrypt the version's private key to the long-term public key
|
||||||
Debug("Encrypting version private key to long-term public key", "version", sv.Version)
|
Debug("Encrypting version private key to long-term public key", "version", sv.Version)
|
||||||
encryptedPrivKey, err := EncryptToRecipient([]byte(versionPrivateKey), ltRecipient)
|
encryptedPrivKey, err := EncryptToRecipient(versionPrivateKeyBuffer, ltRecipient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Debug("Failed to encrypt version private key", "error", err, "version", sv.Version)
|
Debug("Failed to encrypt version private key", "error", err, "version", sv.Version)
|
||||||
|
|
||||||
@ -228,7 +235,10 @@ func (sv *Version) Save(value []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt metadata to the version's public key
|
// Encrypt metadata to the version's public key
|
||||||
encryptedMetadata, err := EncryptToRecipient(metadataBytes, versionIdentity.Recipient())
|
metadataBuffer := memguard.NewBufferFromBytes(metadataBytes)
|
||||||
|
defer metadataBuffer.Destroy()
|
||||||
|
|
||||||
|
encryptedMetadata, err := EncryptToRecipient(metadataBuffer, versionIdentity.Recipient())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Debug("Failed to encrypt version metadata", "error", err, "version", sv.Version)
|
Debug("Failed to encrypt version metadata", "error", err, "version", sv.Version)
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ func (m *MockVersionVault) GetDirectory() (string, error) {
|
|||||||
return filepath.Join(m.stateDir, "vaults.d", m.Name), nil
|
return filepath.Join(m.stateDir, "vaults.d", m.Name), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockVersionVault) AddSecret(_ string, _ []byte, _ bool) error {
|
func (m *MockVersionVault) AddSecret(_ string, _ *memguard.LockedBuffer, _ bool) error {
|
||||||
return fmt.Errorf("not implemented in mock")
|
return fmt.Errorf("not implemented in mock")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,7 +165,9 @@ func TestSecretVersionSave(t *testing.T) {
|
|||||||
sv := NewVersion(vault, "test/secret", "20231215.001")
|
sv := NewVersion(vault, "test/secret", "20231215.001")
|
||||||
testValue := []byte("test-secret-value")
|
testValue := []byte("test-secret-value")
|
||||||
|
|
||||||
err = sv.Save(testValue)
|
testBuffer := memguard.NewBufferFromBytes(testValue)
|
||||||
|
defer testBuffer.Destroy()
|
||||||
|
err = sv.Save(testBuffer)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify files were created
|
// Verify files were created
|
||||||
@ -203,7 +205,9 @@ func TestSecretVersionLoadMetadata(t *testing.T) {
|
|||||||
sv.Metadata.NotBefore = &epochPlusOne
|
sv.Metadata.NotBefore = &epochPlusOne
|
||||||
sv.Metadata.NotAfter = &now
|
sv.Metadata.NotAfter = &now
|
||||||
|
|
||||||
err = sv.Save([]byte("test-value"))
|
testBuffer := memguard.NewBufferFromBytes([]byte("test-value"))
|
||||||
|
defer testBuffer.Destroy()
|
||||||
|
err = sv.Save(testBuffer)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create new version object and load metadata
|
// Create new version object and load metadata
|
||||||
@ -242,15 +246,19 @@ func TestSecretVersionGetValue(t *testing.T) {
|
|||||||
// Create and save a version
|
// Create and save a version
|
||||||
sv := NewVersion(vault, "test/secret", "20231215.001")
|
sv := NewVersion(vault, "test/secret", "20231215.001")
|
||||||
originalValue := []byte("test-secret-value-12345")
|
originalValue := []byte("test-secret-value-12345")
|
||||||
|
expectedValue := make([]byte, len(originalValue))
|
||||||
|
copy(expectedValue, originalValue)
|
||||||
|
|
||||||
err = sv.Save(originalValue)
|
originalBuffer := memguard.NewBufferFromBytes(originalValue)
|
||||||
|
defer originalBuffer.Destroy()
|
||||||
|
err = sv.Save(originalBuffer)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Retrieve the value
|
// Retrieve the value
|
||||||
retrievedValue, err := sv.GetValue(ltIdentity)
|
retrievedValue, err := sv.GetValue(ltIdentity)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, originalValue, retrievedValue)
|
assert.Equal(t, expectedValue, retrievedValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListVersions(t *testing.T) {
|
func TestListVersions(t *testing.T) {
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"git.eeqj.de/sneak/secret/internal/secret"
|
"git.eeqj.de/sneak/secret/internal/secret"
|
||||||
"git.eeqj.de/sneak/secret/internal/vault"
|
"git.eeqj.de/sneak/secret/internal/vault"
|
||||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -107,8 +108,13 @@ func TestVaultWithRealFilesystem(t *testing.T) {
|
|||||||
// Create a secret with a deeply nested path
|
// Create a secret with a deeply nested path
|
||||||
deepPath := "api/credentials/production/database/primary"
|
deepPath := "api/credentials/production/database/primary"
|
||||||
secretValue := []byte("supersecretdbpassword")
|
secretValue := []byte("supersecretdbpassword")
|
||||||
|
expectedValue := make([]byte, len(secretValue))
|
||||||
|
copy(expectedValue, secretValue)
|
||||||
|
|
||||||
err = vlt.AddSecret(deepPath, secretValue, false)
|
secretBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||||
|
defer secretBuffer.Destroy()
|
||||||
|
|
||||||
|
err = vlt.AddSecret(deepPath, secretBuffer, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add secret with deep path: %v", err)
|
t.Fatalf("Failed to add secret with deep path: %v", err)
|
||||||
}
|
}
|
||||||
@ -137,9 +143,9 @@ func TestVaultWithRealFilesystem(t *testing.T) {
|
|||||||
t.Fatalf("Failed to retrieve deep path secret: %v", err)
|
t.Fatalf("Failed to retrieve deep path secret: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(retrievedValue) != string(secretValue) {
|
if string(retrievedValue) != string(expectedValue) {
|
||||||
t.Errorf("Retrieved value doesn't match. Expected %q, got %q",
|
t.Errorf("Retrieved value doesn't match. Expected %q, got %q",
|
||||||
string(secretValue), string(retrievedValue))
|
string(expectedValue), string(retrievedValue))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -368,7 +374,11 @@ func TestVaultWithRealFilesystem(t *testing.T) {
|
|||||||
// Add a secret to vault1
|
// Add a secret to vault1
|
||||||
secretName := "test-secret"
|
secretName := "test-secret"
|
||||||
secretValue := []byte("secret in vault1")
|
secretValue := []byte("secret in vault1")
|
||||||
if err := vault1.AddSecret(secretName, secretValue, false); err != nil {
|
|
||||||
|
secretBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||||
|
defer secretBuffer.Destroy()
|
||||||
|
|
||||||
|
if err := vault1.AddSecret(secretName, secretBuffer, false); err != nil {
|
||||||
t.Fatalf("Failed to add secret to vault1: %v", err)
|
t.Fatalf("Failed to add secret to vault1: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,11 +29,21 @@ import (
|
|||||||
|
|
||||||
"git.eeqj.de/sneak/secret/internal/secret"
|
"git.eeqj.de/sneak/secret/internal/secret"
|
||||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Helper function to add a secret to vault with proper buffer protection
|
||||||
|
func addTestSecret(t *testing.T, vault *Vault, name string, value []byte, force bool) {
|
||||||
|
t.Helper()
|
||||||
|
buffer := memguard.NewBufferFromBytes(value)
|
||||||
|
defer buffer.Destroy()
|
||||||
|
err := vault.AddSecret(name, buffer, force)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
// TestVersionIntegrationWorkflow tests the complete version workflow
|
// TestVersionIntegrationWorkflow tests the complete version workflow
|
||||||
func TestVersionIntegrationWorkflow(t *testing.T) {
|
func TestVersionIntegrationWorkflow(t *testing.T) {
|
||||||
fs := afero.NewMemMapFs()
|
fs := afero.NewMemMapFs()
|
||||||
@ -66,8 +76,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
|
|||||||
|
|
||||||
// Step 1: Create initial version
|
// Step 1: Create initial version
|
||||||
t.Run("create_initial_version", func(t *testing.T) {
|
t.Run("create_initial_version", func(t *testing.T) {
|
||||||
err := vault.AddSecret(secretName, []byte("version-1-data"), false)
|
addTestSecret(t, vault, secretName, []byte("version-1-data"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify secret can be retrieved
|
// Verify secret can be retrieved
|
||||||
value, err := vault.GetSecret(secretName)
|
value, err := vault.GetSecret(secretName)
|
||||||
@ -108,8 +117,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
|
|||||||
firstVersionName = versions[0]
|
firstVersionName = versions[0]
|
||||||
|
|
||||||
// Create second version
|
// Create second version
|
||||||
err = vault.AddSecret(secretName, []byte("version-2-data"), true)
|
addTestSecret(t, vault, secretName, []byte("version-2-data"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify new value is current
|
// Verify new value is current
|
||||||
value, err := vault.GetSecret(secretName)
|
value, err := vault.GetSecret(secretName)
|
||||||
@ -142,8 +150,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
|
|||||||
t.Run("create_third_version", func(t *testing.T) {
|
t.Run("create_third_version", func(t *testing.T) {
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
err := vault.AddSecret(secretName, []byte("version-3-data"), true)
|
addTestSecret(t, vault, secretName, []byte("version-3-data"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify we now have three versions
|
// Verify we now have three versions
|
||||||
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
|
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
|
||||||
@ -214,8 +221,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
|
|||||||
secretDir := filepath.Join(vaultDir, "secrets.d", "limit%test", "versions")
|
secretDir := filepath.Join(vaultDir, "secrets.d", "limit%test", "versions")
|
||||||
|
|
||||||
// Create 998 versions (we already have one from the first AddSecret)
|
// Create 998 versions (we already have one from the first AddSecret)
|
||||||
err := vault.AddSecret(limitSecretName, []byte("initial"), false)
|
addTestSecret(t, vault, limitSecretName, []byte("initial"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Get today's date for consistent version names
|
// Get today's date for consistent version names
|
||||||
today := time.Now().Format("20060102")
|
today := time.Now().Format("20060102")
|
||||||
@ -255,7 +261,9 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// Try to add secret without force when it exists
|
// Try to add secret without force when it exists
|
||||||
err = vault.AddSecret(secretName, []byte("should-fail"), false)
|
failBuffer := memguard.NewBufferFromBytes([]byte("should-fail"))
|
||||||
|
defer failBuffer.Destroy()
|
||||||
|
err = vault.AddSecret(secretName, failBuffer, false)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "already exists")
|
assert.Contains(t, err.Error(), "already exists")
|
||||||
})
|
})
|
||||||
@ -272,8 +280,7 @@ func TestVersionConcurrency(t *testing.T) {
|
|||||||
secretName := "concurrent/test"
|
secretName := "concurrent/test"
|
||||||
|
|
||||||
// Create initial version
|
// Create initial version
|
||||||
err := vault.AddSecret(secretName, []byte("initial"), false)
|
addTestSecret(t, vault, secretName, []byte("initial"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Test concurrent reads
|
// Test concurrent reads
|
||||||
t.Run("concurrent_reads", func(t *testing.T) {
|
t.Run("concurrent_reads", func(t *testing.T) {
|
||||||
@ -326,8 +333,10 @@ func TestVersionCompatibility(t *testing.T) {
|
|||||||
|
|
||||||
// Create old-style encrypted value directly in secret directory
|
// Create old-style encrypted value directly in secret directory
|
||||||
testValue := []byte("legacy-value")
|
testValue := []byte("legacy-value")
|
||||||
|
testValueBuffer := memguard.NewBufferFromBytes(testValue)
|
||||||
|
defer testValueBuffer.Destroy()
|
||||||
ltRecipient := ltIdentity.Recipient()
|
ltRecipient := ltIdentity.Recipient()
|
||||||
encrypted, err := secret.EncryptToRecipient(testValue, ltRecipient)
|
encrypted, err := secret.EncryptToRecipient(testValueBuffer, ltRecipient)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
valuePath := filepath.Join(secretDir, "value.age")
|
valuePath := filepath.Join(secretDir, "value.age")
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"filippo.io/age"
|
"filippo.io/age"
|
||||||
"git.eeqj.de/sneak/secret/internal/secret"
|
"git.eeqj.de/sneak/secret/internal/secret"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,11 +99,15 @@ func isValidSecretName(name string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddSecret adds a secret to this vault
|
// AddSecret adds a secret to this vault
|
||||||
func (v *Vault) AddSecret(name string, value []byte, force bool) error {
|
func (v *Vault) AddSecret(name string, value *memguard.LockedBuffer, force bool) error {
|
||||||
|
if value == nil {
|
||||||
|
return fmt.Errorf("value buffer is nil")
|
||||||
|
}
|
||||||
|
|
||||||
secret.DebugWith("Adding secret to vault",
|
secret.DebugWith("Adding secret to vault",
|
||||||
slog.String("vault_name", v.Name),
|
slog.String("vault_name", v.Name),
|
||||||
slog.String("secret_name", name),
|
slog.String("secret_name", name),
|
||||||
slog.Int("value_length", len(value)),
|
slog.Int("value_length", value.Size()),
|
||||||
slog.Bool("force", force),
|
slog.Bool("force", force),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -195,7 +200,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
|
|||||||
// We'll update the previous version's notAfter after we save the new version
|
// We'll update the previous version's notAfter after we save the new version
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the new version
|
// Save the new version - pass the LockedBuffer directly
|
||||||
if err := newVersion.Save(value); err != nil {
|
if err := newVersion.Save(value); err != nil {
|
||||||
secret.Debug("Failed to save new version", "error", err, "version", versionName)
|
secret.Debug("Failed to save new version", "error", err, "version", versionName)
|
||||||
|
|
||||||
@ -272,7 +277,10 @@ func updateVersionMetadata(fs afero.Fs, version *secret.Version, ltIdentity *age
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt metadata to the version's public key
|
// Encrypt metadata to the version's public key
|
||||||
encryptedMetadata, err := secret.EncryptToRecipient(metadataBytes, versionIdentity.Recipient())
|
metadataBuffer := memguard.NewBufferFromBytes(metadataBytes)
|
||||||
|
defer metadataBuffer.Destroy()
|
||||||
|
|
||||||
|
encryptedMetadata, err := secret.EncryptToRecipient(metadataBuffer, versionIdentity.Recipient())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to encrypt version metadata: %w", err)
|
return fmt.Errorf("failed to encrypt version metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -24,11 +24,21 @@ import (
|
|||||||
|
|
||||||
"git.eeqj.de/sneak/secret/internal/secret"
|
"git.eeqj.de/sneak/secret/internal/secret"
|
||||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Helper function to add a secret to vault with proper buffer protection
|
||||||
|
func addTestSecretToVault(t *testing.T, vault *Vault, name string, value []byte, force bool) {
|
||||||
|
t.Helper()
|
||||||
|
buffer := memguard.NewBufferFromBytes(value)
|
||||||
|
defer buffer.Destroy()
|
||||||
|
err := vault.AddSecret(name, buffer, force)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function to create a vault with long-term key set up
|
// Helper function to create a vault with long-term key set up
|
||||||
func createTestVaultWithKey(t *testing.T, fs afero.Fs, stateDir, vaultName string) *Vault {
|
func createTestVaultWithKey(t *testing.T, fs afero.Fs, stateDir, vaultName string) *Vault {
|
||||||
// Set mnemonic for testing
|
// Set mnemonic for testing
|
||||||
@ -65,9 +75,10 @@ func TestVaultAddSecretCreatesVersion(t *testing.T) {
|
|||||||
// Add a secret
|
// Add a secret
|
||||||
secretName := "test/secret"
|
secretName := "test/secret"
|
||||||
secretValue := []byte("initial-value")
|
secretValue := []byte("initial-value")
|
||||||
|
expectedValue := make([]byte, len(secretValue))
|
||||||
|
copy(expectedValue, secretValue)
|
||||||
|
|
||||||
err := vault.AddSecret(secretName, secretValue, false)
|
addTestSecretToVault(t, vault, secretName, secretValue, false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Check that version directory was created
|
// Check that version directory was created
|
||||||
vaultDir, _ := vault.GetDirectory()
|
vaultDir, _ := vault.GetDirectory()
|
||||||
@ -88,7 +99,7 @@ func TestVaultAddSecretCreatesVersion(t *testing.T) {
|
|||||||
// Get the secret value
|
// Get the secret value
|
||||||
retrievedValue, err := vault.GetSecret(secretName)
|
retrievedValue, err := vault.GetSecret(secretName)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, secretValue, retrievedValue)
|
assert.Equal(t, expectedValue, retrievedValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVaultAddSecretMultipleVersions(t *testing.T) {
|
func TestVaultAddSecretMultipleVersions(t *testing.T) {
|
||||||
@ -101,17 +112,17 @@ func TestVaultAddSecretMultipleVersions(t *testing.T) {
|
|||||||
secretName := "test/secret"
|
secretName := "test/secret"
|
||||||
|
|
||||||
// Add first version
|
// Add first version
|
||||||
err := vault.AddSecret(secretName, []byte("version-1"), false)
|
addTestSecretToVault(t, vault, secretName, []byte("version-1"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Try to add again without force - should fail
|
// Try to add again without force - should fail
|
||||||
err = vault.AddSecret(secretName, []byte("version-2"), false)
|
failBuffer := memguard.NewBufferFromBytes([]byte("version-2"))
|
||||||
|
defer failBuffer.Destroy()
|
||||||
|
err := vault.AddSecret(secretName, failBuffer, false)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "already exists")
|
assert.Contains(t, err.Error(), "already exists")
|
||||||
|
|
||||||
// Add with force - should create new version
|
// Add with force - should create new version
|
||||||
err = vault.AddSecret(secretName, []byte("version-2"), true)
|
addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Check that we have two versions
|
// Check that we have two versions
|
||||||
vaultDir, _ := vault.GetDirectory()
|
vaultDir, _ := vault.GetDirectory()
|
||||||
@ -136,14 +147,12 @@ func TestVaultGetSecretVersion(t *testing.T) {
|
|||||||
secretName := "test/secret"
|
secretName := "test/secret"
|
||||||
|
|
||||||
// Add multiple versions
|
// Add multiple versions
|
||||||
err := vault.AddSecret(secretName, []byte("version-1"), false)
|
addTestSecretToVault(t, vault, secretName, []byte("version-1"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Small delay to ensure different version names
|
// Small delay to ensure different version names
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
err = vault.AddSecret(secretName, []byte("version-2"), true)
|
addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Get versions list
|
// Get versions list
|
||||||
vaultDir, _ := vault.GetDirectory()
|
vaultDir, _ := vault.GetDirectory()
|
||||||
@ -185,7 +194,9 @@ func TestVaultVersionTimestamps(t *testing.T) {
|
|||||||
|
|
||||||
// Add first version
|
// Add first version
|
||||||
beforeFirst := time.Now()
|
beforeFirst := time.Now()
|
||||||
err = vault.AddSecret(secretName, []byte("version-1"), false)
|
v1Buffer := memguard.NewBufferFromBytes([]byte("version-1"))
|
||||||
|
defer v1Buffer.Destroy()
|
||||||
|
err = vault.AddSecret(secretName, v1Buffer, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
afterFirst := time.Now()
|
afterFirst := time.Now()
|
||||||
|
|
||||||
@ -212,8 +223,7 @@ func TestVaultVersionTimestamps(t *testing.T) {
|
|||||||
// Add second version
|
// Add second version
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
beforeSecond := time.Now()
|
beforeSecond := time.Now()
|
||||||
err = vault.AddSecret(secretName, []byte("version-2"), true)
|
addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
|
||||||
require.NoError(t, err)
|
|
||||||
afterSecond := time.Now()
|
afterSecond := time.Now()
|
||||||
|
|
||||||
// Get updated versions
|
// Get updated versions
|
||||||
@ -249,11 +259,10 @@ func TestVaultGetNonExistentVersion(t *testing.T) {
|
|||||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||||
|
|
||||||
// Add a secret
|
// Add a secret
|
||||||
err := vault.AddSecret("test/secret", []byte("value"), false)
|
addTestSecretToVault(t, vault, "test/secret", []byte("value"), false)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Try to get non-existent version
|
// Try to get non-existent version
|
||||||
_, err = vault.GetSecretVersion("test/secret", "20991231.999")
|
_, err := vault.GetSecretVersion("test/secret", "20991231.999")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "not found")
|
assert.Contains(t, err.Error(), "not found")
|
||||||
}
|
}
|
||||||
@ -281,7 +290,9 @@ func TestUpdateVersionMetadata(t *testing.T) {
|
|||||||
version.Metadata.NotAfter = nil
|
version.Metadata.NotAfter = nil
|
||||||
|
|
||||||
// Save version
|
// Save version
|
||||||
err = version.Save([]byte("test-value"))
|
testBuffer := memguard.NewBufferFromBytes([]byte("test-value"))
|
||||||
|
defer testBuffer.Destroy()
|
||||||
|
err = version.Save(testBuffer)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Update metadata
|
// Update metadata
|
||||||
|
@ -346,10 +346,7 @@ func (v *Vault) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*se
|
|||||||
|
|
||||||
// Encrypt private key with passphrase
|
// Encrypt private key with passphrase
|
||||||
privKeyStr := unlockerIdentity.String()
|
privKeyStr := unlockerIdentity.String()
|
||||||
privKeyBuffer := memguard.NewBufferFromBytes([]byte(privKeyStr))
|
encryptedPrivKey, err := secret.EncryptWithPassphrase([]byte(privKeyStr), passphrase)
|
||||||
defer privKeyBuffer.Destroy()
|
|
||||||
|
|
||||||
encryptedPrivKey, err := secret.EncryptWithPassphrase(privKeyBuffer.Bytes(), passphrase)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to encrypt unlocker private key: %w", err)
|
return nil, fmt.Errorf("failed to encrypt unlocker private key: %w", err)
|
||||||
}
|
}
|
||||||
@ -385,8 +382,10 @@ func (v *Vault) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*se
|
|||||||
return nil, fmt.Errorf("failed to get long-term key: %w", err)
|
return nil, fmt.Errorf("failed to get long-term key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ltPrivKey := []byte(ltIdentity.String())
|
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
|
||||||
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKey, unlockerIdentity.Recipient())
|
defer ltPrivKeyBuffer.Destroy()
|
||||||
|
|
||||||
|
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerIdentity.Recipient())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to encrypt long-term private key: %w", err)
|
return nil, fmt.Errorf("failed to encrypt long-term private key: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -122,8 +122,13 @@ func TestVaultOperations(t *testing.T) {
|
|||||||
// Now add a secret
|
// Now add a secret
|
||||||
secretName := "test/secret"
|
secretName := "test/secret"
|
||||||
secretValue := []byte("test-secret-value")
|
secretValue := []byte("test-secret-value")
|
||||||
|
expectedValue := make([]byte, len(secretValue))
|
||||||
|
copy(expectedValue, secretValue)
|
||||||
|
|
||||||
err = vlt.AddSecret(secretName, secretValue, false)
|
secretBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||||
|
defer secretBuffer.Destroy()
|
||||||
|
|
||||||
|
err = vlt.AddSecret(secretName, secretBuffer, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to add secret: %v", err)
|
t.Fatalf("Failed to add secret: %v", err)
|
||||||
}
|
}
|
||||||
@ -152,8 +157,8 @@ func TestVaultOperations(t *testing.T) {
|
|||||||
t.Fatalf("Failed to get secret: %v", err)
|
t.Fatalf("Failed to get secret: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(retrievedValue) != string(secretValue) {
|
if string(retrievedValue) != string(expectedValue) {
|
||||||
t.Errorf("Expected secret value '%s', got '%s'", string(secretValue), string(retrievedValue))
|
t.Errorf("Expected secret value '%s', got '%s'", string(expectedValue), string(retrievedValue))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user