secret/internal/vault/vault_test.go

238 lines
5.7 KiB
Go

package vault
import (
"os"
"path/filepath"
"testing"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/spf13/afero"
)
func TestVaultOperations(t *testing.T) {
// Save original environment variables
oldMnemonic := os.Getenv(secret.EnvMnemonic)
oldPassphrase := os.Getenv(secret.EnvUnlockPassphrase)
// Clean up after test
defer func() {
if oldMnemonic != "" {
os.Setenv(secret.EnvMnemonic, oldMnemonic)
} else {
os.Unsetenv(secret.EnvMnemonic)
}
if oldPassphrase != "" {
os.Setenv(secret.EnvUnlockPassphrase, oldPassphrase)
} else {
os.Unsetenv(secret.EnvUnlockPassphrase)
}
}()
// Set test environment variables
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
os.Setenv(secret.EnvMnemonic, testMnemonic)
os.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
// Use in-memory filesystem
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Test vault creation
t.Run("CreateVault", func(t *testing.T) {
vlt, err := CreateVault(fs, stateDir, "test-vault")
if err != nil {
t.Fatalf("Failed to create vault: %v", err)
}
if vlt.GetName() != "test-vault" {
t.Errorf("Expected vault name 'test-vault', got '%s'", vlt.GetName())
}
// Check vault directory exists
vaultDir, err := vlt.GetDirectory()
if err != nil {
t.Fatalf("Failed to get vault directory: %v", err)
}
exists, err := afero.DirExists(fs, vaultDir)
if err != nil {
t.Fatalf("Failed to check vault directory: %v", err)
}
if !exists {
t.Errorf("Vault directory should exist")
}
})
// Test vault listing
t.Run("ListVaults", func(t *testing.T) {
vaults, err := ListVaults(fs, stateDir)
if err != nil {
t.Fatalf("Failed to list vaults: %v", err)
}
found := false
for _, vault := range vaults {
if vault == "test-vault" {
found = true
break
}
}
if !found {
t.Errorf("Expected to find 'test-vault' in vault list")
}
})
// Test vault selection
t.Run("SelectVault", func(t *testing.T) {
err := SelectVault(fs, stateDir, "test-vault")
if err != nil {
t.Fatalf("Failed to select vault: %v", err)
}
// Test getting current vault
currentVault, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
if currentVault.GetName() != "test-vault" {
t.Errorf("Expected current vault 'test-vault', got '%s'", currentVault.GetName())
}
})
// Test secret operations
t.Run("SecretOperations", func(t *testing.T) {
vlt, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
// First, derive the long-term key from the test mnemonic
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
if err != nil {
t.Fatalf("Failed to derive long-term key: %v", err)
}
// Get the public key from the derived identity
ltPublicKey := ltIdentity.Recipient().String()
// Get the vault directory
vaultDir, err := vlt.GetDirectory()
if err != nil {
t.Fatalf("Failed to get vault directory: %v", err)
}
// Write the correct public key to the pub.age file
pubKeyPath := filepath.Join(vaultDir, "pub.age")
err = afero.WriteFile(fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms)
if err != nil {
t.Fatalf("Failed to write long-term public key: %v", err)
}
// Unlock the vault with the derived identity
vlt.Unlock(ltIdentity)
// Now add a secret
secretName := "test/secret"
secretValue := []byte("test-secret-value")
err = vlt.AddSecret(secretName, secretValue, false)
if err != nil {
t.Fatalf("Failed to add secret: %v", err)
}
// List secrets
secrets, err := vlt.ListSecrets()
if err != nil {
t.Fatalf("Failed to list secrets: %v", err)
}
found := false
for _, secret := range secrets {
if secret == secretName {
found = true
break
}
}
if !found {
t.Errorf("Expected to find secret '%s' in list", secretName)
}
// Get secret value
retrievedValue, err := vlt.GetSecret(secretName)
if err != nil {
t.Fatalf("Failed to get secret: %v", err)
}
if string(retrievedValue) != string(secretValue) {
t.Errorf("Expected secret value '%s', got '%s'", string(secretValue), string(retrievedValue))
}
})
// Test unlock key operations
t.Run("UnlockKeyOperations", func(t *testing.T) {
vlt, err := GetCurrentVault(fs, stateDir)
if err != nil {
t.Fatalf("Failed to get current vault: %v", err)
}
// Test vault unlocking (should happen automatically via mnemonic)
if vlt.Locked() {
_, err := vlt.UnlockVault()
if err != nil {
t.Fatalf("Failed to unlock vault: %v", err)
}
}
// Create a passphrase unlock key
passphraseKey, err := vlt.CreatePassphraseKey("test-passphrase")
if err != nil {
t.Fatalf("Failed to create passphrase key: %v", err)
}
// List unlock keys
keys, err := vlt.ListUnlockKeys()
if err != nil {
t.Fatalf("Failed to list unlock keys: %v", err)
}
if len(keys) == 0 {
t.Errorf("Expected at least one unlock key")
}
// Check key type
keyFound := false
for _, key := range keys {
if key.Type == "passphrase" {
keyFound = true
break
}
}
if !keyFound {
t.Errorf("Expected to find passphrase unlock key")
}
// Test selecting unlock key
err = vlt.SelectUnlockKey(passphraseKey.GetID())
if err != nil {
t.Fatalf("Failed to select unlock key: %v", err)
}
// Test getting current unlock key
currentKey, err := vlt.GetCurrentUnlockKey()
if err != nil {
t.Fatalf("Failed to get current unlock key: %v", err)
}
if currentKey.GetID() != passphraseKey.GetID() {
t.Errorf("Expected current unlock key ID '%s', got '%s'", passphraseKey.GetID(), currentKey.GetID())
}
})
}