secret/internal/vault/vault_test.go
clawbot 341428d9ca fix: NumSecrets() now correctly counts secrets by checking for current file
NumSecrets() previously looked for non-directory, non-'current' files
directly under each secret directory, but the only children are
'current' (file, excluded) and 'versions' (directory, excluded),
so it always returned 0.

Now checks for the existence of the 'current' file, which is the
canonical indicator that a secret exists and has an active version.

This fixes the safety check in UnlockersRemove that was always
allowing removal of the last unlocker.
2026-02-08 12:04:15 -08:00

246 lines
6.1 KiB
Go

package vault
import (
"path/filepath"
"testing"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero"
)
func TestVaultOperations(t *testing.T) {
// Test environment will be cleaned up automatically by t.Setenv
// Set test environment variables
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
t.Setenv(secret.EnvMnemonic, testMnemonic)
t.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
// Use in-memory filesystem
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Test vault creation
t.Run("CreateVault", func(t *testing.T) {
vlt, err := CreateVault(fs, stateDir, "test-vault")
if err != nil {
t.Fatalf("Failed to create vault: %v", err)
}
if vlt.GetName() != "test-vault" {
t.Errorf("Expected vault name 'test-vault', got '%s'", vlt.GetName())
}
// Check vault directory exists
vaultDir, err := vlt.GetDirectory()
if err != nil {
t.Fatalf("Failed to get vault directory: %v", err)
}
exists, err := afero.DirExists(fs, vaultDir)
if err != nil {
t.Fatalf("Failed to check vault directory: %v", err)
}
if !exists {
t.Errorf("Vault directory should exist")
}
})
// Test vault listing
t.Run("ListVaults", func(t *testing.T) {
vaults, err := ListVaults(fs, stateDir)
if err != nil {
t.Fatalf("Failed to list vaults: %v", err)
}
found := false
for _, vault := range vaults {
if vault == "test-vault" {
found = true
break
}
}
if !found {
t.Errorf("Expected to find 'test-vault' in vault list")
}
})
// Test vault selection
t.Run("SelectVault", func(t *testing.T) {
err := SelectVault(fs, stateDir, "test-vault")
if err != nil {
t.Fatalf("Failed to select vault: %v", err)
}
// Test getting current vault
currentVault, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
if currentVault.GetName() != "test-vault" {
t.Errorf("Expected current vault 'test-vault', got '%s'", currentVault.GetName())
}
})
// Test secret operations
t.Run("SecretOperations", func(t *testing.T) {
vlt, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
// First, derive the long-term key from the test mnemonic
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
if err != nil {
t.Fatalf("Failed to derive long-term key: %v", err)
}
// Get the public key from the derived identity
ltPublicKey := ltIdentity.Recipient().String()
// Get the vault directory
vaultDir, err := vlt.GetDirectory()
if err != nil {
t.Fatalf("Failed to get vault directory: %v", err)
}
// Write the correct public key to the pub.age file
pubKeyPath := filepath.Join(vaultDir, "pub.age")
err = afero.WriteFile(fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms)
if err != nil {
t.Fatalf("Failed to write long-term public key: %v", err)
}
// Unlock the vault with the derived identity
vlt.Unlock(ltIdentity)
// Now add a secret
secretName := "test/secret"
secretValue := []byte("test-secret-value")
expectedValue := make([]byte, len(secretValue))
copy(expectedValue, secretValue)
secretBuffer := memguard.NewBufferFromBytes(secretValue)
defer secretBuffer.Destroy()
err = vlt.AddSecret(secretName, secretBuffer, false)
if err != nil {
t.Fatalf("Failed to add secret: %v", err)
}
// List secrets
secrets, err := vlt.ListSecrets()
if err != nil {
t.Fatalf("Failed to list secrets: %v", err)
}
found := false
for _, secret := range secrets {
if secret == secretName {
found = true
break
}
}
if !found {
t.Errorf("Expected to find secret '%s' in list", secretName)
}
// Get secret value
retrievedValue, err := vlt.GetSecret(secretName)
if err != nil {
t.Fatalf("Failed to get secret: %v", err)
}
if string(retrievedValue) != string(expectedValue) {
t.Errorf("Expected secret value '%s', got '%s'", string(expectedValue), string(retrievedValue))
}
})
// Test NumSecrets
t.Run("NumSecrets", func(t *testing.T) {
vlt, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
numSecrets, err := vlt.NumSecrets()
if err != nil {
t.Fatalf("Failed to count secrets: %v", err)
}
// We added one secret in SecretOperations
if numSecrets != 1 {
t.Errorf("Expected 1 secret, got %d", numSecrets)
}
})
// Test unlocker operations
t.Run("UnlockerOperations", func(t *testing.T) {
vlt, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
// Test vault unlocking (should happen automatically via mnemonic)
if vlt.Locked() {
_, err := vlt.UnlockVault()
if err != nil {
t.Fatalf("Failed to unlock vault: %v", err)
}
}
// Create a passphrase unlocker
passphraseBuffer := memguard.NewBufferFromBytes([]byte("test-passphrase"))
defer passphraseBuffer.Destroy()
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil {
t.Fatalf("Failed to create passphrase unlocker: %v", err)
}
// List unlockers
unlockers, err := vlt.ListUnlockers()
if err != nil {
t.Fatalf("Failed to list unlockers: %v", err)
}
if len(unlockers) == 0 {
t.Errorf("Expected at least one unlocker")
}
// Check key type
keyFound := false
for _, key := range unlockers {
if key.Type == "passphrase" {
keyFound = true
break
}
}
if !keyFound {
t.Errorf("Expected to find passphrase unlocker")
}
// Test selecting unlocker
err = vlt.SelectUnlocker(passphraseUnlocker.GetID())
if err != nil {
t.Fatalf("Failed to select unlocker: %v", err)
}
// Test getting current unlocker
currentUnlocker, err := vlt.GetCurrentUnlocker()
if err != nil {
t.Fatalf("Failed to get current unlocker: %v", err)
}
if currentUnlocker.GetID() != passphraseUnlocker.GetID() {
t.Errorf("Expected current unlocker ID '%s', got '%s'", passphraseUnlocker.GetID(), currentUnlocker.GetID())
}
})
}