Compare commits
99 Commits
6a8bd3388c
...
fix-memory
| Author | SHA1 | Date | |
|---|---|---|---|
| 7596049828 | |||
| d3ca006886 | |||
| f91281e991 | |||
| 7c5e78db17 | |||
| 8e374b3d24 | |||
| c9774e89e0 | |||
| f9938135c6 | |||
| 386a27c0b6 | |||
| 080a3dc253 | |||
| 811ddee3b7 | |||
| 4e242c3491 | |||
| 54fce0f187 | |||
| 93a32217e0 | |||
| 95ba80f618 | |||
| d710323bd0 | |||
| 38b450cbcf | |||
| 6fe49344e2 | |||
| 6e01ae6002 | |||
| 11e43542cf | |||
| 2256a37b72 | |||
| 533133486c | |||
| eb19fa4b97 | |||
| 5ed850196b | |||
| be1f323a09 | |||
| bdcddadf90 | |||
| 4062242063 | |||
| abcc7b6c3a | |||
| 9e35bf21a3 | |||
| 2a1e0337fd | |||
| dcc15008cd | |||
| dd2e95f8af | |||
| c450e1c13d | |||
| c6935d8f0f | |||
| 5d973f76ec | |||
| fd125c5fe1 | |||
| 08a42b16dd | |||
| b736789ecb | |||
| f569bc55ea | |||
| 9231409c5c | |||
| 0d140b4636 | |||
| 9e74b34b5d | |||
| 47afe117f4 | |||
| 4fe49ca8d0 | |||
| 8ca7796d04 | |||
| dcab84249f | |||
| e5b18202f3 | |||
| efc9456948 | |||
| c52430554a | |||
| fd7ab06fb1 | |||
| 434b73d834 | |||
| 985d79d3c0 | |||
| 004dce5472 | |||
| 0b31fba663 | |||
| 6958b2a6e2 | |||
| fd4194503c | |||
| a1800a8e88 | |||
| 03e0ee2f95 | |||
| 9adf0c0803 | |||
| e9d03987f9 | |||
| b0e3cdd3d0 | |||
| 2e3fc475cf | |||
| 1f89fce21b | |||
| 512b742c46 | |||
| 02be4b2a55 | |||
| e036d280c0 | |||
| ac81023ea0 | |||
| d76a4cbf4d | |||
| fbda2d91af | |||
| f59ee4d2d6 | |||
| 0bf8e71b52 | |||
| 34d6870e6a | |||
| 1a1b11c5a3 | |||
| 85d7ef21eb | |||
| a4d7225036 | |||
| 8dc2e9d748 | |||
| 8cc15fde3d | |||
| ddb395901b | |||
| c33385be6c | |||
| e95609ce69 | |||
| 345709a306 | |||
| 4b59d6fb82 | |||
| 5ca657c104 | |||
| bbaf1cbd97 | |||
| f838c8cb98 | |||
| 43767c725f | |||
| b26794e21a | |||
| 7dc14da4af | |||
| 3d90388b5b | |||
| 8c08c2e748 | |||
| ee49ace397 | |||
| 1b8ea9695b | |||
| 9f0f5cc8a1 | |||
| 89a8af2aa1 | |||
| 659b5ba508 | |||
| bb82d10f91 | |||
| c526b68f58 | |||
| 2443256338 | |||
| 354681b298 | |||
| efedbe405f |
30
.claude/settings.local.json
Normal file
30
.claude/settings.local.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(go mod why:*)",
|
||||
"Bash(go list:*)",
|
||||
"Bash(~/go/bin/govulncheck -mode=module .)",
|
||||
"Bash(go test:*)",
|
||||
"Bash(grep:*)",
|
||||
"Bash(rg:*)",
|
||||
"Bash(find:*)",
|
||||
"Bash(make test:*)",
|
||||
"Bash(go doc:*)",
|
||||
"Bash(make fmt:*)",
|
||||
"Bash(make:*)",
|
||||
"Bash(golangci-lint run:*)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(gofumpt:*)",
|
||||
"Bash(git stash:*)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(git push:*)",
|
||||
"Bash(golangci-lint:*)",
|
||||
"Bash(git checkout:*)",
|
||||
"Bash(ls:*)",
|
||||
"WebFetch(domain:golangci-lint.run)",
|
||||
"Bash(go:*)",
|
||||
"WebFetch(domain:pkg.go.dev)"
|
||||
],
|
||||
"deny": []
|
||||
}
|
||||
}
|
||||
3
.cursorrules
Normal file
3
.cursorrules
Normal file
@@ -0,0 +1,3 @@
|
||||
EXTREMELY IMPORTANT: Read and follow the policies, procedures, and
|
||||
instructions in the `AGENTS.md` file in the root of the repository. Make
|
||||
sure you follow *all* of the instructions meticulously.
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -1 +1,7 @@
|
||||
secret
|
||||
.DS_Store
|
||||
**/.DS_Store
|
||||
/secret
|
||||
*.log
|
||||
cli.test
|
||||
vault.test
|
||||
*.test
|
||||
|
||||
91
.golangci.yml
Normal file
91
.golangci.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
go: "1.24"
|
||||
tests: false
|
||||
|
||||
linters:
|
||||
enable:
|
||||
# Additional linters requested
|
||||
- testifylint # Checks usage of github.com/stretchr/testify
|
||||
- usetesting # usetesting is an analyzer that detects using os.Setenv instead of t.Setenv since Go 1.17
|
||||
- tagliatelle # Checks the struct tags
|
||||
- nlreturn # nlreturn checks for a new line before return and branch statements
|
||||
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value
|
||||
- nestif # Reports deeply nested if statements
|
||||
- mnd # An analyzer to detect magic numbers
|
||||
- lll # Reports long lines
|
||||
- intrange # intrange is a linter to find places where for loops could make use of an integer range
|
||||
- gochecknoglobals # Check that no global variables exist
|
||||
|
||||
# Default/existing linters that are commonly useful
|
||||
- govet
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- unused
|
||||
- ineffassign
|
||||
- misspell
|
||||
- revive
|
||||
- gosec
|
||||
- unconvert
|
||||
- unparam
|
||||
|
||||
linters-settings:
|
||||
lll:
|
||||
line-length: 120
|
||||
|
||||
mnd:
|
||||
# List of enabled checks, see https://github.com/tommy-muehle/go-mnd/#checks for description.
|
||||
checks:
|
||||
- argument
|
||||
- case
|
||||
- condition
|
||||
- operation
|
||||
- return
|
||||
- assign
|
||||
ignored-numbers:
|
||||
- '0'
|
||||
- '1'
|
||||
- '2'
|
||||
- '8'
|
||||
- '16'
|
||||
- '40' # GPG fingerprint length
|
||||
- '64'
|
||||
- '128'
|
||||
- '256'
|
||||
- '512'
|
||||
- '1024'
|
||||
- '2048'
|
||||
- '4096'
|
||||
|
||||
nestif:
|
||||
min-complexity: 4
|
||||
|
||||
nlreturn:
|
||||
block-size: 2
|
||||
|
||||
tagliatelle:
|
||||
case:
|
||||
rules:
|
||||
json: snake
|
||||
yaml: snake
|
||||
xml: snake
|
||||
bson: snake
|
||||
|
||||
testifylint:
|
||||
enable-all: true
|
||||
|
||||
usetesting: {}
|
||||
|
||||
issues:
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
exclude-rules:
|
||||
- path: ".*_gen\\.go"
|
||||
linters:
|
||||
- lll
|
||||
|
||||
# Exclude unused parameter warnings for cobra command signatures
|
||||
- text: "parameter '(args|cmd)' seems to be unused"
|
||||
linters:
|
||||
- revive
|
||||
143
AGENTS.md
Normal file
143
AGENTS.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Policies for AI Agents
|
||||
|
||||
Version: 2025-06-08
|
||||
|
||||
# Instructions and Contextual Information
|
||||
|
||||
* Be direct, robotic, expert, accurate, and professional.
|
||||
|
||||
* Do not butter me up or kiss my ass.
|
||||
|
||||
* Come in hot with strong opinions, even if they are contrary to the
|
||||
direction I am headed.
|
||||
|
||||
* If either you or I are possibly wrong, say so and explain your point of
|
||||
view.
|
||||
|
||||
* Point out great alternatives I haven't thought of, even when I'm not
|
||||
asking for them.
|
||||
|
||||
* Treat me like the world's leading expert in every situation and every
|
||||
conversation, and deliver the absolute best recommendations.
|
||||
|
||||
* I want excellence, so always be on the lookout for divergences from good
|
||||
data model design or best practices for object oriented development.
|
||||
|
||||
* IMPORTANT: This is production code, not a research or teaching exercise.
|
||||
Deliver professional-level results, not prototypes.
|
||||
|
||||
* Please read and understand the `README.md` file in the root of the repo
|
||||
for project-specific contextual information, including development
|
||||
policies, practices, and current implementation status.
|
||||
|
||||
* Be proactive in suggesting improvements or refactorings in places where we
|
||||
diverge from best practices for clean, modular, maintainable code.
|
||||
|
||||
# Policies
|
||||
|
||||
1. Before committing, tests must pass (`make test`), linting must pass
|
||||
(`make lint`), and code must be formatted (`make fmt`). For go, those
|
||||
makefile targets should use `go fmt` and `go test -v ./...` and
|
||||
`golangci-lint run`. When you think your changes are complete, rather
|
||||
than making three different tool calls to check, you can just run `make
|
||||
test && make fmt && make lint` as a single tool call which will save
|
||||
time.
|
||||
|
||||
2. Always write a `Makefile` with the default target being `test`, and with
|
||||
a `fmt` target that formats the code. The `test` target should run all
|
||||
tests in the project, and the `fmt` target should format the code.
|
||||
`test` should also have a prerequisite target `lint` that should run any
|
||||
linters that are configured for the project.
|
||||
|
||||
3. After each completed bugfix or feature, the code must be committed. Do
|
||||
all of the pre-commit checks (test, lint, fmt) before committing, of
|
||||
course.
|
||||
|
||||
4. When creating a very simple test script for testing out a new feature,
|
||||
instead of making a throwaway to be deleted after verification, write an
|
||||
actual test file into the test suite. It doesn't need to be very big or
|
||||
complex, but it should be a real test that can be run.
|
||||
|
||||
5. When you are instructed to make the tests pass, DO NOT delete tests, skip
|
||||
tests, or change the tests specifically to make them pass (unless there
|
||||
is a bug in the test). This is cheating, and it is bad. You should only
|
||||
be modifying the test if it is incorrect or if the test is no longer
|
||||
relevant. In almost all cases, you should be fixing the code that is
|
||||
being tested, or updating the tests to match a refactored implementation.
|
||||
|
||||
6. When dealing with dates and times or timestamps, always use, display, and
|
||||
store UTC. Set the local timezone to UTC on startup. If the user needs
|
||||
to see the time in a different timezone, store the user's timezone in a
|
||||
separate field and convert the UTC time to the user's timezone when
|
||||
displaying it. For internal use and internal applications and
|
||||
administrative purposes, always display UTC.
|
||||
|
||||
7. Always write tests, even if they are extremely simple and just check for
|
||||
correct syntax (ability to compile/import). If you are writing a new
|
||||
feature, write a test for it. You don't need to target complete
|
||||
coverage, but you should at least test any new functionality you add. If
|
||||
you are fixing a bug, write a test first that reproduces the bug, and
|
||||
then fix the bug in the code.
|
||||
|
||||
8. When implementing new features, be aware of potential side-effects (such
|
||||
as state files on disk, data in the database, etc.) and ensure that it is
|
||||
possible to mock or stub these side-effects in tests.
|
||||
|
||||
9. Always use structured logging. Log any relevant state/context with the
|
||||
messages (but do not log secrets). If stdout is not a terminal, output
|
||||
the structured logs in jsonl format.
|
||||
|
||||
10. Avoid using bare strings or numbers in code, especially if they appear
|
||||
anywhere more than once. Always define a constant (usually at the top
|
||||
of the file) and give it a descriptive name, then use that constant in
|
||||
the code instead of the bare string or number.
|
||||
|
||||
11. You do not need to summarize your changes in the chat after making them.
|
||||
Making the changes and committing them is sufficient. If anything out
|
||||
of the ordinary happened, please explain it, but in the normal case
|
||||
where you found and fixed the bug, or implemented the feature, there is
|
||||
no need for the end-of-change summary.
|
||||
|
||||
12. Do not create additional files in the root directory of the project
|
||||
without asking permission first. Configuration files, documentation, and
|
||||
build files are acceptable in the root, but source code and other files
|
||||
should be organized in appropriate subdirectories.
|
||||
|
||||
## Python-Specific Guidelines
|
||||
|
||||
1. **Type Annotations (UP006)**: Use built-in collection types directly for type annotations instead of importing from `typing`. This avoids the UP006 linter error.
|
||||
|
||||
**Good (modern Python 3.9+):**
|
||||
```python
|
||||
def process_items(items: list[str]) -> dict[str, int]:
|
||||
counts: dict[str, int] = {}
|
||||
return counts
|
||||
```
|
||||
|
||||
**Avoid (triggers UP006):**
|
||||
```python
|
||||
from typing import List, Dict
|
||||
|
||||
def process_items(items: List[str]) -> Dict[str, int]:
|
||||
counts: Dict[str, int] = {}
|
||||
return counts
|
||||
```
|
||||
|
||||
For optional types, use the `|` operator instead of `Union`:
|
||||
```python
|
||||
# Good
|
||||
def get_value(key: str) -> str | None:
|
||||
return None
|
||||
|
||||
# Avoid
|
||||
from typing import Optional, Union
|
||||
def get_value(key: str) -> Optional[str]:
|
||||
return None
|
||||
```
|
||||
|
||||
2. **Import Organization**: Follow the standard Python import order:
|
||||
- Standard library imports
|
||||
- Third-party imports
|
||||
- Local application imports
|
||||
|
||||
Each group should be separated by a blank line.
|
||||
28
CLAUDE.md
Normal file
28
CLAUDE.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Rules
|
||||
|
||||
Read the rules in AGENTS.md and follow them.
|
||||
|
||||
# Memory
|
||||
|
||||
* Claude is an inanimate tool. The spam that Claude attempts to insert into
|
||||
commit messages (which it erroneously refers to as "attribution") is not
|
||||
attribution, as I am the sole author of code created using Claude. It is
|
||||
corporate advertising for Anthropic and is therefore completely
|
||||
unacceptable in commit messages.
|
||||
|
||||
* Tests should always be run before committing code. No commits should be
|
||||
made that do not pass tests.
|
||||
|
||||
* Code should always be formatted before committing. Do not commit
|
||||
unformatted code.
|
||||
|
||||
* Code should always be linted before committing. Do not commit
|
||||
unlinted code.
|
||||
|
||||
* The test suite is fast and local. When running tests, don't run
|
||||
individual parts of the test suite, always run the whole thing by running
|
||||
"make test".
|
||||
|
||||
* Do not stop working on a task until you have reached the definition of
|
||||
done provided to you in the initial instruction. Don't do part or most of
|
||||
the work, do all of the work until the criteria for done are met.
|
||||
13
LICENSE
Normal file
13
LICENSE
Normal file
@@ -0,0 +1,13 @@
|
||||
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
|
||||
Version 2, December 2004
|
||||
|
||||
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim or modified
|
||||
copies of this license document, and changing it is allowed as long
|
||||
as the name is changed.
|
||||
|
||||
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
|
||||
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
|
||||
|
||||
0. You just DO WHAT THE FUCK YOU WANT TO.
|
||||
20
Makefile
20
Makefile
@@ -1,10 +1,26 @@
|
||||
default: test
|
||||
default: check
|
||||
|
||||
build: ./secret
|
||||
|
||||
# Simple build (no code signing needed)
|
||||
./secret:
|
||||
go build -v -o $@ cmd/secret/main.go
|
||||
|
||||
vet:
|
||||
go vet ./...
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
go test ./... || go test -v ./...
|
||||
|
||||
fmt:
|
||||
go fmt ./...
|
||||
|
||||
lint:
|
||||
golangci-lint run --timeout 5m
|
||||
|
||||
# Check all code quality (build + vet + lint + unit tests)
|
||||
check: ./secret vet lint test
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
rm -f ./secret
|
||||
|
||||
461
README.md
461
README.md
@@ -1,136 +1,375 @@
|
||||
# architecture
|
||||
# Secret - Hierarchical Secret Manager
|
||||
|
||||
* `secret vault` allows you to change 'vaults'. vaults are just a universe
|
||||
of secrets (and a single associated long-term key). you can have multiple
|
||||
vaults, each with its own long-term key and secrets. this is useful for
|
||||
separating work and personal secrets, or for separating different projects
|
||||
with different long-term keys.
|
||||
Secret is a modern, secure command-line secret manager that implements a hierarchical key architecture for storing and managing sensitive data. It supports multiple vaults, various unlock mechanisms, and provides secure storage using the Age encryption library.
|
||||
|
||||
* `secret vault list`
|
||||
* `secret vault create <name>` creates a new profile
|
||||
* `secret vault select <name>` selects (switches to) a profile
|
||||
## Core Architecture
|
||||
|
||||
the first and initial vault is titled `default`.
|
||||
### Three-Layer Key Hierarchy
|
||||
|
||||
* `secret init` initializes a new vault. this will create a new profile and
|
||||
generate a new long-term keypair. the long-term keypair is used to
|
||||
encrypt and decrypt secrets. the long-term keypair is stored in the
|
||||
vault. the private key for the vault is encrypted to a short-term
|
||||
keypair. the short-term keypair private key is encrypted to a passphrase.
|
||||
to generate the long-term keypair, a random bip32 seed phrase is
|
||||
generated, then the process proceeds exactly as `secret import private`.
|
||||
Secret implements a sophisticated three-layer key architecture:
|
||||
|
||||
the randomly generated bip32 seed phrase is shown to the user.
|
||||
1. **Long-term Keys**: Derived from BIP39 mnemonic phrases, these provide the foundation for all encryption
|
||||
2. **Unlockers**: Short-term keys that encrypt the long-term keys, supporting multiple authentication methods
|
||||
3. **Version-specific Keys**: Per-version keys that encrypt individual secret values
|
||||
|
||||
if there is already a vault, `secret init` exits with an error.
|
||||
### Version Management
|
||||
|
||||
* `secret import [vaultname]` will derive a long-term key pair from a bip32 seed
|
||||
phrase and import it into the named vault. if no vault name is specified,
|
||||
`default` is used. if the named vault already exists, it exits with
|
||||
an error.
|
||||
Each secret maintains a history of versions, with each version having:
|
||||
- Its own encryption key pair
|
||||
- Encrypted metadata including creation time and validity period
|
||||
- Immutable value storage
|
||||
- Atomic version switching via symlink updates
|
||||
|
||||
first:
|
||||
### Vault System
|
||||
|
||||
* the long term key pair will be derived in memory
|
||||
* a random short term (unlock) key pair will be derived in memory
|
||||
Vaults provide logical separation of secrets, each with its own long-term key and unlocker set. This allows for complete isolation between different contexts (work, personal, projects).
|
||||
|
||||
then:
|
||||
## Installation
|
||||
|
||||
* the long term key pair public key will be written to disk
|
||||
* the short term key pair public key will be written to disk
|
||||
* the long term key pair private key will be encrypted to the short term
|
||||
public key and written to disk
|
||||
* the short term key pair private key will be encrypted to a passphrase
|
||||
and written to disk
|
||||
Build from source:
|
||||
```bash
|
||||
git clone <repository>
|
||||
cd secret
|
||||
make build
|
||||
```
|
||||
|
||||
* `secret enroll sep` creates a short-term keypair inside the secure enclave
|
||||
of a macOS device. it will then use one of your existing short-term
|
||||
keypairs to decrypt the long-term keypair, re-encrypt it to the secure
|
||||
enclave short-term keypair, and write it to disk. it requires an existing
|
||||
vault, and errors otherwise.
|
||||
## Quick Start
|
||||
|
||||
* short-term keypairs are called 'unlock keys'.
|
||||
1. **Initialize the secret manager**:
|
||||
```bash
|
||||
secret init
|
||||
```
|
||||
This creates the default vault and prompts for a BIP39 mnemonic phrase.
|
||||
|
||||
* `secret add <secret>` adds a secret to the vault. this will generate a
|
||||
keypair (secret-specific key) and encrypt the private portion of the
|
||||
secret-specific key (called an 'unlock key') to the long-term keypair and
|
||||
write it to disk. if the secret already exists it will not overwrite, but
|
||||
will exit with an error, unless `--force`/`-f` is used. the secret
|
||||
identifier is [a-z0-9\.\-\_\/]+ and is used as a storage directory name.
|
||||
slashes are converted to % signs.
|
||||
2. **Generate a mnemonic** (if needed):
|
||||
```bash
|
||||
secret generate mnemonic
|
||||
```
|
||||
|
||||
in a future version, overwriting a secret will cause the current secret to
|
||||
get moved to a timestamped history archive.
|
||||
3. **Add a secret**:
|
||||
```bash
|
||||
echo "my-password" | secret add myservice/password
|
||||
```
|
||||
|
||||
* `secret get <secret>` retrieves a secret from the vault. this will use an
|
||||
unlock keypair to decrypt the long-term keypair in memory, then use the
|
||||
long-term keypair to decrypt the secret-specific keypair, which is then
|
||||
used to decrypt the secret. the secret is then returned in plaintext on
|
||||
stdout.
|
||||
4. **Retrieve a secret**:
|
||||
```bash
|
||||
secret get myservice/password
|
||||
```
|
||||
|
||||
* `secret keys list` lists the short-term keypairs in the current vault.
|
||||
this will show the public keys of the short-term keypairs and their
|
||||
creation dates, as well as any flags (such as `hsm`). their identifiers
|
||||
are a metahash of the public key data using the sha256 algorithm.
|
||||
## Commands Reference
|
||||
|
||||
* `secret keys rm <keyid>` removes a short-term keypair from the vault. this will
|
||||
remove the short-term keypair from the vault, and remove the long-term
|
||||
keypair from the short-term keypair.
|
||||
### Initialization
|
||||
|
||||
* `secret keys add pgp <pgp keyid>` adds a new short-term keypair to the vault.
|
||||
this will generate a new short-term keypair and encrypt it to a given gpg
|
||||
key, to allow unlocking a vault with an existing gpg key, for people who
|
||||
use yubikeys or other gpg keys with an agent. the new short-term keypair
|
||||
is randomly generated, the public key stored, and the private key encrypted
|
||||
to the gpg key and stored.
|
||||
#### `secret init`
|
||||
Initializes the secret manager with a default vault. Prompts for a BIP39 mnemonic phrase and creates the initial directory structure.
|
||||
|
||||
* `secret key select <keyid>` selects a short-term keypair to use for
|
||||
`secret get` operations.
|
||||
**Environment Variables:**
|
||||
- `SB_SECRET_MNEMONIC`: Pre-set mnemonic phrase
|
||||
- `SB_UNLOCK_PASSPHRASE`: Pre-set unlock passphrase
|
||||
|
||||
# file layout
|
||||
### Vault Management
|
||||
|
||||
$BASE = ~/.config/berlin.sneak.pkg.secret (on linux per XDG)
|
||||
$BASE = ~/Library/Application Support/berlin.sneak.pkg.secret (on macOS)
|
||||
#### `secret vault list [--json]`
|
||||
Lists all available vaults.
|
||||
|
||||
$BASE/configuration.json
|
||||
$BASE/currentvault -> $BASE/vaults.d/default (symlink)
|
||||
$BASE/vaults.d/default/
|
||||
$BASE/vaults.d/default/vault-metadata.json
|
||||
$BASE/vaults.d/default/pub.age
|
||||
$BASE/vaults.d/default/current-unlock-key -> $BASE/vaults.d/default/unlock.d/passphrase (symlink)
|
||||
$BASE/vaults.d/default/unlock.d/passphrase/unlock-metadata.json
|
||||
$BASE/vaults.d/default/unlock.d/passphrase/pub.age
|
||||
$BASE/vaults.d/default/unlock.d/passphrase/priv.age
|
||||
$BASE/vaults.d/default/unlock.d/passphrase/longterm.age # long-term keypair, encrypted to this short-term keypair
|
||||
$BASE/vaults.d/default/unlock.d/sep/unlock-metadata.json
|
||||
$BASE/vaults.d/default/unlock.d/sep/pub.age
|
||||
$BASE/vaults.d/default/unlock.d/sep/priv.age
|
||||
$BASE/vaults.d/default/unlock.d/sep/longterm.age # long-term keypair, encrypted to this short-term keypair
|
||||
$BASE/vaults.d/default/unlock.d/pgp/unlock-metadata.json
|
||||
$BASE/vaults.d/default/unlock.d/pgp/pub.age
|
||||
$BASE/vaults.d/default/unlock.d/pgp/priv.asc
|
||||
$BASE/vaults.d/default/unlock.d/pgp/longterm.age # long-term keypair, encrypted to this short-term keypair
|
||||
$BASE/vaults.d/default/secrets.d/my-tinder-password/value.age
|
||||
$BASE/vaults.d/default/secrets.d/my-tinder-password/pub.age
|
||||
$BASE/vaults.d/default/secrets.d/my-tinder-password/priv.age # secret-specific key, encrypted to long-term key
|
||||
$BASE/vaults.d/default/secrets.d/my-tinder-password/secret-metadata.json
|
||||
$BASE/vaults.d/default/secrets.d/mail%berlin.sneak.secrets.imaplogin/value.age
|
||||
$BASE/vaults.d/default/secrets.d/mail%berlin.sneak.secrets.imaplogin/pub.age
|
||||
$BASE/vaults.d/default/secrets.d/mail%berlin.sneak.secrets.imaplogin/priv.age
|
||||
$BASE/vaults.d/default/secrets.d/mail%berlin.sneak.secrets.imaplogin/secret-metadata.json
|
||||
#### `secret vault create <name>`
|
||||
Creates a new vault with the specified name.
|
||||
|
||||
# example configuration.json
|
||||
#### `secret vault select <name>`
|
||||
Switches to the specified vault for subsequent operations.
|
||||
|
||||
### Secret Management
|
||||
|
||||
#### `secret add <secret-name> [--force]`
|
||||
Adds a secret to the current vault. Reads the secret value from stdin.
|
||||
- `--force, -f`: Overwrite existing secret
|
||||
|
||||
**Secret Name Format:** `[a-z0-9\.\-\_\/]+`
|
||||
- Forward slashes (`/`) are converted to percent signs (`%`) for storage
|
||||
- Examples: `database/password`, `api.key`, `ssh_private_key`
|
||||
|
||||
#### `secret get <secret-name> [--version <version>]`
|
||||
Retrieves and outputs a secret value to stdout.
|
||||
- `--version, -v`: Get a specific version (default: current)
|
||||
|
||||
#### `secret list [filter] [--json]` / `secret ls`
|
||||
Lists all secrets in the current vault. Optional filter for substring matching.
|
||||
|
||||
### Version Management
|
||||
|
||||
#### `secret version list <secret-name>`
|
||||
Lists all versions of a secret showing creation time, status, and validity period.
|
||||
|
||||
#### `secret version promote <secret-name> <version>`
|
||||
Promotes a specific version to current by updating the symlink. Does not modify any timestamps, allowing for rollback scenarios.
|
||||
|
||||
### Key Generation
|
||||
|
||||
#### `secret generate mnemonic`
|
||||
Generates a cryptographically secure BIP39 mnemonic phrase.
|
||||
|
||||
#### `secret generate secret <name> [--length=16] [--type=base58] [--force]`
|
||||
Generates and stores a random secret.
|
||||
- `--length, -l`: Length of generated secret (default: 16)
|
||||
- `--type, -t`: Type of secret (`base58`, `alnum`)
|
||||
- `--force, -f`: Overwrite existing secret
|
||||
|
||||
### Unlocker Management
|
||||
|
||||
#### `secret unlockers list [--json]`
|
||||
Lists all unlockers in the current vault with their metadata.
|
||||
|
||||
#### `secret unlockers add <type> [options]`
|
||||
Creates a new unlocker of the specified type:
|
||||
|
||||
**Types:**
|
||||
- `passphrase`: Traditional passphrase-protected unlocker
|
||||
- `pgp`: Uses an existing GPG key for encryption/decryption
|
||||
|
||||
**Options:**
|
||||
- `--keyid <id>`: GPG key ID (required for PGP type)
|
||||
|
||||
#### `secret unlockers rm <unlocker-id>`
|
||||
Removes an unlocker.
|
||||
|
||||
#### `secret unlocker select <unlocker-id>`
|
||||
Selects an unlocker as the current default for operations.
|
||||
|
||||
### Import Operations
|
||||
|
||||
#### `secret import <secret-name> --source <filename>`
|
||||
Imports a secret from a file and stores it in the current vault under the given name.
|
||||
|
||||
#### `secret vault import [vault-name]`
|
||||
Imports a mnemonic phrase into the specified vault (defaults to "default").
|
||||
|
||||
### Encryption Operations
|
||||
|
||||
#### `secret encrypt <secret-name> [--input=file] [--output=file]`
|
||||
Encrypts data using an Age key stored as a secret. If the secret doesn't exist, generates a new Age key.
|
||||
|
||||
#### `secret decrypt <secret-name> [--input=file] [--output=file]`
|
||||
Decrypts data using an Age key stored as a secret.
|
||||
|
||||
## Storage Architecture
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
~/.local/share/secret/
|
||||
├── vaults.d/
|
||||
│ ├── default/
|
||||
│ │ ├── unlockers.d/
|
||||
│ │ │ ├── passphrase/ # Passphrase unlocker
|
||||
│ │ │ └── pgp/ # PGP unlocker
|
||||
│ │ ├── secrets.d/
|
||||
│ │ │ ├── api%key/ # Secret: api/key
|
||||
│ │ │ │ ├── versions/
|
||||
│ │ │ │ │ ├── 20231215.001/ # Version directory
|
||||
│ │ │ │ │ │ ├── pub.age # Version public key
|
||||
│ │ │ │ │ │ ├── priv.age # Version private key (encrypted)
|
||||
│ │ │ │ │ │ ├── value.age # Encrypted value
|
||||
│ │ │ │ │ │ └── metadata.age # Encrypted metadata
|
||||
│ │ │ │ │ └── 20231216.001/ # Another version
|
||||
│ │ │ │ └── current -> versions/20231216.001
|
||||
│ │ │ └── database%password/ # Secret: database/password
|
||||
│ │ │ ├── versions/
|
||||
│ │ │ └── current -> versions/20231215.001
|
||||
│ │ ├── vault-metadata.json # Vault metadata
|
||||
│ │ ├── pub.age # Long-term public key
|
||||
│ │ └── current-unlocker -> ../unlockers.d/passphrase
|
||||
│ └── work/
|
||||
│ ├── unlockers.d/
|
||||
│ ├── secrets.d/
|
||||
│ ├── vault-metadata.json
|
||||
│ ├── pub.age
|
||||
│ └── current-unlocker
|
||||
└── currentvault -> vaults.d/default
|
||||
```
|
||||
|
||||
### Key Management and Encryption Flow
|
||||
|
||||
#### Long-term Keys
|
||||
- **Source**: Derived from BIP39 mnemonic phrases using hierarchical deterministic (HD) key derivation
|
||||
- **Purpose**: Master keys for each vault, used to encrypt secret-specific keys
|
||||
- **Storage**: Public key stored as `pub.age`, private key encrypted by unlockers
|
||||
|
||||
#### Unlockers
|
||||
Unlockers provide different authentication methods to access the long-term keys:
|
||||
|
||||
1. **Passphrase Unlockers**:
|
||||
- Encrypted with user-provided passphrase
|
||||
- Stored as encrypted Age keys
|
||||
- Cross-platform compatible
|
||||
|
||||
2. **PGP Unlockers**:
|
||||
- Uses existing GPG key infrastructure
|
||||
- Leverages existing key management workflows
|
||||
- Strong authentication through GPG
|
||||
|
||||
Each vault maintains its own set of unlockers and one long-term key. The long-term key is encrypted to each unlocker, allowing any authorized unlocker to access vault secrets.
|
||||
|
||||
#### Secret-specific Keys
|
||||
- Each secret has its own encryption key pair
|
||||
- Private key encrypted to the vault's long-term key
|
||||
- Provides forward secrecy and granular access control
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `SB_SECRET_STATE_DIR`: Custom state directory location
|
||||
- `SB_SECRET_MNEMONIC`: Pre-set mnemonic phrase (avoids interactive prompt)
|
||||
- `SB_UNLOCK_PASSPHRASE`: Pre-set unlock passphrase (avoids interactive prompt)
|
||||
- `SB_GPG_KEY_ID`: GPG key ID for PGP unlockers
|
||||
|
||||
## Security Features
|
||||
|
||||
### Encryption
|
||||
- Uses the [Age encryption library](https://age-encryption.org/) with X25519 keys
|
||||
- All private keys are encrypted at rest
|
||||
- No plaintext secrets stored on disk
|
||||
|
||||
### Access Control
|
||||
- Multiple authentication methods supported
|
||||
- Hierarchical key architecture provides defense in depth
|
||||
- Vault isolation prevents cross-contamination
|
||||
|
||||
### Forward Secrecy
|
||||
- Per-version encryption keys limit exposure if compromised
|
||||
- Each version is independently encrypted
|
||||
- Long-term keys protected by multiple unlocker layers
|
||||
- Historical versions remain encrypted with their original keys
|
||||
|
||||
### Hardware Integration
|
||||
- Hardware token support via PGP/GPG integration
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Workflow
|
||||
```bash
|
||||
# Initialize with a new mnemonic
|
||||
secret generate mnemonic # Copy the output
|
||||
secret init # Paste the mnemonic when prompted
|
||||
|
||||
# Add some secrets
|
||||
echo "supersecret123" | secret add database/prod/password
|
||||
echo "api-key-xyz" | secret add services/api/key
|
||||
echo "ssh-private-key-content" | secret add ssh/servers/web01
|
||||
|
||||
# List and retrieve secrets
|
||||
secret list
|
||||
secret get database/prod/password
|
||||
secret get services/api/key
|
||||
```
|
||||
|
||||
### Multi-vault Setup
|
||||
```bash
|
||||
# Create separate vaults for different contexts
|
||||
secret vault create work
|
||||
secret vault create personal
|
||||
|
||||
# Work with work vault
|
||||
secret vault select work
|
||||
echo "work-db-pass" | secret add database/password
|
||||
secret unlockers add passphrase # Add passphrase authentication
|
||||
|
||||
# Switch to personal vault
|
||||
secret vault select personal
|
||||
echo "personal-email-pass" | secret add email/password
|
||||
|
||||
# List all vaults
|
||||
secret vault list
|
||||
```
|
||||
|
||||
### Advanced Authentication
|
||||
```bash
|
||||
# Add multiple unlock methods
|
||||
secret unlockers add passphrase # Password-based
|
||||
secret unlockers add pgp --keyid ABCD1234 # GPG key
|
||||
|
||||
# List unlockers
|
||||
secret unlockers list
|
||||
|
||||
# Select a specific unlocker
|
||||
secret unlocker select <unlocker-id>
|
||||
```
|
||||
|
||||
### Encryption/Decryption with Age Keys
|
||||
```bash
|
||||
# Generate an Age key and store it as a secret
|
||||
secret generate secret encryption/mykey
|
||||
|
||||
# Encrypt a file using the stored key
|
||||
secret encrypt encryption/mykey --input document.txt --output document.txt.age
|
||||
|
||||
# Decrypt the file
|
||||
secret decrypt encryption/mykey --input document.txt.age --output document.txt
|
||||
```
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Cryptographic Primitives
|
||||
- **Key Derivation**: BIP32/BIP39 hierarchical deterministic key derivation
|
||||
- **Encryption**: Age (X25519 + ChaCha20-Poly1305)
|
||||
- **Key Exchange**: X25519 elliptic curve Diffie-Hellman
|
||||
- **Authentication**: Poly1305 MAC
|
||||
- **Hashing**: Double SHA-256 for public key identification
|
||||
|
||||
### File Formats
|
||||
- **Age Files**: Standard Age encryption format (.age extension)
|
||||
- **Metadata**: JSON format with timestamps and type information
|
||||
- **Vault Metadata**: JSON containing vault name, creation time, derivation index, and public key hash
|
||||
|
||||
### Vault Management
|
||||
- **Derivation Index**: Each vault uses a unique derivation index from the mnemonic
|
||||
- **Public Key Hash**: Double SHA-256 hash of the index-0 public key identifies vaults from the same mnemonic
|
||||
- **Automatic Key Derivation**: When creating vaults with a mnemonic, keys are automatically derived
|
||||
|
||||
### Cross-Platform Support
|
||||
- **macOS**: Full support including Keychain integration
|
||||
- **Linux**: Full support (excluding Keychain features)
|
||||
- **Windows**: Basic support (filesystem operations only)
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Threat Model
|
||||
- Protects against unauthorized access to secret values
|
||||
- Provides defense against compromise of individual components
|
||||
- Supports hardware-backed authentication where available
|
||||
|
||||
### Best Practices
|
||||
1. Use strong, unique passphrases for unlockers
|
||||
2. Enable hardware authentication (Keychain, hardware tokens) when available
|
||||
3. Regularly audit unlockers and remove unused ones
|
||||
4. Keep mnemonic phrases securely backed up offline
|
||||
5. Use separate vaults for different security contexts
|
||||
|
||||
### Limitations
|
||||
- Requires access to unlockers for secret retrieval
|
||||
- Mnemonic phrases must be securely stored and backed up
|
||||
- Hardware features limited to supported platforms
|
||||
|
||||
## Development
|
||||
|
||||
### Building
|
||||
```bash
|
||||
make build # Build binary
|
||||
make test # Run tests
|
||||
make lint # Run linter
|
||||
```
|
||||
|
||||
### Testing
|
||||
The project includes comprehensive tests:
|
||||
```bash
|
||||
make test # Run all tests
|
||||
go test ./... # Unit tests
|
||||
go test -tags=integration -v ./internal/cli # Integration tests
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Authentication Methods**: Supports passphrase-based and PGP-based unlockers
|
||||
- **Vault Isolation**: Complete separation between different vaults
|
||||
- **Per-Secret Encryption**: Each secret has its own encryption key
|
||||
- **BIP39 Mnemonic Support**: Keyless operation using mnemonic phrases
|
||||
- **Cross-Platform**: Works on macOS, Linux, and other Unix-like systems
|
||||
|
||||
`json
|
||||
{
|
||||
"$id": "https://berlin.sneak.pkg.secret/configuration.json",
|
||||
"$schema": "https://berlin.sneak.pkg.secret/configuration.schema.json",
|
||||
"version": 1,
|
||||
"configuration": {
|
||||
"createdAt": "2025-01-01T00:00:00Z",
|
||||
"requireAuth": false,
|
||||
"pubKey": "<public key of the long-term keypair>",
|
||||
"pubKeyFingerprint": "<fingerprint of the long-term keypair>"
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
102
TODO.md
Normal file
102
TODO.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# TODO for 1.0 Release
|
||||
|
||||
This document outlines the bugs, issues, and improvements that need to be
|
||||
addressed before the 1.0 release of the secret manager. Items are
|
||||
prioritized from most critical (top) to least critical (bottom).
|
||||
|
||||
## Code Cleanups
|
||||
|
||||
* we shouldn't be passing around a statedir, it should be read from the
|
||||
environment or default.
|
||||
|
||||
## HIGH PRIORITY SECURITY ISSUES
|
||||
|
||||
- [ ] **4. Application crashes on corrupted metadata**: Code panics instead
|
||||
of returning errors when metadata is corrupt, causing denial of service.
|
||||
Found in pgpunlocker.go:116 and keychainunlocker.go:141.
|
||||
|
||||
- [ ] **5. Insufficient input validation**: Secret names allow potentially
|
||||
dangerous patterns including dots that could enable path traversal attacks
|
||||
(vault/secrets.go:70-93).
|
||||
|
||||
- [ ] **6. Race conditions in file operations**: Multiple concurrent
|
||||
operations could corrupt the vault state due to lack of file locking
|
||||
mechanisms.
|
||||
|
||||
- [ ] **7. Insecure temporary file handling**: Temporary files containing
|
||||
sensitive data may not be properly cleaned up or secured.
|
||||
|
||||
## HIGH PRIORITY FUNCTIONALITY ISSUES
|
||||
|
||||
- [ ] **8. Inappropriate Cobra usage printing**: Commands currently print
|
||||
usage information for all errors, including internal program failures.
|
||||
Usage should only be printed when the user provides incorrect arguments or
|
||||
invalid commands.
|
||||
|
||||
- [ ] **9. Missing current unlock key initialization**: When creating
|
||||
vaults, no default unlock key is selected, which can cause operations to
|
||||
fail.
|
||||
|
||||
- [ ] **10. Add confirmation prompts for destructive operations**:
|
||||
Operations like `keys rm` and vault deletion should require confirmation.
|
||||
|
||||
- [ ] **11. No secret deletion command**: Missing `secret rm <secret-name>`
|
||||
functionality.
|
||||
|
||||
- [ ] **12. Missing vault deletion command**: No way to delete vaults that
|
||||
are no longer needed.
|
||||
|
||||
## MEDIUM PRIORITY ISSUES
|
||||
|
||||
- [ ] **13. Inconsistent error messages**: Error messages need
|
||||
standardization and should be user-friendly. Many errors currently expose
|
||||
internal implementation details.
|
||||
|
||||
- [ ] **14. No graceful handling of corrupted state**: If key files are
|
||||
corrupted or missing, the tool should provide clear error messages and
|
||||
recovery suggestions.
|
||||
|
||||
- [ ] **15. No validation of GPG key existence**: Should verify the
|
||||
specified GPG key exists before creating PGP unlock keys.
|
||||
|
||||
- [ ] **16. Better separation of concerns**: Some functions in CLI do too
|
||||
much and should be split.
|
||||
|
||||
- [ ] **17. Environment variable security**: Sensitive data read from
|
||||
environment variables (SB_UNLOCK_PASSPHRASE, SB_SECRET_MNEMONIC) without
|
||||
proper clearing. Document security implications.
|
||||
|
||||
- [ ] **18. No secure memory allocation**: No use of mlock/munlock to
|
||||
prevent sensitive data from being swapped to disk.
|
||||
|
||||
## LOWER PRIORITY ENHANCEMENTS
|
||||
|
||||
- [ ] **19. Add `--help` examples**: Command help should include practical examples for each operation.
|
||||
|
||||
- [ ] **20. Add shell completion**: Bash/Zsh completion for commands and secret names.
|
||||
|
||||
- [ ] **21. Colored output**: Use colors to improve readability of lists and error messages.
|
||||
|
||||
- [ ] **22. Add `--quiet` flag**: Option to suppress non-essential output.
|
||||
|
||||
- [ ] **23. Smart secret name suggestions**: When a secret name is not found, suggest similar names.
|
||||
|
||||
- [ ] **24. Audit logging**: Log all secret access and modifications for security auditing.
|
||||
|
||||
- [ ] **25. Integration tests for hardware features**: Automated testing of Keychain and GPG functionality.
|
||||
|
||||
- [ ] **26. Consistent naming conventions**: Some variables and functions use inconsistent naming patterns.
|
||||
|
||||
- [ ] **27. Export/import functionality**: Add ability to export/import entire vaults, not just individual secrets.
|
||||
|
||||
- [ ] **28. Batch operations**: Add commands to process multiple secrets at once.
|
||||
|
||||
- [ ] **29. Search functionality**: Add ability to search secret names and potentially contents.
|
||||
|
||||
- [ ] **30. Secret metadata**: Add support for descriptions, tags, or other metadata with secrets.
|
||||
|
||||
## COMPLETED ITEMS ✓
|
||||
|
||||
- [x] **Missing secret history/versioning**: ✓ Implemented - versioning system exists with --version flag support
|
||||
- [x] **XDG compliance on Linux**: ✓ Implemented - uses os.UserConfigDir() which respects XDG_CONFIG_HOME
|
||||
- [x] **Consistent interface implementation**: ✓ Implemented - Unlocker interface is well-defined and consistently implemented
|
||||
8
cmd/secret/main.go
Normal file
8
cmd/secret/main.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Package main is the entry point for the secret CLI application.
|
||||
package main
|
||||
|
||||
import "git.eeqj.de/sneak/secret/internal/cli"
|
||||
|
||||
func main() {
|
||||
cli.Entry()
|
||||
}
|
||||
27
go.mod
27
go.mod
@@ -2,20 +2,31 @@ module git.eeqj.de/sneak/secret
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require github.com/spf13/cobra v1.9.1
|
||||
require (
|
||||
filippo.io/age v1.2.1
|
||||
github.com/awnumar/memguard v0.22.5
|
||||
github.com/btcsuite/btcd v0.24.2
|
||||
github.com/btcsuite/btcd/btcec/v2 v2.1.3
|
||||
github.com/btcsuite/btcd/btcutil v1.1.6
|
||||
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d
|
||||
github.com/oklog/ulid/v2 v2.1.1
|
||||
github.com/spf13/afero v1.14.0
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/tyler-smith/go-bip39 v1.1.0
|
||||
golang.org/x/crypto v0.38.0
|
||||
golang.org/x/term v0.32.0
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/age v1.2.1 // indirect
|
||||
github.com/btcsuite/btcd v0.24.2 // indirect
|
||||
github.com/btcsuite/btcd/btcec/v2 v2.1.3 // indirect
|
||||
github.com/btcsuite/btcd/btcutil v1.1.6 // indirect
|
||||
github.com/awnumar/memcall v0.2.0 // indirect
|
||||
github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/spf13/afero v1.14.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/tyler-smith/go-bip39 v1.1.0 // indirect
|
||||
golang.org/x/crypto v0.38.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.25.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
22
go.sum
22
go.sum
@@ -1,6 +1,12 @@
|
||||
c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805 h1:u2qwJeEvnypw+OCPUHmoZE3IqwfuN5kgDfo5MLzpNM0=
|
||||
c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805/go.mod h1:FomMrUJ2Lxt5jCLmZkG3FHa72zUprnhd3v/Z18Snm4w=
|
||||
filippo.io/age v1.2.1 h1:X0TZjehAZylOIj4DubWYU1vWQxv9bJpo+Uu2/LGhi1o=
|
||||
filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004=
|
||||
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
||||
github.com/awnumar/memcall v0.2.0 h1:sRaogqExTOOkkNwO9pzJsL8jrOV29UuUW7teRMfbqtI=
|
||||
github.com/awnumar/memcall v0.2.0/go.mod h1:S911igBPR9CThzd/hYQQmTc9SWNu3ZHIlCGaWsWsoJo=
|
||||
github.com/awnumar/memguard v0.22.5 h1:PH7sbUVERS5DdXh3+mLo8FDcl1eIeVjJVYMnyuYpvuI=
|
||||
github.com/awnumar/memguard v0.22.5/go.mod h1:+APmZGThMBWjnMlKiSM1X7MVpbIVewen2MTkqWkA/zE=
|
||||
github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ=
|
||||
github.com/btcsuite/btcd v0.22.0-beta.0.20220111032746-97732e52810c/go.mod h1:tjmYdS6MLJ5/s0Fj4DbLgSbDHbEqLJrtnHecBFkdz5M=
|
||||
github.com/btcsuite/btcd v0.23.5-0.20231215221805-96c9fd8078fd/go.mod h1:nm3Bko6zh6bWP60UxwoT5LzdGJsQJaPo6HjduXq9p6A=
|
||||
@@ -19,6 +25,7 @@ github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtyd
|
||||
github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 h1:59Kx4K6lzOW5w6nFlA0v5+lk/6sjybR934QNHSJZPTQ=
|
||||
github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
|
||||
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA=
|
||||
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d h1:yJzD/yFppdVCf6ApMkVy8cUxV0XrxdP9rVf6D87/Mng=
|
||||
github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg=
|
||||
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg=
|
||||
github.com/btcsuite/goleveldb v0.0.0-20160330041536-7834afc9e8cd/go.mod h1:F+uVaaLLH7j4eDXPRvw78tMflu7Ie2bzYOH4Y8rRKBY=
|
||||
@@ -30,6 +37,7 @@ github.com/btcsuite/winsvc v1.0.0/go.mod h1:jsenWakMcC0zFBFurPLEAyrnc/teJEM1O46f
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc=
|
||||
@@ -57,6 +65,8 @@ github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJS
|
||||
github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ=
|
||||
github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s=
|
||||
github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
|
||||
@@ -65,6 +75,8 @@ github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA=
|
||||
@@ -79,6 +91,7 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc=
|
||||
github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8=
|
||||
@@ -86,8 +99,6 @@ github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3C
|
||||
golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||
golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
|
||||
golang.org/x/net v0.0.0-20180719180050-a680a1efc54d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -100,15 +111,16 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
|
||||
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -123,6 +135,7 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ
|
||||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
|
||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
@@ -130,4 +143,5 @@ gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
package agehd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"filippo.io/age"
|
||||
)
|
||||
|
||||
const (
|
||||
mnemonic = "abandon abandon abandon abandon abandon " +
|
||||
"abandon abandon abandon abandon abandon abandon about"
|
||||
|
||||
// Test xprv from BIP85 test vectors
|
||||
testXPRV = "xprv9s21ZrQH143K2LBWUUQRFXhucrQqBpKdRRxNVq2zBqsx8HVqFk2uYo8kmbaLLHRdqtQpUm98uKfu3vca1LqdGhUtyoFnCNkfmXRyPXLjbKb"
|
||||
)
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
id, err := DeriveIdentity(mnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("secret: %s", id.String())
|
||||
t.Logf("recipient: %s", id.Recipient().String())
|
||||
|
||||
var ct bytes.Buffer
|
||||
w, err := age.Encrypt(&ct, id.Recipient())
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt init: %v", err)
|
||||
}
|
||||
if _, err = io.WriteString(w, "hello world"); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
if err = w.Close(); err != nil {
|
||||
t.Fatalf("encrypt close: %v", err)
|
||||
}
|
||||
|
||||
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), id)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt init: %v", err)
|
||||
}
|
||||
dec, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
|
||||
if got := string(dec); got != "hello world" {
|
||||
t.Fatalf("round-trip mismatch: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveIdentityFromXPRV(t *testing.T) {
|
||||
id, err := DeriveIdentityFromXPRV(testXPRV, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive from xprv: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("xprv secret: %s", id.String())
|
||||
t.Logf("xprv recipient: %s", id.Recipient().String())
|
||||
|
||||
// Test encryption/decryption with xprv-derived identity
|
||||
var ct bytes.Buffer
|
||||
w, err := age.Encrypt(&ct, id.Recipient())
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt init: %v", err)
|
||||
}
|
||||
if _, err = io.WriteString(w, "hello from xprv"); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
if err = w.Close(); err != nil {
|
||||
t.Fatalf("encrypt close: %v", err)
|
||||
}
|
||||
|
||||
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), id)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt init: %v", err)
|
||||
}
|
||||
dec, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
|
||||
if got := string(dec); got != "hello from xprv" {
|
||||
t.Fatalf("round-trip mismatch: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeterministicDerivation(t *testing.T) {
|
||||
// Test that the same mnemonic and index always produce the same identity
|
||||
id1, err := DeriveIdentity(mnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive 1: %v", err)
|
||||
}
|
||||
|
||||
id2, err := DeriveIdentity(mnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive 2: %v", err)
|
||||
}
|
||||
|
||||
if id1.String() != id2.String() {
|
||||
t.Fatalf("identities should be deterministic: %s != %s", id1.String(), id2.String())
|
||||
}
|
||||
|
||||
// Test that different indices produce different identities
|
||||
id3, err := DeriveIdentity(mnemonic, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("derive 3: %v", err)
|
||||
}
|
||||
|
||||
if id1.String() == id3.String() {
|
||||
t.Fatalf("different indices should produce different identities")
|
||||
}
|
||||
|
||||
t.Logf("Index 0: %s", id1.String())
|
||||
t.Logf("Index 1: %s", id3.String())
|
||||
}
|
||||
|
||||
func TestDeterministicXPRVDerivation(t *testing.T) {
|
||||
// Test that the same xprv and index always produce the same identity
|
||||
id1, err := DeriveIdentityFromXPRV(testXPRV, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive 1: %v", err)
|
||||
}
|
||||
|
||||
id2, err := DeriveIdentityFromXPRV(testXPRV, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive 2: %v", err)
|
||||
}
|
||||
|
||||
if id1.String() != id2.String() {
|
||||
t.Fatalf("xprv identities should be deterministic: %s != %s", id1.String(), id2.String())
|
||||
}
|
||||
|
||||
// Test that different indices with same xprv produce different identities
|
||||
id3, err := DeriveIdentityFromXPRV(testXPRV, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("derive 3: %v", err)
|
||||
}
|
||||
|
||||
if id1.String() == id3.String() {
|
||||
t.Fatalf("different indices should produce different identities")
|
||||
}
|
||||
|
||||
t.Logf("XPRV Index 0: %s", id1.String())
|
||||
t.Logf("XPRV Index 1: %s", id3.String())
|
||||
}
|
||||
|
||||
func TestMnemonicVsXPRVConsistency(t *testing.T) {
|
||||
// Test that deriving from mnemonic and from the corresponding xprv produces the same result
|
||||
// Note: This test is removed because the test mnemonic and test xprv are from different sources
|
||||
// and are not expected to produce the same results.
|
||||
t.Skip("Skipping consistency test - test mnemonic and xprv are from different sources")
|
||||
}
|
||||
|
||||
func TestEntropyLength(t *testing.T) {
|
||||
// Test that DeriveEntropy returns exactly 32 bytes
|
||||
entropy, err := DeriveEntropy(mnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive entropy: %v", err)
|
||||
}
|
||||
|
||||
if len(entropy) != 32 {
|
||||
t.Fatalf("expected 32 bytes of entropy, got %d", len(entropy))
|
||||
}
|
||||
|
||||
t.Logf("Entropy (32 bytes): %x", entropy)
|
||||
|
||||
// Test that DeriveEntropyFromXPRV returns exactly 32 bytes
|
||||
entropyXPRV, err := DeriveEntropyFromXPRV(testXPRV, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("derive entropy from xprv: %v", err)
|
||||
}
|
||||
|
||||
if len(entropyXPRV) != 32 {
|
||||
t.Fatalf("expected 32 bytes of entropy from xprv, got %d", len(entropyXPRV))
|
||||
}
|
||||
|
||||
t.Logf("XPRV Entropy (32 bytes): %x", entropyXPRV)
|
||||
|
||||
// Note: We don't compare the entropy values since the test mnemonic and test xprv
|
||||
// are from different sources and should produce different entropy values.
|
||||
}
|
||||
|
||||
func TestIdentityFromEntropy(t *testing.T) {
|
||||
// Test that IdentityFromEntropy works with custom entropy
|
||||
entropy := make([]byte, 32)
|
||||
for i := range entropy {
|
||||
entropy[i] = byte(i)
|
||||
}
|
||||
|
||||
id, err := IdentityFromEntropy(entropy)
|
||||
if err != nil {
|
||||
t.Fatalf("identity from entropy: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Custom entropy identity: %s", id.String())
|
||||
|
||||
// Test that it rejects wrong-sized entropy
|
||||
_, err = IdentityFromEntropy(entropy[:31])
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for 31-byte entropy")
|
||||
}
|
||||
|
||||
// Create a 33-byte slice to test rejection
|
||||
entropy33 := make([]byte, 33)
|
||||
copy(entropy33, entropy)
|
||||
_, err = IdentityFromEntropy(entropy33)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for 33-byte entropy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidXPRV(t *testing.T) {
|
||||
// Test with invalid xprv
|
||||
_, err := DeriveIdentityFromXPRV("invalid-xprv", 0)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid xprv")
|
||||
}
|
||||
|
||||
t.Logf("Got expected error for invalid xprv: %v", err)
|
||||
}
|
||||
99
internal/cli/cli.go
Normal file
99
internal/cli/cli.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Package cli implements the command-line interface for the secret application.
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Global scanner for consistent stdin reading
|
||||
var stdinScanner *bufio.Scanner //nolint:gochecknoglobals // Needed for consistent stdin handling
|
||||
|
||||
// Instance encapsulates all CLI functionality and state
|
||||
type Instance struct {
|
||||
fs afero.Fs
|
||||
stateDir string
|
||||
cmd *cobra.Command
|
||||
}
|
||||
|
||||
// NewCLIInstance creates a new CLI instance with the real filesystem
|
||||
func NewCLIInstance() *Instance {
|
||||
fs := afero.NewOsFs()
|
||||
stateDir := secret.DetermineStateDir("")
|
||||
|
||||
return &Instance{
|
||||
fs: fs,
|
||||
stateDir: stateDir,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCLIInstanceWithFs creates a new CLI instance with the given filesystem (for testing)
|
||||
func NewCLIInstanceWithFs(fs afero.Fs) *Instance {
|
||||
stateDir := secret.DetermineStateDir("")
|
||||
|
||||
return &Instance{
|
||||
fs: fs,
|
||||
stateDir: stateDir,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCLIInstanceWithStateDir creates a new CLI instance with custom state directory (for testing)
|
||||
func NewCLIInstanceWithStateDir(fs afero.Fs, stateDir string) *Instance {
|
||||
return &Instance{
|
||||
fs: fs,
|
||||
stateDir: stateDir,
|
||||
}
|
||||
}
|
||||
|
||||
// SetFilesystem sets the filesystem for this CLI instance (for testing)
|
||||
func (cli *Instance) SetFilesystem(fs afero.Fs) {
|
||||
cli.fs = fs
|
||||
}
|
||||
|
||||
// SetStateDir sets the state directory for this CLI instance (for testing)
|
||||
func (cli *Instance) SetStateDir(stateDir string) {
|
||||
cli.stateDir = stateDir
|
||||
}
|
||||
|
||||
// GetStateDir returns the state directory for this CLI instance
|
||||
func (cli *Instance) GetStateDir() string {
|
||||
return cli.stateDir
|
||||
}
|
||||
|
||||
// getStdinScanner returns a shared scanner for stdin to avoid buffering issues
|
||||
func getStdinScanner() *bufio.Scanner {
|
||||
if stdinScanner == nil {
|
||||
stdinScanner = bufio.NewScanner(os.Stdin)
|
||||
}
|
||||
|
||||
return stdinScanner
|
||||
}
|
||||
|
||||
// readLineFromStdin reads a single line from stdin with a prompt
|
||||
// Uses a shared scanner to avoid buffering issues between multiple calls
|
||||
func readLineFromStdin(prompt string) (string, error) {
|
||||
// Check if stderr is a terminal - if not, we can't prompt interactively
|
||||
if !term.IsTerminal(syscall.Stderr) {
|
||||
return "", fmt.Errorf("cannot prompt for input: stderr is not a terminal (running in non-interactive mode)")
|
||||
}
|
||||
|
||||
fmt.Fprint(os.Stderr, prompt) // Write prompt to stderr, not stdout
|
||||
scanner := getStdinScanner()
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("failed to read from stdin: %w", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to read from stdin: EOF")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(scanner.Text()), nil
|
||||
}
|
||||
57
internal/cli/cli_test.go
Normal file
57
internal/cli/cli_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
func TestCLIInstanceStateDir(t *testing.T) {
|
||||
// Test the CLI instance state directory functionality
|
||||
fs := afero.NewMemMapFs()
|
||||
|
||||
// Create a test state directory
|
||||
testStateDir := "/test-state-dir"
|
||||
cli := NewCLIInstanceWithStateDir(fs, testStateDir)
|
||||
|
||||
if cli.GetStateDir() != testStateDir {
|
||||
t.Errorf("Expected state directory %q, got %q", testStateDir, cli.GetStateDir())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLIInstanceWithFs(t *testing.T) {
|
||||
// Test creating CLI instance with custom filesystem
|
||||
fs := afero.NewMemMapFs()
|
||||
cli := NewCLIInstanceWithFs(fs)
|
||||
|
||||
// The state directory should be determined automatically
|
||||
stateDir := cli.GetStateDir()
|
||||
if stateDir == "" {
|
||||
t.Error("Expected non-empty state directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineStateDir(t *testing.T) {
|
||||
// Test the determineStateDir function from the secret package
|
||||
|
||||
// Test with environment variable set
|
||||
testEnvDir := "/test-env-dir"
|
||||
t.Setenv(secret.EnvStateDir, testEnvDir)
|
||||
|
||||
stateDir := secret.DetermineStateDir("")
|
||||
if stateDir != testEnvDir {
|
||||
t.Errorf("Expected state directory %q from environment, got %q", testEnvDir, stateDir)
|
||||
}
|
||||
|
||||
// Test with custom config dir
|
||||
_ = os.Unsetenv(secret.EnvStateDir)
|
||||
customConfigDir := "/custom-config"
|
||||
stateDir = secret.DetermineStateDir(customConfigDir)
|
||||
expectedDir := filepath.Join(customConfigDir, secret.AppID)
|
||||
if stateDir != expectedDir {
|
||||
t.Errorf("Expected state directory %q with custom config, got %q", expectedDir, stateDir)
|
||||
}
|
||||
}
|
||||
280
internal/cli/crypto.go
Normal file
280
internal/cli/crypto.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newEncryptCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "encrypt <secret-name>",
|
||||
Short: "Encrypt data using an age secret key stored in a secret",
|
||||
Long: `Encrypt data using an age secret key. If the secret doesn't exist, a new age key is generated and stored.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
inputFile, _ := cmd.Flags().GetString("input")
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
|
||||
cli := NewCLIInstance()
|
||||
cli.cmd = cmd
|
||||
|
||||
return cli.Encrypt(args[0], inputFile, outputFile)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringP("input", "i", "", "Input file (default: stdin)")
|
||||
cmd.Flags().StringP("output", "o", "", "Output file (default: stdout)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newDecryptCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "decrypt <secret-name>",
|
||||
Short: "Decrypt data using an age secret key stored in a secret",
|
||||
Long: `Decrypt data using an age secret key stored in the specified secret.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
inputFile, _ := cmd.Flags().GetString("input")
|
||||
outputFile, _ := cmd.Flags().GetString("output")
|
||||
|
||||
cli := NewCLIInstance()
|
||||
cli.cmd = cmd
|
||||
|
||||
return cli.Decrypt(args[0], inputFile, outputFile)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringP("input", "i", "", "Input file (default: stdin)")
|
||||
cmd.Flags().StringP("output", "o", "", "Output file (default: stdout)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Encrypt encrypts data using an age secret key stored in a secret
|
||||
func (cli *Instance) Encrypt(secretName, inputFile, outputFile string) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ageSecretKey string
|
||||
|
||||
// Check if secret exists
|
||||
secretObj := secret.NewSecret(vlt, secretName)
|
||||
exists, err := secretObj.Exists()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if secret exists: %w", err)
|
||||
}
|
||||
|
||||
if !exists { //nolint:nestif // Clear conditional logic for secret generation vs retrieval
|
||||
// Secret doesn't exist, generate new age key and store it
|
||||
identity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate age key: %w", err)
|
||||
}
|
||||
|
||||
// Store the generated key directly in a secure buffer
|
||||
identityStr := identity.String()
|
||||
secureBuffer := memguard.NewBufferFromBytes([]byte(identityStr))
|
||||
defer secureBuffer.Destroy()
|
||||
|
||||
// Set ageSecretKey for later use (we need it for encryption)
|
||||
ageSecretKey = identityStr
|
||||
|
||||
err = vlt.AddSecret(secretName, secureBuffer, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store age key: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Secret exists, get the age secret key from it
|
||||
secretValue, err := cli.getSecretValue(vlt, secretObj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get secret value: %w", err)
|
||||
}
|
||||
|
||||
// Create secure buffer for the secret value
|
||||
secureBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||
defer secureBuffer.Destroy()
|
||||
|
||||
// Clear the original secret value
|
||||
for i := range secretValue {
|
||||
secretValue[i] = 0
|
||||
}
|
||||
|
||||
ageSecretKey = secureBuffer.String()
|
||||
|
||||
// Validate that it's a valid age secret key
|
||||
if !isValidAgeSecretKey(ageSecretKey) {
|
||||
return fmt.Errorf("secret '%s' does not contain a valid age secret key", secretName)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the secret key using secure buffer
|
||||
finalSecureBuffer := memguard.NewBufferFromBytes([]byte(ageSecretKey))
|
||||
defer finalSecureBuffer.Destroy()
|
||||
|
||||
identity, err := age.ParseX25519Identity(finalSecureBuffer.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse age secret key: %w", err)
|
||||
}
|
||||
|
||||
// Get recipient from identity
|
||||
recipient := identity.Recipient()
|
||||
|
||||
// Set up input reader
|
||||
var input io.Reader = os.Stdin
|
||||
if inputFile != "" {
|
||||
file, err := cli.fs.Open(inputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open input file: %w", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
input = file
|
||||
}
|
||||
|
||||
// Set up output writer
|
||||
output := cli.cmd.OutOrStdout()
|
||||
if outputFile != "" {
|
||||
file, err := cli.fs.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
output = file
|
||||
}
|
||||
|
||||
// Encrypt the data
|
||||
encryptor, err := age.Encrypt(output, recipient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create age encryptor: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(encryptor, input); err != nil {
|
||||
return fmt.Errorf("failed to encrypt data: %w", err)
|
||||
}
|
||||
|
||||
if err := encryptor.Close(); err != nil {
|
||||
return fmt.Errorf("failed to finalize encryption: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts data using an age secret key stored in a secret
|
||||
func (cli *Instance) Decrypt(secretName, inputFile, outputFile string) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if secret exists
|
||||
secretObj := secret.NewSecret(vlt, secretName)
|
||||
exists, err := secretObj.Exists()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if secret exists: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("secret '%s' does not exist", secretName)
|
||||
}
|
||||
|
||||
// Get the age secret key from the secret
|
||||
var secretValue []byte
|
||||
if os.Getenv(secret.EnvMnemonic) != "" {
|
||||
secretValue, err = secretObj.GetValue(nil)
|
||||
} else {
|
||||
unlocker, unlockErr := vlt.GetCurrentUnlocker()
|
||||
if unlockErr != nil {
|
||||
return fmt.Errorf("failed to get current unlocker: %w", unlockErr)
|
||||
}
|
||||
secretValue, err = secretObj.GetValue(unlocker)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get secret value: %w", err)
|
||||
}
|
||||
|
||||
// Create secure buffer for the secret value
|
||||
secureBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||
defer secureBuffer.Destroy()
|
||||
|
||||
// Clear the original secret value
|
||||
for i := range secretValue {
|
||||
secretValue[i] = 0
|
||||
}
|
||||
|
||||
// Validate that it's a valid age secret key
|
||||
if !isValidAgeSecretKey(secureBuffer.String()) {
|
||||
return fmt.Errorf("secret '%s' does not contain a valid age secret key", secretName)
|
||||
}
|
||||
|
||||
// Parse the age secret key to get the identity
|
||||
identity, err := age.ParseX25519Identity(secureBuffer.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse age secret key: %w", err)
|
||||
}
|
||||
|
||||
// Set up input reader
|
||||
var input io.Reader = os.Stdin
|
||||
if inputFile != "" {
|
||||
file, err := cli.fs.Open(inputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open input file: %w", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
input = file
|
||||
}
|
||||
|
||||
// Set up output writer
|
||||
output := cli.cmd.OutOrStdout()
|
||||
if outputFile != "" {
|
||||
file, err := cli.fs.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
output = file
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
decryptor, err := age.Decrypt(input, identity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create age decryptor: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(output, decryptor); err != nil {
|
||||
return fmt.Errorf("failed to decrypt data: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidAgeSecretKey checks if a string is a valid age secret key by attempting to parse it
|
||||
func isValidAgeSecretKey(key string) bool {
|
||||
_, err := age.ParseX25519Identity(key)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// getSecretValue retrieves the value of a secret using the appropriate unlocker
|
||||
func (cli *Instance) getSecretValue(vlt *vault.Vault, secretObj *secret.Secret) ([]byte, error) {
|
||||
if os.Getenv(secret.EnvMnemonic) != "" {
|
||||
return secretObj.GetValue(nil)
|
||||
}
|
||||
|
||||
unlocker, err := vlt.GetCurrentUnlocker()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
|
||||
}
|
||||
|
||||
return secretObj.GetValue(unlocker)
|
||||
}
|
||||
185
internal/cli/generate.go
Normal file
185
internal/cli/generate.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/tyler-smith/go-bip39"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSecretLength = 16
|
||||
mnemonicEntropyBits = 128
|
||||
)
|
||||
|
||||
func newGenerateCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "generate",
|
||||
Short: "Generate random data",
|
||||
Long: `Generate various types of random data including mnemonics and secrets.`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(newGenerateMnemonicCmd())
|
||||
cmd.AddCommand(newGenerateSecretCmd())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newGenerateMnemonicCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "mnemonic",
|
||||
Short: "Generate a random BIP39 mnemonic phrase",
|
||||
Long: `Generate a cryptographically secure random BIP39 ` +
|
||||
`mnemonic phrase that can be used with 'secret init' ` +
|
||||
`or 'secret import'.`,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.GenerateMnemonic(cmd)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newGenerateSecretCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "secret <name>",
|
||||
Short: "Generate a random secret and store it in the vault",
|
||||
Long: `Generate a cryptographically secure random secret and store it in the current vault under the given name.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
length, _ := cmd.Flags().GetInt("length")
|
||||
secretType, _ := cmd.Flags().GetString("type")
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.GenerateSecret(cmd, args[0], length, secretType, force)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().IntP("length", "l", defaultSecretLength, "Length of the generated secret (default 16)")
|
||||
cmd.Flags().StringP("type", "t", "base58", "Type of secret to generate (base58, alnum)")
|
||||
cmd.Flags().BoolP("force", "f", false, "Overwrite existing secret")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// GenerateMnemonic generates a random BIP39 mnemonic phrase
|
||||
func (cli *Instance) GenerateMnemonic(cmd *cobra.Command) error {
|
||||
// Generate 128 bits of entropy for a 12-word mnemonic
|
||||
entropy, err := bip39.NewEntropy(mnemonicEntropyBits)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate entropy: %w", err)
|
||||
}
|
||||
|
||||
// Create mnemonic from entropy
|
||||
mnemonic, err := bip39.NewMnemonic(entropy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate mnemonic: %w", err)
|
||||
}
|
||||
|
||||
// Output mnemonic to stdout
|
||||
cmd.Println(mnemonic)
|
||||
|
||||
// Output helpful information to stderr
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "⚠️ IMPORTANT: Save this mnemonic phrase securely!")
|
||||
fmt.Fprintln(os.Stderr, " • Write it down on paper and store it safely")
|
||||
fmt.Fprintln(os.Stderr, " • Do not store it digitally or share it with anyone")
|
||||
fmt.Fprintln(os.Stderr, " • You will need this phrase to recover your secrets")
|
||||
fmt.Fprintln(os.Stderr, " • If you lose this phrase, your secrets cannot be recovered")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use this mnemonic with:")
|
||||
fmt.Fprintln(os.Stderr, " secret init (to initialize a new secret manager)")
|
||||
fmt.Fprintln(os.Stderr, " secret import (to import into an existing vault)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSecret generates a random secret and stores it in the vault
|
||||
func (cli *Instance) GenerateSecret(
|
||||
cmd *cobra.Command,
|
||||
secretName string,
|
||||
length int,
|
||||
secretType string,
|
||||
force bool,
|
||||
) error {
|
||||
if length < 1 {
|
||||
return fmt.Errorf("length must be at least 1")
|
||||
}
|
||||
|
||||
var secretValue string
|
||||
var err error
|
||||
|
||||
switch secretType {
|
||||
case "base58":
|
||||
secretValue, err = generateRandomBase58(length)
|
||||
case "alnum":
|
||||
secretValue, err = generateRandomAlnum(length)
|
||||
case "mnemonic":
|
||||
return fmt.Errorf("mnemonic type not supported for secret generation, use 'secret generate mnemonic' instead")
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %s (supported: base58, alnum)", secretType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate random secret: %w", err)
|
||||
}
|
||||
|
||||
// Store the secret in the vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Protect the generated secret immediately
|
||||
secretBuffer := memguard.NewBufferFromBytes([]byte(secretValue))
|
||||
defer secretBuffer.Destroy()
|
||||
|
||||
if err := vlt.AddSecret(secretName, secretBuffer, force); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Generated and stored %d-character %s secret: %s\n", length, secretType, secretName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateRandomBase58 generates a random base58 string of the specified length
|
||||
func generateRandomBase58(length int) (string, error) {
|
||||
const base58Chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
|
||||
|
||||
return generateRandomString(length, base58Chars)
|
||||
}
|
||||
|
||||
// generateRandomAlnum generates a random alphanumeric string of the specified length
|
||||
func generateRandomAlnum(length int) (string, error) {
|
||||
const alnumChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
return generateRandomString(length, alnumChars)
|
||||
}
|
||||
|
||||
// generateRandomString generates a random string of the specified length using the given character set
|
||||
func generateRandomString(length int, charset string) (string, error) {
|
||||
if length <= 0 {
|
||||
return "", fmt.Errorf("length must be positive")
|
||||
}
|
||||
|
||||
result := make([]byte, length)
|
||||
charsetLen := big.NewInt(int64(len(charset)))
|
||||
|
||||
for i := range length {
|
||||
randomIndex, err := rand.Int(rand.Reader, charsetLen)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate random number: %w", err)
|
||||
}
|
||||
result[i] = charset[randomIndex.Int64()]
|
||||
}
|
||||
|
||||
return string(result), nil
|
||||
}
|
||||
221
internal/cli/init.go
Normal file
221
internal/cli/init.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/tyler-smith/go-bip39"
|
||||
)
|
||||
|
||||
// NewInitCmd creates the init command
|
||||
func NewInitCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Initialize the secrets manager",
|
||||
Long: `Create the necessary directory structure for storing secrets and generate encryption keys.`,
|
||||
RunE: RunInit,
|
||||
}
|
||||
}
|
||||
|
||||
// RunInit is the exported function that handles the init command
|
||||
func RunInit(cmd *cobra.Command, _ []string) error {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.Init(cmd)
|
||||
}
|
||||
|
||||
// Init initializes the secret manager
|
||||
func (cli *Instance) Init(cmd *cobra.Command) error {
|
||||
secret.Debug("Starting secret manager initialization")
|
||||
|
||||
// Create state directory
|
||||
stateDir := cli.GetStateDir()
|
||||
secret.DebugWith("Creating state directory", slog.String("path", stateDir))
|
||||
|
||||
if err := cli.fs.MkdirAll(stateDir, secret.DirPerms); err != nil {
|
||||
secret.Debug("Failed to create state directory", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to create state directory: %w", err)
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
cmd.Printf("Initialized secrets manager at: %s\n", stateDir)
|
||||
}
|
||||
|
||||
// Prompt for mnemonic
|
||||
var mnemonicStr string
|
||||
|
||||
if envMnemonic := os.Getenv(secret.EnvMnemonic); envMnemonic != "" {
|
||||
secret.Debug("Using mnemonic from environment variable")
|
||||
mnemonicStr = envMnemonic
|
||||
} else {
|
||||
secret.Debug("Prompting user for mnemonic phrase")
|
||||
// Read mnemonic from stdin using shared line reader
|
||||
var err error
|
||||
mnemonicStr, err = readLineFromStdin("Enter your BIP39 mnemonic phrase: ")
|
||||
if err != nil {
|
||||
secret.Debug("Failed to read mnemonic from stdin", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to read mnemonic: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if mnemonicStr == "" {
|
||||
secret.Debug("Empty mnemonic provided")
|
||||
|
||||
return fmt.Errorf("mnemonic cannot be empty")
|
||||
}
|
||||
|
||||
// Validate the mnemonic using BIP39
|
||||
secret.DebugWith("Validating BIP39 mnemonic", slog.Int("word_count", len(strings.Fields(mnemonicStr))))
|
||||
if !bip39.IsMnemonicValid(mnemonicStr) {
|
||||
secret.Debug("Invalid BIP39 mnemonic provided")
|
||||
|
||||
return fmt.Errorf("invalid BIP39 mnemonic phrase\nRun 'secret generate mnemonic' to create a valid mnemonic")
|
||||
}
|
||||
|
||||
// Set mnemonic in environment for CreateVault to use
|
||||
originalMnemonic := os.Getenv(secret.EnvMnemonic)
|
||||
_ = os.Setenv(secret.EnvMnemonic, mnemonicStr)
|
||||
defer func() {
|
||||
if originalMnemonic != "" {
|
||||
_ = os.Setenv(secret.EnvMnemonic, originalMnemonic)
|
||||
} else {
|
||||
_ = os.Unsetenv(secret.EnvMnemonic)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create the default vault - it will handle key derivation internally
|
||||
secret.Debug("Creating default vault")
|
||||
vlt, err := vault.CreateVault(cli.fs, cli.stateDir, "default")
|
||||
if err != nil {
|
||||
secret.Debug("Failed to create default vault", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to create default vault: %w", err)
|
||||
}
|
||||
|
||||
// Get the vault metadata to retrieve the derivation index
|
||||
vaultDir := filepath.Join(stateDir, "vaults.d", "default")
|
||||
metadata, err := vault.LoadVaultMetadata(cli.fs, vaultDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to load vault metadata", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to load vault metadata: %w", err)
|
||||
}
|
||||
|
||||
// Derive the long-term key using the same index that CreateVault used
|
||||
ltIdentity, err := agehd.DeriveIdentity(mnemonicStr, metadata.DerivationIndex)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to derive long-term key", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
|
||||
}
|
||||
ltPubKey := ltIdentity.Recipient().String()
|
||||
|
||||
// Unlock the vault with the derived long-term key
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Prompt for passphrase for unlocker
|
||||
var passphraseBuffer *memguard.LockedBuffer
|
||||
if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" {
|
||||
secret.Debug("Using unlock passphrase from environment variable")
|
||||
passphraseBuffer = memguard.NewBufferFromBytes([]byte(envPassphrase))
|
||||
} else {
|
||||
secret.Debug("Prompting user for unlock passphrase")
|
||||
// Use secure passphrase input with confirmation
|
||||
passphraseBuffer, err = readSecurePassphrase("Enter passphrase for unlocker: ")
|
||||
if err != nil {
|
||||
secret.Debug("Failed to read unlock passphrase", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to read passphrase: %w", err)
|
||||
}
|
||||
}
|
||||
defer passphraseBuffer.Destroy()
|
||||
|
||||
// Create passphrase-protected unlocker
|
||||
secret.Debug("Creating passphrase-protected unlocker")
|
||||
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to create unlocker", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to create unlocker: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt long-term private key to the unlocker
|
||||
unlockerDir := passphraseUnlocker.GetDirectory()
|
||||
|
||||
// Read unlocker public key
|
||||
unlockerPubKeyData, err := afero.ReadFile(cli.fs, filepath.Join(unlockerDir, "pub.age"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read unlocker public key: %w", err)
|
||||
}
|
||||
|
||||
unlockerRecipient, err := age.ParseX25519Recipient(string(unlockerPubKeyData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse unlocker public key: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt long-term private key to unlocker
|
||||
// Use memguard to protect the private key in memory
|
||||
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
|
||||
defer ltPrivKeyBuffer.Destroy()
|
||||
|
||||
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerRecipient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt long-term private key: %w", err)
|
||||
}
|
||||
|
||||
// Write encrypted long-term private key
|
||||
ltPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
|
||||
if err := afero.WriteFile(cli.fs, ltPrivKeyPath, encryptedLtPrivKey, secret.FilePerms); err != nil {
|
||||
return fmt.Errorf("failed to write encrypted long-term private key: %w", err)
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
cmd.Printf("\nDefault vault created and configured\n")
|
||||
cmd.Printf("Long-term public key: %s\n", ltPubKey)
|
||||
cmd.Printf("Unlocker ID: %s\n", passphraseUnlocker.GetID())
|
||||
cmd.Println("\nYour secret manager is ready to use!")
|
||||
cmd.Println("Note: When using SB_SECRET_MNEMONIC environment variable,")
|
||||
cmd.Println("unlockers are not required for secret operations.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readSecurePassphrase reads a passphrase securely from the terminal without echoing
|
||||
// This version adds confirmation (read twice) for creating new unlockers
|
||||
// Returns a LockedBuffer containing the passphrase
|
||||
func readSecurePassphrase(prompt string) (*memguard.LockedBuffer, error) {
|
||||
// Get the first passphrase
|
||||
passphraseBuffer1, err := secret.ReadPassphrase(prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer passphraseBuffer1.Destroy()
|
||||
|
||||
// Read confirmation passphrase
|
||||
passphraseBuffer2, err := secret.ReadPassphrase("Confirm passphrase: ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read passphrase confirmation: %w", err)
|
||||
}
|
||||
defer passphraseBuffer2.Destroy()
|
||||
|
||||
// Compare passphrases
|
||||
if passphraseBuffer1.String() != passphraseBuffer2.String() {
|
||||
return nil, fmt.Errorf("passphrases do not match")
|
||||
}
|
||||
|
||||
// Create a new buffer with the confirmed passphrase
|
||||
return memguard.NewBufferFromBytes(passphraseBuffer1.Bytes()), nil
|
||||
}
|
||||
2148
internal/cli/integration_test.go
Normal file
2148
internal/cli/integration_test.go
Normal file
File diff suppressed because it is too large
Load Diff
47
internal/cli/root.go
Normal file
47
internal/cli/root.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Entry is the entry point for the secret CLI application
|
||||
func Entry() {
|
||||
cmd := newRootCmd()
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func newRootCmd() *cobra.Command {
|
||||
secret.Debug("newRootCmd starting")
|
||||
cmd := &cobra.Command{
|
||||
Use: "secret",
|
||||
Short: "A simple secrets manager",
|
||||
Long: `A simple secrets manager to store and retrieve sensitive information securely.`,
|
||||
// Ensure usage is shown after errors
|
||||
SilenceUsage: false,
|
||||
SilenceErrors: false,
|
||||
}
|
||||
|
||||
secret.Debug("Adding subcommands to root command")
|
||||
// Add subcommands
|
||||
cmd.AddCommand(NewInitCmd())
|
||||
cmd.AddCommand(newGenerateCmd())
|
||||
cmd.AddCommand(newVaultCmd())
|
||||
cmd.AddCommand(newAddCmd())
|
||||
cmd.AddCommand(newGetCmd())
|
||||
cmd.AddCommand(newListCmd())
|
||||
cmd.AddCommand(newUnlockersCmd())
|
||||
cmd.AddCommand(newUnlockerCmd())
|
||||
cmd.AddCommand(newImportCmd())
|
||||
cmd.AddCommand(newEncryptCmd())
|
||||
cmd.AddCommand(newDecryptCmd())
|
||||
cmd.AddCommand(newVersionCmd())
|
||||
|
||||
secret.Debug("newRootCmd completed")
|
||||
|
||||
return cmd
|
||||
}
|
||||
450
internal/cli/secrets.go
Normal file
450
internal/cli/secrets.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newAddCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "add <secret-name>",
|
||||
Short: "Add a secret to the vault",
|
||||
Long: `Add a secret to the current vault. The secret value is read from stdin.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
secret.Debug("Add command RunE starting", "secret_name", args[0])
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
secret.Debug("Got force flag", "force", force)
|
||||
|
||||
cli := NewCLIInstance()
|
||||
cli.cmd = cmd // Set the command for stdin access
|
||||
secret.Debug("Created CLI instance, calling AddSecret")
|
||||
|
||||
return cli.AddSecret(args[0], force)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolP("force", "f", false, "Overwrite existing secret")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newGetCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "get <secret-name>",
|
||||
Short: "Retrieve a secret from the vault",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
version, _ := cmd.Flags().GetString("version")
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.GetSecretWithVersion(cmd, args[0], version)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringP("version", "v", "", "Get a specific version (default: current)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newListCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "list [filter]",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all secrets in the current vault",
|
||||
Long: `List all secrets in the current vault. Optionally filter by substring match in secret name.`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
jsonOutput, _ := cmd.Flags().GetBool("json")
|
||||
|
||||
var filter string
|
||||
if len(args) > 0 {
|
||||
filter = args[0]
|
||||
}
|
||||
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.ListSecrets(cmd, jsonOutput, filter)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("json", false, "Output in JSON format")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newImportCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "import <secret-name>",
|
||||
Short: "Import a secret from a file",
|
||||
Long: `Import a secret from a file and store it in the current vault under the given name.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
sourceFile, _ := cmd.Flags().GetString("source")
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.ImportSecret(cmd, args[0], sourceFile, force)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringP("source", "s", "", "Source file to import from (required)")
|
||||
cmd.Flags().BoolP("force", "f", false, "Overwrite existing secret")
|
||||
_ = cmd.MarkFlagRequired("source")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// updateBufferSize updates the buffer size based on usage pattern
|
||||
func updateBufferSize(currentSize int, sameSize *int) int {
|
||||
*sameSize++
|
||||
const doubleAfterBuffers = 2
|
||||
const growthFactor = 2
|
||||
if *sameSize >= doubleAfterBuffers {
|
||||
*sameSize = 0
|
||||
|
||||
return currentSize * growthFactor
|
||||
}
|
||||
|
||||
return currentSize
|
||||
}
|
||||
|
||||
// AddSecret adds a secret to the current vault
|
||||
func (cli *Instance) AddSecret(secretName string, force bool) error {
|
||||
secret.Debug("CLI AddSecret starting", "secret_name", secretName, "force", force)
|
||||
|
||||
// Get current vault
|
||||
secret.Debug("Getting current vault")
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
secret.Debug("Got current vault", "vault_name", vlt.GetName())
|
||||
|
||||
// Read secret value directly into protected buffers
|
||||
secret.Debug("Reading secret value from stdin into protected buffers")
|
||||
|
||||
const initialSize = 4 * 1024 // 4KB initial buffer
|
||||
const maxSize = 100 * 1024 * 1024 // 100MB max
|
||||
|
||||
type bufferInfo struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
used int
|
||||
}
|
||||
|
||||
var buffers []bufferInfo
|
||||
defer func() {
|
||||
for _, b := range buffers {
|
||||
b.buffer.Destroy()
|
||||
}
|
||||
}()
|
||||
|
||||
reader := cli.cmd.InOrStdin()
|
||||
totalSize := 0
|
||||
currentBufferSize := initialSize
|
||||
sameSize := 0
|
||||
|
||||
for {
|
||||
// Create a new buffer
|
||||
buffer := memguard.NewBuffer(currentBufferSize)
|
||||
n, err := io.ReadFull(reader, buffer.Bytes())
|
||||
|
||||
if n == 0 {
|
||||
// No data read, destroy the unused buffer
|
||||
buffer.Destroy()
|
||||
} else {
|
||||
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
|
||||
totalSize += n
|
||||
|
||||
if totalSize > maxSize {
|
||||
return fmt.Errorf("secret too large: exceeds 100MB limit")
|
||||
}
|
||||
|
||||
// If we filled the buffer, consider growing for next iteration
|
||||
if n == currentBufferSize {
|
||||
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to read secret value: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for trailing newline in the last buffer
|
||||
if len(buffers) > 0 && totalSize > 0 {
|
||||
lastBuffer := &buffers[len(buffers)-1]
|
||||
if lastBuffer.buffer.Bytes()[lastBuffer.used-1] == '\n' {
|
||||
lastBuffer.used--
|
||||
totalSize--
|
||||
}
|
||||
}
|
||||
|
||||
secret.Debug("Read secret value from stdin", "value_length", totalSize, "buffers", len(buffers))
|
||||
|
||||
// Combine all buffers into a single protected buffer
|
||||
valueBuffer := memguard.NewBuffer(totalSize)
|
||||
defer valueBuffer.Destroy()
|
||||
|
||||
offset := 0
|
||||
for _, b := range buffers {
|
||||
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
|
||||
offset += b.used
|
||||
}
|
||||
|
||||
// Add the secret to the vault
|
||||
secret.Debug("Calling vault.AddSecret", "secret_name", secretName, "value_length", valueBuffer.Size(), "force", force)
|
||||
if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
|
||||
secret.Debug("vault.AddSecret failed", "error", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
secret.Debug("vault.AddSecret completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSecret retrieves and prints a secret from the current vault
|
||||
func (cli *Instance) GetSecret(cmd *cobra.Command, secretName string) error {
|
||||
return cli.GetSecretWithVersion(cmd, secretName, "")
|
||||
}
|
||||
|
||||
// GetSecretWithVersion retrieves and prints a specific version of a secret
|
||||
func (cli *Instance) GetSecretWithVersion(cmd *cobra.Command, secretName string, version string) error {
|
||||
secret.Debug("GetSecretWithVersion called", "secretName", secretName, "version", version)
|
||||
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get current vault", "error", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the secret value
|
||||
var value []byte
|
||||
if version == "" {
|
||||
value, err = vlt.GetSecret(secretName)
|
||||
} else {
|
||||
value, err = vlt.GetSecretVersion(secretName, version)
|
||||
}
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get secret", "error", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
secret.Debug("Got secret value", "valueLength", len(value))
|
||||
|
||||
// Print the secret value to stdout
|
||||
cmd.Print(string(value))
|
||||
secret.Debug("Printed value to cmd")
|
||||
|
||||
// Debug: Log what we're actually printing
|
||||
secret.Debug("Secret retrieval debug info",
|
||||
"secretName", secretName,
|
||||
"version", version,
|
||||
"valueLength", len(value),
|
||||
"valueAsString", string(value),
|
||||
"isEmpty", len(value) == 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListSecrets lists all secrets in the current vault
|
||||
func (cli *Instance) ListSecrets(cmd *cobra.Command, jsonOutput bool, filter string) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get list of secrets
|
||||
secrets, err := vlt.ListSecrets()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list secrets: %w", err)
|
||||
}
|
||||
|
||||
// Filter secrets if filter is provided
|
||||
var filteredSecrets []string
|
||||
if filter != "" {
|
||||
for _, secretName := range secrets {
|
||||
if strings.Contains(secretName, filter) {
|
||||
filteredSecrets = append(filteredSecrets, secretName)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
filteredSecrets = secrets
|
||||
}
|
||||
|
||||
if jsonOutput { //nolint:nestif // Separate JSON and table output formatting logic
|
||||
// For JSON output, get metadata for each secret
|
||||
secretsWithMetadata := make([]map[string]interface{}, 0, len(filteredSecrets))
|
||||
|
||||
for _, secretName := range filteredSecrets {
|
||||
secretInfo := map[string]interface{}{
|
||||
"name": secretName,
|
||||
}
|
||||
|
||||
// Try to get metadata using GetSecretObject
|
||||
if secretObj, err := vlt.GetSecretObject(secretName); err == nil {
|
||||
metadata := secretObj.GetMetadata()
|
||||
secretInfo["created_at"] = metadata.CreatedAt
|
||||
secretInfo["updated_at"] = metadata.UpdatedAt
|
||||
}
|
||||
|
||||
secretsWithMetadata = append(secretsWithMetadata, secretInfo)
|
||||
}
|
||||
|
||||
output := map[string]interface{}{
|
||||
"secrets": secretsWithMetadata,
|
||||
}
|
||||
if filter != "" {
|
||||
output["filter"] = filter
|
||||
}
|
||||
|
||||
jsonBytes, err := json.MarshalIndent(output, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal JSON: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println(string(jsonBytes))
|
||||
} else {
|
||||
// Pretty table output
|
||||
if len(filteredSecrets) == 0 {
|
||||
if filter != "" {
|
||||
cmd.Printf("No secrets found in vault '%s' matching filter '%s'.\n", vlt.GetName(), filter)
|
||||
} else {
|
||||
cmd.Println("No secrets found in current vault.")
|
||||
cmd.Println("Run 'secret add <name>' to create one.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current vault name for display
|
||||
if filter != "" {
|
||||
cmd.Printf("Secrets in vault '%s' matching '%s':\n\n", vlt.GetName(), filter)
|
||||
} else {
|
||||
cmd.Printf("Secrets in vault '%s':\n\n", vlt.GetName())
|
||||
}
|
||||
cmd.Printf("%-40s %-20s\n", "NAME", "LAST UPDATED")
|
||||
cmd.Printf("%-40s %-20s\n", "----", "------------")
|
||||
|
||||
for _, secretName := range filteredSecrets {
|
||||
lastUpdated := "unknown"
|
||||
if secretObj, err := vlt.GetSecretObject(secretName); err == nil {
|
||||
metadata := secretObj.GetMetadata()
|
||||
lastUpdated = metadata.UpdatedAt.Format("2006-01-02 15:04")
|
||||
}
|
||||
cmd.Printf("%-40s %-20s\n", secretName, lastUpdated)
|
||||
}
|
||||
|
||||
cmd.Printf("\nTotal: %d secret(s)", len(filteredSecrets))
|
||||
if filter != "" {
|
||||
cmd.Printf(" (filtered from %d)", len(secrets))
|
||||
}
|
||||
cmd.Println()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ImportSecret imports a secret from a file
|
||||
func (cli *Instance) ImportSecret(cmd *cobra.Command, secretName, sourceFile string, force bool) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read secret value from the source file into protected buffers
|
||||
file, err := cli.fs.Open(sourceFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file %s: %w", sourceFile, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
secret.Debug("Failed to close file", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
const initialSize = 4 * 1024 // 4KB initial buffer
|
||||
const maxSize = 100 * 1024 * 1024 // 100MB max
|
||||
|
||||
type bufferInfo struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
used int
|
||||
}
|
||||
|
||||
var buffers []bufferInfo
|
||||
defer func() {
|
||||
for _, b := range buffers {
|
||||
b.buffer.Destroy()
|
||||
}
|
||||
}()
|
||||
|
||||
totalSize := 0
|
||||
currentBufferSize := initialSize
|
||||
sameSize := 0
|
||||
|
||||
for {
|
||||
// Create a new buffer
|
||||
buffer := memguard.NewBuffer(currentBufferSize)
|
||||
n, err := io.ReadFull(file, buffer.Bytes())
|
||||
|
||||
if n == 0 {
|
||||
// No data read, destroy the unused buffer
|
||||
buffer.Destroy()
|
||||
} else {
|
||||
buffers = append(buffers, bufferInfo{buffer: buffer, used: n})
|
||||
totalSize += n
|
||||
|
||||
if totalSize > maxSize {
|
||||
return fmt.Errorf("secret file too large: exceeds 100MB limit")
|
||||
}
|
||||
|
||||
// If we filled the buffer, consider growing for next iteration
|
||||
if n == currentBufferSize {
|
||||
currentBufferSize = updateBufferSize(currentBufferSize, &sameSize)
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to read secret from file %s: %w", sourceFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Combine all buffers into a single protected buffer
|
||||
valueBuffer := memguard.NewBuffer(totalSize)
|
||||
defer valueBuffer.Destroy()
|
||||
|
||||
offset := 0
|
||||
for _, b := range buffers {
|
||||
copy(valueBuffer.Bytes()[offset:], b.buffer.Bytes()[:b.used])
|
||||
offset += b.used
|
||||
}
|
||||
|
||||
// Store the secret in the vault
|
||||
if err := vlt.AddSecret(secretName, valueBuffer, force); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Successfully imported secret '%s' from file '%s'\n", secretName, sourceFile)
|
||||
|
||||
return nil
|
||||
}
|
||||
425
internal/cli/secrets_size_test.go
Normal file
425
internal/cli/secrets_size_test.go
Normal file
@@ -0,0 +1,425 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestAddSecretVariousSizes tests adding secrets of various sizes through stdin
|
||||
func TestAddSecretVariousSizes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
size int
|
||||
shouldError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "1KB secret",
|
||||
size: 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "10KB secret",
|
||||
size: 10 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "100KB secret",
|
||||
size: 100 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "1MB secret",
|
||||
size: 1024 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "10MB secret",
|
||||
size: 10 * 1024 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "99MB secret",
|
||||
size: 99 * 1024 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "100MB secret minus 1 byte",
|
||||
size: 100*1024*1024 - 1,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "101MB secret - should fail",
|
||||
size: 101 * 1024 * 1024,
|
||||
shouldError: true,
|
||||
errorMsg: "secret too large: exceeds 100MB limit",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up test environment
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Set test mnemonic
|
||||
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||
|
||||
// Create vault
|
||||
vaultName := "test-vault"
|
||||
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set current vault
|
||||
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get vault and set up long-term key
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||
require.NoError(t, err)
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Generate test data of specified size
|
||||
testData := make([]byte, tt.size)
|
||||
_, err = rand.Read(testData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add newline that will be stripped
|
||||
testDataWithNewline := append(testData, '\n')
|
||||
|
||||
// Create fake stdin
|
||||
stdin := bytes.NewReader(testDataWithNewline)
|
||||
|
||||
// Create command with fake stdin
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetIn(stdin)
|
||||
|
||||
// Create CLI instance
|
||||
cli := NewCLIInstance()
|
||||
cli.fs = fs
|
||||
cli.stateDir = stateDir
|
||||
cli.cmd = cmd
|
||||
|
||||
// Test adding the secret
|
||||
secretName := fmt.Sprintf("test-secret-%d", tt.size)
|
||||
err = cli.AddSecret(secretName, false)
|
||||
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the secret was stored correctly
|
||||
retrievedValue, err := vlt.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original (without newline)")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestImportSecretVariousSizes tests importing secrets of various sizes from files
|
||||
func TestImportSecretVariousSizes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
size int
|
||||
shouldError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "1KB file",
|
||||
size: 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "10KB file",
|
||||
size: 10 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "100KB file",
|
||||
size: 100 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "1MB file",
|
||||
size: 1024 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "10MB file",
|
||||
size: 10 * 1024 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "99MB file",
|
||||
size: 99 * 1024 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "100MB file",
|
||||
size: 100 * 1024 * 1024,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "101MB file - should fail",
|
||||
size: 101 * 1024 * 1024,
|
||||
shouldError: true,
|
||||
errorMsg: "secret file too large: exceeds 100MB limit",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up test environment
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Set test mnemonic
|
||||
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||
|
||||
// Create vault
|
||||
vaultName := "test-vault"
|
||||
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set current vault
|
||||
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get vault and set up long-term key
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||
require.NoError(t, err)
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Generate test data of specified size
|
||||
testData := make([]byte, tt.size)
|
||||
_, err = rand.Read(testData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write test data to file
|
||||
testFile := fmt.Sprintf("/test/secret-%d.bin", tt.size)
|
||||
err = afero.WriteFile(fs, testFile, testData, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create command
|
||||
cmd := &cobra.Command{}
|
||||
|
||||
// Create CLI instance
|
||||
cli := NewCLIInstance()
|
||||
cli.fs = fs
|
||||
cli.stateDir = stateDir
|
||||
|
||||
// Test importing the secret
|
||||
secretName := fmt.Sprintf("imported-secret-%d", tt.size)
|
||||
err = cli.ImportSecret(cmd, secretName, testFile, false)
|
||||
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the secret was stored correctly
|
||||
retrievedValue, err := vlt.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddSecretBufferGrowth tests that our buffer growth strategy works correctly
|
||||
func TestAddSecretBufferGrowth(t *testing.T) {
|
||||
// Test various sizes that should trigger buffer growth
|
||||
sizes := []int{
|
||||
1, // Single byte
|
||||
100, // Small
|
||||
4095, // Just under initial 4KB
|
||||
4096, // Exactly 4KB
|
||||
4097, // Just over 4KB
|
||||
8191, // Just under 8KB (first double)
|
||||
8192, // Exactly 8KB
|
||||
8193, // Just over 8KB
|
||||
12288, // 12KB (should trigger second double)
|
||||
16384, // 16KB
|
||||
32768, // 32KB (after more doublings)
|
||||
65536, // 64KB
|
||||
131072, // 128KB
|
||||
524288, // 512KB
|
||||
1048576, // 1MB
|
||||
2097152, // 2MB
|
||||
}
|
||||
|
||||
for _, size := range sizes {
|
||||
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
|
||||
// Set up test environment
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Set test mnemonic
|
||||
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||
|
||||
// Create vault
|
||||
vaultName := "test-vault"
|
||||
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set current vault
|
||||
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get vault and set up long-term key
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||
require.NoError(t, err)
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Create test data of exactly the specified size
|
||||
// Use a pattern that's easy to verify
|
||||
testData := make([]byte, size)
|
||||
for i := range testData {
|
||||
testData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
// Create fake stdin without newline
|
||||
stdin := bytes.NewReader(testData)
|
||||
|
||||
// Create command with fake stdin
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetIn(stdin)
|
||||
|
||||
// Create CLI instance
|
||||
cli := NewCLIInstance()
|
||||
cli.fs = fs
|
||||
cli.stateDir = stateDir
|
||||
cli.cmd = cmd
|
||||
|
||||
// Test adding the secret
|
||||
secretName := fmt.Sprintf("buffer-test-%d", size)
|
||||
err = cli.AddSecret(secretName, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the secret was stored correctly
|
||||
retrievedValue, err := vlt.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original exactly")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddSecretStreamingBehavior tests that we handle streaming input correctly
|
||||
func TestAddSecretStreamingBehavior(t *testing.T) {
|
||||
// Set up test environment
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Set test mnemonic
|
||||
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||
|
||||
// Create vault
|
||||
vaultName := "test-vault"
|
||||
_, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set current vault
|
||||
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||
vaultPath := filepath.Join(stateDir, "vaults.d", vaultName)
|
||||
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultPath), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get vault and set up long-term key
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
ltIdentity, err := agehd.DeriveIdentity("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", 0)
|
||||
require.NoError(t, err)
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Create a custom reader that simulates slow streaming input
|
||||
// This will help verify our buffer handling works correctly with partial reads
|
||||
testData := []byte(strings.Repeat("Hello, World! ", 1000)) // ~14KB
|
||||
slowReader := &slowReader{
|
||||
data: testData,
|
||||
chunkSize: 1000, // Read 1KB at a time
|
||||
}
|
||||
|
||||
// Create command with slow reader as stdin
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetIn(slowReader)
|
||||
|
||||
// Create CLI instance
|
||||
cli := NewCLIInstance()
|
||||
cli.fs = fs
|
||||
cli.stateDir = stateDir
|
||||
cli.cmd = cmd
|
||||
|
||||
// Test adding the secret
|
||||
err = cli.AddSecret("streaming-test", false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the secret was stored correctly
|
||||
retrievedValue, err := vlt.GetSecret("streaming-test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, retrievedValue, "Retrieved secret should match original")
|
||||
}
|
||||
|
||||
// slowReader simulates a reader that returns data in small chunks
|
||||
type slowReader struct {
|
||||
data []byte
|
||||
offset int
|
||||
chunkSize int
|
||||
}
|
||||
|
||||
func (r *slowReader) Read(p []byte) (n int, err error) {
|
||||
if r.offset >= len(r.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Read at most chunkSize bytes
|
||||
remaining := len(r.data) - r.offset
|
||||
toRead := r.chunkSize
|
||||
if toRead > remaining {
|
||||
toRead = remaining
|
||||
}
|
||||
if toRead > len(p) {
|
||||
toRead = len(p)
|
||||
}
|
||||
|
||||
n = copy(p, r.data[r.offset:r.offset+toRead])
|
||||
r.offset += n
|
||||
|
||||
if r.offset >= len(r.data) {
|
||||
err = io.EOF
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
63
internal/cli/test_helpers.go
Normal file
63
internal/cli/test_helpers.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
)
|
||||
|
||||
// ExecuteCommandInProcess executes a CLI command in-process for testing
|
||||
func ExecuteCommandInProcess(args []string, stdin string, env map[string]string) (string, error) {
|
||||
secret.Debug("ExecuteCommandInProcess called", "args", args)
|
||||
|
||||
// Save current environment
|
||||
savedEnv := make(map[string]string)
|
||||
for k := range env {
|
||||
savedEnv[k] = os.Getenv(k)
|
||||
}
|
||||
|
||||
// Set test environment
|
||||
for k, v := range env {
|
||||
_ = os.Setenv(k, v)
|
||||
}
|
||||
|
||||
// Create root command
|
||||
rootCmd := newRootCmd()
|
||||
|
||||
// Capture output
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
// Set stdin if provided
|
||||
if stdin != "" {
|
||||
rootCmd.SetIn(strings.NewReader(stdin))
|
||||
}
|
||||
|
||||
// Set args
|
||||
rootCmd.SetArgs(args)
|
||||
|
||||
// Execute command
|
||||
err := rootCmd.Execute()
|
||||
|
||||
output := buf.String()
|
||||
secret.Debug("Command execution completed", "error", err, "outputLength", len(output), "output", output)
|
||||
|
||||
// Add debug info for troubleshooting
|
||||
if len(output) == 0 && err == nil {
|
||||
secret.Debug("Warning: Command executed successfully but produced no output", "args", args)
|
||||
}
|
||||
|
||||
// Restore environment
|
||||
for k, v := range savedEnv {
|
||||
if v == "" {
|
||||
_ = os.Unsetenv(k)
|
||||
} else {
|
||||
_ = os.Setenv(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
return output, err
|
||||
}
|
||||
22
internal/cli/test_output_test.go
Normal file
22
internal/cli/test_output_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOutputCapture(t *testing.T) {
|
||||
// Test vault list command which we fixed
|
||||
output, err := ExecuteCommandInProcess([]string{"vault", "list"}, "", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, output, "Available vaults", "should capture vault list output")
|
||||
t.Logf("vault list output: %q", output)
|
||||
|
||||
// Test help command
|
||||
output, err = ExecuteCommandInProcess([]string{"--help"}, "", nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, output, "help output should not be empty")
|
||||
t.Logf("help output length: %d", len(output))
|
||||
}
|
||||
338
internal/cli/unlockers.go
Normal file
338
internal/cli/unlockers.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Import from init.go
|
||||
|
||||
// ... existing imports ...
|
||||
|
||||
func newUnlockersCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "unlockers",
|
||||
Short: "Manage unlockers",
|
||||
Long: `Create, list, and remove unlockers for the current vault.`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(newUnlockersListCmd())
|
||||
cmd.AddCommand(newUnlockersAddCmd())
|
||||
cmd.AddCommand(newUnlockersRmCmd())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newUnlockersListCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List unlockers in the current vault",
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
jsonOutput, _ := cmd.Flags().GetBool("json")
|
||||
|
||||
cli := NewCLIInstance()
|
||||
cli.cmd = cmd
|
||||
|
||||
return cli.UnlockersList(jsonOutput)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("json", false, "Output in JSON format")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newUnlockersAddCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "add <type>",
|
||||
Short: "Add a new unlocker",
|
||||
Long: `Add a new unlocker of the specified type (passphrase, keychain, pgp).`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.UnlockersAdd(args[0], cmd)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().String("keyid", "", "GPG key ID for PGP unlockers")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newUnlockersRmCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "rm <unlocker-id>",
|
||||
Short: "Remove an unlocker",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(_ *cobra.Command, args []string) error {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.UnlockersRemove(args[0])
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newUnlockerCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "unlocker",
|
||||
Short: "Manage current unlocker",
|
||||
Long: `Select the current unlocker for operations.`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(newUnlockerSelectSubCmd())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newUnlockerSelectSubCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "select <unlocker-id>",
|
||||
Short: "Select an unlocker as current",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(_ *cobra.Command, args []string) error {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.UnlockerSelect(args[0])
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UnlockersList lists unlockers in the current vault
|
||||
func (cli *Instance) UnlockersList(jsonOutput bool) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the metadata first
|
||||
unlockerMetadataList, err := vlt.ListUnlockers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load actual unlocker objects to get the proper IDs
|
||||
type UnlockerInfo struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Flags []string `json:"flags,omitempty"`
|
||||
}
|
||||
|
||||
var unlockers []UnlockerInfo
|
||||
for _, metadata := range unlockerMetadataList {
|
||||
// Create unlocker instance to get the proper ID
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the unlocker directory by type and created time
|
||||
unlockersDir := filepath.Join(vaultDir, "unlockers.d")
|
||||
files, err := afero.ReadDir(cli.fs, unlockersDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var unlocker secret.Unlocker
|
||||
for _, file := range files {
|
||||
if !file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
unlockerDir := filepath.Join(unlockersDir, file.Name())
|
||||
metadataPath := filepath.Join(unlockerDir, "unlocker-metadata.json")
|
||||
|
||||
// Check if this is the right unlocker by comparing metadata
|
||||
metadataBytes, err := afero.ReadFile(cli.fs, metadataPath)
|
||||
if err != nil {
|
||||
continue // FIXME this error needs to be handled
|
||||
}
|
||||
|
||||
var diskMetadata secret.UnlockerMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &diskMetadata); err != nil {
|
||||
continue // FIXME this error needs to be handled
|
||||
}
|
||||
|
||||
// Match by type and creation time
|
||||
if diskMetadata.Type == metadata.Type && diskMetadata.CreatedAt.Equal(metadata.CreatedAt) {
|
||||
// Create the appropriate unlocker instance
|
||||
switch metadata.Type {
|
||||
case "passphrase":
|
||||
unlocker = secret.NewPassphraseUnlocker(cli.fs, unlockerDir, diskMetadata)
|
||||
case "keychain":
|
||||
unlocker = secret.NewKeychainUnlocker(cli.fs, unlockerDir, diskMetadata)
|
||||
case "pgp":
|
||||
unlocker = secret.NewPGPUnlocker(cli.fs, unlockerDir, diskMetadata)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Get the proper ID using the unlocker's ID() method
|
||||
var properID string
|
||||
if unlocker != nil {
|
||||
properID = unlocker.GetID()
|
||||
} else {
|
||||
// Generate ID as fallback
|
||||
properID = fmt.Sprintf("%s-%s", metadata.CreatedAt.Format("2006-01-02.15.04"), metadata.Type)
|
||||
}
|
||||
|
||||
unlockerInfo := UnlockerInfo{
|
||||
ID: properID,
|
||||
Type: metadata.Type,
|
||||
CreatedAt: metadata.CreatedAt,
|
||||
Flags: metadata.Flags,
|
||||
}
|
||||
unlockers = append(unlockers, unlockerInfo)
|
||||
}
|
||||
|
||||
if jsonOutput {
|
||||
// JSON output
|
||||
output := map[string]interface{}{
|
||||
"unlockers": unlockers,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.MarshalIndent(output, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal JSON: %w", err)
|
||||
}
|
||||
|
||||
cli.cmd.Println(string(jsonBytes))
|
||||
} else {
|
||||
// Pretty table output
|
||||
if len(unlockers) == 0 {
|
||||
cli.cmd.Println("No unlockers found in current vault.")
|
||||
cli.cmd.Println("Run 'secret unlockers add passphrase' to create one.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
cli.cmd.Printf("%-18s %-12s %-20s %s\n", "UNLOCKER ID", "TYPE", "CREATED", "FLAGS")
|
||||
cli.cmd.Printf("%-18s %-12s %-20s %s\n", "-----------", "----", "-------", "-----")
|
||||
|
||||
for _, unlocker := range unlockers {
|
||||
flags := ""
|
||||
if len(unlocker.Flags) > 0 {
|
||||
flags = strings.Join(unlocker.Flags, ",")
|
||||
}
|
||||
cli.cmd.Printf("%-18s %-12s %-20s %s\n",
|
||||
unlocker.ID,
|
||||
unlocker.Type,
|
||||
unlocker.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
flags)
|
||||
}
|
||||
|
||||
cli.cmd.Printf("\nTotal: %d unlocker(s)\n", len(unlockers))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnlockersAdd adds a new unlocker
|
||||
func (cli *Instance) UnlockersAdd(unlockerType string, cmd *cobra.Command) error {
|
||||
switch unlockerType {
|
||||
case "passphrase":
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current vault: %w", err)
|
||||
}
|
||||
|
||||
// For passphrase unlockers, we don't need the vault to be unlocked
|
||||
// The CreatePassphraseUnlocker method will handle getting the long-term key
|
||||
|
||||
// Check if passphrase is set in environment variable
|
||||
var passphraseBuffer *memguard.LockedBuffer
|
||||
if envPassphrase := os.Getenv(secret.EnvUnlockPassphrase); envPassphrase != "" {
|
||||
passphraseBuffer = memguard.NewBufferFromBytes([]byte(envPassphrase))
|
||||
} else {
|
||||
// Use secure passphrase input with confirmation
|
||||
passphraseBuffer, err = readSecurePassphrase("Enter passphrase for unlocker: ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read passphrase: %w", err)
|
||||
}
|
||||
}
|
||||
defer passphraseBuffer.Destroy()
|
||||
|
||||
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Created passphrase unlocker: %s\n", passphraseUnlocker.GetID())
|
||||
|
||||
return nil
|
||||
|
||||
case "keychain":
|
||||
keychainUnlocker, err := secret.CreateKeychainUnlocker(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create macOS Keychain unlocker: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("Created macOS Keychain unlocker: %s\n", keychainUnlocker.GetID())
|
||||
if keyName, err := keychainUnlocker.GetKeychainItemName(); err == nil {
|
||||
cmd.Printf("Keychain Item Name: %s\n", keyName)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
case "pgp":
|
||||
// Get GPG key ID from flag or environment variable
|
||||
var gpgKeyID string
|
||||
if flagKeyID, _ := cmd.Flags().GetString("keyid"); flagKeyID != "" {
|
||||
gpgKeyID = flagKeyID
|
||||
} else if envKeyID := os.Getenv(secret.EnvGPGKeyID); envKeyID != "" {
|
||||
gpgKeyID = envKeyID
|
||||
} else {
|
||||
return fmt.Errorf("GPG key ID required: use --keyid flag or set SB_GPG_KEY_ID environment variable")
|
||||
}
|
||||
|
||||
pgpUnlocker, err := secret.CreatePGPUnlocker(cli.fs, cli.stateDir, gpgKeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Created PGP unlocker: %s\n", pgpUnlocker.GetID())
|
||||
cmd.Printf("GPG Key ID: %s\n", gpgKeyID)
|
||||
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported unlocker type: %s (supported: passphrase, keychain, pgp)", unlockerType)
|
||||
}
|
||||
}
|
||||
|
||||
// UnlockersRemove removes an unlocker
|
||||
func (cli *Instance) UnlockersRemove(unlockerID string) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return vlt.RemoveUnlocker(unlockerID)
|
||||
}
|
||||
|
||||
// UnlockerSelect selects an unlocker as current
|
||||
func (cli *Instance) UnlockerSelect(unlockerID string) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return vlt.SelectUnlocker(unlockerID)
|
||||
}
|
||||
297
internal/cli/vault.go
Normal file
297
internal/cli/vault.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/tyler-smith/go-bip39"
|
||||
)
|
||||
|
||||
func newVaultCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "vault",
|
||||
Short: "Manage vaults",
|
||||
Long: `Create, list, and select vaults for organizing secrets.`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(newVaultListCmd())
|
||||
cmd.AddCommand(newVaultCreateCmd())
|
||||
cmd.AddCommand(newVaultSelectCmd())
|
||||
cmd.AddCommand(newVaultImportCmd())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newVaultListCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List available vaults",
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
jsonOutput, _ := cmd.Flags().GetBool("json")
|
||||
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.ListVaults(cmd, jsonOutput)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("json", false, "Output in JSON format")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newVaultCreateCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "create <name>",
|
||||
Short: "Create a new vault",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.CreateVault(cmd, args[0])
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newVaultSelectCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "select <name>",
|
||||
Short: "Select a vault as current",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.SelectVault(cmd, args[0])
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newVaultImportCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "import <vault-name>",
|
||||
Short: "Import a mnemonic into a vault",
|
||||
Long: `Import a BIP39 mnemonic phrase into the specified vault (default if not specified).`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
vaultName := "default"
|
||||
if len(args) > 0 {
|
||||
vaultName = args[0]
|
||||
}
|
||||
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return cli.VaultImport(cmd, vaultName)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ListVaults lists all available vaults
|
||||
func (cli *Instance) ListVaults(cmd *cobra.Command, jsonOutput bool) error {
|
||||
vaults, err := vault.ListVaults(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if jsonOutput { //nolint:nestif // Separate JSON and text output formatting logic
|
||||
// Get current vault name for context
|
||||
currentVault := ""
|
||||
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
|
||||
currentVault = currentVlt.GetName()
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"vaults": vaults,
|
||||
"currentVault": currentVault,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.MarshalIndent(result, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Println(string(jsonBytes))
|
||||
} else {
|
||||
// Text output
|
||||
cmd.Println("Available vaults:")
|
||||
if len(vaults) == 0 {
|
||||
cmd.Println(" (none)")
|
||||
} else {
|
||||
// Try to get current vault for marking
|
||||
currentVault := ""
|
||||
if currentVlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir); err == nil {
|
||||
currentVault = currentVlt.GetName()
|
||||
}
|
||||
|
||||
for _, vaultName := range vaults {
|
||||
if vaultName == currentVault {
|
||||
cmd.Printf(" %s (current)\n", vaultName)
|
||||
} else {
|
||||
cmd.Printf(" %s\n", vaultName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateVault creates a new vault
|
||||
func (cli *Instance) CreateVault(cmd *cobra.Command, name string) error {
|
||||
secret.Debug("Creating new vault", "name", name, "state_dir", cli.stateDir)
|
||||
|
||||
vlt, err := vault.CreateVault(cli.fs, cli.stateDir, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Created vault '%s'\n", vlt.GetName())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SelectVault selects a vault as the current one
|
||||
func (cli *Instance) SelectVault(cmd *cobra.Command, name string) error {
|
||||
if err := vault.SelectVault(cli.fs, cli.stateDir, name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Selected vault '%s' as current\n", name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VaultImport imports a mnemonic into a specific vault
|
||||
func (cli *Instance) VaultImport(cmd *cobra.Command, vaultName string) error {
|
||||
secret.Debug("Importing mnemonic into vault", "vault_name", vaultName, "state_dir", cli.stateDir)
|
||||
|
||||
// Get the specific vault by name
|
||||
vlt := vault.NewVault(cli.fs, cli.stateDir, vaultName)
|
||||
|
||||
// Check if vault exists
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exists, err := afero.DirExists(cli.fs, vaultDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if vault exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("vault '%s' does not exist", vaultName)
|
||||
}
|
||||
|
||||
// Check if vault already has a public key
|
||||
pubKeyPath := fmt.Sprintf("%s/pub.age", vaultDir)
|
||||
if _, err := cli.fs.Stat(pubKeyPath); err == nil {
|
||||
return fmt.Errorf("vault '%s' already has a long-term key configured", vaultName)
|
||||
}
|
||||
|
||||
// Get mnemonic from environment
|
||||
mnemonic := os.Getenv(secret.EnvMnemonic)
|
||||
if mnemonic == "" {
|
||||
return fmt.Errorf("SB_SECRET_MNEMONIC environment variable not set")
|
||||
}
|
||||
|
||||
// Validate the mnemonic
|
||||
mnemonicWords := strings.Fields(mnemonic)
|
||||
secret.Debug("Validating BIP39 mnemonic", "word_count", len(mnemonicWords))
|
||||
if !bip39.IsMnemonicValid(mnemonic) {
|
||||
return fmt.Errorf("invalid BIP39 mnemonic")
|
||||
}
|
||||
|
||||
// Get the next available derivation index for this mnemonic
|
||||
derivationIndex, err := vault.GetNextDerivationIndex(cli.fs, cli.stateDir, mnemonic)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get next derivation index", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to get next derivation index: %w", err)
|
||||
}
|
||||
secret.Debug("Using derivation index", "index", derivationIndex)
|
||||
|
||||
// Derive long-term key from mnemonic with the appropriate index
|
||||
secret.Debug("Deriving long-term key from mnemonic", "index", derivationIndex)
|
||||
ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive long-term key: %w", err)
|
||||
}
|
||||
|
||||
// Store long-term public key in vault
|
||||
ltPublicKey := ltIdentity.Recipient().String()
|
||||
secret.Debug("Storing long-term public key", "pubkey", ltPublicKey, "vault_dir", vaultDir)
|
||||
|
||||
if err := afero.WriteFile(cli.fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms); err != nil {
|
||||
return fmt.Errorf("failed to store long-term public key: %w", err)
|
||||
}
|
||||
|
||||
// Calculate public key hash from the actual derivation index being used
|
||||
// This is used to verify that the derived key matches what was stored
|
||||
publicKeyHash := vault.ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
|
||||
|
||||
// Calculate family hash from index 0 (same for all vaults with this mnemonic)
|
||||
// This is used to identify which vaults belong to the same mnemonic family
|
||||
identity0, err := agehd.DeriveIdentity(mnemonic, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to derive identity for index 0: %w", err)
|
||||
}
|
||||
familyHash := vault.ComputeDoubleSHA256([]byte(identity0.Recipient().String()))
|
||||
|
||||
// Load existing metadata
|
||||
existingMetadata, err := vault.LoadVaultMetadata(cli.fs, vaultDir)
|
||||
if err != nil {
|
||||
// If metadata doesn't exist, create new
|
||||
existingMetadata = &vault.Metadata{
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Update metadata with new derivation info
|
||||
existingMetadata.DerivationIndex = derivationIndex
|
||||
existingMetadata.PublicKeyHash = publicKeyHash
|
||||
existingMetadata.MnemonicFamilyHash = familyHash
|
||||
|
||||
if err := vault.SaveVaultMetadata(cli.fs, vaultDir, existingMetadata); err != nil {
|
||||
secret.Debug("Failed to save vault metadata", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to save vault metadata: %w", err)
|
||||
}
|
||||
secret.Debug("Saved vault metadata with derivation index and public key hash")
|
||||
|
||||
// Get passphrase from environment variable
|
||||
passphraseStr := os.Getenv(secret.EnvUnlockPassphrase)
|
||||
if passphraseStr == "" {
|
||||
return fmt.Errorf("SB_UNLOCK_PASSPHRASE environment variable not set")
|
||||
}
|
||||
|
||||
secret.Debug("Using unlock passphrase from environment variable")
|
||||
|
||||
// Create secure buffer for passphrase
|
||||
passphraseBuffer := memguard.NewBufferFromBytes([]byte(passphraseStr))
|
||||
defer passphraseBuffer.Destroy()
|
||||
|
||||
// Unlock the vault with the derived long-term key
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Create passphrase-protected unlocker
|
||||
secret.Debug("Creating passphrase-protected unlocker")
|
||||
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to create unlocker", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to create unlocker: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("Successfully imported mnemonic into vault '%s'\n", vaultName)
|
||||
cmd.Printf("Long-term public key: %s\n", ltPublicKey)
|
||||
cmd.Printf("Unlocker ID: %s\n", passphraseUnlocker.GetID())
|
||||
|
||||
return nil
|
||||
}
|
||||
209
internal/cli/version.go
Normal file
209
internal/cli/version.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
tabWriterPadding = 2
|
||||
)
|
||||
|
||||
// newVersionCmd returns the version management command
|
||||
func newVersionCmd() *cobra.Command {
|
||||
cli := NewCLIInstance()
|
||||
|
||||
return VersionCommands(cli)
|
||||
}
|
||||
|
||||
// VersionCommands returns the version management commands
|
||||
func VersionCommands(cli *Instance) *cobra.Command {
|
||||
versionCmd := &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Manage secret versions",
|
||||
Long: "Commands for managing secret versions including listing, promoting, and retrieving specific versions",
|
||||
}
|
||||
|
||||
// List versions command
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list <secret-name>",
|
||||
Short: "List all versions of a secret",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cli.ListVersions(cmd, args[0])
|
||||
},
|
||||
}
|
||||
|
||||
// Promote version command
|
||||
promoteCmd := &cobra.Command{
|
||||
Use: "promote <secret-name> <version>",
|
||||
Short: "Promote a specific version to current",
|
||||
Long: "Updates the current symlink to point to the specified version without modifying timestamps",
|
||||
Args: cobra.ExactArgs(2), //nolint:mnd // Command requires exactly 2 arguments: secret-name and version
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cli.PromoteVersion(cmd, args[0], args[1])
|
||||
},
|
||||
}
|
||||
|
||||
versionCmd.AddCommand(listCmd, promoteCmd)
|
||||
|
||||
return versionCmd
|
||||
}
|
||||
|
||||
// ListVersions lists all versions of a secret
|
||||
func (cli *Instance) ListVersions(cmd *cobra.Command, secretName string) error {
|
||||
secret.Debug("ListVersions called", "secret_name", secretName)
|
||||
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get current vault", "error", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get vault directory", "error", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the encoded secret name
|
||||
encodedName := strings.ReplaceAll(secretName, "/", "%")
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", encodedName)
|
||||
|
||||
// Check if secret exists
|
||||
exists, err := afero.DirExists(cli.fs, secretDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to check if secret exists", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to check if secret exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
secret.Debug("Secret not found", "secret_name", secretName)
|
||||
|
||||
return fmt.Errorf("secret '%s' not found", secretName)
|
||||
}
|
||||
|
||||
// List all versions
|
||||
versions, err := secret.ListVersions(cli.fs, secretDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to list versions", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to list versions: %w", err)
|
||||
}
|
||||
|
||||
if len(versions) == 0 {
|
||||
cmd.Println("No versions found")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get current version
|
||||
currentVersion, err := secret.GetCurrentVersion(cli.fs, secretDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get current version", "error", err)
|
||||
currentVersion = ""
|
||||
}
|
||||
|
||||
// Get long-term key for decrypting metadata
|
||||
ltIdentity, err := vlt.GetOrDeriveLongTermKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get long-term key: %w", err)
|
||||
}
|
||||
|
||||
// Create table writer
|
||||
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, tabWriterPadding, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "VERSION\tCREATED\tSTATUS\tNOT_BEFORE\tNOT_AFTER")
|
||||
|
||||
// Load and display each version's metadata
|
||||
for _, version := range versions {
|
||||
sv := secret.NewVersion(vlt, secretName, version)
|
||||
|
||||
// Load metadata
|
||||
if err := sv.LoadMetadata(ltIdentity); err != nil {
|
||||
secret.Debug("Failed to load version metadata", "version", version, "error", err)
|
||||
// Display version with error
|
||||
status := "error"
|
||||
if version == currentVersion {
|
||||
status = "current (error)"
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\n", version, "-", status, "-", "-")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine status
|
||||
status := "expired"
|
||||
if version == currentVersion {
|
||||
status = "current"
|
||||
}
|
||||
|
||||
// Format timestamps
|
||||
createdAt := "-"
|
||||
if sv.Metadata.CreatedAt != nil {
|
||||
createdAt = sv.Metadata.CreatedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
notBefore := "-"
|
||||
if sv.Metadata.NotBefore != nil {
|
||||
notBefore = sv.Metadata.NotBefore.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
notAfter := "-"
|
||||
if sv.Metadata.NotAfter != nil {
|
||||
notAfter = sv.Metadata.NotAfter.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\n", version, createdAt, status, notBefore, notAfter)
|
||||
}
|
||||
|
||||
_ = w.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PromoteVersion promotes a specific version to current
|
||||
func (cli *Instance) PromoteVersion(cmd *cobra.Command, secretName string, version string) error {
|
||||
// Get current vault
|
||||
vlt, err := vault.GetCurrentVault(cli.fs, cli.stateDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the encoded secret name
|
||||
encodedName := strings.ReplaceAll(secretName, "/", "%")
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", encodedName)
|
||||
|
||||
// Check if version exists
|
||||
versionDir := filepath.Join(secretDir, "versions", version)
|
||||
exists, err := afero.DirExists(cli.fs, versionDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if version exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("version '%s' not found for secret '%s'", version, secretName)
|
||||
}
|
||||
|
||||
// Update the current symlink using the proper function
|
||||
if err := secret.SetCurrentVersion(cli.fs, secretDir, version); err != nil {
|
||||
return fmt.Errorf("failed to update current version: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("Promoted version %s to current for secret '%s'\n", version, secretName)
|
||||
|
||||
return nil
|
||||
}
|
||||
310
internal/cli/version_test.go
Normal file
310
internal/cli/version_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
// Version CLI Command Tests
|
||||
//
|
||||
// Tests for version-related CLI commands:
|
||||
//
|
||||
// - TestListVersionsCommand: Tests `secret version list` command output
|
||||
// - TestListVersionsNonExistentSecret: Tests error handling for missing secrets
|
||||
// - TestPromoteVersionCommand: Tests `secret version promote` command
|
||||
// - TestPromoteNonExistentVersion: Tests error handling for invalid promotion
|
||||
// - TestGetSecretWithVersion: Tests `secret get --version` flag functionality
|
||||
// - TestVersionCommandStructure: Tests command structure and help text
|
||||
// - TestListVersionsEmptyOutput: Tests edge case with no versions
|
||||
//
|
||||
// Test Utilities:
|
||||
// - setupTestVault(): CLI test helper for vault initialization
|
||||
// - Uses consistent test mnemonic for reproducible testing
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper function to add a secret to vault with proper buffer protection
|
||||
func addTestSecret(t *testing.T, vlt *vault.Vault, name string, value []byte, force bool) {
|
||||
t.Helper()
|
||||
buffer := memguard.NewBufferFromBytes(value)
|
||||
defer buffer.Destroy()
|
||||
err := vlt.AddSecret(name, buffer, force)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Helper function to set up a vault with long-term key
|
||||
func setupTestVault(t *testing.T, fs afero.Fs, stateDir string) {
|
||||
// Set mnemonic for testing
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon " +
|
||||
"abandon abandon abandon abandon abandon about"
|
||||
t.Setenv(secret.EnvMnemonic, testMnemonic)
|
||||
|
||||
// Create vault
|
||||
vlt, err := vault.CreateVault(fs, stateDir, "default")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Derive and store long-term key from mnemonic
|
||||
mnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
ltIdentity, err := agehd.DeriveIdentity(mnemonic, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store long-term public key in vault
|
||||
vaultDir, _ := vlt.GetDirectory()
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, ltPubKeyPath, []byte(ltIdentity.Recipient().String()), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Select vault
|
||||
err = vault.SelectVault(fs, stateDir, "default")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestListVersionsCommand(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
cli := NewCLIInstanceWithStateDir(fs, stateDir)
|
||||
|
||||
// Set up vault with long-term key
|
||||
setupTestVault(t, fs, stateDir)
|
||||
|
||||
// Add a secret with multiple versions
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
|
||||
|
||||
// Create a command for output capture
|
||||
cmd := newRootCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
// List versions
|
||||
err = cli.ListVersions(cmd, "test/secret")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read output
|
||||
outputStr := buf.String()
|
||||
|
||||
// Verify output contains version headers
|
||||
assert.Contains(t, outputStr, "VERSION")
|
||||
assert.Contains(t, outputStr, "CREATED")
|
||||
assert.Contains(t, outputStr, "STATUS")
|
||||
assert.Contains(t, outputStr, "NOT_BEFORE")
|
||||
assert.Contains(t, outputStr, "NOT_AFTER")
|
||||
|
||||
// Should have current status for latest version
|
||||
assert.Contains(t, outputStr, "current")
|
||||
|
||||
// Should have two version entries
|
||||
lines := strings.Split(outputStr, "\n")
|
||||
versionLines := 0
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, ".001") || strings.Contains(line, ".002") {
|
||||
versionLines++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, versionLines)
|
||||
}
|
||||
|
||||
func TestListVersionsNonExistentSecret(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
cli := NewCLIInstanceWithStateDir(fs, stateDir)
|
||||
|
||||
// Set up vault with long-term key
|
||||
setupTestVault(t, fs, stateDir)
|
||||
|
||||
// Create a command for output capture
|
||||
cmd := newRootCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
// Try to list versions of non-existent secret
|
||||
err := cli.ListVersions(cmd, "nonexistent/secret")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestPromoteVersionCommand(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
cli := NewCLIInstanceWithStateDir(fs, stateDir)
|
||||
|
||||
// Set up vault with long-term key
|
||||
setupTestVault(t, fs, stateDir)
|
||||
|
||||
// Add a secret with multiple versions
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
|
||||
|
||||
// Get versions
|
||||
vaultDir, _ := vlt.GetDirectory()
|
||||
secretDir := vaultDir + "/secrets.d/test%secret"
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 2)
|
||||
|
||||
// Current should be version-2
|
||||
value, err := vlt.GetSecret("test/secret")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-2"), value)
|
||||
|
||||
// Promote first version
|
||||
firstVersion := versions[1] // Older version
|
||||
|
||||
// Create a command for output capture
|
||||
cmd := newRootCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cli.PromoteVersion(cmd, "test/secret", firstVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read output
|
||||
outputStr := buf.String()
|
||||
|
||||
// Verify success message
|
||||
assert.Contains(t, outputStr, "Promoted version")
|
||||
assert.Contains(t, outputStr, firstVersion)
|
||||
|
||||
// Verify current is now version-1
|
||||
value, err = vlt.GetSecret("test/secret")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-1"), value)
|
||||
}
|
||||
|
||||
func TestPromoteNonExistentVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
cli := NewCLIInstanceWithStateDir(fs, stateDir)
|
||||
|
||||
// Set up vault with long-term key
|
||||
setupTestVault(t, fs, stateDir)
|
||||
|
||||
// Add a secret
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
addTestSecret(t, vlt, "test/secret", []byte("value"), false)
|
||||
|
||||
// Create a command for output capture
|
||||
cmd := newRootCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
// Try to promote non-existent version
|
||||
err = cli.PromoteVersion(cmd, "test/secret", "20991231.999")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestGetSecretWithVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
cli := NewCLIInstanceWithStateDir(fs, stateDir)
|
||||
|
||||
// Set up vault with long-term key
|
||||
setupTestVault(t, fs, stateDir)
|
||||
|
||||
// Add a secret with multiple versions
|
||||
vlt, err := vault.GetCurrentVault(fs, stateDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
addTestSecret(t, vlt, "test/secret", []byte("version-1"), false)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
addTestSecret(t, vlt, "test/secret", []byte("version-2"), true)
|
||||
|
||||
// Get versions
|
||||
vaultDir, _ := vlt.GetDirectory()
|
||||
secretDir := vaultDir + "/secrets.d/test%secret"
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 2)
|
||||
|
||||
// Create a command for output capture
|
||||
cmd := newRootCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
|
||||
// Test getting current version (empty version string)
|
||||
err = cli.GetSecretWithVersion(cmd, "test/secret", "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "version-2", buf.String())
|
||||
|
||||
// Test getting specific version
|
||||
buf.Reset()
|
||||
firstVersion := versions[1] // Older version
|
||||
err = cli.GetSecretWithVersion(cmd, "test/secret", firstVersion)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "version-1", buf.String())
|
||||
}
|
||||
|
||||
func TestVersionCommandStructure(t *testing.T) {
|
||||
// Test that version commands are properly structured
|
||||
cli := NewCLIInstance()
|
||||
cmd := VersionCommands(cli)
|
||||
|
||||
assert.Equal(t, "version", cmd.Use)
|
||||
assert.Equal(t, "Manage secret versions", cmd.Short)
|
||||
|
||||
// Check subcommands
|
||||
listCmd := cmd.Commands()[0]
|
||||
assert.Equal(t, "list <secret-name>", listCmd.Use)
|
||||
assert.Equal(t, "List all versions of a secret", listCmd.Short)
|
||||
|
||||
promoteCmd := cmd.Commands()[1]
|
||||
assert.Equal(t, "promote <secret-name> <version>", promoteCmd.Use)
|
||||
assert.Equal(t, "Promote a specific version to current", promoteCmd.Short)
|
||||
}
|
||||
|
||||
func TestListVersionsEmptyOutput(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
cli := NewCLIInstanceWithStateDir(fs, stateDir)
|
||||
|
||||
// Set up vault with long-term key
|
||||
setupTestVault(t, fs, stateDir)
|
||||
|
||||
// Create a secret directory without versions (edge case)
|
||||
vaultDir := stateDir + "/vaults.d/default"
|
||||
secretDir := vaultDir + "/secrets.d/test%secret"
|
||||
err := fs.MkdirAll(secretDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a command for output capture
|
||||
cmd := newRootCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
// List versions - should show "No versions found"
|
||||
err = cli.ListVersions(cmd, "test/secret")
|
||||
|
||||
// Should succeed even with no versions
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
27
internal/secret/constants.go
Normal file
27
internal/secret/constants.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Package secret provides core types and constants for the secret application.
|
||||
package secret
|
||||
|
||||
import "os"
|
||||
|
||||
const (
|
||||
// AppID is the unique identifier for this application
|
||||
AppID = "berlin.sneak.pkg.secret"
|
||||
|
||||
// EnvStateDir is the environment variable for specifying the state directory
|
||||
EnvStateDir = "SB_SECRET_STATE_DIR"
|
||||
// EnvMnemonic is the environment variable for providing the mnemonic phrase
|
||||
EnvMnemonic = "SB_SECRET_MNEMONIC"
|
||||
// EnvUnlockPassphrase is the environment variable for providing the unlock passphrase
|
||||
EnvUnlockPassphrase = "SB_UNLOCK_PASSPHRASE" //nolint:gosec // G101: This is an env var name, not a credential
|
||||
// EnvGPGKeyID is the environment variable for providing the GPG key ID
|
||||
EnvGPGKeyID = "SB_GPG_KEY_ID"
|
||||
)
|
||||
|
||||
// File system permission constants
|
||||
const (
|
||||
// DirPerms is the permission used for directories (read-write-execute for owner only)
|
||||
DirPerms os.FileMode = 0o700
|
||||
|
||||
// FilePerms is the permission used for sensitive files (read-write for owner only)
|
||||
FilePerms os.FileMode = 0o600
|
||||
)
|
||||
149
internal/secret/crypto.go
Normal file
149
internal/secret/crypto.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"filippo.io/age"
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// EncryptToRecipient encrypts data to a recipient using age
|
||||
// The data parameter should be a LockedBuffer for secure memory handling
|
||||
func EncryptToRecipient(data *memguard.LockedBuffer, recipient age.Recipient) ([]byte, error) {
|
||||
if data == nil {
|
||||
return nil, fmt.Errorf("data buffer is nil")
|
||||
}
|
||||
|
||||
Debug("EncryptToRecipient starting", "data_length", data.Size())
|
||||
|
||||
var buf bytes.Buffer
|
||||
Debug("Creating age encryptor")
|
||||
w, err := age.Encrypt(&buf, recipient)
|
||||
if err != nil {
|
||||
Debug("Failed to create encryptor", "error", err)
|
||||
|
||||
return nil, fmt.Errorf("failed to create encryptor: %w", err)
|
||||
}
|
||||
Debug("Created age encryptor successfully")
|
||||
|
||||
Debug("Writing data to encryptor")
|
||||
if _, err := w.Write(data.Bytes()); err != nil {
|
||||
Debug("Failed to write data to encryptor", "error", err)
|
||||
|
||||
return nil, fmt.Errorf("failed to write data: %w", err)
|
||||
}
|
||||
Debug("Wrote data to encryptor successfully")
|
||||
|
||||
Debug("Closing encryptor")
|
||||
if err := w.Close(); err != nil {
|
||||
Debug("Failed to close encryptor", "error", err)
|
||||
|
||||
return nil, fmt.Errorf("failed to close encryptor: %w", err)
|
||||
}
|
||||
Debug("Closed encryptor successfully")
|
||||
|
||||
result := buf.Bytes()
|
||||
Debug("EncryptToRecipient completed successfully", "result_length", len(result))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecryptWithIdentity decrypts data with an identity using age
|
||||
func DecryptWithIdentity(data []byte, identity age.Identity) ([]byte, error) {
|
||||
r, err := age.Decrypt(bytes.NewReader(data), identity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create decryptor: %w", err)
|
||||
}
|
||||
|
||||
result, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read decrypted data: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EncryptWithPassphrase encrypts data using a passphrase with age's scrypt-based encryption
|
||||
// The passphrase parameter should be a LockedBuffer for secure memory handling
|
||||
func EncryptWithPassphrase(data []byte, passphrase *memguard.LockedBuffer) ([]byte, error) {
|
||||
if passphrase == nil {
|
||||
return nil, fmt.Errorf("passphrase buffer is nil")
|
||||
}
|
||||
|
||||
// Get the passphrase string temporarily
|
||||
passphraseStr := passphrase.String()
|
||||
recipient, err := age.NewScryptRecipient(passphraseStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create scrypt recipient: %w", err)
|
||||
}
|
||||
|
||||
// Create a secure buffer for the data
|
||||
dataBuffer := memguard.NewBufferFromBytes(data)
|
||||
defer dataBuffer.Destroy()
|
||||
|
||||
return EncryptToRecipient(dataBuffer, recipient)
|
||||
}
|
||||
|
||||
// DecryptWithPassphrase decrypts data using a passphrase with age's scrypt-based decryption
|
||||
// The passphrase parameter should be a LockedBuffer for secure memory handling
|
||||
func DecryptWithPassphrase(encryptedData []byte, passphrase *memguard.LockedBuffer) ([]byte, error) {
|
||||
if passphrase == nil {
|
||||
return nil, fmt.Errorf("passphrase buffer is nil")
|
||||
}
|
||||
|
||||
// Get the passphrase string temporarily
|
||||
passphraseStr := passphrase.String()
|
||||
identity, err := age.NewScryptIdentity(passphraseStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create scrypt identity: %w", err)
|
||||
}
|
||||
|
||||
return DecryptWithIdentity(encryptedData, identity)
|
||||
}
|
||||
|
||||
// ReadPassphrase reads a passphrase securely from the terminal without echoing
|
||||
// This version is for unlocking and doesn't require confirmation
|
||||
// Returns a LockedBuffer containing the passphrase for secure memory handling
|
||||
func ReadPassphrase(prompt string) (*memguard.LockedBuffer, error) {
|
||||
// Check if stdin is a terminal
|
||||
if !term.IsTerminal(syscall.Stdin) {
|
||||
// Not a terminal - never read passphrases from piped input for security reasons
|
||||
return nil, fmt.Errorf("cannot read passphrase from non-terminal stdin " +
|
||||
"(piped input or script). Please set the SB_UNLOCK_PASSPHRASE " +
|
||||
"environment variable or run interactively")
|
||||
}
|
||||
|
||||
// stdin is a terminal, check if stderr is also a terminal for interactive prompting
|
||||
if !term.IsTerminal(syscall.Stderr) {
|
||||
return nil, fmt.Errorf("cannot prompt for passphrase: stderr is not a terminal " +
|
||||
"(running in non-interactive mode). Please set the SB_UNLOCK_PASSPHRASE " +
|
||||
"environment variable")
|
||||
}
|
||||
|
||||
// Both stdin and stderr are terminals - use secure password reading
|
||||
fmt.Fprint(os.Stderr, prompt) // Write prompt to stderr, not stdout
|
||||
passphrase, err := term.ReadPassword(syscall.Stdin)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read passphrase: %w", err)
|
||||
}
|
||||
fmt.Fprintln(os.Stderr) // Print newline to stderr since ReadPassword doesn't echo
|
||||
|
||||
if len(passphrase) == 0 {
|
||||
return nil, fmt.Errorf("passphrase cannot be empty")
|
||||
}
|
||||
|
||||
// Create a secure buffer and copy the passphrase
|
||||
secureBuffer := memguard.NewBufferFromBytes(passphrase)
|
||||
|
||||
// Clear the original passphrase slice
|
||||
for i := range passphrase {
|
||||
passphrase[i] = 0
|
||||
}
|
||||
|
||||
return secureBuffer, nil
|
||||
}
|
||||
139
internal/secret/debug.go
Normal file
139
internal/secret/debug.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var (
|
||||
debugEnabled bool //nolint:gochecknoglobals // Package-wide debug state is necessary
|
||||
debugLogger *slog.Logger //nolint:gochecknoglobals // Package-wide logger instance is necessary
|
||||
)
|
||||
|
||||
func init() {
|
||||
InitDebugLogging()
|
||||
}
|
||||
|
||||
// InitDebugLogging initializes the debug logging system based on current GODEBUG environment variable
|
||||
func InitDebugLogging() {
|
||||
godebug := os.Getenv("GODEBUG")
|
||||
debugEnabled = strings.Contains(godebug, "berlin.sneak.pkg.secret")
|
||||
|
||||
if !debugEnabled {
|
||||
// Create a no-op logger that discards all output
|
||||
debugLogger = slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Disable stderr buffering for immediate debug output when debugging is enabled
|
||||
_, _, _ = syscall.Syscall(syscall.SYS_FCNTL, os.Stderr.Fd(), syscall.F_SETFL, syscall.O_SYNC)
|
||||
|
||||
// Check if STDERR is a TTY
|
||||
isTTY := term.IsTerminal(syscall.Stderr)
|
||||
|
||||
var handler slog.Handler
|
||||
if isTTY {
|
||||
// TTY output: colorized structured format
|
||||
handler = newColorizedHandler(os.Stderr)
|
||||
} else {
|
||||
// Non-TTY output: JSON Lines format
|
||||
handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})
|
||||
}
|
||||
|
||||
debugLogger = slog.New(handler)
|
||||
}
|
||||
|
||||
// IsDebugEnabled returns true if debug logging is enabled
|
||||
func IsDebugEnabled() bool {
|
||||
return debugEnabled
|
||||
}
|
||||
|
||||
// Debug logs a debug message with optional attributes
|
||||
func Debug(msg string, args ...any) {
|
||||
if !debugEnabled {
|
||||
return
|
||||
}
|
||||
debugLogger.Debug(msg, args...)
|
||||
}
|
||||
|
||||
// DebugF logs a formatted debug message with optional attributes
|
||||
func DebugF(format string, args ...any) {
|
||||
if !debugEnabled {
|
||||
return
|
||||
}
|
||||
debugLogger.Debug(fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
// DebugWith logs a debug message with structured attributes
|
||||
func DebugWith(msg string, attrs ...slog.Attr) {
|
||||
if !debugEnabled {
|
||||
return
|
||||
}
|
||||
debugLogger.LogAttrs(context.Background(), slog.LevelDebug, msg, attrs...)
|
||||
}
|
||||
|
||||
// colorizedHandler implements a TTY-friendly structured log handler
|
||||
type colorizedHandler struct {
|
||||
output io.Writer
|
||||
}
|
||||
|
||||
func newColorizedHandler(output io.Writer) slog.Handler {
|
||||
return &colorizedHandler{output: output}
|
||||
}
|
||||
|
||||
func (h *colorizedHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
// Explicitly check that debug is enabled AND the level is DEBUG or higher
|
||||
// This ensures we don't default to INFO level when debug is enabled
|
||||
return debugEnabled && level >= slog.LevelDebug
|
||||
}
|
||||
|
||||
func (h *colorizedHandler) Handle(_ context.Context, record slog.Record) error {
|
||||
if !debugEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Format: [DEBUG] message {key=value, key2=value2}
|
||||
output := fmt.Sprintf("\033[36m[DEBUG]\033[0m \033[1m%s\033[0m", record.Message)
|
||||
|
||||
if record.NumAttrs() > 0 {
|
||||
output += " \033[33m{"
|
||||
first := true
|
||||
record.Attrs(func(attr slog.Attr) bool {
|
||||
if !first {
|
||||
output += ", "
|
||||
}
|
||||
first = false
|
||||
output += fmt.Sprintf("%s=%#v", attr.Key, attr.Value.Any())
|
||||
|
||||
return true
|
||||
})
|
||||
output += "}\033[0m"
|
||||
}
|
||||
|
||||
output += "\n"
|
||||
_, err := h.output.Write([]byte(output))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *colorizedHandler) WithAttrs(_ []slog.Attr) slog.Handler {
|
||||
// For simplicity, return the same handler
|
||||
// In a more complex implementation, we'd create a new handler with the attrs
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *colorizedHandler) WithGroup(_ string) slog.Handler {
|
||||
// For simplicity, return the same handler
|
||||
// In a more complex implementation, we'd create a new handler with the group
|
||||
return h
|
||||
}
|
||||
121
internal/secret/debug_test.go
Normal file
121
internal/secret/debug_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func TestDebugLogging(t *testing.T) {
|
||||
// Test cleanup handled by t.Setenv
|
||||
defer InitDebugLogging() // Re-initialize debug system after test
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
godebug string
|
||||
expectEnabled bool
|
||||
}{
|
||||
{
|
||||
name: "debug enabled",
|
||||
godebug: "berlin.sneak.pkg.secret",
|
||||
expectEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "debug enabled with other flags",
|
||||
godebug: "other=1,berlin.sneak.pkg.secret,another=value",
|
||||
expectEnabled: true,
|
||||
},
|
||||
{
|
||||
name: "debug disabled",
|
||||
godebug: "other=1",
|
||||
expectEnabled: false,
|
||||
},
|
||||
{
|
||||
name: "debug disabled empty",
|
||||
godebug: "",
|
||||
expectEnabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set GODEBUG
|
||||
if tt.godebug != "" {
|
||||
t.Setenv("GODEBUG", tt.godebug)
|
||||
}
|
||||
|
||||
// Re-initialize debug system
|
||||
InitDebugLogging()
|
||||
|
||||
// Test if debug is enabled
|
||||
enabled := IsDebugEnabled()
|
||||
if enabled != tt.expectEnabled {
|
||||
t.Errorf("IsDebugEnabled() = %v, want %v", enabled, tt.expectEnabled)
|
||||
}
|
||||
|
||||
// If debug should be enabled, test that debug output works
|
||||
if tt.expectEnabled {
|
||||
// Capture debug output by redirecting the colorized handler
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Override the debug logger for testing
|
||||
oldLogger := debugLogger
|
||||
if term.IsTerminal(syscall.Stderr) {
|
||||
// TTY: use colorized handler with our buffer
|
||||
debugLogger = slog.New(newColorizedHandler(&buf))
|
||||
} else {
|
||||
// Non-TTY: use JSON handler with our buffer
|
||||
debugLogger = slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
}))
|
||||
}
|
||||
|
||||
// Test debug output
|
||||
Debug("test message", "key", "value")
|
||||
|
||||
// Restore original logger
|
||||
debugLogger = oldLogger
|
||||
|
||||
// Check that output was generated
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "test message") {
|
||||
t.Errorf("Debug output does not contain expected message. Got: %s", output)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugFunctions(t *testing.T) {
|
||||
// Enable debug for testing
|
||||
t.Setenv("GODEBUG", "berlin.sneak.pkg.secret")
|
||||
defer InitDebugLogging() // Re-initialize after test
|
||||
|
||||
InitDebugLogging()
|
||||
|
||||
if !IsDebugEnabled() {
|
||||
t.Log("Debug not enabled, but continuing with debug function tests anyway")
|
||||
}
|
||||
|
||||
// Test that debug functions don't panic and can be called
|
||||
t.Run("Debug", func(_ *testing.T) {
|
||||
Debug("test debug message")
|
||||
Debug("test with args", "key", "value", "number", 42)
|
||||
})
|
||||
|
||||
t.Run("DebugF", func(_ *testing.T) {
|
||||
DebugF("formatted message: %s %d", "test", 123)
|
||||
})
|
||||
|
||||
t.Run("DebugWith", func(_ *testing.T) {
|
||||
DebugWith("structured message",
|
||||
slog.String("string_key", "string_value"),
|
||||
slog.Int("int_key", 42),
|
||||
slog.Bool("bool_key", true),
|
||||
)
|
||||
})
|
||||
}
|
||||
56
internal/secret/helpers.go
Normal file
56
internal/secret/helpers.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// generateRandomString generates a random string of the specified length using the given character set
|
||||
func generateRandomString(length int, charset string) (string, error) {
|
||||
if length <= 0 {
|
||||
return "", fmt.Errorf("length must be positive")
|
||||
}
|
||||
|
||||
result := make([]byte, length)
|
||||
charsetLen := big.NewInt(int64(len(charset)))
|
||||
|
||||
for i := range length {
|
||||
randomIndex, err := rand.Int(rand.Reader, charsetLen)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate random number: %w", err)
|
||||
}
|
||||
result[i] = charset[randomIndex.Int64()]
|
||||
}
|
||||
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
// DetermineStateDir determines the state directory based on environment variables and OS
|
||||
func DetermineStateDir(customConfigDir string) string {
|
||||
// Check for environment variable first
|
||||
if envStateDir := os.Getenv(EnvStateDir); envStateDir != "" {
|
||||
return envStateDir
|
||||
}
|
||||
|
||||
// Use custom config dir if provided
|
||||
if customConfigDir != "" {
|
||||
return filepath.Join(customConfigDir, AppID)
|
||||
}
|
||||
|
||||
// Use os.UserConfigDir() which handles platform-specific directories:
|
||||
// - On Unix systems, it returns $XDG_CONFIG_HOME or $HOME/.config
|
||||
// - On Darwin, it returns $HOME/Library/Application Support
|
||||
// - On Windows, it returns %AppData%
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
// Fallback to a reasonable default if we can't determine user config dir
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
|
||||
return filepath.Join(homeDir, ".config", AppID)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, AppID)
|
||||
}
|
||||
530
internal/secret/keychainunlocker.go
Normal file
530
internal/secret/keychainunlocker.go
Normal file
@@ -0,0 +1,530 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
const (
|
||||
agePrivKeyPassphraseLength = 64
|
||||
)
|
||||
|
||||
// keychainItemNameRegex validates keychain item names
|
||||
// Allows alphanumeric characters, dots, hyphens, and underscores only
|
||||
var keychainItemNameRegex = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
|
||||
|
||||
// KeychainUnlockerMetadata extends UnlockerMetadata with keychain-specific data
|
||||
type KeychainUnlockerMetadata struct {
|
||||
UnlockerMetadata
|
||||
// Keychain item name
|
||||
KeychainItemName string `json:"keychainItemName"`
|
||||
}
|
||||
|
||||
// KeychainUnlocker represents a macOS Keychain-protected unlocker
|
||||
type KeychainUnlocker struct {
|
||||
Directory string
|
||||
Metadata UnlockerMetadata
|
||||
fs afero.Fs
|
||||
}
|
||||
|
||||
// KeychainData represents the data stored in the macOS keychain
|
||||
type KeychainData struct {
|
||||
AgePublicKey string `json:"agePublicKey"`
|
||||
AgePrivKeyPassphrase string `json:"agePrivKeyPassphrase"`
|
||||
EncryptedLongtermKey string `json:"encryptedLongtermKey"`
|
||||
}
|
||||
|
||||
// GetIdentity implements Unlocker interface for Keychain-based unlockers
|
||||
func (k *KeychainUnlocker) GetIdentity() (*age.X25519Identity, error) {
|
||||
DebugWith("Getting keychain unlocker identity",
|
||||
slog.String("unlocker_id", k.GetID()),
|
||||
slog.String("unlocker_type", k.GetType()),
|
||||
)
|
||||
|
||||
// Step 1: Get keychain item name
|
||||
keychainItemName, err := k.GetKeychainItemName()
|
||||
if err != nil {
|
||||
Debug("Failed to get keychain item name", "error", err, "unlocker_id", k.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to get keychain item name: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Retrieve data from keychain
|
||||
Debug("Retrieving data from macOS keychain", "keychain_item", keychainItemName)
|
||||
keychainDataBytes, err := retrieveFromKeychain(keychainItemName)
|
||||
if err != nil {
|
||||
Debug("Failed to retrieve data from keychain", "error", err, "keychain_item", keychainItemName)
|
||||
|
||||
return nil, fmt.Errorf("failed to retrieve data from keychain: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Retrieved data from keychain",
|
||||
slog.String("unlocker_id", k.GetID()),
|
||||
slog.Int("data_length", len(keychainDataBytes)),
|
||||
)
|
||||
|
||||
// Step 3: Parse keychain data
|
||||
var keychainData KeychainData
|
||||
if err := json.Unmarshal(keychainDataBytes, &keychainData); err != nil {
|
||||
Debug("Failed to parse keychain data", "error", err, "unlocker_id", k.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to parse keychain data: %w", err)
|
||||
}
|
||||
|
||||
Debug("Parsed keychain data successfully", "unlocker_id", k.GetID())
|
||||
|
||||
// Step 4: Read the encrypted age private key from filesystem
|
||||
agePrivKeyPath := filepath.Join(k.Directory, "priv.age")
|
||||
Debug("Reading encrypted age private key", "path", agePrivKeyPath)
|
||||
|
||||
encryptedAgePrivKeyData, err := afero.ReadFile(k.fs, agePrivKeyPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read encrypted age private key", "error", err, "path", agePrivKeyPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read encrypted age private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Read encrypted age private key",
|
||||
slog.String("unlocker_id", k.GetID()),
|
||||
slog.Int("encrypted_length", len(encryptedAgePrivKeyData)),
|
||||
)
|
||||
|
||||
// Step 5: Decrypt the age private key using the passphrase from keychain
|
||||
Debug("Decrypting age private key with keychain passphrase", "unlocker_id", k.GetID())
|
||||
// Create secure buffer for the keychain passphrase
|
||||
passphraseBuffer := memguard.NewBufferFromBytes([]byte(keychainData.AgePrivKeyPassphrase))
|
||||
defer passphraseBuffer.Destroy()
|
||||
|
||||
agePrivKeyData, err := DecryptWithPassphrase(encryptedAgePrivKeyData, passphraseBuffer)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt age private key with keychain passphrase", "error", err, "unlocker_id", k.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt age private key with keychain passphrase: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Successfully decrypted age private key with keychain passphrase",
|
||||
slog.String("unlocker_id", k.GetID()),
|
||||
slog.Int("decrypted_length", len(agePrivKeyData)),
|
||||
)
|
||||
|
||||
// Step 6: Parse the decrypted age private key
|
||||
Debug("Parsing decrypted age private key", "unlocker_id", k.GetID())
|
||||
|
||||
// Create a secure buffer for the private key data
|
||||
agePrivKeyBuffer := memguard.NewBufferFromBytes(agePrivKeyData)
|
||||
defer agePrivKeyBuffer.Destroy()
|
||||
|
||||
// Clear the original private key data
|
||||
for i := range agePrivKeyData {
|
||||
agePrivKeyData[i] = 0
|
||||
}
|
||||
|
||||
ageIdentity, err := age.ParseX25519Identity(agePrivKeyBuffer.String())
|
||||
if err != nil {
|
||||
Debug("Failed to parse age private key", "error", err, "unlocker_id", k.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to parse age private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Successfully parsed keychain age identity",
|
||||
slog.String("unlocker_id", k.GetID()),
|
||||
slog.String("public_key", ageIdentity.Recipient().String()),
|
||||
)
|
||||
|
||||
return ageIdentity, nil
|
||||
}
|
||||
|
||||
// GetType implements Unlocker interface
|
||||
func (k *KeychainUnlocker) GetType() string {
|
||||
return "keychain"
|
||||
}
|
||||
|
||||
// GetMetadata implements Unlocker interface
|
||||
func (k *KeychainUnlocker) GetMetadata() UnlockerMetadata {
|
||||
return k.Metadata
|
||||
}
|
||||
|
||||
// GetDirectory implements Unlocker interface
|
||||
func (k *KeychainUnlocker) GetDirectory() string {
|
||||
return k.Directory
|
||||
}
|
||||
|
||||
// GetID implements Unlocker interface - generates ID from keychain item name
|
||||
func (k *KeychainUnlocker) GetID() string {
|
||||
// Generate ID using keychain item name
|
||||
keychainItemName, err := k.GetKeychainItemName()
|
||||
if err != nil {
|
||||
// The vault metadata is corrupt - this is a fatal error
|
||||
// We cannot continue with a fallback ID as that would mask data corruption
|
||||
panic(fmt.Sprintf("Keychain unlocker metadata is corrupt or missing keychain item name: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-keychain", keychainItemName)
|
||||
}
|
||||
|
||||
// Remove implements Unlocker interface - removes the keychain unlocker
|
||||
func (k *KeychainUnlocker) Remove() error {
|
||||
// Step 1: Get keychain item name
|
||||
keychainItemName, err := k.GetKeychainItemName()
|
||||
if err != nil {
|
||||
Debug("Failed to get keychain item name during removal", "error", err, "unlocker_id", k.GetID())
|
||||
|
||||
return fmt.Errorf("failed to get keychain item name: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Remove from keychain
|
||||
Debug("Removing keychain item", "keychain_item", keychainItemName)
|
||||
if err := deleteFromKeychain(keychainItemName); err != nil {
|
||||
Debug("Failed to remove keychain item", "error", err, "keychain_item", keychainItemName)
|
||||
|
||||
return fmt.Errorf("failed to remove keychain item: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Remove directory
|
||||
Debug("Removing keychain unlocker directory", "directory", k.Directory)
|
||||
if err := k.fs.RemoveAll(k.Directory); err != nil {
|
||||
Debug("Failed to remove keychain unlocker directory", "error", err, "directory", k.Directory)
|
||||
|
||||
return fmt.Errorf("failed to remove keychain unlocker directory: %w", err)
|
||||
}
|
||||
|
||||
Debug("Successfully removed keychain unlocker", "unlocker_id", k.GetID(), "keychain_item", keychainItemName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewKeychainUnlocker creates a new KeychainUnlocker instance
|
||||
func NewKeychainUnlocker(fs afero.Fs, directory string, metadata UnlockerMetadata) *KeychainUnlocker {
|
||||
return &KeychainUnlocker{
|
||||
Directory: directory,
|
||||
Metadata: metadata,
|
||||
fs: fs,
|
||||
}
|
||||
}
|
||||
|
||||
// GetKeychainItemName returns the keychain item name from metadata
|
||||
func (k *KeychainUnlocker) GetKeychainItemName() (string, error) {
|
||||
// Load the metadata
|
||||
metadataPath := filepath.Join(k.Directory, "unlocker-metadata.json")
|
||||
metadataData, err := afero.ReadFile(k.fs, metadataPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read keychain metadata: %w", err)
|
||||
}
|
||||
|
||||
var keychainMetadata KeychainUnlockerMetadata
|
||||
if err := json.Unmarshal(metadataData, &keychainMetadata); err != nil {
|
||||
return "", fmt.Errorf("failed to parse keychain metadata: %w", err)
|
||||
}
|
||||
|
||||
return keychainMetadata.KeychainItemName, nil
|
||||
}
|
||||
|
||||
// generateKeychainUnlockerName generates a unique name for the keychain unlocker
|
||||
func generateKeychainUnlockerName(vaultName string) (string, error) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get hostname: %w", err)
|
||||
}
|
||||
|
||||
// Format: secret-<vault>-<hostname>-<date>
|
||||
enrollmentDate := time.Now().Format("2006-01-02")
|
||||
|
||||
return fmt.Sprintf("secret-%s-%s-%s", vaultName, hostname, enrollmentDate), nil
|
||||
}
|
||||
|
||||
// getLongTermPrivateKey retrieves the long-term private key either from environment or current unlocker
|
||||
// Returns a LockedBuffer to ensure the private key is protected in memory
|
||||
func getLongTermPrivateKey(fs afero.Fs, vault VaultInterface) (*memguard.LockedBuffer, error) {
|
||||
// Check if mnemonic is available in environment variable
|
||||
envMnemonic := os.Getenv(EnvMnemonic)
|
||||
if envMnemonic != "" {
|
||||
// Use mnemonic directly to derive long-term key
|
||||
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
|
||||
}
|
||||
|
||||
// Return the private key in a secure buffer
|
||||
return memguard.NewBufferFromBytes([]byte(ltIdentity.String())), nil
|
||||
}
|
||||
|
||||
// Get the vault to access current unlocker
|
||||
currentUnlocker, err := vault.GetCurrentUnlocker()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
|
||||
}
|
||||
|
||||
// Get the current unlocker identity
|
||||
currentUnlockerIdentity, err := currentUnlocker.GetIdentity()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current unlocker identity: %w", err)
|
||||
}
|
||||
|
||||
// Get encrypted long-term key from current unlocker, handling different types
|
||||
var encryptedLtPrivKey []byte
|
||||
switch currentUnlocker := currentUnlocker.(type) {
|
||||
case *PassphraseUnlocker:
|
||||
// Read the encrypted long-term private key from passphrase unlocker
|
||||
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted long-term key from current passphrase unlocker: %w", err)
|
||||
}
|
||||
|
||||
case *PGPUnlocker:
|
||||
// Read the encrypted long-term private key from PGP unlocker
|
||||
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted long-term key from current PGP unlocker: %w", err)
|
||||
}
|
||||
|
||||
case *KeychainUnlocker:
|
||||
// Read the encrypted long-term private key from another keychain unlocker
|
||||
encryptedLtPrivKey, err = afero.ReadFile(fs, filepath.Join(currentUnlocker.GetDirectory(), "longterm.age"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted long-term key from current keychain unlocker: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported current unlocker type for keychain unlocker creation")
|
||||
}
|
||||
|
||||
// Decrypt long-term private key using current unlocker
|
||||
ltPrivKeyData, err := DecryptWithIdentity(encryptedLtPrivKey, currentUnlockerIdentity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
|
||||
}
|
||||
|
||||
// Return the decrypted key in a secure buffer
|
||||
return memguard.NewBufferFromBytes(ltPrivKeyData), nil
|
||||
}
|
||||
|
||||
// CreateKeychainUnlocker creates a new keychain unlocker and stores it in the vault
|
||||
func CreateKeychainUnlocker(fs afero.Fs, stateDir string) (*KeychainUnlocker, error) {
|
||||
// Check if we're on macOS
|
||||
if err := checkMacOSAvailable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get current vault using the GetCurrentVault function from the same package
|
||||
vault, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current vault: %w", err)
|
||||
}
|
||||
|
||||
// Generate the keychain item name
|
||||
keychainItemName, err := generateKeychainUnlockerName(vault.GetName())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate keychain item name: %w", err)
|
||||
}
|
||||
|
||||
// Create unlocker directory using the keychain item name as the directory name
|
||||
vaultDir, err := vault.GetDirectory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get vault directory: %w", err)
|
||||
}
|
||||
|
||||
unlockerDir := filepath.Join(vaultDir, "unlockers.d", keychainItemName)
|
||||
if err := fs.MkdirAll(unlockerDir, DirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create unlocker directory: %w", err)
|
||||
}
|
||||
|
||||
// Step 1: Generate a new age keypair for the keychain unlocker
|
||||
ageIdentity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate age keypair: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Generate a random passphrase for encrypting the age private key
|
||||
agePrivKeyPassphrase, err := generateRandomPassphrase(agePrivKeyPassphraseLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate age private key passphrase: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Store age recipient as plaintext
|
||||
ageRecipient := ageIdentity.Recipient().String()
|
||||
recipientPath := filepath.Join(unlockerDir, "pub.txt")
|
||||
if err := afero.WriteFile(fs, recipientPath, []byte(ageRecipient), FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write age recipient: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Encrypt age private key with the generated passphrase and store on disk
|
||||
// Create secure buffers for both the private key and passphrase
|
||||
agePrivKeyStr := ageIdentity.String()
|
||||
agePrivKeyBuffer := memguard.NewBufferFromBytes([]byte(agePrivKeyStr))
|
||||
defer agePrivKeyBuffer.Destroy()
|
||||
|
||||
passphraseBuffer := memguard.NewBufferFromBytes([]byte(agePrivKeyPassphrase))
|
||||
defer passphraseBuffer.Destroy()
|
||||
|
||||
encryptedAgePrivKey, err := EncryptWithPassphrase(agePrivKeyBuffer.Bytes(), passphraseBuffer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt age private key with passphrase: %w", err)
|
||||
}
|
||||
|
||||
agePrivKeyPath := filepath.Join(unlockerDir, "priv.age")
|
||||
if err := afero.WriteFile(fs, agePrivKeyPath, encryptedAgePrivKey, FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write encrypted age private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Get or derive the long-term private key
|
||||
ltPrivKeyData, err := getLongTermPrivateKey(fs, vault)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ltPrivKeyData.Destroy()
|
||||
|
||||
// Step 6: Encrypt long-term private key to the new age unlocker
|
||||
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt long-term private key to age unlocker: %w", err)
|
||||
}
|
||||
|
||||
// Write encrypted long-term private key
|
||||
ltPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
|
||||
if err := afero.WriteFile(fs, ltPrivKeyPath, encryptedLtPrivKeyToAge, FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write encrypted long-term private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: Prepare keychain data
|
||||
keychainData := KeychainData{
|
||||
AgePublicKey: ageRecipient,
|
||||
AgePrivKeyPassphrase: agePrivKeyPassphrase,
|
||||
EncryptedLongtermKey: hex.EncodeToString(encryptedLtPrivKeyToAge),
|
||||
}
|
||||
|
||||
keychainDataBytes, err := json.Marshal(keychainData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal keychain data: %w", err)
|
||||
}
|
||||
|
||||
// Step 8: Store data in keychain
|
||||
if err := storeInKeychain(keychainItemName, keychainDataBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to store data in keychain: %w", err)
|
||||
}
|
||||
|
||||
// Step 9: Create and write enhanced metadata
|
||||
keychainMetadata := KeychainUnlockerMetadata{
|
||||
UnlockerMetadata: UnlockerMetadata{
|
||||
Type: "keychain",
|
||||
CreatedAt: time.Now(),
|
||||
Flags: []string{"keychain", "macos"},
|
||||
},
|
||||
KeychainItemName: keychainItemName,
|
||||
}
|
||||
|
||||
metadataBytes, err := json.MarshalIndent(keychainMetadata, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal unlocker metadata: %w", err)
|
||||
}
|
||||
|
||||
if err := afero.WriteFile(fs,
|
||||
filepath.Join(unlockerDir, "unlocker-metadata.json"),
|
||||
metadataBytes, FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write unlocker metadata: %w", err)
|
||||
}
|
||||
|
||||
return &KeychainUnlocker{
|
||||
Directory: unlockerDir,
|
||||
Metadata: keychainMetadata.UnlockerMetadata,
|
||||
fs: fs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkMacOSAvailable verifies that we're running on macOS and security command is available
|
||||
func checkMacOSAvailable() error {
|
||||
cmd := exec.Command("/usr/bin/security", "help")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("macOS security command not available: %w (keychain unlockers are only supported on macOS)", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateKeychainItemName validates that a keychain item name is safe for command execution
|
||||
func validateKeychainItemName(itemName string) error {
|
||||
if itemName == "" {
|
||||
return fmt.Errorf("keychain item name cannot be empty")
|
||||
}
|
||||
|
||||
if !keychainItemNameRegex.MatchString(itemName) {
|
||||
return fmt.Errorf("invalid keychain item name format: %s", itemName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// storeInKeychain stores data in the macOS keychain using the security command
|
||||
func storeInKeychain(itemName string, data []byte) error {
|
||||
if err := validateKeychainItemName(itemName); err != nil {
|
||||
return fmt.Errorf("invalid keychain item name: %w", err)
|
||||
}
|
||||
cmd := exec.Command("/usr/bin/security", "add-generic-password", //nolint:gosec
|
||||
"-a", itemName,
|
||||
"-s", itemName,
|
||||
"-w", string(data),
|
||||
"-U") // Update if exists
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to store item in keychain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// retrieveFromKeychain retrieves data from the macOS keychain using the security command
|
||||
func retrieveFromKeychain(itemName string) ([]byte, error) {
|
||||
if err := validateKeychainItemName(itemName); err != nil {
|
||||
return nil, fmt.Errorf("invalid keychain item name: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("/usr/bin/security", "find-generic-password", //nolint:gosec
|
||||
"-a", itemName,
|
||||
"-s", itemName,
|
||||
"-w") // Return password only
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve item from keychain: %w", err)
|
||||
}
|
||||
|
||||
// Remove trailing newline if present
|
||||
if len(output) > 0 && output[len(output)-1] == '\n' {
|
||||
output = output[:len(output)-1]
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// deleteFromKeychain removes an item from the macOS keychain using the security command
|
||||
func deleteFromKeychain(itemName string) error {
|
||||
if err := validateKeychainItemName(itemName); err != nil {
|
||||
return fmt.Errorf("invalid keychain item name: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("/usr/bin/security", "delete-generic-password", //nolint:gosec
|
||||
"-a", itemName,
|
||||
"-s", itemName)
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to delete item from keychain: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateRandomPassphrase generates a random passphrase for encrypting the age private key
|
||||
func generateRandomPassphrase(length int) (string, error) {
|
||||
return generateRandomString(length, "0123456789abcdef")
|
||||
}
|
||||
34
internal/secret/metadata.go
Normal file
34
internal/secret/metadata.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// VaultMetadata contains information about a vault
|
||||
type VaultMetadata struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Description string `json:"description,omitempty"`
|
||||
DerivationIndex uint32 `json:"derivationIndex"`
|
||||
// Double SHA256 hash of the actual long-term public key
|
||||
PublicKeyHash string `json:"publicKeyHash,omitempty"`
|
||||
// Double SHA256 hash of index-0 key (for grouping vaults from same mnemonic)
|
||||
MnemonicFamilyHash string `json:"mnemonicFamilyHash,omitempty"`
|
||||
}
|
||||
|
||||
// UnlockerMetadata contains information about an unlocker
|
||||
type UnlockerMetadata struct {
|
||||
Type string `json:"type"` // passphrase, pgp, keychain
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Flags []string `json:"flags,omitempty"`
|
||||
}
|
||||
|
||||
// Metadata contains information about a secret
|
||||
type Metadata struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// Configuration represents the global configuration
|
||||
type Configuration struct {
|
||||
DefaultVault string `json:"defaultVault,omitempty"`
|
||||
}
|
||||
186
internal/secret/passphrase_test.go
Normal file
186
internal/secret/passphrase_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package secret_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
func TestPassphraseUnlockerWithRealFS(t *testing.T) {
|
||||
// This test uses real filesystem
|
||||
if os.Getenv("CI") == "true" {
|
||||
t.Log("Running in CI environment with real filesystem")
|
||||
}
|
||||
|
||||
// Create a temporary directory for our tests
|
||||
tempDir, err := os.MkdirTemp("", "secret-passphrase-test-")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir) // Clean up after test
|
||||
|
||||
// Use the real filesystem
|
||||
fs := afero.NewOsFs()
|
||||
|
||||
// Test data
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
testPassphrase := "test-passphrase-123"
|
||||
|
||||
// Create the directory structure
|
||||
unlockerDir := filepath.Join(tempDir, "unlocker")
|
||||
if err := os.MkdirAll(unlockerDir, secret.DirPerms); err != nil {
|
||||
t.Fatalf("Failed to create unlocker directory: %v", err)
|
||||
}
|
||||
|
||||
// Set up test metadata
|
||||
metadata := secret.UnlockerMetadata{
|
||||
Type: "passphrase",
|
||||
CreatedAt: time.Now(),
|
||||
Flags: []string{},
|
||||
}
|
||||
|
||||
// Create passphrase unlocker
|
||||
unlocker := secret.NewPassphraseUnlocker(fs, unlockerDir, metadata)
|
||||
|
||||
// Generate a test age identity
|
||||
ageIdentity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate age identity: %v", err)
|
||||
}
|
||||
agePrivateKey := ageIdentity.String()
|
||||
agePublicKey := ageIdentity.Recipient().String()
|
||||
|
||||
// Test writing public key
|
||||
t.Run("WritePublicKey", func(t *testing.T) {
|
||||
pubKeyPath := filepath.Join(unlockerDir, "pub.age")
|
||||
if err := afero.WriteFile(fs, pubKeyPath, []byte(agePublicKey), secret.FilePerms); err != nil {
|
||||
t.Fatalf("Failed to write public key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the file exists
|
||||
exists, err := afero.Exists(fs, pubKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if public key exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Errorf("Public key file should exist at %s", pubKeyPath)
|
||||
}
|
||||
})
|
||||
|
||||
// Test encrypting private key with passphrase
|
||||
t.Run("EncryptPrivateKey", func(t *testing.T) {
|
||||
privKeyData := []byte(agePrivateKey)
|
||||
passphraseBuffer := memguard.NewBufferFromBytes([]byte(testPassphrase))
|
||||
defer passphraseBuffer.Destroy()
|
||||
encryptedPrivKey, err := secret.EncryptWithPassphrase(privKeyData, passphraseBuffer)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt private key: %v", err)
|
||||
}
|
||||
|
||||
privKeyPath := filepath.Join(unlockerDir, "priv.age")
|
||||
if err := afero.WriteFile(fs, privKeyPath, encryptedPrivKey, secret.FilePerms); err != nil {
|
||||
t.Fatalf("Failed to write encrypted private key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the file exists
|
||||
exists, err := afero.Exists(fs, privKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if private key exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Errorf("Encrypted private key file should exist at %s", privKeyPath)
|
||||
}
|
||||
})
|
||||
|
||||
// Test writing long-term key
|
||||
t.Run("WriteLongTermKey", func(t *testing.T) {
|
||||
// Derive a long-term identity from the test mnemonic
|
||||
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term identity: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt long-term private key to the unlocker's recipient
|
||||
recipient, err := age.ParseX25519Recipient(agePublicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse recipient: %v", err)
|
||||
}
|
||||
|
||||
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
|
||||
defer ltPrivKeyBuffer.Destroy()
|
||||
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, recipient)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt long-term private key: %v", err)
|
||||
}
|
||||
|
||||
ltPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
|
||||
if err := afero.WriteFile(fs, ltPrivKeyPath, encryptedLtPrivKey, secret.FilePerms); err != nil {
|
||||
t.Fatalf("Failed to write encrypted long-term private key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the file exists
|
||||
exists, err := afero.Exists(fs, ltPrivKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if long-term key exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Errorf("Encrypted long-term key file should exist at %s", ltPrivKeyPath)
|
||||
}
|
||||
})
|
||||
|
||||
// Set test environment variable (cleaned up automatically)
|
||||
t.Setenv(secret.EnvUnlockPassphrase, testPassphrase)
|
||||
|
||||
// Test getting identity from environment variable
|
||||
t.Run("GetIdentityFromEnv", func(t *testing.T) {
|
||||
identity, err := unlocker.GetIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get identity from env: %v", err)
|
||||
}
|
||||
|
||||
// Verify the identity matches what we expect
|
||||
expectedPubKey := ageIdentity.Recipient().String()
|
||||
actualPubKey := identity.Recipient().String()
|
||||
if actualPubKey != expectedPubKey {
|
||||
t.Errorf("Public key mismatch. Expected %s, got %s", expectedPubKey, actualPubKey)
|
||||
}
|
||||
})
|
||||
|
||||
// Unset the environment variable to test interactive prompt
|
||||
os.Unsetenv(secret.EnvUnlockPassphrase)
|
||||
|
||||
// Test getting identity from prompt (this would require mocking the prompt)
|
||||
// For real integration tests, we'd need to provide a way to mock the passphrase input
|
||||
// Here we'll just verify the error is what we expect when no passphrase is available
|
||||
t.Run("GetIdentityWithoutEnv", func(t *testing.T) {
|
||||
// This should fail since we're not in an interactive terminal
|
||||
_, err := unlocker.GetIdentity()
|
||||
if err == nil {
|
||||
t.Errorf("Should have failed to get identity without passphrase env var")
|
||||
}
|
||||
})
|
||||
|
||||
// Test removing the unlocker
|
||||
t.Run("RemoveUnlocker", func(t *testing.T) {
|
||||
err := unlocker.Remove()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove unlocker: %v", err)
|
||||
}
|
||||
|
||||
// Verify the directory is gone
|
||||
exists, err := afero.DirExists(fs, unlockerDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if unlocker directory exists: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Errorf("Unlocker directory should not exist after removal")
|
||||
}
|
||||
})
|
||||
}
|
||||
188
internal/secret/passphraseunlocker.go
Normal file
188
internal/secret/passphraseunlocker.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"filippo.io/age"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// PassphraseUnlocker represents a passphrase-protected unlocker
|
||||
type PassphraseUnlocker struct {
|
||||
Directory string
|
||||
Metadata UnlockerMetadata
|
||||
fs afero.Fs
|
||||
Passphrase *memguard.LockedBuffer // Secure buffer for passphrase
|
||||
}
|
||||
|
||||
// getPassphrase retrieves the passphrase from memory, environment, or user input
|
||||
// Returns a LockedBuffer for secure memory handling
|
||||
func (p *PassphraseUnlocker) getPassphrase() (*memguard.LockedBuffer, error) {
|
||||
// First check if we already have the passphrase
|
||||
if p.Passphrase != nil && p.Passphrase.IsAlive() {
|
||||
Debug("Using in-memory passphrase", "unlocker_id", p.GetID())
|
||||
// Return a copy of the passphrase buffer
|
||||
return memguard.NewBufferFromBytes(p.Passphrase.Bytes()), nil
|
||||
}
|
||||
|
||||
Debug("No passphrase in memory, checking environment")
|
||||
// Check environment variable for passphrase
|
||||
passphraseStr := os.Getenv(EnvUnlockPassphrase)
|
||||
if passphraseStr != "" {
|
||||
Debug("Using passphrase from environment", "unlocker_id", p.GetID())
|
||||
// Convert to secure buffer
|
||||
secureBuffer := memguard.NewBufferFromBytes([]byte(passphraseStr))
|
||||
|
||||
return secureBuffer, nil
|
||||
}
|
||||
|
||||
Debug("No passphrase in environment, prompting user")
|
||||
// Prompt for passphrase
|
||||
secureBuffer, err := ReadPassphrase("Enter unlock passphrase: ")
|
||||
if err != nil {
|
||||
Debug("Failed to read passphrase", "error", err, "unlocker_id", p.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to read passphrase: %w", err)
|
||||
}
|
||||
|
||||
return secureBuffer, nil
|
||||
}
|
||||
|
||||
// GetIdentity implements Unlocker interface for passphrase-based unlockers
|
||||
func (p *PassphraseUnlocker) GetIdentity() (*age.X25519Identity, error) {
|
||||
DebugWith("Getting passphrase unlocker identity",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.String("unlocker_type", p.GetType()),
|
||||
)
|
||||
|
||||
passphraseBuffer, err := p.getPassphrase()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer passphraseBuffer.Destroy()
|
||||
|
||||
// Read encrypted private key of unlocker
|
||||
unlockerPrivPath := filepath.Join(p.Directory, "priv.age")
|
||||
Debug("Reading encrypted passphrase unlocker", "path", unlockerPrivPath)
|
||||
|
||||
encryptedPrivKeyData, err := afero.ReadFile(p.fs, unlockerPrivPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read passphrase unlocker private key", "error", err, "path", unlockerPrivPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read unlocker private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Read encrypted passphrase unlocker",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.Int("encrypted_length", len(encryptedPrivKeyData)),
|
||||
)
|
||||
|
||||
Debug("Decrypting unlocker private key with passphrase", "unlocker_id", p.GetID())
|
||||
|
||||
// Decrypt the unlocker private key with passphrase
|
||||
privKeyData, err := DecryptWithPassphrase(encryptedPrivKeyData, passphraseBuffer)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt unlocker private key", "error", err, "unlocker_id", p.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt unlocker private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Successfully decrypted unlocker private key",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.Int("decrypted_length", len(privKeyData)),
|
||||
)
|
||||
|
||||
// Parse the decrypted private key
|
||||
Debug("Parsing decrypted unlocker identity", "unlocker_id", p.GetID())
|
||||
|
||||
// Create a secure buffer for the private key data
|
||||
privKeyBuffer := memguard.NewBufferFromBytes(privKeyData)
|
||||
defer privKeyBuffer.Destroy()
|
||||
|
||||
// Clear the original private key data
|
||||
for i := range privKeyData {
|
||||
privKeyData[i] = 0
|
||||
}
|
||||
|
||||
identity, err := age.ParseX25519Identity(privKeyBuffer.String())
|
||||
if err != nil {
|
||||
Debug("Failed to parse unlocker private key", "error", err, "unlocker_id", p.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to parse unlocker private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Successfully parsed passphrase unlocker identity",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.String("public_key", identity.Recipient().String()),
|
||||
)
|
||||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
// GetType implements Unlocker interface
|
||||
func (p *PassphraseUnlocker) GetType() string {
|
||||
return "passphrase"
|
||||
}
|
||||
|
||||
// GetMetadata implements Unlocker interface
|
||||
func (p *PassphraseUnlocker) GetMetadata() UnlockerMetadata {
|
||||
return p.Metadata
|
||||
}
|
||||
|
||||
// GetDirectory implements Unlocker interface
|
||||
func (p *PassphraseUnlocker) GetDirectory() string {
|
||||
return p.Directory
|
||||
}
|
||||
|
||||
// GetID implements Unlocker interface - generates ID from creation timestamp
|
||||
func (p *PassphraseUnlocker) GetID() string {
|
||||
// Generate ID using creation timestamp: YYYY-MM-DD.HH.mm-passphrase
|
||||
createdAt := p.Metadata.CreatedAt
|
||||
|
||||
return fmt.Sprintf("%s-passphrase", createdAt.Format("2006-01-02.15.04"))
|
||||
}
|
||||
|
||||
// Remove implements Unlocker interface - removes the passphrase unlocker
|
||||
func (p *PassphraseUnlocker) Remove() error {
|
||||
// Clean up the passphrase from memory if it exists
|
||||
if p.Passphrase != nil && p.Passphrase.IsAlive() {
|
||||
p.Passphrase.Destroy()
|
||||
}
|
||||
|
||||
// For passphrase unlockers, we just need to remove the directory
|
||||
// No external resources (like keychain items) to clean up
|
||||
if err := p.fs.RemoveAll(p.Directory); err != nil {
|
||||
return fmt.Errorf("failed to remove passphrase unlocker directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewPassphraseUnlocker creates a new PassphraseUnlocker instance
|
||||
func NewPassphraseUnlocker(fs afero.Fs, directory string, metadata UnlockerMetadata) *PassphraseUnlocker {
|
||||
return &PassphraseUnlocker{
|
||||
Directory: directory,
|
||||
Metadata: metadata,
|
||||
fs: fs,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePassphraseUnlocker creates a new passphrase-protected unlocker
|
||||
// The passphrase must be provided as a LockedBuffer for security
|
||||
func CreatePassphraseUnlocker(
|
||||
fs afero.Fs,
|
||||
stateDir string,
|
||||
passphrase *memguard.LockedBuffer,
|
||||
) (*PassphraseUnlocker, error) {
|
||||
// Get current vault
|
||||
currentVault, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current vault: %w", err)
|
||||
}
|
||||
|
||||
return currentVault.CreatePassphraseUnlocker(passphrase)
|
||||
}
|
||||
499
internal/secret/pgpunlock_test.go
Normal file
499
internal/secret/pgpunlock_test.go
Normal file
@@ -0,0 +1,499 @@
|
||||
package secret_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// Register vault with secret package for testing
|
||||
func init() {
|
||||
// Register the vault.GetCurrentVault function with the secret package
|
||||
secret.RegisterGetCurrentVaultFunc(func(fs afero.Fs, stateDir string) (secret.VaultInterface, error) {
|
||||
return vault.GetCurrentVault(fs, stateDir)
|
||||
})
|
||||
}
|
||||
|
||||
// setupNonInteractiveGPG creates a custom GPG environment for testing
|
||||
func setupNonInteractiveGPG(t *testing.T, _, passphrase, gnupgHomeDir string) {
|
||||
// Create GPG config file for non-interactive operation
|
||||
gpgConfPath := filepath.Join(gnupgHomeDir, "gpg.conf")
|
||||
gpgConfContent := `batch
|
||||
no-tty
|
||||
pinentry-mode loopback
|
||||
`
|
||||
if err := os.WriteFile(gpgConfPath, []byte(gpgConfContent), 0o600); err != nil {
|
||||
t.Fatalf("Failed to write GPG config file: %v", err)
|
||||
}
|
||||
|
||||
// Create a test-specific GPG implementation
|
||||
origEncryptFunc := secret.GPGEncryptFunc
|
||||
origDecryptFunc := secret.GPGDecryptFunc
|
||||
|
||||
// Set custom GPG functions for this test
|
||||
secret.GPGEncryptFunc = func(data []byte, keyID string) ([]byte, error) {
|
||||
cmd := exec.Command("gpg",
|
||||
"--homedir", gnupgHomeDir,
|
||||
"--batch",
|
||||
"--yes",
|
||||
"--pinentry-mode", "loopback",
|
||||
"--passphrase", passphrase,
|
||||
"--trust-model", "always",
|
||||
"--armor",
|
||||
"--encrypt",
|
||||
"-r", keyID)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Stdin = bytes.NewReader(data)
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("GPG encryption failed: %w\nStderr: %s", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
secret.GPGDecryptFunc = func(encryptedData []byte) ([]byte, error) {
|
||||
cmd := exec.Command("gpg",
|
||||
"--homedir", gnupgHomeDir,
|
||||
"--batch",
|
||||
"--yes",
|
||||
"--pinentry-mode", "loopback",
|
||||
"--passphrase", passphrase,
|
||||
"--quiet",
|
||||
"--decrypt")
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
cmd.Stdin = bytes.NewReader(encryptedData)
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("GPG decryption failed: %w\nStderr: %s", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
// Restore original functions after test
|
||||
t.Cleanup(func() {
|
||||
secret.GPGEncryptFunc = origEncryptFunc
|
||||
secret.GPGDecryptFunc = origDecryptFunc
|
||||
})
|
||||
}
|
||||
|
||||
// runGPGWithPassphrase executes a GPG command with the specified passphrase
|
||||
func runGPGWithPassphrase(gnupgHome, passphrase string, args []string, input io.Reader) ([]byte, error) {
|
||||
cmdArgs := []string{
|
||||
"--homedir=" + gnupgHome,
|
||||
"--batch",
|
||||
"--yes",
|
||||
"--pinentry-mode", "loopback",
|
||||
"--passphrase", passphrase,
|
||||
}
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
|
||||
cmd := exec.Command("gpg", cmdArgs...)
|
||||
cmd.Stdin = input
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GPG command failed: %w\nStderr: %s", err, stderr.String())
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
|
||||
func TestPGPUnlockerWithRealFS(t *testing.T) {
|
||||
// Check if gpg is available
|
||||
if _, err := exec.LookPath("gpg"); err != nil {
|
||||
t.Log("GPG not available, PGP unlock key tests may not fully function")
|
||||
// Continue anyway to test what we can
|
||||
}
|
||||
|
||||
// Create a temporary directory for our tests
|
||||
tempDir, err := os.MkdirTemp("", "secret-pgp-test-")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir) // Clean up after test
|
||||
|
||||
// Create a temporary GNUPGHOME
|
||||
gnupgHomeDir := filepath.Join(tempDir, "gnupg")
|
||||
if err := os.MkdirAll(gnupgHomeDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create GNUPGHOME: %v", err)
|
||||
}
|
||||
|
||||
// Set new GNUPGHOME
|
||||
t.Setenv("GNUPGHOME", gnupgHomeDir)
|
||||
|
||||
// Test passphrase for GPG key
|
||||
testPassphrase := "test123"
|
||||
|
||||
// Setup non-interactive GPG with custom functions
|
||||
setupNonInteractiveGPG(t, tempDir, testPassphrase, gnupgHomeDir)
|
||||
|
||||
// Create GPG batch file for key generation
|
||||
batchFile := filepath.Join(tempDir, "gen-key-batch")
|
||||
batchContent := `%echo Generating a test key
|
||||
Key-Type: RSA
|
||||
Key-Length: 2048
|
||||
Name-Real: Test User
|
||||
Name-Email: test@example.com
|
||||
Expire-Date: 0
|
||||
Passphrase: ` + testPassphrase + `
|
||||
%commit
|
||||
%echo Key generation completed
|
||||
`
|
||||
if err := os.WriteFile(batchFile, []byte(batchContent), 0o600); err != nil {
|
||||
t.Fatalf("Failed to write batch file: %v", err)
|
||||
}
|
||||
|
||||
// Generate GPG key with batch mode
|
||||
t.Log("Generating GPG key...")
|
||||
_, err = runGPGWithPassphrase(gnupgHomeDir, testPassphrase,
|
||||
[]string{"--gen-key", batchFile}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate GPG key: %v", err)
|
||||
}
|
||||
t.Log("GPG key generated successfully")
|
||||
|
||||
// Get the key ID and fingerprint
|
||||
output, err := runGPGWithPassphrase(gnupgHomeDir, testPassphrase,
|
||||
[]string{"--list-secret-keys", "--with-colons", "--fingerprint"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list GPG keys: %v", err)
|
||||
}
|
||||
|
||||
// Parse output to get key ID and fingerprint
|
||||
var keyID, fingerprint string
|
||||
lines := strings.Split(string(output), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "sec:") {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) >= 5 {
|
||||
keyID = fields[4]
|
||||
}
|
||||
} else if strings.HasPrefix(line, "fpr:") {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) >= 10 && fields[9] != "" {
|
||||
fingerprint = fields[9]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if keyID == "" {
|
||||
t.Fatalf("Failed to find GPG key ID in output: %s", output)
|
||||
}
|
||||
if fingerprint == "" {
|
||||
t.Fatalf("Failed to find GPG fingerprint in output: %s", output)
|
||||
}
|
||||
t.Logf("Generated GPG key ID: %s", keyID)
|
||||
t.Logf("Generated GPG fingerprint: %s", fingerprint)
|
||||
|
||||
// Set the GPG_AGENT_INFO to empty to ensure gpg-agent doesn't interfere
|
||||
t.Setenv("GPG_AGENT_INFO", "")
|
||||
|
||||
// Use the real filesystem
|
||||
fs := afero.NewOsFs()
|
||||
|
||||
// Test data
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
|
||||
// Set test environment variables
|
||||
t.Setenv(secret.EnvMnemonic, testMnemonic)
|
||||
t.Setenv(secret.EnvGPGKeyID, keyID)
|
||||
|
||||
// Set up vault structure for testing
|
||||
stateDir := tempDir
|
||||
vaultName := "test-vault"
|
||||
|
||||
// Test creation of a PGP unlock key through a vault
|
||||
t.Run("CreatePGPUnlocker", func(t *testing.T) {
|
||||
// Set a limited test timeout to avoid hanging
|
||||
timer := time.AfterFunc(30*time.Second, func() {
|
||||
t.Fatalf("Test timed out after 30 seconds")
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
// Create a test vault directory structure
|
||||
vlt, err := vault.CreateVault(fs, stateDir, vaultName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault: %v", err)
|
||||
}
|
||||
|
||||
// Set the current vault
|
||||
err = vault.SelectVault(fs, stateDir, vaultName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select vault: %v", err)
|
||||
}
|
||||
|
||||
// Derive long-term key from mnemonic
|
||||
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Get the vault directory
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Write long-term public key
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
if err := afero.WriteFile(fs, ltPubKeyPath, []byte(ltIdentity.Recipient().String()), secret.FilePerms); err != nil {
|
||||
t.Fatalf("Failed to write long-term public key: %v", err)
|
||||
}
|
||||
|
||||
// Unlock the vault
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Create a passphrase unlocker first (to have current unlocker)
|
||||
passphraseBuffer := memguard.NewBufferFromBytes([]byte("test-passphrase"))
|
||||
defer passphraseBuffer.Destroy()
|
||||
passUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create passphrase unlocker: %v", err)
|
||||
}
|
||||
|
||||
// Verify passphrase unlocker was created
|
||||
if passUnlocker == nil {
|
||||
t.Fatal("Passphrase unlocker is nil")
|
||||
}
|
||||
|
||||
// Now create a PGP unlock key (this will use our custom GPGEncryptFunc)
|
||||
pgpUnlocker, err := secret.CreatePGPUnlocker(fs, stateDir, keyID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PGP unlock key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the PGP unlock key was created
|
||||
if pgpUnlocker == nil {
|
||||
t.Fatal("PGP unlock key is nil")
|
||||
}
|
||||
|
||||
// Check if the key has the correct type
|
||||
if pgpUnlocker.GetType() != "pgp" {
|
||||
t.Errorf("Expected PGP unlock key type 'pgp', got '%s'", pgpUnlocker.GetType())
|
||||
}
|
||||
|
||||
// Check if the key ID includes the GPG fingerprint
|
||||
if !strings.Contains(pgpUnlocker.GetID(), fingerprint) {
|
||||
t.Errorf("PGP unlock key ID '%s' does not contain GPG fingerprint '%s'", pgpUnlocker.GetID(), fingerprint)
|
||||
}
|
||||
|
||||
// Check if the key directory exists
|
||||
unlockerDir := pgpUnlocker.GetDirectory()
|
||||
keyExists, err := afero.DirExists(fs, unlockerDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if PGP key directory exists: %v", err)
|
||||
}
|
||||
if !keyExists {
|
||||
t.Errorf("PGP unlock key directory does not exist: %s", unlockerDir)
|
||||
}
|
||||
|
||||
// Check if required files exist
|
||||
recipientPath := filepath.Join(unlockerDir, "pub.txt")
|
||||
recipientExists, err := afero.Exists(fs, recipientPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if recipient file exists: %v", err)
|
||||
}
|
||||
if !recipientExists {
|
||||
t.Errorf("PGP unlock key recipient file does not exist: %s", recipientPath)
|
||||
}
|
||||
|
||||
privKeyPath := filepath.Join(unlockerDir, "priv.age.gpg")
|
||||
privKeyExists, err := afero.Exists(fs, privKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if private key file exists: %v", err)
|
||||
}
|
||||
if !privKeyExists {
|
||||
t.Errorf("PGP unlock key private key file does not exist: %s", privKeyPath)
|
||||
}
|
||||
|
||||
metadataPath := filepath.Join(unlockerDir, "unlocker-metadata.json")
|
||||
metadataExists, err := afero.Exists(fs, metadataPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if metadata file exists: %v", err)
|
||||
}
|
||||
if !metadataExists {
|
||||
t.Errorf("PGP unlock key metadata file does not exist: %s", metadataPath)
|
||||
}
|
||||
|
||||
longtermPath := filepath.Join(unlockerDir, "longterm.age")
|
||||
longtermExists, err := afero.Exists(fs, longtermPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if longterm key file exists: %v", err)
|
||||
}
|
||||
if !longtermExists {
|
||||
t.Errorf("PGP unlock key longterm key file does not exist: %s", longtermPath)
|
||||
}
|
||||
|
||||
// Read and verify metadata
|
||||
metadataBytes, err := afero.ReadFile(fs, metadataPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read metadata: %v", err)
|
||||
}
|
||||
|
||||
var metadata struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Flags []string `json:"flags"`
|
||||
GPGKeyID string `json:"gpgKeyId"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
t.Fatalf("Failed to parse metadata: %v", err)
|
||||
}
|
||||
|
||||
if metadata.Type != "pgp" {
|
||||
t.Errorf("Expected metadata type 'pgp', got '%s'", metadata.Type)
|
||||
}
|
||||
|
||||
if metadata.GPGKeyID != fingerprint {
|
||||
t.Errorf("Expected GPG fingerprint '%s', got '%s'", fingerprint, metadata.GPGKeyID)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up key directory for individual tests
|
||||
unlockerDir := filepath.Join(tempDir, "unlocker")
|
||||
if err := os.MkdirAll(unlockerDir, secret.DirPerms); err != nil {
|
||||
t.Fatalf("Failed to create unlocker directory: %v", err)
|
||||
}
|
||||
|
||||
// Set up test metadata
|
||||
metadata := secret.UnlockerMetadata{
|
||||
Type: "pgp",
|
||||
CreatedAt: time.Now(),
|
||||
Flags: []string{"gpg", "encrypted"},
|
||||
}
|
||||
|
||||
// Create a PGP unlocker for the remaining tests
|
||||
unlocker := secret.NewPGPUnlocker(fs, unlockerDir, metadata)
|
||||
|
||||
// Test getting GPG key ID
|
||||
t.Run("GetGPGKeyID", func(t *testing.T) {
|
||||
// Create PGP metadata with GPG key ID
|
||||
type PGPUnlockerMetadata struct {
|
||||
secret.UnlockerMetadata
|
||||
GPGKeyID string `json:"gpgKeyId"`
|
||||
}
|
||||
|
||||
pgpMetadata := PGPUnlockerMetadata{
|
||||
UnlockerMetadata: metadata,
|
||||
GPGKeyID: fingerprint,
|
||||
}
|
||||
|
||||
// Write metadata file
|
||||
metadataPath := filepath.Join(unlockerDir, "unlocker-metadata.json")
|
||||
metadataBytes, err := json.MarshalIndent(pgpMetadata, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal metadata: %v", err)
|
||||
}
|
||||
if err := afero.WriteFile(fs, metadataPath, metadataBytes, secret.FilePerms); err != nil {
|
||||
t.Fatalf("Failed to write metadata: %v", err)
|
||||
}
|
||||
|
||||
// Get GPG key ID
|
||||
retrievedKeyID, err := unlocker.GetGPGKeyID()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get GPG key ID: %v", err)
|
||||
}
|
||||
|
||||
// Verify key ID (should be the fingerprint)
|
||||
if retrievedKeyID != fingerprint {
|
||||
t.Errorf("Expected GPG fingerprint '%s', got '%s'", fingerprint, retrievedKeyID)
|
||||
}
|
||||
})
|
||||
|
||||
// Test getting identity from PGP unlocker
|
||||
t.Run("GetIdentity", func(t *testing.T) {
|
||||
// Generate an age identity for testing
|
||||
ageIdentity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate age identity: %v", err)
|
||||
}
|
||||
|
||||
// Write the recipient
|
||||
recipientPath := filepath.Join(unlockerDir, "pub.txt")
|
||||
if err := afero.WriteFile(fs, recipientPath, []byte(ageIdentity.Recipient().String()), secret.FilePerms); err != nil {
|
||||
t.Fatalf("Failed to write recipient: %v", err)
|
||||
}
|
||||
|
||||
// GPG encrypt the private key using our custom encrypt function
|
||||
privKeyData := []byte(ageIdentity.String())
|
||||
encryptedOutput, err := secret.GPGEncryptFunc(privKeyData, keyID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt with GPG: %v", err)
|
||||
}
|
||||
|
||||
// Write the encrypted data to a file
|
||||
encryptedPath := filepath.Join(unlockerDir, "priv.age.gpg")
|
||||
if err := afero.WriteFile(fs, encryptedPath, encryptedOutput, secret.FilePerms); err != nil {
|
||||
t.Fatalf("Failed to write encrypted private key: %v", err)
|
||||
}
|
||||
|
||||
// Now try to get the identity - this will use our custom GPGDecryptFunc
|
||||
identity, err := unlocker.GetIdentity()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get identity: %v", err)
|
||||
}
|
||||
|
||||
// Verify the identity matches
|
||||
expectedPubKey := ageIdentity.Recipient().String()
|
||||
actualPubKey := identity.Recipient().String()
|
||||
if actualPubKey != expectedPubKey {
|
||||
t.Errorf("Expected public key '%s', got '%s'", expectedPubKey, actualPubKey)
|
||||
}
|
||||
})
|
||||
|
||||
// Test removing the unlocker
|
||||
t.Run("RemoveUnlocker", func(t *testing.T) {
|
||||
// Ensure unlocker directory exists before removal
|
||||
keyExists, err := afero.DirExists(fs, unlockerDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if unlocker directory exists: %v", err)
|
||||
}
|
||||
if !keyExists {
|
||||
t.Fatalf("Unlocker directory does not exist: %s", unlockerDir)
|
||||
}
|
||||
|
||||
// Remove unlocker
|
||||
err = unlocker.Remove()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove unlocker: %v", err)
|
||||
}
|
||||
|
||||
// Verify directory is gone
|
||||
keyExists, err = afero.DirExists(fs, unlockerDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if unlocker directory exists: %v", err)
|
||||
}
|
||||
if keyExists {
|
||||
t.Errorf("Unlocker directory still exists after removal: %s", unlockerDir)
|
||||
}
|
||||
})
|
||||
}
|
||||
378
internal/secret/pgpunlocker.go
Normal file
378
internal/secret/pgpunlocker.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// Variables to allow overriding in tests
|
||||
var (
|
||||
// GPGEncryptFunc is the function used for GPG encryption
|
||||
// Can be overridden in tests to provide a non-interactive implementation
|
||||
GPGEncryptFunc = gpgEncryptDefault //nolint:gochecknoglobals // Required for test mocking
|
||||
|
||||
// GPGDecryptFunc is the function used for GPG decryption
|
||||
// Can be overridden in tests to provide a non-interactive implementation
|
||||
GPGDecryptFunc = gpgDecryptDefault //nolint:gochecknoglobals // Required for test mocking
|
||||
|
||||
// gpgKeyIDRegex validates GPG key IDs
|
||||
// Allows either:
|
||||
// 1. Email addresses (user@domain.tld format)
|
||||
// 2. Short key IDs (8 hex characters)
|
||||
// 3. Long key IDs (16 hex characters)
|
||||
// 4. Full fingerprints (40 hex characters)
|
||||
gpgKeyIDRegex = regexp.MustCompile(
|
||||
`^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$|` +
|
||||
`^[A-Fa-f0-9]{8}$|` +
|
||||
`^[A-Fa-f0-9]{16}$|` +
|
||||
`^[A-Fa-f0-9]{40}$`,
|
||||
)
|
||||
)
|
||||
|
||||
// PGPUnlockerMetadata extends UnlockerMetadata with PGP-specific data
|
||||
type PGPUnlockerMetadata struct {
|
||||
UnlockerMetadata
|
||||
// GPG key ID used for encryption
|
||||
GPGKeyID string `json:"gpgKeyId"`
|
||||
}
|
||||
|
||||
// PGPUnlocker represents a PGP-protected unlocker
|
||||
type PGPUnlocker struct {
|
||||
Directory string
|
||||
Metadata UnlockerMetadata
|
||||
fs afero.Fs
|
||||
}
|
||||
|
||||
// GetIdentity implements Unlocker interface for PGP-based unlockers
|
||||
func (p *PGPUnlocker) GetIdentity() (*age.X25519Identity, error) {
|
||||
DebugWith("Getting PGP unlocker identity",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.String("unlocker_type", p.GetType()),
|
||||
)
|
||||
|
||||
// Step 1: Read the encrypted age private key from filesystem
|
||||
agePrivKeyPath := filepath.Join(p.Directory, "priv.age.gpg")
|
||||
Debug("Reading PGP-encrypted age private key", "path", agePrivKeyPath)
|
||||
|
||||
encryptedAgePrivKeyData, err := afero.ReadFile(p.fs, agePrivKeyPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read PGP-encrypted age private key", "error", err, "path", agePrivKeyPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read encrypted age private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Read PGP-encrypted age private key",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.Int("encrypted_length", len(encryptedAgePrivKeyData)),
|
||||
)
|
||||
|
||||
// Step 2: Decrypt the age private key using GPG
|
||||
Debug("Decrypting age private key with GPG", "unlocker_id", p.GetID())
|
||||
agePrivKeyData, err := GPGDecryptFunc(encryptedAgePrivKeyData)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt age private key with GPG", "error", err, "unlocker_id", p.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt age private key with GPG: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Successfully decrypted age private key with GPG",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.Int("decrypted_length", len(agePrivKeyData)),
|
||||
)
|
||||
|
||||
// Step 3: Parse the decrypted age private key
|
||||
Debug("Parsing decrypted age private key", "unlocker_id", p.GetID())
|
||||
ageIdentity, err := age.ParseX25519Identity(string(agePrivKeyData))
|
||||
if err != nil {
|
||||
Debug("Failed to parse age private key", "error", err, "unlocker_id", p.GetID())
|
||||
|
||||
return nil, fmt.Errorf("failed to parse age private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Successfully parsed PGP age identity",
|
||||
slog.String("unlocker_id", p.GetID()),
|
||||
slog.String("public_key", ageIdentity.Recipient().String()),
|
||||
)
|
||||
|
||||
return ageIdentity, nil
|
||||
}
|
||||
|
||||
// GetType implements Unlocker interface
|
||||
func (p *PGPUnlocker) GetType() string {
|
||||
return "pgp"
|
||||
}
|
||||
|
||||
// GetMetadata implements Unlocker interface
|
||||
func (p *PGPUnlocker) GetMetadata() UnlockerMetadata {
|
||||
return p.Metadata
|
||||
}
|
||||
|
||||
// GetDirectory implements Unlocker interface
|
||||
func (p *PGPUnlocker) GetDirectory() string {
|
||||
return p.Directory
|
||||
}
|
||||
|
||||
// GetID implements Unlocker interface - generates ID from GPG key ID
|
||||
func (p *PGPUnlocker) GetID() string {
|
||||
// Generate ID using GPG key ID: <keyid>-pgp
|
||||
gpgKeyID, err := p.GetGPGKeyID()
|
||||
if err != nil {
|
||||
// The vault metadata is corrupt - this is a fatal error
|
||||
// We cannot continue with a fallback ID as that would mask data corruption
|
||||
panic(fmt.Sprintf("PGP unlocker metadata is corrupt or missing GPG key ID: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-pgp", gpgKeyID)
|
||||
}
|
||||
|
||||
// Remove implements Unlocker interface - removes the PGP unlocker
|
||||
func (p *PGPUnlocker) Remove() error {
|
||||
// For PGP unlockers, we just need to remove the directory
|
||||
// No external resources (like keychain items) to clean up
|
||||
if err := p.fs.RemoveAll(p.Directory); err != nil {
|
||||
return fmt.Errorf("failed to remove PGP unlocker directory: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewPGPUnlocker creates a new PGPUnlocker instance
|
||||
func NewPGPUnlocker(fs afero.Fs, directory string, metadata UnlockerMetadata) *PGPUnlocker {
|
||||
return &PGPUnlocker{
|
||||
Directory: directory,
|
||||
Metadata: metadata,
|
||||
fs: fs,
|
||||
}
|
||||
}
|
||||
|
||||
// GetGPGKeyID returns the GPG key ID from metadata
|
||||
func (p *PGPUnlocker) GetGPGKeyID() (string, error) {
|
||||
// Load the metadata
|
||||
metadataPath := filepath.Join(p.Directory, "unlocker-metadata.json")
|
||||
metadataData, err := afero.ReadFile(p.fs, metadataPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read PGP metadata: %w", err)
|
||||
}
|
||||
|
||||
var pgpMetadata PGPUnlockerMetadata
|
||||
if err := json.Unmarshal(metadataData, &pgpMetadata); err != nil {
|
||||
return "", fmt.Errorf("failed to parse PGP metadata: %w", err)
|
||||
}
|
||||
|
||||
return pgpMetadata.GPGKeyID, nil
|
||||
}
|
||||
|
||||
// generatePGPUnlockerName generates a unique name for the PGP unlocker based on hostname and date
|
||||
func generatePGPUnlockerName() (string, error) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get hostname: %w", err)
|
||||
}
|
||||
|
||||
// Format: hostname-pgp-YYYY-MM-DD
|
||||
enrollmentDate := time.Now().Format("2006-01-02")
|
||||
|
||||
return fmt.Sprintf("%s-pgp-%s", hostname, enrollmentDate), nil
|
||||
}
|
||||
|
||||
// CreatePGPUnlocker creates a new PGP unlocker and stores it in the vault
|
||||
func CreatePGPUnlocker(fs afero.Fs, stateDir string, gpgKeyID string) (*PGPUnlocker, error) {
|
||||
// Check if GPG is available
|
||||
if err := checkGPGAvailable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get current vault
|
||||
vault, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current vault: %w", err)
|
||||
}
|
||||
|
||||
// Generate the unlocker name based on hostname and date
|
||||
unlockerName, err := generatePGPUnlockerName()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate unlocker name: %w", err)
|
||||
}
|
||||
|
||||
// Create unlocker directory using the generated name
|
||||
vaultDir, err := vault.GetDirectory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get vault directory: %w", err)
|
||||
}
|
||||
|
||||
unlockerDir := filepath.Join(vaultDir, "unlockers.d", unlockerName)
|
||||
if err := fs.MkdirAll(unlockerDir, DirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create unlocker directory: %w", err)
|
||||
}
|
||||
|
||||
// Step 1: Generate a new age keypair for the PGP unlocker
|
||||
ageIdentity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate age keypair: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Store age recipient as plaintext
|
||||
ageRecipient := ageIdentity.Recipient().String()
|
||||
recipientPath := filepath.Join(unlockerDir, "pub.txt")
|
||||
if err := afero.WriteFile(fs, recipientPath, []byte(ageRecipient), FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write age recipient: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Get or derive the long-term private key
|
||||
ltPrivKeyData, err := getLongTermPrivateKey(fs, vault)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ltPrivKeyData.Destroy()
|
||||
|
||||
// Step 7: Encrypt long-term private key to the new age unlocker
|
||||
encryptedLtPrivKeyToAge, err := EncryptToRecipient(ltPrivKeyData, ageIdentity.Recipient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt long-term private key to age unlocker: %w", err)
|
||||
}
|
||||
|
||||
// Write encrypted long-term private key
|
||||
ltPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
|
||||
if err := afero.WriteFile(fs, ltPrivKeyPath, encryptedLtPrivKeyToAge, FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write encrypted long-term private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 8: Encrypt age private key to the GPG key ID
|
||||
// Use memguard to protect the private key in memory
|
||||
agePrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(ageIdentity.String()))
|
||||
defer agePrivateKeyBuffer.Destroy()
|
||||
|
||||
encryptedAgePrivKey, err := GPGEncryptFunc(agePrivateKeyBuffer.Bytes(), gpgKeyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt age private key with GPG: %w", err)
|
||||
}
|
||||
|
||||
agePrivKeyPath := filepath.Join(unlockerDir, "priv.age.gpg")
|
||||
if err := afero.WriteFile(fs, agePrivKeyPath, encryptedAgePrivKey, FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write encrypted age private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 9: Resolve the GPG key ID to its full fingerprint
|
||||
fingerprint, err := resolveGPGKeyFingerprint(gpgKeyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve GPG key fingerprint: %w", err)
|
||||
}
|
||||
|
||||
// Step 10: Create and write enhanced metadata with full fingerprint
|
||||
pgpMetadata := PGPUnlockerMetadata{
|
||||
UnlockerMetadata: UnlockerMetadata{
|
||||
Type: "pgp",
|
||||
CreatedAt: time.Now(),
|
||||
Flags: []string{"gpg", "encrypted"},
|
||||
},
|
||||
GPGKeyID: fingerprint,
|
||||
}
|
||||
|
||||
metadataBytes, err := json.MarshalIndent(pgpMetadata, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal unlocker metadata: %w", err)
|
||||
}
|
||||
|
||||
if err := afero.WriteFile(fs,
|
||||
filepath.Join(unlockerDir, "unlocker-metadata.json"),
|
||||
metadataBytes, FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write unlocker metadata: %w", err)
|
||||
}
|
||||
|
||||
return &PGPUnlocker{
|
||||
Directory: unlockerDir,
|
||||
Metadata: pgpMetadata.UnlockerMetadata,
|
||||
fs: fs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateGPGKeyID validates that a GPG key ID is safe for command execution
|
||||
func validateGPGKeyID(keyID string) error {
|
||||
if keyID == "" {
|
||||
return fmt.Errorf("GPG key ID cannot be empty")
|
||||
}
|
||||
|
||||
if !gpgKeyIDRegex.MatchString(keyID) {
|
||||
return fmt.Errorf("invalid GPG key ID format: %s", keyID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveGPGKeyFingerprint resolves any GPG key identifier to its full fingerprint
|
||||
func resolveGPGKeyFingerprint(keyID string) (string, error) {
|
||||
if err := validateGPGKeyID(keyID); err != nil {
|
||||
return "", fmt.Errorf("invalid GPG key ID: %w", err)
|
||||
}
|
||||
|
||||
// Use GPG to get the full fingerprint for the key
|
||||
cmd := exec.Command("gpg", "--list-keys", "--with-colons", "--fingerprint", keyID)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve GPG key fingerprint: %w", err)
|
||||
}
|
||||
|
||||
// Parse the output to extract the fingerprint
|
||||
lines := strings.Split(string(output), "\n")
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "fpr:") {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) >= 10 && fields[9] != "" {
|
||||
return fields[9], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not find fingerprint for GPG key: %s", keyID)
|
||||
}
|
||||
|
||||
// checkGPGAvailable verifies that GPG is available
|
||||
func checkGPGAvailable() error {
|
||||
cmd := exec.Command("gpg", "--version")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("GPG not available: %w (make sure 'gpg' command is installed and in PATH)", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// gpgEncryptDefault is the default implementation of GPG encryption
|
||||
func gpgEncryptDefault(data []byte, keyID string) ([]byte, error) {
|
||||
if err := validateGPGKeyID(keyID); err != nil {
|
||||
return nil, fmt.Errorf("invalid GPG key ID: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("gpg", "--trust-model", "always", "--armor", "--encrypt", "-r", keyID)
|
||||
cmd.Stdin = strings.NewReader(string(data))
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GPG encryption failed: %w", err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// gpgDecryptDefault is the default implementation of GPG decryption
|
||||
func gpgDecryptDefault(encryptedData []byte) ([]byte, error) {
|
||||
cmd := exec.Command("gpg", "--quiet", "--decrypt")
|
||||
cmd.Stdin = strings.NewReader(string(encryptedData))
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GPG decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
321
internal/secret/secret.go
Normal file
321
internal/secret/secret.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// VaultInterface defines the interface that vault implementations must satisfy
|
||||
type VaultInterface interface {
|
||||
GetDirectory() (string, error)
|
||||
AddSecret(name string, value *memguard.LockedBuffer, force bool) error
|
||||
GetName() string
|
||||
GetFilesystem() afero.Fs
|
||||
GetCurrentUnlocker() (Unlocker, error)
|
||||
CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*PassphraseUnlocker, error)
|
||||
}
|
||||
|
||||
// Secret represents a secret in a vault
|
||||
type Secret struct {
|
||||
Name string
|
||||
Directory string
|
||||
Metadata Metadata
|
||||
vault VaultInterface
|
||||
}
|
||||
|
||||
// NewSecret creates a new Secret instance
|
||||
func NewSecret(vault VaultInterface, name string) *Secret {
|
||||
DebugWith("Creating new secret instance",
|
||||
slog.String("secret_name", name),
|
||||
slog.String("vault_name", vault.GetName()),
|
||||
)
|
||||
|
||||
// Convert slashes to percent signs for storage directory name
|
||||
storageName := strings.ReplaceAll(name, "/", "%")
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", storageName)
|
||||
|
||||
DebugWith("Secret storage details",
|
||||
slog.String("secret_name", name),
|
||||
slog.String("storage_name", storageName),
|
||||
slog.String("secret_dir", secretDir),
|
||||
)
|
||||
|
||||
return &Secret{
|
||||
Name: name,
|
||||
Directory: secretDir,
|
||||
vault: vault,
|
||||
Metadata: Metadata{
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Save is deprecated - use vault.AddSecret directly which creates versions
|
||||
// Kept for backward compatibility
|
||||
func (s *Secret) Save(value []byte, force bool) error {
|
||||
DebugWith("Saving secret (deprecated method)",
|
||||
slog.String("secret_name", s.Name),
|
||||
slog.String("vault_name", s.vault.GetName()),
|
||||
slog.Int("value_length", len(value)),
|
||||
slog.Bool("force", force),
|
||||
)
|
||||
|
||||
// Create a secure buffer for the value - note that the caller
|
||||
// should ideally pass a LockedBuffer directly to vault.AddSecret
|
||||
valueBuffer := memguard.NewBufferFromBytes(value)
|
||||
defer valueBuffer.Destroy()
|
||||
|
||||
err := s.vault.AddSecret(s.Name, valueBuffer, force)
|
||||
if err != nil {
|
||||
Debug("Failed to save secret", "error", err, "secret_name", s.Name)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
Debug("Successfully saved secret", "secret_name", s.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValue retrieves and decrypts the current version's value using the provided unlocker
|
||||
func (s *Secret) GetValue(unlocker Unlocker) ([]byte, error) {
|
||||
DebugWith("Getting secret value",
|
||||
slog.String("secret_name", s.Name),
|
||||
slog.String("vault_name", s.vault.GetName()),
|
||||
)
|
||||
|
||||
// Check if secret exists
|
||||
exists, err := s.Exists()
|
||||
if err != nil {
|
||||
Debug("Failed to check if secret exists during GetValue", "error", err, "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to check if secret exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
Debug("Secret not found during GetValue", "secret_name", s.Name, "vault_name", s.vault.GetName())
|
||||
|
||||
return nil, fmt.Errorf("secret %s not found", s.Name)
|
||||
}
|
||||
|
||||
Debug("Secret exists, getting current version", "secret_name", s.Name)
|
||||
|
||||
// Get current version
|
||||
currentVersion, err := GetCurrentVersion(s.vault.GetFilesystem(), s.Directory)
|
||||
if err != nil {
|
||||
Debug("Failed to get current version", "error", err, "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
|
||||
// Create version object
|
||||
version := NewVersion(s.vault, s.Name, currentVersion)
|
||||
|
||||
// Check if we have SB_SECRET_MNEMONIC environment variable for direct decryption
|
||||
if envMnemonic := os.Getenv(EnvMnemonic); envMnemonic != "" {
|
||||
Debug("Using mnemonic from environment for direct long-term key derivation", "secret_name", s.Name)
|
||||
|
||||
// Get vault directory to read metadata
|
||||
vaultDir, err := s.vault.GetDirectory()
|
||||
if err != nil {
|
||||
Debug("Failed to get vault directory", "error", err, "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to get vault directory: %w", err)
|
||||
}
|
||||
|
||||
// Load vault metadata to get the correct derivation index
|
||||
metadataPath := filepath.Join(vaultDir, "vault-metadata.json")
|
||||
metadataBytes, err := afero.ReadFile(s.vault.GetFilesystem(), metadataPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read vault metadata", "error", err, "path", metadataPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read vault metadata: %w", err)
|
||||
}
|
||||
|
||||
var metadata VaultMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
Debug("Failed to parse vault metadata", "error", err, "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to parse vault metadata: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Using vault derivation index from metadata",
|
||||
slog.String("secret_name", s.Name),
|
||||
slog.String("vault_name", s.vault.GetName()),
|
||||
slog.Uint64("derivation_index", uint64(metadata.DerivationIndex)),
|
||||
)
|
||||
|
||||
// Use mnemonic with the vault's derivation index from metadata
|
||||
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex)
|
||||
if err != nil {
|
||||
Debug("Failed to derive long-term key from mnemonic for secret", "error", err, "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
|
||||
}
|
||||
|
||||
Debug("Successfully derived long-term key from mnemonic", "secret_name", s.Name)
|
||||
|
||||
// Use the long-term key to decrypt the version
|
||||
return version.GetValue(ltIdentity)
|
||||
}
|
||||
|
||||
Debug("Using unlocker for vault access", "secret_name", s.Name)
|
||||
|
||||
// Use the provided unlocker to get the vault's long-term private key
|
||||
if unlocker == nil {
|
||||
Debug("No unlocker provided for secret decryption", "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("unlocker required to decrypt secret")
|
||||
}
|
||||
|
||||
DebugWith("Getting vault's long-term key using unlocker",
|
||||
slog.String("secret_name", s.Name),
|
||||
slog.String("unlocker_type", unlocker.GetType()),
|
||||
slog.String("unlocker_id", unlocker.GetID()),
|
||||
)
|
||||
|
||||
// Step 1: Use the unlocker to get the vault's long-term private key
|
||||
unlockIdentity, err := unlocker.GetIdentity()
|
||||
if err != nil {
|
||||
Debug("Failed to get unlocker identity", "error", err, "secret_name", s.Name, "unlocker_type", unlocker.GetType())
|
||||
|
||||
return nil, fmt.Errorf("failed to get unlocker identity: %w", err)
|
||||
}
|
||||
|
||||
// Read the encrypted long-term private key from the unlocker directory
|
||||
encryptedLtPrivKeyPath := filepath.Join(unlocker.GetDirectory(), "longterm.age")
|
||||
Debug("Reading encrypted long-term private key", "path", encryptedLtPrivKeyPath)
|
||||
|
||||
encryptedLtPrivKey, err := afero.ReadFile(s.vault.GetFilesystem(), encryptedLtPrivKeyPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read encrypted long-term private key", "error", err, "path", encryptedLtPrivKeyPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the encrypted long-term private key using the unlocker
|
||||
Debug("Decrypting long-term private key using unlocker", "secret_name", s.Name)
|
||||
ltPrivKeyData, err := DecryptWithIdentity(encryptedLtPrivKey, unlockIdentity)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt long-term private key", "error", err, "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
|
||||
}
|
||||
|
||||
// Parse the long-term private key
|
||||
Debug("Parsing long-term private key", "secret_name", s.Name)
|
||||
ltIdentity, err := age.ParseX25519Identity(string(ltPrivKeyData))
|
||||
if err != nil {
|
||||
Debug("Failed to parse long-term private key", "error", err, "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to parse long-term private key: %w", err)
|
||||
}
|
||||
|
||||
DebugWith("Successfully obtained vault's long-term key",
|
||||
slog.String("secret_name", s.Name),
|
||||
slog.String("public_key", ltIdentity.Recipient().String()),
|
||||
)
|
||||
|
||||
// Use the long-term key to decrypt the version
|
||||
return version.GetValue(ltIdentity)
|
||||
}
|
||||
|
||||
// LoadMetadata is deprecated - metadata is now per-version and encrypted
|
||||
func (s *Secret) LoadMetadata() error {
|
||||
Debug("LoadMetadata called but is deprecated in versioned model", "secret_name", s.Name)
|
||||
// For backward compatibility, we'll populate with basic info
|
||||
now := time.Now()
|
||||
s.Metadata = Metadata{
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMetadata returns the secret metadata (deprecated)
|
||||
func (s *Secret) GetMetadata() Metadata {
|
||||
Debug("GetMetadata called but is deprecated in versioned model", "secret_name", s.Name)
|
||||
|
||||
return s.Metadata
|
||||
}
|
||||
|
||||
// GetEncryptedData is deprecated - data is now stored in versions
|
||||
func (s *Secret) GetEncryptedData() ([]byte, error) {
|
||||
Debug("GetEncryptedData called but is deprecated in versioned model", "secret_name", s.Name)
|
||||
|
||||
return nil, fmt.Errorf("GetEncryptedData is deprecated - use version-specific methods")
|
||||
}
|
||||
|
||||
// Exists checks if the secret exists on disk
|
||||
func (s *Secret) Exists() (bool, error) {
|
||||
DebugWith("Checking if secret exists",
|
||||
slog.String("secret_name", s.Name),
|
||||
slog.String("vault_name", s.vault.GetName()),
|
||||
)
|
||||
|
||||
// Check if the secret directory exists and has a current symlink
|
||||
exists, err := afero.DirExists(s.vault.GetFilesystem(), s.Directory)
|
||||
if err != nil {
|
||||
Debug("Failed to check secret directory existence", "error", err, "secret_dir", s.Directory)
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
Debug("Secret directory does not exist", "secret_dir", s.Directory)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check if current symlink exists
|
||||
_, err = GetCurrentVersion(s.vault.GetFilesystem(), s.Directory)
|
||||
if err != nil {
|
||||
Debug("No current version found", "error", err, "secret_name", s.Name)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
DebugWith("Secret existence check result",
|
||||
slog.String("secret_name", s.Name),
|
||||
slog.Bool("exists", true),
|
||||
)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetCurrentVault gets the current vault from the file system
|
||||
// This function is a wrapper around the actual implementation in the vault package
|
||||
// and exists to break the import cycle.
|
||||
func GetCurrentVault(fs afero.Fs, stateDir string) (VaultInterface, error) {
|
||||
// This is a forward declaration. The actual implementation is provided
|
||||
// by the vault package when it calls RegisterGetCurrentVaultFunc.
|
||||
if getCurrentVaultFunc == nil {
|
||||
return nil, fmt.Errorf("GetCurrentVault function not registered")
|
||||
}
|
||||
|
||||
return getCurrentVaultFunc(fs, stateDir)
|
||||
}
|
||||
|
||||
// getCurrentVaultFunc is a function variable that will be set by the vault package
|
||||
// to implement the actual GetCurrentVault functionality
|
||||
//
|
||||
//nolint:gochecknoglobals // Required to break import cycle
|
||||
var getCurrentVaultFunc func(fs afero.Fs, stateDir string) (VaultInterface, error)
|
||||
|
||||
// RegisterGetCurrentVaultFunc allows the vault package to register its implementation
|
||||
// of GetCurrentVault to break the import cycle
|
||||
func RegisterGetCurrentVaultFunc(fn func(fs afero.Fs, stateDir string) (VaultInterface, error)) {
|
||||
getCurrentVaultFunc = fn
|
||||
}
|
||||
329
internal/secret/secret_test.go
Normal file
329
internal/secret/secret_test.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockVault is a test implementation of the VaultInterface
|
||||
type MockVault struct {
|
||||
name string
|
||||
fs afero.Fs
|
||||
directory string
|
||||
derivationIndex uint32
|
||||
}
|
||||
|
||||
func (m *MockVault) GetDirectory() (string, error) {
|
||||
return m.directory, nil
|
||||
}
|
||||
|
||||
func (m *MockVault) AddSecret(name string, value *memguard.LockedBuffer, _ bool) error {
|
||||
// Create secret directory with proper storage name conversion
|
||||
storageName := strings.ReplaceAll(name, "/", "%")
|
||||
secretDir := filepath.Join(m.directory, "secrets.d", storageName)
|
||||
if err := m.fs.MkdirAll(secretDir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create version directory with proper path
|
||||
versionName := "20240101.001" // Use a fixed version name for testing
|
||||
versionDir := filepath.Join(secretDir, "versions", versionName)
|
||||
if err := m.fs.MkdirAll(versionDir, 0o700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read the vault's long-term public key
|
||||
ltPubKeyPath := filepath.Join(m.directory, "pub.age")
|
||||
|
||||
// Derive long-term key using the vault's derivation index
|
||||
mnemonic := os.Getenv(EnvMnemonic)
|
||||
if mnemonic == "" {
|
||||
return fmt.Errorf("SB_SECRET_MNEMONIC not set")
|
||||
}
|
||||
|
||||
ltIdentity, err := agehd.DeriveIdentity(mnemonic, m.derivationIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write long-term public key if it doesn't exist
|
||||
if _, err := m.fs.Stat(ltPubKeyPath); os.IsNotExist(err) {
|
||||
pubKey := ltIdentity.Recipient().String()
|
||||
if err := afero.WriteFile(m.fs, ltPubKeyPath, []byte(pubKey), 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Generate version-specific keypair
|
||||
versionIdentity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write version public key
|
||||
pubKeyPath := filepath.Join(versionDir, "pub.age")
|
||||
if err := afero.WriteFile(m.fs, pubKeyPath, []byte(versionIdentity.Recipient().String()), 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encrypt value to version's public key (value is already a LockedBuffer)
|
||||
encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write encrypted value
|
||||
valuePath := filepath.Join(versionDir, "value.age")
|
||||
if err := afero.WriteFile(m.fs, valuePath, encryptedValue, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encrypt version private key to long-term public key
|
||||
versionPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String()))
|
||||
defer versionPrivKeyBuffer.Destroy()
|
||||
encryptedPrivKey, err := EncryptToRecipient(versionPrivKeyBuffer, ltIdentity.Recipient())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write encrypted version private key
|
||||
privKeyPath := filepath.Join(versionDir, "priv.age")
|
||||
if err := afero.WriteFile(m.fs, privKeyPath, encryptedPrivKey, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create current symlink pointing to the version
|
||||
currentLink := filepath.Join(secretDir, "current")
|
||||
// For MemMapFs, write a file with the target path
|
||||
if err := afero.WriteFile(m.fs, currentLink, []byte("versions/"+versionName), 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockVault) GetName() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockVault) GetFilesystem() afero.Fs {
|
||||
return m.fs
|
||||
}
|
||||
|
||||
func (m *MockVault) GetCurrentUnlocker() (Unlocker, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockVault) CreatePassphraseUnlocker(_ *memguard.LockedBuffer) (*PassphraseUnlocker, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestPerSecretKeyFunctionality(t *testing.T) {
|
||||
// Create an in-memory filesystem for testing
|
||||
fs := afero.NewMemMapFs()
|
||||
|
||||
// Set test mnemonic for direct encryption/decryption
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
t.Setenv(EnvMnemonic, testMnemonic)
|
||||
|
||||
// Set up a test vault structure
|
||||
baseDir := "/test-config/berlin.sneak.pkg.secret"
|
||||
vaultDir := filepath.Join(baseDir, "vaults.d", "test-vault")
|
||||
|
||||
// Create vault directory structure
|
||||
err := fs.MkdirAll(filepath.Join(vaultDir, "secrets.d"), DirPerms)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Generate a long-term keypair for the vault using the test mnemonic
|
||||
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate long-term identity: %v", err)
|
||||
}
|
||||
|
||||
// Write long-term public key
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(
|
||||
fs,
|
||||
ltPubKeyPath,
|
||||
[]byte(ltIdentity.Recipient().String()),
|
||||
0o600,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write long-term public key: %v", err)
|
||||
}
|
||||
|
||||
// Set current vault
|
||||
currentVaultPath := filepath.Join(baseDir, "currentvault")
|
||||
err = afero.WriteFile(fs, currentVaultPath, []byte(vaultDir), FilePerms)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set current vault: %v", err)
|
||||
}
|
||||
|
||||
// Create vault instance using the mock vault
|
||||
vault := &MockVault{
|
||||
name: "test-vault",
|
||||
fs: fs,
|
||||
directory: vaultDir,
|
||||
derivationIndex: 0,
|
||||
}
|
||||
|
||||
// Test data
|
||||
secretName := "test-secret"
|
||||
secretValue := []byte("this is a test secret value")
|
||||
|
||||
// Create a secure buffer for the test value
|
||||
valueBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||
defer valueBuffer.Destroy()
|
||||
|
||||
// Test AddSecret
|
||||
t.Run("AddSecret", func(t *testing.T) {
|
||||
err := vault.AddSecret(secretName, valueBuffer, false)
|
||||
if err != nil {
|
||||
t.Fatalf("AddSecret failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify that all expected files were created
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", secretName)
|
||||
|
||||
// Check versions directory exists
|
||||
versionsDir := filepath.Join(secretDir, "versions")
|
||||
versionsDirExists, err := afero.DirExists(fs, versionsDir)
|
||||
if err != nil || !versionsDirExists {
|
||||
t.Fatalf("versions directory was not created")
|
||||
}
|
||||
|
||||
// Check current symlink exists
|
||||
currentVersion, err := GetCurrentVersion(fs, secretDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current version: %v", err)
|
||||
}
|
||||
|
||||
// Check value.age exists in the version directory
|
||||
versionDir := filepath.Join(versionsDir, currentVersion)
|
||||
valueExists, err := afero.Exists(
|
||||
fs,
|
||||
filepath.Join(versionDir, "value.age"),
|
||||
)
|
||||
if err != nil || !valueExists {
|
||||
t.Fatalf("value.age file was not created in version directory")
|
||||
}
|
||||
|
||||
t.Logf("All expected files created successfully with versioning")
|
||||
})
|
||||
|
||||
// Create a Secret object to test with
|
||||
secret := NewSecret(vault, secretName)
|
||||
|
||||
// Test GetValue (this will need to be modified since we're using a mock vault)
|
||||
t.Run("GetSecret", func(t *testing.T) {
|
||||
// This test is simplified since we're not implementing the full encryption/decryption
|
||||
// in the mock. We just verify the Secret object is created correctly.
|
||||
if secret.Name != secretName {
|
||||
t.Fatalf("Secret name doesn't match. Expected: %s, Got: %s", secretName, secret.Name)
|
||||
}
|
||||
|
||||
if secret.vault != vault {
|
||||
t.Fatalf("Secret vault reference doesn't match expected vault")
|
||||
}
|
||||
|
||||
t.Logf("Successfully created Secret object with correct properties")
|
||||
})
|
||||
|
||||
// Test Exists
|
||||
t.Run("SecretExists", func(t *testing.T) {
|
||||
exists, err := secret.Exists()
|
||||
if err != nil {
|
||||
t.Fatalf("Error checking if secret exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("Secret should exist but Exists() returned false")
|
||||
}
|
||||
t.Logf("Secret.Exists() works correctly")
|
||||
})
|
||||
}
|
||||
|
||||
// For testing purposes only
|
||||
func isValidSecretName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
// Valid characters for secret names: lowercase letters, numbers, dash, dot, underscore, slash
|
||||
for _, char := range name {
|
||||
if (char < 'a' || char > 'z') && // lowercase letters
|
||||
(char < '0' || char > '9') && // numbers
|
||||
char != '-' && // dash
|
||||
char != '.' && // dot
|
||||
char != '_' && // underscore
|
||||
char != '/' { // slash
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestSecretNameValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
valid bool
|
||||
}{
|
||||
{"valid-name", true},
|
||||
{"valid.name", true},
|
||||
{"valid_name", true},
|
||||
{"valid/path/name", true},
|
||||
{"123valid", true},
|
||||
{"", false},
|
||||
{"Invalid-Name", false}, // uppercase not allowed
|
||||
{"invalid name", false}, // space not allowed
|
||||
{"invalid@name", false}, // @ not allowed
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := isValidSecretName(test.name)
|
||||
if result != test.valid {
|
||||
t.Errorf(
|
||||
"isValidSecretName(%q) = %v, want %v",
|
||||
test.name,
|
||||
result,
|
||||
test.valid,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecretGetValueWithEnvMnemonicUsesVaultDerivationIndex(t *testing.T) {
|
||||
// This test demonstrates the bug where GetValue uses hardcoded index 0
|
||||
// instead of the vault's actual derivation index when using environment mnemonic
|
||||
|
||||
// Set up test mnemonic
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
t.Setenv(EnvMnemonic, testMnemonic)
|
||||
|
||||
// Create temporary directory for vaults
|
||||
fs := afero.NewOsFs()
|
||||
tempDir, err := afero.TempDir(fs, "", "secret-test-")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = fs.RemoveAll(tempDir)
|
||||
}()
|
||||
|
||||
stateDir := filepath.Join(tempDir, ".secret")
|
||||
require.NoError(t, fs.MkdirAll(stateDir, 0o700))
|
||||
|
||||
// This test is now in the integration test file where it can use real vaults
|
||||
// The bug is demonstrated there - see test31EnvMnemonicUsesVaultDerivationIndex
|
||||
t.Log("This test demonstrates the bug in the integration test file")
|
||||
}
|
||||
15
internal/secret/unlocker.go
Normal file
15
internal/secret/unlocker.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"filippo.io/age"
|
||||
)
|
||||
|
||||
// Unlocker interface defines the methods all unlocker types must implement
|
||||
type Unlocker interface {
|
||||
GetIdentity() (*age.X25519Identity, error)
|
||||
GetType() string
|
||||
GetMetadata() UnlockerMetadata
|
||||
GetDirectory() string
|
||||
GetID() string // Generate ID based on unlocker type and data
|
||||
Remove() error // Remove the unlocker and any associated resources
|
||||
}
|
||||
297
internal/secret/validation_test.go
Normal file
297
internal/secret/validation_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateGPGKeyID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyID string
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid cases
|
||||
{
|
||||
name: "valid email address",
|
||||
keyID: "test@example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid email with dots and hyphens",
|
||||
keyID: "test.user-name@example-domain.co.uk",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid email with plus",
|
||||
keyID: "test+tag@example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid short key ID (8 hex chars)",
|
||||
keyID: "ABCDEF12",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid long key ID (16 hex chars)",
|
||||
keyID: "ABCDEF1234567890",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid fingerprint (40 hex chars)",
|
||||
keyID: "ABCDEF1234567890ABCDEF1234567890ABCDEF12",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid lowercase hex fingerprint",
|
||||
keyID: "abcdef1234567890abcdef1234567890abcdef12",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid mixed case hex",
|
||||
keyID: "AbCdEf1234567890",
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// Invalid cases
|
||||
{
|
||||
name: "empty key ID",
|
||||
keyID: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with spaces",
|
||||
keyID: "test user@example.com",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with semicolon (command injection)",
|
||||
keyID: "test@example.com; rm -rf /",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with pipe (command injection)",
|
||||
keyID: "test@example.com | cat /etc/passwd",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with backticks (command injection)",
|
||||
keyID: "test@example.com`whoami`",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with dollar sign (command injection)",
|
||||
keyID: "test@example.com$(whoami)",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with quotes",
|
||||
keyID: "test\"@example.com",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with single quotes",
|
||||
keyID: "test'@example.com",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with backslash",
|
||||
keyID: "test\\@example.com",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with newline",
|
||||
keyID: "test@example.com\nrm -rf /",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with carriage return",
|
||||
keyID: "test@example.com\rrm -rf /",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "hex with invalid length (7 chars)",
|
||||
keyID: "ABCDEF1",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "hex with invalid length (9 chars)",
|
||||
keyID: "ABCDEF123",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "hex with non-hex characters",
|
||||
keyID: "ABCDEFGH",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mixed format (email with hex)",
|
||||
keyID: "test@ABCDEF12",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with ampersand",
|
||||
keyID: "test@example.com & echo test",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with redirect",
|
||||
keyID: "test@example.com > /tmp/test",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "key ID with null byte",
|
||||
keyID: "test@example.com\x00",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateGPGKeyID(tt.keyID)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateGPGKeyID() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateKeychainItemName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
itemName string
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid cases
|
||||
{
|
||||
name: "valid simple name",
|
||||
itemName: "my-secret-key",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid name with dots",
|
||||
itemName: "com.example.app.key",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid name with underscores",
|
||||
itemName: "my_secret_key_123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid alphanumeric",
|
||||
itemName: "Secret123Key",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid with hyphen at start",
|
||||
itemName: "-my-key",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid with dot at start",
|
||||
itemName: ".hidden-key",
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// Invalid cases
|
||||
{
|
||||
name: "empty item name",
|
||||
itemName: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with spaces",
|
||||
itemName: "my secret key",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with semicolon",
|
||||
itemName: "key;rm -rf /",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with pipe",
|
||||
itemName: "key|cat /etc/passwd",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with backticks",
|
||||
itemName: "key`whoami`",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with dollar sign",
|
||||
itemName: "key$(whoami)",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with quotes",
|
||||
itemName: "key\"name",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with single quotes",
|
||||
itemName: "key'name",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with backslash",
|
||||
itemName: "key\\name",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with newline",
|
||||
itemName: "key\nname",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with carriage return",
|
||||
itemName: "key\rname",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with ampersand",
|
||||
itemName: "key&echo test",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with redirect",
|
||||
itemName: "key>/tmp/test",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with null byte",
|
||||
itemName: "key\x00name",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with parentheses",
|
||||
itemName: "key(test)",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with brackets",
|
||||
itemName: "key[test]",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with asterisk",
|
||||
itemName: "key*",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "item name with question mark",
|
||||
itemName: "key?",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateKeychainItemName(tt.itemName)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateKeychainItemName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
487
internal/secret/version.go
Normal file
487
internal/secret/version.go
Normal file
@@ -0,0 +1,487 @@
|
||||
package secret
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
const (
|
||||
versionNameParts = 2
|
||||
maxVersionsPerDay = 999
|
||||
)
|
||||
|
||||
// VersionMetadata contains information about a secret version
|
||||
type VersionMetadata struct {
|
||||
ID string `json:"id"` // ULID
|
||||
CreatedAt *time.Time `json:"createdAt,omitempty"` // When version was created
|
||||
NotBefore *time.Time `json:"notBefore,omitempty"` // When this version becomes active
|
||||
NotAfter *time.Time `json:"notAfter,omitempty"` // When this version expires (nil = current)
|
||||
}
|
||||
|
||||
// Version represents a version of a secret
|
||||
type Version struct {
|
||||
SecretName string
|
||||
Version string
|
||||
Directory string
|
||||
Metadata VersionMetadata
|
||||
vault VaultInterface
|
||||
}
|
||||
|
||||
// NewVersion creates a new Version instance
|
||||
func NewVersion(vault VaultInterface, secretName string, version string) *Version {
|
||||
DebugWith("Creating new secret version instance",
|
||||
slog.String("secret_name", secretName),
|
||||
slog.String("version", version),
|
||||
slog.String("vault_name", vault.GetName()),
|
||||
)
|
||||
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
storageName := strings.ReplaceAll(secretName, "/", "%")
|
||||
versionDir := filepath.Join(vaultDir, "secrets.d", storageName, "versions", version)
|
||||
|
||||
DebugWith("Secret version storage details",
|
||||
slog.String("secret_name", secretName),
|
||||
slog.String("version", version),
|
||||
slog.String("version_dir", versionDir),
|
||||
)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
return &Version{
|
||||
SecretName: secretName,
|
||||
Version: version,
|
||||
Directory: versionDir,
|
||||
vault: vault,
|
||||
Metadata: VersionMetadata{
|
||||
ID: ulid.Make().String(),
|
||||
CreatedAt: &now,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateVersionName generates a new version name in YYYYMMDD.NNN format
|
||||
func GenerateVersionName(fs afero.Fs, secretDir string) (string, error) {
|
||||
today := time.Now().Format("20060102")
|
||||
versionsDir := filepath.Join(secretDir, "versions")
|
||||
|
||||
// Ensure versions directory exists
|
||||
if err := fs.MkdirAll(versionsDir, DirPerms); err != nil {
|
||||
return "", fmt.Errorf("failed to create versions directory: %w", err)
|
||||
}
|
||||
|
||||
// Find the highest serial number for today
|
||||
entries, err := afero.ReadDir(fs, versionsDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read versions directory: %w", err)
|
||||
}
|
||||
|
||||
maxSerial := 0
|
||||
prefix := today + "."
|
||||
|
||||
for _, entry := range entries {
|
||||
// Skip non-directories and those without correct prefix
|
||||
if !entry.IsDir() || !strings.HasPrefix(entry.Name(), prefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract serial number
|
||||
parts := strings.Split(entry.Name(), ".")
|
||||
if len(parts) != versionNameParts {
|
||||
continue
|
||||
}
|
||||
|
||||
var serial int
|
||||
if _, err := fmt.Sscanf(parts[1], "%03d", &serial); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if serial > maxSerial {
|
||||
maxSerial = serial
|
||||
}
|
||||
}
|
||||
|
||||
// Generate new version name
|
||||
newSerial := maxSerial + 1
|
||||
if newSerial > maxVersionsPerDay {
|
||||
return "", fmt.Errorf("exceeded maximum versions per day (999)")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s.%03d", today, newSerial), nil
|
||||
}
|
||||
|
||||
// Save saves the version metadata and value
|
||||
func (sv *Version) Save(value *memguard.LockedBuffer) error {
|
||||
if value == nil {
|
||||
return fmt.Errorf("value buffer is nil")
|
||||
}
|
||||
|
||||
DebugWith("Saving secret version",
|
||||
slog.String("secret_name", sv.SecretName),
|
||||
slog.String("version", sv.Version),
|
||||
slog.Int("value_length", value.Size()),
|
||||
)
|
||||
|
||||
fs := sv.vault.GetFilesystem()
|
||||
|
||||
// Create version directory
|
||||
if err := fs.MkdirAll(sv.Directory, DirPerms); err != nil {
|
||||
Debug("Failed to create version directory", "error", err, "dir", sv.Directory)
|
||||
|
||||
return fmt.Errorf("failed to create version directory: %w", err)
|
||||
}
|
||||
|
||||
// Step 1: Generate a new keypair for this version
|
||||
Debug("Generating version-specific keypair", "version", sv.Version)
|
||||
versionIdentity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
Debug("Failed to generate version keypair", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to generate version keypair: %w", err)
|
||||
}
|
||||
|
||||
versionPublicKey := versionIdentity.Recipient().String()
|
||||
// Store private key in memguard buffer immediately
|
||||
versionPrivateKeyBuffer := memguard.NewBufferFromBytes([]byte(versionIdentity.String()))
|
||||
defer versionPrivateKeyBuffer.Destroy()
|
||||
|
||||
DebugWith("Generated version keypair",
|
||||
slog.String("version", sv.Version),
|
||||
slog.String("public_key", versionPublicKey),
|
||||
)
|
||||
|
||||
// Step 2: Store the version's public key
|
||||
pubKeyPath := filepath.Join(sv.Directory, "pub.age")
|
||||
Debug("Writing version public key", "path", pubKeyPath)
|
||||
if err := afero.WriteFile(fs, pubKeyPath, []byte(versionPublicKey), FilePerms); err != nil {
|
||||
Debug("Failed to write version public key", "error", err, "path", pubKeyPath)
|
||||
|
||||
return fmt.Errorf("failed to write version public key: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Encrypt the value to the version's public key
|
||||
Debug("Encrypting value to version's public key", "version", sv.Version)
|
||||
encryptedValue, err := EncryptToRecipient(value, versionIdentity.Recipient())
|
||||
if err != nil {
|
||||
Debug("Failed to encrypt version value", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to encrypt version value: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Store the encrypted value
|
||||
valuePath := filepath.Join(sv.Directory, "value.age")
|
||||
Debug("Writing encrypted version value", "path", valuePath)
|
||||
if err := afero.WriteFile(fs, valuePath, encryptedValue, FilePerms); err != nil {
|
||||
Debug("Failed to write encrypted version value", "error", err, "path", valuePath)
|
||||
|
||||
return fmt.Errorf("failed to write encrypted version value: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Get vault's long-term public key for encrypting the version's private key
|
||||
vaultDir, _ := sv.vault.GetDirectory()
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
Debug("Reading long-term public key", "path", ltPubKeyPath)
|
||||
|
||||
ltPubKeyData, err := afero.ReadFile(fs, ltPubKeyPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read long-term public key", "error", err, "path", ltPubKeyPath)
|
||||
|
||||
return fmt.Errorf("failed to read long-term public key: %w", err)
|
||||
}
|
||||
|
||||
Debug("Parsing long-term public key")
|
||||
ltRecipient, err := age.ParseX25519Recipient(string(ltPubKeyData))
|
||||
if err != nil {
|
||||
Debug("Failed to parse long-term public key", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to parse long-term public key: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Encrypt the version's private key to the long-term public key
|
||||
Debug("Encrypting version private key to long-term public key", "version", sv.Version)
|
||||
encryptedPrivKey, err := EncryptToRecipient(versionPrivateKeyBuffer, ltRecipient)
|
||||
if err != nil {
|
||||
Debug("Failed to encrypt version private key", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to encrypt version private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: Store the encrypted private key
|
||||
privKeyPath := filepath.Join(sv.Directory, "priv.age")
|
||||
Debug("Writing encrypted version private key", "path", privKeyPath)
|
||||
if err := afero.WriteFile(fs, privKeyPath, encryptedPrivKey, FilePerms); err != nil {
|
||||
Debug("Failed to write encrypted version private key", "error", err, "path", privKeyPath)
|
||||
|
||||
return fmt.Errorf("failed to write encrypted version private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 8: Encrypt and store metadata
|
||||
Debug("Encrypting version metadata", "version", sv.Version)
|
||||
metadataBytes, err := json.MarshalIndent(sv.Metadata, "", " ")
|
||||
if err != nil {
|
||||
Debug("Failed to marshal version metadata", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to marshal version metadata: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt metadata to the version's public key
|
||||
metadataBuffer := memguard.NewBufferFromBytes(metadataBytes)
|
||||
defer metadataBuffer.Destroy()
|
||||
|
||||
encryptedMetadata, err := EncryptToRecipient(metadataBuffer, versionIdentity.Recipient())
|
||||
if err != nil {
|
||||
Debug("Failed to encrypt version metadata", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to encrypt version metadata: %w", err)
|
||||
}
|
||||
|
||||
metadataPath := filepath.Join(sv.Directory, "metadata.age")
|
||||
Debug("Writing encrypted version metadata", "path", metadataPath)
|
||||
if err := afero.WriteFile(fs, metadataPath, encryptedMetadata, FilePerms); err != nil {
|
||||
Debug("Failed to write encrypted version metadata", "error", err, "path", metadataPath)
|
||||
|
||||
return fmt.Errorf("failed to write encrypted version metadata: %w", err)
|
||||
}
|
||||
|
||||
Debug("Successfully saved secret version", "version", sv.Version, "secret_name", sv.SecretName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadMetadata loads and decrypts the version metadata
|
||||
func (sv *Version) LoadMetadata(ltIdentity *age.X25519Identity) error {
|
||||
DebugWith("Loading version metadata",
|
||||
slog.String("secret_name", sv.SecretName),
|
||||
slog.String("version", sv.Version),
|
||||
)
|
||||
|
||||
fs := sv.vault.GetFilesystem()
|
||||
|
||||
// Step 1: Read encrypted version private key
|
||||
encryptedPrivKeyPath := filepath.Join(sv.Directory, "priv.age")
|
||||
encryptedPrivKey, err := afero.ReadFile(fs, encryptedPrivKeyPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read encrypted version private key", "error", err, "path", encryptedPrivKeyPath)
|
||||
|
||||
return fmt.Errorf("failed to read encrypted version private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Decrypt version private key using long-term key
|
||||
versionPrivKeyData, err := DecryptWithIdentity(encryptedPrivKey, ltIdentity)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt version private key", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to decrypt version private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Parse version private key
|
||||
versionIdentity, err := age.ParseX25519Identity(string(versionPrivKeyData))
|
||||
if err != nil {
|
||||
Debug("Failed to parse version private key", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to parse version private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Read encrypted metadata
|
||||
encryptedMetadataPath := filepath.Join(sv.Directory, "metadata.age")
|
||||
encryptedMetadata, err := afero.ReadFile(fs, encryptedMetadataPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read encrypted version metadata", "error", err, "path", encryptedMetadataPath)
|
||||
|
||||
return fmt.Errorf("failed to read encrypted version metadata: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Decrypt metadata using version key
|
||||
metadataBytes, err := DecryptWithIdentity(encryptedMetadata, versionIdentity)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt version metadata", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to decrypt version metadata: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Unmarshal metadata
|
||||
var metadata VersionMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
Debug("Failed to unmarshal version metadata", "error", err, "version", sv.Version)
|
||||
|
||||
return fmt.Errorf("failed to unmarshal version metadata: %w", err)
|
||||
}
|
||||
|
||||
sv.Metadata = metadata
|
||||
Debug("Successfully loaded version metadata", "version", sv.Version)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValue retrieves and decrypts the version value
|
||||
func (sv *Version) GetValue(ltIdentity *age.X25519Identity) ([]byte, error) {
|
||||
DebugWith("Getting version value",
|
||||
slog.String("secret_name", sv.SecretName),
|
||||
slog.String("version", sv.Version),
|
||||
)
|
||||
|
||||
// Debug: Log the directory and long-term key info
|
||||
Debug("SecretVersion GetValue debug info",
|
||||
"secret_name", sv.SecretName,
|
||||
"version", sv.Version,
|
||||
"directory", sv.Directory,
|
||||
"lt_identity_public_key", ltIdentity.Recipient().String())
|
||||
|
||||
fs := sv.vault.GetFilesystem()
|
||||
|
||||
// Step 1: Read encrypted version private key
|
||||
encryptedPrivKeyPath := filepath.Join(sv.Directory, "priv.age")
|
||||
Debug("Reading encrypted version private key", "path", encryptedPrivKeyPath)
|
||||
encryptedPrivKey, err := afero.ReadFile(fs, encryptedPrivKeyPath)
|
||||
if err != nil {
|
||||
Debug("Failed to read encrypted version private key", "error", err, "path", encryptedPrivKeyPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read encrypted version private key: %w", err)
|
||||
}
|
||||
Debug("Successfully read encrypted version private key", "path", encryptedPrivKeyPath, "size", len(encryptedPrivKey))
|
||||
|
||||
// Step 2: Decrypt version private key using long-term key
|
||||
Debug("Decrypting version private key with long-term identity", "version", sv.Version)
|
||||
versionPrivKeyData, err := DecryptWithIdentity(encryptedPrivKey, ltIdentity)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt version private key", "error", err, "version", sv.Version)
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt version private key: %w", err)
|
||||
}
|
||||
Debug("Successfully decrypted version private key", "version", sv.Version, "size", len(versionPrivKeyData))
|
||||
|
||||
// Step 3: Parse version private key
|
||||
versionIdentity, err := age.ParseX25519Identity(string(versionPrivKeyData))
|
||||
if err != nil {
|
||||
Debug("Failed to parse version private key", "error", err, "version", sv.Version)
|
||||
|
||||
return nil, fmt.Errorf("failed to parse version private key: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Read encrypted value
|
||||
encryptedValuePath := filepath.Join(sv.Directory, "value.age")
|
||||
Debug("Reading encrypted value", "path", encryptedValuePath)
|
||||
encryptedValue, err := afero.ReadFile(fs, encryptedValuePath)
|
||||
if err != nil {
|
||||
Debug("Failed to read encrypted version value", "error", err, "path", encryptedValuePath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read encrypted version value: %w", err)
|
||||
}
|
||||
Debug("Successfully read encrypted value", "path", encryptedValuePath, "size", len(encryptedValue))
|
||||
|
||||
// Step 5: Decrypt value using version key
|
||||
Debug("Decrypting value with version identity", "version", sv.Version)
|
||||
value, err := DecryptWithIdentity(encryptedValue, versionIdentity)
|
||||
if err != nil {
|
||||
Debug("Failed to decrypt version value", "error", err, "version", sv.Version)
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt version value: %w", err)
|
||||
}
|
||||
|
||||
Debug("Successfully retrieved version value",
|
||||
"version", sv.Version,
|
||||
"value_length", len(value),
|
||||
"is_empty", len(value) == 0)
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ListVersions lists all versions of a secret
|
||||
func ListVersions(fs afero.Fs, secretDir string) ([]string, error) {
|
||||
versionsDir := filepath.Join(secretDir, "versions")
|
||||
|
||||
// Check if versions directory exists
|
||||
exists, err := afero.DirExists(fs, versionsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check versions directory: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// List all version directories
|
||||
entries, err := afero.ReadDir(fs, versionsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read versions directory: %w", err)
|
||||
}
|
||||
|
||||
var versions []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
versions = append(versions, entry.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// Sort versions in reverse chronological order
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(versions)))
|
||||
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
// GetCurrentVersion returns the version that the "current" symlink points to
|
||||
func GetCurrentVersion(fs afero.Fs, secretDir string) (string, error) {
|
||||
currentPath := filepath.Join(secretDir, "current")
|
||||
|
||||
// Try to read as a real symlink first
|
||||
if _, ok := fs.(*afero.OsFs); ok {
|
||||
target, err := os.Readlink(currentPath)
|
||||
if err == nil {
|
||||
// Extract version from path (e.g., "versions/20231215.001" -> "20231215.001")
|
||||
parts := strings.Split(target, "/")
|
||||
if len(parts) >= 2 && parts[0] == "versions" {
|
||||
return parts[1], nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid current version symlink format: %s", target)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to reading as a file (for MemMapFs testing)
|
||||
fileData, err := afero.ReadFile(fs, currentPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read current version symlink: %w", err)
|
||||
}
|
||||
|
||||
target := strings.TrimSpace(string(fileData))
|
||||
|
||||
// Extract version from path
|
||||
parts := strings.Split(target, "/")
|
||||
if len(parts) >= 2 && parts[0] == "versions" {
|
||||
return parts[1], nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid current version symlink format: %s", target)
|
||||
}
|
||||
|
||||
// SetCurrentVersion updates the "current" symlink to point to a specific version
|
||||
func SetCurrentVersion(fs afero.Fs, secretDir string, version string) error {
|
||||
currentPath := filepath.Join(secretDir, "current")
|
||||
targetPath := filepath.Join("versions", version)
|
||||
|
||||
// Remove existing symlink if it exists
|
||||
_ = fs.Remove(currentPath)
|
||||
|
||||
// Try to create a real symlink first (works on Unix systems)
|
||||
if _, ok := fs.(*afero.OsFs); ok {
|
||||
if err := os.Symlink(targetPath, currentPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to creating a file with the target path (for MemMapFs testing)
|
||||
if err := afero.WriteFile(fs, currentPath, []byte(targetPath), FilePerms); err != nil {
|
||||
return fmt.Errorf("failed to create current version symlink: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
371
internal/secret/version_test.go
Normal file
371
internal/secret/version_test.go
Normal file
@@ -0,0 +1,371 @@
|
||||
// Version Support Test Suite Documentation
|
||||
//
|
||||
// This file contains core unit tests for version functionality:
|
||||
//
|
||||
// - TestGenerateVersionName: Tests version name generation with date and serial format
|
||||
// - TestGenerateVersionNameMaxSerial: Tests the 999 versions per day limit
|
||||
// - TestNewVersion: Tests secret version object creation
|
||||
// - TestSecretVersionSave: Tests saving a version with encryption
|
||||
// - TestSecretVersionLoadMetadata: Tests loading and decrypting version metadata
|
||||
// - TestSecretVersionGetValue: Tests retrieving and decrypting version values
|
||||
// - TestListVersions: Tests listing versions in reverse chronological order
|
||||
// - TestGetCurrentVersion: Tests retrieving the current version via symlink
|
||||
// - TestSetCurrentVersion: Tests updating the current version symlink
|
||||
// - TestVersionMetadataTimestamps: Tests timestamp pointer consistency
|
||||
//
|
||||
// Key Test Scenarios:
|
||||
// - Version Creation: First version gets notBefore = epoch + 1 second
|
||||
// - Subsequent versions update previous version's notAfter timestamp
|
||||
// - New version's notBefore equals previous version's notAfter
|
||||
// - Version names follow YYYYMMDD.NNN format
|
||||
// - Maximum 999 versions per day enforced
|
||||
//
|
||||
// Version Retrieval:
|
||||
// - Get current version via symlink
|
||||
// - Get specific version by name
|
||||
// - Empty version parameter returns current
|
||||
// - Non-existent versions return appropriate errors
|
||||
//
|
||||
// Data Integrity:
|
||||
// - Each version has independent encryption keys
|
||||
// - Metadata encryption protects version history
|
||||
// - Long-term key required for all operations
|
||||
// - Concurrent reads handled safely
|
||||
|
||||
package secret
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockVault implements VaultInterface for testing
|
||||
type MockVersionVault struct {
|
||||
Name string
|
||||
fs afero.Fs
|
||||
stateDir string
|
||||
longTermKey *age.X25519Identity
|
||||
}
|
||||
|
||||
func (m *MockVersionVault) GetDirectory() (string, error) {
|
||||
return filepath.Join(m.stateDir, "vaults.d", m.Name), nil
|
||||
}
|
||||
|
||||
func (m *MockVersionVault) AddSecret(_ string, _ *memguard.LockedBuffer, _ bool) error {
|
||||
return fmt.Errorf("not implemented in mock")
|
||||
}
|
||||
|
||||
func (m *MockVersionVault) GetName() string {
|
||||
return m.Name
|
||||
}
|
||||
|
||||
func (m *MockVersionVault) GetFilesystem() afero.Fs {
|
||||
return m.fs
|
||||
}
|
||||
|
||||
func (m *MockVersionVault) GetCurrentUnlocker() (Unlocker, error) {
|
||||
return nil, fmt.Errorf("not implemented in mock")
|
||||
}
|
||||
|
||||
func (m *MockVersionVault) CreatePassphraseUnlocker(_ *memguard.LockedBuffer) (*PassphraseUnlocker, error) {
|
||||
return nil, fmt.Errorf("not implemented in mock")
|
||||
}
|
||||
|
||||
func TestGenerateVersionName(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
secretDir := "/test/secret"
|
||||
|
||||
// Test first version generation
|
||||
version1, err := GenerateVersionName(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, `^\d{8}\.001$`, version1)
|
||||
|
||||
// Create the version directory
|
||||
versionDir := filepath.Join(secretDir, "versions", version1)
|
||||
err = fs.MkdirAll(versionDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test second version generation on same day
|
||||
version2, err := GenerateVersionName(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Regexp(t, `^\d{8}\.002$`, version2)
|
||||
|
||||
// Verify they have the same date prefix
|
||||
assert.Equal(t, version1[:8], version2[:8])
|
||||
assert.NotEqual(t, version1, version2)
|
||||
}
|
||||
|
||||
func TestGenerateVersionNameMaxSerial(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
secretDir := "/test/secret"
|
||||
versionsDir := filepath.Join(secretDir, "versions")
|
||||
|
||||
// Create 999 versions
|
||||
today := time.Now().Format("20060102")
|
||||
for i := 1; i <= 999; i++ {
|
||||
versionName := fmt.Sprintf("%s.%03d", today, i)
|
||||
err := fs.MkdirAll(filepath.Join(versionsDir, versionName), 0o755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Try to create one more - should fail
|
||||
_, err := GenerateVersionName(fs, secretDir)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeded maximum versions per day")
|
||||
}
|
||||
|
||||
func TestNewVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
vault := &MockVersionVault{
|
||||
Name: "test",
|
||||
fs: fs,
|
||||
stateDir: "/test",
|
||||
}
|
||||
|
||||
sv := NewVersion(vault, "test/secret", "20231215.001")
|
||||
|
||||
assert.Equal(t, "test/secret", sv.SecretName)
|
||||
assert.Equal(t, "20231215.001", sv.Version)
|
||||
assert.Contains(t, sv.Directory, "test%secret/versions/20231215.001")
|
||||
assert.NotEmpty(t, sv.Metadata.ID)
|
||||
assert.NotNil(t, sv.Metadata.CreatedAt)
|
||||
}
|
||||
|
||||
func TestSecretVersionSave(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
vault := &MockVersionVault{
|
||||
Name: "test",
|
||||
fs: fs,
|
||||
stateDir: "/test",
|
||||
}
|
||||
|
||||
// Create vault directory structure and long-term key
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
err := fs.MkdirAll(vaultDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store long-term public key
|
||||
ltIdentity, err := age.GenerateX25519Identity()
|
||||
require.NoError(t, err)
|
||||
vault.longTermKey = ltIdentity
|
||||
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, ltPubKeyPath, []byte(ltIdentity.Recipient().String()), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and save a version
|
||||
sv := NewVersion(vault, "test/secret", "20231215.001")
|
||||
testValue := []byte("test-secret-value")
|
||||
|
||||
testBuffer := memguard.NewBufferFromBytes(testValue)
|
||||
defer testBuffer.Destroy()
|
||||
err = sv.Save(testBuffer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify files were created
|
||||
assert.True(t, fileExists(fs, filepath.Join(sv.Directory, "pub.age")))
|
||||
assert.True(t, fileExists(fs, filepath.Join(sv.Directory, "priv.age")))
|
||||
assert.True(t, fileExists(fs, filepath.Join(sv.Directory, "value.age")))
|
||||
assert.True(t, fileExists(fs, filepath.Join(sv.Directory, "metadata.age")))
|
||||
}
|
||||
|
||||
func TestSecretVersionLoadMetadata(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
vault := &MockVersionVault{
|
||||
Name: "test",
|
||||
fs: fs,
|
||||
stateDir: "/test",
|
||||
}
|
||||
|
||||
// Setup vault with long-term key
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
err := fs.MkdirAll(vaultDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
ltIdentity, err := age.GenerateX25519Identity()
|
||||
require.NoError(t, err)
|
||||
vault.longTermKey = ltIdentity
|
||||
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, ltPubKeyPath, []byte(ltIdentity.Recipient().String()), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and save a version with custom metadata
|
||||
sv := NewVersion(vault, "test/secret", "20231215.001")
|
||||
now := time.Now()
|
||||
epochPlusOne := time.Unix(1, 0)
|
||||
sv.Metadata.NotBefore = &epochPlusOne
|
||||
sv.Metadata.NotAfter = &now
|
||||
|
||||
testBuffer := memguard.NewBufferFromBytes([]byte("test-value"))
|
||||
defer testBuffer.Destroy()
|
||||
err = sv.Save(testBuffer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create new version object and load metadata
|
||||
sv2 := NewVersion(vault, "test/secret", "20231215.001")
|
||||
err = sv2.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify loaded metadata
|
||||
assert.Equal(t, sv.Metadata.ID, sv2.Metadata.ID)
|
||||
assert.NotNil(t, sv2.Metadata.NotBefore)
|
||||
assert.Equal(t, epochPlusOne.Unix(), sv2.Metadata.NotBefore.Unix())
|
||||
assert.NotNil(t, sv2.Metadata.NotAfter)
|
||||
}
|
||||
|
||||
func TestSecretVersionGetValue(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
vault := &MockVersionVault{
|
||||
Name: "test",
|
||||
fs: fs,
|
||||
stateDir: "/test",
|
||||
}
|
||||
|
||||
// Setup vault with long-term key
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
err := fs.MkdirAll(vaultDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
ltIdentity, err := age.GenerateX25519Identity()
|
||||
require.NoError(t, err)
|
||||
vault.longTermKey = ltIdentity
|
||||
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, ltPubKeyPath, []byte(ltIdentity.Recipient().String()), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and save a version
|
||||
sv := NewVersion(vault, "test/secret", "20231215.001")
|
||||
originalValue := []byte("test-secret-value-12345")
|
||||
expectedValue := make([]byte, len(originalValue))
|
||||
copy(expectedValue, originalValue)
|
||||
|
||||
originalBuffer := memguard.NewBufferFromBytes(originalValue)
|
||||
defer originalBuffer.Destroy()
|
||||
err = sv.Save(originalBuffer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve the value
|
||||
retrievedValue, err := sv.GetValue(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, expectedValue, retrievedValue)
|
||||
}
|
||||
|
||||
func TestListVersions(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
secretDir := "/test/secret"
|
||||
versionsDir := filepath.Join(secretDir, "versions")
|
||||
|
||||
// No versions directory
|
||||
versions, err := ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, versions)
|
||||
|
||||
// Create some versions
|
||||
testVersions := []string{"20231215.001", "20231215.002", "20231216.001", "20231214.001"}
|
||||
for _, v := range testVersions {
|
||||
err := fs.MkdirAll(filepath.Join(versionsDir, v), 0o755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create a file (not directory) that should be ignored
|
||||
err = afero.WriteFile(fs, filepath.Join(versionsDir, "ignore.txt"), []byte("test"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List versions
|
||||
versions, err = ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be sorted in reverse chronological order
|
||||
expected := []string{"20231216.001", "20231215.002", "20231215.001", "20231214.001"}
|
||||
assert.Equal(t, expected, versions)
|
||||
}
|
||||
|
||||
func TestGetCurrentVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
secretDir := "/test/secret"
|
||||
|
||||
// Simulate symlink with file content (works for both OsFs and MemMapFs)
|
||||
currentPath := filepath.Join(secretDir, "current")
|
||||
err := fs.MkdirAll(secretDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = afero.WriteFile(fs, currentPath, []byte("versions/20231216.001"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
version, err := GetCurrentVersion(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "20231216.001", version)
|
||||
}
|
||||
|
||||
func TestSetCurrentVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
secretDir := "/test/secret"
|
||||
|
||||
err := fs.MkdirAll(secretDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set current version
|
||||
err = SetCurrentVersion(fs, secretDir, "20231216.002")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it was set
|
||||
version, err := GetCurrentVersion(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "20231216.002", version)
|
||||
|
||||
// Update to different version
|
||||
err = SetCurrentVersion(fs, secretDir, "20231217.001")
|
||||
require.NoError(t, err)
|
||||
|
||||
version, err = GetCurrentVersion(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "20231217.001", version)
|
||||
}
|
||||
|
||||
func TestVersionMetadataTimestamps(t *testing.T) {
|
||||
// Test that all timestamp fields behave consistently as pointers
|
||||
vm := VersionMetadata{
|
||||
ID: "test-id",
|
||||
}
|
||||
|
||||
// All should be nil initially
|
||||
assert.Nil(t, vm.CreatedAt)
|
||||
assert.Nil(t, vm.NotBefore)
|
||||
assert.Nil(t, vm.NotAfter)
|
||||
|
||||
// Set timestamps
|
||||
now := time.Now()
|
||||
epoch := time.Unix(1, 0)
|
||||
future := now.Add(time.Hour)
|
||||
|
||||
vm.CreatedAt = &now
|
||||
vm.NotBefore = &epoch
|
||||
vm.NotAfter = &future
|
||||
|
||||
// All should be non-nil
|
||||
assert.NotNil(t, vm.CreatedAt)
|
||||
assert.NotNil(t, vm.NotBefore)
|
||||
assert.NotNil(t, vm.NotAfter)
|
||||
|
||||
// Values should match
|
||||
assert.Equal(t, now.Unix(), vm.CreatedAt.Unix())
|
||||
assert.Equal(t, int64(1), vm.NotBefore.Unix())
|
||||
assert.Equal(t, future.Unix(), vm.NotAfter.Unix())
|
||||
}
|
||||
|
||||
// Helper function
|
||||
func fileExists(fs afero.Fs, path string) bool {
|
||||
exists, _ := afero.Exists(fs, path)
|
||||
return exists
|
||||
}
|
||||
421
internal/vault/integration_test.go
Normal file
421
internal/vault/integration_test.go
Normal file
@@ -0,0 +1,421 @@
|
||||
package vault_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/internal/vault"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
func TestVaultWithRealFilesystem(t *testing.T) {
|
||||
// Create a temporary directory for our tests
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Use the real filesystem
|
||||
fs := afero.NewOsFs()
|
||||
|
||||
// Test mnemonic
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
|
||||
// Set test environment variables
|
||||
t.Setenv(secret.EnvMnemonic, testMnemonic)
|
||||
t.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
|
||||
|
||||
// Test symlink handling
|
||||
t.Run("SymlinkHandling", func(t *testing.T) {
|
||||
stateDir := filepath.Join(tempDir, "symlink-test")
|
||||
if err := os.MkdirAll(stateDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create state dir: %v", err)
|
||||
}
|
||||
|
||||
// Create a test vault
|
||||
vlt, err := vault.CreateVault(fs, stateDir, "test-vault")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault: %v", err)
|
||||
}
|
||||
|
||||
// Get the vault directory
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Create a symlink to the vault directory in a different location
|
||||
symlinkPath := filepath.Join(tempDir, "test-symlink")
|
||||
if err := os.Symlink(vaultDir, symlinkPath); err != nil {
|
||||
t.Fatalf("Failed to create symlink: %v", err)
|
||||
}
|
||||
|
||||
// Test that we can resolve the symlink correctly
|
||||
resolvedPath, err := vault.ResolveVaultSymlink(fs, symlinkPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve symlink: %v", err)
|
||||
}
|
||||
|
||||
// On some platforms, the resolved path might have different case or format
|
||||
// We'll use filepath.EvalSymlinks to get the canonical path for comparison
|
||||
expectedPath, err := filepath.EvalSymlinks(vaultDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to evaluate symlink: %v", err)
|
||||
}
|
||||
actualPath, err := filepath.EvalSymlinks(resolvedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to evaluate resolved path: %v", err)
|
||||
}
|
||||
|
||||
if actualPath != expectedPath {
|
||||
t.Errorf("Expected symlink to resolve to %s, got %s", expectedPath, actualPath)
|
||||
}
|
||||
})
|
||||
|
||||
// Test secret operations with deeply nested paths
|
||||
t.Run("DeepPathSecrets", func(t *testing.T) {
|
||||
stateDir := filepath.Join(tempDir, "deep-path-test")
|
||||
if err := os.MkdirAll(stateDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create state dir: %v", err)
|
||||
}
|
||||
|
||||
// Create a test vault - CreateVault now handles public key when mnemonic is in env
|
||||
vlt, err := vault.CreateVault(fs, stateDir, "test-vault")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault: %v", err)
|
||||
}
|
||||
|
||||
// Load vault metadata to get its derivation index
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
vaultMetadata, err := vault.LoadVaultMetadata(fs, vaultDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load vault metadata: %v", err)
|
||||
}
|
||||
|
||||
// Derive long-term key from mnemonic using the vault's derivation index
|
||||
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, vaultMetadata.DerivationIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Unlock the vault
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Create a secret with a deeply nested path
|
||||
deepPath := "api/credentials/production/database/primary"
|
||||
secretValue := []byte("supersecretdbpassword")
|
||||
expectedValue := make([]byte, len(secretValue))
|
||||
copy(expectedValue, secretValue)
|
||||
|
||||
secretBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||
defer secretBuffer.Destroy()
|
||||
|
||||
err = vlt.AddSecret(deepPath, secretBuffer, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add secret with deep path: %v", err)
|
||||
}
|
||||
|
||||
// List secrets and verify our deep path secret is there
|
||||
secrets, err := vlt.ListSecrets()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list secrets: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, s := range secrets {
|
||||
if s == deepPath {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("Deep path secret not found in listed secrets")
|
||||
}
|
||||
|
||||
// Retrieve the secret and verify its value
|
||||
retrievedValue, err := vlt.GetSecret(deepPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve deep path secret: %v", err)
|
||||
}
|
||||
|
||||
if string(retrievedValue) != string(expectedValue) {
|
||||
t.Errorf("Retrieved value doesn't match. Expected %q, got %q",
|
||||
string(expectedValue), string(retrievedValue))
|
||||
}
|
||||
})
|
||||
|
||||
// Test key caching in GetOrDeriveLongTermKey
|
||||
t.Run("KeyCaching", func(t *testing.T) {
|
||||
stateDir := filepath.Join(tempDir, "key-cache-test")
|
||||
if err := os.MkdirAll(stateDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create state dir: %v", err)
|
||||
}
|
||||
|
||||
// Create a test vault - CreateVault now handles public key when mnemonic is in env
|
||||
vlt, err := vault.CreateVault(fs, stateDir, "test-vault")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault: %v", err)
|
||||
}
|
||||
|
||||
// Load vault metadata to get its derivation index
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
vaultMetadata, err := vault.LoadVaultMetadata(fs, vaultDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load vault metadata: %v", err)
|
||||
}
|
||||
|
||||
// Derive long-term key from mnemonic for verification using the vault's derivation index
|
||||
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, vaultMetadata.DerivationIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the vault is locked initially
|
||||
if !vlt.Locked() {
|
||||
t.Errorf("Vault should be locked initially")
|
||||
}
|
||||
|
||||
// First call to GetOrDeriveLongTermKey should derive and cache the key
|
||||
firstKey, err := vlt.GetOrDeriveLongTermKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the vault is now unlocked
|
||||
if vlt.Locked() {
|
||||
t.Errorf("Vault should be unlocked after GetOrDeriveLongTermKey")
|
||||
}
|
||||
|
||||
// Second call should return the cached key without re-deriving
|
||||
secondKey, err := vlt.GetOrDeriveLongTermKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get cached long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Verify both keys are the same instance
|
||||
if firstKey != secondKey {
|
||||
t.Errorf("Second key call should return same instance as first call")
|
||||
}
|
||||
|
||||
// Verify the public key matches what we expect
|
||||
expectedPubKey := ltIdentity.Recipient().String()
|
||||
actualPubKey := firstKey.Recipient().String()
|
||||
if actualPubKey != expectedPubKey {
|
||||
t.Errorf("Public key mismatch. Expected %s, got %s", expectedPubKey, actualPubKey)
|
||||
}
|
||||
|
||||
// Now clear the key and verify it's locked again
|
||||
vlt.ClearLongTermKey()
|
||||
if !vlt.Locked() {
|
||||
t.Errorf("Vault should be locked after clearing key")
|
||||
}
|
||||
|
||||
// Get the key again and verify it works
|
||||
thirdKey, err := vlt.GetOrDeriveLongTermKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to re-derive long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Verify the public key still matches
|
||||
actualPubKey = thirdKey.Recipient().String()
|
||||
if actualPubKey != expectedPubKey {
|
||||
t.Errorf("Re-derived public key mismatch. Expected %s, got %s", expectedPubKey, actualPubKey)
|
||||
}
|
||||
})
|
||||
|
||||
// Test vault name validation
|
||||
t.Run("VaultNameValidation", func(t *testing.T) {
|
||||
stateDir := filepath.Join(tempDir, "name-validation-test")
|
||||
if err := os.MkdirAll(stateDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create state dir: %v", err)
|
||||
}
|
||||
|
||||
// Test valid vault names
|
||||
validNames := []string{
|
||||
"default",
|
||||
"test-vault",
|
||||
"production.vault",
|
||||
"vault_123",
|
||||
"a-very-long-vault-name-with-dashes",
|
||||
}
|
||||
|
||||
for _, name := range validNames {
|
||||
_, err := vault.CreateVault(fs, stateDir, name)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create vault with valid name %q: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test invalid vault names
|
||||
invalidNames := []string{
|
||||
"", // Empty
|
||||
"UPPERCASE", // Uppercase not allowed
|
||||
"invalid/name", // Slashes not allowed in vault names
|
||||
"invalid name", // Spaces not allowed
|
||||
"invalid@name", // Special chars not allowed
|
||||
}
|
||||
|
||||
for _, name := range invalidNames {
|
||||
_, err := vault.CreateVault(fs, stateDir, name)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error creating vault with invalid name %q, but got none", name)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test multiple vaults and switching between them
|
||||
t.Run("MultipleVaults", func(t *testing.T) {
|
||||
stateDir := filepath.Join(tempDir, "multi-vault-test")
|
||||
if err := os.MkdirAll(stateDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create state dir: %v", err)
|
||||
}
|
||||
|
||||
// Create three vaults
|
||||
vaultNames := []string{"vault1", "vault2", "vault3"}
|
||||
for _, name := range vaultNames {
|
||||
_, err := vault.CreateVault(fs, stateDir, name)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// List vaults and verify all three are there
|
||||
vaults, err := vault.ListVaults(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list vaults: %v", err)
|
||||
}
|
||||
|
||||
if len(vaults) != 3 {
|
||||
t.Errorf("Expected 3 vaults, got %d", len(vaults))
|
||||
}
|
||||
|
||||
// Test switching between vaults
|
||||
for _, name := range vaultNames {
|
||||
// Select the vault
|
||||
if err := vault.SelectVault(fs, stateDir, name); err != nil {
|
||||
t.Fatalf("Failed to select vault %s: %v", name, err)
|
||||
}
|
||||
|
||||
// Get current vault and verify it's the one we selected
|
||||
currentVault, err := vault.GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current vault after selecting %s: %v", name, err)
|
||||
}
|
||||
|
||||
if currentVault.GetName() != name {
|
||||
t.Errorf("Expected current vault to be %s, got %s", name, currentVault.GetName())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test adding a secret in one vault and verifying it's not visible in another
|
||||
t.Run("VaultIsolation", func(t *testing.T) {
|
||||
stateDir := filepath.Join(tempDir, "isolation-test")
|
||||
if err := os.MkdirAll(stateDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create state dir: %v", err)
|
||||
}
|
||||
|
||||
// Create two vaults - CreateVault now handles public key when mnemonic is in env
|
||||
vault1, err := vault.CreateVault(fs, stateDir, "vault1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault1: %v", err)
|
||||
}
|
||||
|
||||
vault2, err := vault.CreateVault(fs, stateDir, "vault2")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault2: %v", err)
|
||||
}
|
||||
|
||||
// Derive long-term key from mnemonic
|
||||
// Note: Both vaults will have different derivation indexes due to GetNextDerivationIndex
|
||||
|
||||
// Load vault1 metadata to get its derivation index
|
||||
vault1Dir, err := vault1.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault1 directory: %v", err)
|
||||
}
|
||||
vault1Metadata, err := vault.LoadVaultMetadata(fs, vault1Dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load vault1 metadata: %v", err)
|
||||
}
|
||||
|
||||
ltIdentity1, err := agehd.DeriveIdentity(testMnemonic, vault1Metadata.DerivationIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term key for vault1: %v", err)
|
||||
}
|
||||
|
||||
// Load vault2 metadata to get its derivation index
|
||||
vault2Dir, err := vault2.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault2 directory: %v", err)
|
||||
}
|
||||
vault2Metadata, err := vault.LoadVaultMetadata(fs, vault2Dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load vault2 metadata: %v", err)
|
||||
}
|
||||
|
||||
ltIdentity2, err := agehd.DeriveIdentity(testMnemonic, vault2Metadata.DerivationIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term key for vault2: %v", err)
|
||||
}
|
||||
|
||||
// Unlock the vaults with their respective keys
|
||||
vault1.Unlock(ltIdentity1)
|
||||
vault2.Unlock(ltIdentity2)
|
||||
|
||||
// Add a secret to vault1
|
||||
secretName := "test-secret"
|
||||
secretValue := []byte("secret in vault1")
|
||||
|
||||
secretBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||
defer secretBuffer.Destroy()
|
||||
|
||||
if err := vault1.AddSecret(secretName, secretBuffer, false); err != nil {
|
||||
t.Fatalf("Failed to add secret to vault1: %v", err)
|
||||
}
|
||||
|
||||
// Verify the secret exists in vault1
|
||||
vault1Secrets, err := vault1.ListSecrets()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list secrets in vault1: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, s := range vault1Secrets {
|
||||
if s == secretName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("Secret not found in vault1")
|
||||
}
|
||||
|
||||
// Verify the secret does NOT exist in vault2
|
||||
vault2Secrets, err := vault2.ListSecrets()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list secrets in vault2: %v", err)
|
||||
}
|
||||
|
||||
found = false
|
||||
for _, s := range vault2Secrets {
|
||||
if s == secretName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
t.Errorf("Secret from vault1 should not be visible in vault2")
|
||||
}
|
||||
})
|
||||
}
|
||||
354
internal/vault/integration_version_test.go
Normal file
354
internal/vault/integration_version_test.go
Normal file
@@ -0,0 +1,354 @@
|
||||
// Version Support Integration Tests
|
||||
//
|
||||
// Comprehensive integration tests for version functionality:
|
||||
//
|
||||
// - TestVersionIntegrationWorkflow: End-to-end workflow testing
|
||||
// - Creating initial version with proper metadata
|
||||
// - Creating multiple versions with timestamp updates
|
||||
// - Retrieving specific versions by name
|
||||
// - Promoting old versions to current
|
||||
// - Testing version serial number limits (999/day)
|
||||
// - Error cases and edge conditions
|
||||
//
|
||||
// - TestVersionConcurrency: Tests concurrent read operations
|
||||
//
|
||||
// - TestVersionCompatibility: Tests handling of legacy non-versioned secrets
|
||||
//
|
||||
// Test Environment:
|
||||
// - Uses in-memory filesystem (afero.MemMapFs)
|
||||
// - Consistent test mnemonic for reproducible keys
|
||||
// - Proper cleanup and isolation between tests
|
||||
|
||||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper function to add a secret to vault with proper buffer protection
|
||||
func addTestSecret(t *testing.T, vault *Vault, name string, value []byte, force bool) {
|
||||
t.Helper()
|
||||
buffer := memguard.NewBufferFromBytes(value)
|
||||
defer buffer.Destroy()
|
||||
err := vault.AddSecret(name, buffer, force)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestVersionIntegrationWorkflow tests the complete version workflow
|
||||
func TestVersionIntegrationWorkflow(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Set mnemonic for testing
|
||||
t.Setenv(secret.EnvMnemonic,
|
||||
"abandon abandon abandon abandon abandon abandon "+
|
||||
"abandon abandon abandon abandon abandon about")
|
||||
|
||||
// Create vault
|
||||
vault, err := CreateVault(fs, stateDir, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Derive and store long-term key from mnemonic
|
||||
mnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
ltIdentity, err := agehd.DeriveIdentity(mnemonic, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store long-term public key in vault
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, ltPubKeyPath, []byte(ltIdentity.Recipient().String()), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unlock the vault
|
||||
vault.Unlock(ltIdentity)
|
||||
|
||||
secretName := "integration/test"
|
||||
|
||||
// Step 1: Create initial version
|
||||
t.Run("create_initial_version", func(t *testing.T) {
|
||||
addTestSecret(t, vault, secretName, []byte("version-1-data"), false)
|
||||
|
||||
// Verify secret can be retrieved
|
||||
value, err := vault.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-1-data"), value)
|
||||
|
||||
// Verify version directory structure
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, versions, 1)
|
||||
|
||||
// Verify current symlink exists
|
||||
currentVersion, err := secret.GetCurrentVersion(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, versions[0], currentVersion)
|
||||
|
||||
// Verify metadata
|
||||
version := secret.NewVersion(vault, secretName, versions[0])
|
||||
err = version.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, version.Metadata.CreatedAt)
|
||||
assert.NotNil(t, version.Metadata.NotBefore)
|
||||
assert.Equal(t, int64(1), version.Metadata.NotBefore.Unix()) // epoch + 1
|
||||
assert.Nil(t, version.Metadata.NotAfter) // should be nil for current version
|
||||
})
|
||||
|
||||
// Step 2: Create second version
|
||||
var firstVersionName string
|
||||
t.Run("create_second_version", func(t *testing.T) {
|
||||
// Small delay to ensure different timestamps
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Get first version name before creating second
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
firstVersionName = versions[0]
|
||||
|
||||
// Create second version
|
||||
addTestSecret(t, vault, secretName, []byte("version-2-data"), true)
|
||||
|
||||
// Verify new value is current
|
||||
value, err := vault.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-2-data"), value)
|
||||
|
||||
// Verify we now have two versions
|
||||
versions, err = secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, versions, 2)
|
||||
|
||||
// Verify first version metadata was updated with notAfter
|
||||
firstVersion := secret.NewVersion(vault, secretName, firstVersionName)
|
||||
err = firstVersion.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, firstVersion.Metadata.NotAfter)
|
||||
|
||||
// Verify second version metadata
|
||||
secondVersion := secret.NewVersion(vault, secretName, versions[0])
|
||||
err = secondVersion.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, secondVersion.Metadata.NotBefore)
|
||||
assert.Nil(t, secondVersion.Metadata.NotAfter)
|
||||
|
||||
// NotBefore of second should equal NotAfter of first
|
||||
assert.Equal(t, firstVersion.Metadata.NotAfter.Unix(), secondVersion.Metadata.NotBefore.Unix())
|
||||
})
|
||||
|
||||
// Step 3: Create third version
|
||||
t.Run("create_third_version", func(t *testing.T) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
addTestSecret(t, vault, secretName, []byte("version-3-data"), true)
|
||||
|
||||
// Verify we now have three versions
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, versions, 3)
|
||||
|
||||
// Current should be version-3
|
||||
value, err := vault.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-3-data"), value)
|
||||
})
|
||||
|
||||
// Step 4: Retrieve specific versions
|
||||
t.Run("retrieve_specific_versions", func(t *testing.T) {
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 3)
|
||||
|
||||
// Get each version by its name
|
||||
value1, err := vault.GetSecretVersion(secretName, versions[2]) // oldest
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-1-data"), value1)
|
||||
|
||||
value2, err := vault.GetSecretVersion(secretName, versions[1]) // middle
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-2-data"), value2)
|
||||
|
||||
value3, err := vault.GetSecretVersion(secretName, versions[0]) // newest
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-3-data"), value3)
|
||||
|
||||
// Empty version should return current
|
||||
valueCurrent, err := vault.GetSecretVersion(secretName, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-3-data"), valueCurrent)
|
||||
})
|
||||
|
||||
// Step 5: Promote old version to current
|
||||
t.Run("promote_old_version", func(t *testing.T) {
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", "integration%test")
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Promote the first version (oldest) to current
|
||||
oldestVersion := versions[2]
|
||||
err = secret.SetCurrentVersion(fs, secretDir, oldestVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify current now returns the old version's value
|
||||
value, err := vault.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-1-data"), value)
|
||||
|
||||
// Verify the version metadata hasn't changed
|
||||
// (promoting shouldn't modify timestamps)
|
||||
version := secret.NewVersion(vault, secretName, oldestVersion)
|
||||
err = version.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, version.Metadata.NotAfter) // should still have its old notAfter
|
||||
})
|
||||
|
||||
// Step 6: Test version limits
|
||||
t.Run("version_serial_limits", func(t *testing.T) {
|
||||
// Create a new secret for this test
|
||||
limitSecretName := "limit/test"
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", "limit%test", "versions")
|
||||
|
||||
// Create 998 versions (we already have one from the first AddSecret)
|
||||
addTestSecret(t, vault, limitSecretName, []byte("initial"), false)
|
||||
|
||||
// Get today's date for consistent version names
|
||||
today := time.Now().Format("20060102")
|
||||
|
||||
// Manually create many versions with same date
|
||||
for i := 2; i <= 998; i++ {
|
||||
versionName := fmt.Sprintf("%s.%03d", today, i)
|
||||
versionDir := filepath.Join(secretDir, versionName)
|
||||
err := fs.MkdirAll(versionDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Should be able to create one more (999)
|
||||
versionName, err := secret.GenerateVersionName(fs, filepath.Dir(secretDir))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, fmt.Sprintf("%s.999", today), versionName)
|
||||
|
||||
// Create the 999th version directory
|
||||
err = fs.MkdirAll(filepath.Join(secretDir, versionName), 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should fail to create 1000th version
|
||||
_, err = secret.GenerateVersionName(fs, filepath.Dir(secretDir))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "exceeded maximum versions per day")
|
||||
})
|
||||
|
||||
// Step 7: Test error cases
|
||||
t.Run("error_cases", func(t *testing.T) {
|
||||
// Try to get non-existent version
|
||||
_, err := vault.GetSecretVersion(secretName, "99991231.999")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
|
||||
// Try to get version of non-existent secret
|
||||
_, err = vault.GetSecretVersion("nonexistent/secret", "")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Try to add secret without force when it exists
|
||||
failBuffer := memguard.NewBufferFromBytes([]byte("should-fail"))
|
||||
defer failBuffer.Destroy()
|
||||
err = vault.AddSecret(secretName, failBuffer, false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
})
|
||||
}
|
||||
|
||||
// TestVersionConcurrency tests concurrent version operations
|
||||
func TestVersionConcurrency(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Set up vault
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
|
||||
secretName := "concurrent/test"
|
||||
|
||||
// Create initial version
|
||||
addTestSecret(t, vault, secretName, []byte("initial"), false)
|
||||
|
||||
// Test concurrent reads
|
||||
t.Run("concurrent_reads", func(t *testing.T) {
|
||||
done := make(chan bool, 10)
|
||||
errors := make(chan error, 10)
|
||||
|
||||
for range 10 {
|
||||
go func() {
|
||||
value, err := vault.GetSecret(secretName)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
} else if string(value) != "initial" {
|
||||
errors <- fmt.Errorf("unexpected value: %s", value)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for range 10 {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
select {
|
||||
case err := <-errors:
|
||||
t.Fatalf("concurrent read failed: %v", err)
|
||||
default:
|
||||
// No errors
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestVersionCompatibility tests that old secrets without versions still work
|
||||
func TestVersionCompatibility(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Set up vault
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
ltIdentity, err := vault.GetOrDeriveLongTermKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually create an old-style secret (no versions)
|
||||
secretName := "legacy/secret"
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", "legacy%secret")
|
||||
err = fs.MkdirAll(secretDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create old-style encrypted value directly in secret directory
|
||||
testValue := []byte("legacy-value")
|
||||
testValueBuffer := memguard.NewBufferFromBytes(testValue)
|
||||
defer testValueBuffer.Destroy()
|
||||
ltRecipient := ltIdentity.Recipient()
|
||||
encrypted, err := secret.EncryptToRecipient(testValueBuffer, ltRecipient)
|
||||
require.NoError(t, err)
|
||||
|
||||
valuePath := filepath.Join(secretDir, "value.age")
|
||||
err = afero.WriteFile(fs, valuePath, encrypted, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should fail to get with version-aware methods
|
||||
_, err = vault.GetSecret(secretName)
|
||||
assert.Error(t, err)
|
||||
|
||||
// List versions should return empty
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, versions)
|
||||
}
|
||||
363
internal/vault/management.go
Normal file
363
internal/vault/management.go
Normal file
@@ -0,0 +1,363 @@
|
||||
// Package vault provides functionality for managing encrypted vaults.
|
||||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// Register the GetCurrentVault function with the secret package
|
||||
func init() {
|
||||
secret.RegisterGetCurrentVaultFunc(func(fs afero.Fs, stateDir string) (secret.VaultInterface, error) {
|
||||
return GetCurrentVault(fs, stateDir)
|
||||
})
|
||||
}
|
||||
|
||||
// isValidVaultName validates vault names according to the format [a-z0-9\.\-\_]+
|
||||
// Note: We don't allow slashes in vault names unlike secret names
|
||||
func isValidVaultName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
matched, _ := regexp.MatchString(`^[a-z0-9\.\-\_]+$`, name)
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
// resolveRelativeSymlink resolves a relative symlink target to an absolute path
|
||||
func resolveRelativeSymlink(symlinkPath, _ string) (string, error) {
|
||||
// Get the current directory before changing
|
||||
originalDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current directory: %w", err)
|
||||
}
|
||||
secret.Debug("Got current directory", "original_dir", originalDir)
|
||||
|
||||
// Change to the symlink's directory
|
||||
symlinkDir := filepath.Dir(symlinkPath)
|
||||
secret.Debug("Changing to symlink directory", "symlink_path", symlinkDir)
|
||||
secret.Debug("About to call os.Chdir - this might hang if symlink is broken")
|
||||
if err := os.Chdir(symlinkDir); err != nil {
|
||||
return "", fmt.Errorf("failed to change to symlink directory: %w", err)
|
||||
}
|
||||
secret.Debug("Changed to symlink directory successfully - os.Chdir completed")
|
||||
|
||||
// Get the absolute path of the target
|
||||
secret.Debug("Getting absolute path of current directory")
|
||||
absolutePath, err := os.Getwd()
|
||||
if err != nil {
|
||||
// Try to restore original directory before returning error
|
||||
_ = os.Chdir(originalDir)
|
||||
|
||||
return "", fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
secret.Debug("Got absolute path", "absolute_path", absolutePath)
|
||||
|
||||
// Restore the original directory
|
||||
secret.Debug("Restoring original directory", "original_dir", originalDir)
|
||||
if err := os.Chdir(originalDir); err != nil {
|
||||
return "", fmt.Errorf("failed to restore original directory: %w", err)
|
||||
}
|
||||
secret.Debug("Restored original directory successfully")
|
||||
|
||||
return absolutePath, nil
|
||||
}
|
||||
|
||||
// ResolveVaultSymlink resolves the currentvault symlink by reading either the symlink target or file contents
|
||||
// This function is designed to work on both Unix and Windows systems, as well as with in-memory filesystems
|
||||
func ResolveVaultSymlink(fs afero.Fs, symlinkPath string) (string, error) {
|
||||
secret.Debug("resolveVaultSymlink starting", "symlink_path", symlinkPath)
|
||||
|
||||
// First try to handle the path as a real symlink (works on Unix systems)
|
||||
_, isOsFs := fs.(*afero.OsFs)
|
||||
if isOsFs {
|
||||
target, err := tryResolveOsSymlink(symlinkPath)
|
||||
if err == nil {
|
||||
secret.Debug("resolveVaultSymlink completed successfully", "result", target)
|
||||
|
||||
return target, nil
|
||||
}
|
||||
// Fall through to fallback if symlink resolution failed
|
||||
} else {
|
||||
secret.Debug("Not using OS filesystem, skipping symlink resolution")
|
||||
}
|
||||
|
||||
// Fallback: treat it as a regular file containing the target path
|
||||
secret.Debug("Fallback: trying to read regular file with target path")
|
||||
|
||||
fileData, err := afero.ReadFile(fs, symlinkPath)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to read target path file", "error", err)
|
||||
|
||||
return "", fmt.Errorf("failed to read vault symlink: %w", err)
|
||||
}
|
||||
|
||||
target := string(fileData)
|
||||
secret.Debug("Read target path from file", "target", target)
|
||||
|
||||
secret.Debug("resolveVaultSymlink completed via fallback", "result", target)
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// tryResolveOsSymlink attempts to resolve a symlink on OS filesystems
|
||||
func tryResolveOsSymlink(symlinkPath string) (string, error) {
|
||||
secret.Debug("Using real filesystem symlink resolution")
|
||||
|
||||
// Check if the symlink exists
|
||||
secret.Debug("Checking symlink target", "symlink_path", symlinkPath)
|
||||
target, err := os.Readlink(symlinkPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
secret.Debug("Symlink points to", "target", target)
|
||||
|
||||
// On real filesystem, we need to handle relative symlinks
|
||||
// by resolving them relative to the symlink's directory
|
||||
if !filepath.IsAbs(target) {
|
||||
return resolveRelativeSymlink(symlinkPath, target)
|
||||
}
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// GetCurrentVault gets the current vault from the file system
|
||||
func GetCurrentVault(fs afero.Fs, stateDir string) (*Vault, error) {
|
||||
secret.Debug("Getting current vault", "state_dir", stateDir)
|
||||
|
||||
// Check if the current vault symlink exists
|
||||
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||
|
||||
secret.Debug("Checking current vault symlink", "path", currentVaultPath)
|
||||
_, err := fs.Stat(currentVaultPath)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to stat current vault symlink", "error", err, "path", currentVaultPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read current vault symlink: %w", err)
|
||||
}
|
||||
|
||||
secret.Debug("Current vault symlink exists")
|
||||
|
||||
// Resolve the symlink to get the actual vault directory
|
||||
secret.Debug("Resolving vault symlink")
|
||||
targetPath, err := ResolveVaultSymlink(fs, currentVaultPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secret.Debug("Resolved vault symlink", "target_path", targetPath)
|
||||
|
||||
// Extract the vault name from the path
|
||||
// The path will be something like "/path/to/vaults.d/default"
|
||||
vaultName := filepath.Base(targetPath)
|
||||
secret.Debug("Extracted vault name", "vault_name", vaultName)
|
||||
|
||||
secret.Debug("Current vault resolved", "vault_name", vaultName, "target_path", targetPath)
|
||||
|
||||
// Create and return the vault
|
||||
return NewVault(fs, stateDir, vaultName), nil
|
||||
}
|
||||
|
||||
// ListVaults lists all vaults in the state directory
|
||||
func ListVaults(fs afero.Fs, stateDir string) ([]string, error) {
|
||||
vaultsDir := filepath.Join(stateDir, "vaults.d")
|
||||
|
||||
// Check if vaults directory exists
|
||||
exists, err := afero.DirExists(fs, vaultsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if vaults directory exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// Read the vaults directory
|
||||
entries, err := afero.ReadDir(fs, vaultsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read vaults directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract vault names
|
||||
var vaults []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
vaults = append(vaults, entry.Name())
|
||||
}
|
||||
}
|
||||
|
||||
return vaults, nil
|
||||
}
|
||||
|
||||
// processMnemonicForVault handles mnemonic processing for vault creation
|
||||
func processMnemonicForVault(fs afero.Fs, stateDir, vaultDir, vaultName string) (
|
||||
derivationIndex uint32, publicKeyHash string, familyHash string, err error) {
|
||||
// Check if mnemonic is available in environment
|
||||
mnemonic := os.Getenv(secret.EnvMnemonic)
|
||||
|
||||
if mnemonic == "" {
|
||||
secret.Debug("No mnemonic in environment, vault created without long-term key", "vault", vaultName)
|
||||
// Use 0 for derivation index when no mnemonic is provided
|
||||
return 0, "", "", nil
|
||||
}
|
||||
|
||||
secret.Debug("Mnemonic found in environment, deriving long-term key", "vault", vaultName)
|
||||
|
||||
// Get the next available derivation index for this mnemonic
|
||||
derivationIndex, err = GetNextDerivationIndex(fs, stateDir, mnemonic)
|
||||
if err != nil {
|
||||
return 0, "", "", fmt.Errorf("failed to get next derivation index: %w", err)
|
||||
}
|
||||
|
||||
// Derive the long-term key using the actual derivation index
|
||||
ltIdentity, err := agehd.DeriveIdentity(mnemonic, derivationIndex)
|
||||
if err != nil {
|
||||
return 0, "", "", fmt.Errorf("failed to derive long-term key: %w", err)
|
||||
}
|
||||
|
||||
// Write the public key
|
||||
ltPubKey := ltIdentity.Recipient().String()
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
if err := afero.WriteFile(fs, ltPubKeyPath, []byte(ltPubKey), secret.FilePerms); err != nil {
|
||||
return 0, "", "", fmt.Errorf("failed to write long-term public key: %w", err)
|
||||
}
|
||||
secret.Debug("Wrote long-term public key", "path", ltPubKeyPath)
|
||||
|
||||
// Compute verification hash from actual derivation index
|
||||
publicKeyHash = ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
|
||||
|
||||
// Compute family hash from index 0 (same for all vaults with this mnemonic)
|
||||
// This is used to identify which vaults belong to the same mnemonic family
|
||||
identity0, err := agehd.DeriveIdentity(mnemonic, 0)
|
||||
if err != nil {
|
||||
return 0, "", "", fmt.Errorf("failed to derive identity for index 0: %w", err)
|
||||
}
|
||||
familyHash = ComputeDoubleSHA256([]byte(identity0.Recipient().String()))
|
||||
|
||||
return derivationIndex, publicKeyHash, familyHash, nil
|
||||
}
|
||||
|
||||
// CreateVault creates a new vault
|
||||
func CreateVault(fs afero.Fs, stateDir string, name string) (*Vault, error) {
|
||||
secret.Debug("Creating new vault", "name", name, "state_dir", stateDir)
|
||||
|
||||
// Validate vault name
|
||||
if !isValidVaultName(name) {
|
||||
secret.Debug("Invalid vault name provided", "vault_name", name)
|
||||
|
||||
return nil, fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name)
|
||||
}
|
||||
secret.Debug("Vault name validation passed", "vault_name", name)
|
||||
|
||||
// Create vault directory structure
|
||||
vaultDir := filepath.Join(stateDir, "vaults.d", name)
|
||||
secret.Debug("Creating vault directory structure", "vault_dir", vaultDir)
|
||||
|
||||
// Create main vault directory
|
||||
if err := fs.MkdirAll(vaultDir, secret.DirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create vault directory: %w", err)
|
||||
}
|
||||
|
||||
// Create secrets directory
|
||||
secretsDir := filepath.Join(vaultDir, "secrets.d")
|
||||
if err := fs.MkdirAll(secretsDir, secret.DirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create secrets directory: %w", err)
|
||||
}
|
||||
|
||||
// Create unlockers directory
|
||||
unlockersDir := filepath.Join(vaultDir, "unlockers.d")
|
||||
if err := fs.MkdirAll(unlockersDir, secret.DirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create unlockers directory: %w", err)
|
||||
}
|
||||
|
||||
// Process mnemonic if available
|
||||
derivationIndex, publicKeyHash, familyHash, err := processMnemonicForVault(fs, stateDir, vaultDir, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save vault metadata
|
||||
metadata := &Metadata{
|
||||
CreatedAt: time.Now(),
|
||||
DerivationIndex: derivationIndex,
|
||||
PublicKeyHash: publicKeyHash,
|
||||
MnemonicFamilyHash: familyHash,
|
||||
}
|
||||
if err := SaveVaultMetadata(fs, vaultDir, metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to save vault metadata: %w", err)
|
||||
}
|
||||
|
||||
// Select the newly created vault as current
|
||||
secret.Debug("Selecting newly created vault as current", "name", name)
|
||||
if err := SelectVault(fs, stateDir, name); err != nil {
|
||||
return nil, fmt.Errorf("failed to select vault: %w", err)
|
||||
}
|
||||
|
||||
// Create and return the vault
|
||||
secret.Debug("Successfully created vault", "name", name)
|
||||
|
||||
return NewVault(fs, stateDir, name), nil
|
||||
}
|
||||
|
||||
// SelectVault selects the given vault as the current vault
|
||||
func SelectVault(fs afero.Fs, stateDir string, name string) error {
|
||||
secret.Debug("Selecting vault", "vault_name", name, "state_dir", stateDir)
|
||||
|
||||
// Validate vault name
|
||||
if !isValidVaultName(name) {
|
||||
secret.Debug("Invalid vault name provided", "vault_name", name)
|
||||
|
||||
return fmt.Errorf("invalid vault name '%s': must match pattern [a-z0-9.\\-_]+", name)
|
||||
}
|
||||
secret.Debug("Vault name validation passed", "vault_name", name)
|
||||
|
||||
// Check if vault exists
|
||||
vaultDir := filepath.Join(stateDir, "vaults.d", name)
|
||||
exists, err := afero.DirExists(fs, vaultDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if vault exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("vault %s does not exist", name)
|
||||
}
|
||||
|
||||
// Create or update the current vault symlink/file
|
||||
currentVaultPath := filepath.Join(stateDir, "currentvault")
|
||||
targetPath := filepath.Join(stateDir, "vaults.d", name)
|
||||
|
||||
// First try to remove existing symlink if it exists
|
||||
if _, err := fs.Stat(currentVaultPath); err == nil {
|
||||
secret.Debug("Removing existing current vault symlink", "path", currentVaultPath)
|
||||
// Ignore errors from Remove as we'll try to create/update it anyway.
|
||||
// On some systems, removing a symlink may fail but the subsequent create may still succeed.
|
||||
_ = fs.Remove(currentVaultPath)
|
||||
}
|
||||
|
||||
// Try to create a real symlink first (works on Unix systems)
|
||||
if _, ok := fs.(*afero.OsFs); ok {
|
||||
secret.Debug("Creating vault symlink", "target", targetPath, "link", currentVaultPath)
|
||||
if err := os.Symlink(targetPath, currentVaultPath); err == nil {
|
||||
secret.Debug("Successfully selected vault", "vault_name", name)
|
||||
|
||||
return nil
|
||||
}
|
||||
// If symlink creation fails, fall back to regular file
|
||||
}
|
||||
|
||||
// Fallback: create a regular file with the target path
|
||||
secret.Debug("Fallback: creating regular file with target path", "target", targetPath)
|
||||
if err := afero.WriteFile(fs, currentVaultPath, []byte(targetPath), secret.FilePerms); err != nil {
|
||||
return fmt.Errorf("failed to select vault: %w", err)
|
||||
}
|
||||
|
||||
secret.Debug("Successfully selected vault", "vault_name", name)
|
||||
|
||||
return nil
|
||||
}
|
||||
131
internal/vault/metadata.go
Normal file
131
internal/vault/metadata.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// Metadata is an alias for secret.VaultMetadata
|
||||
type Metadata = secret.VaultMetadata
|
||||
|
||||
// UnlockerMetadata is an alias for secret.UnlockerMetadata
|
||||
type UnlockerMetadata = secret.UnlockerMetadata
|
||||
|
||||
// SecretMetadata is an alias for secret.Metadata
|
||||
type SecretMetadata = secret.Metadata
|
||||
|
||||
// Configuration is an alias for secret.Configuration
|
||||
type Configuration = secret.Configuration
|
||||
|
||||
// ComputeDoubleSHA256 computes the double SHA256 hash of data and returns it as hex
|
||||
func ComputeDoubleSHA256(data []byte) string {
|
||||
firstHash := sha256.Sum256(data)
|
||||
secondHash := sha256.Sum256(firstHash[:])
|
||||
|
||||
return hex.EncodeToString(secondHash[:])
|
||||
}
|
||||
|
||||
// GetNextDerivationIndex finds the next available derivation index for a given mnemonic
|
||||
// by deriving the public key for index 0 and using its hash to identify related vaults
|
||||
func GetNextDerivationIndex(fs afero.Fs, stateDir string, mnemonic string) (uint32, error) {
|
||||
// First, derive the public key for index 0 to get our identifier
|
||||
identity0, err := agehd.DeriveIdentity(mnemonic, 0)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to derive identity for index 0: %w", err)
|
||||
}
|
||||
pubKeyHash := ComputeDoubleSHA256([]byte(identity0.Recipient().String()))
|
||||
|
||||
vaultsDir := filepath.Join(stateDir, "vaults.d")
|
||||
|
||||
// Check if vaults directory exists
|
||||
exists, err := afero.DirExists(fs, vaultsDir)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to check if vaults directory exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
// No vaults yet, start with index 0
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read all vault directories
|
||||
entries, err := afero.ReadDir(fs, vaultsDir)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to read vaults directory: %w", err)
|
||||
}
|
||||
|
||||
// Track which indices are in use for this mnemonic
|
||||
usedIndices := make(map[uint32]bool)
|
||||
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to read vault metadata
|
||||
metadataPath := filepath.Join(vaultsDir, entry.Name(), "vault-metadata.json")
|
||||
metadataBytes, err := afero.ReadFile(fs, metadataPath)
|
||||
if err != nil {
|
||||
// Skip vaults without metadata
|
||||
continue
|
||||
}
|
||||
|
||||
var metadata Metadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
// Skip vaults with invalid metadata
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this vault uses the same mnemonic by comparing family hashes
|
||||
if metadata.MnemonicFamilyHash == pubKeyHash {
|
||||
usedIndices[metadata.DerivationIndex] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first available index
|
||||
var index uint32
|
||||
for usedIndices[index] {
|
||||
index++
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// SaveVaultMetadata saves vault metadata to the vault directory
|
||||
func SaveVaultMetadata(fs afero.Fs, vaultDir string, metadata *Metadata) error {
|
||||
metadataPath := filepath.Join(vaultDir, "vault-metadata.json")
|
||||
|
||||
metadataBytes, err := json.MarshalIndent(metadata, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal vault metadata: %w", err)
|
||||
}
|
||||
|
||||
if err := afero.WriteFile(fs, metadataPath, metadataBytes, secret.FilePerms); err != nil {
|
||||
return fmt.Errorf("failed to write vault metadata: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadVaultMetadata loads vault metadata from the vault directory
|
||||
func LoadVaultMetadata(fs afero.Fs, vaultDir string) (*Metadata, error) {
|
||||
metadataPath := filepath.Join(vaultDir, "vault-metadata.json")
|
||||
|
||||
metadataBytes, err := afero.ReadFile(fs, metadataPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read vault metadata: %w", err)
|
||||
}
|
||||
|
||||
var metadata Metadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal vault metadata: %w", err)
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
414
internal/vault/metadata_test.go
Normal file
414
internal/vault/metadata_test.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
func TestVaultMetadata(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Test mnemonic for consistent testing
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
|
||||
t.Run("ComputeDoubleSHA256", func(t *testing.T) {
|
||||
// Test data
|
||||
data := []byte("test data")
|
||||
hash := ComputeDoubleSHA256(data)
|
||||
|
||||
// Verify it's a valid hex string of 64 characters (32 bytes * 2)
|
||||
if len(hash) != 64 {
|
||||
t.Errorf("Expected hash length of 64, got %d", len(hash))
|
||||
}
|
||||
|
||||
// Verify consistency
|
||||
hash2 := ComputeDoubleSHA256(data)
|
||||
if hash != hash2 {
|
||||
t.Errorf("Hash should be consistent for same input")
|
||||
}
|
||||
|
||||
// Verify different input produces different hash
|
||||
hash3 := ComputeDoubleSHA256([]byte("different data"))
|
||||
if hash == hash3 {
|
||||
t.Errorf("Different input should produce different hash")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetNextDerivationIndex", func(t *testing.T) {
|
||||
// Test with no existing vaults
|
||||
index, err := GetNextDerivationIndex(fs, stateDir, testMnemonic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get derivation index: %v", err)
|
||||
}
|
||||
if index != 0 {
|
||||
t.Errorf("Expected index 0 for first vault, got %d", index)
|
||||
}
|
||||
|
||||
// Create a vault with metadata and matching public key
|
||||
vaultDir := filepath.Join(stateDir, "vaults.d", "vault1")
|
||||
if err := fs.MkdirAll(vaultDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Derive identity for index 0
|
||||
identity0, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive identity: %v", err)
|
||||
}
|
||||
pubKey0 := identity0.Recipient().String()
|
||||
pubKeyHash0 := ComputeDoubleSHA256([]byte(pubKey0))
|
||||
|
||||
// Write public key
|
||||
if err := afero.WriteFile(fs, filepath.Join(vaultDir, "pub.age"), []byte(pubKey0), 0o600); err != nil {
|
||||
t.Fatalf("Failed to write public key: %v", err)
|
||||
}
|
||||
|
||||
metadata1 := &Metadata{
|
||||
DerivationIndex: 0,
|
||||
PublicKeyHash: pubKeyHash0, // Hash of the actual key (index 0)
|
||||
MnemonicFamilyHash: pubKeyHash0, // Hash of index 0 key (for family identification)
|
||||
}
|
||||
if err := SaveVaultMetadata(fs, vaultDir, metadata1); err != nil {
|
||||
t.Fatalf("Failed to save metadata: %v", err)
|
||||
}
|
||||
|
||||
// Next index for same mnemonic should be 1
|
||||
index, err = GetNextDerivationIndex(fs, stateDir, testMnemonic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get derivation index: %v", err)
|
||||
}
|
||||
if index != 1 {
|
||||
t.Errorf("Expected index 1 for second vault with same mnemonic, got %d", index)
|
||||
}
|
||||
|
||||
// Different mnemonic should start at 0
|
||||
differentMnemonic := "zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo wrong"
|
||||
index, err = GetNextDerivationIndex(fs, stateDir, differentMnemonic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get derivation index: %v", err)
|
||||
}
|
||||
if index != 0 {
|
||||
t.Errorf("Expected index 0 for first vault with different mnemonic, got %d", index)
|
||||
}
|
||||
|
||||
// Add another vault with same mnemonic but higher index
|
||||
vaultDir2 := filepath.Join(stateDir, "vaults.d", "vault2")
|
||||
if err := fs.MkdirAll(vaultDir2, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Derive identity for index 5
|
||||
identity5, err := agehd.DeriveIdentity(testMnemonic, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive identity: %v", err)
|
||||
}
|
||||
pubKey5 := identity5.Recipient().String()
|
||||
|
||||
// Write public key
|
||||
if err := afero.WriteFile(fs, filepath.Join(vaultDir2, "pub.age"), []byte(pubKey5), 0o600); err != nil {
|
||||
t.Fatalf("Failed to write public key: %v", err)
|
||||
}
|
||||
|
||||
// Compute the hash for index 5 key
|
||||
pubKeyHash5 := ComputeDoubleSHA256([]byte(pubKey5))
|
||||
|
||||
metadata2 := &Metadata{
|
||||
DerivationIndex: 5,
|
||||
PublicKeyHash: pubKeyHash5, // Hash of the actual key (index 5)
|
||||
MnemonicFamilyHash: pubKeyHash0, // Same family hash since it's from the same mnemonic
|
||||
}
|
||||
if err := SaveVaultMetadata(fs, vaultDir2, metadata2); err != nil {
|
||||
t.Fatalf("Failed to save metadata: %v", err)
|
||||
}
|
||||
|
||||
// Next index should be 1 (not 6) because we look for the first available slot
|
||||
index, err = GetNextDerivationIndex(fs, stateDir, testMnemonic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get derivation index: %v", err)
|
||||
}
|
||||
if index != 1 {
|
||||
t.Errorf("Expected index 1 (first available), got %d", index)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MetadataPersistence", func(t *testing.T) {
|
||||
vaultDir := filepath.Join(stateDir, "vaults.d", "test-vault")
|
||||
if err := fs.MkdirAll(vaultDir, 0o700); err != nil {
|
||||
t.Fatalf("Failed to create vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Create and save metadata
|
||||
metadata := &Metadata{
|
||||
DerivationIndex: 3,
|
||||
PublicKeyHash: "test-public-key-hash",
|
||||
}
|
||||
|
||||
if err := SaveVaultMetadata(fs, vaultDir, metadata); err != nil {
|
||||
t.Fatalf("Failed to save metadata: %v", err)
|
||||
}
|
||||
|
||||
// Load and verify
|
||||
loaded, err := LoadVaultMetadata(fs, vaultDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load metadata: %v", err)
|
||||
}
|
||||
|
||||
if loaded.DerivationIndex != metadata.DerivationIndex {
|
||||
t.Errorf("DerivationIndex mismatch: expected %d, got %d", metadata.DerivationIndex, loaded.DerivationIndex)
|
||||
}
|
||||
if loaded.PublicKeyHash != metadata.PublicKeyHash {
|
||||
t.Errorf("PublicKeyHash mismatch: expected %s, got %s", metadata.PublicKeyHash, loaded.PublicKeyHash)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DifferentKeysForDifferentIndices", func(t *testing.T) {
|
||||
// Derive keys with different indices
|
||||
identity0, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive identity with index 0: %v", err)
|
||||
}
|
||||
|
||||
identity1, err := agehd.DeriveIdentity(testMnemonic, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive identity with index 1: %v", err)
|
||||
}
|
||||
|
||||
// Compute public key hashes
|
||||
pubKey0 := identity0.Recipient().String()
|
||||
pubKey1 := identity1.Recipient().String()
|
||||
hash0 := ComputeDoubleSHA256([]byte(pubKey0))
|
||||
|
||||
// Verify different indices produce different public keys
|
||||
if pubKey0 == pubKey1 {
|
||||
t.Errorf("Different derivation indices should produce different public keys")
|
||||
}
|
||||
|
||||
// But the hash of index 0's public key should be the same for the same mnemonic
|
||||
// This is what we use as the identifier
|
||||
identity0Again, _ := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
pubKey0Again := identity0Again.Recipient().String()
|
||||
hash0Again := ComputeDoubleSHA256([]byte(pubKey0Again))
|
||||
|
||||
if hash0 != hash0Again {
|
||||
t.Errorf("Same mnemonic should produce same public key hash for index 0")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPublicKeyHashConsistency(t *testing.T) {
|
||||
// Use the same test mnemonic that the integration test uses
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
|
||||
// Derive identity from index 0 multiple times
|
||||
identity1, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive first identity: %v", err)
|
||||
}
|
||||
|
||||
identity2, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive second identity: %v", err)
|
||||
}
|
||||
|
||||
// Verify identities are the same
|
||||
if identity1.Recipient().String() != identity2.Recipient().String() {
|
||||
t.Errorf("Identity derivation is not deterministic")
|
||||
t.Logf("First: %s", identity1.Recipient().String())
|
||||
t.Logf("Second: %s", identity2.Recipient().String())
|
||||
}
|
||||
|
||||
// Compute public key hashes
|
||||
hash1 := ComputeDoubleSHA256([]byte(identity1.Recipient().String()))
|
||||
hash2 := ComputeDoubleSHA256([]byte(identity2.Recipient().String()))
|
||||
|
||||
// Verify hashes are the same
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Public key hash computation is not deterministic")
|
||||
t.Logf("First hash: %s", hash1)
|
||||
t.Logf("Second hash: %s", hash2)
|
||||
}
|
||||
|
||||
t.Logf("Test mnemonic public key hash (index 0): %s", hash1)
|
||||
}
|
||||
|
||||
func TestSampleHashCalculation(t *testing.T) {
|
||||
// Test with the exact mnemonic from integration test if available
|
||||
// We'll also test with a few different mnemonics to make sure they produce different hashes
|
||||
mnemonics := []string{
|
||||
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
|
||||
"legal winner thank year wave sausage worth useful legal winner thank yellow",
|
||||
"zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo wrong",
|
||||
}
|
||||
|
||||
for i, mnemonic := range mnemonics {
|
||||
identity, err := agehd.DeriveIdentity(mnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive identity for mnemonic %d: %v", i, err)
|
||||
}
|
||||
|
||||
hash := ComputeDoubleSHA256([]byte(identity.Recipient().String()))
|
||||
t.Logf("Mnemonic %d hash (index 0): %s", i, hash)
|
||||
t.Logf(" Recipient: %s", identity.Recipient().String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkflowMismatch(t *testing.T) {
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
|
||||
// Create a temporary directory for testing
|
||||
tempDir := t.TempDir()
|
||||
fs := afero.NewOsFs()
|
||||
|
||||
// Test Case 1: Create vault WITH mnemonic (like init command)
|
||||
t.Setenv("SB_SECRET_MNEMONIC", testMnemonic)
|
||||
_, err := CreateVault(fs, tempDir, "default")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault with mnemonic: %v", err)
|
||||
}
|
||||
|
||||
// Load metadata for vault1
|
||||
vault1Dir := filepath.Join(tempDir, "vaults.d", "default")
|
||||
metadata1, err := LoadVaultMetadata(fs, vault1Dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load vault1 metadata: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Vault1 (with mnemonic) - DerivationIndex: %d, PublicKeyHash: %s",
|
||||
metadata1.DerivationIndex, metadata1.PublicKeyHash)
|
||||
|
||||
// Test Case 2: Create vault WITHOUT mnemonic, then import (like work vault)
|
||||
t.Setenv("SB_SECRET_MNEMONIC", "")
|
||||
_, err = CreateVault(fs, tempDir, "work")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault without mnemonic: %v", err)
|
||||
}
|
||||
|
||||
vault2Dir := filepath.Join(tempDir, "vaults.d", "work")
|
||||
|
||||
// Simulate the vault import process
|
||||
t.Setenv("SB_SECRET_MNEMONIC", testMnemonic)
|
||||
|
||||
// Get the next available derivation index for this mnemonic
|
||||
derivationIndex, err := GetNextDerivationIndex(fs, tempDir, testMnemonic)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get next derivation index: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Next derivation index for import: %d", derivationIndex)
|
||||
|
||||
// Calculate public key hash from index 0 (same as in VaultImport)
|
||||
identity0, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive identity for index 0: %v", err)
|
||||
}
|
||||
publicKeyHash := ComputeDoubleSHA256([]byte(identity0.Recipient().String()))
|
||||
|
||||
// Load existing metadata and update it (same as in VaultImport)
|
||||
existingMetadata, err := LoadVaultMetadata(fs, vault2Dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load existing metadata: %v", err)
|
||||
}
|
||||
|
||||
// Update metadata with new derivation info
|
||||
existingMetadata.DerivationIndex = derivationIndex
|
||||
existingMetadata.PublicKeyHash = publicKeyHash
|
||||
|
||||
if err := SaveVaultMetadata(fs, vault2Dir, existingMetadata); err != nil {
|
||||
t.Fatalf("Failed to save vault metadata: %v", err)
|
||||
}
|
||||
|
||||
// Load updated metadata for vault2
|
||||
metadata2, err := LoadVaultMetadata(fs, vault2Dir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load vault2 metadata: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Vault2 (imported mnemonic) - DerivationIndex: %d, PublicKeyHash: %s",
|
||||
metadata2.DerivationIndex, metadata2.PublicKeyHash)
|
||||
|
||||
// Verify that both vaults have the same public key hash
|
||||
if metadata1.PublicKeyHash != metadata2.PublicKeyHash {
|
||||
t.Errorf("Public key hashes don't match!")
|
||||
t.Logf("Vault1 hash: %s", metadata1.PublicKeyHash)
|
||||
t.Logf("Vault2 hash: %s", metadata2.PublicKeyHash)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Both vaults have the same public key hash: %s", metadata1.PublicKeyHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseEngineerHash(t *testing.T) {
|
||||
// This is the hash that the work vault is getting in the failing test
|
||||
wrongHash := "e34a2f500e395d8934a90a99ee9311edcfffd68cb701079575e50cbac7bb9417"
|
||||
correctHash := "992552b00b3879dfae461fab9a084b47784a032771c7a9accaebdde05ec7a7d1"
|
||||
|
||||
// Test mnemonic from integration test
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
|
||||
// Calculate hash for test mnemonic
|
||||
identity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive identity: %v", err)
|
||||
}
|
||||
|
||||
calculatedHash := ComputeDoubleSHA256([]byte(identity.Recipient().String()))
|
||||
t.Logf("Test mnemonic hash: %s", calculatedHash)
|
||||
|
||||
if calculatedHash == correctHash {
|
||||
t.Logf("✓ Test mnemonic produces the correct hash")
|
||||
} else {
|
||||
t.Errorf("✗ Test mnemonic does not produce the correct hash")
|
||||
}
|
||||
|
||||
if calculatedHash == wrongHash {
|
||||
t.Logf("✗ Test mnemonic unexpectedly produces the wrong hash")
|
||||
}
|
||||
|
||||
// Let's try some other possibilities - maybe there's a string normalization issue?
|
||||
variations := []string{
|
||||
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
|
||||
" abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about ",
|
||||
"abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about\n",
|
||||
strings.TrimSpace("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"),
|
||||
}
|
||||
|
||||
for i, variation := range variations {
|
||||
identity, err := agehd.DeriveIdentity(variation, 0)
|
||||
if err != nil {
|
||||
t.Logf("Variation %d failed: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
hash := ComputeDoubleSHA256([]byte(identity.Recipient().String()))
|
||||
t.Logf("Variation %d hash: %s", i, hash)
|
||||
|
||||
if hash == wrongHash {
|
||||
t.Logf("✗ Found variation that produces wrong hash: '%s'", variation)
|
||||
}
|
||||
}
|
||||
|
||||
// Maybe let's try an empty mnemonic or something else?
|
||||
emptyMnemonics := []string{
|
||||
"",
|
||||
" ",
|
||||
}
|
||||
|
||||
for i, emptyMnemonic := range emptyMnemonics {
|
||||
identity, err := agehd.DeriveIdentity(emptyMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Logf("Empty mnemonic %d failed (expected): %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
hash := ComputeDoubleSHA256([]byte(identity.Recipient().String()))
|
||||
t.Logf("Empty mnemonic %d hash: %s", i, hash)
|
||||
|
||||
if hash == wrongHash {
|
||||
t.Logf("✗ Empty mnemonic produces wrong hash!")
|
||||
}
|
||||
}
|
||||
}
|
||||
473
internal/vault/secrets.go
Normal file
473
internal/vault/secrets.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// ListSecrets returns a list of secret names in this vault
|
||||
func (v *Vault) ListSecrets() ([]string, error) {
|
||||
secret.DebugWith("Listing secrets in vault", slog.String("vault_name", v.Name))
|
||||
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get vault directory for secret listing", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secretsDir := filepath.Join(vaultDir, "secrets.d")
|
||||
|
||||
// Check if secrets directory exists
|
||||
exists, err := afero.DirExists(v.fs, secretsDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to check secrets directory", "error", err, "secrets_dir", secretsDir)
|
||||
|
||||
return nil, fmt.Errorf("failed to check if secrets directory exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
secret.Debug("Secrets directory does not exist", "secrets_dir", secretsDir, "vault_name", v.Name)
|
||||
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// List directories in secrets.d
|
||||
files, err := afero.ReadDir(v.fs, secretsDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to read secrets directory", "error", err, "secrets_dir", secretsDir)
|
||||
|
||||
return nil, fmt.Errorf("failed to read secrets directory: %w", err)
|
||||
}
|
||||
|
||||
var secrets []string
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
// Convert storage name back to secret name
|
||||
secretName := strings.ReplaceAll(file.Name(), "%", "/")
|
||||
secrets = append(secrets, secretName)
|
||||
}
|
||||
}
|
||||
|
||||
secret.DebugWith("Found secrets in vault",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.Int("secret_count", len(secrets)),
|
||||
slog.Any("secret_names", secrets),
|
||||
)
|
||||
|
||||
return secrets, nil
|
||||
}
|
||||
|
||||
// isValidSecretName validates secret names according to the format [a-z0-9\.\-\_\/]+
|
||||
// but with additional restrictions:
|
||||
// - No leading or trailing slashes
|
||||
// - No double slashes
|
||||
// - No names starting with dots
|
||||
func isValidSecretName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for leading/trailing slashes
|
||||
if strings.HasPrefix(name, "/") || strings.HasSuffix(name, "/") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for double slashes
|
||||
if strings.Contains(name, "//") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for names starting with dot
|
||||
if strings.HasPrefix(name, ".") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check the basic pattern
|
||||
matched, _ := regexp.MatchString(`^[a-z0-9\.\-\_\/]+$`, name)
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
// AddSecret adds a secret to this vault
|
||||
func (v *Vault) AddSecret(name string, value *memguard.LockedBuffer, force bool) error {
|
||||
if value == nil {
|
||||
return fmt.Errorf("value buffer is nil")
|
||||
}
|
||||
|
||||
secret.DebugWith("Adding secret to vault",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("secret_name", name),
|
||||
slog.Int("value_length", value.Size()),
|
||||
slog.Bool("force", force),
|
||||
)
|
||||
|
||||
// Validate secret name
|
||||
if !isValidSecretName(name) {
|
||||
secret.Debug("Invalid secret name provided", "secret_name", name)
|
||||
|
||||
return fmt.Errorf("invalid secret name '%s': must match pattern [a-z0-9.\\-_/]+", name)
|
||||
}
|
||||
secret.Debug("Secret name validation passed", "secret_name", name)
|
||||
|
||||
secret.Debug("Getting vault directory")
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get vault directory for secret addition", "error", err, "vault_name", v.Name)
|
||||
|
||||
return err
|
||||
}
|
||||
secret.Debug("Got vault directory", "vault_dir", vaultDir)
|
||||
|
||||
// Convert slashes to percent signs for storage
|
||||
storageName := strings.ReplaceAll(name, "/", "%")
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", storageName)
|
||||
|
||||
secret.DebugWith("Secret storage details",
|
||||
slog.String("storage_name", storageName),
|
||||
slog.String("secret_dir", secretDir),
|
||||
)
|
||||
|
||||
// Check if secret already exists
|
||||
secret.Debug("Checking if secret already exists", "secret_dir", secretDir)
|
||||
exists, err := afero.DirExists(v.fs, secretDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to check if secret exists", "error", err, "secret_dir", secretDir)
|
||||
|
||||
return fmt.Errorf("failed to check if secret exists: %w", err)
|
||||
}
|
||||
secret.Debug("Secret existence check complete", "exists", exists)
|
||||
|
||||
// Handle existing secret case
|
||||
now := time.Now()
|
||||
var previousVersion *secret.Version
|
||||
|
||||
if exists {
|
||||
if !force {
|
||||
secret.Debug("Secret already exists and force not specified", "secret_name", name, "secret_dir", secretDir)
|
||||
|
||||
return fmt.Errorf("secret %s already exists (use --force to overwrite)", name)
|
||||
}
|
||||
|
||||
// Get the current version to update its notAfter timestamp
|
||||
currentVersionName, err := secret.GetCurrentVersion(v.fs, secretDir)
|
||||
if err == nil && currentVersionName != "" {
|
||||
previousVersion = secret.NewVersion(v, name, currentVersionName)
|
||||
// We'll need to load and update its metadata after we unlock the vault
|
||||
}
|
||||
} else {
|
||||
// Create secret directory for new secret
|
||||
secret.Debug("Creating secret directory", "secret_dir", secretDir)
|
||||
if err := v.fs.MkdirAll(secretDir, secret.DirPerms); err != nil {
|
||||
secret.Debug("Failed to create secret directory", "error", err, "secret_dir", secretDir)
|
||||
|
||||
return fmt.Errorf("failed to create secret directory: %w", err)
|
||||
}
|
||||
secret.Debug("Created secret directory successfully")
|
||||
}
|
||||
|
||||
// Generate new version name
|
||||
versionName, err := secret.GenerateVersionName(v.fs, secretDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to generate version name", "error", err, "secret_name", name)
|
||||
|
||||
return fmt.Errorf("failed to generate version name: %w", err)
|
||||
}
|
||||
|
||||
secret.Debug("Generated new version name", "version", versionName, "secret_name", name)
|
||||
|
||||
// Create new version
|
||||
newVersion := secret.NewVersion(v, name, versionName)
|
||||
|
||||
// Set version timestamps
|
||||
if previousVersion == nil {
|
||||
// First version: notBefore = epoch + 1 second
|
||||
epochPlusOne := time.Unix(1, 0)
|
||||
newVersion.Metadata.NotBefore = &epochPlusOne
|
||||
} else {
|
||||
// New version: notBefore = now
|
||||
newVersion.Metadata.NotBefore = &now
|
||||
|
||||
// We'll update the previous version's notAfter after we save the new version
|
||||
}
|
||||
|
||||
// Save the new version - pass the LockedBuffer directly
|
||||
if err := newVersion.Save(value); err != nil {
|
||||
secret.Debug("Failed to save new version", "error", err, "version", versionName)
|
||||
|
||||
return fmt.Errorf("failed to save version: %w", err)
|
||||
}
|
||||
|
||||
// Update previous version if it exists
|
||||
if previousVersion != nil {
|
||||
// Get long-term key to decrypt/encrypt metadata
|
||||
ltIdentity, err := v.GetOrDeriveLongTermKey()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get long-term key for metadata update", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to get long-term key: %w", err)
|
||||
}
|
||||
|
||||
// Load previous version metadata
|
||||
if err := previousVersion.LoadMetadata(ltIdentity); err != nil {
|
||||
secret.Debug("Failed to load previous version metadata", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to load previous version metadata: %w", err)
|
||||
}
|
||||
|
||||
// Update notAfter timestamp
|
||||
previousVersion.Metadata.NotAfter = &now
|
||||
|
||||
// Re-save the metadata (we need to implement an update method)
|
||||
if err := updateVersionMetadata(v.fs, previousVersion, ltIdentity); err != nil {
|
||||
secret.Debug("Failed to update previous version metadata", "error", err)
|
||||
|
||||
return fmt.Errorf("failed to update previous version metadata: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set current symlink to new version
|
||||
if err := secret.SetCurrentVersion(v.fs, secretDir, versionName); err != nil {
|
||||
secret.Debug("Failed to set current version", "error", err, "version", versionName)
|
||||
|
||||
return fmt.Errorf("failed to set current version: %w", err)
|
||||
}
|
||||
|
||||
secret.Debug("Successfully added secret version to vault",
|
||||
"secret_name", name, "version", versionName,
|
||||
"vault_name", v.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVersionMetadata updates the metadata of an existing version
|
||||
func updateVersionMetadata(fs afero.Fs, version *secret.Version, ltIdentity *age.X25519Identity) error {
|
||||
// Read the version's encrypted private key
|
||||
encryptedPrivKeyPath := filepath.Join(version.Directory, "priv.age")
|
||||
encryptedPrivKey, err := afero.ReadFile(fs, encryptedPrivKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read encrypted version private key: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt version private key using long-term key
|
||||
versionPrivKeyData, err := secret.DecryptWithIdentity(encryptedPrivKey, ltIdentity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt version private key: %w", err)
|
||||
}
|
||||
|
||||
// Parse version private key
|
||||
versionIdentity, err := age.ParseX25519Identity(string(versionPrivKeyData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse version private key: %w", err)
|
||||
}
|
||||
|
||||
// Marshal updated metadata
|
||||
metadataBytes, err := json.MarshalIndent(version.Metadata, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal version metadata: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt metadata to the version's public key
|
||||
metadataBuffer := memguard.NewBufferFromBytes(metadataBytes)
|
||||
defer metadataBuffer.Destroy()
|
||||
|
||||
encryptedMetadata, err := secret.EncryptToRecipient(metadataBuffer, versionIdentity.Recipient())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt version metadata: %w", err)
|
||||
}
|
||||
|
||||
// Write encrypted metadata
|
||||
metadataPath := filepath.Join(version.Directory, "metadata.age")
|
||||
if err := afero.WriteFile(fs, metadataPath, encryptedMetadata, secret.FilePerms); err != nil {
|
||||
return fmt.Errorf("failed to write encrypted version metadata: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSecret retrieves a secret from this vault
|
||||
func (v *Vault) GetSecret(name string) ([]byte, error) {
|
||||
secret.DebugWith("Getting secret from vault",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("secret_name", name),
|
||||
)
|
||||
|
||||
return v.GetSecretVersion(name, "")
|
||||
}
|
||||
|
||||
// GetSecretVersion retrieves a specific version of a secret (empty version means current)
|
||||
func (v *Vault) GetSecretVersion(name string, version string) ([]byte, error) {
|
||||
secret.DebugWith("Getting secret version from vault",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("secret_name", name),
|
||||
slog.String("version", version),
|
||||
)
|
||||
|
||||
// Get vault directory
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get vault directory", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert slashes to percent signs for storage
|
||||
storageName := strings.ReplaceAll(name, "/", "%")
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", storageName)
|
||||
|
||||
// Check if secret exists
|
||||
exists, err := afero.DirExists(v.fs, secretDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to check if secret exists", "error", err, "secret_name", name)
|
||||
|
||||
return nil, fmt.Errorf("failed to check if secret exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
secret.Debug("Secret not found in vault", "secret_name", name, "vault_name", v.Name)
|
||||
|
||||
return nil, fmt.Errorf("secret %s not found", name)
|
||||
}
|
||||
|
||||
// Determine which version to get
|
||||
if version == "" {
|
||||
// Get current version
|
||||
currentVersion, err := secret.GetCurrentVersion(v.fs, secretDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get current version", "error", err, "secret_name", name)
|
||||
|
||||
return nil, fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
version = currentVersion
|
||||
secret.Debug("Using current version", "version", version, "secret_name", name)
|
||||
}
|
||||
|
||||
// Create version object
|
||||
secretVersion := secret.NewVersion(v, name, version)
|
||||
|
||||
// Check if version exists
|
||||
versionPath := filepath.Join(secretDir, "versions", version)
|
||||
exists, err = afero.DirExists(v.fs, versionPath)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to check if version exists", "error", err, "version", version)
|
||||
|
||||
return nil, fmt.Errorf("failed to check if version exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
secret.Debug("Version not found", "version", version, "secret_name", name)
|
||||
|
||||
return nil, fmt.Errorf("version %s not found for secret %s", version, name)
|
||||
}
|
||||
|
||||
secret.Debug("Version exists, proceeding with vault unlock and decryption", "version", version, "secret_name", name)
|
||||
|
||||
// Unlock the vault (get long-term key in memory)
|
||||
longTermIdentity, err := v.UnlockVault()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to unlock vault", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to unlock vault: %w", err)
|
||||
}
|
||||
|
||||
secret.DebugWith("Successfully unlocked vault",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("secret_name", name),
|
||||
slog.String("version", version),
|
||||
slog.String("long_term_public_key", longTermIdentity.Recipient().String()),
|
||||
)
|
||||
|
||||
// Get the version's value
|
||||
secret.Debug("About to call secretVersion.GetValue", "version", version, "secret_name", name)
|
||||
decryptedValue, err := secretVersion.GetValue(longTermIdentity)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to decrypt version value", "error", err, "version", version, "secret_name", name)
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt version: %w", err)
|
||||
}
|
||||
|
||||
secret.DebugWith("Successfully decrypted secret version",
|
||||
slog.String("secret_name", name),
|
||||
slog.String("version", version),
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.Int("decrypted_length", len(decryptedValue)),
|
||||
)
|
||||
|
||||
// Debug: Log metadata about the decrypted value without exposing the actual secret
|
||||
secret.Debug("Vault secret decryption debug info",
|
||||
"secret_name", name,
|
||||
"version", version,
|
||||
"decrypted_value_length", len(decryptedValue),
|
||||
"is_empty", len(decryptedValue) == 0)
|
||||
|
||||
return decryptedValue, nil
|
||||
}
|
||||
|
||||
// UnlockVault unlocks the vault and returns the long-term private key
|
||||
func (v *Vault) UnlockVault() (*age.X25519Identity, error) {
|
||||
secret.Debug("Unlocking vault", "vault_name", v.Name)
|
||||
|
||||
// If vault is already unlocked, return the cached key
|
||||
if !v.Locked() {
|
||||
secret.Debug("Vault already unlocked, returning cached long-term key", "vault_name", v.Name)
|
||||
|
||||
return v.longTermKey, nil
|
||||
}
|
||||
|
||||
// Get or derive the long-term key (but don't store it yet)
|
||||
longTermIdentity, err := v.GetOrDeriveLongTermKey()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get or derive long-term key", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to get long-term key: %w", err)
|
||||
}
|
||||
|
||||
// Now unlock the vault by storing the key in memory
|
||||
v.Unlock(longTermIdentity)
|
||||
|
||||
secret.DebugWith("Successfully unlocked vault",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("public_key", longTermIdentity.Recipient().String()),
|
||||
)
|
||||
|
||||
return longTermIdentity, nil
|
||||
}
|
||||
|
||||
// GetSecretObject retrieves a Secret object with metadata loaded from this vault
|
||||
func (v *Vault) GetSecretObject(name string) (*secret.Secret, error) {
|
||||
// First check if the secret exists by checking for the metadata file
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert slashes to percent signs for storage
|
||||
storageName := strings.ReplaceAll(name, "/", "%")
|
||||
secretDir := filepath.Join(vaultDir, "secrets.d", storageName)
|
||||
|
||||
// Check if secret directory exists
|
||||
exists, err := afero.DirExists(v.fs, secretDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if secret exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("secret %s not found", name)
|
||||
}
|
||||
|
||||
// Create a Secret object
|
||||
secretObj := secret.NewSecret(v, name)
|
||||
|
||||
// Load the metadata from disk
|
||||
if err := secretObj.LoadMetadata(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return secretObj, nil
|
||||
}
|
||||
310
internal/vault/secrets_version_test.go
Normal file
310
internal/vault/secrets_version_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
// Vault-Level Version Operation Tests
|
||||
//
|
||||
// Integration tests for vault-level version operations:
|
||||
//
|
||||
// - TestVaultAddSecretCreatesVersion: Tests that AddSecret creates proper version structure
|
||||
// - TestVaultAddSecretMultipleVersions: Tests creating multiple versions with force flag
|
||||
// - TestVaultGetSecretVersion: Tests retrieving specific versions and current version
|
||||
// - TestVaultVersionTimestamps: Tests timestamp logic (notBefore/notAfter) across versions
|
||||
// - TestVaultGetNonExistentVersion: Tests error handling for invalid versions
|
||||
// - TestUpdateVersionMetadata: Tests metadata update functionality
|
||||
//
|
||||
// Version Management:
|
||||
// - List versions in reverse chronological order
|
||||
// - Promote any version to current
|
||||
// - Promotion doesn't modify timestamps
|
||||
// - Metadata remains encrypted and intact
|
||||
|
||||
package vault
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper function to add a secret to vault with proper buffer protection
|
||||
func addTestSecretToVault(t *testing.T, vault *Vault, name string, value []byte, force bool) {
|
||||
t.Helper()
|
||||
buffer := memguard.NewBufferFromBytes(value)
|
||||
defer buffer.Destroy()
|
||||
err := vault.AddSecret(name, buffer, force)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Helper function to create a vault with long-term key set up
|
||||
func createTestVaultWithKey(t *testing.T, fs afero.Fs, stateDir, vaultName string) *Vault {
|
||||
// Set mnemonic for testing
|
||||
t.Setenv(secret.EnvMnemonic, "abandon abandon abandon abandon abandon abandon abandon abandon abandon about")
|
||||
|
||||
// Create vault
|
||||
vault, err := CreateVault(fs, stateDir, vaultName)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Derive and store long-term key from mnemonic
|
||||
mnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
ltIdentity, err := agehd.DeriveIdentity(mnemonic, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store long-term public key in vault
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
ltPubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, ltPubKeyPath, []byte(ltIdentity.Recipient().String()), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unlock the vault with the derived key
|
||||
vault.Unlock(ltIdentity)
|
||||
|
||||
return vault
|
||||
}
|
||||
|
||||
func TestVaultAddSecretCreatesVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Create vault with long-term key
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
|
||||
// Add a secret
|
||||
secretName := "test/secret"
|
||||
secretValue := []byte("initial-value")
|
||||
expectedValue := make([]byte, len(secretValue))
|
||||
copy(expectedValue, secretValue)
|
||||
|
||||
addTestSecretToVault(t, vault, secretName, secretValue, false)
|
||||
|
||||
// Check that version directory was created
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
secretDir := vaultDir + "/secrets.d/test%secret"
|
||||
versionsDir := secretDir + "/versions"
|
||||
|
||||
// Should have one version
|
||||
entries, err := afero.ReadDir(fs, versionsDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, entries, 1)
|
||||
|
||||
// Should have current symlink
|
||||
currentPath := secretDir + "/current"
|
||||
exists, err := afero.Exists(fs, currentPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Get the secret value
|
||||
retrievedValue, err := vault.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, retrievedValue)
|
||||
}
|
||||
|
||||
func TestVaultAddSecretMultipleVersions(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Create vault with long-term key
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
|
||||
secretName := "test/secret"
|
||||
|
||||
// Add first version
|
||||
addTestSecretToVault(t, vault, secretName, []byte("version-1"), false)
|
||||
|
||||
// Try to add again without force - should fail
|
||||
failBuffer := memguard.NewBufferFromBytes([]byte("version-2"))
|
||||
defer failBuffer.Destroy()
|
||||
err := vault.AddSecret(secretName, failBuffer, false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
|
||||
// Add with force - should create new version
|
||||
addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
|
||||
|
||||
// Check that we have two versions
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
versionsDir := vaultDir + "/secrets.d/test%secret/versions"
|
||||
entries, err := afero.ReadDir(fs, versionsDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, entries, 2)
|
||||
|
||||
// Current value should be version-2
|
||||
value, err := vault.GetSecret(secretName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-2"), value)
|
||||
}
|
||||
|
||||
func TestVaultGetSecretVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Create vault with long-term key
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
|
||||
secretName := "test/secret"
|
||||
|
||||
// Add multiple versions
|
||||
addTestSecretToVault(t, vault, secretName, []byte("version-1"), false)
|
||||
|
||||
// Small delay to ensure different version names
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
|
||||
|
||||
// Get versions list
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
secretDir := vaultDir + "/secrets.d/test%secret"
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 2)
|
||||
|
||||
// Get specific version (first one)
|
||||
firstVersion := versions[1] // Last in list is first created
|
||||
value, err := vault.GetSecretVersion(secretName, firstVersion)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-1"), value)
|
||||
|
||||
// Get specific version (second one)
|
||||
secondVersion := versions[0] // First in list is most recent
|
||||
value, err = vault.GetSecretVersion(secretName, secondVersion)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-2"), value)
|
||||
|
||||
// Get current (empty version)
|
||||
value, err = vault.GetSecretVersion(secretName, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("version-2"), value)
|
||||
}
|
||||
|
||||
func TestVaultVersionTimestamps(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Create vault with long-term key
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
|
||||
// Get long-term key
|
||||
ltIdentity, err := vault.GetOrDeriveLongTermKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
secretName := "test/secret"
|
||||
|
||||
// Add first version
|
||||
beforeFirst := time.Now()
|
||||
v1Buffer := memguard.NewBufferFromBytes([]byte("version-1"))
|
||||
defer v1Buffer.Destroy()
|
||||
err = vault.AddSecret(secretName, v1Buffer, false)
|
||||
require.NoError(t, err)
|
||||
afterFirst := time.Now()
|
||||
|
||||
// Get first version metadata
|
||||
vaultDir, _ := vault.GetDirectory()
|
||||
secretDir := vaultDir + "/secrets.d/test%secret"
|
||||
versions, err := secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 1)
|
||||
|
||||
firstVersion := secret.NewVersion(vault, secretName, versions[0])
|
||||
err = firstVersion.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check first version timestamps
|
||||
assert.NotNil(t, firstVersion.Metadata.CreatedAt)
|
||||
assert.True(t, firstVersion.Metadata.CreatedAt.After(beforeFirst.Add(-time.Second)))
|
||||
assert.True(t, firstVersion.Metadata.CreatedAt.Before(afterFirst.Add(time.Second)))
|
||||
|
||||
assert.NotNil(t, firstVersion.Metadata.NotBefore)
|
||||
assert.Equal(t, int64(1), firstVersion.Metadata.NotBefore.Unix()) // Epoch + 1
|
||||
assert.Nil(t, firstVersion.Metadata.NotAfter) // Still current
|
||||
|
||||
// Add second version
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
beforeSecond := time.Now()
|
||||
addTestSecretToVault(t, vault, secretName, []byte("version-2"), true)
|
||||
afterSecond := time.Now()
|
||||
|
||||
// Get updated versions
|
||||
versions, err = secret.ListVersions(fs, secretDir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 2)
|
||||
|
||||
// Reload first version metadata (should have notAfter now)
|
||||
firstVersion = secret.NewVersion(vault, secretName, versions[1])
|
||||
err = firstVersion.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, firstVersion.Metadata.NotAfter)
|
||||
assert.True(t, firstVersion.Metadata.NotAfter.After(beforeSecond.Add(-time.Second)))
|
||||
assert.True(t, firstVersion.Metadata.NotAfter.Before(afterSecond.Add(time.Second)))
|
||||
|
||||
// Check second version timestamps
|
||||
secondVersion := secret.NewVersion(vault, secretName, versions[0])
|
||||
err = secondVersion.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, secondVersion.Metadata.NotBefore)
|
||||
assert.True(t, secondVersion.Metadata.NotBefore.After(beforeSecond.Add(-time.Second)))
|
||||
assert.True(t, secondVersion.Metadata.NotBefore.Before(afterSecond.Add(time.Second)))
|
||||
assert.Nil(t, secondVersion.Metadata.NotAfter) // Current version
|
||||
}
|
||||
|
||||
func TestVaultGetNonExistentVersion(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Create vault with long-term key
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
|
||||
// Add a secret
|
||||
addTestSecretToVault(t, vault, "test/secret", []byte("value"), false)
|
||||
|
||||
// Try to get non-existent version
|
||||
_, err := vault.GetSecretVersion("test/secret", "20991231.999")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestUpdateVersionMetadata(t *testing.T) {
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Create vault with long-term key
|
||||
vault := createTestVaultWithKey(t, fs, stateDir, "test")
|
||||
|
||||
// Get long-term key
|
||||
ltIdentity, err := vault.GetOrDeriveLongTermKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a version manually to test updateVersionMetadata
|
||||
secretName := "test/secret"
|
||||
versionName := "20231215.001"
|
||||
version := secret.NewVersion(vault, secretName, versionName)
|
||||
|
||||
// Set initial metadata
|
||||
now := time.Now()
|
||||
epochPlusOne := time.Unix(1, 0)
|
||||
version.Metadata.NotBefore = &epochPlusOne
|
||||
version.Metadata.NotAfter = nil
|
||||
|
||||
// Save version
|
||||
testBuffer := memguard.NewBufferFromBytes([]byte("test-value"))
|
||||
defer testBuffer.Destroy()
|
||||
err = version.Save(testBuffer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update metadata
|
||||
version.Metadata.NotAfter = &now
|
||||
err = updateVersionMetadata(fs, version, ltIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load and verify
|
||||
version2 := secret.NewVersion(vault, secretName, versionName)
|
||||
err = version2.LoadMetadata(ltIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, version2.Metadata.NotAfter)
|
||||
assert.Equal(t, now.Unix(), version2.Metadata.NotAfter.Unix())
|
||||
}
|
||||
407
internal/vault/unlockers.go
Normal file
407
internal/vault/unlockers.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// GetCurrentUnlocker returns the current unlocker for this vault
|
||||
func (v *Vault) GetCurrentUnlocker() (secret.Unlocker, error) {
|
||||
secret.DebugWith("Getting current unlocker", slog.String("vault_name", v.Name))
|
||||
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get vault directory for unlocker", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentUnlockerPath := filepath.Join(vaultDir, "current-unlocker")
|
||||
|
||||
// Check if the symlink exists
|
||||
_, err = v.fs.Stat(currentUnlockerPath)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to stat current unlocker symlink", "error", err, "path", currentUnlockerPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read current unlocker: %w", err)
|
||||
}
|
||||
|
||||
// Resolve the symlink to get the target directory
|
||||
unlockerDir, err := v.resolveUnlockerDirectory(currentUnlockerPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secret.DebugWith("Resolved unlocker directory",
|
||||
slog.String("unlocker_dir", unlockerDir),
|
||||
slog.String("vault_name", v.Name),
|
||||
)
|
||||
|
||||
// Read unlocker metadata
|
||||
metadataPath := filepath.Join(unlockerDir, "unlocker-metadata.json")
|
||||
secret.Debug("Reading unlocker metadata", "path", metadataPath)
|
||||
|
||||
metadataBytes, err := afero.ReadFile(v.fs, metadataPath)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to read unlocker metadata", "error", err, "path", metadataPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read unlocker metadata: %w", err)
|
||||
}
|
||||
|
||||
var metadata UnlockerMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
secret.Debug("Failed to parse unlocker metadata", "error", err, "path", metadataPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to parse unlocker metadata: %w", err)
|
||||
}
|
||||
|
||||
secret.DebugWith("Parsed unlocker metadata",
|
||||
slog.String("unlocker_type", metadata.Type),
|
||||
slog.Time("created_at", metadata.CreatedAt),
|
||||
slog.Any("flags", metadata.Flags),
|
||||
)
|
||||
|
||||
// Create unlocker instance using direct constructors with filesystem
|
||||
var unlocker secret.Unlocker
|
||||
// Use metadata directly as it's already the correct type
|
||||
switch metadata.Type {
|
||||
case "passphrase":
|
||||
secret.Debug("Creating passphrase unlocker instance", "unlocker_type", metadata.Type)
|
||||
unlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDir, metadata)
|
||||
case "pgp":
|
||||
secret.Debug("Creating PGP unlocker instance", "unlocker_type", metadata.Type)
|
||||
unlocker = secret.NewPGPUnlocker(v.fs, unlockerDir, metadata)
|
||||
case "keychain":
|
||||
secret.Debug("Creating keychain unlocker instance", "unlocker_type", metadata.Type)
|
||||
unlocker = secret.NewKeychainUnlocker(v.fs, unlockerDir, metadata)
|
||||
default:
|
||||
secret.Debug("Unsupported unlocker type", "type", metadata.Type)
|
||||
|
||||
return nil, fmt.Errorf("unsupported unlocker type: %s", metadata.Type)
|
||||
}
|
||||
|
||||
secret.DebugWith("Successfully created unlocker instance",
|
||||
slog.String("unlocker_type", unlocker.GetType()),
|
||||
slog.String("unlocker_id", unlocker.GetID()),
|
||||
slog.String("vault_name", v.Name),
|
||||
)
|
||||
|
||||
return unlocker, nil
|
||||
}
|
||||
|
||||
// resolveUnlockerDirectory resolves the unlocker directory from a symlink or file
|
||||
func (v *Vault) resolveUnlockerDirectory(currentUnlockerPath string) (string, error) {
|
||||
linkReader, ok := v.fs.(afero.LinkReader)
|
||||
if !ok {
|
||||
// Fallback for filesystems that don't support symlinks
|
||||
return v.readUnlockerPathFromFile(currentUnlockerPath)
|
||||
}
|
||||
|
||||
secret.Debug("Resolving unlocker symlink using afero")
|
||||
// Try to read as symlink first
|
||||
unlockerDir, err := linkReader.ReadlinkIfPossible(currentUnlockerPath)
|
||||
if err == nil {
|
||||
return unlockerDir, nil
|
||||
}
|
||||
|
||||
secret.Debug("Failed to read symlink, falling back to file contents",
|
||||
"error", err, "symlink_path", currentUnlockerPath)
|
||||
// Fallback: read the path from file contents
|
||||
return v.readUnlockerPathFromFile(currentUnlockerPath)
|
||||
}
|
||||
|
||||
// readUnlockerPathFromFile reads the unlocker directory path from a file
|
||||
func (v *Vault) readUnlockerPathFromFile(path string) (string, error) {
|
||||
secret.Debug("Reading unlocker path from file", "path", path)
|
||||
unlockerDirBytes, err := afero.ReadFile(v.fs, path)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to read unlocker path file", "error", err, "path", path)
|
||||
|
||||
return "", fmt.Errorf("failed to read current unlocker: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(unlockerDirBytes)), nil
|
||||
}
|
||||
|
||||
// findUnlockerByID finds an unlocker by its ID and returns the unlocker instance and its directory path
|
||||
func (v *Vault) findUnlockerByID(unlockersDir, unlockerID string) (secret.Unlocker, string, error) {
|
||||
files, err := afero.ReadDir(v.fs, unlockersDir)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read unlockers directory: %w", err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read metadata file
|
||||
metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json")
|
||||
exists, err := afero.Exists(v.fs, metadataPath)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err)
|
||||
}
|
||||
if !exists {
|
||||
// Skip directories without metadata - they might not be unlockers
|
||||
continue
|
||||
}
|
||||
|
||||
metadataBytes, err := afero.ReadFile(v.fs, metadataPath)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err)
|
||||
}
|
||||
|
||||
var metadata UnlockerMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err)
|
||||
}
|
||||
|
||||
unlockerDirPath := filepath.Join(unlockersDir, file.Name())
|
||||
|
||||
// Create the appropriate unlocker instance
|
||||
var tempUnlocker secret.Unlocker
|
||||
switch metadata.Type {
|
||||
case "passphrase":
|
||||
tempUnlocker = secret.NewPassphraseUnlocker(v.fs, unlockerDirPath, metadata)
|
||||
case "pgp":
|
||||
tempUnlocker = secret.NewPGPUnlocker(v.fs, unlockerDirPath, metadata)
|
||||
case "keychain":
|
||||
tempUnlocker = secret.NewKeychainUnlocker(v.fs, unlockerDirPath, metadata)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this unlocker's ID matches
|
||||
if tempUnlocker.GetID() == unlockerID {
|
||||
return tempUnlocker, unlockerDirPath, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
// ListUnlockers returns a list of available unlockers for this vault
|
||||
func (v *Vault) ListUnlockers() ([]UnlockerMetadata, error) {
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
unlockersDir := filepath.Join(vaultDir, "unlockers.d")
|
||||
|
||||
// Check if unlockers directory exists
|
||||
exists, err := afero.DirExists(v.fs, unlockersDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if unlockers directory exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return []UnlockerMetadata{}, nil
|
||||
}
|
||||
|
||||
// List directories in unlockers.d
|
||||
files, err := afero.ReadDir(v.fs, unlockersDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read unlockers directory: %w", err)
|
||||
}
|
||||
|
||||
var unlockers []UnlockerMetadata
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
// Read metadata file
|
||||
metadataPath := filepath.Join(unlockersDir, file.Name(), "unlocker-metadata.json")
|
||||
exists, err := afero.Exists(v.fs, metadataPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if metadata exists for unlocker %s: %w", file.Name(), err)
|
||||
}
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unlocker directory %s is missing metadata file", file.Name())
|
||||
}
|
||||
|
||||
metadataBytes, err := afero.ReadFile(v.fs, metadataPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read metadata for unlocker %s: %w", file.Name(), err)
|
||||
}
|
||||
|
||||
var metadata UnlockerMetadata
|
||||
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse metadata for unlocker %s: %w", file.Name(), err)
|
||||
}
|
||||
|
||||
unlockers = append(unlockers, metadata)
|
||||
}
|
||||
}
|
||||
|
||||
return unlockers, nil
|
||||
}
|
||||
|
||||
// RemoveUnlocker removes an unlocker from this vault
|
||||
func (v *Vault) RemoveUnlocker(unlockerID string) error {
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find the unlocker directory and create the unlocker instance
|
||||
unlockersDir := filepath.Join(vaultDir, "unlockers.d")
|
||||
|
||||
// Find the unlocker by ID
|
||||
unlocker, _, err := v.findUnlockerByID(unlockersDir, unlockerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if unlocker == nil {
|
||||
return fmt.Errorf("unlocker with ID %s not found", unlockerID)
|
||||
}
|
||||
|
||||
// Use the unlocker's Remove method
|
||||
return unlocker.Remove()
|
||||
}
|
||||
|
||||
// SelectUnlocker selects an unlocker as current for this vault
|
||||
func (v *Vault) SelectUnlocker(unlockerID string) error {
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find the unlocker directory by ID
|
||||
unlockersDir := filepath.Join(vaultDir, "unlockers.d")
|
||||
|
||||
// Find the unlocker by ID
|
||||
_, targetUnlockerDir, err := v.findUnlockerByID(unlockersDir, unlockerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if targetUnlockerDir == "" {
|
||||
return fmt.Errorf("unlocker with ID %s not found", unlockerID)
|
||||
}
|
||||
|
||||
// Create/update current unlocker symlink
|
||||
currentUnlockerPath := filepath.Join(vaultDir, "current-unlocker")
|
||||
|
||||
// Remove existing symlink if it exists
|
||||
if exists, err := afero.Exists(v.fs, currentUnlockerPath); err != nil {
|
||||
return fmt.Errorf("failed to check if current unlocker symlink exists: %w", err)
|
||||
} else if exists {
|
||||
if err := v.fs.Remove(currentUnlockerPath); err != nil {
|
||||
return fmt.Errorf("failed to remove existing unlocker symlink: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new symlink using afero's SymlinkIfPossible
|
||||
if linker, ok := v.fs.(afero.Linker); ok {
|
||||
secret.Debug("Creating unlocker symlink", "target", targetUnlockerDir, "link", currentUnlockerPath)
|
||||
if err := linker.SymlinkIfPossible(targetUnlockerDir, currentUnlockerPath); err != nil {
|
||||
return fmt.Errorf("failed to create unlocker symlink: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Fallback: create a regular file with the target path for filesystems that don't support symlinks
|
||||
secret.Debug("Fallback: creating regular file with target path", "target", targetUnlockerDir)
|
||||
if err := afero.WriteFile(v.fs, currentUnlockerPath, []byte(targetUnlockerDir), secret.FilePerms); err != nil {
|
||||
return fmt.Errorf("failed to create unlocker symlink file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreatePassphraseUnlocker creates a new passphrase-protected unlocker
|
||||
// The passphrase must be provided as a LockedBuffer for security
|
||||
func (v *Vault) CreatePassphraseUnlocker(passphrase *memguard.LockedBuffer) (*secret.PassphraseUnlocker, error) {
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get vault directory: %w", err)
|
||||
}
|
||||
|
||||
// Create unlocker directory
|
||||
unlockerDir := filepath.Join(vaultDir, "unlockers.d", "passphrase")
|
||||
if err := v.fs.MkdirAll(unlockerDir, secret.DirPerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to create unlocker directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate new age keypair for unlocker
|
||||
unlockerIdentity, err := age.GenerateX25519Identity()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate unlocker: %w", err)
|
||||
}
|
||||
|
||||
// Write public key
|
||||
pubKeyPath := filepath.Join(unlockerDir, "pub.age")
|
||||
if err := afero.WriteFile(v.fs, pubKeyPath,
|
||||
[]byte(unlockerIdentity.Recipient().String()),
|
||||
secret.FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write unlocker public key: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt private key with passphrase
|
||||
privKeyStr := unlockerIdentity.String()
|
||||
encryptedPrivKey, err := secret.EncryptWithPassphrase([]byte(privKeyStr), passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt unlocker private key: %w", err)
|
||||
}
|
||||
|
||||
// Write encrypted private key
|
||||
privKeyPath := filepath.Join(unlockerDir, "priv.age")
|
||||
if err := afero.WriteFile(v.fs, privKeyPath, encryptedPrivKey, secret.FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write encrypted unlocker private key: %w", err)
|
||||
}
|
||||
|
||||
// Create metadata
|
||||
metadata := UnlockerMetadata{
|
||||
Type: "passphrase",
|
||||
CreatedAt: time.Now(),
|
||||
Flags: []string{},
|
||||
}
|
||||
|
||||
// Write metadata
|
||||
metadataBytes, err := json.MarshalIndent(metadata, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
metadataPath := filepath.Join(unlockerDir, "unlocker-metadata.json")
|
||||
if err := afero.WriteFile(v.fs, metadataPath, metadataBytes, secret.FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write unlocker metadata: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt long-term private key to this unlocker
|
||||
// We need to get the long-term key (either from memory if unlocked, or derive it)
|
||||
ltIdentity, err := v.GetOrDeriveLongTermKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get long-term key: %w", err)
|
||||
}
|
||||
|
||||
ltPrivKeyBuffer := memguard.NewBufferFromBytes([]byte(ltIdentity.String()))
|
||||
defer ltPrivKeyBuffer.Destroy()
|
||||
|
||||
encryptedLtPrivKey, err := secret.EncryptToRecipient(ltPrivKeyBuffer, unlockerIdentity.Recipient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt long-term private key: %w", err)
|
||||
}
|
||||
|
||||
ltPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
|
||||
if err := afero.WriteFile(v.fs, ltPrivKeyPath, encryptedLtPrivKey, secret.FilePerms); err != nil {
|
||||
return nil, fmt.Errorf("failed to write encrypted long-term private key: %w", err)
|
||||
}
|
||||
|
||||
// Create the unlocker instance
|
||||
unlocker := secret.NewPassphraseUnlocker(v.fs, unlockerDir, metadata)
|
||||
|
||||
// Select this unlocker as current
|
||||
if err := v.SelectUnlocker(unlocker.GetID()); err != nil {
|
||||
return nil, fmt.Errorf("failed to select new unlocker: %w", err)
|
||||
}
|
||||
|
||||
return unlocker, nil
|
||||
}
|
||||
209
internal/vault/vault.go
Normal file
209
internal/vault/vault.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// Vault represents a secrets vault
|
||||
type Vault struct {
|
||||
Name string
|
||||
fs afero.Fs
|
||||
stateDir string
|
||||
longTermKey *age.X25519Identity // In-memory long-term key when unlocked
|
||||
}
|
||||
|
||||
// NewVault creates a new Vault instance
|
||||
func NewVault(fs afero.Fs, stateDir string, name string) *Vault {
|
||||
secret.Debug("Creating NewVault instance")
|
||||
v := &Vault{
|
||||
Name: name,
|
||||
fs: fs,
|
||||
stateDir: stateDir,
|
||||
longTermKey: nil,
|
||||
}
|
||||
secret.Debug("Created NewVault instance successfully")
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// Locked returns true if the vault doesn't have a long-term key in memory
|
||||
func (v *Vault) Locked() bool {
|
||||
return v.longTermKey == nil
|
||||
}
|
||||
|
||||
// Unlock sets the long-term key in memory, unlocking the vault
|
||||
func (v *Vault) Unlock(key *age.X25519Identity) {
|
||||
v.longTermKey = key
|
||||
}
|
||||
|
||||
// GetLongTermKey returns the long-term key if available in memory
|
||||
func (v *Vault) GetLongTermKey() *age.X25519Identity {
|
||||
return v.longTermKey
|
||||
}
|
||||
|
||||
// ClearLongTermKey removes the long-term key from memory (locks the vault)
|
||||
func (v *Vault) ClearLongTermKey() {
|
||||
v.longTermKey = nil
|
||||
}
|
||||
|
||||
// GetOrDeriveLongTermKey gets the long-term key from memory or derives it from available sources
|
||||
func (v *Vault) GetOrDeriveLongTermKey() (*age.X25519Identity, error) {
|
||||
// If we have it in memory, return it
|
||||
if !v.Locked() {
|
||||
return v.longTermKey, nil
|
||||
}
|
||||
|
||||
secret.Debug("Vault is locked, attempting to unlock", "vault_name", v.Name)
|
||||
|
||||
// Try to derive from environment mnemonic first
|
||||
if envMnemonic := os.Getenv(secret.EnvMnemonic); envMnemonic != "" {
|
||||
secret.Debug("Using mnemonic from environment for long-term key derivation", "vault_name", v.Name)
|
||||
|
||||
// Load vault metadata to get the derivation index
|
||||
vaultDir, err := v.GetDirectory()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get vault directory: %w", err)
|
||||
}
|
||||
|
||||
metadata, err := LoadVaultMetadata(v.fs, vaultDir)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to load vault metadata", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to load vault metadata: %w", err)
|
||||
}
|
||||
|
||||
ltIdentity, err := agehd.DeriveIdentity(envMnemonic, metadata.DerivationIndex)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to derive long-term key from mnemonic", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to derive long-term key from mnemonic: %w", err)
|
||||
}
|
||||
|
||||
// Verify that the derived key matches the stored public key hash
|
||||
derivedPubKeyHash := ComputeDoubleSHA256([]byte(ltIdentity.Recipient().String()))
|
||||
if derivedPubKeyHash != metadata.PublicKeyHash {
|
||||
secret.Debug("Derived public key hash does not match stored hash",
|
||||
"vault_name", v.Name,
|
||||
"derived_hash", derivedPubKeyHash,
|
||||
"stored_hash", metadata.PublicKeyHash,
|
||||
"derivation_index", metadata.DerivationIndex)
|
||||
|
||||
return nil, fmt.Errorf("derived public key does not match vault: mnemonic may be incorrect")
|
||||
}
|
||||
|
||||
secret.DebugWith("Successfully derived long-term key from mnemonic",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("public_key", ltIdentity.Recipient().String()),
|
||||
slog.Uint64("derivation_index", uint64(metadata.DerivationIndex)),
|
||||
)
|
||||
|
||||
// Cache the derived key by unlocking the vault
|
||||
v.Unlock(ltIdentity)
|
||||
secret.Debug("Vault is unlocked (lt key in memory) via mnemonic", "vault_name", v.Name)
|
||||
|
||||
return ltIdentity, nil
|
||||
}
|
||||
|
||||
// No mnemonic available, try to use current unlocker
|
||||
secret.Debug("No mnemonic available, using current unlocker to unlock vault", "vault_name", v.Name)
|
||||
|
||||
// Get current unlocker
|
||||
unlocker, err := v.GetCurrentUnlocker()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get current unlocker", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to get current unlocker: %w", err)
|
||||
}
|
||||
|
||||
secret.DebugWith("Retrieved current unlocker for vault unlock",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("unlocker_type", unlocker.GetType()),
|
||||
slog.String("unlocker_id", unlocker.GetID()),
|
||||
)
|
||||
|
||||
// Get unlocker identity
|
||||
unlockerIdentity, err := unlocker.GetIdentity()
|
||||
if err != nil {
|
||||
secret.Debug("Failed to get unlocker identity", "error", err, "unlocker_type", unlocker.GetType())
|
||||
|
||||
return nil, fmt.Errorf("failed to get unlocker identity: %w", err)
|
||||
}
|
||||
|
||||
// Read encrypted long-term private key from unlocker directory
|
||||
unlockerDir := unlocker.GetDirectory()
|
||||
encryptedLtPrivKeyPath := filepath.Join(unlockerDir, "longterm.age")
|
||||
secret.Debug("Reading encrypted long-term private key", "path", encryptedLtPrivKeyPath)
|
||||
|
||||
encryptedLtPrivKey, err := afero.ReadFile(v.fs, encryptedLtPrivKeyPath)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to read encrypted long-term private key", "error", err, "path", encryptedLtPrivKeyPath)
|
||||
|
||||
return nil, fmt.Errorf("failed to read encrypted long-term private key: %w", err)
|
||||
}
|
||||
|
||||
secret.DebugWith("Read encrypted long-term private key",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("unlocker_type", unlocker.GetType()),
|
||||
slog.Int("encrypted_length", len(encryptedLtPrivKey)),
|
||||
)
|
||||
|
||||
// Decrypt long-term private key using unlocker
|
||||
secret.Debug("Decrypting long-term private key with unlocker", "unlocker_type", unlocker.GetType())
|
||||
ltPrivKeyData, err := secret.DecryptWithIdentity(encryptedLtPrivKey, unlockerIdentity)
|
||||
if err != nil {
|
||||
secret.Debug("Failed to decrypt long-term private key", "error", err, "unlocker_type", unlocker.GetType())
|
||||
|
||||
return nil, fmt.Errorf("failed to decrypt long-term private key: %w", err)
|
||||
}
|
||||
|
||||
secret.DebugWith("Successfully decrypted long-term private key",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("unlocker_type", unlocker.GetType()),
|
||||
slog.Int("decrypted_length", len(ltPrivKeyData)),
|
||||
)
|
||||
|
||||
// Parse long-term private key
|
||||
secret.Debug("Parsing long-term private key", "vault_name", v.Name)
|
||||
ltIdentity, err := age.ParseX25519Identity(string(ltPrivKeyData))
|
||||
if err != nil {
|
||||
secret.Debug("Failed to parse long-term private key", "error", err, "vault_name", v.Name)
|
||||
|
||||
return nil, fmt.Errorf("failed to parse long-term private key: %w", err)
|
||||
}
|
||||
|
||||
secret.DebugWith("Successfully obtained long-term identity via unlocker",
|
||||
slog.String("vault_name", v.Name),
|
||||
slog.String("unlocker_type", unlocker.GetType()),
|
||||
slog.String("public_key", ltIdentity.Recipient().String()),
|
||||
)
|
||||
|
||||
// Cache the derived key by unlocking the vault
|
||||
v.Unlock(ltIdentity)
|
||||
secret.Debug("Vault is unlocked (lt key in memory) via unlocker",
|
||||
"vault_name", v.Name, "unlocker_type", unlocker.GetType())
|
||||
|
||||
return ltIdentity, nil
|
||||
}
|
||||
|
||||
// GetDirectory returns the vault's directory path
|
||||
func (v *Vault) GetDirectory() (string, error) {
|
||||
return filepath.Join(v.stateDir, "vaults.d", v.Name), nil
|
||||
}
|
||||
|
||||
// GetName returns the vault's name (for VaultInterface compatibility)
|
||||
func (v *Vault) GetName() string {
|
||||
return v.Name
|
||||
}
|
||||
|
||||
// GetFilesystem returns the vault's filesystem (for VaultInterface compatibility)
|
||||
func (v *Vault) GetFilesystem() afero.Fs {
|
||||
return v.fs
|
||||
}
|
||||
227
internal/vault/vault_test.go
Normal file
227
internal/vault/vault_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/secret"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
func TestVaultOperations(t *testing.T) {
|
||||
// Test environment will be cleaned up automatically by t.Setenv
|
||||
|
||||
// Set test environment variables
|
||||
testMnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
||||
t.Setenv(secret.EnvMnemonic, testMnemonic)
|
||||
t.Setenv(secret.EnvUnlockPassphrase, "test-passphrase")
|
||||
|
||||
// Use in-memory filesystem
|
||||
fs := afero.NewMemMapFs()
|
||||
stateDir := "/test/state"
|
||||
|
||||
// Test vault creation
|
||||
t.Run("CreateVault", func(t *testing.T) {
|
||||
vlt, err := CreateVault(fs, stateDir, "test-vault")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create vault: %v", err)
|
||||
}
|
||||
|
||||
if vlt.GetName() != "test-vault" {
|
||||
t.Errorf("Expected vault name 'test-vault', got '%s'", vlt.GetName())
|
||||
}
|
||||
|
||||
// Check vault directory exists
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
|
||||
exists, err := afero.DirExists(fs, vaultDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check vault directory: %v", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
t.Errorf("Vault directory should exist")
|
||||
}
|
||||
})
|
||||
|
||||
// Test vault listing
|
||||
t.Run("ListVaults", func(t *testing.T) {
|
||||
vaults, err := ListVaults(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list vaults: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, vault := range vaults {
|
||||
if vault == "test-vault" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("Expected to find 'test-vault' in vault list")
|
||||
}
|
||||
})
|
||||
|
||||
// Test vault selection
|
||||
t.Run("SelectVault", func(t *testing.T) {
|
||||
err := SelectVault(fs, stateDir, "test-vault")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select vault: %v", err)
|
||||
}
|
||||
|
||||
// Test getting current vault
|
||||
currentVault, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current vault: %v", err)
|
||||
}
|
||||
|
||||
if currentVault.GetName() != "test-vault" {
|
||||
t.Errorf("Expected current vault 'test-vault', got '%s'", currentVault.GetName())
|
||||
}
|
||||
})
|
||||
|
||||
// Test secret operations
|
||||
t.Run("SecretOperations", func(t *testing.T) {
|
||||
vlt, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current vault: %v", err)
|
||||
}
|
||||
|
||||
// First, derive the long-term key from the test mnemonic
|
||||
ltIdentity, err := agehd.DeriveIdentity(testMnemonic, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to derive long-term key: %v", err)
|
||||
}
|
||||
|
||||
// Get the public key from the derived identity
|
||||
ltPublicKey := ltIdentity.Recipient().String()
|
||||
|
||||
// Get the vault directory
|
||||
vaultDir, err := vlt.GetDirectory()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get vault directory: %v", err)
|
||||
}
|
||||
|
||||
// Write the correct public key to the pub.age file
|
||||
pubKeyPath := filepath.Join(vaultDir, "pub.age")
|
||||
err = afero.WriteFile(fs, pubKeyPath, []byte(ltPublicKey), secret.FilePerms)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write long-term public key: %v", err)
|
||||
}
|
||||
|
||||
// Unlock the vault with the derived identity
|
||||
vlt.Unlock(ltIdentity)
|
||||
|
||||
// Now add a secret
|
||||
secretName := "test/secret"
|
||||
secretValue := []byte("test-secret-value")
|
||||
expectedValue := make([]byte, len(secretValue))
|
||||
copy(expectedValue, secretValue)
|
||||
|
||||
secretBuffer := memguard.NewBufferFromBytes(secretValue)
|
||||
defer secretBuffer.Destroy()
|
||||
|
||||
err = vlt.AddSecret(secretName, secretBuffer, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add secret: %v", err)
|
||||
}
|
||||
|
||||
// List secrets
|
||||
secrets, err := vlt.ListSecrets()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list secrets: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, secret := range secrets {
|
||||
if secret == secretName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("Expected to find secret '%s' in list", secretName)
|
||||
}
|
||||
|
||||
// Get secret value
|
||||
retrievedValue, err := vlt.GetSecret(secretName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get secret: %v", err)
|
||||
}
|
||||
|
||||
if string(retrievedValue) != string(expectedValue) {
|
||||
t.Errorf("Expected secret value '%s', got '%s'", string(expectedValue), string(retrievedValue))
|
||||
}
|
||||
})
|
||||
|
||||
// Test unlocker operations
|
||||
t.Run("UnlockerOperations", func(t *testing.T) {
|
||||
vlt, err := GetCurrentVault(fs, stateDir)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current vault: %v", err)
|
||||
}
|
||||
|
||||
// Test vault unlocking (should happen automatically via mnemonic)
|
||||
if vlt.Locked() {
|
||||
_, err := vlt.UnlockVault()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unlock vault: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a passphrase unlocker
|
||||
passphraseBuffer := memguard.NewBufferFromBytes([]byte("test-passphrase"))
|
||||
defer passphraseBuffer.Destroy()
|
||||
passphraseUnlocker, err := vlt.CreatePassphraseUnlocker(passphraseBuffer)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create passphrase unlocker: %v", err)
|
||||
}
|
||||
|
||||
// List unlockers
|
||||
unlockers, err := vlt.ListUnlockers()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list unlockers: %v", err)
|
||||
}
|
||||
|
||||
if len(unlockers) == 0 {
|
||||
t.Errorf("Expected at least one unlocker")
|
||||
}
|
||||
|
||||
// Check key type
|
||||
keyFound := false
|
||||
for _, key := range unlockers {
|
||||
if key.Type == "passphrase" {
|
||||
keyFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !keyFound {
|
||||
t.Errorf("Expected to find passphrase unlocker")
|
||||
}
|
||||
|
||||
// Test selecting unlocker
|
||||
err = vlt.SelectUnlocker(passphraseUnlocker.GetID())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to select unlocker: %v", err)
|
||||
}
|
||||
|
||||
// Test getting current unlocker
|
||||
currentUnlocker, err := vlt.GetCurrentUnlocker()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get current unlocker: %v", err)
|
||||
}
|
||||
|
||||
if currentUnlocker.GetID() != passphraseUnlocker.GetID() {
|
||||
t.Errorf("Expected current unlocker ID '%s', got '%s'", passphraseUnlocker.GetID(), currentUnlocker.GetID())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -35,7 +35,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/agehd"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -61,7 +61,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/agehd"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -87,7 +87,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/agehd"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -114,7 +114,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"git.eeqj.de/sneak/secret/internal/agehd"
|
||||
"git.eeqj.de/sneak/secret/pkg/agehd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"filippo.io/age"
|
||||
"git.eeqj.de/sneak/secret/internal/bip85"
|
||||
"git.eeqj.de/sneak/secret/pkg/bip85"
|
||||
"github.com/btcsuite/btcd/btcutil/hdkeychain"
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btcutil/bech32"
|
||||
@@ -25,6 +25,7 @@ const (
|
||||
vendorID = uint32(592366788) // berlin.sneak
|
||||
appID = uint32(733482323) // secret
|
||||
hrp = "age-secret-key-" // Bech32 HRP used by age
|
||||
x25519KeySize = 32 // 256-bit key size for X25519
|
||||
)
|
||||
|
||||
// clamp applies RFC-7748 clamping to a 32-byte scalar.
|
||||
@@ -37,16 +38,20 @@ func clamp(k []byte) {
|
||||
// IdentityFromEntropy converts 32 deterministic bytes into an
|
||||
// *age.X25519Identity by round-tripping through Bech32.
|
||||
func IdentityFromEntropy(ent []byte) (*age.X25519Identity, error) {
|
||||
if len(ent) != 32 {
|
||||
if len(ent) != x25519KeySize {
|
||||
return nil, fmt.Errorf("need 32-byte scalar, got %d", len(ent))
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original
|
||||
key := make([]byte, 32)
|
||||
key := make([]byte, x25519KeySize)
|
||||
copy(key, ent)
|
||||
clamp(key)
|
||||
|
||||
data, err := bech32.ConvertBits(key, 8, 5, true)
|
||||
const (
|
||||
bech32BitSize8 = 8 // Standard 8-bit encoding
|
||||
bech32BitSize5 = 5 // Bech32 5-bit encoding
|
||||
)
|
||||
data, err := bech32.ConvertBits(key, bech32BitSize8, bech32BitSize5, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bech32 convert: %w", err)
|
||||
}
|
||||
@@ -54,6 +59,7 @@ func IdentityFromEntropy(ent []byte) (*age.X25519Identity, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bech32 encode: %w", err)
|
||||
}
|
||||
|
||||
return age.ParseX25519Identity(strings.ToUpper(s))
|
||||
}
|
||||
|
||||
@@ -80,7 +86,7 @@ func DeriveEntropy(mnemonic string, n uint32) ([]byte, error) {
|
||||
|
||||
// Use BIP85 DRNG to generate deterministic 32 bytes for the age key
|
||||
drng := bip85.NewBIP85DRNG(entropy)
|
||||
key := make([]byte, 32)
|
||||
key := make([]byte, x25519KeySize)
|
||||
_, err = drng.Read(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read from DRNG: %w", err)
|
||||
@@ -109,7 +115,7 @@ func DeriveEntropyFromXPRV(xprv string, n uint32) ([]byte, error) {
|
||||
|
||||
// Use BIP85 DRNG to generate deterministic 32 bytes for the age key
|
||||
drng := bip85.NewBIP85DRNG(entropy)
|
||||
key := make([]byte, 32)
|
||||
key := make([]byte, x25519KeySize)
|
||||
_, err = drng.Read(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read from DRNG: %w", err)
|
||||
@@ -125,6 +131,7 @@ func DeriveIdentity(mnemonic string, n uint32) (*age.X25519Identity, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return IdentityFromEntropy(ent)
|
||||
}
|
||||
|
||||
@@ -135,5 +142,6 @@ func DeriveIdentityFromXPRV(xprv string, n uint32) (*age.X25519Identity, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return IdentityFromEntropy(ent)
|
||||
}
|
||||
1058
pkg/agehd/agehd_test.go
Normal file
1058
pkg/agehd/agehd_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ BIP85 enables a variety of use cases:
|
||||
```go
|
||||
import (
|
||||
"fmt"
|
||||
"git.eeqj.de/sneak/secret/internal/bip85"
|
||||
"git.eeqj.de/sneak/secret/pkg/bip85"
|
||||
"github.com/btcsuite/btcd/btcutil/hdkeychain"
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ The implementation is also compatible with the Python reference implementation's
|
||||
Run the tests with verbose output to see the test vectors and results:
|
||||
|
||||
```
|
||||
go test -v git.eeqj.de/sneak/secret/internal/bip85
|
||||
go test -v git.eeqj.de/sneak/secret/pkg/bip85
|
||||
```
|
||||
|
||||
## References
|
||||
@@ -1,3 +1,4 @@
|
||||
// Package bip85 implements BIP85 deterministic entropy derivation.
|
||||
package bip85
|
||||
|
||||
import (
|
||||
@@ -22,52 +23,55 @@ import (
|
||||
|
||||
const (
|
||||
// BIP85_MASTER_PATH is the derivation path prefix for all BIP85 applications
|
||||
BIP85_MASTER_PATH = "m/83696968'"
|
||||
BIP85_MASTER_PATH = "m/83696968'" //nolint:revive // ALL_CAPS used for BIP85 constants
|
||||
|
||||
// BIP85_KEY_HMAC_KEY is the HMAC key used for deriving the entropy
|
||||
BIP85_KEY_HMAC_KEY = "bip-entropy-from-k"
|
||||
BIP85_KEY_HMAC_KEY = "bip-entropy-from-k" //nolint:revive // ALL_CAPS used for BIP85 constants
|
||||
|
||||
// Application numbers
|
||||
APP_BIP39 = 39 // BIP39 mnemonics
|
||||
APP_HD_WIF = 2 // WIF for Bitcoin Core
|
||||
APP_XPRV = 32 // Extended private key
|
||||
APP_HEX = 128169
|
||||
APP_PWD64 = 707764 // Base64 passwords
|
||||
APP_PWD85 = 707785 // Base85 passwords
|
||||
APP_RSA = 828365
|
||||
// AppBIP39 is the application number for BIP39 mnemonics
|
||||
AppBIP39 = 39
|
||||
// AppHDWIF is the application number for WIF (Wallet Import Format) for Bitcoin Core
|
||||
AppHDWIF = 2
|
||||
// AppXPRV is the application number for extended private key
|
||||
AppXPRV = 32
|
||||
APP_HEX = 128169 //nolint:revive // ALL_CAPS used for BIP85 constants
|
||||
APP_PWD64 = 707764 // Base64 passwords //nolint:revive // ALL_CAPS used for BIP85 constants
|
||||
AppPWD85 = 707785 // Base85 passwords
|
||||
APP_RSA = 828365 //nolint:revive // ALL_CAPS used for BIP85 constants
|
||||
)
|
||||
|
||||
// Version bytes for extended keys
|
||||
var (
|
||||
// MainNetPrivateKey is the version for mainnet private keys
|
||||
MainNetPrivateKey = []byte{0x04, 0x88, 0xAD, 0xE4}
|
||||
MainNetPrivateKey = []byte{0x04, 0x88, 0xAD, 0xE4} //nolint:gochecknoglobals // Standard BIP32 constant
|
||||
// TestNetPrivateKey is the version for testnet private keys
|
||||
TestNetPrivateKey = []byte{0x04, 0x35, 0x83, 0x94}
|
||||
TestNetPrivateKey = []byte{0x04, 0x35, 0x83, 0x94} //nolint:gochecknoglobals // Standard BIP32 constant
|
||||
)
|
||||
|
||||
// BIP85DRNG is a deterministic random number generator seeded by BIP85 entropy
|
||||
type BIP85DRNG struct {
|
||||
// DRNG is a deterministic random number generator seeded by BIP85 entropy
|
||||
type DRNG struct {
|
||||
shake io.Reader
|
||||
}
|
||||
|
||||
// NewBIP85DRNG creates a new DRNG seeded with BIP85 entropy
|
||||
func NewBIP85DRNG(entropy []byte) *BIP85DRNG {
|
||||
func NewBIP85DRNG(entropy []byte) *DRNG {
|
||||
const bip85EntropySize = 64 // 512 bits
|
||||
// The entropy must be exactly 64 bytes (512 bits)
|
||||
if len(entropy) != 64 {
|
||||
panic("BIP85DRNG entropy must be 64 bytes")
|
||||
if len(entropy) != bip85EntropySize {
|
||||
panic("DRNG entropy must be 64 bytes")
|
||||
}
|
||||
|
||||
// Initialize SHAKE256 with the entropy
|
||||
shake := sha3.NewShake256()
|
||||
shake.Write(entropy)
|
||||
_, _ = shake.Write(entropy) // Write to hash functions never returns an error
|
||||
|
||||
return &BIP85DRNG{
|
||||
return &DRNG{
|
||||
shake: shake,
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface
|
||||
func (d *BIP85DRNG) Read(p []byte) (n int, err error) {
|
||||
func (d *DRNG) Read(p []byte) (n int, err error) {
|
||||
return d.shake.Read(p)
|
||||
}
|
||||
|
||||
@@ -161,7 +165,7 @@ func deriveChildKey(parent *hdkeychain.ExtendedKey, path string) (*hdkeychain.Ex
|
||||
|
||||
// DeriveBIP39Entropy derives entropy for a BIP39 mnemonic
|
||||
func DeriveBIP39Entropy(masterKey *hdkeychain.ExtendedKey, language, words, index uint32) ([]byte, error) {
|
||||
path := fmt.Sprintf("%s/%d'/%d'/%d'/%d'", BIP85_MASTER_PATH, APP_BIP39, language, words, index)
|
||||
path := fmt.Sprintf("%s/%d'/%d'/%d'/%d'", BIP85_MASTER_PATH, AppBIP39, language, words, index)
|
||||
|
||||
entropy, err := DeriveBIP85Entropy(masterKey, path)
|
||||
if err != nil {
|
||||
@@ -169,17 +173,26 @@ func DeriveBIP39Entropy(masterKey *hdkeychain.ExtendedKey, language, words, inde
|
||||
}
|
||||
|
||||
// Determine how many bits of entropy to use based on the words
|
||||
// BIP39 defines specific word counts and their corresponding entropy bits
|
||||
const (
|
||||
words12 = 12 // 128 bits of entropy
|
||||
words15 = 15 // 160 bits of entropy
|
||||
words18 = 18 // 192 bits of entropy
|
||||
words21 = 21 // 224 bits of entropy
|
||||
words24 = 24 // 256 bits of entropy
|
||||
)
|
||||
|
||||
var bits int
|
||||
switch words {
|
||||
case 12:
|
||||
case words12:
|
||||
bits = 128
|
||||
case 15:
|
||||
case words15:
|
||||
bits = 160
|
||||
case 18:
|
||||
case words18:
|
||||
bits = 192
|
||||
case 21:
|
||||
case words21:
|
||||
bits = 224
|
||||
case 24:
|
||||
case words24:
|
||||
bits = 256
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid BIP39 word count: %d", words)
|
||||
@@ -193,7 +206,7 @@ func DeriveBIP39Entropy(masterKey *hdkeychain.ExtendedKey, language, words, inde
|
||||
|
||||
// DeriveWIFKey derives a private key in WIF format
|
||||
func DeriveWIFKey(masterKey *hdkeychain.ExtendedKey, index uint32) (string, error) {
|
||||
path := fmt.Sprintf("%s/%d'/%d'", BIP85_MASTER_PATH, APP_HD_WIF, index)
|
||||
path := fmt.Sprintf("%s/%d'/%d'", BIP85_MASTER_PATH, AppHDWIF, index)
|
||||
|
||||
entropy, err := DeriveBIP85Entropy(masterKey, path)
|
||||
if err != nil {
|
||||
@@ -215,7 +228,7 @@ func DeriveWIFKey(masterKey *hdkeychain.ExtendedKey, index uint32) (string, erro
|
||||
|
||||
// DeriveXPRV derives an extended private key (XPRV)
|
||||
func DeriveXPRV(masterKey *hdkeychain.ExtendedKey, index uint32) (*hdkeychain.ExtendedKey, error) {
|
||||
path := fmt.Sprintf("%s/%d'/%d'", BIP85_MASTER_PATH, APP_XPRV, index)
|
||||
path := fmt.Sprintf("%s/%d'/%d'", BIP85_MASTER_PATH, AppXPRV, index)
|
||||
|
||||
entropy, err := DeriveBIP85Entropy(masterKey, path)
|
||||
if err != nil {
|
||||
@@ -266,6 +279,7 @@ func DeriveXPRV(masterKey *hdkeychain.ExtendedKey, index uint32) (*hdkeychain.Ex
|
||||
func doubleSHA256(data []byte) []byte {
|
||||
hash1 := sha256.Sum256(data)
|
||||
hash2 := sha256.Sum256(hash1[:])
|
||||
|
||||
return hash2[:]
|
||||
}
|
||||
|
||||
@@ -308,7 +322,7 @@ func DeriveBase64Password(masterKey *hdkeychain.ExtendedKey, pwdLen, index uint3
|
||||
encodedStr = strings.TrimRight(encodedStr, "=")
|
||||
|
||||
// Slice to the desired password length
|
||||
if uint32(len(encodedStr)) < pwdLen {
|
||||
if len(encodedStr) < int(pwdLen) {
|
||||
return "", fmt.Errorf("derived password length %d is shorter than requested length %d", len(encodedStr), pwdLen)
|
||||
}
|
||||
|
||||
@@ -321,7 +335,7 @@ func DeriveBase85Password(masterKey *hdkeychain.ExtendedKey, pwdLen, index uint3
|
||||
return "", fmt.Errorf("pwdLen must be between 10 and 80")
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("%s/%d'/%d'/%d'", BIP85_MASTER_PATH, APP_PWD85, pwdLen, index)
|
||||
path := fmt.Sprintf("%s/%d'/%d'/%d'", BIP85_MASTER_PATH, AppPWD85, pwdLen, index)
|
||||
|
||||
entropy, err := DeriveBIP85Entropy(masterKey, path)
|
||||
if err != nil {
|
||||
@@ -332,7 +346,7 @@ func DeriveBase85Password(masterKey *hdkeychain.ExtendedKey, pwdLen, index uint3
|
||||
encoded := encodeBase85WithRFC1924Charset(entropy)
|
||||
|
||||
// Slice to the desired password length
|
||||
if uint32(len(encoded)) < pwdLen {
|
||||
if len(encoded) < int(pwdLen) {
|
||||
return "", fmt.Errorf("encoded length %d is less than requested length %d", len(encoded), pwdLen)
|
||||
}
|
||||
|
||||
@@ -344,24 +358,30 @@ func encodeBase85WithRFC1924Charset(data []byte) string {
|
||||
// RFC1924 character set
|
||||
charset := "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!#$%&()*+-;<=>?@^_`{|}~"
|
||||
|
||||
const (
|
||||
base85ChunkSize = 4 // Process 4 bytes at a time
|
||||
base85DigitCount = 5 // Each chunk produces 5 digits
|
||||
base85Base = 85 // Base85 encoding uses base 85
|
||||
)
|
||||
|
||||
// Pad data to multiple of 4
|
||||
padded := make([]byte, ((len(data)+3)/4)*4)
|
||||
padded := make([]byte, ((len(data)+base85ChunkSize-1)/base85ChunkSize)*base85ChunkSize)
|
||||
copy(padded, data)
|
||||
|
||||
var buf strings.Builder
|
||||
buf.Grow(len(padded) * 5 / 4) // Each 4 bytes becomes 5 Base85 characters
|
||||
buf.Grow(len(padded) * base85DigitCount / base85ChunkSize) // Each 4 bytes becomes 5 Base85 characters
|
||||
|
||||
// Process in 4-byte chunks
|
||||
for i := 0; i < len(padded); i += 4 {
|
||||
for i := 0; i < len(padded); i += base85ChunkSize {
|
||||
// Convert 4 bytes to uint32 (big-endian)
|
||||
chunk := binary.BigEndian.Uint32(padded[i : i+4])
|
||||
chunk := binary.BigEndian.Uint32(padded[i : i+base85ChunkSize])
|
||||
|
||||
// Convert to 5 base-85 digits
|
||||
digits := make([]byte, 5)
|
||||
for j := 4; j >= 0; j-- {
|
||||
idx := chunk % 85
|
||||
digits := make([]byte, base85DigitCount)
|
||||
for j := base85DigitCount - 1; j >= 0; j-- {
|
||||
idx := chunk % base85Base
|
||||
digits[j] = charset[idx]
|
||||
chunk /= 85
|
||||
chunk /= base85Base
|
||||
}
|
||||
|
||||
buf.Write(digits)
|
||||
@@ -1,5 +1,9 @@
|
||||
//nolint:gosec // G101: Test file contains BIP85 test vectors, not real credentials
|
||||
//nolint:lll // Test vectors contain long lines
|
||||
package bip85
|
||||
|
||||
//nolint:revive,unparam // Test file with BIP85 test vectors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
@@ -79,13 +83,13 @@ const (
|
||||
)
|
||||
|
||||
// logTestVector logs test information in a cleaner, more concise format
|
||||
func logTestVector(t *testing.T, title, description string) {
|
||||
func logTestVector(t *testing.T, title string) {
|
||||
t.Logf("=== TEST: %s ===", title)
|
||||
}
|
||||
|
||||
// TestDerivedKey is a helper function to test the derived key directly
|
||||
func TestDerivedKey(t *testing.T) {
|
||||
logTestVector(t, "Derived Child Keys", "Testing direct key derivation for BIP85 paths")
|
||||
logTestVector(t, "Derived Child Keys")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -129,7 +133,7 @@ func TestDerivedKey(t *testing.T) {
|
||||
|
||||
// TestCase1 tests the first test vector from the BIP85 specification
|
||||
func TestCase1(t *testing.T) {
|
||||
logTestVector(t, "Test Case 1", "Basic entropy derivation with path m/83696968'/0'/0'")
|
||||
logTestVector(t, "Test Case 1")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -155,7 +159,7 @@ func TestCase1(t *testing.T) {
|
||||
|
||||
// TestCase2 tests the second test vector from the BIP85 specification
|
||||
func TestCase2(t *testing.T) {
|
||||
logTestVector(t, "Test Case 2", "Basic entropy derivation with path m/83696968'/0'/1'")
|
||||
logTestVector(t, "Test Case 2")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -181,7 +185,7 @@ func TestCase2(t *testing.T) {
|
||||
|
||||
// TestBIP39_12EnglishWords tests the BIP39 12 English words test vector
|
||||
func TestBIP39_12EnglishWords(t *testing.T) {
|
||||
logTestVector(t, "BIP39 12 English Words", "Deriving a 12-word English BIP39 mnemonic")
|
||||
logTestVector(t, "BIP39 12 English Words")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -225,7 +229,7 @@ func TestBIP39_12EnglishWords(t *testing.T) {
|
||||
|
||||
// TestBIP39_18EnglishWords tests the BIP39 18 English words test vector
|
||||
func TestBIP39_18EnglishWords(t *testing.T) {
|
||||
logTestVector(t, "BIP39 18 English Words", "Deriving an 18-word English BIP39 mnemonic")
|
||||
logTestVector(t, "BIP39 18 English Words")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -269,7 +273,7 @@ func TestBIP39_18EnglishWords(t *testing.T) {
|
||||
|
||||
// TestBIP39_24EnglishWords tests the BIP39 24 English words test vector
|
||||
func TestBIP39_24EnglishWords(t *testing.T) {
|
||||
logTestVector(t, "BIP39 24 English Words", "Deriving a 24-word English BIP39 mnemonic")
|
||||
logTestVector(t, "BIP39 24 English Words")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -313,7 +317,7 @@ func TestBIP39_24EnglishWords(t *testing.T) {
|
||||
|
||||
// TestHD_WIF tests the WIF test vector
|
||||
func TestHD_WIF(t *testing.T) {
|
||||
logTestVector(t, "HD-Seed WIF", "Deriving a WIF-encoded private key for Bitcoin Core hdseed")
|
||||
logTestVector(t, "HD-Seed WIF")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -353,7 +357,7 @@ func TestHD_WIF(t *testing.T) {
|
||||
|
||||
// TestXPRV tests the XPRV test vector
|
||||
func TestXPRV(t *testing.T) {
|
||||
logTestVector(t, "XPRV", "Deriving an extended private key (XPRV)")
|
||||
logTestVector(t, "XPRV")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -379,7 +383,7 @@ func TestXPRV(t *testing.T) {
|
||||
|
||||
// TestDRNG_SHAKE256 tests the BIP85-DRNG-SHAKE256 test vector
|
||||
func TestDRNG_SHAKE256(t *testing.T) {
|
||||
logTestVector(t, "DRNG-SHAKE256", "Testing the deterministic random number generator with SHAKE256")
|
||||
logTestVector(t, "DRNG-SHAKE256")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -414,7 +418,7 @@ func TestDRNG_SHAKE256(t *testing.T) {
|
||||
|
||||
// TestPythonDRNGVectors tests the DRNG vectors from the Python implementation
|
||||
func TestPythonDRNGVectors(t *testing.T) {
|
||||
logTestVector(t, "Python DRNG Vectors", "Testing specific DRNG vectors from the Python implementation")
|
||||
logTestVector(t, "Python DRNG Vectors")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -489,7 +493,7 @@ func TestPythonDRNGVectors(t *testing.T) {
|
||||
|
||||
// TestDRNGDeterminism tests the deterministic behavior of the DRNG
|
||||
func TestDRNGDeterminism(t *testing.T) {
|
||||
logTestVector(t, "DRNG Determinism", "Testing deterministic behavior of the DRNG")
|
||||
logTestVector(t, "DRNG Determinism")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -573,7 +577,7 @@ func TestDRNGDeterminism(t *testing.T) {
|
||||
|
||||
// TestDRNGLengths tests the DRNG with different lengths
|
||||
func TestDRNGLengths(t *testing.T) {
|
||||
logTestVector(t, "DRNG Lengths", "Testing DRNG with different read lengths")
|
||||
logTestVector(t, "DRNG Lengths")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -606,7 +610,7 @@ func TestDRNGLengths(t *testing.T) {
|
||||
|
||||
// TestDRNGExceptions tests error handling in the DRNG
|
||||
func TestDRNGExceptions(t *testing.T) {
|
||||
logTestVector(t, "DRNG Exceptions", "Testing error handling in the DRNG")
|
||||
logTestVector(t, "DRNG Exceptions")
|
||||
|
||||
// Test with entropy of the wrong size
|
||||
testCases := []int{0, 1, 32, 63, 65, 128}
|
||||
@@ -638,7 +642,7 @@ func TestDRNGExceptions(t *testing.T) {
|
||||
|
||||
// TestDRNGDifferentSizes tests the DRNG with different buffer sizes
|
||||
func TestDRNGDifferentSizes(t *testing.T) {
|
||||
logTestVector(t, "DRNG Different Sizes", "Testing the DRNG with different buffer sizes")
|
||||
logTestVector(t, "DRNG Different Sizes")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -700,7 +704,7 @@ func TestDRNGDifferentSizes(t *testing.T) {
|
||||
|
||||
// TestMasterKeyParsing tests parsing of different master key formats
|
||||
func TestMasterKeyParsing(t *testing.T) {
|
||||
logTestVector(t, "Master Key Parsing", "Testing parsing of master keys in different formats")
|
||||
logTestVector(t, "Master Key Parsing")
|
||||
|
||||
// Test valid master key
|
||||
t.Logf("Testing valid master key")
|
||||
@@ -747,7 +751,7 @@ func TestMasterKeyParsing(t *testing.T) {
|
||||
|
||||
// TestDifferentPathFormats tests different path format expressions
|
||||
func TestDifferentPathFormats(t *testing.T) {
|
||||
logTestVector(t, "Path Formats", "Testing different path format expressions")
|
||||
logTestVector(t, "Path Formats")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -791,7 +795,7 @@ func TestDifferentPathFormats(t *testing.T) {
|
||||
|
||||
// TestDirectBase85Encoding tests direct Base85 encoding with the test vector entropy
|
||||
func TestDirectBase85Encoding(t *testing.T) {
|
||||
logTestVector(t, "Direct Base85 Encoding", "Testing Base85 encoding with BIP85 test vector")
|
||||
logTestVector(t, "Direct Base85 Encoding")
|
||||
|
||||
// Parse the master key
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
@@ -836,7 +840,7 @@ func TestDirectBase85Encoding(t *testing.T) {
|
||||
|
||||
// TestPWDBase64 tests the Base64 password test vector
|
||||
func TestPWDBase64(t *testing.T) {
|
||||
logTestVector(t, "PWD Base64", "Deriving a Base64-encoded password")
|
||||
logTestVector(t, "PWD Base64")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -881,7 +885,7 @@ func TestPWDBase64(t *testing.T) {
|
||||
|
||||
// TestPWDBase85 tests the Base85 password test vector
|
||||
func TestPWDBase85(t *testing.T) {
|
||||
logTestVector(t, "PWD Base85", "Deriving a Base85-encoded password")
|
||||
logTestVector(t, "PWD Base85")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -926,7 +930,7 @@ func TestPWDBase85(t *testing.T) {
|
||||
|
||||
// TestHexDerivation tests the HEX derivation test vector
|
||||
func TestHexDerivation(t *testing.T) {
|
||||
logTestVector(t, "HEX Derivation", "Testing HEX data derivation with BIP85 test vector")
|
||||
logTestVector(t, "HEX Derivation")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -963,7 +967,7 @@ func TestHexDerivation(t *testing.T) {
|
||||
|
||||
// TestInvalidParameters tests error conditions for parameter validation
|
||||
func TestInvalidParameters(t *testing.T) {
|
||||
logTestVector(t, "Invalid Parameters", "Testing error handling for invalid inputs")
|
||||
logTestVector(t, "Invalid Parameters")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
@@ -1041,7 +1045,7 @@ func TestInvalidParameters(t *testing.T) {
|
||||
|
||||
// TestAdditionalDeriveHex tests additional hex derivation scenarios
|
||||
func TestAdditionalDeriveHex(t *testing.T) {
|
||||
logTestVector(t, "Additional Hex Derivation", "Testing hex data derivation with various lengths")
|
||||
logTestVector(t, "Additional Hex Derivation")
|
||||
|
||||
masterKey, err := ParseMasterKey(testMasterKey)
|
||||
if err != nil {
|
||||
Reference in New Issue
Block a user