secret/internal/cli/vault.go
clawbot 6be4601763 refactor: return errors from NewCLIInstance instead of panicking
Change NewCLIInstance() and NewCLIInstanceWithFs() to return
(*Instance, error) instead of panicking on DetermineStateDir failure.

Callers in RunE contexts propagate the error. Callers in command
construction (for shell completion) use log.Fatalf. Test callers
use t.Fatalf.

Addresses review feedback on PR #18.
2026-02-19 23:53:35 -08:00

525 lines
15 KiB
Go

package cli
import (
"log"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39"
)
func newVaultCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "vault",
Short: "Manage vaults",
Long: `Create, list, and select vaults for organizing secrets.`,
}
cmd.AddCommand(newVaultListCmd())
cmd.AddCommand(newVaultCreateCmd())
cmd.AddCommand(newVaultSelectCmd())
cmd.AddCommand(newVaultImportCmd())
cmd.AddCommand(newVaultRemoveCmd())
return cmd
}
func newVaultListCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List available vaults",
RunE: func(cmd *cobra.Command, _ []string) error {
jsonOutput, _ := cmd.Flags().GetBool("json")
cli, err := NewCLIInstance()
if err != nil {
return fmt.Errorf("failed to initialize CLI: %w", err)
}
return cli.ListVaults(cmd, jsonOutput)
},
}
cmd.Flags().Bool("json", false, "Output in JSON format")
return cmd
}
func newVaultCreateCmd() *cobra.Command {
return &cobra.Command{
Use: "create <name>",
Short: "Create a new vault",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cli, err := NewCLIInstance()
if err != nil {
return fmt.Errorf("failed to initialize CLI: %w", err)
}
return cli.CreateVault(cmd, args[0])
},
}
}
func newVaultSelectCmd() *cobra.Command {
cli, err := NewCLIInstance()
if err != nil {
log.Fatalf("failed to initialize CLI: %v", err)
}
return &cobra.Command{
Use: "select <name>",
Short: "Select a vault as current",
Args: cobra.ExactArgs(1),
ValidArgsFunction: getVaultNamesCompletionFunc(cli.fs, cli.stateDir),
RunE: func(cmd *cobra.Command, args []string) error {
cli, err := NewCLIInstance()
if err != nil {
return fmt.Errorf("failed to initialize CLI: %w", err)
}
return cli.SelectVault(cmd, args[0])
},
}
}
func newVaultImportCmd() *cobra.Command {
cli, err := NewCLIInstance()
if err != nil {
log.Fatalf("failed to initialize CLI: %v", err)
}
return &cobra.Command{
Use: "import <vault-name>",
Short: "Import a mnemonic into a vault",
Long: `Import a BIP39 mnemonic phrase into the specified vault (default if not specified).`,
Args: cobra.MaximumNArgs(1),
ValidArgsFunction: getVaultNamesCompletionFunc(cli.fs, cli.stateDir),
RunE: func(cmd *cobra.Command, args []string) error {
vaultName := "default"
if len(args) > 0 {
vaultName = args[0]
}
cli, err := NewCLIInstance()
if err != nil {
return fmt.Errorf("failed to initialize CLI: %w", err)
}
return cli.VaultImport(cmd, vaultName)
},
}
}
func newVaultRemoveCmd() *cobra.Command {
cli, err := NewCLIInstance()
if err != nil {
log.Fatalf("failed to initialize CLI: %v", err)
}
cmd := &cobra.Command{
Use: "remove <name>",
Aliases: []string{"rm"},
Short: "Remove a vault",
Long: `Remove a vault. Requires --force if the vault contains secrets. Will automatically ` +
`switch to another vault if removing the currently selected one.`,
Args: cobra.ExactArgs(1),
ValidArgsFunction: getVaultNamesCompletionFunc(cli.fs, cli.stateDir),
RunE: func(cmd *cobra.Command, args []string) error {
force, _ := cmd.Flags().GetBool("force")
cli, err := NewCLIInstance()
if err != nil {
return fmt.Errorf("failed to initialize CLI: %w", err)
}
return cli.RemoveVault(cmd, args[0], force)
},
}
cmd.Flags().BoolP("force", "f", false, "Force removal even if vault contains secrets")
return cmd
}
// ListVaults lists all available vaults
func (cli *Instance) ListVaults(cmd *cobra.Command, jsonOutput bool) error {
vaults, err := vault.ListVaults(cli.fs, cli.stateDir)
if err != nil {
return err
}
if jsonOutput { //nolint:nestif // Separate JSON and text output formatting logic
// Get current vault name for context
currentVault := ""
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
currentVault = currentVlt.GetName()
}
result := map[string]interface{}{
"vaults": vaults,
"currentVault": currentVault,
}
jsonBytes, err := json.MarshalIndent(result, "", " ")
if err != nil {
return err
}
cmd.Println(string(jsonBytes))
} else {
// Text output
cmd.Println("Available vaults:")
if len(vaults) == 0 {
cmd.Println(" (none)")
} else {
// Try to get current vault for marking
currentVault := ""
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
currentVault = currentVlt.GetName()
}
for _, vaultName := range vaults {
if vaultName == currentVault {
cmd.Printf(" %s (current)\n", vaultName)
} else {
cmd.Printf(" %s\n", vaultName)
}
}
}
}
return nil
}
// CreateVault creates a new vault
func (cli *Instance) CreateVault(cmd *cobra.Command, name string) error {
secret.Debug("Creating new vault", "name", name, "state_dir", cli.stateDir)
// Get or prompt for mnemonic
var mnemonicStr string
if envMnemonic := os.Getenv(secret.EnvMnemonic); envMnemonic != "" {
secret.Debug("Using mnemonic from environment variable")
mnemonicStr = envMnemonic
} else {
secret.Debug("Prompting user for mnemonic phrase")
// Read mnemonic securely without echo
mnemonicBuffer, err := secret.ReadPassphrase("Enter your BIP39 mnemonic phrase: ")
if err != nil {
secret.Debug("Failed to read mnemonic from stdin", "error", err)
return fmt.Errorf("failed to read mnemonic: %w", err)
}
defer mnemonicBuffer.Destroy()
mnemonicStr = mnemonicBuffer.String()
fmt.Fprintln(os.Stderr) // Add newline after hidden input
}
if mnemonicStr == "" {
return fmt.Errorf("mnemonic cannot be empty")
}
// Validate the mnemonic
mnemonicWords := strings.Fields(mnemonicStr)
secret.Debug("Validating BIP39 mnemonic", "word_count", len(mnemonicWords))
if !bip39.IsMnemonicValid(mnemonicStr) {
return fmt.Errorf("invalid BIP39 mnemonic phrase")
}
// Set mnemonic in environment for CreateVault to use
originalMnemonic := os.Getenv(secret.EnvMnemonic)
_ = os.Setenv(secret.EnvMnemonic, mnemonicStr)
defer func() {
if originalMnemonic != "" {
_ = os.Setenv(secret.EnvMnemonic, originalMnemonic)
} else {
_ = os.Unsetenv(secret.EnvMnemonic)
}
}()
// Create the vault - it will handle key derivation internally
vlt, err := vault.CreateVault(cli.fs, cli.stateDir, name)
if err != nil {
return err
}
// Get the vault metadata to retrieve the derivation index
vaultDir := filepath.Join(cli.stateDir, "vaults.d", name)
metadata, err := vault.LoadVaultMetadata(cli.fs, vaultDir)
if err != nil {
return fmt.Errorf("failed to load vault metadata: %w", err)
}
// Derive the long-term key using the same index that CreateVault used
ltIdentity, err := agehd.DeriveIdentity(mnemonicStr, metadata.DerivationIndex)
if err != nil {
return fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
}
// Unlock the vault with the derived long-term key
vlt.Unlock(ltIdentity)
// Get or prompt for passphrase
var passphraseBuffer *memguard.LockedBuffer
if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" {
secret.Debug("Using unlock passphrase from environment variable")
passphraseBuffer = memguard.NewBufferFromBytes([]byte(envPassphrase))
} else {
secret.Debug("Prompting user for unlock passphrase")
// Use secure passphrase input with confirmation
passphraseBuffer, err = readSecurePassphrase("Enter passphrase for unlocker: ")
if err != nil {
return fmt.Errorf("failed to read passphrase: %w", err)
}
}
defer passphraseBuffer.Destroy()
// Create passphrase-protected unlocker
secret.Debug("Creating passphrase-protected unlocker")
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil {
return fmt.Errorf("failed to create unlocker: %w", err)
}
cmd.Printf("Created vault '%s'\n", vlt.GetName())
cmd.Printf("Long-term public key: %s\n", ltIdentity.Recipient().String())
cmd.Printf("Unlocker ID: %s\n", passphraseUnlocker.GetID())
return nil
}
// SelectVault selects a vault as the current one
func (cli *Instance) SelectVault(cmd *cobra.Command, name string) error {
if err := vault.SelectVault(cli.fs, cli.stateDir, name); err != nil {
return err
}
cmd.Printf("Selected vault '%s' as current\n", name)
return nil
}
// VaultImport imports a mnemonic into a specific vault
func (cli *Instance) VaultImport(cmd *cobra.Command, vaultName string) error {
secret.Debug("Importing mnemonic into vault", "vault_name", vaultName, "state_dir", cli.stateDir)
// Get the specific vault by name
vlt := vault.NewVault(cli.fs, cli.stateDir, vaultName)
// Check if vault exists
vaultDir, err := vlt.GetDirectory()
if err != nil {
return err
}
exists, err := afero.DirExists(cli.fs, vaultDir)
if err != nil {
return fmt.Errorf("failed to check if vault exists: %w", err)
}
if !exists {
return fmt.Errorf("vault '%s' does not exist", vaultName)
}
// Check if vault already has a public key
pubKeyPath := fmt.Sprintf("%s/pub.age", vaultDir)
if _, err := cli.fs.Stat(pubKeyPath); err == nil {
return fmt.Errorf("vault '%s' already has a long-term key configured", vaultName)
}
// Get mnemonic from environment
mnemonic := os.Getenv(secret.EnvMnemonic)
if mnemonic == "" {
return fmt.Errorf("SB_SECRET_MNEMONIC environment variable not set")
}
// Validate the mnemonic
mnemonicWords := strings.Fields(mnemonic)
secret.Debug("Validating BIP39 mnemonic", "word_count", len(mnemonicWords))
if !bip39.IsMnemonicValid(mnemonic) {
return fmt.Errorf("invalid BIP39 mnemonic")
}
// Get the next available derivation index for this mnemonic
derivationIndex, err := vault.GetNextDerivationIndex(cli.fs, cli.stateDir, mnemonic)
if err != nil {
secret.Debug("Failed to get next derivation index", "error", err)
return fmt.Errorf("failed to get next derivation index: %w", err)
}
secret.Debug("Using derivation index", "index", derivationIndex)
// Derive long-term key from mnemonic with the appropriate index
secret.Debug("Deriving long-term key from mnemonic", "index", derivationIndex)
ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex)
if err != nil {
return fmt.Errorf("failed to derive long-term key: %w", err)
}
// Store long-term public key in vault
ltPublicKey := ltIdentity.Recipient().String()
secret.Debug("Storing long-term public key", "pubkey", ltPublicKey, "vault_dir", vaultDir)
if err := afero.WriteFile(cli.fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms); err != nil {
return fmt.Errorf("failed to store long-term public key: %w", err)
}
// Calculate public key hash from the actual derivation index being used
// This is used to verify that the derived key matches what was stored
publicKeyHash := vault.ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
// Calculate family hash from index 0 (same for all vaults with this mnemonic)
// This is used to identify which vaults belong to the same mnemonic family
identity0, err := agehd.DeriveIdentity(mnemonic, 0)
if err != nil {
return fmt.Errorf("failed to derive identity for index 0: %w", err)
}
familyHash := vault.ComputeDoubleSHA256([]byte(identity0.Recipient().String()))
// Load existing metadata
existingMetadata, err := vault.LoadVaultMetadata(cli.fs, vaultDir)
if err != nil {
// If metadata doesn't exist, create new
existingMetadata = &vault.Metadata{
CreatedAt: time.Now(),
}
}
// Update metadata with new derivation info
existingMetadata.DerivationIndex = derivationIndex
existingMetadata.PublicKeyHash = publicKeyHash
existingMetadata.MnemonicFamilyHash = familyHash
if err := vault.SaveVaultMetadata(cli.fs, vaultDir, existingMetadata); err != nil {
secret.Debug("Failed to save vault metadata", "error", err)
return fmt.Errorf("failed to save vault metadata: %w", err)
}
secret.Debug("Saved vault metadata with derivation index and public key hash")
// Get passphrase from environment variable
passphraseStr := os.Getenv(secret.EnvUnlockPassphrase)
if passphraseStr == "" {
return fmt.Errorf("SB_UNLOCK_PASSPHRASE environment variable not set")
}
secret.Debug("Using unlock passphrase from environment variable")
// Create secure buffer for passphrase
passphraseBuffer := memguard.NewBufferFromBytes([]byte(passphraseStr))
defer passphraseBuffer.Destroy()
// Unlock the vault with the derived long-term key
vlt.Unlock(ltIdentity)
// Create passphrase-protected unlocker
secret.Debug("Creating passphrase-protected unlocker")
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil {
secret.Debug("Failed to create unlocker", "error", err)
return fmt.Errorf("failed to create unlocker: %w", err)
}
cmd.Printf("Successfully imported mnemonic into vault '%s'\n", vaultName)
cmd.Printf("Long-term public key: %s\n", ltPublicKey)
cmd.Printf("Unlocker ID: %s\n", passphraseUnlocker.GetID())
return nil
}
// RemoveVault removes a vault with safety checks
func (cli *Instance) RemoveVault(cmd *cobra.Command, name string, force bool) error {
// Get list of all vaults
vaults, err := vault.ListVaults(cli.fs, cli.stateDir)
if err != nil {
return fmt.Errorf("failed to list vaults: %w", err)
}
// Check if vault exists
vaultExists := false
for _, v := range vaults {
if v == name {
vaultExists = true
break
}
}
if !vaultExists {
return fmt.Errorf("vault '%s' does not exist", name)
}
// Don't allow removing the last vault
if len(vaults) == 1 {
return fmt.Errorf("cannot remove the last vault")
}
// Check if this is the current vault
currentVault, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
if err != nil {
return fmt.Errorf("failed to get current vault: %w", err)
}
isCurrentVault := currentVault.GetName() == name
// Load the vault to check for secrets
vlt := vault.NewVault(cli.fs, cli.stateDir, name)
vaultDir, err := vlt.GetDirectory()
if err != nil {
return fmt.Errorf("failed to get vault directory: %w", err)
}
// Check if vault has secrets
secretsDir := filepath.Join(vaultDir, "secrets.d")
hasSecrets := false
if exists, _ := afero.DirExists(cli.fs, secretsDir); exists {
entries, err := afero.ReadDir(cli.fs, secretsDir)
if err == nil && len(entries) > 0 {
hasSecrets = true
}
}
// Require --force if vault has secrets
if hasSecrets && !force {
return fmt.Errorf("vault '%s' contains secrets; use --force to remove", name)
}
// If removing current vault, switch to another vault first
if isCurrentVault {
// Find another vault to switch to
var newVault string
for _, v := range vaults {
if v != name {
newVault = v
break
}
}
// Switch to the new vault
if err := vault.SelectVault(cli.fs, cli.stateDir, newVault); err != nil {
return fmt.Errorf("failed to switch to vault '%s': %w", newVault, err)
}
cmd.Printf("Switched current vault to '%s'\n", newVault)
}
// Remove the vault directory
if err := cli.fs.RemoveAll(vaultDir); err != nil {
return fmt.Errorf("failed to remove vault directory: %w", err)
}
cmd.Printf("Removed vault '%s'\n", name)
if hasSecrets {
cmd.Printf("Warning: Vault contained secrets that have been permanently deleted\n")
}
return nil
}