Compare commits

..

1 Commits

Author SHA1 Message Date
clawbot
79ae572cc3 test: add test for getLongTermPrivateKey derivation index
Verifies that getLongTermPrivateKey reads the derivation index from
vault metadata instead of using hardcoded index 0. Test creates a
mock vault with DerivationIndex=5 and confirms the derived key
matches index 5.
2026-02-08 17:45:34 -08:00
6 changed files with 110 additions and 95 deletions

View File

@ -7,10 +7,12 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"filippo.io/age"
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard" "github.com/awnumar/memguard"
"github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39" "github.com/tyler-smith/go-bip39"
) )
@ -152,8 +154,35 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
return fmt.Errorf("failed to create unlocker: %w", err) return fmt.Errorf("failed to create unlocker: %w", err)
} }
// Note: CreatePassphraseUnlocker already encrypts and writes the long-term // Encrypt long-term private key to the unlocker
// private key to longterm.age, so no need to do it again here. unlockerDir := passphraseUnlocker.GetDirectory()
// Read unlocker public key
unlockerPubKeyData, err := afero.ReadFile(cli.fs, filepath.Join(unlockerDir, "pub.age"))
if err != nil {
return fmt.Errorf("failed to read unlocker public key: %w", err)
}
unlockerRecipient, err := age.ParseX25519Recipient(string(unlockerPubKeyData))
if err != nil {
return fmt.Errorf("failed to parse unlocker public key: %w", err)
}
// Encrypt long-term private key to unlocker
// Use memguard to protect the private key in memory
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
defer ltPrivKeyBuffer.Destroy()
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerRecipient)
if err != nil {
return fmt.Errorf("failed to encrypt long-term private key: %w", err)
}
// Write encrypted long-term private key
ltPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
if err := afero.WriteFile(cli.fs, ltPrivKeyPath, encryptedLtPrivKey, secret.FilePerms); err != nil {
return fmt.Errorf("failed to write encrypted long-term private key: %w", err)
}
if cmd != nil { if cmd != nil {
cmd.Printf("\nDefault vault created and configured\n") cmd.Printf("\nDefault vault created and configured\n")

View File

@ -68,11 +68,6 @@ func DecryptWithIdentity(data []byte, identity age.Identity) (*memguard.LockedBu
// Create a secure buffer for the decrypted data // Create a secure buffer for the decrypted data
resultBuffer := memguard.NewBufferFromBytes(result) resultBuffer := memguard.NewBufferFromBytes(result)
// Zero out the original slice to prevent plaintext from lingering in unprotected memory
for i := range result {
result[i] = 0
}
return resultBuffer, nil return resultBuffer, nil
} }

View File

@ -2,67 +2,81 @@ package secret
import ( import (
"encoding/json" "encoding/json"
"path/filepath"
"testing" "testing"
"time" "time"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard" "github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"git.eeqj.de/sneak/secret/pkg/agehd"
) )
// mockVault implements VaultInterface for testing getLongTermPrivateKey // realVault is a minimal VaultInterface backed by a real afero filesystem,
type mockVaultForDerivation struct { // using the same directory layout as vault.Vault.
dir string type realVault struct {
name string
stateDir string
fs afero.Fs fs afero.Fs
} }
func (m *mockVaultForDerivation) GetDirectory() (string, error) { return m.dir, nil } func (v *realVault) GetDirectory() (string, error) {
func (m *mockVaultForDerivation) GetName() string { return "test-vault" } return filepath.Join(v.stateDir, "vaults.d", v.name), nil
func (m *mockVaultForDerivation) GetFilesystem() afero.Fs { return m.fs }
// Stubs for unused interface methods
func (m *mockVaultForDerivation) AddSecret(name string, value *memguard.LockedBuffer, force bool) error {
return nil
} }
func (m *mockVaultForDerivation) GetCurrentUnlocker() (Unlocker, error) { return nil, nil } func (v *realVault) GetName() string { return v.name }
func (m *mockVaultForDerivation) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*PassphraseUnlocker, error) { func (v *realVault) GetFilesystem() afero.Fs { return v.fs }
return nil, nil
// Unused by getLongTermPrivateKey — these satisfy VaultInterface.
func (v *realVault) AddSecret(string, *memguard.LockedBuffer, bool) error { panic("not used") }
func (v *realVault) GetCurrentUnlocker() (Unlocker, error) { panic("not used") }
func (v *realVault) CreatePassphraseUnlocker(*memguard.LockedBuffer) (*PassphraseUnlocker, error) {
panic("not used")
}
// createRealVault sets up a complete vault directory structure on an in-memory
// filesystem, identical to what vault.CreateVault produces.
func createRealVault(t *testing.T, fs afero.Fs, stateDir, name string, derivationIndex uint32) *realVault {
t.Helper()
vaultDir := filepath.Join(stateDir, "vaults.d", name)
require.NoError(t, fs.MkdirAll(filepath.Join(vaultDir, "secrets.d"), DirPerms))
require.NoError(t, fs.MkdirAll(filepath.Join(vaultDir, "unlockers.d"), DirPerms))
metadata := VaultMetadata{
CreatedAt: time.Now(),
DerivationIndex: derivationIndex,
}
metaBytes, err := json.Marshal(metadata)
require.NoError(t, err)
require.NoError(t, afero.WriteFile(fs, filepath.Join(vaultDir, "vault-metadata.json"), metaBytes, FilePerms))
return &realVault{name: name, stateDir: stateDir, fs: fs}
} }
func TestGetLongTermPrivateKeyUsesVaultDerivationIndex(t *testing.T) { func TestGetLongTermPrivateKeyUsesVaultDerivationIndex(t *testing.T) {
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" const testMnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
// Derive keys at index 0 and index 5 — they must differ // Derive expected keys at two different indices to prove they differ.
key0, err := agehd.DeriveIdentity(testMnemonic, 0) key0, err := agehd.DeriveIdentity(testMnemonic, 0)
require.NoError(t, err) require.NoError(t, err)
key5, err := agehd.DeriveIdentity(testMnemonic, 5) key5, err := agehd.DeriveIdentity(testMnemonic, 5)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, key0.String(), key5.String(), "different derivation indices must produce different keys") require.NotEqual(t, key0.String(), key5.String(),
"sanity check: different derivation indices must produce different keys")
// Set up in-memory filesystem with vault metadata at index 5 // Build a real vault with DerivationIndex=5 on an in-memory filesystem.
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
vaultDir := "/test-vault" vault := createRealVault(t, fs, "/state", "test-vault", 5)
require.NoError(t, fs.MkdirAll(vaultDir, 0o700))
metadata := VaultMetadata{
CreatedAt: time.Now(),
DerivationIndex: 5,
}
metaBytes, err := json.Marshal(metadata)
require.NoError(t, err)
require.NoError(t, afero.WriteFile(fs, vaultDir+"/vault-metadata.json", metaBytes, 0o600))
// Set the mnemonic env var
t.Setenv(EnvMnemonic, testMnemonic) t.Setenv(EnvMnemonic, testMnemonic)
vault := &mockVaultForDerivation{dir: vaultDir, fs: fs}
result, err := getLongTermPrivateKey(fs, vault) result, err := getLongTermPrivateKey(fs, vault)
require.NoError(t, err) require.NoError(t, err)
defer result.Destroy() defer result.Destroy()
// The derived key should match index 5, NOT index 0 assert.Equal(t, key5.String(), string(result.Bytes()),
assert.Equal(t, key5.String(), string(result.Bytes()), "getLongTermPrivateKey should use vault's derivation index (5), not hardcoded 0") "getLongTermPrivateKey should derive at vault's DerivationIndex (5)")
assert.NotEqual(t, key0.String(), string(result.Bytes()), "getLongTermPrivateKey must not use hardcoded index 0") assert.NotEqual(t, key0.String(), string(result.Bytes()),
"getLongTermPrivateKey must not use hardcoded index 0")
} }

