Files
secret/internal/vault/vault.go
sneak a3d3fb3b69 secure-enclave-unlocker (#24)
Co-authored-by: clawbot <clawbot@eeqj.de>
Reviewed-on: #24
Reviewed-by: clawbot <clawbot@noreply.example.org>
Co-authored-by: sneak <sneak@sneak.berlin>
Co-committed-by: sneak <sneak@sneak.berlin>
2026-03-14 07:36:28 +01:00

250 lines
7.6 KiB
Go

package vault
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"filippo.io/age"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/spf13/afero"
)
// Vault represents a secrets vault
type Vault struct {
Name string
fs afero.Fs
stateDir string
longTermKey *age.X25519Identity // In-memory long-term key when unlocked
}
// NewVault creates a new Vault instance
func NewVault(fs afero.Fs, stateDir string, name string) *Vault {
secret.Debug("Creating NewVault instance")
v := &Vault{
Name: name,
fs: fs,
stateDir: stateDir,
longTermKey: nil,
}
secret.Debug("Created NewVault instance successfully")
return v
}
// Locked returns true if the vault doesn't have a long-term key in memory
func (v *Vault) Locked() bool {
return v.longTermKey == nil
}
// Unlock sets the long-term key in memory, unlocking the vault
func (v *Vault) Unlock(key *age.X25519Identity) {
v.longTermKey = key
}
// GetLongTermKey returns the long-term key if available in memory
func (v *Vault) GetLongTermKey() *age.X25519Identity {
return v.longTermKey
}
// ClearLongTermKey removes the long-term key from memory (locks the vault)
func (v *Vault) ClearLongTermKey() {
v.longTermKey = nil
}
// GetOrDeriveLongTermKey gets the long-term key from memory or derives it from available sources
func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
// If we have it in memory, return it
if !v.Locked() {
return v.longTermKey, nil
}
secret.Debug("Vault is locked, attempting to unlock", "vault_name", v.Name)
// Try to derive from environment mnemonic first
if envMnemonic := os.Getenv(secret.EnvMnemonic); envMnemonic != "" {
secret.Debug("Using mnemonic from environment for long-term key derivation", "vault_name", v.Name)
// Load vault metadata to get the derivation index
vaultDir, err := v.GetDirectory()
if err != nil {
return nil, fmt.Errorf("failed to get vault directory: %w", err)
}
metadata, err := LoadVaultMetadata(v.fs, vaultDir)
if err != nil {
secret.Debug("Failed to load vault metadata", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to load vault metadata: %w", err)
}
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex)
if err != nil {
secret.Debug("Failed to derive long-term key from mnemonic", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
}
// Verify that the derived key matches the stored public key hash
derivedPubKeyHash := ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
if derivedPubKeyHash != metadata.PublicKeyHash {
secret.Debug("Derived public key hash does not match stored hash",
"vault_name", v.Name,
"derived_hash", derivedPubKeyHash,
"stored_hash", metadata.PublicKeyHash,
"derivation_index", metadata.DerivationIndex)
return nil, fmt.Errorf("derived public key does not match vault: mnemonic may be incorrect")
}
secret.DebugWith("Successfully derived long-term key from mnemonic",
slog.String("vault_name", v.Name),
slog.String("public_key", ltIdentity.Recipient().String()),
slog.Uint64("derivation_index", uint64(metadata.DerivationIndex)),
)
// Cache the derived key by unlocking the vault
v.Unlock(ltIdentity)
secret.Debug("Vault is unlocked (lt key in memory) via mnemonic", "vault_name", v.Name)
return ltIdentity, nil
}
// No mnemonic available, try to use current unlocker
secret.Debug("No mnemonic available, using current unlocker to unlock vault", "vault_name", v.Name)
// Get current unlocker
unlocker, err := v.GetCurrentUnlocker()
if err != nil {
secret.Debug("Failed to get current unlocker", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
}
secret.DebugWith("Retrieved current unlocker for vault unlock",
slog.String("vault_name", v.Name),
slog.String("unlocker_type", unlocker.GetType()),
slog.String("unlocker_id", unlocker.GetID()),
)
// Get the long-term key via the unlocker.
// SE unlockers return the long-term key directly from GetIdentity().
// Other unlockers return their own identity, used to decrypt longterm.age.
ltIdentity, err := v.unlockLongTermKey(unlocker)
if err != nil {
return nil, err
}
secret.DebugWith("Successfully obtained long-term identity via unlocker",
slog.String("vault_name", v.Name),
slog.String("unlocker_type", unlocker.GetType()),
slog.String("public_key", ltIdentity.Recipient().String()),
)
// Cache the derived key by unlocking the vault
v.Unlock(ltIdentity)
secret.Debug("Vault is unlocked (lt key in memory) via unlocker",
"vault_name", v.Name, "unlocker_type", unlocker.GetType())
return ltIdentity, nil
}
// unlockLongTermKey extracts the vault's long-term key using the given unlocker.
// SE unlockers decrypt the long-term key directly; other unlockers use an intermediate identity.
func (v *Vault) unlockLongTermKey(unlocker secret.Unlocker) (*age.X25519Identity, error) {
if unlocker.GetType() == "secure-enclave" {
secret.Debug("SE unlocker: decrypting long-term key directly via Secure Enclave")
ltIdentity, err := unlocker.GetIdentity()
if err != nil {
return nil, fmt.Errorf("failed to decrypt long-term key via SE: %w", err)
}
return ltIdentity, nil
}
// Standard unlockers: get unlocker identity, then decrypt longterm.age
unlockerIdentity, err := unlocker.GetIdentity()
if err != nil {
return nil, fmt.Errorf("failed to get unlocker identity: %w", err)
}
encryptedLtPrivKeyPath := filepath.Join(unlocker.GetDirectory(), "longterm.age")
encryptedLtPrivKey, err := afero.ReadFile(v.fs, encryptedLtPrivKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err)
}
ltPrivKeyBuffer, err := secret.DecryptWithIdentity(encryptedLtPrivKey, unlockerIdentity)
if err != nil {
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
}
defer ltPrivKeyBuffer.Destroy()
ltIdentity, err := age.ParseX25519Identity(ltPrivKeyBuffer.String())
if err != nil {
return nil, fmt.Errorf("failed to parse long-term private key: %w", err)
}
return ltIdentity, nil
}
// GetDirectory returns the vault's directory path
func (v *Vault) GetDirectory() (string, error) {
return filepath.Join(v.stateDir, "vaults.d", v.Name), nil
}
// GetName returns the vault's name (for VaultInterface compatibility)
func (v *Vault) GetName() string {
return v.Name
}
// GetFilesystem returns the vault's filesystem (for VaultInterface compatibility)
func (v *Vault) GetFilesystem() afero.Fs {
return v.fs
}
// NumSecrets returns the number of secrets in the vault
func (v *Vault) NumSecrets() (int, error) {
vaultDir, err := v.GetDirectory()
if err != nil {
return 0, fmt.Errorf("failed to get vault directory: %w", err)
}
secretsDir := filepath.Join(vaultDir, "secrets.d")
exists, _ := afero.DirExists(v.fs, secretsDir)
if !exists {
return 0, nil
}
entries, err := afero.ReadDir(v.fs, secretsDir)
if err != nil {
return 0, fmt.Errorf("failed to read secrets directory: %w", err)
}
// Count only directories that have a "current" version pointer file
count := 0
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// A valid secret has a "current" file pointing to the active version
secretDir := filepath.Join(secretsDir, entry.Name())
currentFile := filepath.Join(secretDir, "current")
exists, err := afero.Exists(v.fs, currentFile)
if err != nil {
continue // Skip directories we can't read
}
if exists {
count++
}
}
return count, nil
}