secret/internal/vault/vault.go
clawbot 341428d9ca fix: NumSecrets() now correctly counts secrets by checking for current file
NumSecrets() previously looked for non-directory, non-'current' files
directly under each secret directory, but the only children are
'current' (file, excluded) and 'versions' (directory, excluded),
so it always returned 0.

Now checks for the existence of the 'current' file, which is the
canonical indicator that a secret exists and has an active version.

This fixes the safety check in UnlockersRemove that was always
allowing removal of the last unlocker.
2026-02-08 12:04:15 -08:00

252 lines
7.9 KiB
Go

package vault
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"filippo.io/age"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/spf13/afero"
)
// Vault represents a secrets vault
type Vault struct {
Name string
fs afero.Fs
stateDir string
longTermKey *age.X25519Identity // In-memory long-term key when unlocked
}
// NewVault creates a new Vault instance
func NewVault(fs afero.Fs, stateDir string, name string) *Vault {
secret.Debug("Creating NewVault instance")
v := &Vault{
Name: name,
fs: fs,
stateDir: stateDir,
longTermKey: nil,
}
secret.Debug("Created NewVault instance successfully")
return v
}
// Locked returns true if the vault doesn't have a long-term key in memory
func (v *Vault) Locked() bool {
return v.longTermKey == nil
}
// Unlock sets the long-term key in memory, unlocking the vault
func (v *Vault) Unlock(key *age.X25519Identity) {
v.longTermKey = key
}
// GetLongTermKey returns the long-term key if available in memory
func (v *Vault) GetLongTermKey() *age.X25519Identity {
return v.longTermKey
}
// ClearLongTermKey removes the long-term key from memory (locks the vault)
func (v *Vault) ClearLongTermKey() {
v.longTermKey = nil
}
// GetOrDeriveLongTermKey gets the long-term key from memory or derives it from available sources
func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
// If we have it in memory, return it
if !v.Locked() {
return v.longTermKey, nil
}
secret.Debug("Vault is locked, attempting to unlock", "vault_name", v.Name)
// Try to derive from environment mnemonic first
if envMnemonic := os.Getenv(secret.EnvMnemonic); envMnemonic != "" {
secret.Debug("Using mnemonic from environment for long-term key derivation", "vault_name", v.Name)
// Load vault metadata to get the derivation index
vaultDir, err := v.GetDirectory()
if err != nil {
return nil, fmt.Errorf("failed to get vault directory: %w", err)
}
metadata, err := LoadVaultMetadata(v.fs, vaultDir)
if err != nil {
secret.Debug("Failed to load vault metadata", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to load vault metadata: %w", err)
}
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex)
if err != nil {
secret.Debug("Failed to derive long-term key from mnemonic", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
}
// Verify that the derived key matches the stored public key hash
derivedPubKeyHash := ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
if derivedPubKeyHash != metadata.PublicKeyHash {
secret.Debug("Derived public key hash does not match stored hash",
"vault_name", v.Name,
"derived_hash", derivedPubKeyHash,
"stored_hash", metadata.PublicKeyHash,
"derivation_index", metadata.DerivationIndex)
return nil, fmt.Errorf("derived public key does not match vault: mnemonic may be incorrect")
}
secret.DebugWith("Successfully derived long-term key from mnemonic",
slog.String("vault_name", v.Name),
slog.String("public_key", ltIdentity.Recipient().String()),
slog.Uint64("derivation_index", uint64(metadata.DerivationIndex)),
)
// Cache the derived key by unlocking the vault
v.Unlock(ltIdentity)
secret.Debug("Vault is unlocked (lt key in memory) via mnemonic", "vault_name", v.Name)
return ltIdentity, nil
}
// No mnemonic available, try to use current unlocker
secret.Debug("No mnemonic available, using current unlocker to unlock vault", "vault_name", v.Name)
// Get current unlocker
unlocker, err := v.GetCurrentUnlocker()
if err != nil {
secret.Debug("Failed to get current unlocker", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
}
secret.DebugWith("Retrieved current unlocker for vault unlock",
slog.String("vault_name", v.Name),
slog.String("unlocker_type", unlocker.GetType()),
slog.String("unlocker_id", unlocker.GetID()),
)
// Get unlocker identity
unlockerIdentity, err := unlocker.GetIdentity()
if err != nil {
secret.Debug("Failed to get unlocker identity", "error", err, "unlocker_type", unlocker.GetType())
return nil, fmt.Errorf("failed to get unlocker identity: %w", err)
}
// Read encrypted long-term private key from unlocker directory
unlockerDir := unlocker.GetDirectory()
encryptedLtPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
secret.Debug("Reading encrypted long-term private key", "path", encryptedLtPrivKeyPath)
encryptedLtPrivKey, err := afero.ReadFile(v.fs, encryptedLtPrivKeyPath)
if err != nil {
secret.Debug("Failed to read encrypted long-term private key", "error", err, "path", encryptedLtPrivKeyPath)
return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err)
}
secret.DebugWith("Read encrypted long-term private key",
slog.String("vault_name", v.Name),
slog.String("unlocker_type", unlocker.GetType()),
slog.Int("encrypted_length", len(encryptedLtPrivKey)),
)
// Decrypt long-term private key using unlocker
secret.Debug("Decrypting long-term private key with unlocker", "unlocker_type", unlocker.GetType())
ltPrivKeyBuffer, err := secret.DecryptWithIdentity(encryptedLtPrivKey, unlockerIdentity)
if err != nil {
secret.Debug("Failed to decrypt long-term private key", "error", err, "unlocker_type", unlocker.GetType())
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
}
defer ltPrivKeyBuffer.Destroy()
secret.DebugWith("Successfully decrypted long-term private key",
slog.String("vault_name", v.Name),
slog.String("unlocker_type", unlocker.GetType()),
slog.Int("decrypted_length", ltPrivKeyBuffer.Size()),
)
// Parse long-term private key
secret.Debug("Parsing long-term private key", "vault_name", v.Name)
ltIdentity, err := age.ParseX25519Identity(ltPrivKeyBuffer.String())
if err != nil {
secret.Debug("Failed to parse long-term private key", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to parse long-term private key: %w", err)
}
secret.DebugWith("Successfully obtained long-term identity via unlocker",
slog.String("vault_name", v.Name),
slog.String("unlocker_type", unlocker.GetType()),
slog.String("public_key", ltIdentity.Recipient().String()),
)
// Cache the derived key by unlocking the vault
v.Unlock(ltIdentity)
secret.Debug("Vault is unlocked (lt key in memory) via unlocker",
"vault_name", v.Name, "unlocker_type", unlocker.GetType())
return ltIdentity, nil
}
// GetDirectory returns the vault's directory path
func (v *Vault) GetDirectory() (string, error) {
return filepath.Join(v.stateDir, "vaults.d", v.Name), nil
}
// GetName returns the vault's name (for VaultInterface compatibility)
func (v *Vault) GetName() string {
return v.Name
}
// GetFilesystem returns the vault's filesystem (for VaultInterface compatibility)
func (v *Vault) GetFilesystem() afero.Fs {
return v.fs
}
// NumSecrets returns the number of secrets in the vault
func (v *Vault) NumSecrets() (int, error) {
vaultDir, err := v.GetDirectory()
if err != nil {
return 0, fmt.Errorf("failed to get vault directory: %w", err)
}
secretsDir := filepath.Join(vaultDir, "secrets.d")
exists, _ := afero.DirExists(v.fs, secretsDir)
if !exists {
return 0, nil
}
entries, err := afero.ReadDir(v.fs, secretsDir)
if err != nil {
return 0, fmt.Errorf("failed to read secrets directory: %w", err)
}
// Count only directories that have a "current" version pointer file
count := 0
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// A valid secret has a "current" file pointing to the active version
secretDir := filepath.Join(secretsDir, entry.Name())
currentFile := filepath.Join(secretDir, "current")
exists, err := afero.Exists(v.fs, currentFile)
if err != nil {
continue // Skip directories we can't read
}
if exists {
count++
}
}
return count, nil
}