View File

@ -4,8 +4,6 @@
package secret package secret
import ( import (
"fmt"
"filippo.io/age" "filippo.io/age"
"github.com/awnumar/memguard" "github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -24,59 +22,52 @@ type KeychainUnlocker struct {
fs afero.Fs fs afero.Fs
} }
var errKeychainNotSupported = fmt.Errorf("keychain unlockers are only supported on macOS") // GetIdentity panics on non-Darwin platforms
// GetIdentity returns an error on non-Darwin platforms
func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) { func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
return nil, errKeychainNotSupported panic("keychain unlockers are only supported on macOS")
} }
// GetType returns the unlocker type // GetType panics on non-Darwin platforms
func (k *KeychainUnlocker) GetType() string { func (k *KeychainUnlocker) GetType() string {
return "keychain" panic("keychain unlockers are only supported on macOS")
} }
// GetMetadata returns the unlocker metadata // GetMetadata panics on non-Darwin platforms
func (k *KeychainUnlocker) GetMetadata() UnlockerMetadata { func (k *KeychainUnlocker) GetMetadata() UnlockerMetadata {
return k.Metadata panic("keychain unlockers are only supported on macOS")
} }
// GetDirectory returns the unlocker directory // GetDirectory panics on non-Darwin platforms
func (k *KeychainUnlocker) GetDirectory() string { func (k *KeychainUnlocker) GetDirectory() string {
return k.Directory panic("keychain unlockers are only supported on macOS")
} }
// GetID returns the unlocker ID // GetID returns the unlocker ID
func (k *KeychainUnlocker) GetID() string { func (k *KeychainUnlocker) GetID() string {
return fmt.Sprintf("%s-keychain", k.Metadata.CreatedAt.Format("2006-01-02.15.04")) panic("keychain unlockers are only supported on macOS")
} }
// GetKeychainItemName returns an error on non-Darwin platforms // GetKeychainItemName panics on non-Darwin platforms
func (k *KeychainUnlocker) GetKeychainItemName() (string, error) { func (k *KeychainUnlocker) GetKeychainItemName() (string, error) {
return "", errKeychainNotSupported panic("keychain unlockers are only supported on macOS")
} }
// Remove returns an error on non-Darwin platforms // Remove panics on non-Darwin platforms
func (k *KeychainUnlocker) Remove() error { func (k *KeychainUnlocker) Remove() error {
return errKeychainNotSupported panic("keychain unlockers are only supported on macOS")
} }
// NewKeychainUnlocker creates a stub KeychainUnlocker on non-Darwin platforms. // NewKeychainUnlocker panics on non-Darwin platforms
// The returned instance's methods that require macOS functionality will return errors.
func NewKeychainUnlocker(fs afero.Fs, directory string, metadata UnlockerMetadata) *KeychainUnlocker { func NewKeychainUnlocker(fs afero.Fs, directory string, metadata UnlockerMetadata) *KeychainUnlocker {
return &KeychainUnlocker{ panic("keychain unlockers are only supported on macOS")
Directory: directory,
Metadata: metadata,
fs: fs,
}
} }
// CreateKeychainUnlocker returns an error on non-Darwin platforms // CreateKeychainUnlocker panics on non-Darwin platforms
func CreateKeychainUnlocker(_ afero.Fs, _ string) (*KeychainUnlocker, error) { func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, error) {
return nil, errKeychainNotSupported panic("keychain unlockers are only supported on macOS")
} }
// getLongTermPrivateKey returns an error on non-Darwin platforms // getLongTermPrivateKey panics on non-Darwin platforms
func getLongTermPrivateKey(_ afero.Fs, _ VaultInterface) (*memguard.LockedBuffer, error) { func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) (*memguard.LockedBuffer, error) {
return nil, errKeychainNotSupported panic("keychain unlockers are only supported on macOS")
} }

