Refactor vault functionality to dedicated package, fix import cycles with interface pattern, fix tests
This commit is contained in:
237
internal/vault/vault_test.go
Normal file
237
internal/vault/vault_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
func TestVaultOperations(t *testing.T) {
|
||||
// Save original environment variables
|
||||
oldMnemonic := os.Getenv(secret.EnvMnemonic)
|
||||
oldPassphrase := os.Getenv(secret.EnvUnlockPassphrase)
|
||||
|
||||
// Clean up after test
|
||||
defer func() {
|
||||
if oldMnemonic != "" {
|
||||
os.Setenv(secret.EnvMnemonic, oldMnemonic)
|
||||
} else {
|
||||
os.Unsetenv(secret.EnvMnemonic)
|
||||
}
|
||||
|
||||
if oldPassphrase != "" {
|
||||
os.Setenv(secret.EnvUnlockPassphrase, oldPassphrase)
|
||||
} else {
|
||||
os.Unsetenv(secret.EnvUnlockPassphrase)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set test environment variables
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
os.Setenv(secret.EnvMnemonic, testMnemonic)
|
||||
os.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
|
||||
|
||||
// Use in-memory filesystem
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Test vault creation
|
||||
t.Run("CreateVault", func(t *testing.T) {
|
||||
vlt, err := CreateVault(fs, stateDir, "test-vault")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault: %v", err)
|
||||
}
|
||||
|
||||
if vlt.GetName() != "test-vault" {
|
||||
t.Errorf("Expected vault name 'test-vault', got '%s'", vlt.GetName())
|
||||
}
|
||||
|
||||
// Check vault directory exists
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
|
||||
exists, err := afero.DirExists(fs, vaultDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check vault directory: %v", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
t.Errorf("Vault directory should exist")
|
||||
}
|
||||
})
|
||||
|
||||
// Test vault listing
|
||||
t.Run("ListVaults", func(t *testing.T) {
|
||||
vaults, err := ListVaults(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list vaults: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, vault := range vaults {
|
||||
if vault == "test-vault" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("Expected to find 'test-vault' in vault list")
|
||||
}
|
||||
})
|
||||
|
||||
// Test vault selection
|
||||
t.Run("SelectVault", func(t *testing.T) {
|
||||
err := SelectVault(fs, stateDir, "test-vault")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select vault: %v", err)
|
||||
}
|
||||
|
||||
// Test getting current vault
|
||||
currentVault, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current vault: %v", err)
|
||||
}
|
||||
|
||||
if currentVault.GetName() != "test-vault" {
|
||||
t.Errorf("Expected current vault 'test-vault', got '%s'", currentVault.GetName())
|
||||
}
|
||||
})
|
||||
|
||||
// Test secret operations
|
||||
t.Run("SecretOperations", func(t *testing.T) {
|
||||
vlt, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current vault: %v", err)
|
||||
}
|
||||
|
||||
// First, derive the long-term key from the test mnemonic
|
||||
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Get the public key from the derived identity
|
||||
ltPublicKey := ltIdentity.Recipient().String()
|
||||
|
||||
// Get the vault directory
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Write the correct public key to the pub.age file
|
||||
pubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, pubKeyPath, []byte(ltPublicKey), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write long-term public key: %v", err)
|
||||
}
|
||||
|
||||
// Unlock the vault with the derived identity
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Now add a secret
|
||||
secretName := "test/secret"
|
||||
secretValue := []byte("test-secret-value")
|
||||
|
||||
err = vlt.AddSecret(secretName, secretValue, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add secret: %v", err)
|
||||
}
|
||||
|
||||
// List secrets
|
||||
secrets, err := vlt.ListSecrets()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list secrets: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, secret := range secrets {
|
||||
if secret == secretName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("Expected to find secret '%s' in list", secretName)
|
||||
}
|
||||
|
||||
// Get secret value
|
||||
retrievedValue, err := vlt.GetSecret(secretName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get secret: %v", err)
|
||||
}
|
||||
|
||||
if string(retrievedValue) != string(secretValue) {
|
||||
t.Errorf("Expected secret value '%s', got '%s'", string(secretValue), string(retrievedValue))
|
||||
}
|
||||
})
|
||||
|
||||
// Test unlock key operations
|
||||
t.Run("UnlockKeyOperations", func(t *testing.T) {
|
||||
vlt, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current vault: %v", err)
|
||||
}
|
||||
|
||||
// Test vault unlocking (should happen automatically via mnemonic)
|
||||
if vlt.Locked() {
|
||||
_, err := vlt.UnlockVault()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unlock vault: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a passphrase unlock key
|
||||
passphraseKey, err := vlt.CreatePassphraseKey("test-passphrase")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create passphrase key: %v", err)
|
||||
}
|
||||
|
||||
// List unlock keys
|
||||
keys, err := vlt.ListUnlockKeys()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list unlock keys: %v", err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
t.Errorf("Expected at least one unlock key")
|
||||
}
|
||||
|
||||
// Check key type
|
||||
keyFound := false
|
||||
for _, key := range keys {
|
||||
if key.Type == "passphrase" {
|
||||
keyFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !keyFound {
|
||||
t.Errorf("Expected to find passphrase unlock key")
|
||||
}
|
||||
|
||||
// Test selecting unlock key
|
||||
err = vlt.SelectUnlockKey(passphraseKey.GetID())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select unlock key: %v", err)
|
||||
}
|
||||
|
||||
// Test getting current unlock key
|
||||
currentKey, err := vlt.GetCurrentUnlockKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current unlock key: %v", err)
|
||||
}
|
||||
|
||||
if currentKey.GetID() != passphraseKey.GetID() {
|
||||
t.Errorf("Expected current unlock key ID '%s', got '%s'", passphraseKey.GetID(), currentKey.GetID())
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user