diff --git a/TODO.md b/TODO.md index cb24a4a..928ac12 100644 --- a/TODO.md +++ b/TODO.md @@ -9,7 +9,7 @@ prioritized from most critical (top) to least critical (bottom). ### Functions accepting bare []byte for sensitive data - [x] **1. Secret.Save accepts unprotected data**: `internal/secret/secret.go:67` - `Save(value []byte, force bool)` - ✓ REMOVED - deprecated function deleted - [x] **2. EncryptWithPassphrase accepts unprotected data**: `internal/secret/crypto.go:73` - `EncryptWithPassphrase(data []byte, passphrase *memguard.LockedBuffer)` - ✓ FIXED - now accepts LockedBuffer for data -- [ ] **3. storeInKeychain accepts unprotected data**: `internal/secret/keychainunlocker.go:469` - `storeInKeychain(itemName string, data []byte)` - stores secrets in keychain with unprotected data +- [x] **3. storeInKeychain accepts unprotected data**: `internal/secret/keychainunlocker.go:469` - `storeInKeychain(itemName string, data []byte)` - ✓ FIXED - now accepts LockedBuffer for data - [ ] **4. gpgEncryptDefault accepts unprotected data**: `internal/secret/pgpunlocker.go:351` - `gpgEncryptDefault(data []byte, keyID string)` - encrypts unprotected data ### Functions returning unprotected secrets diff --git a/internal/secret/pgpunlock_test.go b/internal/secret/pgpunlock_test.go index 308a9c2..6ba8054 100644 --- a/internal/secret/pgpunlock_test.go +++ b/internal/secret/pgpunlock_test.go @@ -45,7 +45,10 @@ pinentry-mode loopback origDecryptFunc := secret.GPGDecryptFunc // Set custom GPG functions for this test - secret.GPGEncryptFunc = func(data []byte, keyID string) ([]byte, error) { + secret.GPGEncryptFunc = func(data *memguard.LockedBuffer, keyID string) ([]byte, error) { + if data == nil { + return nil, fmt.Errorf("data buffer is nil") + } cmd := exec.Command("gpg", "--homedir", gnupgHomeDir, "--batch", @@ -60,7 +63,7 @@ pinentry-mode loopback var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr - cmd.Stdin = bytes.NewReader(data) + cmd.Stdin = bytes.NewReader(data.Bytes()) if err := cmd.Run(); err != nil { return nil, fmt.Errorf("GPG encryption failed: %w\nStderr: %s", err, stderr.String()) @@ -444,8 +447,9 @@ Passphrase: ` + testPassphrase + ` } // GPG encrypt the private key using our custom encrypt function - privKeyData := []byte(ageIdentity.String()) - encryptedOutput, err := secret.GPGEncryptFunc(privKeyData, keyID) + privKeyBuffer := memguard.NewBufferFromBytes([]byte(ageIdentity.String())) + defer privKeyBuffer.Destroy() + encryptedOutput, err := secret.GPGEncryptFunc(privKeyBuffer, keyID) if err != nil { t.Fatalf("Failed to encrypt with GPG: %v", err) } diff --git a/internal/secret/pgpunlocker.go b/internal/secret/pgpunlocker.go index d252477..177d8cf 100644 --- a/internal/secret/pgpunlocker.go +++ b/internal/secret/pgpunlocker.go @@ -20,7 +20,8 @@ import ( var ( // GPGEncryptFunc is the function used for GPG encryption // Can be overridden in tests to provide a non-interactive implementation - GPGEncryptFunc = gpgEncryptDefault //nolint:gochecknoglobals // Required for test mocking + //nolint:gochecknoglobals // Required for test mocking + GPGEncryptFunc func(data *memguard.LockedBuffer, keyID string) ([]byte, error) = gpgEncryptDefault // GPGDecryptFunc is the function used for GPG decryption // Can be overridden in tests to provide a non-interactive implementation @@ -253,7 +254,7 @@ func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnloc agePrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(ageIdentity.String())) defer agePrivateKeyBuffer.Destroy() - encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBuffer.Bytes(), gpgKeyID) + encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBuffer, gpgKeyID) if err != nil { return nil, fmt.Errorf("failed to encrypt age private key with GPG: %w", err) } @@ -348,13 +349,16 @@ func checkGPGAvailable() error { } // gpgEncryptDefault is the default implementation of GPG encryption -func gpgEncryptDefault(data []byte, keyID string) ([]byte, error) { +func gpgEncryptDefault(data *memguard.LockedBuffer, keyID string) ([]byte, error) { + if data == nil { + return nil, fmt.Errorf("data buffer is nil") + } if err := validateGPGKeyID(keyID); err != nil { return nil, fmt.Errorf("invalid GPG key ID: %w", err) } cmd := exec.Command("gpg", "--trust-model", "always", "--armor", "--encrypt", "-r", keyID) - cmd.Stdin = strings.NewReader(string(data)) + cmd.Stdin = strings.NewReader(data.String()) output, err := cmd.Output() if err != nil {