Compare commits

3 Commits

Author SHA1 Message Date
9815b8e265 Merge branch 'main' into fix/issue-3 2026-02-09 02:43:52 +01:00
clawbot
4c7ed3fc51 test: add test for getLongTermPrivateKey derivation index
Verifies that getLongTermPrivateKey reads the derivation index from
vault metadata instead of using hardcoded index 0. Test creates a
mock vault with DerivationIndex=5 and confirms the derived key
matches index 5.
2026-02-08 17:21:31 -08:00
clawbot
2a4ceb2045 fix: use vault derivation index in getLongTermPrivateKey instead of hardcoded 0
Previously, getLongTermPrivateKey() always used derivation index 0 when
deriving the long-term key from a mnemonic. This caused wrong key
derivation for vaults with index > 0 (second+ vault from same mnemonic),
leading to silent data corruption in keychain unlocker creation.

Now reads the vault's actual DerivationIndex from vault-metadata.json.
2026-02-08 12:03:06 -08:00
2 changed files with 87 additions and 2 deletions

View File

@@ -0,0 +1,68 @@
package secret
import (
"encoding/json"
"testing"
"time"
"github.com/awnumar/memguard"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.eeqj.de/sneak/secret/pkg/agehd"
)
// mockVault implements VaultInterface for testing getLongTermPrivateKey
type mockVaultForDerivation struct {
dir string
fs afero.Fs
}
func (m *mockVaultForDerivation) GetDirectory() (string, error) { return m.dir, nil }
func (m *mockVaultForDerivation) GetName() string { return "test-vault" }
func (m *mockVaultForDerivation) GetFilesystem() afero.Fs { return m.fs }
// Stubs for unused interface methods
func (m *mockVaultForDerivation) AddSecret(name string, value *memguard.LockedBuffer, force bool) error {
return nil
}
func (m *mockVaultForDerivation) GetCurrentUnlocker() (Unlocker, error) { return nil, nil }
func (m *mockVaultForDerivation) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*PassphraseUnlocker, error) {
return nil, nil
}
func TestGetLongTermPrivateKeyUsesVaultDerivationIndex(t *testing.T) {
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
// Derive keys at index 0 and index 5 — they must differ
key0, err := agehd.DeriveIdentity(testMnemonic, 0)
require.NoError(t, err)
key5, err := agehd.DeriveIdentity(testMnemonic, 5)
require.NoError(t, err)
require.NotEqual(t, key0.String(), key5.String(), "different derivation indices must produce different keys")
// Set up in-memory filesystem with vault metadata at index 5
fs := afero.NewMemMapFs()
vaultDir := "/test-vault"
require.NoError(t, fs.MkdirAll(vaultDir, 0o700))
metadata := VaultMetadata{
CreatedAt: time.Now(),
DerivationIndex: 5,
}
metaBytes, err := json.Marshal(metadata)
require.NoError(t, err)
require.NoError(t, afero.WriteFile(fs, vaultDir+"/vault-metadata.json", metaBytes, 0o600))
// Set the mnemonic env var
t.Setenv(EnvMnemonic, testMnemonic)
vault := &mockVaultForDerivation{dir: vaultDir, fs: fs}
result, err := getLongTermPrivateKey(fs, vault)
require.NoError(t, err)
defer result.Destroy()
// The derived key should match index 5, NOT index 0
assert.Equal(t, key5.String(), string(result.Bytes()), "getLongTermPrivateKey should use vault's derivation index (5), not hardcoded 0")
assert.NotEqual(t, key0.String(), string(result.Bytes()), "getLongTermPrivateKey must not use hardcoded index 0")
}

View File

@@ -251,8 +251,25 @@ func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) (*memguard.LockedB
// Check if mnemonic is available in environment variable
envMnemonic := os.Getenv(EnvMnemonic)
if envMnemonic != "" {
// Use mnemonic directly to derive long-term key
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0)
// Read vault metadata to get the correct derivation index
vaultDir, err := vault.GetDirectory()
if err != nil {
return nil, fmt.Errorf("failed to get vault directory: %w", err)
}
metadataPath := filepath.Join(vaultDir, "vault-metadata.json")
metadataBytes, err := afero.ReadFile(fs, metadataPath)
if err != nil {
return nil, fmt.Errorf("failed to read vault metadata: %w", err)
}
var metadata VaultMetadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
return nil, fmt.Errorf("failed to parse vault metadata: %w", err)
}
// Use mnemonic with the vault's actual derivation index
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex)
if err != nil {
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
}