View File

@ -227,23 +227,27 @@ func (v *Vault) NumSecrets() (int, error) {
return 0, fmt.Errorf("failed to read secrets directory: %w", err) return 0, fmt.Errorf("failed to read secrets directory: %w", err)
} }
// Count only directories that have a "current" version pointer file // Count only directories that contain at least one version file
count := 0 count := 0
for _, entry := range entries { for _, entry := range entries {
if !entry.IsDir() { if !entry.IsDir() {
continue continue
} }
// A valid secret has a "current" file pointing to the active version // Check if this secret directory contains any version files
secretDir := filepath.Join(secretsDir, entry.Name()) secretDir := filepath.Join(secretsDir, entry.Name())
currentFile := filepath.Join(secretDir, "current") versionFiles, err := afero.ReadDir(v.fs, secretDir)
exists, err := afero.Exists(v.fs, currentFile)
if err != nil { if err != nil {
continue // Skip directories we can't read continue // Skip directories we can't read
} }
if exists { // Look for at least one version file (excluding "current" symlink)
for _, vFile := range versionFiles {
if !vFile.IsDir() && vFile.Name() != "current" {
count++ count++
break // Found at least one version, count this secret
}
} }
} }

View File

@ -162,24 +162,6 @@ func TestVaultOperations(t *testing.T) {
} }
}) })
// Test NumSecrets
t.Run("NumSecrets", func(t *testing.T) {
vlt, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
numSecrets, err := vlt.NumSecrets()
if err != nil {
t.Fatalf("Failed to count secrets: %v", err)
}
// We added one secret in SecretOperations
if numSecrets != 1 {
t.Errorf("Expected 1 secret, got %d", numSecrets)
}
})
// Test unlocker operations // Test unlocker operations
t.Run("UnlockerOperations", func(t *testing.T) { t.Run("UnlockerOperations", func(t *testing.T) {
vlt, err := GetCurrentVault(fs, stateDir) vlt, err := GetCurrentVault(fs, stateDir)