NumSecrets() previously looked for non-directory, non-'current' files directly under each secret directory, but the only children are 'current' (file, excluded) and 'versions' (directory, excluded), so it always returned 0. Now checks for the existence of the 'current' file, which is the canonical indicator that a secret exists and has an active version. This fixes the safety check in UnlockersRemove that was always allowing removal of the last unlocker.
246 lines
6.1 KiB
Go
246 lines
6.1 KiB
Go
package vault
|
|
|
|
import (
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"git.eeqj.de/sneak/secret/internal/secret"
|
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
|
"github.com/awnumar/memguard"
|
|
"github.com/spf13/afero"
|
|
)
|
|
|
|
func TestVaultOperations(t *testing.T) {
|
|
// Test environment will be cleaned up automatically by t.Setenv
|
|
|
|
// Set test environment variables
|
|
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
|
t.Setenv(secret.EnvMnemonic, testMnemonic)
|
|
t.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
|
|
|
|
// Use in-memory filesystem
|
|
fs := afero.NewMemMapFs()
|
|
stateDir := "/test/state"
|
|
|
|
// Test vault creation
|
|
t.Run("CreateVault", func(t *testing.T) {
|
|
vlt, err := CreateVault(fs, stateDir, "test-vault")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault: %v", err)
|
|
}
|
|
|
|
if vlt.GetName() != "test-vault" {
|
|
t.Errorf("Expected vault name 'test-vault', got '%s'", vlt.GetName())
|
|
}
|
|
|
|
// Check vault directory exists
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
exists, err := afero.DirExists(fs, vaultDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to check vault directory: %v", err)
|
|
}
|
|
|
|
if !exists {
|
|
t.Errorf("Vault directory should exist")
|
|
}
|
|
})
|
|
|
|
// Test vault listing
|
|
t.Run("ListVaults", func(t *testing.T) {
|
|
vaults, err := ListVaults(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to list vaults: %v", err)
|
|
}
|
|
|
|
found := false
|
|
for _, vault := range vaults {
|
|
if vault == "test-vault" {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("Expected to find 'test-vault' in vault list")
|
|
}
|
|
})
|
|
|
|
// Test vault selection
|
|
t.Run("SelectVault", func(t *testing.T) {
|
|
err := SelectVault(fs, stateDir, "test-vault")
|
|
if err != nil {
|
|
t.Fatalf("Failed to select vault: %v", err)
|
|
}
|
|
|
|
// Test getting current vault
|
|
currentVault, err := GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault: %v", err)
|
|
}
|
|
|
|
if currentVault.GetName() != "test-vault" {
|
|
t.Errorf("Expected current vault 'test-vault', got '%s'", currentVault.GetName())
|
|
}
|
|
})
|
|
|
|
// Test secret operations
|
|
t.Run("SecretOperations", func(t *testing.T) {
|
|
vlt, err := GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault: %v", err)
|
|
}
|
|
|
|
// First, derive the long-term key from the test mnemonic
|
|
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("Failed to derive long-term key: %v", err)
|
|
}
|
|
|
|
// Get the public key from the derived identity
|
|
ltPublicKey := ltIdentity.Recipient().String()
|
|
|
|
// Get the vault directory
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
// Write the correct public key to the pub.age file
|
|
pubKeyPath := filepath.Join(vaultDir, "pub.age")
|
|
err = afero.WriteFile(fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms)
|
|
if err != nil {
|
|
t.Fatalf("Failed to write long-term public key: %v", err)
|
|
}
|
|
|
|
// Unlock the vault with the derived identity
|
|
vlt.Unlock(ltIdentity)
|
|
|
|
// Now add a secret
|
|
secretName := "test/secret"
|
|
secretValue := []byte("test-secret-value")
|
|
expectedValue := make([]byte, len(secretValue))
|
|
copy(expectedValue, secretValue)
|
|
|
|
secretBuffer := memguard.NewBufferFromBytes(secretValue)
|
|
defer secretBuffer.Destroy()
|
|
|
|
err = vlt.AddSecret(secretName, secretBuffer, false)
|
|
if err != nil {
|
|
t.Fatalf("Failed to add secret: %v", err)
|
|
}
|
|
|
|
// List secrets
|
|
secrets, err := vlt.ListSecrets()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list secrets: %v", err)
|
|
}
|
|
|
|
found := false
|
|
for _, secret := range secrets {
|
|
if secret == secretName {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("Expected to find secret '%s' in list", secretName)
|
|
}
|
|
|
|
// Get secret value
|
|
retrievedValue, err := vlt.GetSecret(secretName)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get secret: %v", err)
|
|
}
|
|
|
|
if string(retrievedValue) != string(expectedValue) {
|
|
t.Errorf("Expected secret value '%s', got '%s'", string(expectedValue), string(retrievedValue))
|
|
}
|
|
})
|
|
|
|
// Test NumSecrets
|
|
t.Run("NumSecrets", func(t *testing.T) {
|
|
vlt, err := GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault: %v", err)
|
|
}
|
|
|
|
numSecrets, err := vlt.NumSecrets()
|
|
if err != nil {
|
|
t.Fatalf("Failed to count secrets: %v", err)
|
|
}
|
|
|
|
// We added one secret in SecretOperations
|
|
if numSecrets != 1 {
|
|
t.Errorf("Expected 1 secret, got %d", numSecrets)
|
|
}
|
|
})
|
|
|
|
// Test unlocker operations
|
|
t.Run("UnlockerOperations", func(t *testing.T) {
|
|
vlt, err := GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault: %v", err)
|
|
}
|
|
|
|
// Test vault unlocking (should happen automatically via mnemonic)
|
|
if vlt.Locked() {
|
|
_, err := vlt.UnlockVault()
|
|
if err != nil {
|
|
t.Fatalf("Failed to unlock vault: %v", err)
|
|
}
|
|
}
|
|
|
|
// Create a passphrase unlocker
|
|
passphraseBuffer := memguard.NewBufferFromBytes([]byte("test-passphrase"))
|
|
defer passphraseBuffer.Destroy()
|
|
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create passphrase unlocker: %v", err)
|
|
}
|
|
|
|
// List unlockers
|
|
unlockers, err := vlt.ListUnlockers()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list unlockers: %v", err)
|
|
}
|
|
|
|
if len(unlockers) == 0 {
|
|
t.Errorf("Expected at least one unlocker")
|
|
}
|
|
|
|
// Check key type
|
|
keyFound := false
|
|
for _, key := range unlockers {
|
|
if key.Type == "passphrase" {
|
|
keyFound = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !keyFound {
|
|
t.Errorf("Expected to find passphrase unlocker")
|
|
}
|
|
|
|
// Test selecting unlocker
|
|
err = vlt.SelectUnlocker(passphraseUnlocker.GetID())
|
|
if err != nil {
|
|
t.Fatalf("Failed to select unlocker: %v", err)
|
|
}
|
|
|
|
// Test getting current unlocker
|
|
currentUnlocker, err := vlt.GetCurrentUnlocker()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current unlocker: %v", err)
|
|
}
|
|
|
|
if currentUnlocker.GetID() != passphraseUnlocker.GetID() {
|
|
t.Errorf("Expected current unlocker ID '%s', got '%s'", passphraseUnlocker.GetID(), currentUnlocker.GetID())
|
|
}
|
|
})
|
|
}
|