426 lines
12 KiB
Go
426 lines
12 KiB
Go
package vault_test
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"git.eeqj.de/sneak/secret/internal/secret"
|
|
"git.eeqj.de/sneak/secret/internal/vault"
|
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
|
"github.com/spf13/afero"
|
|
)
|
|
|
|
func TestVaultWithRealFilesystem(t *testing.T) {
|
|
// Create a temporary directory for our tests
|
|
tempDir, err := os.MkdirTemp("", "secret-test-")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir) // Clean up after test
|
|
|
|
// Use the real filesystem
|
|
fs := afero.NewOsFs()
|
|
|
|
// Test mnemonic
|
|
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
|
|
|
// Save original environment variables
|
|
oldMnemonic := os.Getenv(secret.EnvMnemonic)
|
|
oldPassphrase := os.Getenv(secret.EnvUnlockPassphrase)
|
|
|
|
// Set test environment variables
|
|
os.Setenv(secret.EnvMnemonic, testMnemonic)
|
|
os.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
|
|
|
|
// Clean up after test
|
|
defer func() {
|
|
if oldMnemonic != "" {
|
|
os.Setenv(secret.EnvMnemonic, oldMnemonic)
|
|
} else {
|
|
os.Unsetenv(secret.EnvMnemonic)
|
|
}
|
|
|
|
if oldPassphrase != "" {
|
|
os.Setenv(secret.EnvUnlockPassphrase, oldPassphrase)
|
|
} else {
|
|
os.Unsetenv(secret.EnvUnlockPassphrase)
|
|
}
|
|
}()
|
|
|
|
// Test symlink handling
|
|
t.Run("SymlinkHandling", func(t *testing.T) {
|
|
stateDir := filepath.Join(tempDir, "symlink-test")
|
|
if err := os.MkdirAll(stateDir, 0700); err != nil {
|
|
t.Fatalf("Failed to create state dir: %v", err)
|
|
}
|
|
|
|
// Create a test vault
|
|
vlt, err := vault.CreateVault(fs, stateDir, "test-vault")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault: %v", err)
|
|
}
|
|
|
|
// Get the vault directory
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
// Create a symlink to the vault directory in a different location
|
|
symlinkPath := filepath.Join(tempDir, "test-symlink")
|
|
if err := os.Symlink(vaultDir, symlinkPath); err != nil {
|
|
t.Fatalf("Failed to create symlink: %v", err)
|
|
}
|
|
|
|
// Test that we can resolve the symlink correctly
|
|
resolvedPath, err := vault.ResolveVaultSymlink(fs, symlinkPath)
|
|
if err != nil {
|
|
t.Fatalf("Failed to resolve symlink: %v", err)
|
|
}
|
|
|
|
// On some platforms, the resolved path might have different case or format
|
|
// We'll use filepath.EvalSymlinks to get the canonical path for comparison
|
|
expectedPath, err := filepath.EvalSymlinks(vaultDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to evaluate symlink: %v", err)
|
|
}
|
|
actualPath, err := filepath.EvalSymlinks(resolvedPath)
|
|
if err != nil {
|
|
t.Fatalf("Failed to evaluate resolved path: %v", err)
|
|
}
|
|
|
|
if actualPath != expectedPath {
|
|
t.Errorf("Expected symlink to resolve to %s, got %s", expectedPath, actualPath)
|
|
}
|
|
})
|
|
|
|
// Test secret operations with deeply nested paths
|
|
t.Run("DeepPathSecrets", func(t *testing.T) {
|
|
stateDir := filepath.Join(tempDir, "deep-path-test")
|
|
if err := os.MkdirAll(stateDir, 0700); err != nil {
|
|
t.Fatalf("Failed to create state dir: %v", err)
|
|
}
|
|
|
|
// Create a test vault
|
|
vlt, err := vault.CreateVault(fs, stateDir, "test-vault")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault: %v", err)
|
|
}
|
|
|
|
// Derive long-term key from mnemonic
|
|
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("Failed to derive long-term key: %v", err)
|
|
}
|
|
|
|
// Get the vault directory
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
// Write long-term public key
|
|
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
|
pubKey := ltIdentity.Recipient().String()
|
|
if err := afero.WriteFile(fs, ltPubKeyPath, []byte(pubKey), secret.FilePerms); err != nil {
|
|
t.Fatalf("Failed to write long-term public key: %v", err)
|
|
}
|
|
|
|
// Unlock the vault
|
|
vlt.Unlock(ltIdentity)
|
|
|
|
// Create a secret with a deeply nested path
|
|
deepPath := "api/credentials/production/database/primary"
|
|
secretValue := []byte("supersecretdbpassword")
|
|
|
|
err = vlt.AddSecret(deepPath, secretValue, false)
|
|
if err != nil {
|
|
t.Fatalf("Failed to add secret with deep path: %v", err)
|
|
}
|
|
|
|
// List secrets and verify our deep path secret is there
|
|
secrets, err := vlt.ListSecrets()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list secrets: %v", err)
|
|
}
|
|
|
|
found := false
|
|
for _, s := range secrets {
|
|
if s == deepPath {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("Deep path secret not found in listed secrets")
|
|
}
|
|
|
|
// Retrieve the secret and verify its value
|
|
retrievedValue, err := vlt.GetSecret(deepPath)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve deep path secret: %v", err)
|
|
}
|
|
|
|
if string(retrievedValue) != string(secretValue) {
|
|
t.Errorf("Retrieved value doesn't match. Expected %q, got %q",
|
|
string(secretValue), string(retrievedValue))
|
|
}
|
|
})
|
|
|
|
// Test key caching in GetOrDeriveLongTermKey
|
|
t.Run("KeyCaching", func(t *testing.T) {
|
|
stateDir := filepath.Join(tempDir, "key-cache-test")
|
|
if err := os.MkdirAll(stateDir, 0700); err != nil {
|
|
t.Fatalf("Failed to create state dir: %v", err)
|
|
}
|
|
|
|
// Create a test vault
|
|
vlt, err := vault.CreateVault(fs, stateDir, "test-vault")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault: %v", err)
|
|
}
|
|
|
|
// Derive long-term key from mnemonic
|
|
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("Failed to derive long-term key: %v", err)
|
|
}
|
|
|
|
// Get the vault directory
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
// Write long-term public key
|
|
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
|
pubKey := ltIdentity.Recipient().String()
|
|
if err := afero.WriteFile(fs, ltPubKeyPath, []byte(pubKey), secret.FilePerms); err != nil {
|
|
t.Fatalf("Failed to write long-term public key: %v", err)
|
|
}
|
|
|
|
// Verify the vault is locked initially
|
|
if !vlt.Locked() {
|
|
t.Errorf("Vault should be locked initially")
|
|
}
|
|
|
|
// First call to GetOrDeriveLongTermKey should derive and cache the key
|
|
firstKey, err := vlt.GetOrDeriveLongTermKey()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get long-term key: %v", err)
|
|
}
|
|
|
|
// Verify the vault is now unlocked
|
|
if vlt.Locked() {
|
|
t.Errorf("Vault should be unlocked after GetOrDeriveLongTermKey")
|
|
}
|
|
|
|
// Second call should return the cached key without re-deriving
|
|
secondKey, err := vlt.GetOrDeriveLongTermKey()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get cached long-term key: %v", err)
|
|
}
|
|
|
|
// Verify both keys are the same instance
|
|
if firstKey != secondKey {
|
|
t.Errorf("Second key call should return same instance as first call")
|
|
}
|
|
|
|
// Verify the public key matches what we expect
|
|
expectedPubKey := ltIdentity.Recipient().String()
|
|
actualPubKey := firstKey.Recipient().String()
|
|
if actualPubKey != expectedPubKey {
|
|
t.Errorf("Public key mismatch. Expected %s, got %s", expectedPubKey, actualPubKey)
|
|
}
|
|
|
|
// Now clear the key and verify it's locked again
|
|
vlt.ClearLongTermKey()
|
|
if !vlt.Locked() {
|
|
t.Errorf("Vault should be locked after clearing key")
|
|
}
|
|
|
|
// Get the key again and verify it works
|
|
thirdKey, err := vlt.GetOrDeriveLongTermKey()
|
|
if err != nil {
|
|
t.Fatalf("Failed to re-derive long-term key: %v", err)
|
|
}
|
|
|
|
// Verify the public key still matches
|
|
actualPubKey = thirdKey.Recipient().String()
|
|
if actualPubKey != expectedPubKey {
|
|
t.Errorf("Re-derived public key mismatch. Expected %s, got %s", expectedPubKey, actualPubKey)
|
|
}
|
|
})
|
|
|
|
// Test vault name validation
|
|
t.Run("VaultNameValidation", func(t *testing.T) {
|
|
stateDir := filepath.Join(tempDir, "name-validation-test")
|
|
if err := os.MkdirAll(stateDir, 0700); err != nil {
|
|
t.Fatalf("Failed to create state dir: %v", err)
|
|
}
|
|
|
|
// Test valid vault names
|
|
validNames := []string{
|
|
"default",
|
|
"test-vault",
|
|
"production.vault",
|
|
"vault_123",
|
|
"a-very-long-vault-name-with-dashes",
|
|
}
|
|
|
|
for _, name := range validNames {
|
|
_, err := vault.CreateVault(fs, stateDir, name)
|
|
if err != nil {
|
|
t.Errorf("Failed to create vault with valid name %q: %v", name, err)
|
|
}
|
|
}
|
|
|
|
// Test invalid vault names
|
|
invalidNames := []string{
|
|
"", // Empty
|
|
"UPPERCASE", // Uppercase not allowed
|
|
"invalid/name", // Slashes not allowed in vault names
|
|
"invalid name", // Spaces not allowed
|
|
"invalid@name", // Special chars not allowed
|
|
}
|
|
|
|
for _, name := range invalidNames {
|
|
_, err := vault.CreateVault(fs, stateDir, name)
|
|
if err == nil {
|
|
t.Errorf("Expected error creating vault with invalid name %q, but got none", name)
|
|
}
|
|
}
|
|
})
|
|
|
|
// Test multiple vaults and switching between them
|
|
t.Run("MultipleVaults", func(t *testing.T) {
|
|
stateDir := filepath.Join(tempDir, "multi-vault-test")
|
|
if err := os.MkdirAll(stateDir, 0700); err != nil {
|
|
t.Fatalf("Failed to create state dir: %v", err)
|
|
}
|
|
|
|
// Create three vaults
|
|
vaultNames := []string{"vault1", "vault2", "vault3"}
|
|
for _, name := range vaultNames {
|
|
_, err := vault.CreateVault(fs, stateDir, name)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault %s: %v", name, err)
|
|
}
|
|
}
|
|
|
|
// List vaults and verify all three are there
|
|
vaults, err := vault.ListVaults(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to list vaults: %v", err)
|
|
}
|
|
|
|
if len(vaults) != 3 {
|
|
t.Errorf("Expected 3 vaults, got %d", len(vaults))
|
|
}
|
|
|
|
// Test switching between vaults
|
|
for _, name := range vaultNames {
|
|
// Select the vault
|
|
if err := vault.SelectVault(fs, stateDir, name); err != nil {
|
|
t.Fatalf("Failed to select vault %s: %v", name, err)
|
|
}
|
|
|
|
// Get current vault and verify it's the one we selected
|
|
currentVault, err := vault.GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault after selecting %s: %v", name, err)
|
|
}
|
|
|
|
if currentVault.GetName() != name {
|
|
t.Errorf("Expected current vault to be %s, got %s", name, currentVault.GetName())
|
|
}
|
|
}
|
|
})
|
|
|
|
// Test adding a secret in one vault and verifying it's not visible in another
|
|
t.Run("VaultIsolation", func(t *testing.T) {
|
|
stateDir := filepath.Join(tempDir, "isolation-test")
|
|
if err := os.MkdirAll(stateDir, 0700); err != nil {
|
|
t.Fatalf("Failed to create state dir: %v", err)
|
|
}
|
|
|
|
// Create two vaults
|
|
vault1, err := vault.CreateVault(fs, stateDir, "vault1")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault1: %v", err)
|
|
}
|
|
|
|
vault2, err := vault.CreateVault(fs, stateDir, "vault2")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault2: %v", err)
|
|
}
|
|
|
|
// Derive long-term key from mnemonic
|
|
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("Failed to derive long-term key: %v", err)
|
|
}
|
|
|
|
// Setup both vaults with the same long-term key
|
|
for _, vlt := range []*vault.Vault{vault1, vault2} {
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
|
pubKey := ltIdentity.Recipient().String()
|
|
if err := afero.WriteFile(fs, ltPubKeyPath, []byte(pubKey), secret.FilePerms); err != nil {
|
|
t.Fatalf("Failed to write long-term public key: %v", err)
|
|
}
|
|
|
|
vlt.Unlock(ltIdentity)
|
|
}
|
|
|
|
// Add a secret to vault1
|
|
secretName := "test-secret"
|
|
secretValue := []byte("secret in vault1")
|
|
if err := vault1.AddSecret(secretName, secretValue, false); err != nil {
|
|
t.Fatalf("Failed to add secret to vault1: %v", err)
|
|
}
|
|
|
|
// Verify the secret exists in vault1
|
|
vault1Secrets, err := vault1.ListSecrets()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list secrets in vault1: %v", err)
|
|
}
|
|
|
|
found := false
|
|
for _, s := range vault1Secrets {
|
|
if s == secretName {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("Secret not found in vault1")
|
|
}
|
|
|
|
// Verify the secret does NOT exist in vault2
|
|
vault2Secrets, err := vault2.ListSecrets()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list secrets in vault2: %v", err)
|
|
}
|
|
|
|
found = false
|
|
for _, s := range vault2Secrets {
|
|
if s == secretName {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if found {
|
|
t.Errorf("Secret from vault1 should not be visible in vault2")
|
|
}
|
|
})
|
|
}
|