diff --git a/internal/secret/keychainunlocker.go b/internal/secret/keychainunlocker.go index 0a34659..f59ddbd 100644 --- a/internal/secret/keychainunlocker.go +++ b/internal/secret/keychainunlocker.go @@ -220,6 +220,68 @@ func generateKeychainUnlockerName(vaultName string) (string, error) { return fmt.Sprintf("secret-%s-%s-%s", vaultName, hostname, enrollmentDate), nil } +// getLongTermPrivateKey retrieves the long-term private key either from environment or current unlocker +func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) ([]byte, error) { + // Check if mnemonic is available in environment variable + envMnemonic := os.Getenv(EnvMnemonic) + if envMnemonic != "" { + // Use mnemonic directly to derive long-term key + ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0) + if err != nil { + return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err) + } + return []byte(ltIdentity.String()), nil + } + + // Get the vault to access current unlocker + currentUnlocker, err := vault.GetCurrentUnlocker() + if err != nil { + return nil, fmt.Errorf("failed to get current unlocker: %w", err) + } + + // Get the current unlocker identity + currentUnlockerIdentity, err := currentUnlocker.GetIdentity() + if err != nil { + return nil, fmt.Errorf("failed to get current unlocker identity: %w", err) + } + + // Get encrypted long-term key from current unlocker, handling different types + var encryptedLtPrivKey []byte + switch currentUnlocker := currentUnlocker.(type) { + case *PassphraseUnlocker: + // Read the encrypted long-term private key from passphrase unlocker + encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted long-term key from current passphrase unlocker: %w", err) + } + + case *PGPUnlocker: + // Read the encrypted long-term private key from PGP unlocker + encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted long-term key from current PGP unlocker: %w", err) + } + + case *KeychainUnlocker: + // Read the encrypted long-term private key from another keychain unlocker + encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) + if err != nil { + return nil, fmt.Errorf("failed to read encrypted long-term key from current keychain unlocker: %w", err) + } + + default: + return nil, fmt.Errorf("unsupported current unlocker type for keychain unlocker creation") + } + + // Decrypt long-term private key using current unlocker + ltPrivKeyData, err := DecryptWithIdentity(encryptedLtPrivKey, currentUnlockerIdentity) + if err != nil { + return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err) + } + + return ltPrivKeyData, nil +} + // CreateKeychainUnlocker creates a new keychain unlocker and stores it in the vault func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, error) { // Check if we're on macOS @@ -282,62 +344,9 @@ func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, er } // Step 5: Get or derive the long-term private key - var ltPrivKeyData []byte - - // Check if mnemonic is available in environment variable - if envMnemonic := os.Getenv(EnvMnemonic); envMnemonic != "" { - // Use mnemonic directly to derive long-term key - ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0) - if err != nil { - return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err) - } - ltPrivKeyData = []byte(ltIdentity.String()) - } else { - // Get the vault to access current unlocker - currentUnlocker, err := vault.GetCurrentUnlocker() - if err != nil { - return nil, fmt.Errorf("failed to get current unlocker: %w", err) - } - - // Get the current unlocker identity - currentUnlockerIdentity, err := currentUnlocker.GetIdentity() - if err != nil { - return nil, fmt.Errorf("failed to get current unlocker identity: %w", err) - } - - // Get encrypted long-term key from current unlocker, handling different types - var encryptedLtPrivKey []byte - switch currentUnlocker := currentUnlocker.(type) { - case *PassphraseUnlocker: - // Read the encrypted long-term private key from passphrase unlocker - encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) - if err != nil { - return nil, fmt.Errorf("failed to read encrypted long-term key from current passphrase unlocker: %w", err) - } - - case *PGPUnlocker: - // Read the encrypted long-term private key from PGP unlocker - encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) - if err != nil { - return nil, fmt.Errorf("failed to read encrypted long-term key from current PGP unlocker: %w", err) - } - - case *KeychainUnlocker: - // Read the encrypted long-term private key from another keychain unlocker - encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) - if err != nil { - return nil, fmt.Errorf("failed to read encrypted long-term key from current keychain unlocker: %w", err) - } - - default: - return nil, fmt.Errorf("unsupported current unlocker type for keychain unlocker creation") - } - - // Decrypt long-term private key using current unlocker - ltPrivKeyData, err = DecryptWithIdentity(encryptedLtPrivKey, currentUnlockerIdentity) - if err != nil { - return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err) - } + ltPrivKeyData, err := getLongTermPrivateKey(fs, vault) + if err != nil { + return nil, err } // Step 6: Encrypt long-term private key to the new age unlocker diff --git a/internal/secret/passphraseunlocker.go b/internal/secret/passphraseunlocker.go index 4e87171..df7dd88 100644 --- a/internal/secret/passphraseunlocker.go +++ b/internal/secret/passphraseunlocker.go @@ -18,6 +18,32 @@ type PassphraseUnlocker struct { Passphrase string } +// getPassphrase retrieves the passphrase from memory, environment, or user input +func (p *PassphraseUnlocker) getPassphrase() (string, error) { + // First check if we already have the passphrase + if p.Passphrase != "" { + Debug("Using in-memory passphrase", "unlocker_id", p.GetID()) + return p.Passphrase, nil + } + + Debug("No passphrase in memory, checking environment") + // Check environment variable for passphrase + passphraseStr := os.Getenv(EnvUnlockPassphrase) + if passphraseStr != "" { + Debug("Using passphrase from environment", "unlocker_id", p.GetID()) + return passphraseStr, nil + } + + Debug("No passphrase in environment, prompting user") + // Prompt for passphrase + passphraseStr, err := ReadPassphrase("Enter unlock passphrase: ") + if err != nil { + Debug("Failed to read passphrase", "error", err, "unlocker_id", p.GetID()) + return "", fmt.Errorf("failed to read passphrase: %w", err) + } + return passphraseStr, nil +} + // GetIdentity implements Unlocker interface for passphrase-based unlockers func (p *PassphraseUnlocker) GetIdentity() (*age.X25519Identity, error) { DebugWith("Getting passphrase unlocker identity", @@ -25,26 +51,9 @@ func (p *PassphraseUnlocker) GetIdentity() (*age.X25519Identity, error) { slog.String("unlocker_type", p.GetType()), ) - // First check if we already have the passphrase - passphraseStr := p.Passphrase - if passphraseStr == "" { - Debug("No passphrase in memory, checking environment") - // Check environment variable for passphrase - passphraseStr = os.Getenv(EnvUnlockPassphrase) - if passphraseStr == "" { - Debug("No passphrase in environment, prompting user") - // Prompt for passphrase - var err error - passphraseStr, err = ReadPassphrase("Enter unlock passphrase: ") - if err != nil { - Debug("Failed to read passphrase", "error", err, "unlocker_id", p.GetID()) - return nil, fmt.Errorf("failed to read passphrase: %w", err) - } - } else { - Debug("Using passphrase from environment", "unlocker_id", p.GetID()) - } - } else { - Debug("Using in-memory passphrase", "unlocker_id", p.GetID()) + passphraseStr, err := p.getPassphrase() + if err != nil { + return nil, err } // Read encrypted private key of unlocker diff --git a/internal/secret/pgpunlocker.go b/internal/secret/pgpunlocker.go index 963d291..0ce0318 100644 --- a/internal/secret/pgpunlocker.go +++ b/internal/secret/pgpunlocker.go @@ -12,7 +12,6 @@ import ( "time" "filippo.io/age" - "git.eeqj.de/sneak/secret/pkg/agehd" "github.com/spf13/afero" ) @@ -227,55 +226,9 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc } // Step 3: Get or derive the long-term private key - var ltPrivKeyData []byte - - // Check if mnemonic is available in environment variable - if envMnemonic := os.Getenv(EnvMnemonic); envMnemonic != "" { - // Use mnemonic directly to derive long-term key - ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0) - if err != nil { - return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err) - } - ltPrivKeyData = []byte(ltIdentity.String()) - } else { - // Get the vault to access current unlocker - currentUnlocker, err := vault.GetCurrentUnlocker() - if err != nil { - return nil, fmt.Errorf("failed to get current unlocker: %w", err) - } - - // Get the current unlocker identity - currentUnlockerIdentity, err := currentUnlocker.GetIdentity() - if err != nil { - return nil, fmt.Errorf("failed to get current unlocker identity: %w", err) - } - - // Get encrypted long-term key from current unlocker, handling different types - var encryptedLtPrivKey []byte - switch currentUnlocker := currentUnlocker.(type) { - case *PassphraseUnlocker: - // Read the encrypted long-term private key from passphrase unlocker - encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) - if err != nil { - return nil, fmt.Errorf("failed to read encrypted long-term key from current passphrase unlocker: %w", err) - } - - case *PGPUnlocker: - // Read the encrypted long-term private key from PGP unlocker - encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age")) - if err != nil { - return nil, fmt.Errorf("failed to read encrypted long-term key from current PGP unlocker: %w", err) - } - - default: - return nil, fmt.Errorf("unsupported current unlocker type for PGP unlocker creation") - } - - // Step 6: Decrypt long-term private key using current unlocker - ltPrivKeyData, err = DecryptWithIdentity(encryptedLtPrivKey, currentUnlockerIdentity) - if err != nil { - return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err) - } + ltPrivKeyData, err := getLongTermPrivateKey(fs, vault) + if err != nil { + return nil, err } // Step 7: Encrypt long-term private key to the new age unlocker diff --git a/internal/secret/version.go b/internal/secret/version.go index 2b35106..6d1efd9 100644 --- a/internal/secret/version.go +++ b/internal/secret/version.go @@ -89,17 +89,24 @@ func GenerateVersionName(fs afero.Fs, secretDir string) (string, error) { prefix := today + "." for _, entry := range entries { - if entry.IsDir() && strings.HasPrefix(entry.Name(), prefix) { - // Extract serial number - parts := strings.Split(entry.Name(), ".") - if len(parts) == versionNameParts { - var serial int - if _, err := fmt.Sscanf(parts[1], "%03d", &serial); err == nil { - if serial > maxSerial { - maxSerial = serial - } - } - } + // Skip non-directories and those without correct prefix + if !entry.IsDir() || !strings.HasPrefix(entry.Name(), prefix) { + continue + } + + // Extract serial number + parts := strings.Split(entry.Name(), ".") + if len(parts) != versionNameParts { + continue + } + + var serial int + if _, err := fmt.Sscanf(parts[1], "%03d", &serial); err != nil { + continue + } + + if serial > maxSerial { + maxSerial = serial } } diff --git a/internal/vault/management.go b/internal/vault/management.go index 6e8b8c2..9489d5c 100644 --- a/internal/vault/management.go +++ b/internal/vault/management.go @@ -30,65 +30,60 @@ func isValidVaultName(name string) bool { return matched } +// resolveRelativeSymlink resolves a relative symlink target to an absolute path +func resolveRelativeSymlink(symlinkPath, target string) (string, error) { + // Get the current directory before changing + originalDir, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("failed to get current directory: %w", err) + } + secret.Debug("Got current directory", "original_dir", originalDir) + + // Change to the symlink's directory + symlinkDir := filepath.Dir(symlinkPath) + secret.Debug("Changing to symlink directory", "symlink_path", symlinkDir) + secret.Debug("About to call os.Chdir - this might hang if symlink is broken") + if err := os.Chdir(symlinkDir); err != nil { + return "", fmt.Errorf("failed to change to symlink directory: %w", err) + } + secret.Debug("Changed to symlink directory successfully - os.Chdir completed") + + // Get the absolute path of the target + secret.Debug("Getting absolute path of current directory") + absolutePath, err := os.Getwd() + if err != nil { + // Try to restore original directory before returning error + _ = os.Chdir(originalDir) + return "", fmt.Errorf("failed to get absolute path: %w", err) + } + secret.Debug("Got absolute path", "absolute_path", absolutePath) + + // Restore the original directory + secret.Debug("Restoring original directory", "original_dir", originalDir) + if err := os.Chdir(originalDir); err != nil { + return "", fmt.Errorf("failed to restore original directory: %w", err) + } + secret.Debug("Restored original directory successfully") + + return absolutePath, nil +} + // ResolveVaultSymlink resolves the currentvault symlink by reading either the symlink target or file contents // This function is designed to work on both Unix and Windows systems, as well as with in-memory filesystems func ResolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) { secret.Debug("resolveVaultSymlink starting", "symlink_path", symlinkPath) // First try to handle the path as a real symlink (works on Unix systems) - if _, ok := fs.(*afero.OsFs); ok { - secret.Debug("Using real filesystem symlink resolution") - - // Check if the symlink exists - secret.Debug("Checking symlink target", "symlink_path", symlinkPath) - target, err := os.Readlink(symlinkPath) + _, isOsFs := fs.(*afero.OsFs) + if isOsFs { + target, err := tryResolveOsSymlink(symlinkPath) if err == nil { - secret.Debug("Symlink points to", "target", target) - - // On real filesystem, we need to handle relative symlinks - // by resolving them relative to the symlink's directory - if !filepath.IsAbs(target) { - // Get the current directory before changing - originalDir, err := os.Getwd() - if err != nil { - return "", fmt.Errorf("failed to get current directory: %w", err) - } - secret.Debug("Got current directory", "original_dir", originalDir) - - // Change to the symlink's directory - symlinkDir := filepath.Dir(symlinkPath) - secret.Debug("Changing to symlink directory", "symlink_path", symlinkDir) - secret.Debug("About to call os.Chdir - this might hang if symlink is broken") - if err := os.Chdir(symlinkDir); err != nil { - return "", fmt.Errorf("failed to change to symlink directory: %w", err) - } - secret.Debug("Changed to symlink directory successfully - os.Chdir completed") - - // Get the absolute path of the target - secret.Debug("Getting absolute path of current directory") - absolutePath, err := os.Getwd() - if err != nil { - // Try to restore original directory before returning error - _ = os.Chdir(originalDir) - return "", fmt.Errorf("failed to get absolute path: %w", err) - } - secret.Debug("Got absolute path", "absolute_path", absolutePath) - - // Restore the original directory - secret.Debug("Restoring original directory", "original_dir", originalDir) - if err := os.Chdir(originalDir); err != nil { - return "", fmt.Errorf("failed to restore original directory: %w", err) - } - secret.Debug("Restored original directory successfully") - - // Use the absolute path of the target - target = absolutePath - } - secret.Debug("resolveVaultSymlink completed successfully", "result", target) - return target, nil } + // Fall through to fallback if symlink resolution failed + } else { + secret.Debug("Not using OS filesystem, skipping symlink resolution") } // Fallback: treat it as a regular file containing the target path @@ -108,6 +103,28 @@ func ResolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) { return target, nil } +// tryResolveOsSymlink attempts to resolve a symlink on OS filesystems +func tryResolveOsSymlink(symlinkPath string) (string, error) { + secret.Debug("Using real filesystem symlink resolution") + + // Check if the symlink exists + secret.Debug("Checking symlink target", "symlink_path", symlinkPath) + target, err := os.Readlink(symlinkPath) + if err != nil { + return "", err + } + + secret.Debug("Symlink points to", "target", target) + + // On real filesystem, we need to handle relative symlinks + // by resolving them relative to the symlink's directory + if !filepath.IsAbs(target) { + return resolveRelativeSymlink(symlinkPath, target) + } + + return target, nil +} + // GetCurrentVault gets the current vault from the file system func GetCurrentVault(fs afero.Fs, stateDir string) (*Vault, error) { secret.Debug("Getting current vault", "state_dir", stateDir) @@ -174,6 +191,53 @@ func ListVaults(fs afero.Fs, stateDir string) ([]string, error) { return vaults, nil } +// processMnemonicForVault handles mnemonic processing for vault creation +func processMnemonicForVault(fs afero.Fs, stateDir, vaultDir, vaultName string) (derivationIndex uint32, publicKeyHash string, familyHash string, err error) { + // Check if mnemonic is available in environment + mnemonic := os.Getenv(secret.EnvMnemonic) + + if mnemonic == "" { + secret.Debug("No mnemonic in environment, vault created without long-term key", "vault", vaultName) + // Use 0 for derivation index when no mnemonic is provided + return 0, "", "", nil + } + + secret.Debug("Mnemonic found in environment, deriving long-term key", "vault", vaultName) + + // Get the next available derivation index for this mnemonic + derivationIndex, err = GetNextDerivationIndex(fs, stateDir, mnemonic) + if err != nil { + return 0, "", "", fmt.Errorf("failed to get next derivation index: %w", err) + } + + // Derive the long-term key using the actual derivation index + ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex) + if err != nil { + return 0, "", "", fmt.Errorf("failed to derive long-term key: %w", err) + } + + // Write the public key + ltPubKey := ltIdentity.Recipient().String() + ltPubKeyPath := filepath.Join(vaultDir, "pub.age") + if err := afero.WriteFile(fs, ltPubKeyPath, []byte(ltPubKey), secret.FilePerms); err != nil { + return 0, "", "", fmt.Errorf("failed to write long-term public key: %w", err) + } + secret.Debug("Wrote long-term public key", "path", ltPubKeyPath) + + // Compute verification hash from actual derivation index + publicKeyHash = ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String())) + + // Compute family hash from index 0 (same for all vaults with this mnemonic) + // This is used to identify which vaults belong to the same mnemonic family + identity0, err := agehd.DeriveIdentity(mnemonic, 0) + if err != nil { + return 0, "", "", fmt.Errorf("failed to derive identity for index 0: %w", err) + } + familyHash = ComputeDoubleSHA256([]byte(identity0.Recipient().String())) + + return derivationIndex, publicKeyHash, familyHash, nil +} + // CreateVault creates a new vault func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) { secret.Debug("Creating new vault", "name", name, "state_dir", stateDir) @@ -206,50 +270,10 @@ func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) { return nil, fmt.Errorf("failed to create unlockers directory: %w", err) } - // Check if mnemonic is available in environment - mnemonic := os.Getenv(secret.EnvMnemonic) - var derivationIndex uint32 - var publicKeyHash string - var familyHash string - - if mnemonic != "" { - secret.Debug("Mnemonic found in environment, deriving long-term key", "vault", name) - - // Get the next available derivation index for this mnemonic - var err error - derivationIndex, err = GetNextDerivationIndex(fs, stateDir, mnemonic) - if err != nil { - return nil, fmt.Errorf("failed to get next derivation index: %w", err) - } - - // Derive the long-term key using the actual derivation index - ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex) - if err != nil { - return nil, fmt.Errorf("failed to derive long-term key: %w", err) - } - - // Write the public key - ltPubKey := ltIdentity.Recipient().String() - ltPubKeyPath := filepath.Join(vaultDir, "pub.age") - if err := afero.WriteFile(fs, ltPubKeyPath, []byte(ltPubKey), secret.FilePerms); err != nil { - return nil, fmt.Errorf("failed to write long-term public key: %w", err) - } - secret.Debug("Wrote long-term public key", "path", ltPubKeyPath) - - // Compute verification hash from actual derivation index - publicKeyHash = ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String())) - - // Compute family hash from index 0 (same for all vaults with this mnemonic) - // This is used to identify which vaults belong to the same mnemonic family - identity0, err := agehd.DeriveIdentity(mnemonic, 0) - if err != nil { - return nil, fmt.Errorf("failed to derive identity for index 0: %w", err) - } - familyHash = ComputeDoubleSHA256([]byte(identity0.Recipient().String())) - } else { - secret.Debug("No mnemonic in environment, vault created without long-term key", "vault", name) - // Use 0 for derivation index when no mnemonic is provided - derivationIndex = 0 + // Process mnemonic if available + derivationIndex, publicKeyHash, familyHash, err := processMnemonicForVault(fs, stateDir, vaultDir, name) + if err != nil { + return nil, err } // Save vault metadata diff --git a/internal/vault/unlockers.go b/internal/vault/unlockers.go index 598e2c6..117abd7 100644 --- a/internal/vault/unlockers.go +++ b/internal/vault/unlockers.go @@ -33,31 +33,9 @@ func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) { } // Resolve the symlink to get the target directory - var unlockerDir string - if linkReader, ok := v.fs.(afero.LinkReader); ok { - secret.Debug("Resolving unlocker symlink using afero") - // Try to read as symlink first - unlockerDir, err = linkReader.ReadlinkIfPossible(currentUnlockerPath) - if err != nil { - secret.Debug("Failed to read symlink, falling back to file contents", - "error", err, "symlink_path", currentUnlockerPath) - // Fallback: read the path from file contents - unlockerDirBytes, err := afero.ReadFile(v.fs, currentUnlockerPath) - if err != nil { - secret.Debug("Failed to read unlocker path file", "error", err, "path", currentUnlockerPath) - return nil, fmt.Errorf("failed to read current unlocker: %w", err) - } - unlockerDir = strings.TrimSpace(string(unlockerDirBytes)) - } - } else { - secret.Debug("Reading unlocker path (filesystem doesn't support symlinks)") - // Fallback for filesystems that don't support symlinks: read the path from file contents - unlockerDirBytes, err := afero.ReadFile(v.fs, currentUnlockerPath) - if err != nil { - secret.Debug("Failed to read unlocker path file", "error", err, "path", currentUnlockerPath) - return nil, fmt.Errorf("failed to read current unlocker: %w", err) - } - unlockerDir = strings.TrimSpace(string(unlockerDirBytes)) + unlockerDir, err := v.resolveUnlockerDirectory(currentUnlockerPath) + if err != nil { + return nil, err } secret.DebugWith("Resolved unlocker directory", @@ -114,6 +92,95 @@ func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) { return unlocker, nil } +// resolveUnlockerDirectory resolves the unlocker directory from a symlink or file +func (v *Vault) resolveUnlockerDirectory(currentUnlockerPath string) (string, error) { + linkReader, ok := v.fs.(afero.LinkReader) + if !ok { + // Fallback for filesystems that don't support symlinks + return v.readUnlockerPathFromFile(currentUnlockerPath) + } + + secret.Debug("Resolving unlocker symlink using afero") + // Try to read as symlink first + unlockerDir, err := linkReader.ReadlinkIfPossible(currentUnlockerPath) + if err == nil { + return unlockerDir, nil + } + + secret.Debug("Failed to read symlink, falling back to file contents", + "error", err, "symlink_path", currentUnlockerPath) + // Fallback: read the path from file contents + return v.readUnlockerPathFromFile(currentUnlockerPath) +} + +// readUnlockerPathFromFile reads the unlocker directory path from a file +func (v *Vault) readUnlockerPathFromFile(path string) (string, error) { + secret.Debug("Reading unlocker path from file", "path", path) + unlockerDirBytes, err := afero.ReadFile(v.fs, path) + if err != nil { + secret.Debug("Failed to read unlocker path file", "error", err, "path", path) + return "", fmt.Errorf("failed to read current unlocker: %w", err) + } + return strings.TrimSpace(string(unlockerDirBytes)), nil +} + +// findUnlockerByID finds an unlocker by its ID and returns the unlocker instance and its directory path +func (v *Vault) findUnlockerByID(unlockersDir, unlockerID string) (secret.Unlocker, string, error) { + files, err := afero.ReadDir(v.fs, unlockersDir) + if err != nil { + return nil, "", fmt.Errorf("failed to read unlockers directory: %w", err) + } + + for _, file := range files { + if !file.IsDir() { + continue + } + + // Read metadata file + metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json") + exists, err := afero.Exists(v.fs, metadataPath) + if err != nil { + return nil, "", fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err) + } + if !exists { + // Skip directories without metadata - they might not be unlockers + continue + } + + metadataBytes, err := afero.ReadFile(v.fs, metadataPath) + if err != nil { + return nil, "", fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err) + } + + var metadata UnlockerMetadata + if err := json.Unmarshal(metadataBytes, &metadata); err != nil { + return nil, "", fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err) + } + + unlockerDirPath := filepath.Join(unlockersDir, file.Name()) + + // Create the appropriate unlocker instance + var tempUnlocker secret.Unlocker + switch metadata.Type { + case "passphrase": + tempUnlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDirPath, metadata) + case "pgp": + tempUnlocker = secret.NewPGPUnlocker(v.fs, unlockerDirPath, metadata) + case "keychain": + tempUnlocker = secret.NewKeychainUnlocker(v.fs, unlockerDirPath, metadata) + default: + continue + } + + // Check if this unlocker's ID matches + if tempUnlocker.GetID() == unlockerID { + return tempUnlocker, unlockerDirPath, nil + } + } + + return nil, "", nil +} + // ListUnlockers returns a list of available unlockers for this vault func (v *Vault) ListUnlockers() ([]UnlockerMetadata, error) { vaultDir, err := v.GetDirectory() @@ -178,58 +245,10 @@ func (v *Vault) RemoveUnlocker(unlockerID string) error { // Find the unlocker directory and create the unlocker instance unlockersDir := filepath.Join(vaultDir, "unlockers.d") - // List directories in unlockers.d - files, err := afero.ReadDir(v.fs, unlockersDir) + // Find the unlocker by ID + unlocker, _, err := v.findUnlockerByID(unlockersDir, unlockerID) if err != nil { - return fmt.Errorf("failed to read unlockers directory: %w", err) - } - - var unlocker secret.Unlocker - var unlockerDirPath string - for _, file := range files { - if file.IsDir() { - // Read metadata file - metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json") - exists, err := afero.Exists(v.fs, metadataPath) - if err != nil { - return fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err) - } - if !exists { - // Skip directories without metadata - they might not be unlockers - continue - } - - metadataBytes, err := afero.ReadFile(v.fs, metadataPath) - if err != nil { - return fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err) - } - - var metadata UnlockerMetadata - if err := json.Unmarshal(metadataBytes, &metadata); err != nil { - return fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err) - } - - unlockerDirPath = filepath.Join(unlockersDir, file.Name()) - - // Create the appropriate unlocker instance - var tempUnlocker secret.Unlocker - switch metadata.Type { - case "passphrase": - tempUnlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDirPath, metadata) - case "pgp": - tempUnlocker = secret.NewPGPUnlocker(v.fs, unlockerDirPath, metadata) - case "keychain": - tempUnlocker = secret.NewKeychainUnlocker(v.fs, unlockerDirPath, metadata) - default: - continue - } - - // Check if this unlocker's ID matches - if tempUnlocker.GetID() == unlockerID { - unlocker = tempUnlocker - break - } - } + return err } if unlocker == nil { @@ -250,57 +269,10 @@ func (v *Vault) SelectUnlocker(unlockerID string) error { // Find the unlocker directory by ID unlockersDir := filepath.Join(vaultDir, "unlockers.d") - // List directories in unlockers.d to find the unlocker - files, err := afero.ReadDir(v.fs, unlockersDir) + // Find the unlocker by ID + _, targetUnlockerDir, err := v.findUnlockerByID(unlockersDir, unlockerID) if err != nil { - return fmt.Errorf("failed to read unlockers directory: %w", err) - } - - var targetUnlockerDir string - for _, file := range files { - if file.IsDir() { - // Read metadata file - metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json") - exists, err := afero.Exists(v.fs, metadataPath) - if err != nil { - return fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err) - } - if !exists { - // Skip directories without metadata - they might not be unlockers - continue - } - - metadataBytes, err := afero.ReadFile(v.fs, metadataPath) - if err != nil { - return fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err) - } - - var metadata UnlockerMetadata - if err := json.Unmarshal(metadataBytes, &metadata); err != nil { - return fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err) - } - - unlockerDirPath := filepath.Join(unlockersDir, file.Name()) - - // Create the appropriate unlocker instance - var tempUnlocker secret.Unlocker - switch metadata.Type { - case "passphrase": - tempUnlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDirPath, metadata) - case "pgp": - tempUnlocker = secret.NewPGPUnlocker(v.fs, unlockerDirPath, metadata) - case "keychain": - tempUnlocker = secret.NewKeychainUnlocker(v.fs, unlockerDirPath, metadata) - default: - continue - } - - // Check if this unlocker's ID matches - if tempUnlocker.GetID() == unlockerID { - targetUnlockerDir = unlockerDirPath - break - } - } + return err } if targetUnlockerDir == "" {