Compare commits

...

10 Commits

Author SHA1 Message Date
7596049828 uses protected memory buffers now for all secrets in ram 2025-07-15 08:32:33 +02:00
d3ca006886 Merge branch 'main' into fix-memory-security 2025-07-15 07:36:13 +02:00
f91281e991 Merge branch 'fix-test-json-fields' 2025-07-15 07:35:58 +02:00
7c5e78db17 fix: update JSON fields from snake_case to camelCase and make tests quiet by default
- Update all JSON field references in tests from snake_case to camelCase
- Update vault list JSON output to use currentVault instead of current_vault
- Make integration tests quiet by default unless run with -v flag
- Fix tests that were using exec.Command to use in-process execution helpers
- Tests now only show debug output when explicitly requested or on failure
2025-07-15 07:35:48 +02:00
8e374b3d24 add test binaries to gitignore 2025-07-15 07:24:15 +02:00
c9774e89e0 WIP: refactor to use memguard for secure memory handling
- Add memguard dependency
- Update ReadPassphrase to return LockedBuffer
- Update EncryptWithPassphrase/DecryptWithPassphrase to accept LockedBuffer
- Remove string wrapper functions
- Update all callers to create LockedBuffers at entry points
- Update interfaces and mock implementations
2025-07-15 07:23:58 +02:00
f9938135c6 fix: resolve all remaining linter issues (staticcheck, tagliatelle, lll)
- Fix staticcheck QF1011: Remove explicit type declaration for io.Writer variables
- Fix tagliatelle: Change all JSON tags from snake_case to camelCase
  - created_at → createdAt
  - keychain_item_name → keychainItemName
  - age_public_key → agePublicKey
  - age_priv_key_passphrase → agePrivKeyPassphrase
  - encrypted_longterm_key → encryptedLongtermKey
  - derivation_index → derivationIndex
  - public_key_hash → publicKeyHash
  - mnemonic_family_hash → mnemonicFamilyHash
  - gpg_key_id → gpgKeyId
- Fix lll: Break long function signature line to stay under 120 character limit

All linter issues have been resolved. The codebase now passes all linter checks.
2025-07-15 06:33:25 +02:00
386a27c0b6 fix: resolve all revive linter issues
Added missing package comments:
- cmd/secret/main.go
- internal/cli/cli.go
- internal/secret/constants.go
- internal/vault/management.go
- pkg/bip85/bip85.go

Fixed comment format issues for exported items:
- EnvStateDir, EnvMnemonic, EnvUnlockPassphrase, EnvGPGKeyID in constants.go
- Metadata, UnlockerMetadata, SecretMetadata, Configuration in metadata.go
- AppBIP39, AppHDWIF, AppXPRV in bip85.go

Replaced unused parameters with underscore (_):
- generate.go:39 - parameter 'args'
- init.go:30 - parameter 'args'
- unlockers.go:39,77,102 - parameter 'args' or 'cmd'
- vault.go:37 - parameter 'args'
- management.go:34 - parameter 'target'
2025-07-15 06:06:48 +02:00
080a3dc253 fix: resolve all nlreturn linter errors
Add blank lines before return statements in all files to satisfy
the nlreturn linter. This improves code readability by providing
visual separation before return statements.

