secret/internal/cli/vault.go

268 lines
7.4 KiB
Go

package cli
import (
"encoding/json"
"fmt"
"os"
"strings"
"time"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39"
)
func newVaultCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "vault",
Short: "Manage vaults",
Long: `Create, list, and select vaults for organizing secrets.`,
}
cmd.AddCommand(newVaultListCmd())
cmd.AddCommand(newVaultCreateCmd())
cmd.AddCommand(newVaultSelectCmd())
cmd.AddCommand(newVaultImportCmd())
return cmd
}
func newVaultListCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "List available vaults",
RunE: func(cmd *cobra.Command, args []string) error {
jsonOutput, _ := cmd.Flags().GetBool("json")
cli := NewCLIInstance()
return cli.ListVaults(jsonOutput)
},
}
cmd.Flags().Bool("json", false, "Output in JSON format")
return cmd
}
func newVaultCreateCmd() *cobra.Command {
return &cobra.Command{
Use: "create <name>",
Short: "Create a new vault",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cli := NewCLIInstance()
return cli.CreateVault(args[0])
},
}
}
func newVaultSelectCmd() *cobra.Command {
return &cobra.Command{
Use: "select <name>",
Short: "Select a vault as current",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cli := NewCLIInstance()
return cli.SelectVault(args[0])
},
}
}
func newVaultImportCmd() *cobra.Command {
return &cobra.Command{
Use: "import <vault-name>",
Short: "Import a mnemonic into a vault",
Long: `Import a BIP39 mnemonic phrase into the specified vault (default if not specified).`,
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
vaultName := "default"
if len(args) > 0 {
vaultName = args[0]
}
cli := NewCLIInstance()
return cli.VaultImport(vaultName)
},
}
}
// ListVaults lists all available vaults
func (cli *CLIInstance) ListVaults(jsonOutput bool) error {
vaults, err := vault.ListVaults(cli.fs, cli.stateDir)
if err != nil {
return err
}
if jsonOutput {
// Get current vault name for context
currentVault := ""
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
currentVault = currentVlt.GetName()
}
result := map[string]interface{}{
"vaults": vaults,
"current_vault": currentVault,
}
jsonBytes, err := json.MarshalIndent(result, "", " ")
if err != nil {
return err
}
fmt.Println(string(jsonBytes))
} else {
// Text output
fmt.Println("Available vaults:")
if len(vaults) == 0 {
fmt.Println(" (none)")
} else {
// Try to get current vault for marking
currentVault := ""
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
currentVault = currentVlt.GetName()
}
for _, vaultName := range vaults {
if vaultName == currentVault {
fmt.Printf(" %s (current)\n", vaultName)
} else {
fmt.Printf(" %s\n", vaultName)
}
}
}
}
return nil
}
// CreateVault creates a new vault
func (cli *CLIInstance) CreateVault(name string) error {
secret.Debug("Creating new vault", "name", name, "state_dir", cli.stateDir)
vlt, err := vault.CreateVault(cli.fs, cli.stateDir, name)
if err != nil {
return err
}
fmt.Printf("Created vault '%s'\n", vlt.GetName())
return nil
}
// SelectVault selects a vault as the current one
func (cli *CLIInstance) SelectVault(name string) error {
if err := vault.SelectVault(cli.fs, cli.stateDir, name); err != nil {
return err
}
fmt.Printf("Selected vault '%s' as current\n", name)
return nil
}
// VaultImport imports a mnemonic into a specific vault
func (cli *CLIInstance) VaultImport(vaultName string) error {
secret.Debug("Importing mnemonic into vault", "vault_name", vaultName, "state_dir", cli.stateDir)
// Get the specific vault by name
vlt := vault.NewVault(cli.fs, cli.stateDir, vaultName)
// Check if vault exists
vaultDir, err := vlt.GetDirectory()
if err != nil {
return err
}
exists, err := afero.DirExists(cli.fs, vaultDir)
if err != nil {
return fmt.Errorf("failed to check if vault exists: %w", err)
}
if !exists {
return fmt.Errorf("vault '%s' does not exist", vaultName)
}
// Get mnemonic from environment
mnemonic := os.Getenv(secret.EnvMnemonic)
if mnemonic == "" {
return fmt.Errorf("SB_SECRET_MNEMONIC environment variable not set")
}
// Validate the mnemonic
mnemonicWords := strings.Fields(mnemonic)
secret.Debug("Validating BIP39 mnemonic", "word_count", len(mnemonicWords))
if !bip39.IsMnemonicValid(mnemonic) {
return fmt.Errorf("invalid BIP39 mnemonic")
}
// Calculate mnemonic hash for index tracking
mnemonicHash := vault.ComputeDoubleSHA256([]byte(mnemonic))
secret.Debug("Calculated mnemonic hash", "hash", mnemonicHash)
// Get the next available derivation index for this mnemonic
derivationIndex, err := vault.GetNextDerivationIndex(cli.fs, cli.stateDir, mnemonicHash)
if err != nil {
secret.Debug("Failed to get next derivation index", "error", err)
return fmt.Errorf("failed to get next derivation index: %w", err)
}
secret.Debug("Using derivation index", "index", derivationIndex)
// Derive long-term key from mnemonic with the appropriate index
secret.Debug("Deriving long-term key from mnemonic", "index", derivationIndex)
ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex)
if err != nil {
return fmt.Errorf("failed to derive long-term key: %w", err)
}
// Calculate the long-term key hash
ltKeyHash := vault.ComputeDoubleSHA256([]byte(ltIdentity.String()))
secret.Debug("Calculated long-term key hash", "hash", ltKeyHash)
// Store long-term public key in vault
ltPublicKey := ltIdentity.Recipient().String()
secret.Debug("Storing long-term public key", "pubkey", ltPublicKey, "vault_dir", vaultDir)
pubKeyPath := fmt.Sprintf("%s/pub.age", vaultDir)
if err := afero.WriteFile(cli.fs, pubKeyPath, []byte(ltPublicKey), 0600); err != nil {
return fmt.Errorf("failed to store long-term public key: %w", err)
}
// Save vault metadata
metadata := &vault.VaultMetadata{
Name: vaultName,
CreatedAt: time.Now(),
DerivationIndex: derivationIndex,
LongTermKeyHash: ltKeyHash,
MnemonicHash: mnemonicHash,
}
if err := vault.SaveVaultMetadata(cli.fs, vaultDir, metadata); err != nil {
secret.Debug("Failed to save vault metadata", "error", err)
return fmt.Errorf("failed to save vault metadata: %w", err)
}
secret.Debug("Saved vault metadata with derivation index and key hash")
// Get passphrase from environment variable
passphraseStr := os.Getenv(secret.EnvUnlockPassphrase)
if passphraseStr == "" {
return fmt.Errorf("SB_UNLOCK_PASSPHRASE environment variable not set")
}
secret.Debug("Using unlock passphrase from environment variable")
// Unlock the vault with the derived long-term key
vlt.Unlock(ltIdentity)
// Create passphrase-protected unlocker
secret.Debug("Creating passphrase-protected unlocker")
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseStr)
if err != nil {
secret.Debug("Failed to create unlocker", "error", err)
return fmt.Errorf("failed to create unlocker: %w", err)
}
fmt.Printf("Successfully imported mnemonic into vault '%s'\n", vaultName)
fmt.Printf("Long-term public key: %s\n", ltPublicKey)
fmt.Printf("Unlocker ID: %s\n", passphraseUnlocker.GetID())
return nil
}