package vault import ( "os" "path/filepath" "testing" "git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/pkg/agehd" "github.com/spf13/afero" ) func TestVaultOperations(t *testing.T) { // Save original environment variables oldMnemonic := os.Getenv(secret.EnvMnemonic) oldPassphrase := os.Getenv(secret.EnvUnlockPassphrase) // Clean up after test defer func() { if oldMnemonic != "" { os.Setenv(secret.EnvMnemonic, oldMnemonic) } else { os.Unsetenv(secret.EnvMnemonic) } if oldPassphrase != "" { os.Setenv(secret.EnvUnlockPassphrase, oldPassphrase) } else { os.Unsetenv(secret.EnvUnlockPassphrase) } }() // Set test environment variables testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" os.Setenv(secret.EnvMnemonic, testMnemonic) os.Setenv(secret.EnvUnlockPassphrase, "test-passphrase") // Use in-memory filesystem fs := afero.NewMemMapFs() stateDir := "/test/state" // Test vault creation t.Run("CreateVault", func(t *testing.T) { vlt, err := CreateVault(fs, stateDir, "test-vault") if err != nil { t.Fatalf("Failed to create vault: %v", err) } if vlt.GetName() != "test-vault" { t.Errorf("Expected vault name 'test-vault', got '%s'", vlt.GetName()) } // Check vault directory exists vaultDir, err := vlt.GetDirectory() if err != nil { t.Fatalf("Failed to get vault directory: %v", err) } exists, err := afero.DirExists(fs, vaultDir) if err != nil { t.Fatalf("Failed to check vault directory: %v", err) } if !exists { t.Errorf("Vault directory should exist") } }) // Test vault listing t.Run("ListVaults", func(t *testing.T) { vaults, err := ListVaults(fs, stateDir) if err != nil { t.Fatalf("Failed to list vaults: %v", err) } found := false for _, vault := range vaults { if vault == "test-vault" { found = true break } } if !found { t.Errorf("Expected to find 'test-vault' in vault list") } }) // Test vault selection t.Run("SelectVault", func(t *testing.T) { err := SelectVault(fs, stateDir, "test-vault") if err != nil { t.Fatalf("Failed to select vault: %v", err) } // Test getting current vault currentVault, err := GetCurrentVault(fs, stateDir) if err != nil { t.Fatalf("Failed to get current vault: %v", err) } if currentVault.GetName() != "test-vault" { t.Errorf("Expected current vault 'test-vault', got '%s'", currentVault.GetName()) } }) // Test secret operations t.Run("SecretOperations", func(t *testing.T) { vlt, err := GetCurrentVault(fs, stateDir) if err != nil { t.Fatalf("Failed to get current vault: %v", err) } // First, derive the long-term key from the test mnemonic ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0) if err != nil { t.Fatalf("Failed to derive long-term key: %v", err) } // Get the public key from the derived identity ltPublicKey := ltIdentity.Recipient().String() // Get the vault directory vaultDir, err := vlt.GetDirectory() if err != nil { t.Fatalf("Failed to get vault directory: %v", err) } // Write the correct public key to the pub.age file pubKeyPath := filepath.Join(vaultDir, "pub.age") err = afero.WriteFile(fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms) if err != nil { t.Fatalf("Failed to write long-term public key: %v", err) } // Unlock the vault with the derived identity vlt.Unlock(ltIdentity) // Now add a secret secretName := "test/secret" secretValue := []byte("test-secret-value") err = vlt.AddSecret(secretName, secretValue, false) if err != nil { t.Fatalf("Failed to add secret: %v", err) } // List secrets secrets, err := vlt.ListSecrets() if err != nil { t.Fatalf("Failed to list secrets: %v", err) } found := false for _, secret := range secrets { if secret == secretName { found = true break } } if !found { t.Errorf("Expected to find secret '%s' in list", secretName) } // Get secret value retrievedValue, err := vlt.GetSecret(secretName) if err != nil { t.Fatalf("Failed to get secret: %v", err) } if string(retrievedValue) != string(secretValue) { t.Errorf("Expected secret value '%s', got '%s'", string(secretValue), string(retrievedValue)) } }) // Test unlock key operations t.Run("UnlockKeyOperations", func(t *testing.T) { vlt, err := GetCurrentVault(fs, stateDir) if err != nil { t.Fatalf("Failed to get current vault: %v", err) } // Test vault unlocking (should happen automatically via mnemonic) if vlt.Locked() { _, err := vlt.UnlockVault() if err != nil { t.Fatalf("Failed to unlock vault: %v", err) } } // Create a passphrase unlock key passphraseKey, err := vlt.CreatePassphraseKey("test-passphrase") if err != nil { t.Fatalf("Failed to create passphrase key: %v", err) } // List unlock keys keys, err := vlt.ListUnlockKeys() if err != nil { t.Fatalf("Failed to list unlock keys: %v", err) } if len(keys) == 0 { t.Errorf("Expected at least one unlock key") } // Check key type keyFound := false for _, key := range keys { if key.Type == "passphrase" { keyFound = true break } } if !keyFound { t.Errorf("Expected to find passphrase unlock key") } // Test selecting unlock key err = vlt.SelectUnlockKey(passphraseKey.GetID()) if err != nil { t.Fatalf("Failed to select unlock key: %v", err) } // Test getting current unlock key currentKey, err := vlt.GetCurrentUnlockKey() if err != nil { t.Fatalf("Failed to get current unlock key: %v", err) } if currentKey.GetID() != passphraseKey.GetID() { t.Errorf("Expected current unlock key ID '%s', got '%s'", passphraseKey.GetID(), currentKey.GetID()) } }) }