Changes made across 24 files:
- internal/cli/*.go
- internal/secret/*.go
- internal/vault/*.go
- pkg/agehd/agehd.go
- pkg/bip85/bip85.go

All 143 nlreturn issues have been resolved.
2025-07-15 06:00:32 +02:00
811ddee3b7 fix: resolve all nestif linter errors
- Extract getLongTermPrivateKey helper function to reduce nesting in keychainunlocker.go and pgpunlocker.go
- Add getPassphrase helper method to reduce nesting in passphraseunlocker.go
- Refactor version serial extraction to use early returns in version.go
- Extract resolveRelativeSymlink and tryResolveOsSymlink helpers in management.go
- Add processMnemonicForVault helper to reduce nesting in vault creation
- Extract resolveUnlockerDirectory and readUnlockerPathFromFile helpers in unlockers.go
- Add findUnlockerByID helper to reduce duplicate code in RemoveUnlocker and SelectUnlocker

All tests pass after refactoring.
2025-07-15 05:47:16 +02:00
43 changed files with 1535 additions and 627 deletions

View File

@ -17,7 +17,13 @@
"Bash(gofumpt:*)", "Bash(gofumpt:*)",
"Bash(git stash:*)", "Bash(git stash:*)",
"Bash(git commit:*)", "Bash(git commit:*)",
"Bash(git push:*)" "Bash(git push:*)",
"Bash(golangci-lint:*)",
"Bash(git checkout:*)",
"Bash(ls:*)",
"WebFetch(domain:golangci-lint.run)",
"Bash(go:*)",
"WebFetch(domain:pkg.go.dev)"
], ],
"deny": [] "deny": []
} }

3
.gitignore vendored
View File

@ -3,4 +3,5 @@
/secret /secret
*.log *.log
cli.test cli.test
vault.test
*.test

View File

@ -81,10 +81,6 @@ issues:
max-issues-per-linter: 0 max-issues-per-linter: 0
max-same-issues: 0 max-same-issues: 0
exclude-rules: exclude-rules:
# Exclude all linters from running on test files
- path: _test\.go
# Allow long lines in generated code or test data
- path: ".*_gen\\.go" - path: ".*_gen\\.go"
linters: linters:
- lll - lll

13
TODO.md
View File

@ -6,22 +6,9 @@ prioritized from most critical (top) to least critical (bottom).
## Code Cleanups ## Code Cleanups
* none of the integration tests should be searching for a binary or trying
to execute another process. the integration tests cannot make another
process or depend on a compiled file, they must do all of their testing in
the current (test) process.
* we shouldn't be passing around a statedir, it should be read from the * we shouldn't be passing around a statedir, it should be read from the
environment or default. environment or default.
## CRITICAL SECURITY ISSUES - Must Fix Before 1.0
- [ ] **1. Memory security vulnerabilities**: Sensitive data (passwords,
private keys, passphrases) stored as strings are not properly zeroed from
memory after use. Memory dumps or swap files could expose secrets. Found
in crypto.go:107, passphraseunlocker.go:29-48, cli/crypto.go:89,193,
pgpunlocker.go:278, keychainunlocker.go:252,346.
## HIGH PRIORITY SECURITY ISSUES ## HIGH PRIORITY SECURITY ISSUES
- [ ] **4. Application crashes on corrupted metadata**: Code panics instead - [ ] **4. Application crashes on corrupted metadata**: Code panics instead

View File

@ -1,3 +1,4 @@
// Package main is the entry point for the secret CLI application.
package main package main
import "git.eeqj.de/sneak/secret/internal/cli" import "git.eeqj.de/sneak/secret/internal/cli"

2
go.mod
View File

@ -4,6 +4,7 @@ go 1.24.1
require ( require (
filippo.io/age v1.2.1 filippo.io/age v1.2.1
github.com/awnumar/memguard v0.22.5
github.com/btcsuite/btcd v0.24.2 github.com/btcsuite/btcd v0.24.2
github.com/btcsuite/btcd/btcec/v2 v2.1.3 github.com/btcsuite/btcd/btcec/v2 v2.1.3
github.com/btcsuite/btcd/btcutil v1.1.6 github.com/btcsuite/btcd/btcutil v1.1.6
@ -18,6 +19,7 @@ require (
) )
require ( require (
github.com/awnumar/memcall v0.2.0 // indirect
github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect

5
go.sum
View File

@ -3,6 +3,10 @@ c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805/go.mod h1:FomMrUJ2Lxt5jCLmZ
filippo.io/age v1.2.1 h1:X0TZjehAZylOIj4DubWYU1vWQxv9bJpo+Uu2/LGhi1o= filippo.io/age v1.2.1 h1:X0TZjehAZylOIj4DubWYU1vWQxv9bJpo+Uu2/LGhi1o=
filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004= filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004=
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/awnumar/memcall v0.2.0 h1:sRaogqExTOOkkNwO9pzJsL8jrOV29UuUW7teRMfbqtI=
github.com/awnumar/memcall v0.2.0/go.mod h1:S911igBPR9CThzd/hYQQmTc9SWNu3ZHIlCGaWsWsoJo=
github.com/awnumar/memguard v0.22.5 h1:PH7sbUVERS5DdXh3+mLo8FDcl1eIeVjJVYMnyuYpvuI=
github.com/awnumar/memguard v0.22.5/go.mod h1:+APmZGThMBWjnMlKiSM1X7MVpbIVewen2MTkqWkA/zE=
github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ=
github.com/btcsuite/btcd v0.22.0-beta.0.20220111032746-97732e52810c/go.mod h1:tjmYdS6MLJ5/s0Fj4DbLgSbDHbEqLJrtnHecBFkdz5M= github.com/btcsuite/btcd v0.22.0-beta.0.20220111032746-97732e52810c/go.mod h1:tjmYdS6MLJ5/s0Fj4DbLgSbDHbEqLJrtnHecBFkdz5M=
github.com/btcsuite/btcd v0.23.5-0.20231215221805-96c9fd8078fd/go.mod h1:nm3Bko6zh6bWP60UxwoT5LzdGJsQJaPo6HjduXq9p6A= github.com/btcsuite/btcd v0.23.5-0.20231215221805-96c9fd8078fd/go.mod h1:nm3Bko6zh6bWP60UxwoT5LzdGJsQJaPo6HjduXq9p6A=
@ -107,6 +111,7 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -1,3 +1,4 @@
// Package cli implements the command-line interface for the secret application.
package cli package cli
import ( import (
@ -37,6 +38,7 @@ func NewCLIInstance() *Instance {
// NewCLIInstanceWithFs creates a new CLI instance with the given filesystem (for testing) // NewCLIInstanceWithFs creates a new CLI instance with the given filesystem (for testing)
func NewCLIInstanceWithFs(fs afero.Fs) *Instance { func NewCLIInstanceWithFs(fs afero.Fs) *Instance {
stateDir := secret.DetermineStateDir("") stateDir := secret.DetermineStateDir("")
return &Instance{ return &Instance{
fs: fs, fs: fs,
stateDir: stateDir, stateDir: stateDir,

View File

@ -8,6 +8,7 @@ import (
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"github.com/awnumar/memguard"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -81,10 +82,15 @@ func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error {
return fmt.Errorf("failed to generate age key: %w", err) return fmt.Errorf("failed to generate age key: %w", err)
} }
ageSecretKey = identity.String() // Store the generated key directly in a secure buffer
identityStr := identity.String()
secureBuffer := memguard.NewBufferFromBytes([]byte(identityStr))
defer secureBuffer.Destroy()
// Store the generated key as a secret // Set ageSecretKey for later use (we need it for encryption)
err = vlt.AddSecret(secretName, []byte(ageSecretKey), false) ageSecretKey = identityStr
err = vlt.AddSecret(secretName, secureBuffer, false)
if err != nil { if err != nil {
return fmt.Errorf("failed to store age key: %w", err) return fmt.Errorf("failed to store age key: %w", err)
} }
@ -95,7 +101,16 @@ func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error {
return fmt.Errorf("failed to get secret value: %w", err) return fmt.Errorf("failed to get secret value: %w", err)
} }
ageSecretKey = string(secretValue) // Create secure buffer for the secret value
secureBuffer := memguard.NewBufferFromBytes(secretValue)
defer secureBuffer.Destroy()
// Clear the original secret value
for i := range secretValue {
secretValue[i] = 0
}
ageSecretKey = secureBuffer.String()
// Validate that it's a valid age secret key // Validate that it's a valid age secret key
if !isValidAgeSecretKey(ageSecretKey) { if !isValidAgeSecretKey(ageSecretKey) {
@ -103,8 +118,11 @@ func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error {
} }
} }
// Parse the secret key // Parse the secret key using secure buffer
identity, err := age.ParseX25519Identity(ageSecretKey) finalSecureBuffer := memguard.NewBufferFromBytes([]byte(ageSecretKey))
defer finalSecureBuffer.Destroy()
identity, err := age.ParseX25519Identity(finalSecureBuffer.String())
if err != nil { if err != nil {
return fmt.Errorf("failed to parse age secret key: %w", err) return fmt.Errorf("failed to parse age secret key: %w", err)
} }
@ -124,7 +142,7 @@ func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error {
} }
// Set up output writer // Set up output writer
var output io.Writer = cli.cmd.OutOrStdout() output := cli.cmd.OutOrStdout()
if outputFile != "" { if outputFile != "" {
file, err := cli.fs.Create(outputFile) file, err := cli.fs.Create(outputFile)
if err != nil { if err != nil {
@ -185,15 +203,22 @@ func (cli *Instance) Decrypt(secretName, inputFile, outputFile string) error {
return fmt.Errorf("failed to get secret value: %w", err) return fmt.Errorf("failed to get secret value: %w", err)
} }
ageSecretKey := string(secretValue) // Create secure buffer for the secret value
secureBuffer := memguard.NewBufferFromBytes(secretValue)
defer secureBuffer.Destroy()
// Clear the original secret value
for i := range secretValue {
secretValue[i] = 0
}
// Validate that it's a valid age secret key // Validate that it's a valid age secret key
if !isValidAgeSecretKey(ageSecretKey) { if !isValidAgeSecretKey(secureBuffer.String()) {
return fmt.Errorf("secret '%s' does not contain a valid age secret key", secretName) return fmt.Errorf("secret '%s' does not contain a valid age secret key", secretName)
} }
// Parse the age secret key to get the identity // Parse the age secret key to get the identity
identity, err := age.ParseX25519Identity(ageSecretKey) identity, err := age.ParseX25519Identity(secureBuffer.String())
if err != nil { if err != nil {
return fmt.Errorf("failed to parse age secret key: %w", err) return fmt.Errorf("failed to parse age secret key: %w", err)
} }
@ -210,7 +235,7 @@ func (cli *Instance) Decrypt(secretName, inputFile, outputFile string) error {
} }
// Set up output writer // Set up output writer
var output io.Writer = cli.cmd.OutOrStdout() output := cli.cmd.OutOrStdout()
if outputFile != "" { if outputFile != "" {
file, err := cli.fs.Create(outputFile) file, err := cli.fs.Create(outputFile)
if err != nil { if err != nil {
@ -236,6 +261,7 @@ func (cli *Instance) Decrypt(secretName, inputFile, outputFile string) error {
// isValidAgeSecretKey checks if a string is a valid age secret key by attempting to parse it // isValidAgeSecretKey checks if a string is a valid age secret key by attempting to parse it
func isValidAgeSecretKey(key string) bool { func isValidAgeSecretKey(key string) bool {
_, err := age.ParseX25519Identity(key) _, err := age.ParseX25519Identity(key)
return err == nil return err == nil
} }

View File

@ -7,6 +7,7 @@ import (
"os" "os"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"github.com/awnumar/memguard"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39" "github.com/tyler-smith/go-bip39"
) )
@ -36,8 +37,9 @@ func newGenerateMnemonicCmd() *cobra.Command {
Long: `Generate a cryptographically secure random BIP39 ` + Long: `Generate a cryptographically secure random BIP39 ` +
`mnemonic phrase that can be used with 'secret init' ` + `mnemonic phrase that can be used with 'secret init' ` +
`or 'secret import'.`, `or 'secret import'.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
cli := NewCLIInstance() cli := NewCLIInstance()
return cli.GenerateMnemonic(cmd) return cli.GenerateMnemonic(cmd)
}, },
} }
@ -135,7 +137,11 @@ func (cli *Instance) GenerateSecret(
return err return err
} }
if err := vlt.AddSecret(secretName, []byte(secretValue), force); err != nil { // Protect the generated secret immediately
secretBuffer := memguard.NewBufferFromBytes([]byte(secretValue))
defer secretBuffer.Destroy()
if err := vlt.AddSecret(secretName, secretBuffer, force); err != nil {
return err return err
} }
@ -147,12 +153,14 @@ func (cli *Instance) GenerateSecret(
// generateRandomBase58 generates a random base58 string of the specified length // generateRandomBase58 generates a random base58 string of the specified length
func generateRandomBase58(length int) (string, error) { func generateRandomBase58(length int) (string, error) {
const base58Chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" const base58Chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
return generateRandomString(length, base58Chars) return generateRandomString(length, base58Chars)
} }
// generateRandomAlnum generates a random alphanumeric string of the specified length // generateRandomAlnum generates a random alphanumeric string of the specified length
func generateRandomAlnum(length int) (string, error) { func generateRandomAlnum(length int) (string, error) {
const alnumChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" const alnumChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
return generateRandomString(length, alnumChars) return generateRandomString(length, alnumChars)
} }

View File

@ -11,6 +11,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39" "github.com/tyler-smith/go-bip39"
@ -27,8 +28,9 @@ func NewInitCmd() *cobra.Command {
} }
// RunInit is the exported function that handles the init command // RunInit is the exported function that handles the init command
func RunInit(cmd *cobra.Command, args []string) error { func RunInit(cmd *cobra.Command, _ []string) error {
cli := NewCLIInstance() cli := NewCLIInstance()
return cli.Init(cmd) return cli.Init(cmd)
} }
@ -42,6 +44,7 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
if err := cli.fs.MkdirAll(stateDir, secret.DirPerms); err != nil { if err := cli.fs.MkdirAll(stateDir, secret.DirPerms); err != nil {
secret.Debug("Failed to create state directory", "error", err) secret.Debug("Failed to create state directory", "error", err)
return fmt.Errorf("failed to create state directory: %w", err) return fmt.Errorf("failed to create state directory: %w", err)
} }
@ -62,12 +65,14 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
mnemonicStr, err = readLineFromStdin("Enter your BIP39 mnemonic phrase: ") mnemonicStr, err = readLineFromStdin("Enter your BIP39 mnemonic phrase: ")
if err != nil { if err != nil {
secret.Debug("Failed to read mnemonic from stdin", "error", err) secret.Debug("Failed to read mnemonic from stdin", "error", err)
return fmt.Errorf("failed to read mnemonic: %w", err) return fmt.Errorf("failed to read mnemonic: %w", err)
} }
} }
if mnemonicStr == "" { if mnemonicStr == "" {
secret.Debug("Empty mnemonic provided") secret.Debug("Empty mnemonic provided")
return fmt.Errorf("mnemonic cannot be empty") return fmt.Errorf("mnemonic cannot be empty")
} }
@ -75,6 +80,7 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
secret.DebugWith("Validating BIP39 mnemonic", slog.Int("word_count", len(strings.Fields(mnemonicStr)))) secret.DebugWith("Validating BIP39 mnemonic", slog.Int("word_count", len(strings.Fields(mnemonicStr))))
if !bip39.IsMnemonicValid(mnemonicStr) { if !bip39.IsMnemonicValid(mnemonicStr) {
secret.Debug("Invalid BIP39 mnemonic provided") secret.Debug("Invalid BIP39 mnemonic provided")
return fmt.Errorf("invalid BIP39 mnemonic phrase\nRun 'secret generate mnemonic' to create a valid mnemonic") return fmt.Errorf("invalid BIP39 mnemonic phrase\nRun 'secret generate mnemonic' to create a valid mnemonic")
} }
@ -94,6 +100,7 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
vlt, err := vault.CreateVault(cli.fs, cli.stateDir, "default") vlt, err := vault.CreateVault(cli.fs, cli.stateDir, "default")
if err != nil { if err != nil {
secret.Debug("Failed to create default vault", "error", err) secret.Debug("Failed to create default vault", "error", err)
return fmt.Errorf("failed to create default vault: %w", err) return fmt.Errorf("failed to create default vault: %w", err)
} }
@ -102,6 +109,7 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
metadata, err := vault.LoadVaultMetadata(cli.fs, vaultDir) metadata, err := vault.LoadVaultMetadata(cli.fs, vaultDir)
if err != nil { if err != nil {
secret.Debug("Failed to load vault metadata", "error", err) secret.Debug("Failed to load vault metadata", "error", err)
return fmt.Errorf("failed to load vault metadata: %w", err) return fmt.Errorf("failed to load vault metadata: %w", err)
} }
@ -109,6 +117,7 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
ltIdentity, err := agehd.DeriveIdentity(mnemonicStr, metadata.DerivationIndex) ltIdentity, err := agehd.DeriveIdentity(mnemonicStr, metadata.DerivationIndex)
if err != nil { if err != nil {
secret.Debug("Failed to derive long-term key", "error", err) secret.Debug("Failed to derive long-term key", "error", err)
return fmt.Errorf("failed to derive long-term key from mnemonic: %w", err) return fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
} }
ltPubKey := ltIdentity.Recipient().String() ltPubKey := ltIdentity.Recipient().String()
@ -117,25 +126,28 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
vlt.Unlock(ltIdentity) vlt.Unlock(ltIdentity)
// Prompt for passphrase for unlocker // Prompt for passphrase for unlocker
var passphraseStr string var passphraseBuffer *memguard.LockedBuffer
if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" { if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" {
secret.Debug("Using unlock passphrase from environment variable") secret.Debug("Using unlock passphrase from environment variable")
passphraseStr = envPassphrase passphraseBuffer = memguard.NewBufferFromBytes([]byte(envPassphrase))
} else { } else {
secret.Debug("Prompting user for unlock passphrase") secret.Debug("Prompting user for unlock passphrase")
// Use secure passphrase input with confirmation // Use secure passphrase input with confirmation
passphraseStr, err = readSecurePassphrase("Enter passphrase for unlocker: ") passphraseBuffer, err = readSecurePassphrase("Enter passphrase for unlocker: ")
if err != nil { if err != nil {
secret.Debug("Failed to read unlock passphrase", "error", err) secret.Debug("Failed to read unlock passphrase", "error", err)
return fmt.Errorf("failed to read passphrase: %w", err) return fmt.Errorf("failed to read passphrase: %w", err)
} }
} }
defer passphraseBuffer.Destroy()
// Create passphrase-protected unlocker // Create passphrase-protected unlocker
secret.Debug("Creating passphrase-protected unlocker") secret.Debug("Creating passphrase-protected unlocker")
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseStr) passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil { if err != nil {
secret.Debug("Failed to create unlocker", "error", err) secret.Debug("Failed to create unlocker", "error", err)
return fmt.Errorf("failed to create unlocker: %w", err) return fmt.Errorf("failed to create unlocker: %w", err)
} }
@ -154,8 +166,11 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
} }
// Encrypt long-term private key to unlocker // Encrypt long-term private key to unlocker
ltPrivKeyData := []byte(ltIdentity.String()) // Use memguard to protect the private key in memory
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyData, unlockerRecipient) ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
defer ltPrivKeyBuffer.Destroy()
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerRecipient)
if err != nil { if err != nil {
return fmt.Errorf("failed to encrypt long-term private key: %w", err) return fmt.Errorf("failed to encrypt long-term private key: %w", err)
} }
@ -180,23 +195,27 @@ func (cli *Instance) Init(cmd *cobra.Command) error {
// readSecurePassphrase reads a passphrase securely from the terminal without echoing // readSecurePassphrase reads a passphrase securely from the terminal without echoing
// This version adds confirmation (read twice) for creating new unlockers // This version adds confirmation (read twice) for creating new unlockers
func readSecurePassphrase(prompt string) (string, error) { // Returns a LockedBuffer containing the passphrase
func readSecurePassphrase(prompt string) (*memguard.LockedBuffer, error) {
// Get the first passphrase // Get the first passphrase
passphrase1, err := secret.ReadPassphrase(prompt) passphraseBuffer1, err := secret.ReadPassphrase(prompt)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer passphraseBuffer1.Destroy()
// Read confirmation passphrase // Read confirmation passphrase
passphrase2, err := secret.ReadPassphrase("Confirm passphrase: ") passphraseBuffer2, err := secret.ReadPassphrase("Confirm passphrase: ")
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read passphrase confirmation: %w", err) return nil, fmt.Errorf("failed to read passphrase confirmation: %w", err)
} }
defer passphraseBuffer2.Destroy()
// Compare passphrases // Compare passphrases
if passphrase1 != passphrase2 { if passphraseBuffer1.String() != passphraseBuffer2.String() {
return "", fmt.Errorf("passphrases do not match") return nil, fmt.Errorf("passphrases do not match")
} }
return passphrase1, nil // Create a new buffer with the confirmed passphrase
return memguard.NewBufferFromBytes(passphraseBuffer1.Bytes()), nil
} }

View File

@ -52,11 +52,12 @@ func TestMain(m *testing.M) {
// all functionality of the secret manager using a real filesystem in a temporary directory. // all functionality of the secret manager using a real filesystem in a temporary directory.
// This test serves as both validation and documentation of the program's behavior. // This test serves as both validation and documentation of the program's behavior.
func TestSecretManagerIntegration(t *testing.T) { func TestSecretManagerIntegration(t *testing.T) {
// Enable debug logging to diagnose issues // Only enable debug logging if running with -v flag
t.Setenv("GODEBUG", "berlin.sneak.pkg.secret") if testing.Verbose() {
t.Setenv("GODEBUG", "berlin.sneak.pkg.secret")
// Reinitialize debug logging to pick up the environment variable change // Reinitialize debug logging to pick up the environment variable change
secret.InitDebugLogging() secret.InitDebugLogging()
}
// Test configuration // Test configuration
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
@ -266,7 +267,7 @@ func TestSecretManagerIntegration(t *testing.T) {
// Test 26: Large secret values // Test 26: Large secret values
// Purpose: Test with large secret values (e.g., certificates) // Purpose: Test with large secret values (e.g., certificates)
// Expected: Proper storage and retrieval // Expected: Proper storage and retrieval
test26LargeSecrets(t, tempDir, secretPath, testMnemonic, runSecret, runSecretWithEnv) test26LargeSecrets(t, tempDir, secretPath, testMnemonic, runSecret, runSecretWithEnv, runSecretWithStdin)
// Test 27: Special characters in values // Test 27: Special characters in values
// Purpose: Test secrets with newlines, unicode, binary data // Purpose: Test secrets with newlines, unicode, binary data
@ -380,8 +381,8 @@ func test01Initialize(t *testing.T, tempDir, testMnemonic, testPassphrase string
t.Logf("Parsed metadata: %+v", metadata) t.Logf("Parsed metadata: %+v", metadata)
// Verify metadata fields // Verify metadata fields
assert.Equal(t, float64(0), metadata["derivation_index"], "first vault should have index 0") assert.Equal(t, float64(0), metadata["derivationIndex"], "first vault should have index 0")
assert.Contains(t, metadata, "public_key_hash", "should contain public key hash") assert.Contains(t, metadata, "publicKeyHash", "should contain public key hash")
assert.Contains(t, metadata, "createdAt", "should contain creation timestamp") assert.Contains(t, metadata, "createdAt", "should contain creation timestamp")
// Verify the longterm.age file in passphrase unlocker // Verify the longterm.age file in passphrase unlocker
@ -411,8 +412,8 @@ func test02ListVaults(t *testing.T, runSecret func(...string) (string, error)) {
require.NoError(t, err, "JSON output should be valid") require.NoError(t, err, "JSON output should be valid")
// Verify current vault // Verify current vault
currentVault, ok := response["current_vault"] currentVault, ok := response["currentVault"]
require.True(t, ok, "response should contain current_vault") require.True(t, ok, "response should contain currentVault")
assert.Equal(t, "default", currentVault, "current vault should be default") assert.Equal(t, "default", currentVault, "current vault should be default")
// Verify vaults list // Verify vaults list
@ -520,14 +521,14 @@ func test04ImportMnemonic(t *testing.T, tempDir, testMnemonic, testPassphrase st
require.NoError(t, err, "vault metadata should be valid JSON") require.NoError(t, err, "vault metadata should be valid JSON")
// Work vault should have a different derivation index than default (0) // Work vault should have a different derivation index than default (0)
derivIndex, ok := metadata["derivation_index"].(float64) derivIndex, ok := metadata["derivationIndex"].(float64)
require.True(t, ok, "derivation_index should be a number") require.True(t, ok, "derivationIndex should be a number")
assert.NotEqual(t, float64(0), derivIndex, "work vault should have non-zero derivation index") assert.NotEqual(t, float64(0), derivIndex, "work vault should have non-zero derivation index")
// Verify public key hash is stored // Verify public key hash is stored
assert.Contains(t, metadata, "public_key_hash", "should contain public key hash") assert.Contains(t, metadata, "publicKeyHash", "should contain public key hash")
pubKeyHash, ok := metadata["public_key_hash"].(string) pubKeyHash, ok := metadata["publicKeyHash"].(string)
require.True(t, ok, "public_key_hash should be a string") require.True(t, ok, "publicKeyHash should be a string")
assert.NotEmpty(t, pubKeyHash, "public key hash should not be empty") assert.NotEmpty(t, pubKeyHash, "public key hash should not be empty")
} }
@ -876,8 +877,8 @@ func test11ListSecrets(t *testing.T, testMnemonic string, runSecret func(...stri
var listResponse struct { var listResponse struct {
Secrets []struct { Secrets []struct {
Name string `json:"name"` Name string `json:"name"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updated_at"` UpdatedAt string `json:"updatedAt"`
} `json:"secrets"` } `json:"secrets"`
Filter string `json:"filter,omitempty"` Filter string `json:"filter,omitempty"`
} }
@ -1377,7 +1378,7 @@ func test19DisasterRecovery(t *testing.T, tempDir, secretPath, testMnemonic stri
require.NoError(t, err, "read vault metadata") require.NoError(t, err, "read vault metadata")
var metadata struct { var metadata struct {
DerivationIndex uint32 `json:"derivation_index"` DerivationIndex uint32 `json:"derivationIndex"`
} }
err = json.Unmarshal(metadataBytes, &metadata) err = json.Unmarshal(metadataBytes, &metadata)
require.NoError(t, err, "parse vault metadata") require.NoError(t, err, "parse vault metadata")
@ -1531,7 +1532,7 @@ func test22JSONOutput(t *testing.T, runSecret func(...string) (string, error)) {
err = json.Unmarshal([]byte(output), &vaultListResponse) err = json.Unmarshal([]byte(output), &vaultListResponse)
require.NoError(t, err, "vault list JSON should be valid") require.NoError(t, err, "vault list JSON should be valid")
assert.Contains(t, vaultListResponse, "vaults", "should have vaults key") assert.Contains(t, vaultListResponse, "vaults", "should have vaults key")
assert.Contains(t, vaultListResponse, "current_vault", "should have current_vault key") assert.Contains(t, vaultListResponse, "currentVault", "should have currentVault key")
// Test secret list --json (already tested in test 11) // Test secret list --json (already tested in test 11)
@ -1687,7 +1688,7 @@ func test25ConcurrentOperations(t *testing.T, testMnemonic string, runSecret fun
// to avoid conflicts, but reads should always work // to avoid conflicts, but reads should always work
} }
func test26LargeSecrets(t *testing.T, tempDir, secretPath, testMnemonic string, runSecret func(...string) (string, error), runSecretWithEnv func(map[string]string, ...string) (string, error)) { func test26LargeSecrets(t *testing.T, tempDir, secretPath, testMnemonic string, runSecret func(...string) (string, error), runSecretWithEnv func(map[string]string, ...string) (string, error), runSecretWithStdin func(string, map[string]string, ...string) (string, error)) {
// Make sure we're in default vault // Make sure we're in default vault
_, err := runSecret("vault", "select", "default") _, err := runSecret("vault", "select", "default")
require.NoError(t, err, "vault select should succeed") require.NoError(t, err, "vault select should succeed")
@ -1700,16 +1701,10 @@ func test26LargeSecrets(t *testing.T, tempDir, secretPath, testMnemonic string,
assert.Greater(t, len(largeValue), 10000, "should be > 10KB") assert.Greater(t, len(largeValue), 10000, "should be > 10KB")
// Add large secret // Add large secret
cmd := exec.Command(secretPath, "add", "large/secret", "--force") _, err = runSecretWithStdin(largeValue, map[string]string{
cmd.Env = []string{ "SB_SECRET_MNEMONIC": testMnemonic,
fmt.Sprintf("SB_SECRET_STATE_DIR=%s", tempDir), }, "add", "large/secret", "--force")
fmt.Sprintf("SB_SECRET_MNEMONIC=%s", testMnemonic), require.NoError(t, err, "add large secret should succeed")
fmt.Sprintf("PATH=%s", os.Getenv("PATH")),
fmt.Sprintf("HOME=%s", os.Getenv("HOME")),
}
cmd.Stdin = strings.NewReader(largeValue)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "add large secret should succeed: %s", string(output))
// Retrieve and verify // Retrieve and verify
retrievedValue, err := runSecretWithEnv(map[string]string{ retrievedValue, err := runSecretWithEnv(map[string]string{
@ -1725,15 +1720,9 @@ BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTgwMjI4MTQwMzQ5WhcNMjgwMjI2MTQwMzQ5WjBF aWRnaXRzIFB0eSBMdGQwHhcNMTgwMjI4MTQwMzQ5WhcNMjgwMjI2MTQwMzQ5WjBF
-----END CERTIFICATE-----` -----END CERTIFICATE-----`
cmd = exec.Command(secretPath, "add", "cert/test", "--force") _, err = runSecretWithStdin(certValue, map[string]string{
cmd.Env = []string{ "SB_SECRET_MNEMONIC": testMnemonic,
fmt.Sprintf("SB_SECRET_STATE_DIR=%s", tempDir), }, "add", "cert/test", "--force")
fmt.Sprintf("SB_SECRET_MNEMONIC=%s", testMnemonic),
fmt.Sprintf("PATH=%s", os.Getenv("PATH")),
fmt.Sprintf("HOME=%s", os.Getenv("HOME")),
}
cmd.Stdin = strings.NewReader(certValue)
_, err = cmd.CombinedOutput()
require.NoError(t, err, "add certificate should succeed") require.NoError(t, err, "add certificate should succeed")
// Retrieve and verify certificate // Retrieve and verify certificate
@ -1821,10 +1810,10 @@ func test28VaultMetadata(t *testing.T, tempDir string) {
require.NoError(t, err, "default vault metadata should be valid JSON") require.NoError(t, err, "default vault metadata should be valid JSON")
// Verify required fields // Verify required fields
assert.Equal(t, float64(0), defaultMetadata["derivation_index"]) assert.Equal(t, float64(0), defaultMetadata["derivationIndex"])
assert.Contains(t, defaultMetadata, "createdAt") assert.Contains(t, defaultMetadata, "createdAt")
assert.Contains(t, defaultMetadata, "public_key_hash") assert.Contains(t, defaultMetadata, "publicKeyHash")
assert.Contains(t, defaultMetadata, "mnemonic_family_hash") assert.Contains(t, defaultMetadata, "mnemonicFamilyHash")
// Check work vault metadata // Check work vault metadata
workMetadataPath := filepath.Join(tempDir, "vaults.d", "work", "vault-metadata.json") workMetadataPath := filepath.Join(tempDir, "vaults.d", "work", "vault-metadata.json")
@ -1836,12 +1825,12 @@ func test28VaultMetadata(t *testing.T, tempDir string) {
require.NoError(t, err, "work vault metadata should be valid JSON") require.NoError(t, err, "work vault metadata should be valid JSON")
// Work vault should have different derivation index // Work vault should have different derivation index
workIndex := workMetadata["derivation_index"].(float64) workIndex := workMetadata["derivationIndex"].(float64)
assert.NotEqual(t, float64(0), workIndex, "work vault should have non-zero derivation index") assert.NotEqual(t, float64(0), workIndex, "work vault should have non-zero derivation index")
// Both vaults created with same mnemonic should have same mnemonic_family_hash // Both vaults created with same mnemonic should have same mnemonicFamilyHash
assert.Equal(t, defaultMetadata["mnemonic_family_hash"], workMetadata["mnemonic_family_hash"], assert.Equal(t, defaultMetadata["mnemonicFamilyHash"], workMetadata["mnemonicFamilyHash"],
"vaults from same mnemonic should have same mnemonic_family_hash") "vaults from same mnemonic should have same mnemonicFamilyHash")
} }
func test29SymlinkHandling(t *testing.T, tempDir, secretPath, testMnemonic string) { func test29SymlinkHandling(t *testing.T, tempDir, secretPath, testMnemonic string) {
@ -2000,14 +1989,14 @@ func test31EnvMnemonicUsesVaultDerivationIndex(t *testing.T, tempDir, secretPath
var defaultMetadata map[string]interface{} var defaultMetadata map[string]interface{}
err := json.Unmarshal(defaultMetadataBytes, &defaultMetadata) err := json.Unmarshal(defaultMetadataBytes, &defaultMetadata)
require.NoError(t, err, "default vault metadata should be valid JSON") require.NoError(t, err, "default vault metadata should be valid JSON")
assert.Equal(t, float64(0), defaultMetadata["derivation_index"], "default vault should have index 0") assert.Equal(t, float64(0), defaultMetadata["derivationIndex"], "default vault should have index 0")
workMetadataPath := filepath.Join(tempDir, "vaults.d", "work", "vault-metadata.json") workMetadataPath := filepath.Join(tempDir, "vaults.d", "work", "vault-metadata.json")
workMetadataBytes := readFile(t, workMetadataPath) workMetadataBytes := readFile(t, workMetadataPath)
var workMetadata map[string]interface{} var workMetadata map[string]interface{}
err = json.Unmarshal(workMetadataBytes, &workMetadata) err = json.Unmarshal(workMetadataBytes, &workMetadata)
require.NoError(t, err, "work vault metadata should be valid JSON") require.NoError(t, err, "work vault metadata should be valid JSON")
assert.Equal(t, float64(1), workMetadata["derivation_index"], "work vault should have index 1") assert.Equal(t, float64(1), workMetadata["derivationIndex"], "work vault should have index 1")
// Switch to work vault // Switch to work vault
_, err = runSecret("vault", "select", "work") _, err = runSecret("vault", "select", "work")

View File

@ -8,7 +8,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"github.com/spf13/afero" "github.com/awnumar/memguard"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -103,6 +103,20 @@ func newImportCmd() *cobra.Command {
return cmd return cmd
} }
// updateBufferSize updates the buffer size based on usage pattern
func updateBufferSize(currentSize int, sameSize *int) int {
*sameSize++
const doubleAfterBuffers = 2
const growthFactor = 2
if *sameSize >= doubleAfterBuffers {
*sameSize = 0
return currentSize * growthFactor
}
return currentSize
}
// AddSecret adds a secret to the current vault // AddSecret adds a secret to the current vault
func (cli *Instance) AddSecret(secretName string, force bool) error { func (cli *Instance) AddSecret(secretName string, force bool) error {
secret.Debug("CLI AddSecret starting", "secret_name", secretName, "force", force) secret.Debug("CLI AddSecret starting", "secret_name", secretName, "force", force)
@ -116,25 +130,84 @@ func (cli *Instance) AddSecret(secretName string, force bool) error {
secret.Debug("Got current vault", "vault_name", vlt.GetName()) secret.Debug("Got current vault", "vault_name", vlt.GetName())
// Read secret value from stdin // Read secret value directly into protected buffers
secret.Debug("Reading secret value from stdin") secret.Debug("Reading secret value from stdin into protected buffers")
value, err := io.ReadAll(cli.cmd.InOrStdin())
if err != nil { const initialSize = 4 * 1024 // 4KB initial buffer
return fmt.Errorf("failed to read secret value: %w", err) const maxSize = 100 * 1024 * 1024 // 100MB max
type bufferInfo struct {
buffer *memguard.LockedBuffer
used int
} }
secret.Debug("Read secret value from stdin", "value_length", len(value)) var buffers []bufferInfo
defer func() {
for _, b := range buffers {
b.buffer.Destroy()
}
}()
// Remove trailing newline if present reader := cli.cmd.InOrStdin()
if len(value) > 0 && value[len(value)-1] == '\n' { totalSize := 0
value = value[:len(value)-1] currentBufferSize := initialSize
secret.Debug("Removed trailing newline", "new_length", len(value)) sameSize := 0
for {
// Create a new buffer
buffer := memguard.NewBuffer(currentBufferSize)
n, err := io.ReadFull(reader, buffer.Bytes())
if n == 0 {
// No data read, destroy the unused buffer
buffer.Destroy()
} else {
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
totalSize += n
if totalSize > maxSize {
return fmt.Errorf("secret too large: exceeds 100MB limit")
}
// If we filled the buffer, consider growing for next iteration
if n == currentBufferSize {
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
}
}
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
} else if err != nil {
return fmt.Errorf("failed to read secret value: %w", err)
}
}
// Check for trailing newline in the last buffer
if len(buffers) > 0 && totalSize > 0 {
lastBuffer := &buffers[len(buffers)-1]
if lastBuffer.buffer.Bytes()[lastBuffer.used-1] == '\n' {
lastBuffer.used--
totalSize--
}
}
secret.Debug("Read secret value from stdin", "value_length", totalSize, "buffers", len(buffers))
// Combine all buffers into a single protected buffer
valueBuffer := memguard.NewBuffer(totalSize)
defer valueBuffer.Destroy()
offset := 0
for _, b := range buffers {
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
offset += b.used
} }
// Add the secret to the vault // Add the secret to the vault
secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", len(value), "force", force) secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", valueBuffer.Size(), "force", force)
if err := vlt.AddSecret(secretName, value, force); err != nil { if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
secret.Debug("vault.AddSecret failed", "error", err) secret.Debug("vault.AddSecret failed", "error", err)
return err return err
} }
@ -156,6 +229,7 @@ func (cli *Instance) GetSecretWithVersion(cmd *cobra.Command, secretName string,
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir) vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
if err != nil { if err != nil {
secret.Debug("Failed to get current vault", "error", err) secret.Debug("Failed to get current vault", "error", err)
return err return err
} }
@ -168,6 +242,7 @@ func (cli *Instance) GetSecretWithVersion(cmd *cobra.Command, secretName string,
} }
if err != nil { if err != nil {
secret.Debug("Failed to get secret", "error", err) secret.Debug("Failed to get secret", "error", err)
return err return err
} }
@ -295,14 +370,77 @@ func (cli *Instance) ImportSecret(cmd *cobra.Command, secretName, sourceFile str
return err return err
} }
// Read secret value from the source file // Read secret value from the source file into protected buffers
value, err := afero.ReadFile(cli.fs, sourceFile) file, err := cli.fs.Open(sourceFile)
if err != nil { if err != nil {
return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err) return fmt.Errorf("failed to open file %s: %w", sourceFile, err)
}
defer func() {
if err := file.Close(); err != nil {
secret.Debug("Failed to close file", "error", err)
}
}()
const initialSize = 4 * 1024 // 4KB initial buffer
const maxSize = 100 * 1024 * 1024 // 100MB max
type bufferInfo struct {
buffer *memguard.LockedBuffer
used int
}
var buffers []bufferInfo
defer func() {
for _, b := range buffers {
b.buffer.Destroy()
}
}()
totalSize := 0
currentBufferSize := initialSize
sameSize := 0
for {
// Create a new buffer
buffer := memguard.NewBuffer(currentBufferSize)
n, err := io.ReadFull(file, buffer.Bytes())
if n == 0 {
// No data read, destroy the unused buffer
buffer.Destroy()
} else {
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
totalSize += n
if totalSize > maxSize {
return fmt.Errorf("secret file too large: exceeds 100MB limit")
}
// If we filled the buffer, consider growing for next iteration
if n == currentBufferSize {
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
}
}
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
} else if err != nil {
return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err)
}
}
// Combine all buffers into a single protected buffer
valueBuffer := memguard.NewBuffer(totalSize)
defer valueBuffer.Destroy()
offset := 0
for _, b := range buffers {
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
offset += b.used
} }
// Store the secret in the vault // Store the secret in the vault
if err := vlt.AddSecret(secretName, value, force); err != nil { if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
return err return err
} }

View File

@ -0,0 +1,425 @@
package cli
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"path/filepath"
"strings"
"testing"
"git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestAddSecretVariousSizes tests adding secrets of various sizes through stdin
func TestAddSecretVariousSizes(t *testing.T) {
tests := []struct {
name string
size int
shouldError bool
errorMsg string
}{
{
name: "1KB secret",
size: 1024,
shouldError: false,
},
{
name: "10KB secret",
size: 10 * 1024,
shouldError: false,
},
{
name: "100KB secret",
size: 100 * 1024,
shouldError: false,
},
{
name: "1MB secret",
size: 1024 * 1024,
shouldError: false,
},
{
name: "10MB secret",
size: 10 * 1024 * 1024,
shouldError: false,
},
{
name: "99MB secret",
size: 99 * 1024 * 1024,
shouldError: false,
},
{
name: "100MB secret minus 1 byte",
size: 100*1024*1024 - 1,
shouldError: false,
},
{
name: "101MB secret - should fail",
size: 101 * 1024 * 1024,
shouldError: true,
errorMsg: "secret too large: exceeds 100MB limit",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Generate test data of specified size
testData := make([]byte, tt.size)
_, err = rand.Read(testData)
require.NoError(t, err)
// Add newline that will be stripped
testDataWithNewline := append(testData, '\n')
// Create fake stdin
stdin := bytes.NewReader(testDataWithNewline)
// Create command with fake stdin
cmd := &cobra.Command{}
cmd.SetIn(stdin)
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
cli.cmd = cmd
// Test adding the secret
secretName := fmt.Sprintf("test-secret-%d", tt.size)
err = cli.AddSecret(secretName, false)
if tt.shouldError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret(secretName)
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original (without newline)")
}
})
}
}
// TestImportSecretVariousSizes tests importing secrets of various sizes from files
func TestImportSecretVariousSizes(t *testing.T) {
tests := []struct {
name string
size int
shouldError bool
errorMsg string
}{
{
name: "1KB file",
size: 1024,
shouldError: false,
},
{
name: "10KB file",
size: 10 * 1024,
shouldError: false,
},
{
name: "100KB file",
size: 100 * 1024,
shouldError: false,
},
{
name: "1MB file",
size: 1024 * 1024,
shouldError: false,
},
{
name: "10MB file",
size: 10 * 1024 * 1024,
shouldError: false,
},
{
name: "99MB file",
size: 99 * 1024 * 1024,
shouldError: false,
},
{
name: "100MB file",
size: 100 * 1024 * 1024,
shouldError: false,
},
{
name: "101MB file - should fail",
size: 101 * 1024 * 1024,
shouldError: true,
errorMsg: "secret file too large: exceeds 100MB limit",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Generate test data of specified size
testData := make([]byte, tt.size)
_, err = rand.Read(testData)
require.NoError(t, err)
// Write test data to file
testFile := fmt.Sprintf("/test/secret-%d.bin", tt.size)
err = afero.WriteFile(fs, testFile, testData, 0o600)
require.NoError(t, err)
// Create command
cmd := &cobra.Command{}
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
// Test importing the secret
secretName := fmt.Sprintf("imported-secret-%d", tt.size)
err = cli.ImportSecret(cmd, secretName, testFile, false)
if tt.shouldError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret(secretName)
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
}
})
}
}
// TestAddSecretBufferGrowth tests that our buffer growth strategy works correctly
func TestAddSecretBufferGrowth(t *testing.T) {
// Test various sizes that should trigger buffer growth
sizes := []int{
1, // Single byte
100, // Small
4095, // Just under initial 4KB
4096, // Exactly 4KB
4097, // Just over 4KB
8191, // Just under 8KB (first double)
8192, // Exactly 8KB
8193, // Just over 8KB
12288, // 12KB (should trigger second double)
16384, // 16KB
32768, // 32KB (after more doublings)
65536, // 64KB
131072, // 128KB
524288, // 512KB
1048576, // 1MB
2097152, // 2MB
}
for _, size := range sizes {
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Create test data of exactly the specified size
// Use a pattern that's easy to verify
testData := make([]byte, size)
for i := range testData {
testData[i] = byte(i % 256)
}
// Create fake stdin without newline
stdin := bytes.NewReader(testData)
// Create command with fake stdin
cmd := &cobra.Command{}
cmd.SetIn(stdin)
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
cli.cmd = cmd
// Test adding the secret
secretName := fmt.Sprintf("buffer-test-%d", size)
err = cli.AddSecret(secretName, false)
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret(secretName)
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original exactly")
})
}
}
// TestAddSecretStreamingBehavior tests that we handle streaming input correctly
func TestAddSecretStreamingBehavior(t *testing.T) {
// Set up test environment
fs := afero.NewMemMapFs()
stateDir := "/test/state"
// Set test mnemonic
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
// Create vault
vaultName := "test-vault"
_, err := vault.CreateVault(fs, stateDir, vaultName)
require.NoError(t, err)
// Set current vault
currentVaultPath := filepath.Join(stateDir, "currentvault")
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
require.NoError(t, err)
// Get vault and set up long-term key
vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err)
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
require.NoError(t, err)
vlt.Unlock(ltIdentity)
// Create a custom reader that simulates slow streaming input
// This will help verify our buffer handling works correctly with partial reads
testData := []byte(strings.Repeat("Hello, World! ", 1000)) // ~14KB
slowReader := &slowReader{
data: testData,
chunkSize: 1000, // Read 1KB at a time
}
// Create command with slow reader as stdin
cmd := &cobra.Command{}
cmd.SetIn(slowReader)
// Create CLI instance
cli := NewCLIInstance()
cli.fs = fs
cli.stateDir = stateDir
cli.cmd = cmd
// Test adding the secret
err = cli.AddSecret("streaming-test", false)
require.NoError(t, err)
// Verify the secret was stored correctly
retrievedValue, err := vlt.GetSecret("streaming-test")
require.NoError(t, err)
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
}
// slowReader simulates a reader that returns data in small chunks
type slowReader struct {
data []byte
offset int
chunkSize int
}
func (r *slowReader) Read(p []byte) (n int, err error) {
if r.offset >= len(r.data) {
return 0, io.EOF
}
// Read at most chunkSize bytes
remaining := len(r.data) - r.offset
toRead := r.chunkSize
if toRead > remaining {
toRead = remaining
}
if toRead > len(p) {
toRead = len(p)
}
n = copy(p, r.data[r.offset:r.offset+toRead])
r.offset += n
if r.offset >= len(r.data) {
err = io.EOF
}
return n, err
}

View File

@ -10,6 +10,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -36,7 +37,7 @@ func newUnlockersListCmd() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "list", Use: "list",
Short: "List unlockers in the current vault", Short: "List unlockers in the current vault",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
jsonOutput, _ := cmd.Flags().GetBool("json") jsonOutput, _ := cmd.Flags().GetBool("json")
cli := NewCLIInstance() cli := NewCLIInstance()
@ -59,6 +60,7 @@ func newUnlockersAddCmd() *cobra.Command {
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
cli := NewCLIInstance() cli := NewCLIInstance()
return cli.UnlockersAdd(args[0], cmd) return cli.UnlockersAdd(args[0], cmd)
}, },
} }
@ -73,8 +75,9 @@ func newUnlockersRmCmd() *cobra.Command {
Use: "rm <unlocker-id>", Use: "rm <unlocker-id>",
Short: "Remove an unlocker", Short: "Remove an unlocker",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(_ *cobra.Command, args []string) error {
cli := NewCLIInstance() cli := NewCLIInstance()
return cli.UnlockersRemove(args[0]) return cli.UnlockersRemove(args[0])
}, },
} }
@ -97,8 +100,9 @@ func newUnlockerSelectSubCmd() *cobra.Command {
Use: "select <unlocker-id>", Use: "select <unlocker-id>",
Short: "Select an unlocker as current", Short: "Select an unlocker as current",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(_ *cobra.Command, args []string) error {
cli := NewCLIInstance() cli := NewCLIInstance()
return cli.UnlockerSelect(args[0]) return cli.UnlockerSelect(args[0])
}, },
} }
@ -122,7 +126,7 @@ func (cli *Instance) UnlockersList(jsonOutput bool) error {
type UnlockerInfo struct { type UnlockerInfo struct {
ID string `json:"id"` ID string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"createdAt"`
Flags []string `json:"flags,omitempty"` Flags []string `json:"flags,omitempty"`
} }
@ -251,18 +255,19 @@ func (cli *Instance) UnlockersAdd(unlockerType string, cmd *cobra.Command) error
// The CreatePassphraseUnlocker method will handle getting the long-term key // The CreatePassphraseUnlocker method will handle getting the long-term key
// Check if passphrase is set in environment variable // Check if passphrase is set in environment variable
var passphraseStr string var passphraseBuffer *memguard.LockedBuffer
if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" { if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" {
passphraseStr = envPassphrase passphraseBuffer = memguard.NewBufferFromBytes([]byte(envPassphrase))
} else { } else {
// Use secure passphrase input with confirmation // Use secure passphrase input with confirmation
passphraseStr, err = readSecurePassphrase("Enter passphrase for unlocker: ") passphraseBuffer, err = readSecurePassphrase("Enter passphrase for unlocker: ")
if err != nil { if err != nil {
return fmt.Errorf("failed to read passphrase: %w", err) return fmt.Errorf("failed to read passphrase: %w", err)
} }
} }
defer passphraseBuffer.Destroy()
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseStr) passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil { if err != nil {
return err return err
} }

View File

@ -10,6 +10,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/tyler-smith/go-bip39" "github.com/tyler-smith/go-bip39"
@ -34,7 +35,7 @@ func newVaultListCmd() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "list", Use: "list",
Short: "List available vaults", Short: "List available vaults",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, _ []string) error {
jsonOutput, _ := cmd.Flags().GetBool("json") jsonOutput, _ := cmd.Flags().GetBool("json")
cli := NewCLIInstance() cli := NewCLIInstance()
@ -55,6 +56,7 @@ func newVaultCreateCmd() *cobra.Command {
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
cli := NewCLIInstance() cli := NewCLIInstance()
return cli.CreateVault(cmd, args[0]) return cli.CreateVault(cmd, args[0])
}, },
} }
@ -67,6 +69,7 @@ func newVaultSelectCmd() *cobra.Command {
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
cli := NewCLIInstance() cli := NewCLIInstance()
return cli.SelectVault(cmd, args[0]) return cli.SelectVault(cmd, args[0])
}, },
} }
@ -106,8 +109,8 @@ func (cli *Instance) ListVaults(cmd *cobra.Command, jsonOutput bool) error {
} }
result := map[string]interface{}{ result := map[string]interface{}{
"vaults": vaults, "vaults": vaults,
"current_vault": currentVault, "currentVault": currentVault,
} }
jsonBytes, err := json.MarshalIndent(result, "", " ") jsonBytes, err := json.MarshalIndent(result, "", " ")
@ -209,6 +212,7 @@ func (cli *Instance) VaultImport(cmd *cobra.Command, vaultName string) error {
derivationIndex, err := vault.GetNextDerivationIndex(cli.fs, cli.stateDir, mnemonic) derivationIndex, err := vault.GetNextDerivationIndex(cli.fs, cli.stateDir, mnemonic)
if err != nil { if err != nil {
secret.Debug("Failed to get next derivation index", "error", err) secret.Debug("Failed to get next derivation index", "error", err)
return fmt.Errorf("failed to get next derivation index: %w", err) return fmt.Errorf("failed to get next derivation index: %w", err)
} }
secret.Debug("Using derivation index", "index", derivationIndex) secret.Debug("Using derivation index", "index", derivationIndex)
@ -256,6 +260,7 @@ func (cli *Instance) VaultImport(cmd *cobra.Command, vaultName string) error {
if err := vault.SaveVaultMetadata(cli.fs, vaultDir, existingMetadata); err != nil { if err := vault.SaveVaultMetadata(cli.fs, vaultDir, existingMetadata); err != nil {
secret.Debug("Failed to save vault metadata", "error", err) secret.Debug("Failed to save vault metadata", "error", err)
return fmt.Errorf("failed to save vault metadata: %w", err) return fmt.Errorf("failed to save vault metadata: %w", err)
} }
secret.Debug("Saved vault metadata with derivation index and public key hash") secret.Debug("Saved vault metadata with derivation index and public key hash")
@ -268,14 +273,19 @@ func (cli *Instance) VaultImport(cmd *cobra.Command, vaultName string) error {
secret.Debug("Using unlock passphrase from environment variable") secret.Debug("Using unlock passphrase from environment variable")
// Create secure buffer for passphrase
passphraseBuffer := memguard.NewBufferFromBytes([]byte(passphraseStr))
defer passphraseBuffer.Destroy()
// Unlock the vault with the derived long-term key // Unlock the vault with the derived long-term key
vlt.Unlock(ltIdentity) vlt.Unlock(ltIdentity)
// Create passphrase-protected unlocker // Create passphrase-protected unlocker
secret.Debug("Creating passphrase-protected unlocker") secret.Debug("Creating passphrase-protected unlocker")
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseStr) passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil { if err != nil {
secret.Debug("Failed to create unlocker", "error", err) secret.Debug("Failed to create unlocker", "error", err)
return fmt.Errorf("failed to create unlocker: %w", err) return fmt.Errorf("failed to create unlocker: %w", err)
} }

View File

@ -19,6 +19,7 @@ const (
// newVersionCmd returns the version management command // newVersionCmd returns the version management command
func newVersionCmd() *cobra.Command { func newVersionCmd() *cobra.Command {
cli := NewCLIInstance() cli := NewCLIInstance()
return VersionCommands(cli) return VersionCommands(cli)
} }
@ -64,12 +65,14 @@ func (cli *Instance) ListVersions(cmd *cobra.Command, secretName string) error {
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir) vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
if err != nil { if err != nil {
secret.Debug("Failed to get current vault", "error", err) secret.Debug("Failed to get current vault", "error", err)
return err return err
} }
vaultDir, err := vlt.GetDirectory() vaultDir, err := vlt.GetDirectory()
if err != nil { if err != nil {
secret.Debug("Failed to get vault directory", "error", err) secret.Debug("Failed to get vault directory", "error", err)
return err return err
} }
@ -81,10 +84,12 @@ func (cli *Instance) ListVersions(cmd *cobra.Command, secretName string) error {
exists, err := afero.DirExists(cli.fs, secretDir) exists, err := afero.DirExists(cli.fs, secretDir)
if err != nil { if err != nil {
secret.Debug("Failed to check if secret exists", "error", err) secret.Debug("Failed to check if secret exists", "error", err)
return fmt.Errorf("failed to check if secret exists: %w", err) return fmt.Errorf("failed to check if secret exists: %w", err)
} }
if !exists { if !exists {
secret.Debug("Secret not found", "secret_name", secretName) secret.Debug("Secret not found", "secret_name", secretName)
return fmt.Errorf("secret '%s' not found", secretName) return fmt.Errorf("secret '%s' not found", secretName)
} }
@ -92,11 +97,13 @@ func (cli *Instance) ListVersions(cmd *cobra.Command, secretName string) error {
versions, err := secret.ListVersions(cli.fs, secretDir) versions, err := secret.ListVersions(cli.fs, secretDir)
if err != nil { if err != nil {
secret.Debug("Failed to list versions", "error", err) secret.Debug("Failed to list versions", "error", err)
return fmt.Errorf("failed to list versions: %w", err) return fmt.Errorf("failed to list versions: %w", err)
} }
if len(versions) == 0 { if len(versions) == 0 {
cmd.Println("No versions found") cmd.Println("No versions found")
return nil return nil
} }

View File

@ -26,11 +26,21 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Helper function to add a secret to vault with proper buffer protection
func addTestSecret(t *testing.T, vlt *vault.Vault, name string, value []byte, force bool) {
t.Helper()
buffer := memguard.NewBufferFromBytes(value)
defer buffer.Destroy()
err := vlt.AddSecret(name, buffer, force)
require.NoError(t, err)
}
// Helper function to set up a vault with long-term key // Helper function to set up a vault with long-term key
func setupTestVault(t *testing.T, fs afero.Fs, stateDir string) { func setupTestVault(t *testing.T, fs afero.Fs, stateDir string) {
// Set mnemonic for testing // Set mnemonic for testing
@ -70,13 +80,11 @@ func TestListVersionsCommand(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir) vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err) require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("version-1"), false) addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
err = vlt.AddSecret("test/secret", []byte("version-2"), true) addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
require.NoError(t, err)
// Create a command for output capture // Create a command for output capture
cmd := newRootCmd() cmd := newRootCmd()
@ -144,13 +152,11 @@ func TestPromoteVersionCommand(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir) vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err) require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("version-1"), false) addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
err = vlt.AddSecret("test/secret", []byte("version-2"), true) addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
require.NoError(t, err)
// Get versions // Get versions
vaultDir, _ := vlt.GetDirectory() vaultDir, _ := vlt.GetDirectory()
@ -201,8 +207,7 @@ func TestPromoteNonExistentVersion(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir) vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err) require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("value"), false) addTestSecret(t, vlt, "test/secret", []byte("value"), false)
require.NoError(t, err)
// Create a command for output capture // Create a command for output capture
cmd := newRootCmd() cmd := newRootCmd()
@ -228,13 +233,11 @@ func TestGetSecretWithVersion(t *testing.T) {
vlt, err := vault.GetCurrentVault(fs, stateDir) vlt, err := vault.GetCurrentVault(fs, stateDir)
require.NoError(t, err) require.NoError(t, err)
err = vlt.AddSecret("test/secret", []byte("version-1"), false) addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
err = vlt.AddSecret("test/secret", []byte("version-2"), true) addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
require.NoError(t, err)
// Get versions // Get versions
vaultDir, _ := vlt.GetDirectory() vaultDir, _ := vlt.GetDirectory()

View File

@ -1,3 +1,4 @@
// Package secret provides core types and constants for the secret application.
package secret package secret
import "os" import "os"
@ -6,11 +7,14 @@ const (
// AppID is the unique identifier for this application // AppID is the unique identifier for this application
AppID = "berlin.sneak.pkg.secret" AppID = "berlin.sneak.pkg.secret"
// Environment variable names // EnvStateDir is the environment variable for specifying the state directory
EnvStateDir = "SB_SECRET_STATE_DIR" EnvStateDir = "SB_SECRET_STATE_DIR"
EnvMnemonic = "SB_SECRET_MNEMONIC" // EnvMnemonic is the environment variable for providing the mnemonic phrase
EnvMnemonic = "SB_SECRET_MNEMONIC"
// EnvUnlockPassphrase is the environment variable for providing the unlock passphrase
EnvUnlockPassphrase = "SB_UNLOCK_PASSPHRASE" //nolint:gosec // G101: This is an env var name, not a credential EnvUnlockPassphrase = "SB_UNLOCK_PASSPHRASE" //nolint:gosec // G101: This is an env var name, not a credential
EnvGPGKeyID = "SB_GPG_KEY_ID" // EnvGPGKeyID is the environment variable for providing the GPG key ID
EnvGPGKeyID = "SB_GPG_KEY_ID"
) )
// File system permission constants // File system permission constants

View File

@ -8,25 +8,33 @@ import (
"syscall" "syscall"
"filippo.io/age" "filippo.io/age"
"github.com/awnumar/memguard"
"golang.org/x/term" "golang.org/x/term"
) )
// EncryptToRecipient encrypts data to a recipient using age // EncryptToRecipient encrypts data to a recipient using age
func EncryptToRecipient(data []byte, recipient age.Recipient) ([]byte, error) { // The data parameter should be a LockedBuffer for secure memory handling
Debug("EncryptToRecipient starting", "data_length", len(data)) func EncryptToRecipient(data *memguard.LockedBuffer, recipient age.Recipient) ([]byte, error) {
if data == nil {
return nil, fmt.Errorf("data buffer is nil")
}
Debug("EncryptToRecipient starting", "data_length", data.Size())
var buf bytes.Buffer var buf bytes.Buffer
Debug("Creating age encryptor") Debug("Creating age encryptor")
w, err := age.Encrypt(&buf, recipient) w, err := age.Encrypt(&buf, recipient)
if err != nil { if err != nil {
Debug("Failed to create encryptor", "error", err) Debug("Failed to create encryptor", "error", err)
return nil, fmt.Errorf("failed to create encryptor: %w", err) return nil, fmt.Errorf("failed to create encryptor: %w", err)
} }
Debug("Created age encryptor successfully") Debug("Created age encryptor successfully")
Debug("Writing data to encryptor") Debug("Writing data to encryptor")
if _, err := w.Write(data); err != nil { if _, err := w.Write(data.Bytes()); err != nil {
Debug("Failed to write data to encryptor", "error", err) Debug("Failed to write data to encryptor", "error", err)
return nil, fmt.Errorf("failed to write data: %w", err) return nil, fmt.Errorf("failed to write data: %w", err)
} }
Debug("Wrote data to encryptor successfully") Debug("Wrote data to encryptor successfully")
@ -34,6 +42,7 @@ func EncryptToRecipient(data []byte, recipient age.Recipient) ([]byte, error) {
Debug("Closing encryptor") Debug("Closing encryptor")
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
Debug("Failed to close encryptor", "error", err) Debug("Failed to close encryptor", "error", err)
return nil, fmt.Errorf("failed to close encryptor: %w", err) return nil, fmt.Errorf("failed to close encryptor: %w", err)
} }
Debug("Closed encryptor successfully") Debug("Closed encryptor successfully")
@ -60,18 +69,36 @@ func DecryptWithIdentity(data []byte, identity age.Identity) ([]byte, error) {
} }
// EncryptWithPassphrase encrypts data using a passphrase with age's scrypt-based encryption // EncryptWithPassphrase encrypts data using a passphrase with age's scrypt-based encryption
func EncryptWithPassphrase(data []byte, passphrase string) ([]byte, error) { // The passphrase parameter should be a LockedBuffer for secure memory handling
recipient, err := age.NewScryptRecipient(passphrase) func EncryptWithPassphrase(data []byte, passphrase *memguard.LockedBuffer) ([]byte, error) {
if passphrase == nil {
return nil, fmt.Errorf("passphrase buffer is nil")
}
// Get the passphrase string temporarily
passphraseStr := passphrase.String()
recipient, err := age.NewScryptRecipient(passphraseStr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create scrypt recipient: %w", err) return nil, fmt.Errorf("failed to create scrypt recipient: %w", err)
} }
return EncryptToRecipient(data, recipient) // Create a secure buffer for the data
dataBuffer := memguard.NewBufferFromBytes(data)
defer dataBuffer.Destroy()
return EncryptToRecipient(dataBuffer, recipient)
} }
// DecryptWithPassphrase decrypts data using a passphrase with age's scrypt-based decryption // DecryptWithPassphrase decrypts data using a passphrase with age's scrypt-based decryption
func DecryptWithPassphrase(encryptedData []byte, passphrase string) ([]byte, error) { // The passphrase parameter should be a LockedBuffer for secure memory handling
identity, err := age.NewScryptIdentity(passphrase) func DecryptWithPassphrase(encryptedData []byte, passphrase *memguard.LockedBuffer) ([]byte, error) {
if passphrase == nil {
return nil, fmt.Errorf("passphrase buffer is nil")
}
// Get the passphrase string temporarily
passphraseStr := passphrase.String()
identity, err := age.NewScryptIdentity(passphraseStr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create scrypt identity: %w", err) return nil, fmt.Errorf("failed to create scrypt identity: %w", err)
} }
@ -81,18 +108,19 @@ func DecryptWithPassphrase(encryptedData []byte, passphrase string) ([]byte, err
// ReadPassphrase reads a passphrase securely from the terminal without echoing // ReadPassphrase reads a passphrase securely from the terminal without echoing
// This version is for unlocking and doesn't require confirmation // This version is for unlocking and doesn't require confirmation
func ReadPassphrase(prompt string) (string, error) { // Returns a LockedBuffer containing the passphrase for secure memory handling
func ReadPassphrase(prompt string) (*memguard.LockedBuffer, error) {
// Check if stdin is a terminal // Check if stdin is a terminal
if !term.IsTerminal(syscall.Stdin) { if !term.IsTerminal(syscall.Stdin) {
// Not a terminal - never read passphrases from piped input for security reasons // Not a terminal - never read passphrases from piped input for security reasons
return "", fmt.Errorf("cannot read passphrase from non-terminal stdin " + return nil, fmt.Errorf("cannot read passphrase from non-terminal stdin " +
"(piped input or script). Please set the SB_UNLOCK_PASSPHRASE " + "(piped input or script). Please set the SB_UNLOCK_PASSPHRASE " +
"environment variable or run interactively") "environment variable or run interactively")
} }
// stdin is a terminal, check if stderr is also a terminal for interactive prompting // stdin is a terminal, check if stderr is also a terminal for interactive prompting
if !term.IsTerminal(syscall.Stderr) { if !term.IsTerminal(syscall.Stderr) {
return "", fmt.Errorf("cannot prompt for passphrase: stderr is not a terminal " + return nil, fmt.Errorf("cannot prompt for passphrase: stderr is not a terminal " +
"(running in non-interactive mode). Please set the SB_UNLOCK_PASSPHRASE " + "(running in non-interactive mode). Please set the SB_UNLOCK_PASSPHRASE " +
"environment variable") "environment variable")
} }
@ -101,13 +129,21 @@ func ReadPassphrase(prompt string) (string, error) {
fmt.Fprint(os.Stderr, prompt) // Write prompt to stderr, not stdout fmt.Fprint(os.Stderr, prompt) // Write prompt to stderr, not stdout
passphrase, err := term.ReadPassword(syscall.Stdin) passphrase, err := term.ReadPassword(syscall.Stdin)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read passphrase: %w", err) return nil, fmt.Errorf("failed to read passphrase: %w", err)
} }
fmt.Fprintln(os.Stderr) // Print newline to stderr since ReadPassword doesn't echo fmt.Fprintln(os.Stderr) // Print newline to stderr since ReadPassword doesn't echo
if len(passphrase) == 0 { if len(passphrase) == 0 {
return "", fmt.Errorf("passphrase cannot be empty") return nil, fmt.Errorf("passphrase cannot be empty")
} }
return string(passphrase), nil // Create a secure buffer and copy the passphrase
secureBuffer := memguard.NewBufferFromBytes(passphrase)
// Clear the original passphrase slice
for i := range passphrase {
passphrase[i] = 0
}
return secureBuffer, nil
} }

View File

@ -29,6 +29,7 @@ func InitDebugLogging() {
if !debugEnabled { if !debugEnabled {
// Create a no-op logger that discards all output // Create a no-op logger that discards all output
debugLogger = slog.New(slog.NewTextHandler(io.Discard, nil)) debugLogger = slog.New(slog.NewTextHandler(io.Discard, nil))
return return
} }

View File

@ -48,6 +48,7 @@ func DetermineStateDir(customConfigDir string) string {
if err != nil { if err != nil {
// Fallback to a reasonable default if we can't determine user config dir // Fallback to a reasonable default if we can't determine user config dir
homeDir, _ := os.UserHomeDir() homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, ".config", AppID) return filepath.Join(homeDir, ".config", AppID)
} }

View File

@ -13,6 +13,7 @@ import (
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -28,7 +29,7 @@ var keychainItemNameRegex = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
type KeychainUnlockerMetadata struct { type KeychainUnlockerMetadata struct {
UnlockerMetadata UnlockerMetadata
// Keychain item name // Keychain item name
KeychainItemName string `json:"keychain_item_name"` KeychainItemName string `json:"keychainItemName"`
} }
// KeychainUnlocker represents a macOS Keychain-protected unlocker // KeychainUnlocker represents a macOS Keychain-protected unlocker
@ -40,9 +41,9 @@ type KeychainUnlocker struct {
// KeychainData represents the data stored in the macOS keychain // KeychainData represents the data stored in the macOS keychain
type KeychainData struct { type KeychainData struct {
AgePublicKey string `json:"age_public_key"` AgePublicKey string `json:"agePublicKey"`
AgePrivKeyPassphrase string `json:"age_priv_key_passphrase"` AgePrivKeyPassphrase string `json:"agePrivKeyPassphrase"`
EncryptedLongtermKey string `json:"encrypted_longterm_key"` EncryptedLongtermKey string `json:"encryptedLongtermKey"`
} }
// GetIdentity implements Unlocker interface for Keychain-based unlockers // GetIdentity implements Unlocker interface for Keychain-based unlockers
@ -56,6 +57,7 @@ func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
keychainItemName, err := k.GetKeychainItemName() keychainItemName, err := k.GetKeychainItemName()
if err != nil { if err != nil {
Debug("Failed to get keychain item name", "error", err, "unlocker_id", k.GetID()) Debug("Failed to get keychain item name", "error", err, "unlocker_id", k.GetID())
return nil, fmt.Errorf("failed to get keychain item name: %w", err) return nil, fmt.Errorf("failed to get keychain item name: %w", err)
} }
@ -64,6 +66,7 @@ func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
keychainDataBytes, err := retrieveFromKeychain(keychainItemName) keychainDataBytes, err := retrieveFromKeychain(keychainItemName)
if err != nil { if err != nil {
Debug("Failed to retrieve data from keychain", "error", err, "keychain_item", keychainItemName) Debug("Failed to retrieve data from keychain", "error", err, "keychain_item", keychainItemName)
return nil, fmt.Errorf("failed to retrieve data from keychain: %w", err) return nil, fmt.Errorf("failed to retrieve data from keychain: %w", err)
} }
@ -76,6 +79,7 @@ func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
var keychainData KeychainData var keychainData KeychainData
if err := json.Unmarshal(keychainDataBytes, &keychainData); err != nil { if err := json.Unmarshal(keychainDataBytes, &keychainData); err != nil {
Debug("Failed to parse keychain data", "error", err, "unlocker_id", k.GetID()) Debug("Failed to parse keychain data", "error", err, "unlocker_id", k.GetID())
return nil, fmt.Errorf("failed to parse keychain data: %w", err) return nil, fmt.Errorf("failed to parse keychain data: %w", err)
} }
@ -88,6 +92,7 @@ func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
encryptedAgePrivKeyData, err := afero.ReadFile(k.fs, agePrivKeyPath) encryptedAgePrivKeyData, err := afero.ReadFile(k.fs, agePrivKeyPath)
if err != nil { if err != nil {
Debug("Failed to read encrypted age private key", "error", err, "path", agePrivKeyPath) Debug("Failed to read encrypted age private key", "error", err, "path", agePrivKeyPath)
return nil, fmt.Errorf("failed to read encrypted age private key: %w", err) return nil, fmt.Errorf("failed to read encrypted age private key: %w", err)
} }
@ -98,9 +103,14 @@ func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
// Step 5: Decrypt the age private key using the passphrase from keychain // Step 5: Decrypt the age private key using the passphrase from keychain
Debug("Decrypting age private key with keychain passphrase", "unlocker_id", k.GetID()) Debug("Decrypting age private key with keychain passphrase", "unlocker_id", k.GetID())
agePrivKeyData, err := DecryptWithPassphrase(encryptedAgePrivKeyData, keychainData.AgePrivKeyPassphrase) // Create secure buffer for the keychain passphrase
passphraseBuffer := memguard.NewBufferFromBytes([]byte(keychainData.AgePrivKeyPassphrase))
defer passphraseBuffer.Destroy()
agePrivKeyData, err := DecryptWithPassphrase(encryptedAgePrivKeyData, passphraseBuffer)
if err != nil { if err != nil {
Debug("Failed to decrypt age private key with keychain passphrase", "error", err, "unlocker_id", k.GetID()) Debug("Failed to decrypt age private key with keychain passphrase", "error", err, "unlocker_id", k.GetID())
return nil, fmt.Errorf("failed to decrypt age private key with keychain passphrase: %w", err) return nil, fmt.Errorf("failed to decrypt age private key with keychain passphrase: %w", err)
} }
@ -111,9 +121,20 @@ func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
// Step 6: Parse the decrypted age private key // Step 6: Parse the decrypted age private key
Debug("Parsing decrypted age private key", "unlocker_id", k.GetID()) Debug("Parsing decrypted age private key", "unlocker_id", k.GetID())
ageIdentity, err := age.ParseX25519Identity(string(agePrivKeyData))
// Create a secure buffer for the private key data
agePrivKeyBuffer := memguard.NewBufferFromBytes(agePrivKeyData)
defer agePrivKeyBuffer.Destroy()
// Clear the original private key data
for i := range agePrivKeyData {
agePrivKeyData[i] = 0
}
ageIdentity, err := age.ParseX25519Identity(agePrivKeyBuffer.String())
if err != nil { if err != nil {
Debug("Failed to parse age private key", "error", err, "unlocker_id", k.GetID()) Debug("Failed to parse age private key", "error", err, "unlocker_id", k.GetID())
return nil, fmt.Errorf("failed to parse age private key: %w", err) return nil, fmt.Errorf("failed to parse age private key: %w", err)
} }
@ -159,6 +180,7 @@ func (k *KeychainUnlocker) Remove() error {
keychainItemName, err := k.GetKeychainItemName() keychainItemName, err := k.GetKeychainItemName()
if err != nil { if err != nil {
Debug("Failed to get keychain item name during removal", "error", err, "unlocker_id", k.GetID()) Debug("Failed to get keychain item name during removal", "error", err, "unlocker_id", k.GetID())
return fmt.Errorf("failed to get keychain item name: %w", err) return fmt.Errorf("failed to get keychain item name: %w", err)
} }
@ -166,6 +188,7 @@ func (k *KeychainUnlocker) Remove() error {
Debug("Removing keychain item", "keychain_item", keychainItemName) Debug("Removing keychain item", "keychain_item", keychainItemName)
if err := deleteFromKeychain(keychainItemName); err != nil { if err := deleteFromKeychain(keychainItemName); err != nil {
Debug("Failed to remove keychain item", "error", err, "keychain_item", keychainItemName) Debug("Failed to remove keychain item", "error", err, "keychain_item", keychainItemName)
return fmt.Errorf("failed to remove keychain item: %w", err) return fmt.Errorf("failed to remove keychain item: %w", err)
} }
@ -173,6 +196,7 @@ func (k *KeychainUnlocker) Remove() error {
Debug("Removing keychain unlocker directory", "directory", k.Directory) Debug("Removing keychain unlocker directory", "directory", k.Directory)
if err := k.fs.RemoveAll(k.Directory); err != nil { if err := k.fs.RemoveAll(k.Directory); err != nil {
Debug("Failed to remove keychain unlocker directory", "error", err, "directory", k.Directory) Debug("Failed to remove keychain unlocker directory", "error", err, "directory", k.Directory)
return fmt.Errorf("failed to remove keychain unlocker directory: %w", err) return fmt.Errorf("failed to remove keychain unlocker directory: %w", err)
} }
@ -220,6 +244,72 @@ func generateKeychainUnlockerName(vaultName string) (string, error) {
return fmt.Sprintf("secret-%s-%s-%s", vaultName, hostname, enrollmentDate), nil return fmt.Sprintf("secret-%s-%s-%s", vaultName, hostname, enrollmentDate), nil
} }
// getLongTermPrivateKey retrieves the long-term private key either from environment or current unlocker
// Returns a LockedBuffer to ensure the private key is protected in memory
func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) (*memguard.LockedBuffer, error) {
// Check if mnemonic is available in environment variable
envMnemonic := os.Getenv(EnvMnemonic)
if envMnemonic != "" {
// Use mnemonic directly to derive long-term key
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0)
if err != nil {
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
}
// Return the private key in a secure buffer
return memguard.NewBufferFromBytes([]byte(ltIdentity.String())), nil
}
// Get the vault to access current unlocker
currentUnlocker, err := vault.GetCurrentUnlocker()
if err != nil {
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
}
// Get the current unlocker identity
currentUnlockerIdentity, err := currentUnlocker.GetIdentity()
if err != nil {
return nil, fmt.Errorf("failed to get current unlocker identity: %w", err)
}
// Get encrypted long-term key from current unlocker, handling different types
var encryptedLtPrivKey []byte
switch currentUnlocker := currentUnlocker.(type) {
case *PassphraseUnlocker:
// Read the encrypted long-term private key from passphrase unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current passphrase unlocker: %w", err)
}
case *PGPUnlocker:
// Read the encrypted long-term private key from PGP unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current PGP unlocker: %w", err)
}
case *KeychainUnlocker:
// Read the encrypted long-term private key from another keychain unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current keychain unlocker: %w", err)
}
default:
return nil, fmt.Errorf("unsupported current unlocker type for keychain unlocker creation")
}
// Decrypt long-term private key using current unlocker
ltPrivKeyData, err := DecryptWithIdentity(encryptedLtPrivKey, currentUnlockerIdentity)
if err != nil {
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
}
// Return the decrypted key in a secure buffer
return memguard.NewBufferFromBytes(ltPrivKeyData), nil
}
// CreateKeychainUnlocker creates a new keychain unlocker and stores it in the vault // CreateKeychainUnlocker creates a new keychain unlocker and stores it in the vault
func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, error) { func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, error) {
// Check if we're on macOS // Check if we're on macOS
@ -270,8 +360,15 @@ func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, er
} }
// Step 4: Encrypt age private key with the generated passphrase and store on disk // Step 4: Encrypt age private key with the generated passphrase and store on disk
agePrivateKeyBytes := []byte(ageIdentity.String()) // Create secure buffers for both the private key and passphrase
encryptedAgePrivKey, err := EncryptWithPassphrase(agePrivateKeyBytes, agePrivKeyPassphrase) agePrivKeyStr := ageIdentity.String()
agePrivKeyBuffer := memguard.NewBufferFromBytes([]byte(agePrivKeyStr))
defer agePrivKeyBuffer.Destroy()
passphraseBuffer := memguard.NewBufferFromBytes([]byte(agePrivKeyPassphrase))
defer passphraseBuffer.Destroy()
encryptedAgePrivKey, err := EncryptWithPassphrase(agePrivKeyBuffer.Bytes(), passphraseBuffer)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encrypt age private key with passphrase: %w", err) return nil, fmt.Errorf("failed to encrypt age private key with passphrase: %w", err)
} }
@ -282,63 +379,11 @@ func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, er
} }
// Step 5: Get or derive the long-term private key // Step 5: Get or derive the long-term private key
var ltPrivKeyData []byte ltPrivKeyData, err := getLongTermPrivateKey(fs, vault)
if err != nil {
// Check if mnemonic is available in environment variable return nil, err
if envMnemonic := os.Getenv(EnvMnemonic); envMnemonic != "" {
// Use mnemonic directly to derive long-term key
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0)
if err != nil {
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
}
ltPrivKeyData = []byte(ltIdentity.String())
} else {
// Get the vault to access current unlocker
currentUnlocker, err := vault.GetCurrentUnlocker()
if err != nil {
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
}
// Get the current unlocker identity
currentUnlockerIdentity, err := currentUnlocker.GetIdentity()
if err != nil {
return nil, fmt.Errorf("failed to get current unlocker identity: %w", err)
}
// Get encrypted long-term key from current unlocker, handling different types
var encryptedLtPrivKey []byte
switch currentUnlocker := currentUnlocker.(type) {
case *PassphraseUnlocker:
// Read the encrypted long-term private key from passphrase unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current passphrase unlocker: %w", err)
}
case *PGPUnlocker:
// Read the encrypted long-term private key from PGP unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current PGP unlocker: %w", err)
}
case *KeychainUnlocker:
// Read the encrypted long-term private key from another keychain unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current keychain unlocker: %w", err)
}
default:
return nil, fmt.Errorf("unsupported current unlocker type for keychain unlocker creation")
}
// Decrypt long-term private key using current unlocker
ltPrivKeyData, err = DecryptWithIdentity(encryptedLtPrivKey, currentUnlockerIdentity)
if err != nil {
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
}
} }
defer ltPrivKeyData.Destroy()
// Step 6: Encrypt long-term private key to the new age unlocker // Step 6: Encrypt long-term private key to the new age unlocker
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient()) encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())

View File

@ -8,11 +8,11 @@ import (
type VaultMetadata struct { type VaultMetadata struct {
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
DerivationIndex uint32 `json:"derivation_index"` DerivationIndex uint32 `json:"derivationIndex"`
// Double SHA256 hash of the actual long-term public key // Double SHA256 hash of the actual long-term public key
PublicKeyHash string `json:"public_key_hash,omitempty"` PublicKeyHash string `json:"publicKeyHash,omitempty"`
// Double SHA256 hash of index-0 key (for grouping vaults from same mnemonic) // Double SHA256 hash of index-0 key (for grouping vaults from same mnemonic)
MnemonicFamilyHash string `json:"mnemonic_family_hash,omitempty"` MnemonicFamilyHash string `json:"mnemonicFamilyHash,omitempty"`
} }
// UnlockerMetadata contains information about an unlocker // UnlockerMetadata contains information about an unlocker

View File

@ -9,6 +9,7 @@ import (
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -76,7 +77,9 @@ func TestPassphraseUnlockerWithRealFS(t *testing.T) {
// Test encrypting private key with passphrase // Test encrypting private key with passphrase
t.Run("EncryptPrivateKey", func(t *testing.T) { t.Run("EncryptPrivateKey", func(t *testing.T) {
privKeyData := []byte(agePrivateKey) privKeyData := []byte(agePrivateKey)
encryptedPrivKey, err := secret.EncryptWithPassphrase(privKeyData, testPassphrase) passphraseBuffer := memguard.NewBufferFromBytes([]byte(testPassphrase))
defer passphraseBuffer.Destroy()
encryptedPrivKey, err := secret.EncryptWithPassphrase(privKeyData, passphraseBuffer)
if err != nil { if err != nil {
t.Fatalf("Failed to encrypt private key: %v", err) t.Fatalf("Failed to encrypt private key: %v", err)
} }
@ -110,8 +113,9 @@ func TestPassphraseUnlockerWithRealFS(t *testing.T) {
t.Fatalf("Failed to parse recipient: %v", err) t.Fatalf("Failed to parse recipient: %v", err)
} }
ltPrivKeyData := []byte(ltIdentity.String()) ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyData, recipient) defer ltPrivKeyBuffer.Destroy()
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, recipient)
if err != nil { if err != nil {
t.Fatalf("Failed to encrypt long-term private key: %v", err) t.Fatalf("Failed to encrypt long-term private key: %v", err)
} }

View File

@ -7,6 +7,7 @@ import (
"path/filepath" "path/filepath"
"filippo.io/age" "filippo.io/age"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -15,7 +16,40 @@ type PassphraseUnlocker struct {
Directory string Directory string
Metadata UnlockerMetadata Metadata UnlockerMetadata
fs afero.Fs fs afero.Fs
Passphrase string Passphrase *memguard.LockedBuffer // Secure buffer for passphrase
}
// getPassphrase retrieves the passphrase from memory, environment, or user input
// Returns a LockedBuffer for secure memory handling
func (p *PassphraseUnlocker) getPassphrase() (*memguard.LockedBuffer, error) {
// First check if we already have the passphrase
if p.Passphrase != nil && p.Passphrase.IsAlive() {
Debug("Using in-memory passphrase", "unlocker_id", p.GetID())
// Return a copy of the passphrase buffer
return memguard.NewBufferFromBytes(p.Passphrase.Bytes()), nil
}
Debug("No passphrase in memory, checking environment")
// Check environment variable for passphrase
passphraseStr := os.Getenv(EnvUnlockPassphrase)
if passphraseStr != "" {
Debug("Using passphrase from environment", "unlocker_id", p.GetID())
// Convert to secure buffer
secureBuffer := memguard.NewBufferFromBytes([]byte(passphraseStr))
return secureBuffer, nil
}
Debug("No passphrase in environment, prompting user")
// Prompt for passphrase
secureBuffer, err := ReadPassphrase("Enter unlock passphrase: ")
if err != nil {
Debug("Failed to read passphrase", "error", err, "unlocker_id", p.GetID())
return nil, fmt.Errorf("failed to read passphrase: %w", err)
}
return secureBuffer, nil
} }
// GetIdentity implements Unlocker interface for passphrase-based unlockers // GetIdentity implements Unlocker interface for passphrase-based unlockers
@ -25,27 +59,11 @@ func (p *PassphraseUnlocker) GetIdentity() (*age.X25519Identity, error) {
slog.String("unlocker_type", p.GetType()), slog.String("unlocker_type", p.GetType()),
) )
// First check if we already have the passphrase passphraseBuffer, err := p.getPassphrase()
passphraseStr := p.Passphrase if err != nil {
if passphraseStr == "" { return nil, err
Debug("No passphrase in memory, checking environment")
// Check environment variable for passphrase
passphraseStr = os.Getenv(EnvUnlockPassphrase)
if passphraseStr == "" {
Debug("No passphrase in environment, prompting user")
// Prompt for passphrase
var err error
passphraseStr, err = ReadPassphrase("Enter unlock passphrase: ")
if err != nil {
Debug("Failed to read passphrase", "error", err, "unlocker_id", p.GetID())
return nil, fmt.Errorf("failed to read passphrase: %w", err)
}
} else {
Debug("Using passphrase from environment", "unlocker_id", p.GetID())
}
} else {
Debug("Using in-memory passphrase", "unlocker_id", p.GetID())
} }
defer passphraseBuffer.Destroy()
// Read encrypted private key of unlocker // Read encrypted private key of unlocker
unlockerPrivPath := filepath.Join(p.Directory, "priv.age") unlockerPrivPath := filepath.Join(p.Directory, "priv.age")
@ -54,6 +72,7 @@ func (p *PassphraseUnlocker) GetIdentity() (*age.X25519Identity, error) {
encryptedPrivKeyData, err := afero.ReadFile(p.fs, unlockerPrivPath) encryptedPrivKeyData, err := afero.ReadFile(p.fs, unlockerPrivPath)
if err != nil { if err != nil {
Debug("Failed to read passphrase unlocker private key", "error", err, "path", unlockerPrivPath) Debug("Failed to read passphrase unlocker private key", "error", err, "path", unlockerPrivPath)
return nil, fmt.Errorf("failed to read unlocker private key: %w", err) return nil, fmt.Errorf("failed to read unlocker private key: %w", err)
} }
@ -65,9 +84,10 @@ func (p *PassphraseUnlocker) GetIdentity() (*age.X25519Identity, error) {
Debug("Decrypting unlocker private key with passphrase", "unlocker_id", p.GetID()) Debug("Decrypting unlocker private key with passphrase", "unlocker_id", p.GetID())
// Decrypt the unlocker private key with passphrase // Decrypt the unlocker private key with passphrase
privKeyData, err := DecryptWithPassphrase(encryptedPrivKeyData, passphraseStr) privKeyData, err := DecryptWithPassphrase(encryptedPrivKeyData, passphraseBuffer)
if err != nil { if err != nil {
Debug("Failed to decrypt unlocker private key", "error", err, "unlocker_id", p.GetID()) Debug("Failed to decrypt unlocker private key", "error", err, "unlocker_id", p.GetID())
return nil, fmt.Errorf("failed to decrypt unlocker private key: %w", err) return nil, fmt.Errorf("failed to decrypt unlocker private key: %w", err)
} }
@ -78,9 +98,20 @@ func (p *PassphraseUnlocker) GetIdentity() (*age.X25519Identity, error) {
// Parse the decrypted private key // Parse the decrypted private key
Debug("Parsing decrypted unlocker identity", "unlocker_id", p.GetID()) Debug("Parsing decrypted unlocker identity", "unlocker_id", p.GetID())
identity, err := age.ParseX25519Identity(string(privKeyData))
// Create a secure buffer for the private key data
privKeyBuffer := memguard.NewBufferFromBytes(privKeyData)
defer privKeyBuffer.Destroy()
// Clear the original private key data
for i := range privKeyData {
privKeyData[i] = 0
}
identity, err := age.ParseX25519Identity(privKeyBuffer.String())
if err != nil { if err != nil {
Debug("Failed to parse unlocker private key", "error", err, "unlocker_id", p.GetID()) Debug("Failed to parse unlocker private key", "error", err, "unlocker_id", p.GetID())
return nil, fmt.Errorf("failed to parse unlocker private key: %w", err) return nil, fmt.Errorf("failed to parse unlocker private key: %w", err)
} }
@ -111,11 +142,17 @@ func (p *PassphraseUnlocker) GetDirectory() string {
func (p *PassphraseUnlocker) GetID() string { func (p *PassphraseUnlocker) GetID() string {
// Generate ID using creation timestamp: YYYY-MM-DD.HH.mm-passphrase // Generate ID using creation timestamp: YYYY-MM-DD.HH.mm-passphrase
createdAt := p.Metadata.CreatedAt createdAt := p.Metadata.CreatedAt
return fmt.Sprintf("%s-passphrase", createdAt.Format("2006-01-02.15.04")) return fmt.Sprintf("%s-passphrase", createdAt.Format("2006-01-02.15.04"))
} }
// Remove implements Unlocker interface - removes the passphrase unlocker // Remove implements Unlocker interface - removes the passphrase unlocker
func (p *PassphraseUnlocker) Remove() error { func (p *PassphraseUnlocker) Remove() error {
// Clean up the passphrase from memory if it exists
if p.Passphrase != nil && p.Passphrase.IsAlive() {
p.Passphrase.Destroy()
}
// For passphrase unlockers, we just need to remove the directory // For passphrase unlockers, we just need to remove the directory
// No external resources (like keychain items) to clean up // No external resources (like keychain items) to clean up
if err := p.fs.RemoveAll(p.Directory); err != nil { if err := p.fs.RemoveAll(p.Directory); err != nil {
@ -135,7 +172,12 @@ func NewPassphraseUnlocker(fs afero.Fs, directory string, metadata UnlockerMetad
} }
// CreatePassphraseUnlocker creates a new passphrase-protected unlocker // CreatePassphraseUnlocker creates a new passphrase-protected unlocker
func CreatePassphraseUnlocker(fs afero.Fs, stateDir string, passphrase string) (*PassphraseUnlocker, error) { // The passphrase must be provided as a LockedBuffer for security
func CreatePassphraseUnlocker(
fs afero.Fs,
stateDir string,
passphrase *memguard.LockedBuffer,
) (*PassphraseUnlocker, error) {
// Get current vault // Get current vault
currentVault, err := GetCurrentVault(fs, stateDir) currentVault, err := GetCurrentVault(fs, stateDir)
if err != nil { if err != nil {

View File

@ -16,6 +16,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -270,7 +271,9 @@ Passphrase: ` + testPassphrase + `
vlt.Unlock(ltIdentity) vlt.Unlock(ltIdentity)
// Create a passphrase unlocker first (to have current unlocker) // Create a passphrase unlocker first (to have current unlocker)
passUnlocker, err := vlt.CreatePassphraseUnlocker("test-passphrase") passphraseBuffer := memguard.NewBufferFromBytes([]byte("test-passphrase"))
defer passphraseBuffer.Destroy()
passUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil { if err != nil {
t.Fatalf("Failed to create passphrase unlocker: %v", err) t.Fatalf("Failed to create passphrase unlocker: %v", err)
} }
@ -357,9 +360,9 @@ Passphrase: ` + testPassphrase + `
var metadata struct { var metadata struct {
ID string `json:"id"` ID string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"createdAt"`
Flags []string `json:"flags"` Flags []string `json:"flags"`
GPGKeyID string `json:"gpg_key_id"` GPGKeyID string `json:"gpgKeyId"`
} }
if err := json.Unmarshal(metadataBytes, &metadata); err != nil { if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
@ -396,7 +399,7 @@ Passphrase: ` + testPassphrase + `
// Create PGP metadata with GPG key ID // Create PGP metadata with GPG key ID
type PGPUnlockerMetadata struct { type PGPUnlockerMetadata struct {
secret.UnlockerMetadata secret.UnlockerMetadata
GPGKeyID string `json:"gpg_key_id"` GPGKeyID string `json:"gpgKeyId"`
} }
pgpMetadata := PGPUnlockerMetadata{ pgpMetadata := PGPUnlockerMetadata{

View File

@ -12,7 +12,7 @@ import (
"time" "time"
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/pkg/agehd" "github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -44,7 +44,7 @@ var (
type PGPUnlockerMetadata struct { type PGPUnlockerMetadata struct {
UnlockerMetadata UnlockerMetadata
// GPG key ID used for encryption // GPG key ID used for encryption
GPGKeyID string `json:"gpg_key_id"` GPGKeyID string `json:"gpgKeyId"`
} }
// PGPUnlocker represents a PGP-protected unlocker // PGPUnlocker represents a PGP-protected unlocker
@ -68,6 +68,7 @@ func (p *PGPUnlocker) GetIdentity() (*age.X25519Identity, error) {
encryptedAgePrivKeyData, err := afero.ReadFile(p.fs, agePrivKeyPath) encryptedAgePrivKeyData, err := afero.ReadFile(p.fs, agePrivKeyPath)
if err != nil { if err != nil {
Debug("Failed to read PGP-encrypted age private key", "error", err, "path", agePrivKeyPath) Debug("Failed to read PGP-encrypted age private key", "error", err, "path", agePrivKeyPath)
return nil, fmt.Errorf("failed to read encrypted age private key: %w", err) return nil, fmt.Errorf("failed to read encrypted age private key: %w", err)
} }
@ -81,6 +82,7 @@ func (p *PGPUnlocker) GetIdentity() (*age.X25519Identity, error) {
agePrivKeyData, err := GPGDecryptFunc(encryptedAgePrivKeyData) agePrivKeyData, err := GPGDecryptFunc(encryptedAgePrivKeyData)
if err != nil { if err != nil {
Debug("Failed to decrypt age private key with GPG", "error", err, "unlocker_id", p.GetID()) Debug("Failed to decrypt age private key with GPG", "error", err, "unlocker_id", p.GetID())
return nil, fmt.Errorf("failed to decrypt age private key with GPG: %w", err) return nil, fmt.Errorf("failed to decrypt age private key with GPG: %w", err)
} }
@ -94,6 +96,7 @@ func (p *PGPUnlocker) GetIdentity() (*age.X25519Identity, error) {
ageIdentity, err := age.ParseX25519Identity(string(agePrivKeyData)) ageIdentity, err := age.ParseX25519Identity(string(agePrivKeyData))
if err != nil { if err != nil {
Debug("Failed to parse age private key", "error", err, "unlocker_id", p.GetID()) Debug("Failed to parse age private key", "error", err, "unlocker_id", p.GetID())
return nil, fmt.Errorf("failed to parse age private key: %w", err) return nil, fmt.Errorf("failed to parse age private key: %w", err)
} }
@ -227,56 +230,11 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc
} }
// Step 3: Get or derive the long-term private key // Step 3: Get or derive the long-term private key
var ltPrivKeyData []byte ltPrivKeyData, err := getLongTermPrivateKey(fs, vault)
if err != nil {
// Check if mnemonic is available in environment variable return nil, err
if envMnemonic := os.Getenv(EnvMnemonic); envMnemonic != "" {
// Use mnemonic directly to derive long-term key
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0)
if err != nil {
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
}
ltPrivKeyData = []byte(ltIdentity.String())
} else {
// Get the vault to access current unlocker
currentUnlocker, err := vault.GetCurrentUnlocker()
if err != nil {
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
}
// Get the current unlocker identity
currentUnlockerIdentity, err := currentUnlocker.GetIdentity()
if err != nil {
return nil, fmt.Errorf("failed to get current unlocker identity: %w", err)
}
// Get encrypted long-term key from current unlocker, handling different types
var encryptedLtPrivKey []byte
switch currentUnlocker := currentUnlocker.(type) {
case *PassphraseUnlocker:
// Read the encrypted long-term private key from passphrase unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current passphrase unlocker: %w", err)
}
case *PGPUnlocker:
// Read the encrypted long-term private key from PGP unlocker
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
if err != nil {
return nil, fmt.Errorf("failed to read encrypted long-term key from current PGP unlocker: %w", err)
}
default:
return nil, fmt.Errorf("unsupported current unlocker type for PGP unlocker creation")
}
// Step 6: Decrypt long-term private key using current unlocker
ltPrivKeyData, err = DecryptWithIdentity(encryptedLtPrivKey, currentUnlockerIdentity)
if err != nil {
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
}
} }
defer ltPrivKeyData.Destroy()
// Step 7: Encrypt long-term private key to the new age unlocker // Step 7: Encrypt long-term private key to the new age unlocker
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient()) encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())
@ -291,8 +249,11 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc
} }
// Step 8: Encrypt age private key to the GPG key ID // Step 8: Encrypt age private key to the GPG key ID
agePrivateKeyBytes := []byte(ageIdentity.String()) // Use memguard to protect the private key in memory
encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBytes, gpgKeyID) agePrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(ageIdentity.String()))
defer agePrivateKeyBuffer.Destroy()
encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBuffer.Bytes(), gpgKeyID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encrypt age private key with GPG: %w", err) return nil, fmt.Errorf("failed to encrypt age private key with GPG: %w", err)
} }

View File

@ -11,17 +11,18 @@ import (
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
// VaultInterface defines the interface that vault implementations must satisfy // VaultInterface defines the interface that vault implementations must satisfy
type VaultInterface interface { type VaultInterface interface {
GetDirectory() (string, error) GetDirectory() (string, error)
AddSecret(name string, value []byte, force bool) error AddSecret(name string, value *memguard.LockedBuffer, force bool) error
GetName() string GetName() string
GetFilesystem() afero.Fs GetFilesystem() afero.Fs
GetCurrentUnlocker() (Unlocker, error) GetCurrentUnlocker() (Unlocker, error)
CreatePassphraseUnlocker(passphrase string) (*PassphraseUnlocker, error) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*PassphraseUnlocker, error)
} }
// Secret represents a secret in a vault // Secret represents a secret in a vault
@ -71,9 +72,15 @@ func (s *Secret) Save(value []byte, force bool) error {
slog.Bool("force", force), slog.Bool("force", force),
) )
err := s.vault.AddSecret(s.Name, value, force) // Create a secure buffer for the value - note that the caller
// should ideally pass a LockedBuffer directly to vault.AddSecret
valueBuffer := memguard.NewBufferFromBytes(value)
defer valueBuffer.Destroy()
err := s.vault.AddSecret(s.Name, valueBuffer, force)
if err != nil { if err != nil {
Debug("Failed to save secret", "error", err, "secret_name", s.Name) Debug("Failed to save secret", "error", err, "secret_name", s.Name)
return err return err
} }
@ -93,10 +100,12 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
exists, err := s.Exists() exists, err := s.Exists()
if err != nil { if err != nil {
Debug("Failed to check if secret exists during GetValue", "error", err, "secret_name", s.Name) Debug("Failed to check if secret exists during GetValue", "error", err, "secret_name", s.Name)
return nil, fmt.Errorf("failed to check if secret exists: %w", err) return nil, fmt.Errorf("failed to check if secret exists: %w", err)
} }
if !exists { if !exists {
Debug("Secret not found during GetValue", "secret_name", s.Name, "vault_name", s.vault.GetName()) Debug("Secret not found during GetValue", "secret_name", s.Name, "vault_name", s.vault.GetName())
return nil, fmt.Errorf("secret %s not found", s.Name) return nil, fmt.Errorf("secret %s not found", s.Name)
} }
@ -106,6 +115,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
currentVersion, err := GetCurrentVersion(s.vault.GetFilesystem(), s.Directory) currentVersion, err := GetCurrentVersion(s.vault.GetFilesystem(), s.Directory)
if err != nil { if err != nil {
Debug("Failed to get current version", "error", err, "secret_name", s.Name) Debug("Failed to get current version", "error", err, "secret_name", s.Name)
return nil, fmt.Errorf("failed to get current version: %w", err) return nil, fmt.Errorf("failed to get current version: %w", err)
} }
@ -120,6 +130,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
vaultDir, err := s.vault.GetDirectory() vaultDir, err := s.vault.GetDirectory()
if err != nil { if err != nil {
Debug("Failed to get vault directory", "error", err, "secret_name", s.Name) Debug("Failed to get vault directory", "error", err, "secret_name", s.Name)
return nil, fmt.Errorf("failed to get vault directory: %w", err) return nil, fmt.Errorf("failed to get vault directory: %w", err)
} }
@ -128,12 +139,14 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
metadataBytes, err := afero.ReadFile(s.vault.GetFilesystem(), metadataPath) metadataBytes, err := afero.ReadFile(s.vault.GetFilesystem(), metadataPath)
if err != nil { if err != nil {
Debug("Failed to read vault metadata", "error", err, "path", metadataPath) Debug("Failed to read vault metadata", "error", err, "path", metadataPath)
return nil, fmt.Errorf("failed to read vault metadata: %w", err) return nil, fmt.Errorf("failed to read vault metadata: %w", err)
} }
var metadata VaultMetadata var metadata VaultMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil { if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
Debug("Failed to parse vault metadata", "error", err, "secret_name", s.Name) Debug("Failed to parse vault metadata", "error", err, "secret_name", s.Name)
return nil, fmt.Errorf("failed to parse vault metadata: %w", err) return nil, fmt.Errorf("failed to parse vault metadata: %w", err)
} }
@ -147,6 +160,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex) ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex)
if err != nil { if err != nil {
Debug("Failed to derive long-term key from mnemonic for secret", "error", err, "secret_name", s.Name) Debug("Failed to derive long-term key from mnemonic for secret", "error", err, "secret_name", s.Name)
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err) return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
} }
@ -161,6 +175,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
// Use the provided unlocker to get the vault's long-term private key // Use the provided unlocker to get the vault's long-term private key
if unlocker == nil { if unlocker == nil {
Debug("No unlocker provided for secret decryption", "secret_name", s.Name) Debug("No unlocker provided for secret decryption", "secret_name", s.Name)
return nil, fmt.Errorf("unlocker required to decrypt secret") return nil, fmt.Errorf("unlocker required to decrypt secret")
} }
@ -174,6 +189,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
unlockIdentity, err := unlocker.GetIdentity() unlockIdentity, err := unlocker.GetIdentity()
if err != nil { if err != nil {
Debug("Failed to get unlocker identity", "error", err, "secret_name", s.Name, "unlocker_type", unlocker.GetType()) Debug("Failed to get unlocker identity", "error", err, "secret_name", s.Name, "unlocker_type", unlocker.GetType())
return nil, fmt.Errorf("failed to get unlocker identity: %w", err) return nil, fmt.Errorf("failed to get unlocker identity: %w", err)
} }
@ -184,6 +200,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
encryptedLtPrivKey, err := afero.ReadFile(s.vault.GetFilesystem(), encryptedLtPrivKeyPath) encryptedLtPrivKey, err := afero.ReadFile(s.vault.GetFilesystem(), encryptedLtPrivKeyPath)
if err != nil { if err != nil {
Debug("Failed to read encrypted long-term private key", "error", err, "path", encryptedLtPrivKeyPath) Debug("Failed to read encrypted long-term private key", "error", err, "path", encryptedLtPrivKeyPath)
return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err) return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err)
} }
@ -192,6 +209,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
ltPrivKeyData, err := DecryptWithIdentity(encryptedLtPrivKey, unlockIdentity) ltPrivKeyData, err := DecryptWithIdentity(encryptedLtPrivKey, unlockIdentity)
if err != nil { if err != nil {
Debug("Failed to decrypt long-term private key", "error", err, "secret_name", s.Name) Debug("Failed to decrypt long-term private key", "error", err, "secret_name", s.Name)
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err) return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
} }
@ -200,6 +218,7 @@ func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
ltIdentity, err := age.ParseX25519Identity(string(ltPrivKeyData)) ltIdentity, err := age.ParseX25519Identity(string(ltPrivKeyData))
if err != nil { if err != nil {
Debug("Failed to parse long-term private key", "error", err, "secret_name", s.Name) Debug("Failed to parse long-term private key", "error", err, "secret_name", s.Name)
return nil, fmt.Errorf("failed to parse long-term private key: %w", err) return nil, fmt.Errorf("failed to parse long-term private key: %w", err)
} }
@ -228,12 +247,14 @@ func (s *Secret) LoadMetadata() error {
// GetMetadata returns the secret metadata (deprecated) // GetMetadata returns the secret metadata (deprecated)
func (s *Secret) GetMetadata() Metadata { func (s *Secret) GetMetadata() Metadata {
Debug("GetMetadata called but is deprecated in versioned model", "secret_name", s.Name) Debug("GetMetadata called but is deprecated in versioned model", "secret_name", s.Name)
return s.Metadata return s.Metadata
} }
// GetEncryptedData is deprecated - data is now stored in versions // GetEncryptedData is deprecated - data is now stored in versions
func (s *Secret) GetEncryptedData() ([]byte, error) { func (s *Secret) GetEncryptedData() ([]byte, error) {
Debug("GetEncryptedData called but is deprecated in versioned model", "secret_name", s.Name) Debug("GetEncryptedData called but is deprecated in versioned model", "secret_name", s.Name)
return nil, fmt.Errorf("GetEncryptedData is deprecated - use version-specific methods") return nil, fmt.Errorf("GetEncryptedData is deprecated - use version-specific methods")
} }
@ -248,11 +269,13 @@ func (s *Secret) Exists() (bool, error) {
exists, err := afero.DirExists(s.vault.GetFilesystem(), s.Directory) exists, err := afero.DirExists(s.vault.GetFilesystem(), s.Directory)
if err != nil { if err != nil {
Debug("Failed to check secret directory existence", "error", err, "secret_dir", s.Directory) Debug("Failed to check secret directory existence", "error", err, "secret_dir", s.Directory)
return false, err return false, err
} }
if !exists { if !exists {
Debug("Secret directory does not exist", "secret_dir", s.Directory) Debug("Secret directory does not exist", "secret_dir", s.Directory)
return false, nil return false, nil
} }
@ -260,6 +283,7 @@ func (s *Secret) Exists() (bool, error) {
_, err = GetCurrentVersion(s.vault.GetFilesystem(), s.Directory) _, err = GetCurrentVersion(s.vault.GetFilesystem(), s.Directory)
if err != nil { if err != nil {
Debug("No current version found", "error", err, "secret_name", s.Name) Debug("No current version found", "error", err, "secret_name", s.Name)
return false, nil return false, nil
} }

View File

@ -9,6 +9,7 @@ import (
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -25,7 +26,7 @@ func (m *MockVault) GetDirectory() (string, error) {
return m.directory, nil return m.directory, nil
} }
func (m *MockVault) AddSecret(name string, value []byte, _ bool) error { func (m *MockVault) AddSecret(name string, value *memguard.LockedBuffer, _ bool) error {
// Create secret directory with proper storage name conversion // Create secret directory with proper storage name conversion
storageName := strings.ReplaceAll(name, "/", "%") storageName := strings.ReplaceAll(name, "/", "%")
secretDir := filepath.Join(m.directory, "secrets.d", storageName) secretDir := filepath.Join(m.directory, "secrets.d", storageName)
@ -74,7 +75,7 @@ func (m *MockVault) AddSecret(name string, value []byte, _ bool) error {
return err return err
} }
// Encrypt value to version's public key // Encrypt value to version's public key (value is already a LockedBuffer)
encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient()) encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient())
if err != nil { if err != nil {
return err return err
@ -87,7 +88,9 @@ func (m *MockVault) AddSecret(name string, value []byte, _ bool) error {
} }
// Encrypt version private key to long-term public key // Encrypt version private key to long-term public key
encryptedPrivKey, err := EncryptToRecipient([]byte(versionIdentity.String()), ltIdentity.Recipient()) versionPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String()))
defer versionPrivKeyBuffer.Destroy()
encryptedPrivKey, err := EncryptToRecipient(versionPrivKeyBuffer, ltIdentity.Recipient())
if err != nil { if err != nil {
return err return err
} }
@ -120,7 +123,7 @@ func (m *MockVault) GetCurrentUnlocker() (Unlocker, error) {
return nil, nil return nil, nil
} }
func (m *MockVault) CreatePassphraseUnlocker(_ string) (*PassphraseUnlocker, error) { func (m *MockVault) CreatePassphraseUnlocker(_ *memguard.LockedBuffer) (*PassphraseUnlocker, error) {
return nil, nil return nil, nil
} }
@ -179,9 +182,13 @@ func TestPerSecretKeyFunctionality(t *testing.T) {
secretName := "test-secret" secretName := "test-secret"
secretValue := []byte("this is a test secret value") secretValue := []byte("this is a test secret value")
// Create a secure buffer for the test value
valueBuffer := memguard.NewBufferFromBytes(secretValue)
defer valueBuffer.Destroy()
// Test AddSecret // Test AddSecret
t.Run("AddSecret", func(t *testing.T) { t.Run("AddSecret", func(t *testing.T) {
err := vault.AddSecret(secretName, secretValue, false) err := vault.AddSecret(secretName, valueBuffer, false)
if err != nil { if err != nil {
t.Fatalf("AddSecret failed: %v", err) t.Fatalf("AddSecret failed: %v", err)
} }

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"filippo.io/age" "filippo.io/age"
"github.com/awnumar/memguard"
"github.com/oklog/ulid/v2" "github.com/oklog/ulid/v2"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -89,17 +90,24 @@ func GenerateVersionName(fs afero.Fs, secretDir string) (string, error) {
prefix := today + "." prefix := today + "."
for _, entry := range entries { for _, entry := range entries {
if entry.IsDir() && strings.HasPrefix(entry.Name(), prefix) { // Skip non-directories and those without correct prefix
// Extract serial number if !entry.IsDir() || !strings.HasPrefix(entry.Name(), prefix) {
parts := strings.Split(entry.Name(), ".") continue
if len(parts) == versionNameParts { }
var serial int
if _, err := fmt.Sscanf(parts[1], "%03d", &serial); err == nil { // Extract serial number
if serial > maxSerial { parts := strings.Split(entry.Name(), ".")
maxSerial = serial if len(parts) != versionNameParts {
} continue
} }
}
var serial int
if _, err := fmt.Sscanf(parts[1], "%03d", &serial); err != nil {
continue
}
if serial > maxSerial {
maxSerial = serial
} }
} }
@ -113,11 +121,15 @@ func GenerateVersionName(fs afero.Fs, secretDir string) (string, error) {
} }
// Save saves the version metadata and value // Save saves the version metadata and value
func (sv *Version) Save(value []byte) error { func (sv *Version) Save(value *memguard.LockedBuffer) error {
if value == nil {
return fmt.Errorf("value buffer is nil")
}
DebugWith("Saving secret version", DebugWith("Saving secret version",
slog.String("secret_name", sv.SecretName), slog.String("secret_name", sv.SecretName),
slog.String("version", sv.Version), slog.String("version", sv.Version),
slog.Int("value_length", len(value)), slog.Int("value_length", value.Size()),
) )
fs := sv.vault.GetFilesystem() fs := sv.vault.GetFilesystem()
@ -125,6 +137,7 @@ func (sv *Version) Save(value []byte) error {
// Create version directory // Create version directory
if err := fs.MkdirAll(sv.Directory, DirPerms); err != nil { if err := fs.MkdirAll(sv.Directory, DirPerms); err != nil {
Debug("Failed to create version directory", "error", err, "dir", sv.Directory) Debug("Failed to create version directory", "error", err, "dir", sv.Directory)
return fmt.Errorf("failed to create version directory: %w", err) return fmt.Errorf("failed to create version directory: %w", err)
} }
@ -133,11 +146,14 @@ func (sv *Version) Save(value []byte) error {
versionIdentity, err := age.GenerateX25519Identity() versionIdentity, err := age.GenerateX25519Identity()
if err != nil { if err != nil {
Debug("Failed to generate version keypair", "error", err, "version", sv.Version) Debug("Failed to generate version keypair", "error", err, "version", sv.Version)
return fmt.Errorf("failed to generate version keypair: %w", err) return fmt.Errorf("failed to generate version keypair: %w", err)
} }
versionPublicKey := versionIdentity.Recipient().String() versionPublicKey := versionIdentity.Recipient().String()
versionPrivateKey := versionIdentity.String() // Store private key in memguard buffer immediately
versionPrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String()))
defer versionPrivateKeyBuffer.Destroy()
DebugWith("Generated version keypair", DebugWith("Generated version keypair",
slog.String("version", sv.Version), slog.String("version", sv.Version),
@ -149,6 +165,7 @@ func (sv *Version) Save(value []byte) error {
Debug("Writing version public key", "path", pubKeyPath) Debug("Writing version public key", "path", pubKeyPath)
if err := afero.WriteFile(fs, pubKeyPath, []byte(versionPublicKey), FilePerms); err != nil { if err := afero.WriteFile(fs, pubKeyPath, []byte(versionPublicKey), FilePerms); err != nil {
Debug("Failed to write version public key", "error", err, "path", pubKeyPath) Debug("Failed to write version public key", "error", err, "path", pubKeyPath)
return fmt.Errorf("failed to write version public key: %w", err) return fmt.Errorf("failed to write version public key: %w", err)
} }
@ -157,6 +174,7 @@ func (sv *Version) Save(value []byte) error {
encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient()) encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient())
if err != nil { if err != nil {
Debug("Failed to encrypt version value", "error", err, "version", sv.Version) Debug("Failed to encrypt version value", "error", err, "version", sv.Version)
return fmt.Errorf("failed to encrypt version value: %w", err) return fmt.Errorf("failed to encrypt version value: %w", err)
} }
@ -165,6 +183,7 @@ func (sv *Version) Save(value []byte) error {
Debug("Writing encrypted version value", "path", valuePath) Debug("Writing encrypted version value", "path", valuePath)
if err := afero.WriteFile(fs, valuePath, encryptedValue, FilePerms); err != nil { if err := afero.WriteFile(fs, valuePath, encryptedValue, FilePerms); err != nil {
Debug("Failed to write encrypted version value", "error", err, "path", valuePath) Debug("Failed to write encrypted version value", "error", err, "path", valuePath)
return fmt.Errorf("failed to write encrypted version value: %w", err) return fmt.Errorf("failed to write encrypted version value: %w", err)
} }
@ -176,6 +195,7 @@ func (sv *Version) Save(value []byte) error {
ltPubKeyData, err := afero.ReadFile(fs, ltPubKeyPath) ltPubKeyData, err := afero.ReadFile(fs, ltPubKeyPath)
if err != nil { if err != nil {
Debug("Failed to read long-term public key", "error", err, "path", ltPubKeyPath) Debug("Failed to read long-term public key", "error", err, "path", ltPubKeyPath)
return fmt.Errorf("failed to read long-term public key: %w", err) return fmt.Errorf("failed to read long-term public key: %w", err)
} }
@ -183,14 +203,16 @@ func (sv *Version) Save(value []byte) error {
ltRecipient, err := age.ParseX25519Recipient(string(ltPubKeyData)) ltRecipient, err := age.ParseX25519Recipient(string(ltPubKeyData))
if err != nil { if err != nil {
Debug("Failed to parse long-term public key", "error", err) Debug("Failed to parse long-term public key", "error", err)
return fmt.Errorf("failed to parse long-term public key: %w", err) return fmt.Errorf("failed to parse long-term public key: %w", err)
} }
// Step 6: Encrypt the version's private key to the long-term public key // Step 6: Encrypt the version's private key to the long-term public key
Debug("Encrypting version private key to long-term public key", "version", sv.Version) Debug("Encrypting version private key to long-term public key", "version", sv.Version)
encryptedPrivKey, err := EncryptToRecipient([]byte(versionPrivateKey), ltRecipient) encryptedPrivKey, err := EncryptToRecipient(versionPrivateKeyBuffer, ltRecipient)
if err != nil { if err != nil {
Debug("Failed to encrypt version private key", "error", err, "version", sv.Version) Debug("Failed to encrypt version private key", "error", err, "version", sv.Version)
return fmt.Errorf("failed to encrypt version private key: %w", err) return fmt.Errorf("failed to encrypt version private key: %w", err)
} }
@ -199,6 +221,7 @@ func (sv *Version) Save(value []byte) error {
Debug("Writing encrypted version private key", "path", privKeyPath) Debug("Writing encrypted version private key", "path", privKeyPath)
if err := afero.WriteFile(fs, privKeyPath, encryptedPrivKey, FilePerms); err != nil { if err := afero.WriteFile(fs, privKeyPath, encryptedPrivKey, FilePerms); err != nil {
Debug("Failed to write encrypted version private key", "error", err, "path", privKeyPath) Debug("Failed to write encrypted version private key", "error", err, "path", privKeyPath)
return fmt.Errorf("failed to write encrypted version private key: %w", err) return fmt.Errorf("failed to write encrypted version private key: %w", err)
} }
@ -207,13 +230,18 @@ func (sv *Version) Save(value []byte) error {
metadataBytes, err := json.MarshalIndent(sv.Metadata, "", " ") metadataBytes, err := json.MarshalIndent(sv.Metadata, "", " ")
if err != nil { if err != nil {
Debug("Failed to marshal version metadata", "error", err) Debug("Failed to marshal version metadata", "error", err)
return fmt.Errorf("failed to marshal version metadata: %w", err) return fmt.Errorf("failed to marshal version metadata: %w", err)
} }
// Encrypt metadata to the version's public key // Encrypt metadata to the version's public key
encryptedMetadata, err := EncryptToRecipient(metadataBytes, versionIdentity.Recipient()) metadataBuffer := memguard.NewBufferFromBytes(metadataBytes)
defer metadataBuffer.Destroy()
encryptedMetadata, err := EncryptToRecipient(metadataBuffer, versionIdentity.Recipient())
if err != nil { if err != nil {
Debug("Failed to encrypt version metadata", "error", err, "version", sv.Version) Debug("Failed to encrypt version metadata", "error", err, "version", sv.Version)
return fmt.Errorf("failed to encrypt version metadata: %w", err) return fmt.Errorf("failed to encrypt version metadata: %w", err)
} }
@ -221,6 +249,7 @@ func (sv *Version) Save(value []byte) error {
Debug("Writing encrypted version metadata", "path", metadataPath) Debug("Writing encrypted version metadata", "path", metadataPath)
if err := afero.WriteFile(fs, metadataPath, encryptedMetadata, FilePerms); err != nil { if err := afero.WriteFile(fs, metadataPath, encryptedMetadata, FilePerms); err != nil {
Debug("Failed to write encrypted version metadata", "error", err, "path", metadataPath) Debug("Failed to write encrypted version metadata", "error", err, "path", metadataPath)
return fmt.Errorf("failed to write encrypted version metadata: %w", err) return fmt.Errorf("failed to write encrypted version metadata: %w", err)
} }
@ -243,6 +272,7 @@ func (sv *Version) LoadMetadata(ltIdentity *age.X25519Identity) error {
encryptedPrivKey, err := afero.ReadFile(fs, encryptedPrivKeyPath) encryptedPrivKey, err := afero.ReadFile(fs, encryptedPrivKeyPath)
if err != nil { if err != nil {
Debug("Failed to read encrypted version private key", "error", err, "path", encryptedPrivKeyPath) Debug("Failed to read encrypted version private key", "error", err, "path", encryptedPrivKeyPath)
return fmt.Errorf("failed to read encrypted version private key: %w", err) return fmt.Errorf("failed to read encrypted version private key: %w", err)
} }
@ -250,6 +280,7 @@ func (sv *Version) LoadMetadata(ltIdentity *age.X25519Identity) error {
versionPrivKeyData, err := DecryptWithIdentity(encryptedPrivKey, ltIdentity) versionPrivKeyData, err := DecryptWithIdentity(encryptedPrivKey, ltIdentity)
if err != nil { if err != nil {
Debug("Failed to decrypt version private key", "error", err, "version", sv.Version) Debug("Failed to decrypt version private key", "error", err, "version", sv.Version)
return fmt.Errorf("failed to decrypt version private key: %w", err) return fmt.Errorf("failed to decrypt version private key: %w", err)
} }
@ -257,6 +288,7 @@ func (sv *Version) LoadMetadata(ltIdentity *age.X25519Identity) error {
versionIdentity, err := age.ParseX25519Identity(string(versionPrivKeyData)) versionIdentity, err := age.ParseX25519Identity(string(versionPrivKeyData))
if err != nil { if err != nil {
Debug("Failed to parse version private key", "error", err, "version", sv.Version) Debug("Failed to parse version private key", "error", err, "version", sv.Version)
return fmt.Errorf("failed to parse version private key: %w", err) return fmt.Errorf("failed to parse version private key: %w", err)
} }
@ -265,6 +297,7 @@ func (sv *Version) LoadMetadata(ltIdentity *age.X25519Identity) error {
encryptedMetadata, err := afero.ReadFile(fs, encryptedMetadataPath) encryptedMetadata, err := afero.ReadFile(fs, encryptedMetadataPath)
if err != nil { if err != nil {
Debug("Failed to read encrypted version metadata", "error", err, "path", encryptedMetadataPath) Debug("Failed to read encrypted version metadata", "error", err, "path", encryptedMetadataPath)
return fmt.Errorf("failed to read encrypted version metadata: %w", err) return fmt.Errorf("failed to read encrypted version metadata: %w", err)
} }
@ -272,6 +305,7 @@ func (sv *Version) LoadMetadata(ltIdentity *age.X25519Identity) error {
metadataBytes, err := DecryptWithIdentity(encryptedMetadata, versionIdentity) metadataBytes, err := DecryptWithIdentity(encryptedMetadata, versionIdentity)
if err != nil { if err != nil {
Debug("Failed to decrypt version metadata", "error", err, "version", sv.Version) Debug("Failed to decrypt version metadata", "error", err, "version", sv.Version)
return fmt.Errorf("failed to decrypt version metadata: %w", err) return fmt.Errorf("failed to decrypt version metadata: %w", err)
} }
@ -279,6 +313,7 @@ func (sv *Version) LoadMetadata(ltIdentity *age.X25519Identity) error {
var metadata VersionMetadata var metadata VersionMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil { if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
Debug("Failed to unmarshal version metadata", "error", err, "version", sv.Version) Debug("Failed to unmarshal version metadata", "error", err, "version", sv.Version)
return fmt.Errorf("failed to unmarshal version metadata: %w", err) return fmt.Errorf("failed to unmarshal version metadata: %w", err)
} }
@ -310,6 +345,7 @@ func (sv *Version) GetValue(ltIdentity *age.X25519Identity) ([]byte, error) {
encryptedPrivKey, err := afero.ReadFile(fs, encryptedPrivKeyPath) encryptedPrivKey, err := afero.ReadFile(fs, encryptedPrivKeyPath)
if err != nil { if err != nil {
Debug("Failed to read encrypted version private key", "error", err, "path", encryptedPrivKeyPath) Debug("Failed to read encrypted version private key", "error", err, "path", encryptedPrivKeyPath)
return nil, fmt.Errorf("failed to read encrypted version private key: %w", err) return nil, fmt.Errorf("failed to read encrypted version private key: %w", err)
} }
Debug("Successfully read encrypted version private key", "path", encryptedPrivKeyPath, "size", len(encryptedPrivKey)) Debug("Successfully read encrypted version private key", "path", encryptedPrivKeyPath, "size", len(encryptedPrivKey))
@ -319,6 +355,7 @@ func (sv *Version) GetValue(ltIdentity *age.X25519Identity) ([]byte, error) {
versionPrivKeyData, err := DecryptWithIdentity(encryptedPrivKey, ltIdentity) versionPrivKeyData, err := DecryptWithIdentity(encryptedPrivKey, ltIdentity)
if err != nil { if err != nil {
Debug("Failed to decrypt version private key", "error", err, "version", sv.Version) Debug("Failed to decrypt version private key", "error", err, "version", sv.Version)
return nil, fmt.Errorf("failed to decrypt version private key: %w", err) return nil, fmt.Errorf("failed to decrypt version private key: %w", err)
} }
Debug("Successfully decrypted version private key", "version", sv.Version, "size", len(versionPrivKeyData)) Debug("Successfully decrypted version private key", "version", sv.Version, "size", len(versionPrivKeyData))
@ -327,6 +364,7 @@ func (sv *Version) GetValue(ltIdentity *age.X25519Identity) ([]byte, error) {
versionIdentity, err := age.ParseX25519Identity(string(versionPrivKeyData)) versionIdentity, err := age.ParseX25519Identity(string(versionPrivKeyData))
if err != nil { if err != nil {
Debug("Failed to parse version private key", "error", err, "version", sv.Version) Debug("Failed to parse version private key", "error", err, "version", sv.Version)
return nil, fmt.Errorf("failed to parse version private key: %w", err) return nil, fmt.Errorf("failed to parse version private key: %w", err)
} }
@ -336,6 +374,7 @@ func (sv *Version) GetValue(ltIdentity *age.X25519Identity) ([]byte, error) {
encryptedValue, err := afero.ReadFile(fs, encryptedValuePath) encryptedValue, err := afero.ReadFile(fs, encryptedValuePath)
if err != nil { if err != nil {
Debug("Failed to read encrypted version value", "error", err, "path", encryptedValuePath) Debug("Failed to read encrypted version value", "error", err, "path", encryptedValuePath)
return nil, fmt.Errorf("failed to read encrypted version value: %w", err) return nil, fmt.Errorf("failed to read encrypted version value: %w", err)
} }
Debug("Successfully read encrypted value", "path", encryptedValuePath, "size", len(encryptedValue)) Debug("Successfully read encrypted value", "path", encryptedValuePath, "size", len(encryptedValue))
@ -345,6 +384,7 @@ func (sv *Version) GetValue(ltIdentity *age.X25519Identity) ([]byte, error) {
value, err := DecryptWithIdentity(encryptedValue, versionIdentity) value, err := DecryptWithIdentity(encryptedValue, versionIdentity)
if err != nil { if err != nil {
Debug("Failed to decrypt version value", "error", err, "version", sv.Version) Debug("Failed to decrypt version value", "error", err, "version", sv.Version)
return nil, fmt.Errorf("failed to decrypt version value: %w", err) return nil, fmt.Errorf("failed to decrypt version value: %w", err)
} }

View File

@ -41,6 +41,7 @@ import (
"time" "time"
"filippo.io/age" "filippo.io/age"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -58,7 +59,7 @@ func (m *MockVersionVault) GetDirectory() (string, error) {
return filepath.Join(m.stateDir, "vaults.d", m.Name), nil return filepath.Join(m.stateDir, "vaults.d", m.Name), nil
} }
func (m *MockVersionVault) AddSecret(_ string, _ []byte, _ bool) error { func (m *MockVersionVault) AddSecret(_ string, _ *memguard.LockedBuffer, _ bool) error {
return fmt.Errorf("not implemented in mock") return fmt.Errorf("not implemented in mock")
} }
@ -74,7 +75,7 @@ func (m *MockVersionVault) GetCurrentUnlocker() (Unlocker, error) {
return nil, fmt.Errorf("not implemented in mock") return nil, fmt.Errorf("not implemented in mock")
} }
func (m *MockVersionVault) CreatePassphraseUnlocker(_ string) (*PassphraseUnlocker, error) { func (m *MockVersionVault) CreatePassphraseUnlocker(_ *memguard.LockedBuffer) (*PassphraseUnlocker, error) {
return nil, fmt.Errorf("not implemented in mock") return nil, fmt.Errorf("not implemented in mock")
} }
@ -164,7 +165,9 @@ func TestSecretVersionSave(t *testing.T) {
sv := NewVersion(vault, "test/secret", "20231215.001") sv := NewVersion(vault, "test/secret", "20231215.001")
testValue := []byte("test-secret-value") testValue := []byte("test-secret-value")
err = sv.Save(testValue) testBuffer := memguard.NewBufferFromBytes(testValue)
defer testBuffer.Destroy()
err = sv.Save(testBuffer)
require.NoError(t, err) require.NoError(t, err)
// Verify files were created // Verify files were created
@ -202,7 +205,9 @@ func TestSecretVersionLoadMetadata(t *testing.T) {
sv.Metadata.NotBefore = &epochPlusOne sv.Metadata.NotBefore = &epochPlusOne
sv.Metadata.NotAfter = &now sv.Metadata.NotAfter = &now
err = sv.Save([]byte("test-value")) testBuffer := memguard.NewBufferFromBytes([]byte("test-value"))
defer testBuffer.Destroy()
err = sv.Save(testBuffer)
require.NoError(t, err) require.NoError(t, err)
// Create new version object and load metadata // Create new version object and load metadata
@ -241,15 +246,19 @@ func TestSecretVersionGetValue(t *testing.T) {
// Create and save a version // Create and save a version
sv := NewVersion(vault, "test/secret", "20231215.001") sv := NewVersion(vault, "test/secret", "20231215.001")
originalValue := []byte("test-secret-value-12345") originalValue := []byte("test-secret-value-12345")
expectedValue := make([]byte, len(originalValue))
copy(expectedValue, originalValue)
err = sv.Save(originalValue) originalBuffer := memguard.NewBufferFromBytes(originalValue)
defer originalBuffer.Destroy()
err = sv.Save(originalBuffer)
require.NoError(t, err) require.NoError(t, err)
// Retrieve the value // Retrieve the value
retrievedValue, err := sv.GetValue(ltIdentity) retrievedValue, err := sv.GetValue(ltIdentity)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, originalValue, retrievedValue) assert.Equal(t, expectedValue, retrievedValue)
} }
func TestListVersions(t *testing.T) { func TestListVersions(t *testing.T) {

View File

@ -8,6 +8,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/internal/vault" "git.eeqj.de/sneak/secret/internal/vault"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -107,8 +108,13 @@ func TestVaultWithRealFilesystem(t *testing.T) {
// Create a secret with a deeply nested path // Create a secret with a deeply nested path
deepPath := "api/credentials/production/database/primary" deepPath := "api/credentials/production/database/primary"
secretValue := []byte("supersecretdbpassword") secretValue := []byte("supersecretdbpassword")
expectedValue := make([]byte, len(secretValue))
copy(expectedValue, secretValue)
err = vlt.AddSecret(deepPath, secretValue, false) secretBuffer := memguard.NewBufferFromBytes(secretValue)
defer secretBuffer.Destroy()
err = vlt.AddSecret(deepPath, secretBuffer, false)
if err != nil { if err != nil {
t.Fatalf("Failed to add secret with deep path: %v", err) t.Fatalf("Failed to add secret with deep path: %v", err)
} }
@ -137,9 +143,9 @@ func TestVaultWithRealFilesystem(t *testing.T) {
t.Fatalf("Failed to retrieve deep path secret: %v", err) t.Fatalf("Failed to retrieve deep path secret: %v", err)
} }
if string(retrievedValue) != string(secretValue) { if string(retrievedValue) != string(expectedValue) {
t.Errorf("Retrieved value doesn't match. Expected %q, got %q", t.Errorf("Retrieved value doesn't match. Expected %q, got %q",
string(secretValue), string(retrievedValue)) string(expectedValue), string(retrievedValue))
} }
}) })
@ -368,7 +374,11 @@ func TestVaultWithRealFilesystem(t *testing.T) {
// Add a secret to vault1 // Add a secret to vault1
secretName := "test-secret" secretName := "test-secret"
secretValue := []byte("secret in vault1") secretValue := []byte("secret in vault1")
if err := vault1.AddSecret(secretName, secretValue, false); err != nil {
secretBuffer := memguard.NewBufferFromBytes(secretValue)
defer secretBuffer.Destroy()
if err := vault1.AddSecret(secretName, secretBuffer, false); err != nil {
t.Fatalf("Failed to add secret to vault1: %v", err) t.Fatalf("Failed to add secret to vault1: %v", err)
} }

View File

@ -29,11 +29,21 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Helper function to add a secret to vault with proper buffer protection
func addTestSecret(t *testing.T, vault *Vault, name string, value []byte, force bool) {
t.Helper()
buffer := memguard.NewBufferFromBytes(value)
defer buffer.Destroy()
err := vault.AddSecret(name, buffer, force)
require.NoError(t, err)
}
// TestVersionIntegrationWorkflow tests the complete version workflow // TestVersionIntegrationWorkflow tests the complete version workflow
func TestVersionIntegrationWorkflow(t *testing.T) { func TestVersionIntegrationWorkflow(t *testing.T) {
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
@ -66,8 +76,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
// Step 1: Create initial version // Step 1: Create initial version
t.Run("create_initial_version", func(t *testing.T) { t.Run("create_initial_version", func(t *testing.T) {
err := vault.AddSecret(secretName, []byte("version-1-data"), false) addTestSecret(t, vault, secretName, []byte("version-1-data"), false)
require.NoError(t, err)
// Verify secret can be retrieved // Verify secret can be retrieved
value, err := vault.GetSecret(secretName) value, err := vault.GetSecret(secretName)
@ -108,8 +117,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
firstVersionName = versions[0] firstVersionName = versions[0]
// Create second version // Create second version
err = vault.AddSecret(secretName, []byte("version-2-data"), true) addTestSecret(t, vault, secretName, []byte("version-2-data"), true)
require.NoError(t, err)
// Verify new value is current // Verify new value is current
value, err := vault.GetSecret(secretName) value, err := vault.GetSecret(secretName)
@ -142,8 +150,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
t.Run("create_third_version", func(t *testing.T) { t.Run("create_third_version", func(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
err := vault.AddSecret(secretName, []byte("version-3-data"), true) addTestSecret(t, vault, secretName, []byte("version-3-data"), true)
require.NoError(t, err)
// Verify we now have three versions // Verify we now have three versions
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test") secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
@ -214,8 +221,7 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
secretDir := filepath.Join(vaultDir, "secrets.d", "limit%test", "versions") secretDir := filepath.Join(vaultDir, "secrets.d", "limit%test", "versions")
// Create 998 versions (we already have one from the first AddSecret) // Create 998 versions (we already have one from the first AddSecret)
err := vault.AddSecret(limitSecretName, []byte("initial"), false) addTestSecret(t, vault, limitSecretName, []byte("initial"), false)
require.NoError(t, err)
// Get today's date for consistent version names // Get today's date for consistent version names
today := time.Now().Format("20060102") today := time.Now().Format("20060102")
@ -255,7 +261,9 @@ func TestVersionIntegrationWorkflow(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
// Try to add secret without force when it exists // Try to add secret without force when it exists
err = vault.AddSecret(secretName, []byte("should-fail"), false) failBuffer := memguard.NewBufferFromBytes([]byte("should-fail"))
defer failBuffer.Destroy()
err = vault.AddSecret(secretName, failBuffer, false)
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "already exists") assert.Contains(t, err.Error(), "already exists")
}) })
@ -272,8 +280,7 @@ func TestVersionConcurrency(t *testing.T) {
secretName := "concurrent/test" secretName := "concurrent/test"
// Create initial version // Create initial version
err := vault.AddSecret(secretName, []byte("initial"), false) addTestSecret(t, vault, secretName, []byte("initial"), false)
require.NoError(t, err)
// Test concurrent reads // Test concurrent reads
t.Run("concurrent_reads", func(t *testing.T) { t.Run("concurrent_reads", func(t *testing.T) {
@ -326,8 +333,10 @@ func TestVersionCompatibility(t *testing.T) {
// Create old-style encrypted value directly in secret directory // Create old-style encrypted value directly in secret directory
testValue := []byte("legacy-value") testValue := []byte("legacy-value")
testValueBuffer := memguard.NewBufferFromBytes(testValue)
defer testValueBuffer.Destroy()
ltRecipient := ltIdentity.Recipient() ltRecipient := ltIdentity.Recipient()
encrypted, err := secret.EncryptToRecipient(testValue, ltRecipient) encrypted, err := secret.EncryptToRecipient(testValueBuffer, ltRecipient)
require.NoError(t, err) require.NoError(t, err)
valuePath := filepath.Join(secretDir, "value.age") valuePath := filepath.Join(secretDir, "value.age")

View File

@ -1,3 +1,4 @@
// Package vault provides functionality for managing encrypted vaults.
package vault package vault
import ( import (
@ -30,65 +31,62 @@ func isValidVaultName(name string) bool {
return matched return matched
} }
// resolveRelativeSymlink resolves a relative symlink target to an absolute path
func resolveRelativeSymlink(symlinkPath, _ string) (string, error) {
// Get the current directory before changing
originalDir, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("failed to get current directory: %w", err)
}
secret.Debug("Got current directory", "original_dir", originalDir)
// Change to the symlink's directory
symlinkDir := filepath.Dir(symlinkPath)
secret.Debug("Changing to symlink directory", "symlink_path", symlinkDir)
secret.Debug("About to call os.Chdir - this might hang if symlink is broken")
if err := os.Chdir(symlinkDir); err != nil {
return "", fmt.Errorf("failed to change to symlink directory: %w", err)
}
secret.Debug("Changed to symlink directory successfully - os.Chdir completed")
// Get the absolute path of the target
secret.Debug("Getting absolute path of current directory")
absolutePath, err := os.Getwd()
if err != nil {
// Try to restore original directory before returning error
_ = os.Chdir(originalDir)
return "", fmt.Errorf("failed to get absolute path: %w", err)
}
secret.Debug("Got absolute path", "absolute_path", absolutePath)
// Restore the original directory
secret.Debug("Restoring original directory", "original_dir", originalDir)
if err := os.Chdir(originalDir); err != nil {
return "", fmt.Errorf("failed to restore original directory: %w", err)
}
secret.Debug("Restored original directory successfully")
return absolutePath, nil
}
// ResolveVaultSymlink resolves the currentvault symlink by reading either the symlink target or file contents // ResolveVaultSymlink resolves the currentvault symlink by reading either the symlink target or file contents
// This function is designed to work on both Unix and Windows systems, as well as with in-memory filesystems // This function is designed to work on both Unix and Windows systems, as well as with in-memory filesystems
func ResolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) { func ResolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) {
secret.Debug("resolveVaultSymlink starting", "symlink_path", symlinkPath) secret.Debug("resolveVaultSymlink starting", "symlink_path", symlinkPath)
// First try to handle the path as a real symlink (works on Unix systems) // First try to handle the path as a real symlink (works on Unix systems)
if _, ok := fs.(*afero.OsFs); ok { _, isOsFs := fs.(*afero.OsFs)
secret.Debug("Using real filesystem symlink resolution") if isOsFs {
target, err := tryResolveOsSymlink(symlinkPath)
// Check if the symlink exists
secret.Debug("Checking symlink target", "symlink_path", symlinkPath)
target, err := os.Readlink(symlinkPath)
if err == nil { if err == nil {
secret.Debug("Symlink points to", "target", target)
// On real filesystem, we need to handle relative symlinks
// by resolving them relative to the symlink's directory
if !filepath.IsAbs(target) {
// Get the current directory before changing
originalDir, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("failed to get current directory: %w", err)
}
secret.Debug("Got current directory", "original_dir", originalDir)
// Change to the symlink's directory
symlinkDir := filepath.Dir(symlinkPath)
secret.Debug("Changing to symlink directory", "symlink_path", symlinkDir)
secret.Debug("About to call os.Chdir - this might hang if symlink is broken")
if err := os.Chdir(symlinkDir); err != nil {
return "", fmt.Errorf("failed to change to symlink directory: %w", err)
}
secret.Debug("Changed to symlink directory successfully - os.Chdir completed")
// Get the absolute path of the target
secret.Debug("Getting absolute path of current directory")
absolutePath, err := os.Getwd()
if err != nil {
// Try to restore original directory before returning error
_ = os.Chdir(originalDir)
return "", fmt.Errorf("failed to get absolute path: %w", err)
}
secret.Debug("Got absolute path", "absolute_path", absolutePath)
// Restore the original directory
secret.Debug("Restoring original directory", "original_dir", originalDir)
if err := os.Chdir(originalDir); err != nil {
return "", fmt.Errorf("failed to restore original directory: %w", err)
}
secret.Debug("Restored original directory successfully")
// Use the absolute path of the target
target = absolutePath
}
secret.Debug("resolveVaultSymlink completed successfully", "result", target) secret.Debug("resolveVaultSymlink completed successfully", "result", target)
return target, nil return target, nil
} }
// Fall through to fallback if symlink resolution failed
} else {
secret.Debug("Not using OS filesystem, skipping symlink resolution")
} }
// Fallback: treat it as a regular file containing the target path // Fallback: treat it as a regular file containing the target path
@ -97,6 +95,7 @@ func ResolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) {
fileData, err := afero.ReadFile(fs, symlinkPath) fileData, err := afero.ReadFile(fs, symlinkPath)
if err != nil { if err != nil {
secret.Debug("Failed to read target path file", "error", err) secret.Debug("Failed to read target path file", "error", err)
return "", fmt.Errorf("failed to read vault symlink: %w", err) return "", fmt.Errorf("failed to read vault symlink: %w", err)
} }
@ -108,6 +107,28 @@ func ResolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) {
return target, nil return target, nil
} }
// tryResolveOsSymlink attempts to resolve a symlink on OS filesystems
func tryResolveOsSymlink(symlinkPath string) (string, error) {
secret.Debug("Using real filesystem symlink resolution")
// Check if the symlink exists
secret.Debug("Checking symlink target", "symlink_path", symlinkPath)
target, err := os.Readlink(symlinkPath)
if err != nil {
return "", err
}
secret.Debug("Symlink points to", "target", target)
// On real filesystem, we need to handle relative symlinks
// by resolving them relative to the symlink's directory
if !filepath.IsAbs(target) {
return resolveRelativeSymlink(symlinkPath, target)
}
return target, nil
}
// GetCurrentVault gets the current vault from the file system // GetCurrentVault gets the current vault from the file system
func GetCurrentVault(fs afero.Fs, stateDir string) (*Vault, error) { func GetCurrentVault(fs afero.Fs, stateDir string) (*Vault, error) {
secret.Debug("Getting current vault", "state_dir", stateDir) secret.Debug("Getting current vault", "state_dir", stateDir)
@ -119,6 +140,7 @@ func GetCurrentVault(fs afero.Fs, stateDir string) (*Vault, error) {
_, err := fs.Stat(currentVaultPath) _, err := fs.Stat(currentVaultPath)
if err != nil { if err != nil {
secret.Debug("Failed to stat current vault symlink", "error", err, "path", currentVaultPath) secret.Debug("Failed to stat current vault symlink", "error", err, "path", currentVaultPath)
return nil, fmt.Errorf("failed to read current vault symlink: %w", err) return nil, fmt.Errorf("failed to read current vault symlink: %w", err)
} }
@ -174,6 +196,54 @@ func ListVaults(fs afero.Fs, stateDir string) ([]string, error) {
return vaults, nil return vaults, nil
} }
// processMnemonicForVault handles mnemonic processing for vault creation
func processMnemonicForVault(fs afero.Fs, stateDir, vaultDir, vaultName string) (
derivationIndex uint32, publicKeyHash string, familyHash string, err error) {
// Check if mnemonic is available in environment
mnemonic := os.Getenv(secret.EnvMnemonic)
if mnemonic == "" {
secret.Debug("No mnemonic in environment, vault created without long-term key", "vault", vaultName)
// Use 0 for derivation index when no mnemonic is provided
return 0, "", "", nil
}
secret.Debug("Mnemonic found in environment, deriving long-term key", "vault", vaultName)
// Get the next available derivation index for this mnemonic
derivationIndex, err = GetNextDerivationIndex(fs, stateDir, mnemonic)
if err != nil {
return 0, "", "", fmt.Errorf("failed to get next derivation index: %w", err)
}
// Derive the long-term key using the actual derivation index
ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex)
if err != nil {
return 0, "", "", fmt.Errorf("failed to derive long-term key: %w", err)
}
// Write the public key
ltPubKey := ltIdentity.Recipient().String()
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
if err := afero.WriteFile(fs, ltPubKeyPath, []byte(ltPubKey), secret.FilePerms); err != nil {
return 0, "", "", fmt.Errorf("failed to write long-term public key: %w", err)
}
secret.Debug("Wrote long-term public key", "path", ltPubKeyPath)
// Compute verification hash from actual derivation index
publicKeyHash = ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
// Compute family hash from index 0 (same for all vaults with this mnemonic)
// This is used to identify which vaults belong to the same mnemonic family
identity0, err := agehd.DeriveIdentity(mnemonic, 0)
if err != nil {
return 0, "", "", fmt.Errorf("failed to derive identity for index 0: %w", err)
}
familyHash = ComputeDoubleSHA256([]byte(identity0.Recipient().String()))
return derivationIndex, publicKeyHash, familyHash, nil
}
// CreateVault creates a new vault // CreateVault creates a new vault
func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) { func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) {
secret.Debug("Creating new vault", "name", name, "state_dir", stateDir) secret.Debug("Creating new vault", "name", name, "state_dir", stateDir)
@ -181,6 +251,7 @@ func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) {
// Validate vault name // Validate vault name
if !isValidVaultName(name) { if !isValidVaultName(name) {
secret.Debug("Invalid vault name provided", "vault_name", name) secret.Debug("Invalid vault name provided", "vault_name", name)
return nil, fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name) return nil, fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name)
} }
secret.Debug("Vault name validation passed", "vault_name", name) secret.Debug("Vault name validation passed", "vault_name", name)
@ -206,50 +277,10 @@ func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) {
return nil, fmt.Errorf("failed to create unlockers directory: %w", err) return nil, fmt.Errorf("failed to create unlockers directory: %w", err)
} }
// Check if mnemonic is available in environment // Process mnemonic if available
mnemonic := os.Getenv(secret.EnvMnemonic) derivationIndex, publicKeyHash, familyHash, err := processMnemonicForVault(fs, stateDir, vaultDir, name)
var derivationIndex uint32 if err != nil {
var publicKeyHash string return nil, err
var familyHash string
if mnemonic != "" {
secret.Debug("Mnemonic found in environment, deriving long-term key", "vault", name)
// Get the next available derivation index for this mnemonic
var err error
derivationIndex, err = GetNextDerivationIndex(fs, stateDir, mnemonic)
if err != nil {
return nil, fmt.Errorf("failed to get next derivation index: %w", err)
}
// Derive the long-term key using the actual derivation index
ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex)
if err != nil {
return nil, fmt.Errorf("failed to derive long-term key: %w", err)
}
// Write the public key
ltPubKey := ltIdentity.Recipient().String()
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
if err := afero.WriteFile(fs, ltPubKeyPath, []byte(ltPubKey), secret.FilePerms); err != nil {
return nil, fmt.Errorf("failed to write long-term public key: %w", err)
}
secret.Debug("Wrote long-term public key", "path", ltPubKeyPath)
// Compute verification hash from actual derivation index
publicKeyHash = ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
// Compute family hash from index 0 (same for all vaults with this mnemonic)
// This is used to identify which vaults belong to the same mnemonic family
identity0, err := agehd.DeriveIdentity(mnemonic, 0)
if err != nil {
return nil, fmt.Errorf("failed to derive identity for index 0: %w", err)
}
familyHash = ComputeDoubleSHA256([]byte(identity0.Recipient().String()))
} else {
secret.Debug("No mnemonic in environment, vault created without long-term key", "vault", name)
// Use 0 for derivation index when no mnemonic is provided
derivationIndex = 0
} }
// Save vault metadata // Save vault metadata
@ -282,6 +313,7 @@ func SelectVault(fs afero.Fs, stateDir string, name string) error {
// Validate vault name // Validate vault name
if !isValidVaultName(name) { if !isValidVaultName(name) {
secret.Debug("Invalid vault name provided", "vault_name", name) secret.Debug("Invalid vault name provided", "vault_name", name)
return fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name) return fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name)
} }
secret.Debug("Vault name validation passed", "vault_name", name) secret.Debug("Vault name validation passed", "vault_name", name)
@ -313,6 +345,7 @@ func SelectVault(fs afero.Fs, stateDir string, name string) error {
secret.Debug("Creating vault symlink", "target", targetPath, "link", currentVaultPath) secret.Debug("Creating vault symlink", "target", targetPath, "link", currentVaultPath)
if err := os.Symlink(targetPath, currentVaultPath); err == nil { if err := os.Symlink(targetPath, currentVaultPath); err == nil {
secret.Debug("Successfully selected vault", "vault_name", name) secret.Debug("Successfully selected vault", "vault_name", name)
return nil return nil
} }
// If symlink creation fails, fall back to regular file // If symlink creation fails, fall back to regular file

View File

@ -12,13 +12,17 @@ import (
"github.com/spf13/afero" "github.com/spf13/afero"
) )
// Alias the metadata types from secret package for convenience // Metadata is an alias for secret.VaultMetadata
type ( type Metadata = secret.VaultMetadata
Metadata = secret.VaultMetadata
UnlockerMetadata = secret.UnlockerMetadata // UnlockerMetadata is an alias for secret.UnlockerMetadata
SecretMetadata = secret.Metadata type UnlockerMetadata = secret.UnlockerMetadata
Configuration = secret.Configuration
) // SecretMetadata is an alias for secret.Metadata
type SecretMetadata = secret.Metadata
// Configuration is an alias for secret.Configuration
type Configuration = secret.Configuration
// ComputeDoubleSHA256 computes the double SHA256 hash of data and returns it as hex // ComputeDoubleSHA256 computes the double SHA256 hash of data and returns it as hex
func ComputeDoubleSHA256(data []byte) string { func ComputeDoubleSHA256(data []byte) string {

View File

@ -11,6 +11,7 @@ import (
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -21,6 +22,7 @@ func (v *Vault) ListSecrets() ([]string, error) {
vaultDir, err := v.GetDirectory() vaultDir, err := v.GetDirectory()
if err != nil { if err != nil {
secret.Debug("Failed to get vault directory for secret listing", "error", err, "vault_name", v.Name) secret.Debug("Failed to get vault directory for secret listing", "error", err, "vault_name", v.Name)
return nil, err return nil, err
} }
@ -30,10 +32,12 @@ func (v *Vault) ListSecrets() ([]string, error) {
exists, err := afero.DirExists(v.fs, secretsDir) exists, err := afero.DirExists(v.fs, secretsDir)
if err != nil { if err != nil {
secret.Debug("Failed to check secrets directory", "error", err, "secrets_dir", secretsDir) secret.Debug("Failed to check secrets directory", "error", err, "secrets_dir", secretsDir)
return nil, fmt.Errorf("failed to check if secrets directory exists: %w", err) return nil, fmt.Errorf("failed to check if secrets directory exists: %w", err)
} }
if !exists { if !exists {
secret.Debug("Secrets directory does not exist", "secrets_dir", secretsDir, "vault_name", v.Name) secret.Debug("Secrets directory does not exist", "secrets_dir", secretsDir, "vault_name", v.Name)
return []string{}, nil return []string{}, nil
} }
@ -41,6 +45,7 @@ func (v *Vault) ListSecrets() ([]string, error) {
files, err := afero.ReadDir(v.fs, secretsDir) files, err := afero.ReadDir(v.fs, secretsDir)
if err != nil { if err != nil {
secret.Debug("Failed to read secrets directory", "error", err, "secrets_dir", secretsDir) secret.Debug("Failed to read secrets directory", "error", err, "secrets_dir", secretsDir)
return nil, fmt.Errorf("failed to read secrets directory: %w", err) return nil, fmt.Errorf("failed to read secrets directory: %w", err)
} }
@ -94,17 +99,22 @@ func isValidSecretName(name string) bool {
} }
// AddSecret adds a secret to this vault // AddSecret adds a secret to this vault
func (v *Vault) AddSecret(name string, value []byte, force bool) error { func (v *Vault) AddSecret(name string, value *memguard.LockedBuffer, force bool) error {
if value == nil {
return fmt.Errorf("value buffer is nil")
}
secret.DebugWith("Adding secret to vault", secret.DebugWith("Adding secret to vault",
slog.String("vault_name", v.Name), slog.String("vault_name", v.Name),
slog.String("secret_name", name), slog.String("secret_name", name),
slog.Int("value_length", len(value)), slog.Int("value_length", value.Size()),
slog.Bool("force", force), slog.Bool("force", force),
) )
// Validate secret name // Validate secret name
if !isValidSecretName(name) { if !isValidSecretName(name) {
secret.Debug("Invalid secret name provided", "secret_name", name) secret.Debug("Invalid secret name provided", "secret_name", name)
return fmt.Errorf("invalid secret name '%s': must match pattern [a-z0-9.\\-_/]+", name) return fmt.Errorf("invalid secret name '%s': must match pattern [a-z0-9.\\-_/]+", name)
} }
secret.Debug("Secret name validation passed", "secret_name", name) secret.Debug("Secret name validation passed", "secret_name", name)
@ -113,6 +123,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
vaultDir, err := v.GetDirectory() vaultDir, err := v.GetDirectory()
if err != nil { if err != nil {
secret.Debug("Failed to get vault directory for secret addition", "error", err, "vault_name", v.Name) secret.Debug("Failed to get vault directory for secret addition", "error", err, "vault_name", v.Name)
return err return err
} }
secret.Debug("Got vault directory", "vault_dir", vaultDir) secret.Debug("Got vault directory", "vault_dir", vaultDir)
@ -131,6 +142,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
exists, err := afero.DirExists(v.fs, secretDir) exists, err := afero.DirExists(v.fs, secretDir)
if err != nil { if err != nil {
secret.Debug("Failed to check if secret exists", "error", err, "secret_dir", secretDir) secret.Debug("Failed to check if secret exists", "error", err, "secret_dir", secretDir)
return fmt.Errorf("failed to check if secret exists: %w", err) return fmt.Errorf("failed to check if secret exists: %w", err)
} }
secret.Debug("Secret existence check complete", "exists", exists) secret.Debug("Secret existence check complete", "exists", exists)
@ -142,6 +154,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
if exists { if exists {
if !force { if !force {
secret.Debug("Secret already exists and force not specified", "secret_name", name, "secret_dir", secretDir) secret.Debug("Secret already exists and force not specified", "secret_name", name, "secret_dir", secretDir)
return fmt.Errorf("secret %s already exists (use --force to overwrite)", name) return fmt.Errorf("secret %s already exists (use --force to overwrite)", name)
} }
@ -156,6 +169,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
secret.Debug("Creating secret directory", "secret_dir", secretDir) secret.Debug("Creating secret directory", "secret_dir", secretDir)
if err := v.fs.MkdirAll(secretDir, secret.DirPerms); err != nil { if err := v.fs.MkdirAll(secretDir, secret.DirPerms); err != nil {
secret.Debug("Failed to create secret directory", "error", err, "secret_dir", secretDir) secret.Debug("Failed to create secret directory", "error", err, "secret_dir", secretDir)
return fmt.Errorf("failed to create secret directory: %w", err) return fmt.Errorf("failed to create secret directory: %w", err)
} }
secret.Debug("Created secret directory successfully") secret.Debug("Created secret directory successfully")
@ -165,6 +179,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
versionName, err := secret.GenerateVersionName(v.fs, secretDir) versionName, err := secret.GenerateVersionName(v.fs, secretDir)
if err != nil { if err != nil {
secret.Debug("Failed to generate version name", "error", err, "secret_name", name) secret.Debug("Failed to generate version name", "error", err, "secret_name", name)
return fmt.Errorf("failed to generate version name: %w", err) return fmt.Errorf("failed to generate version name: %w", err)
} }
@ -185,9 +200,10 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
// We'll update the previous version's notAfter after we save the new version // We'll update the previous version's notAfter after we save the new version
} }
// Save the new version // Save the new version - pass the LockedBuffer directly
if err := newVersion.Save(value); err != nil { if err := newVersion.Save(value); err != nil {
secret.Debug("Failed to save new version", "error", err, "version", versionName) secret.Debug("Failed to save new version", "error", err, "version", versionName)
return fmt.Errorf("failed to save version: %w", err) return fmt.Errorf("failed to save version: %w", err)
} }
@ -197,12 +213,14 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
ltIdentity, err := v.GetOrDeriveLongTermKey() ltIdentity, err := v.GetOrDeriveLongTermKey()
if err != nil { if err != nil {
secret.Debug("Failed to get long-term key for metadata update", "error", err) secret.Debug("Failed to get long-term key for metadata update", "error", err)
return fmt.Errorf("failed to get long-term key: %w", err) return fmt.Errorf("failed to get long-term key: %w", err)
} }
// Load previous version metadata // Load previous version metadata
if err := previousVersion.LoadMetadata(ltIdentity); err != nil { if err := previousVersion.LoadMetadata(ltIdentity); err != nil {
secret.Debug("Failed to load previous version metadata", "error", err) secret.Debug("Failed to load previous version metadata", "error", err)
return fmt.Errorf("failed to load previous version metadata: %w", err) return fmt.Errorf("failed to load previous version metadata: %w", err)
} }
@ -212,6 +230,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
// Re-save the metadata (we need to implement an update method) // Re-save the metadata (we need to implement an update method)
if err := updateVersionMetadata(v.fs, previousVersion, ltIdentity); err != nil { if err := updateVersionMetadata(v.fs, previousVersion, ltIdentity); err != nil {
secret.Debug("Failed to update previous version metadata", "error", err) secret.Debug("Failed to update previous version metadata", "error", err)
return fmt.Errorf("failed to update previous version metadata: %w", err) return fmt.Errorf("failed to update previous version metadata: %w", err)
} }
} }
@ -219,6 +238,7 @@ func (v *Vault) AddSecret(name string, value []byte, force bool) error {
// Set current symlink to new version // Set current symlink to new version
if err := secret.SetCurrentVersion(v.fs, secretDir, versionName); err != nil { if err := secret.SetCurrentVersion(v.fs, secretDir, versionName); err != nil {
secret.Debug("Failed to set current version", "error", err, "version", versionName) secret.Debug("Failed to set current version", "error", err, "version", versionName)
return fmt.Errorf("failed to set current version: %w", err) return fmt.Errorf("failed to set current version: %w", err)
} }
@ -257,7 +277,10 @@ func updateVersionMetadata(fs afero.Fs, version *secret.Version, ltIdentity *age
} }
// Encrypt metadata to the version's public key // Encrypt metadata to the version's public key
encryptedMetadata, err := secret.EncryptToRecipient(metadataBytes, versionIdentity.Recipient()) metadataBuffer := memguard.NewBufferFromBytes(metadataBytes)
defer metadataBuffer.Destroy()
encryptedMetadata, err := secret.EncryptToRecipient(metadataBuffer, versionIdentity.Recipient())
if err != nil { if err != nil {
return fmt.Errorf("failed to encrypt version metadata: %w", err) return fmt.Errorf("failed to encrypt version metadata: %w", err)
} }
@ -293,6 +316,7 @@ func (v *Vault) GetSecretVersion(name string, version string) ([]byte, error) {
vaultDir, err := v.GetDirectory() vaultDir, err := v.GetDirectory()
if err != nil { if err != nil {
secret.Debug("Failed to get vault directory", "error", err, "vault_name", v.Name) secret.Debug("Failed to get vault directory", "error", err, "vault_name", v.Name)
return nil, err return nil, err
} }
@ -304,10 +328,12 @@ func (v *Vault) GetSecretVersion(name string, version string) ([]byte, error) {
exists, err := afero.DirExists(v.fs, secretDir) exists, err := afero.DirExists(v.fs, secretDir)
if err != nil { if err != nil {
secret.Debug("Failed to check if secret exists", "error", err, "secret_name", name) secret.Debug("Failed to check if secret exists", "error", err, "secret_name", name)
return nil, fmt.Errorf("failed to check if secret exists: %w", err) return nil, fmt.Errorf("failed to check if secret exists: %w", err)
} }
if !exists { if !exists {
secret.Debug("Secret not found in vault", "secret_name", name, "vault_name", v.Name) secret.Debug("Secret not found in vault", "secret_name", name, "vault_name", v.Name)
return nil, fmt.Errorf("secret %s not found", name) return nil, fmt.Errorf("secret %s not found", name)
} }
@ -317,6 +343,7 @@ func (v *Vault) GetSecretVersion(name string, version string) ([]byte, error) {
currentVersion, err := secret.GetCurrentVersion(v.fs, secretDir) currentVersion, err := secret.GetCurrentVersion(v.fs, secretDir)
if err != nil { if err != nil {
secret.Debug("Failed to get current version", "error", err, "secret_name", name) secret.Debug("Failed to get current version", "error", err, "secret_name", name)
return nil, fmt.Errorf("failed to get current version: %w", err) return nil, fmt.Errorf("failed to get current version: %w", err)
} }
version = currentVersion version = currentVersion
@ -331,10 +358,12 @@ func (v *Vault) GetSecretVersion(name string, version string) ([]byte, error) {
exists, err = afero.DirExists(v.fs, versionPath) exists, err = afero.DirExists(v.fs, versionPath)
if err != nil { if err != nil {
secret.Debug("Failed to check if version exists", "error", err, "version", version) secret.Debug("Failed to check if version exists", "error", err, "version", version)
return nil, fmt.Errorf("failed to check if version exists: %w", err) return nil, fmt.Errorf("failed to check if version exists: %w", err)
} }
if !exists { if !exists {
secret.Debug("Version not found", "version", version, "secret_name", name) secret.Debug("Version not found", "version", version, "secret_name", name)
return nil, fmt.Errorf("version %s not found for secret %s", version, name) return nil, fmt.Errorf("version %s not found for secret %s", version, name)
} }
@ -344,6 +373,7 @@ func (v *Vault) GetSecretVersion(name string, version string) ([]byte, error) {
longTermIdentity, err := v.UnlockVault() longTermIdentity, err := v.UnlockVault()
if err != nil { if err != nil {
secret.Debug("Failed to unlock vault", "error", err, "vault_name", v.Name) secret.Debug("Failed to unlock vault", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to unlock vault: %w", err) return nil, fmt.Errorf("failed to unlock vault: %w", err)
} }
@ -359,6 +389,7 @@ func (v *Vault) GetSecretVersion(name string, version string) ([]byte, error) {
decryptedValue, err := secretVersion.GetValue(longTermIdentity) decryptedValue, err := secretVersion.GetValue(longTermIdentity)
if err != nil { if err != nil {
secret.Debug("Failed to decrypt version value", "error", err, "version", version, "secret_name", name) secret.Debug("Failed to decrypt version value", "error", err, "version", version, "secret_name", name)
return nil, fmt.Errorf("failed to decrypt version: %w", err) return nil, fmt.Errorf("failed to decrypt version: %w", err)
} }
@ -386,6 +417,7 @@ func (v *Vault) UnlockVault() (*age.X25519Identity, error) {
// If vault is already unlocked, return the cached key // If vault is already unlocked, return the cached key
if !v.Locked() { if !v.Locked() {
secret.Debug("Vault already unlocked, returning cached long-term key", "vault_name", v.Name) secret.Debug("Vault already unlocked, returning cached long-term key", "vault_name", v.Name)
return v.longTermKey, nil return v.longTermKey, nil
} }
@ -393,6 +425,7 @@ func (v *Vault) UnlockVault() (*age.X25519Identity, error) {
longTermIdentity, err := v.GetOrDeriveLongTermKey() longTermIdentity, err := v.GetOrDeriveLongTermKey()
if err != nil { if err != nil {
secret.Debug("Failed to get or derive long-term key", "error", err, "vault_name", v.Name) secret.Debug("Failed to get or derive long-term key", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to get long-term key: %w", err) return nil, fmt.Errorf("failed to get long-term key: %w", err)
} }

View File

@ -24,11 +24,21 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Helper function to add a secret to vault with proper buffer protection
func addTestSecretToVault(t *testing.T, vault *Vault, name string, value []byte, force bool) {
t.Helper()
buffer := memguard.NewBufferFromBytes(value)
defer buffer.Destroy()
err := vault.AddSecret(name, buffer, force)
require.NoError(t, err)
}
// Helper function to create a vault with long-term key set up // Helper function to create a vault with long-term key set up
func createTestVaultWithKey(t *testing.T, fs afero.Fs, stateDir, vaultName string) *Vault { func createTestVaultWithKey(t *testing.T, fs afero.Fs, stateDir, vaultName string) *Vault {
// Set mnemonic for testing // Set mnemonic for testing
@ -65,9 +75,10 @@ func TestVaultAddSecretCreatesVersion(t *testing.T) {
// Add a secret // Add a secret
secretName := "test/secret" secretName := "test/secret"
secretValue := []byte("initial-value") secretValue := []byte("initial-value")
expectedValue := make([]byte, len(secretValue))
copy(expectedValue, secretValue)
err := vault.AddSecret(secretName, secretValue, false) addTestSecretToVault(t, vault, secretName, secretValue, false)
require.NoError(t, err)
// Check that version directory was created // Check that version directory was created
vaultDir, _ := vault.GetDirectory() vaultDir, _ := vault.GetDirectory()
@ -88,7 +99,7 @@ func TestVaultAddSecretCreatesVersion(t *testing.T) {
// Get the secret value // Get the secret value
retrievedValue, err := vault.GetSecret(secretName) retrievedValue, err := vault.GetSecret(secretName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, secretValue, retrievedValue) assert.Equal(t, expectedValue, retrievedValue)
} }
func TestVaultAddSecretMultipleVersions(t *testing.T) { func TestVaultAddSecretMultipleVersions(t *testing.T) {
@ -101,17 +112,17 @@ func TestVaultAddSecretMultipleVersions(t *testing.T) {
secretName := "test/secret" secretName := "test/secret"
// Add first version // Add first version
err := vault.AddSecret(secretName, []byte("version-1"), false) addTestSecretToVault(t, vault, secretName, []byte("version-1"), false)
require.NoError(t, err)
// Try to add again without force - should fail // Try to add again without force - should fail
err = vault.AddSecret(secretName, []byte("version-2"), false) failBuffer := memguard.NewBufferFromBytes([]byte("version-2"))
defer failBuffer.Destroy()
err := vault.AddSecret(secretName, failBuffer, false)
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "already exists") assert.Contains(t, err.Error(), "already exists")
// Add with force - should create new version // Add with force - should create new version
err = vault.AddSecret(secretName, []byte("version-2"), true) addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
require.NoError(t, err)
// Check that we have two versions // Check that we have two versions
vaultDir, _ := vault.GetDirectory() vaultDir, _ := vault.GetDirectory()
@ -136,14 +147,12 @@ func TestVaultGetSecretVersion(t *testing.T) {
secretName := "test/secret" secretName := "test/secret"
// Add multiple versions // Add multiple versions
err := vault.AddSecret(secretName, []byte("version-1"), false) addTestSecretToVault(t, vault, secretName, []byte("version-1"), false)
require.NoError(t, err)
// Small delay to ensure different version names // Small delay to ensure different version names
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
err = vault.AddSecret(secretName, []byte("version-2"), true) addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
require.NoError(t, err)
// Get versions list // Get versions list
vaultDir, _ := vault.GetDirectory() vaultDir, _ := vault.GetDirectory()
@ -185,7 +194,9 @@ func TestVaultVersionTimestamps(t *testing.T) {
// Add first version // Add first version
beforeFirst := time.Now() beforeFirst := time.Now()
err = vault.AddSecret(secretName, []byte("version-1"), false) v1Buffer := memguard.NewBufferFromBytes([]byte("version-1"))
defer v1Buffer.Destroy()
err = vault.AddSecret(secretName, v1Buffer, false)
require.NoError(t, err) require.NoError(t, err)
afterFirst := time.Now() afterFirst := time.Now()
@ -212,8 +223,7 @@ func TestVaultVersionTimestamps(t *testing.T) {
// Add second version // Add second version
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
beforeSecond := time.Now() beforeSecond := time.Now()
err = vault.AddSecret(secretName, []byte("version-2"), true) addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
require.NoError(t, err)
afterSecond := time.Now() afterSecond := time.Now()
// Get updated versions // Get updated versions
@ -249,11 +259,10 @@ func TestVaultGetNonExistentVersion(t *testing.T) {
vault := createTestVaultWithKey(t, fs, stateDir, "test") vault := createTestVaultWithKey(t, fs, stateDir, "test")
// Add a secret // Add a secret
err := vault.AddSecret("test/secret", []byte("value"), false) addTestSecretToVault(t, vault, "test/secret", []byte("value"), false)
require.NoError(t, err)
// Try to get non-existent version // Try to get non-existent version
_, err = vault.GetSecretVersion("test/secret", "20991231.999") _, err := vault.GetSecretVersion("test/secret", "20991231.999")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "not found") assert.Contains(t, err.Error(), "not found")
} }
@ -281,7 +290,9 @@ func TestUpdateVersionMetadata(t *testing.T) {
version.Metadata.NotAfter = nil version.Metadata.NotAfter = nil
// Save version // Save version
err = version.Save([]byte("test-value")) testBuffer := memguard.NewBufferFromBytes([]byte("test-value"))
defer testBuffer.Destroy()
err = version.Save(testBuffer)
require.NoError(t, err) require.NoError(t, err)
// Update metadata // Update metadata

View File

@ -10,6 +10,7 @@ import (
"filippo.io/age" "filippo.io/age"
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -20,6 +21,7 @@ func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) {
vaultDir, err := v.GetDirectory() vaultDir, err := v.GetDirectory()
if err != nil { if err != nil {
secret.Debug("Failed to get vault directory for unlocker", "error", err, "vault_name", v.Name) secret.Debug("Failed to get vault directory for unlocker", "error", err, "vault_name", v.Name)
return nil, err return nil, err
} }
@ -29,35 +31,14 @@ func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) {
_, err = v.fs.Stat(currentUnlockerPath) _, err = v.fs.Stat(currentUnlockerPath)
if err != nil { if err != nil {
secret.Debug("Failed to stat current unlocker symlink", "error", err, "path", currentUnlockerPath) secret.Debug("Failed to stat current unlocker symlink", "error", err, "path", currentUnlockerPath)
return nil, fmt.Errorf("failed to read current unlocker: %w", err) return nil, fmt.Errorf("failed to read current unlocker: %w", err)
} }
// Resolve the symlink to get the target directory // Resolve the symlink to get the target directory
var unlockerDir string unlockerDir, err := v.resolveUnlockerDirectory(currentUnlockerPath)
if linkReader, ok := v.fs.(afero.LinkReader); ok { if err != nil {
secret.Debug("Resolving unlocker symlink using afero") return nil, err
// Try to read as symlink first
unlockerDir, err = linkReader.ReadlinkIfPossible(currentUnlockerPath)
if err != nil {
secret.Debug("Failed to read symlink, falling back to file contents",
"error", err, "symlink_path", currentUnlockerPath)
// Fallback: read the path from file contents
unlockerDirBytes, err := afero.ReadFile(v.fs, currentUnlockerPath)
if err != nil {
secret.Debug("Failed to read unlocker path file", "error", err, "path", currentUnlockerPath)
return nil, fmt.Errorf("failed to read current unlocker: %w", err)
}
unlockerDir = strings.TrimSpace(string(unlockerDirBytes))
}
} else {
secret.Debug("Reading unlocker path (filesystem doesn't support symlinks)")
// Fallback for filesystems that don't support symlinks: read the path from file contents
unlockerDirBytes, err := afero.ReadFile(v.fs, currentUnlockerPath)
if err != nil {
secret.Debug("Failed to read unlocker path file", "error", err, "path", currentUnlockerPath)
return nil, fmt.Errorf("failed to read current unlocker: %w", err)
}
unlockerDir = strings.TrimSpace(string(unlockerDirBytes))
} }
secret.DebugWith("Resolved unlocker directory", secret.DebugWith("Resolved unlocker directory",
@ -72,12 +53,14 @@ func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) {
metadataBytes, err := afero.ReadFile(v.fs, metadataPath) metadataBytes, err := afero.ReadFile(v.fs, metadataPath)
if err != nil { if err != nil {
secret.Debug("Failed to read unlocker metadata", "error", err, "path", metadataPath) secret.Debug("Failed to read unlocker metadata", "error", err, "path", metadataPath)
return nil, fmt.Errorf("failed to read unlocker metadata: %w", err) return nil, fmt.Errorf("failed to read unlocker metadata: %w", err)
} }
var metadata UnlockerMetadata var metadata UnlockerMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil { if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
secret.Debug("Failed to parse unlocker metadata", "error", err, "path", metadataPath) secret.Debug("Failed to parse unlocker metadata", "error", err, "path", metadataPath)
return nil, fmt.Errorf("failed to parse unlocker metadata: %w", err) return nil, fmt.Errorf("failed to parse unlocker metadata: %w", err)
} }
@ -102,6 +85,7 @@ func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) {
unlocker = secret.NewKeychainUnlocker(v.fs, unlockerDir, metadata) unlocker = secret.NewKeychainUnlocker(v.fs, unlockerDir, metadata)
default: default:
secret.Debug("Unsupported unlocker type", "type", metadata.Type) secret.Debug("Unsupported unlocker type", "type", metadata.Type)
return nil, fmt.Errorf("unsupported unlocker type: %s", metadata.Type) return nil, fmt.Errorf("unsupported unlocker type: %s", metadata.Type)
} }
@ -114,6 +98,97 @@ func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) {
return unlocker, nil return unlocker, nil
} }
// resolveUnlockerDirectory resolves the unlocker directory from a symlink or file
func (v *Vault) resolveUnlockerDirectory(currentUnlockerPath string) (string, error) {
linkReader, ok := v.fs.(afero.LinkReader)
if !ok {
// Fallback for filesystems that don't support symlinks
return v.readUnlockerPathFromFile(currentUnlockerPath)
}
secret.Debug("Resolving unlocker symlink using afero")
// Try to read as symlink first
unlockerDir, err := linkReader.ReadlinkIfPossible(currentUnlockerPath)
if err == nil {
return unlockerDir, nil
}
secret.Debug("Failed to read symlink, falling back to file contents",
"error", err, "symlink_path", currentUnlockerPath)
// Fallback: read the path from file contents
return v.readUnlockerPathFromFile(currentUnlockerPath)
}
// readUnlockerPathFromFile reads the unlocker directory path from a file
func (v *Vault) readUnlockerPathFromFile(path string) (string, error) {
secret.Debug("Reading unlocker path from file", "path", path)
unlockerDirBytes, err := afero.ReadFile(v.fs, path)
if err != nil {
secret.Debug("Failed to read unlocker path file", "error", err, "path", path)
return "", fmt.Errorf("failed to read current unlocker: %w", err)
}
return strings.TrimSpace(string(unlockerDirBytes)), nil
}
// findUnlockerByID finds an unlocker by its ID and returns the unlocker instance and its directory path
func (v *Vault) findUnlockerByID(unlockersDir, unlockerID string) (secret.Unlocker, string, error) {
files, err := afero.ReadDir(v.fs, unlockersDir)
if err != nil {
return nil, "", fmt.Errorf("failed to read unlockers directory: %w", err)
}
for _, file := range files {
if !file.IsDir() {
continue
}
// Read metadata file
metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json")
exists, err := afero.Exists(v.fs, metadataPath)
if err != nil {
return nil, "", fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err)
}
if !exists {
// Skip directories without metadata - they might not be unlockers
continue
}
metadataBytes, err := afero.ReadFile(v.fs, metadataPath)
if err != nil {
return nil, "", fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err)
}
var metadata UnlockerMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
return nil, "", fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err)
}
unlockerDirPath := filepath.Join(unlockersDir, file.Name())
// Create the appropriate unlocker instance
var tempUnlocker secret.Unlocker
switch metadata.Type {
case "passphrase":
tempUnlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDirPath, metadata)
case "pgp":
tempUnlocker = secret.NewPGPUnlocker(v.fs, unlockerDirPath, metadata)
case "keychain":
tempUnlocker = secret.NewKeychainUnlocker(v.fs, unlockerDirPath, metadata)
default:
continue
}
// Check if this unlocker's ID matches
if tempUnlocker.GetID() == unlockerID {
return tempUnlocker, unlockerDirPath, nil
}
}
return nil, "", nil
}
// ListUnlockers returns a list of available unlockers for this vault // ListUnlockers returns a list of available unlockers for this vault
func (v *Vault) ListUnlockers() ([]UnlockerMetadata, error) { func (v *Vault) ListUnlockers() ([]UnlockerMetadata, error) {
vaultDir, err := v.GetDirectory() vaultDir, err := v.GetDirectory()
@ -178,58 +253,10 @@ func (v *Vault) RemoveUnlocker(unlockerID string) error {
// Find the unlocker directory and create the unlocker instance // Find the unlocker directory and create the unlocker instance
unlockersDir := filepath.Join(vaultDir, "unlockers.d") unlockersDir := filepath.Join(vaultDir, "unlockers.d")
// List directories in unlockers.d // Find the unlocker by ID
files, err := afero.ReadDir(v.fs, unlockersDir) unlocker, _, err := v.findUnlockerByID(unlockersDir, unlockerID)
if err != nil { if err != nil {
return fmt.Errorf("failed to read unlockers directory: %w", err) return err
}
var unlocker secret.Unlocker
var unlockerDirPath string
for _, file := range files {
if file.IsDir() {
// Read metadata file
metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json")
exists, err := afero.Exists(v.fs, metadataPath)
if err != nil {
return fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err)
}
if !exists {
// Skip directories without metadata - they might not be unlockers
continue
}
metadataBytes, err := afero.ReadFile(v.fs, metadataPath)
if err != nil {
return fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err)
}
var metadata UnlockerMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
return fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err)
}
unlockerDirPath = filepath.Join(unlockersDir, file.Name())
// Create the appropriate unlocker instance
var tempUnlocker secret.Unlocker
switch metadata.Type {
case "passphrase":
tempUnlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDirPath, metadata)
case "pgp":
tempUnlocker = secret.NewPGPUnlocker(v.fs, unlockerDirPath, metadata)
case "keychain":
tempUnlocker = secret.NewKeychainUnlocker(v.fs, unlockerDirPath, metadata)
default:
continue
}
// Check if this unlocker's ID matches
if tempUnlocker.GetID() == unlockerID {
unlocker = tempUnlocker
break
}
}
} }
if unlocker == nil { if unlocker == nil {
@ -250,57 +277,10 @@ func (v *Vault) SelectUnlocker(unlockerID string) error {
// Find the unlocker directory by ID // Find the unlocker directory by ID
unlockersDir := filepath.Join(vaultDir, "unlockers.d") unlockersDir := filepath.Join(vaultDir, "unlockers.d")
// List directories in unlockers.d to find the unlocker // Find the unlocker by ID
files, err := afero.ReadDir(v.fs, unlockersDir) _, targetUnlockerDir, err := v.findUnlockerByID(unlockersDir, unlockerID)
if err != nil { if err != nil {
return fmt.Errorf("failed to read unlockers directory: %w", err) return err
}
var targetUnlockerDir string
for _, file := range files {
if file.IsDir() {
// Read metadata file
metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json")
exists, err := afero.Exists(v.fs, metadataPath)
if err != nil {
return fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err)
}
if !exists {
// Skip directories without metadata - they might not be unlockers
continue
}
metadataBytes, err := afero.ReadFile(v.fs, metadataPath)
if err != nil {
return fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err)
}
var metadata UnlockerMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
return fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err)
}
unlockerDirPath := filepath.Join(unlockersDir, file.Name())
// Create the appropriate unlocker instance
var tempUnlocker secret.Unlocker
switch metadata.Type {
case "passphrase":
tempUnlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDirPath, metadata)
case "pgp":
tempUnlocker = secret.NewPGPUnlocker(v.fs, unlockerDirPath, metadata)
case "keychain":
tempUnlocker = secret.NewKeychainUnlocker(v.fs, unlockerDirPath, metadata)
default:
continue
}
// Check if this unlocker's ID matches
if tempUnlocker.GetID() == unlockerID {
targetUnlockerDir = unlockerDirPath
break
}
}
} }
if targetUnlockerDir == "" { if targetUnlockerDir == "" {
@ -337,7 +317,8 @@ func (v *Vault) SelectUnlocker(unlockerID string) error {
} }
// CreatePassphraseUnlocker creates a new passphrase-protected unlocker // CreatePassphraseUnlocker creates a new passphrase-protected unlocker
func (v *Vault) CreatePassphraseUnlocker(passphrase string) (*secret.PassphraseUnlocker, error) { // The passphrase must be provided as a LockedBuffer for security
func (v *Vault) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*secret.PassphraseUnlocker, error) {
vaultDir, err := v.GetDirectory() vaultDir, err := v.GetDirectory()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get vault directory: %w", err) return nil, fmt.Errorf("failed to get vault directory: %w", err)
@ -364,8 +345,8 @@ func (v *Vault) CreatePassphraseUnlocker(passphrase string) (*secret.PassphraseU
} }
// Encrypt private key with passphrase // Encrypt private key with passphrase
privKeyData := []byte(unlockerIdentity.String()) privKeyStr := unlockerIdentity.String()
encryptedPrivKey, err := secret.EncryptWithPassphrase(privKeyData, passphrase) encryptedPrivKey, err := secret.EncryptWithPassphrase([]byte(privKeyStr), passphrase)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encrypt unlocker private key: %w", err) return nil, fmt.Errorf("failed to encrypt unlocker private key: %w", err)
} }
@ -401,8 +382,10 @@ func (v *Vault) CreatePassphraseUnlocker(passphrase string) (*secret.PassphraseU
return nil, fmt.Errorf("failed to get long-term key: %w", err) return nil, fmt.Errorf("failed to get long-term key: %w", err)
} }
ltPrivKey := []byte(ltIdentity.String()) ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKey, unlockerIdentity.Recipient()) defer ltPrivKeyBuffer.Destroy()
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerIdentity.Recipient())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encrypt long-term private key: %w", err) return nil, fmt.Errorf("failed to encrypt long-term private key: %w", err)
} }

View File

@ -76,12 +76,14 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
metadata, err := LoadVaultMetadata(v.fs, vaultDir) metadata, err := LoadVaultMetadata(v.fs, vaultDir)
if err != nil { if err != nil {
secret.Debug("Failed to load vault metadata", "error", err, "vault_name", v.Name) secret.Debug("Failed to load vault metadata", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to load vault metadata: %w", err) return nil, fmt.Errorf("failed to load vault metadata: %w", err)
} }
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex) ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex)
if err != nil { if err != nil {
secret.Debug("Failed to derive long-term key from mnemonic", "error", err, "vault_name", v.Name) secret.Debug("Failed to derive long-term key from mnemonic", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err) return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
} }
@ -117,6 +119,7 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
unlocker, err := v.GetCurrentUnlocker() unlocker, err := v.GetCurrentUnlocker()
if err != nil { if err != nil {
secret.Debug("Failed to get current unlocker", "error", err, "vault_name", v.Name) secret.Debug("Failed to get current unlocker", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to get current unlocker: %w", err) return nil, fmt.Errorf("failed to get current unlocker: %w", err)
} }
@ -130,6 +133,7 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
unlockerIdentity, err := unlocker.GetIdentity() unlockerIdentity, err := unlocker.GetIdentity()
if err != nil { if err != nil {
secret.Debug("Failed to get unlocker identity", "error", err, "unlocker_type", unlocker.GetType()) secret.Debug("Failed to get unlocker identity", "error", err, "unlocker_type", unlocker.GetType())
return nil, fmt.Errorf("failed to get unlocker identity: %w", err) return nil, fmt.Errorf("failed to get unlocker identity: %w", err)
} }
@ -141,6 +145,7 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
encryptedLtPrivKey, err := afero.ReadFile(v.fs, encryptedLtPrivKeyPath) encryptedLtPrivKey, err := afero.ReadFile(v.fs, encryptedLtPrivKeyPath)
if err != nil { if err != nil {
secret.Debug("Failed to read encrypted long-term private key", "error", err, "path", encryptedLtPrivKeyPath) secret.Debug("Failed to read encrypted long-term private key", "error", err, "path", encryptedLtPrivKeyPath)
return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err) return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err)
} }
@ -155,6 +160,7 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
ltPrivKeyData, err := secret.DecryptWithIdentity(encryptedLtPrivKey, unlockerIdentity) ltPrivKeyData, err := secret.DecryptWithIdentity(encryptedLtPrivKey, unlockerIdentity)
if err != nil { if err != nil {
secret.Debug("Failed to decrypt long-term private key", "error", err, "unlocker_type", unlocker.GetType()) secret.Debug("Failed to decrypt long-term private key", "error", err, "unlocker_type", unlocker.GetType())
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err) return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
} }
@ -169,6 +175,7 @@ func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
ltIdentity, err := age.ParseX25519Identity(string(ltPrivKeyData)) ltIdentity, err := age.ParseX25519Identity(string(ltPrivKeyData))
if err != nil { if err != nil {
secret.Debug("Failed to parse long-term private key", "error", err, "vault_name", v.Name) secret.Debug("Failed to parse long-term private key", "error", err, "vault_name", v.Name)
return nil, fmt.Errorf("failed to parse long-term private key: %w", err) return nil, fmt.Errorf("failed to parse long-term private key: %w", err)
} }

View File

@ -6,6 +6,7 @@ import (
"git.eeqj.de/sneak/secret/internal/secret" "git.eeqj.de/sneak/secret/internal/secret"
"git.eeqj.de/sneak/secret/pkg/agehd" "git.eeqj.de/sneak/secret/pkg/agehd"
"github.com/awnumar/memguard"
"github.com/spf13/afero" "github.com/spf13/afero"
) )
@ -121,8 +122,13 @@ func TestVaultOperations(t *testing.T) {
// Now add a secret // Now add a secret
secretName := "test/secret" secretName := "test/secret"
secretValue := []byte("test-secret-value") secretValue := []byte("test-secret-value")
expectedValue := make([]byte, len(secretValue))
copy(expectedValue, secretValue)
err = vlt.AddSecret(secretName, secretValue, false) secretBuffer := memguard.NewBufferFromBytes(secretValue)
defer secretBuffer.Destroy()
err = vlt.AddSecret(secretName, secretBuffer, false)
if err != nil { if err != nil {
t.Fatalf("Failed to add secret: %v", err) t.Fatalf("Failed to add secret: %v", err)
} }
@ -151,8 +157,8 @@ func TestVaultOperations(t *testing.T) {
t.Fatalf("Failed to get secret: %v", err) t.Fatalf("Failed to get secret: %v", err)
} }
if string(retrievedValue) != string(secretValue) { if string(retrievedValue) != string(expectedValue) {
t.Errorf("Expected secret value '%s', got '%s'", string(secretValue), string(retrievedValue)) t.Errorf("Expected secret value '%s', got '%s'", string(expectedValue), string(retrievedValue))
} }
}) })
@ -172,7 +178,9 @@ func TestVaultOperations(t *testing.T) {
} }
// Create a passphrase unlocker // Create a passphrase unlocker
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker("test-passphrase") passphraseBuffer := memguard.NewBufferFromBytes([]byte("test-passphrase"))
defer passphraseBuffer.Destroy()
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
if err != nil { if err != nil {
t.Fatalf("Failed to create passphrase unlocker: %v", err) t.Fatalf("Failed to create passphrase unlocker: %v", err)
} }

View File

@ -21,11 +21,11 @@ import (
) )
const ( const (
purpose = uint32(83696968) // fixed by BIP-85 ("bip") purpose = uint32(83696968) // fixed by BIP-85 ("bip")
vendorID = uint32(592366788) // berlin.sneak vendorID = uint32(592366788) // berlin.sneak
appID = uint32(733482323) // secret appID = uint32(733482323) // secret
hrp = "age-secret-key-" // Bech32 HRP used by age hrp = "age-secret-key-" // Bech32 HRP used by age
x25519KeySize = 32 // 256-bit key size for X25519 x25519KeySize = 32 // 256-bit key size for X25519
) )
// clamp applies RFC-7748 clamping to a 32-byte scalar. // clamp applies RFC-7748 clamping to a 32-byte scalar.

View File

@ -1,3 +1,4 @@
// Package bip85 implements BIP85 deterministic entropy derivation.
package bip85 package bip85
import ( import (
@ -27,14 +28,16 @@ const (
// BIP85_KEY_HMAC_KEY is the HMAC key used for deriving the entropy // BIP85_KEY_HMAC_KEY is the HMAC key used for deriving the entropy
BIP85_KEY_HMAC_KEY = "bip-entropy-from-k" //nolint:revive // ALL_CAPS used for BIP85 constants BIP85_KEY_HMAC_KEY = "bip-entropy-from-k" //nolint:revive // ALL_CAPS used for BIP85 constants
// Application numbers // AppBIP39 is the application number for BIP39 mnemonics
AppBIP39 = 39 // BIP39 mnemonics AppBIP39 = 39
AppHDWIF = 2 // WIF for Bitcoin Core // AppHDWIF is the application number for WIF (Wallet Import Format) for Bitcoin Core
AppXPRV = 32 // Extended private key AppHDWIF = 2
APP_HEX = 128169 //nolint:revive // ALL_CAPS used for BIP85 constants // AppXPRV is the application number for extended private key
APP_PWD64 = 707764 // Base64 passwords //nolint:revive // ALL_CAPS used for BIP85 constants AppXPRV = 32
AppPWD85 = 707785 // Base85 passwords APP_HEX = 128169 //nolint:revive // ALL_CAPS used for BIP85 constants
APP_RSA = 828365 //nolint:revive // ALL_CAPS used for BIP85 constants APP_PWD64 = 707764 // Base64 passwords //nolint:revive // ALL_CAPS used for BIP85 constants
AppPWD85 = 707785 // Base85 passwords
APP_RSA = 828365 //nolint:revive // ALL_CAPS used for BIP85 constants
) )
// Version bytes for extended keys // Version bytes for extended keys