238 lines
5.7 KiB
Go
238 lines
5.7 KiB
Go
package vault
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"git.eeqj.de/sneak/secret/internal/secret"
|
|
"git.eeqj.de/sneak/secret/pkg/agehd"
|
|
"github.com/spf13/afero"
|
|
)
|
|
|
|
func TestVaultOperations(t *testing.T) {
|
|
// Save original environment variables
|
|
oldMnemonic := os.Getenv(secret.EnvMnemonic)
|
|
oldPassphrase := os.Getenv(secret.EnvUnlockPassphrase)
|
|
|
|
// Clean up after test
|
|
defer func() {
|
|
if oldMnemonic != "" {
|
|
os.Setenv(secret.EnvMnemonic, oldMnemonic)
|
|
} else {
|
|
os.Unsetenv(secret.EnvMnemonic)
|
|
}
|
|
|
|
if oldPassphrase != "" {
|
|
os.Setenv(secret.EnvUnlockPassphrase, oldPassphrase)
|
|
} else {
|
|
os.Unsetenv(secret.EnvUnlockPassphrase)
|
|
}
|
|
}()
|
|
|
|
// Set test environment variables
|
|
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
|
os.Setenv(secret.EnvMnemonic, testMnemonic)
|
|
os.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
|
|
|
|
// Use in-memory filesystem
|
|
fs := afero.NewMemMapFs()
|
|
stateDir := "/test/state"
|
|
|
|
// Test vault creation
|
|
t.Run("CreateVault", func(t *testing.T) {
|
|
vlt, err := CreateVault(fs, stateDir, "test-vault")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create vault: %v", err)
|
|
}
|
|
|
|
if vlt.GetName() != "test-vault" {
|
|
t.Errorf("Expected vault name 'test-vault', got '%s'", vlt.GetName())
|
|
}
|
|
|
|
// Check vault directory exists
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
exists, err := afero.DirExists(fs, vaultDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to check vault directory: %v", err)
|
|
}
|
|
|
|
if !exists {
|
|
t.Errorf("Vault directory should exist")
|
|
}
|
|
})
|
|
|
|
// Test vault listing
|
|
t.Run("ListVaults", func(t *testing.T) {
|
|
vaults, err := ListVaults(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to list vaults: %v", err)
|
|
}
|
|
|
|
found := false
|
|
for _, vault := range vaults {
|
|
if vault == "test-vault" {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("Expected to find 'test-vault' in vault list")
|
|
}
|
|
})
|
|
|
|
// Test vault selection
|
|
t.Run("SelectVault", func(t *testing.T) {
|
|
err := SelectVault(fs, stateDir, "test-vault")
|
|
if err != nil {
|
|
t.Fatalf("Failed to select vault: %v", err)
|
|
}
|
|
|
|
// Test getting current vault
|
|
currentVault, err := GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault: %v", err)
|
|
}
|
|
|
|
if currentVault.GetName() != "test-vault" {
|
|
t.Errorf("Expected current vault 'test-vault', got '%s'", currentVault.GetName())
|
|
}
|
|
})
|
|
|
|
// Test secret operations
|
|
t.Run("SecretOperations", func(t *testing.T) {
|
|
vlt, err := GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault: %v", err)
|
|
}
|
|
|
|
// First, derive the long-term key from the test mnemonic
|
|
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("Failed to derive long-term key: %v", err)
|
|
}
|
|
|
|
// Get the public key from the derived identity
|
|
ltPublicKey := ltIdentity.Recipient().String()
|
|
|
|
// Get the vault directory
|
|
vaultDir, err := vlt.GetDirectory()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get vault directory: %v", err)
|
|
}
|
|
|
|
// Write the correct public key to the pub.age file
|
|
pubKeyPath := filepath.Join(vaultDir, "pub.age")
|
|
err = afero.WriteFile(fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms)
|
|
if err != nil {
|
|
t.Fatalf("Failed to write long-term public key: %v", err)
|
|
}
|
|
|
|
// Unlock the vault with the derived identity
|
|
vlt.Unlock(ltIdentity)
|
|
|
|
// Now add a secret
|
|
secretName := "test/secret"
|
|
secretValue := []byte("test-secret-value")
|
|
|
|
err = vlt.AddSecret(secretName, secretValue, false)
|
|
if err != nil {
|
|
t.Fatalf("Failed to add secret: %v", err)
|
|
}
|
|
|
|
// List secrets
|
|
secrets, err := vlt.ListSecrets()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list secrets: %v", err)
|
|
}
|
|
|
|
found := false
|
|
for _, secret := range secrets {
|
|
if secret == secretName {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("Expected to find secret '%s' in list", secretName)
|
|
}
|
|
|
|
// Get secret value
|
|
retrievedValue, err := vlt.GetSecret(secretName)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get secret: %v", err)
|
|
}
|
|
|
|
if string(retrievedValue) != string(secretValue) {
|
|
t.Errorf("Expected secret value '%s', got '%s'", string(secretValue), string(retrievedValue))
|
|
}
|
|
})
|
|
|
|
// Test unlock key operations
|
|
t.Run("UnlockKeyOperations", func(t *testing.T) {
|
|
vlt, err := GetCurrentVault(fs, stateDir)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current vault: %v", err)
|
|
}
|
|
|
|
// Test vault unlocking (should happen automatically via mnemonic)
|
|
if vlt.Locked() {
|
|
_, err := vlt.UnlockVault()
|
|
if err != nil {
|
|
t.Fatalf("Failed to unlock vault: %v", err)
|
|
}
|
|
}
|
|
|
|
// Create a passphrase unlock key
|
|
passphraseKey, err := vlt.CreatePassphraseKey("test-passphrase")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create passphrase key: %v", err)
|
|
}
|
|
|
|
// List unlock keys
|
|
keys, err := vlt.ListUnlockKeys()
|
|
if err != nil {
|
|
t.Fatalf("Failed to list unlock keys: %v", err)
|
|
}
|
|
|
|
if len(keys) == 0 {
|
|
t.Errorf("Expected at least one unlock key")
|
|
}
|
|
|
|
// Check key type
|
|
keyFound := false
|
|
for _, key := range keys {
|
|
if key.Type == "passphrase" {
|
|
keyFound = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !keyFound {
|
|
t.Errorf("Expected to find passphrase unlock key")
|
|
}
|
|
|
|
// Test selecting unlock key
|
|
err = vlt.SelectUnlockKey(passphraseKey.GetID())
|
|
if err != nil {
|
|
t.Fatalf("Failed to select unlock key: %v", err)
|
|
}
|
|
|
|
// Test getting current unlock key
|
|
currentKey, err := vlt.GetCurrentUnlockKey()
|
|
if err != nil {
|
|
t.Fatalf("Failed to get current unlock key: %v", err)
|
|
}
|
|
|
|
if currentKey.GetID() != passphraseKey.GetID() {
|
|
t.Errorf("Expected current unlock key ID '%s', got '%s'", passphraseKey.GetID(), currentKey.GetID())
|
|
}
|
|
})
|
|
}
|