Refactor vault functionality to dedicated package, fix import cycles with interface pattern, fix tests

This commit is contained in:
2025-05-29 12:48:36 -07:00
parent c33385be6c
commit ddb395901b
18 changed files with 1847 additions and 2128 deletions

View File

@@ -4,15 +4,14 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"syscall"
"strings"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39"
"golang.org/x/term"
)
func newVaultCmd() *cobra.Command {
@@ -38,7 +37,7 @@ func newVaultListCmd() *cobra.Command {
jsonOutput, _ := cmd.Flags().GetBool("json")
cli := NewCLIInstance()
return cli.VaultList(jsonOutput)
return cli.ListVaults(jsonOutput)
},
}
@@ -53,7 +52,7 @@ func newVaultCreateCmd() *cobra.Command {
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cli := NewCLIInstance()
return cli.VaultCreate(args[0])
return cli.CreateVault(args[0])
},
}
}
@@ -65,7 +64,7 @@ func newVaultSelectCmd() *cobra.Command {
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cli := NewCLIInstance()
return cli.VaultSelect(args[0])
return cli.SelectVault(args[0])
},
}
}
@@ -83,166 +82,155 @@ func newVaultImportCmd() *cobra.Command {
}
cli := NewCLIInstance()
return cli.Import(vaultName)
return cli.VaultImport(vaultName)
},
}
}
// VaultList lists available vaults
func (cli *CLIInstance) VaultList(jsonOutput bool) error {
vaults, err := secret.ListVaults(cli.fs, cli.stateDir)
// ListVaults lists all available vaults
func (cli *CLIInstance) ListVaults(jsonOutput bool) error {
vaults, err := vault.ListVaults(cli.fs, cli.stateDir)
if err != nil {
return err
}
if jsonOutput {
// JSON output
output := map[string]interface{}{
"vaults": vaults,
// Get current vault name for context
currentVault := ""
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
currentVault = currentVlt.GetName()
}
jsonBytes, err := json.MarshalIndent(output, "", " ")
result := map[string]interface{}{
"vaults": vaults,
"current_vault": currentVault,
}
jsonBytes, err := json.MarshalIndent(result, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal JSON: %w", err)
return err
}
fmt.Println(string(jsonBytes))
} else {
// Pretty table output
// Text output
fmt.Println("Available vaults:")
if len(vaults) == 0 {
fmt.Println("No vaults found.")
fmt.Println("Run 'secret init' to create the default vault.")
return nil
}
// Get current vault for highlighting
currentVault, err := secret.GetCurrentVault(cli.fs, cli.stateDir)
if err != nil {
fmt.Printf("%-20s %s\n", "VAULT", "STATUS")
fmt.Printf("%-20s %s\n", "-----", "------")
for _, vault := range vaults {
fmt.Printf("%-20s %s\n", vault, "")
}
fmt.Println(" (none)")
} else {
fmt.Printf("%-20s %s\n", "VAULT", "STATUS")
fmt.Printf("%-20s %s\n", "-----", "------")
// Try to get current vault for marking
currentVault := ""
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
currentVault = currentVlt.GetName()
}
for _, vault := range vaults {
status := ""
if vault == currentVault.Name {
status = "(current)"
for _, vaultName := range vaults {
if vaultName == currentVault {
fmt.Printf(" %s (current)\n", vaultName)
} else {
fmt.Printf(" %s\n", vaultName)
}
fmt.Printf("%-20s %s\n", vault, status)
}
}
fmt.Printf("\nTotal: %d vault(s)\n", len(vaults))
}
return nil
}
// VaultCreate creates a new vault
func (cli *CLIInstance) VaultCreate(name string) error {
_, err := secret.CreateVault(cli.fs, cli.stateDir, name)
return err
}
// CreateVault creates a new vault
func (cli *CLIInstance) CreateVault(name string) error {
secret.Debug("Creating new vault", "name", name, "state_dir", cli.stateDir)
// VaultSelect selects a vault as current
func (cli *CLIInstance) VaultSelect(name string) error {
return secret.SelectVault(cli.fs, cli.stateDir, name)
}
// Import imports a mnemonic into a vault
func (cli *CLIInstance) Import(vaultName string) error {
var mnemonicStr string
// Check if mnemonic is set in environment variable
if envMnemonic := os.Getenv(secret.EnvMnemonic); envMnemonic != "" {
mnemonicStr = envMnemonic
} else {
// Read mnemonic from stdin using shared line reader
var err error
mnemonicStr, err = readLineFromStdin("Enter your BIP39 mnemonic phrase: ")
if err != nil {
return fmt.Errorf("failed to read mnemonic: %w", err)
}
}
if mnemonicStr == "" {
return fmt.Errorf("mnemonic cannot be empty")
}
// Validate the mnemonic using BIP39
if !bip39.IsMnemonicValid(mnemonicStr) {
return fmt.Errorf("invalid BIP39 mnemonic phrase\nRun 'secret generate mnemonic' to create a valid mnemonic")
}
return cli.importMnemonic(vaultName, mnemonicStr)
}
// importMnemonic imports a BIP39 mnemonic into the specified vault
func (cli *CLIInstance) importMnemonic(vaultName, mnemonic string) error {
// Derive long-term keypair from mnemonic
ltIdentity, err := agehd.DeriveIdentity(mnemonic, 0)
vlt, err := vault.CreateVault(cli.fs, cli.stateDir, name)
if err != nil {
return fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
return err
}
fmt.Printf("Created vault '%s'\n", vlt.GetName())
return nil
}
// SelectVault selects a vault as the current one
func (cli *CLIInstance) SelectVault(name string) error {
if err := vault.SelectVault(cli.fs, cli.stateDir, name); err != nil {
return err
}
fmt.Printf("Selected vault '%s' as current\n", name)
return nil
}
// VaultImport imports a mnemonic into a specific vault
func (cli *CLIInstance) VaultImport(vaultName string) error {
secret.Debug("Importing mnemonic into vault", "vault_name", vaultName, "state_dir", cli.stateDir)
// Get the specific vault by name
vlt := vault.NewVault(cli.fs, vaultName, cli.stateDir)
// Check if vault exists
stateDir := cli.GetStateDir()
vaultDir := filepath.Join(stateDir, "vaults.d", vaultName)
vaultDir, err := vlt.GetDirectory()
if err != nil {
return err
}
exists, err := afero.DirExists(cli.fs, vaultDir)
if err != nil {
return fmt.Errorf("failed to check if vault exists: %w", err)
}
if !exists {
return fmt.Errorf("vault %s does not exist", vaultName)
return fmt.Errorf("vault '%s' does not exist", vaultName)
}
// Get mnemonic from environment
mnemonic := os.Getenv(secret.EnvMnemonic)
if mnemonic == "" {
return fmt.Errorf("SB_SECRET_MNEMONIC environment variable not set")
}
// Validate the mnemonic
mnemonicWords := strings.Fields(mnemonic)
secret.Debug("Validating BIP39 mnemonic", "word_count", len(mnemonicWords))
if !bip39.IsMnemonicValid(mnemonic) {
return fmt.Errorf("invalid BIP39 mnemonic")
}
// Derive long-term key from mnemonic
secret.Debug("Deriving long-term key from mnemonic", "index", 0)
ltIdentity, err := agehd.DeriveIdentity(mnemonic, 0)
if err != nil {
return fmt.Errorf("failed to derive long-term key: %w", err)
}
// Store long-term public key in vault
ltPubKey := ltIdentity.Recipient().String()
if err := afero.WriteFile(cli.fs, filepath.Join(vaultDir, "pub.age"), []byte(ltPubKey), 0600); err != nil {
return fmt.Errorf("failed to write long-term public key: %w", err)
ltPublicKey := ltIdentity.Recipient().String()
secret.Debug("Storing long-term public key", "pubkey", ltPublicKey, "vault_dir", vaultDir)
pubKeyPath := fmt.Sprintf("%s/pub.age", vaultDir)
if err := afero.WriteFile(cli.fs, pubKeyPath, []byte(ltPublicKey), 0600); err != nil {
return fmt.Errorf("failed to store long-term public key: %w", err)
}
// Get the vault instance and unlock it
vault := secret.NewVault(cli.fs, vaultName, cli.stateDir)
vault.Unlock(ltIdentity)
// Get passphrase from environment variable
passphraseStr := os.Getenv(secret.EnvUnlockPassphrase)
if passphraseStr == "" {
return fmt.Errorf("SB_UNLOCK_PASSPHRASE environment variable not set")
}
secret.Debug("Using unlock passphrase from environment variable")
// Unlock the vault with the derived long-term key
vlt.Unlock(ltIdentity)
// Create passphrase-protected unlock key
secret.Debug("Creating passphrase-protected unlock key")
passphraseKey, err := vlt.CreatePassphraseKey(passphraseStr)
if err != nil {
secret.Debug("Failed to create unlock key", "error", err)
return fmt.Errorf("failed to create unlock key: %w", err)
}
fmt.Printf("Successfully imported mnemonic into vault '%s'\n", vaultName)
fmt.Printf("Long-term public key: %s\n", ltPubKey)
// Try to create unlock key only if running interactively
if term.IsTerminal(int(syscall.Stderr)) {
// Get or create passphrase for unlock key
var passphraseStr string
if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" {
passphraseStr = envPassphrase
} else {
// Use secure passphrase input with confirmation
passphraseStr, err = readSecurePassphrase("Enter passphrase for unlock key: ")
if err != nil {
fmt.Printf("Warning: Failed to create unlock key: %v\n", err)
fmt.Printf("You can create unlock keys later with 'secret keys add passphrase'\n")
return nil
}
}
// Create passphrase-protected unlock key (vault is now unlocked)
passphraseKey, err := vault.CreatePassphraseKey(passphraseStr)
if err != nil {
fmt.Printf("Warning: Failed to create unlock key: %v\n", err)
fmt.Printf("You can create unlock keys later with 'secret keys add passphrase'\n")
return nil
}
fmt.Printf("Unlock key ID: %s\n", passphraseKey.GetMetadata().ID)
} else {
fmt.Printf("Running in non-interactive mode - unlock key not created\n")
fmt.Printf("You can create unlock keys later with 'secret keys add passphrase'\n")
}
fmt.Printf("Long-term public key: %s\n", ltPublicKey)
fmt.Printf("Unlock key ID: %s\n", passphraseKey.GetID())
return nil
}