This commit is contained in:
2025-05-29 13:02:39 -07:00
parent ddb395901b
commit 8cc15fde3d
5 changed files with 66 additions and 76 deletions

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"git.eeqj.de/sneak/secret/internal/secret"
"github.com/spf13/afero"
@@ -16,87 +17,54 @@ func init() {
})
}
// resolveVaultSymlink resolves the currentvault symlink by changing into it and getting the absolute path
// isValidVaultName validates vault names according to the format [a-z0-9\.\-\_]+
// Note: We don't allow slashes in vault names unlike secret names
func isValidVaultName(name string) bool {
if name == "" {
return false
}
matched, _ := regexp.MatchString(`^[a-z0-9\.\-\_]+$`, name)
return matched
}
// resolveVaultSymlink resolves the currentvault symlink by reading either the symlink target or file contents
// This function is designed to work on both Unix and Windows systems, as well as with in-memory filesystems
func resolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) {
secret.Debug("resolveVaultSymlink starting", "symlink_path", symlinkPath)
// For real filesystems, we can use os.Chdir and os.Getwd
// First try to handle the path as a real symlink (works on Unix systems)
if _, ok := fs.(*afero.OsFs); ok {
secret.Debug("Using real filesystem symlink resolution")
secret.Debug("Trying real filesystem symlink resolution")
// Check what the symlink points to first
secret.Debug("Checking symlink target", "symlink_path", symlinkPath)
// Check if it's a real symlink first (will work on Unix)
linkTarget, err := os.Readlink(symlinkPath)
if err != nil {
secret.Debug("Failed to read symlink target", "error", err, "symlink_path", symlinkPath)
// Maybe it's not a symlink, try reading as file
secret.Debug("Trying to read as file instead of symlink")
targetBytes, err := os.ReadFile(symlinkPath)
if err != nil {
secret.Debug("Failed to read as file", "error", err)
return "", fmt.Errorf("failed to read vault symlink: %w", err)
if err == nil {
// Successfully read as symlink (Unix path)
secret.Debug("Successfully read as real symlink", "target", linkTarget)
// Convert relative paths to absolute if needed
if !filepath.IsAbs(linkTarget) {
linkTarget = filepath.Join(filepath.Dir(symlinkPath), linkTarget)
}
targetPath := string(targetBytes)
secret.Debug("Read target path from file", "target_path", targetPath)
return targetPath, nil
secret.Debug("Resolved symlink path", "path", linkTarget)
return linkTarget, nil
}
secret.Debug("Symlink points to", "target", linkTarget)
// Save current directory so we can restore it
secret.Debug("Getting current directory")
originalDir, err := os.Getwd()
if err != nil {
secret.Debug("Failed to get current directory", "error", err)
return "", fmt.Errorf("failed to get current directory: %w", err)
}
secret.Debug("Got current directory", "original_dir", originalDir)
// Change to the symlink directory
secret.Debug("Changing to symlink directory", "symlink_path", symlinkPath)
secret.Debug("About to call os.Chdir - this might hang if symlink is broken")
if err := os.Chdir(symlinkPath); err != nil {
secret.Debug("Failed to change to symlink directory", "error", err)
return "", fmt.Errorf("failed to change to symlink directory: %w", err)
}
secret.Debug("Changed to symlink directory successfully - os.Chdir completed")
// Get absolute path of current directory (which is the resolved symlink)
secret.Debug("Getting absolute path of current directory")
absolutePath, err := os.Getwd()
if err != nil {
secret.Debug("Failed to get absolute path", "error", err)
// Try to restore original directory before returning error
if restoreErr := os.Chdir(originalDir); restoreErr != nil {
secret.Debug("Failed to restore original directory after error", "error", restoreErr)
}
return "", fmt.Errorf("failed to get absolute path: %w", err)
}
secret.Debug("Got absolute path", "absolute_path", absolutePath)
// Restore original directory
secret.Debug("Restoring original directory", "original_dir", originalDir)
if err := os.Chdir(originalDir); err != nil {
secret.Debug("Failed to restore original directory", "error", err)
// Don't return error here since we got what we needed
} else {
secret.Debug("Restored original directory successfully")
}
secret.Debug("resolveVaultSymlink completed successfully", "result", absolutePath)
return absolutePath, nil
secret.Debug("Not a real symlink or on Windows, trying as regular file", "error", err)
}
// For in-memory filesystems, read the symlink content directly
secret.Debug("Using in-memory filesystem symlink resolution")
// For Windows or in-memory filesystems, or when symlink reading fails:
// Read the path from a regular file (our fallback approach)
secret.Debug("Reading symlink target from file")
content, err := afero.ReadFile(fs, symlinkPath)
if err != nil {
secret.Debug("Failed to read symlink content", "error", err)
return "", fmt.Errorf("failed to read vault symlink: %w", err)
secret.Debug("Failed to read from file", "error", err)
return "", fmt.Errorf("failed to read vault path: %w", err)
}
targetPath := string(content)
secret.Debug("Read symlink target from in-memory filesystem", "target_path", targetPath)
secret.Debug("Read target path from file", "target_path", targetPath)
return targetPath, nil
}
@@ -171,6 +139,13 @@ func ListVaults(fs afero.Fs, stateDir string) ([]string, error) {
func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) {
secret.Debug("Creating new vault", "name", name, "state_dir", stateDir)
// Validate vault name
if !isValidVaultName(name) {
secret.Debug("Invalid vault name provided", "vault_name", name)
return nil, fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name)
}
secret.Debug("Vault name validation passed", "vault_name", name)
// Create vault directory structure
vaultDir := filepath.Join(stateDir, "vaults.d", name)
secret.Debug("Creating vault directory structure", "vault_dir", vaultDir)
@@ -214,6 +189,13 @@ func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) {
func SelectVault(fs afero.Fs, stateDir string, name string) error {
secret.Debug("Selecting vault", "vault_name", name, "state_dir", stateDir)
// Validate vault name
if !isValidVaultName(name) {
secret.Debug("Invalid vault name provided", "vault_name", name)
return fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name)
}
secret.Debug("Vault name validation passed", "vault_name", name)
// Check if vault exists
vaultDir := filepath.Join(stateDir, "vaults.d", name)
exists, err := afero.DirExists(fs, vaultDir)

View File

@@ -73,6 +73,10 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
slog.String("public_key", ltIdentity.Recipient().String()),
)
// Cache the derived key by unlocking the vault
v.Unlock(ltIdentity)
secret.Debug("Vault is unlocked (lt key in memory) via mnemonic", "vault_name", v.Name)
return ltIdentity, nil
}
@@ -144,6 +148,10 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
slog.String("public_key", ltIdentity.Recipient().String()),
)
// Cache the derived key by unlocking the vault
v.Unlock(ltIdentity)
secret.Debug("Vault is unlocked (lt key in memory) via unlock key", "vault_name", v.Name, "unlock_key_type", unlockKey.GetType())
return ltIdentity, nil
}