Add bounds checking before converting int to uint32 to prevent potential integer overflow in benchmark and concurrent test functions
1057 lines
26 KiB
Go
1057 lines
26 KiB
Go
package agehd
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
|
|
"filippo.io/age"
|
|
"github.com/tyler-smith/go-bip39"
|
|
)
|
|
|
|
const (
|
|
mnemonic = "abandon abandon abandon abandon abandon " +
|
|
"abandon abandon abandon abandon abandon abandon about"
|
|
|
|
// Test xprv from BIP85 test vectors
|
|
testXPRV = "xprv9s21ZrQH143K2LBWUUQRFXhucrQqBpKdRRxNVq2zBqsx8HVqFk2uYo8kmbaLLHRdqtQpUm98uKfu3vca1LqdGhUtyoFnCNkfmXRyPXLjbKb"
|
|
|
|
// Additional test mnemonics for comprehensive testing
|
|
testMnemonic12 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
|
testMnemonic15 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
|
testMnemonic18 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
|
testMnemonic21 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
|
|
testMnemonic24 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon art"
|
|
|
|
// Test messages used throughout the tests
|
|
testMessageHelloWorld = "hello world"
|
|
testMessageHelloFromXPRV = "hello from xprv"
|
|
testMessageGeneric = "test message"
|
|
testMessageBoundary = "boundary test"
|
|
testMessageBenchmark = "benchmark test message"
|
|
testMessageLargePattern = "A"
|
|
|
|
// Error messages for validation
|
|
errorMsgNeed32Bytes = "need 32-byte scalar, got"
|
|
errorMsgInvalidXPRV = "invalid-xprv"
|
|
|
|
// Test constants for various scenarios
|
|
// Removed testSkipMessage as tests are no longer skipped
|
|
|
|
// Numeric constants for testing
|
|
testNumGoroutines = 10
|
|
testNumIterations = 100
|
|
|
|
// Large data test constants
|
|
testDataSizeMegabyte = 1024 * 1024 // 1 MB
|
|
)
|
|
|
|
func TestEncryptDecrypt(t *testing.T) {
|
|
id, err := DeriveIdentity(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive: %v", err)
|
|
}
|
|
|
|
t.Logf("secret: %s", id.String())
|
|
t.Logf("recipient: %s", id.Recipient().String())
|
|
|
|
var ct bytes.Buffer
|
|
w, err := age.Encrypt(&ct, id.Recipient())
|
|
if err != nil {
|
|
t.Fatalf("encrypt init: %v", err)
|
|
}
|
|
if _, err = io.WriteString(w, testMessageHelloWorld); err != nil {
|
|
t.Fatalf("write: %v", err)
|
|
}
|
|
if err = w.Close(); err != nil {
|
|
t.Fatalf("encrypt close: %v", err)
|
|
}
|
|
|
|
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), id)
|
|
if err != nil {
|
|
t.Fatalf("decrypt init: %v", err)
|
|
}
|
|
dec, err := io.ReadAll(r)
|
|
if err != nil {
|
|
t.Fatalf("read: %v", err)
|
|
}
|
|
|
|
if got := string(dec); got != testMessageHelloWorld {
|
|
t.Fatalf("round-trip mismatch: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestDeriveIdentityFromXPRV(t *testing.T) {
|
|
id, err := DeriveIdentityFromXPRV(testXPRV, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive from xprv: %v", err)
|
|
}
|
|
|
|
t.Logf("xprv secret: %s", id.String())
|
|
t.Logf("xprv recipient: %s", id.Recipient().String())
|
|
|
|
// Test encryption/decryption with xprv-derived identity
|
|
var ct bytes.Buffer
|
|
w, err := age.Encrypt(&ct, id.Recipient())
|
|
if err != nil {
|
|
t.Fatalf("encrypt init: %v", err)
|
|
}
|
|
if _, err = io.WriteString(w, testMessageHelloFromXPRV); err != nil {
|
|
t.Fatalf("write: %v", err)
|
|
}
|
|
if err = w.Close(); err != nil {
|
|
t.Fatalf("encrypt close: %v", err)
|
|
}
|
|
|
|
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), id)
|
|
if err != nil {
|
|
t.Fatalf("decrypt init: %v", err)
|
|
}
|
|
dec, err := io.ReadAll(r)
|
|
if err != nil {
|
|
t.Fatalf("read: %v", err)
|
|
}
|
|
|
|
if got := string(dec); got != testMessageHelloFromXPRV {
|
|
t.Fatalf("round-trip mismatch: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestDeterministicDerivation(t *testing.T) {
|
|
// Test that the same mnemonic and index always produce the same identity
|
|
id1, err := DeriveIdentity(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive 1: %v", err)
|
|
}
|
|
|
|
id2, err := DeriveIdentity(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive 2: %v", err)
|
|
}
|
|
|
|
if id1.String() != id2.String() {
|
|
t.Fatalf(
|
|
"identities should be deterministic: %s != %s",
|
|
id1.String(),
|
|
id2.String(),
|
|
)
|
|
}
|
|
|
|
// Test that different indices produce different identities
|
|
id3, err := DeriveIdentity(mnemonic, 1)
|
|
if err != nil {
|
|
t.Fatalf("derive 3: %v", err)
|
|
}
|
|
|
|
if id1.String() == id3.String() {
|
|
t.Fatalf("different indices should produce different identities")
|
|
}
|
|
|
|
t.Logf("Index 0: %s", id1.String())
|
|
t.Logf("Index 1: %s", id3.String())
|
|
}
|
|
|
|
func TestDeterministicXPRVDerivation(t *testing.T) {
|
|
// Test that the same xprv and index always produce the same identity
|
|
id1, err := DeriveIdentityFromXPRV(testXPRV, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive 1: %v", err)
|
|
}
|
|
|
|
id2, err := DeriveIdentityFromXPRV(testXPRV, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive 2: %v", err)
|
|
}
|
|
|
|
if id1.String() != id2.String() {
|
|
t.Fatalf(
|
|
"xprv identities should be deterministic: %s != %s",
|
|
id1.String(),
|
|
id2.String(),
|
|
)
|
|
}
|
|
|
|
// Test that different indices with same xprv produce different identities
|
|
id3, err := DeriveIdentityFromXPRV(testXPRV, 1)
|
|
if err != nil {
|
|
t.Fatalf("derive 3: %v", err)
|
|
}
|
|
|
|
if id1.String() == id3.String() {
|
|
t.Fatalf("different indices should produce different identities")
|
|
}
|
|
|
|
t.Logf("XPRV Index 0: %s", id1.String())
|
|
t.Logf("XPRV Index 1: %s", id3.String())
|
|
}
|
|
|
|
func TestMnemonicVsXPRVConsistency(t *testing.T) {
|
|
// FIXME This test is missing!
|
|
}
|
|
|
|
func TestEntropyLength(t *testing.T) {
|
|
// Test that DeriveEntropy returns exactly 32 bytes
|
|
entropy, err := DeriveEntropy(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive entropy: %v", err)
|
|
}
|
|
|
|
if len(entropy) != 32 {
|
|
t.Fatalf("expected 32 bytes of entropy, got %d", len(entropy))
|
|
}
|
|
|
|
t.Logf("Entropy (32 bytes): %x", entropy)
|
|
|
|
// Test that DeriveEntropyFromXPRV returns exactly 32 bytes
|
|
entropyXPRV, err := DeriveEntropyFromXPRV(testXPRV, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive entropy from xprv: %v", err)
|
|
}
|
|
|
|
if len(entropyXPRV) != 32 {
|
|
t.Fatalf(
|
|
"expected 32 bytes of entropy from xprv, got %d",
|
|
len(entropyXPRV),
|
|
)
|
|
}
|
|
|
|
t.Logf("XPRV Entropy (32 bytes): %x", entropyXPRV)
|
|
|
|
// Note: We don't compare the entropy values since the test mnemonic and test xprv
|
|
// are from different sources and should produce different entropy values.
|
|
}
|
|
|
|
func TestIdentityFromEntropy(t *testing.T) {
|
|
// Test that IdentityFromEntropy works with custom entropy
|
|
entropy := make([]byte, 32)
|
|
for i := range entropy {
|
|
entropy[i] = byte(i)
|
|
}
|
|
|
|
id, err := IdentityFromEntropy(entropy)
|
|
if err != nil {
|
|
t.Fatalf("identity from entropy: %v", err)
|
|
}
|
|
|
|
t.Logf("Custom entropy identity: %s", id.String())
|
|
|
|
// Test that it rejects wrong-sized entropy
|
|
_, err = IdentityFromEntropy(entropy[:31])
|
|
if err == nil {
|
|
t.Fatalf("expected error for 31-byte entropy")
|
|
}
|
|
|
|
// Create a 33-byte slice to test rejection
|
|
entropy33 := make([]byte, 33)
|
|
copy(entropy33, entropy)
|
|
_, err = IdentityFromEntropy(entropy33)
|
|
if err == nil {
|
|
t.Fatalf("expected error for 33-byte entropy")
|
|
}
|
|
}
|
|
|
|
func TestInvalidXPRV(t *testing.T) {
|
|
// Test with invalid xprv
|
|
_, err := DeriveIdentityFromXPRV(errorMsgInvalidXPRV, 0)
|
|
if err == nil {
|
|
t.Fatalf("expected error for invalid xprv")
|
|
}
|
|
|
|
t.Logf("Got expected error for invalid xprv: %v", err)
|
|
}
|
|
|
|
// TestClampFunction tests the RFC-7748 clamping function
|
|
func TestClampFunction(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
expected []byte
|
|
}{
|
|
{
|
|
name: "all zeros",
|
|
input: make([]byte, 32),
|
|
expected: []byte{
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
0,
|
|
64,
|
|
},
|
|
},
|
|
{
|
|
name: "all ones",
|
|
input: bytes.Repeat([]byte{255}, 32),
|
|
expected: append(
|
|
[]byte{248},
|
|
append(bytes.Repeat([]byte{255}, 30), 127)...),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
input := make([]byte, 32)
|
|
copy(input, tt.input)
|
|
clamp(input)
|
|
|
|
// Check specific bits that should be clamped
|
|
if input[0]&7 != 0 {
|
|
t.Errorf(
|
|
"first byte should have bottom 3 bits cleared, got %08b",
|
|
input[0],
|
|
)
|
|
}
|
|
if input[31]&128 != 0 {
|
|
t.Errorf(
|
|
"last byte should have top bit cleared, got %08b",
|
|
input[31],
|
|
)
|
|
}
|
|
if input[31]&64 == 0 {
|
|
t.Errorf(
|
|
"last byte should have second-to-top bit set, got %08b",
|
|
input[31],
|
|
)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestIdentityFromEntropyEdgeCases tests edge cases for IdentityFromEntropy
|
|
func TestIdentityFromEntropyEdgeCases(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
entropy []byte
|
|
expectError bool
|
|
errorMsg string
|
|
}{
|
|
{
|
|
name: "nil entropy",
|
|
entropy: nil,
|
|
expectError: true,
|
|
errorMsg: errorMsgNeed32Bytes + " 0",
|
|
},
|
|
{
|
|
name: "empty entropy",
|
|
entropy: []byte{},
|
|
expectError: true,
|
|
errorMsg: errorMsgNeed32Bytes + " 0",
|
|
},
|
|
{
|
|
name: "too short entropy",
|
|
entropy: make([]byte, 31),
|
|
expectError: true,
|
|
errorMsg: errorMsgNeed32Bytes + " 31",
|
|
},
|
|
{
|
|
name: "too long entropy",
|
|
entropy: make([]byte, 33),
|
|
expectError: true,
|
|
errorMsg: errorMsgNeed32Bytes + " 33",
|
|
},
|
|
{
|
|
name: "valid 32-byte entropy",
|
|
entropy: make([]byte, 32),
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "random valid entropy",
|
|
entropy: func() []byte {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic(
|
|
err,
|
|
) // In test context, panic is acceptable for setup failures
|
|
}
|
|
return b
|
|
}(),
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
identity, err := IdentityFromEntropy(tt.entropy)
|
|
|
|
if tt.expectError {
|
|
if err == nil {
|
|
t.Errorf("expected error but got none")
|
|
} else if !strings.Contains(err.Error(), tt.errorMsg) {
|
|
t.Errorf("expected error containing %q, got %q", tt.errorMsg, err.Error())
|
|
}
|
|
if identity != nil {
|
|
t.Errorf(
|
|
"expected nil identity on error, got %v",
|
|
identity,
|
|
)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
if identity == nil {
|
|
t.Errorf("expected valid identity, got nil")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDeriveEntropyInvalidMnemonic tests error handling for invalid mnemonics
|
|
func TestDeriveEntropyInvalidMnemonic(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
mnemonic string
|
|
}{
|
|
{
|
|
name: "empty mnemonic",
|
|
mnemonic: "",
|
|
},
|
|
{
|
|
name: "single word",
|
|
mnemonic: "abandon",
|
|
},
|
|
{
|
|
name: "invalid word",
|
|
mnemonic: "invalid word sequence that does not exist in bip39",
|
|
},
|
|
{
|
|
name: "wrong word count",
|
|
mnemonic: "abandon abandon abandon abandon abandon",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Note: BIP39 library is quite permissive and doesn't validate
|
|
// mnemonic words strictly, so we mainly test that the function
|
|
// doesn't panic and produces some result
|
|
entropy, err := DeriveEntropy(tt.mnemonic, 0)
|
|
if err != nil {
|
|
t.Logf("Got error for invalid mnemonic %q: %v", tt.name, err)
|
|
} else {
|
|
if len(entropy) != 32 {
|
|
t.Errorf("expected 32 bytes even for invalid mnemonic, got %d", len(entropy))
|
|
}
|
|
t.Logf("Invalid mnemonic %q produced entropy: %x", tt.name, entropy)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDeriveEntropyFromXPRVInvalidInputs tests error handling for invalid XPRVs
|
|
func TestDeriveEntropyFromXPRVInvalidInputs(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
xprv string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "empty xprv",
|
|
xprv: "",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid base58",
|
|
xprv: "invalid-base58-string-!@#$%",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "wrong prefix",
|
|
xprv: "xpub661MyMwAqRbcFtXgS5sYJABqqG9YLmC4Q1Rdap9gSE8NqtwybGhePY2gZ29ESFjqJoCu1Rupje8YtGqsefD265TMg7usUDFdp6W1EGMcet8",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "truncated xprv",
|
|
xprv: "xprv9s21ZrQH143K2LBWUUQRFXhucrQqBpKdRRxNVq2zBqsx8HVqFk2uYo8kmbaLLHRdqtQpUm98uKfu3vca1LqdGhUtyoFnCNkfmXRyPXLj",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "valid xprv",
|
|
xprv: testXPRV,
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
entropy, err := DeriveEntropyFromXPRV(tt.xprv, 0)
|
|
|
|
if tt.expectError {
|
|
if err == nil {
|
|
t.Errorf("expected error for invalid xprv %q", tt.name)
|
|
} else {
|
|
t.Logf("Got expected error for %q: %v", tt.name, err)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("unexpected error for valid xprv: %v", err)
|
|
}
|
|
if len(entropy) != 32 {
|
|
t.Errorf("expected 32 bytes of entropy, got %d", len(entropy))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDifferentMnemonicLengths tests derivation with different mnemonic lengths
|
|
func TestDifferentMnemonicLengths(t *testing.T) {
|
|
mnemonics := map[string]string{
|
|
"12 words": testMnemonic12,
|
|
"15 words": testMnemonic15,
|
|
"18 words": testMnemonic18,
|
|
"21 words": testMnemonic21,
|
|
"24 words": testMnemonic24,
|
|
}
|
|
|
|
for name, mnemonic := range mnemonics {
|
|
t.Run(name, func(t *testing.T) {
|
|
identity, err := DeriveIdentity(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("failed to derive identity from %s: %v", name, err)
|
|
}
|
|
|
|
// Test that we can encrypt/decrypt
|
|
var ct bytes.Buffer
|
|
w, err := age.Encrypt(&ct, identity.Recipient())
|
|
if err != nil {
|
|
t.Fatalf("encrypt init: %v", err)
|
|
}
|
|
if _, err = io.WriteString(w, testMessageGeneric); err != nil {
|
|
t.Fatalf("write: %v", err)
|
|
}
|
|
if err = w.Close(); err != nil {
|
|
t.Fatalf("encrypt close: %v", err)
|
|
}
|
|
|
|
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
|
|
if err != nil {
|
|
t.Fatalf("decrypt init: %v", err)
|
|
}
|
|
dec, err := io.ReadAll(r)
|
|
if err != nil {
|
|
t.Fatalf("read: %v", err)
|
|
}
|
|
|
|
if string(dec) != testMessageGeneric {
|
|
t.Fatalf("round-trip failed for %s", name)
|
|
}
|
|
|
|
t.Logf("%s identity: %s", name, identity.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestIndexBoundaries tests derivation with various index values
|
|
func TestIndexBoundaries(t *testing.T) {
|
|
indices := []uint32{
|
|
0, // minimum
|
|
1, // basic
|
|
100, // moderate
|
|
1000, // larger
|
|
0x7FFFFFFF, // maximum hardened index
|
|
0xFFFFFFFF, // maximum uint32
|
|
}
|
|
|
|
for _, index := range indices {
|
|
t.Run(fmt.Sprintf("index_%d", index), func(t *testing.T) {
|
|
identity, err := DeriveIdentity(mnemonic, index)
|
|
if err != nil {
|
|
t.Fatalf(
|
|
"failed to derive identity at index %d: %v",
|
|
index,
|
|
err,
|
|
)
|
|
}
|
|
|
|
// Verify the identity is valid by testing encryption/decryption
|
|
var ct bytes.Buffer
|
|
w, err := age.Encrypt(&ct, identity.Recipient())
|
|
if err != nil {
|
|
t.Fatalf("encrypt init at index %d: %v", index, err)
|
|
}
|
|
if _, err = io.WriteString(w, testMessageBoundary); err != nil {
|
|
t.Fatalf("write at index %d: %v", index, err)
|
|
}
|
|
if err = w.Close(); err != nil {
|
|
t.Fatalf("encrypt close at index %d: %v", index, err)
|
|
}
|
|
|
|
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
|
|
if err != nil {
|
|
t.Fatalf("decrypt init at index %d: %v", index, err)
|
|
}
|
|
dec, err := io.ReadAll(r)
|
|
if err != nil {
|
|
t.Fatalf("read at index %d: %v", index, err)
|
|
}
|
|
|
|
if string(dec) != testMessageBoundary {
|
|
t.Fatalf("round-trip failed at index %d", index)
|
|
}
|
|
|
|
t.Logf("Index %d identity: %s", index, identity.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestEntropyUniqueness tests that different inputs produce different entropy
|
|
func TestEntropyUniqueness(t *testing.T) {
|
|
// Test different indices with same mnemonic
|
|
entropy1, err := DeriveEntropy(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive entropy 1: %v", err)
|
|
}
|
|
|
|
entropy2, err := DeriveEntropy(mnemonic, 1)
|
|
if err != nil {
|
|
t.Fatalf("derive entropy 2: %v", err)
|
|
}
|
|
|
|
if bytes.Equal(entropy1, entropy2) {
|
|
t.Fatalf("different indices should produce different entropy")
|
|
}
|
|
|
|
// Test different mnemonics with same index
|
|
entropy3, err := DeriveEntropy(testMnemonic24, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive entropy 3: %v", err)
|
|
}
|
|
|
|
if bytes.Equal(entropy1, entropy3) {
|
|
t.Fatalf("different mnemonics should produce different entropy")
|
|
}
|
|
|
|
t.Logf("Entropy uniqueness verified across indices and mnemonics")
|
|
}
|
|
|
|
// TestConcurrentDerivation tests that derivation is safe for concurrent use
|
|
func TestConcurrentDerivation(t *testing.T) {
|
|
results := make(chan string, testNumGoroutines*testNumIterations)
|
|
errors := make(chan error, testNumGoroutines*testNumIterations)
|
|
|
|
for i := 0; i < testNumGoroutines; i++ {
|
|
go func(goroutineID int) {
|
|
for j := 0; j < testNumIterations; j++ {
|
|
if j < 0 || j > 1000000 {
|
|
errors <- fmt.Errorf("index out of safe range")
|
|
return
|
|
}
|
|
identity, err := DeriveIdentity(mnemonic, uint32(j))
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
results <- identity.String()
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Collect results
|
|
resultMap := make(map[string]int)
|
|
for i := 0; i < testNumGoroutines*testNumIterations; i++ {
|
|
select {
|
|
case result := <-results:
|
|
resultMap[result]++
|
|
case err := <-errors:
|
|
t.Fatalf("concurrent derivation error: %v", err)
|
|
}
|
|
}
|
|
|
|
// Verify that each index produced the same result across all goroutines
|
|
expectedResults := testNumGoroutines
|
|
for result, count := range resultMap {
|
|
if count != expectedResults {
|
|
t.Errorf(
|
|
"result %s appeared %d times, expected %d",
|
|
result,
|
|
count,
|
|
expectedResults,
|
|
)
|
|
}
|
|
}
|
|
|
|
t.Logf(
|
|
"Concurrent derivation test passed with %d unique results",
|
|
len(resultMap),
|
|
)
|
|
}
|
|
|
|
// Benchmark tests
|
|
func BenchmarkDeriveIdentity(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
index := i % 1000
|
|
if index < 0 || index > 1000000 {
|
|
b.Fatalf("index out of safe range: %d", index)
|
|
}
|
|
_, err := DeriveIdentity(mnemonic, uint32(index))
|
|
if err != nil {
|
|
b.Fatalf("derive identity: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkDeriveIdentityFromXPRV(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
index := i % 1000
|
|
if index < 0 || index > 1000000 {
|
|
b.Fatalf("index out of safe range: %d", index)
|
|
}
|
|
_, err := DeriveIdentityFromXPRV(testXPRV, uint32(index))
|
|
if err != nil {
|
|
b.Fatalf("derive identity from xprv: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkDeriveEntropy(b *testing.B) {
|
|
for i := 0; i < b.N; i++ {
|
|
index := i % 1000
|
|
if index < 0 || index > 1000000 {
|
|
b.Fatalf("index out of safe range: %d", index)
|
|
}
|
|
_, err := DeriveEntropy(mnemonic, uint32(index))
|
|
if err != nil {
|
|
b.Fatalf("derive entropy: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkIdentityFromEntropy(b *testing.B) {
|
|
entropy := make([]byte, 32)
|
|
if _, err := rand.Read(entropy); err != nil {
|
|
b.Fatalf("failed to generate random entropy: %v", err)
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := IdentityFromEntropy(entropy)
|
|
if err != nil {
|
|
b.Fatalf("identity from entropy: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkEncryptDecrypt(b *testing.B) {
|
|
identity, err := DeriveIdentity(mnemonic, 0)
|
|
if err != nil {
|
|
b.Fatalf("derive identity: %v", err)
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
var ct bytes.Buffer
|
|
w, err := age.Encrypt(&ct, identity.Recipient())
|
|
if err != nil {
|
|
b.Fatalf("encrypt init: %v", err)
|
|
}
|
|
if _, err = io.WriteString(w, testMessageBenchmark); err != nil {
|
|
b.Fatalf("write: %v", err)
|
|
}
|
|
if err = w.Close(); err != nil {
|
|
b.Fatalf("encrypt close: %v", err)
|
|
}
|
|
|
|
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
|
|
if err != nil {
|
|
b.Fatalf("decrypt init: %v", err)
|
|
}
|
|
_, err = io.ReadAll(r)
|
|
if err != nil {
|
|
b.Fatalf("read: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestConstants verifies the hardcoded constants
|
|
func TestConstants(t *testing.T) {
|
|
if purpose != 83696968 {
|
|
t.Errorf(
|
|
"purpose constant mismatch: expected 83696968, got %d",
|
|
purpose,
|
|
)
|
|
}
|
|
if vendorID != 592366788 {
|
|
t.Errorf(
|
|
"vendorID constant mismatch: expected 592366788, got %d",
|
|
vendorID,
|
|
)
|
|
}
|
|
if appID != 733482323 {
|
|
t.Errorf(
|
|
"appID constant mismatch: expected 733482323, got %d",
|
|
appID,
|
|
)
|
|
}
|
|
if hrp != "age-secret-key-" {
|
|
t.Errorf(
|
|
"hrp constant mismatch: expected 'age-secret-key-', got %q",
|
|
hrp,
|
|
)
|
|
}
|
|
}
|
|
|
|
// TestIdentityStringFormat tests that generated identities have the correct format
|
|
func TestIdentityStringFormat(t *testing.T) {
|
|
identity, err := DeriveIdentity(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive identity: %v", err)
|
|
}
|
|
|
|
secretKey := identity.String()
|
|
recipient := identity.Recipient().String()
|
|
|
|
// Check secret key format
|
|
if !strings.HasPrefix(secretKey, "AGE-SECRET-KEY-") {
|
|
t.Errorf(
|
|
"secret key should start with 'AGE-SECRET-KEY-', got: %s",
|
|
secretKey,
|
|
)
|
|
}
|
|
|
|
// Check recipient format
|
|
if !strings.HasPrefix(recipient, "age1") {
|
|
t.Errorf("recipient should start with 'age1', got: %s", recipient)
|
|
}
|
|
|
|
// Check that they're different
|
|
if secretKey == recipient {
|
|
t.Errorf("secret key and recipient should be different")
|
|
}
|
|
|
|
t.Logf("Secret key format: %s", secretKey)
|
|
t.Logf("Recipient format: %s", recipient)
|
|
}
|
|
|
|
// TestLargeMessageEncryption tests encryption/decryption of larger messages
|
|
func TestLargeMessageEncryption(t *testing.T) {
|
|
identity, err := DeriveIdentity(mnemonic, 0)
|
|
if err != nil {
|
|
t.Fatalf("derive identity: %v", err)
|
|
}
|
|
|
|
// Test with different message sizes
|
|
sizes := []int{1, 100, 1024, 10240, 100000}
|
|
|
|
for _, size := range sizes {
|
|
t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
|
|
message := strings.Repeat(testMessageLargePattern, size)
|
|
|
|
var ct bytes.Buffer
|
|
w, err := age.Encrypt(&ct, identity.Recipient())
|
|
if err != nil {
|
|
t.Fatalf("encrypt init: %v", err)
|
|
}
|
|
if _, err = io.WriteString(w, message); err != nil {
|
|
t.Fatalf("write: %v", err)
|
|
}
|
|
if err = w.Close(); err != nil {
|
|
t.Fatalf("encrypt close: %v", err)
|
|
}
|
|
|
|
r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
|
|
if err != nil {
|
|
t.Fatalf("decrypt init: %v", err)
|
|
}
|
|
dec, err := io.ReadAll(r)
|
|
if err != nil {
|
|
t.Fatalf("read: %v", err)
|
|
}
|
|
|
|
if string(dec) != message {
|
|
t.Fatalf("message size %d: round-trip failed", size)
|
|
}
|
|
|
|
t.Logf("Successfully encrypted/decrypted %d byte message", size)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRandomMnemonicDeterministicGeneration tests that:
|
|
// 1. A random mnemonic generates the same keys deterministically
|
|
// 2. Large data (1MB) can be encrypted and decrypted successfully
|
|
func TestRandomMnemonicDeterministicGeneration(t *testing.T) {
|
|
// Generate a random mnemonic using the BIP39 library
|
|
entropy := make([]byte, 32) // 256 bits for 24-word mnemonic
|
|
if _, err := rand.Read(entropy); err != nil {
|
|
t.Fatalf("failed to generate random entropy: %v", err)
|
|
}
|
|
|
|
randomMnemonic, err := bip39.NewMnemonic(entropy)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate random mnemonic: %v", err)
|
|
}
|
|
|
|
t.Logf("Generated random mnemonic: %s", randomMnemonic)
|
|
|
|
// Test index for key derivation
|
|
testIndex := uint32(42)
|
|
|
|
// Generate the first identity
|
|
identity1, err := DeriveIdentity(randomMnemonic, testIndex)
|
|
if err != nil {
|
|
t.Fatalf("failed to derive first identity: %v", err)
|
|
}
|
|
|
|
// Generate the second identity with the same mnemonic and index
|
|
identity2, err := DeriveIdentity(randomMnemonic, testIndex)
|
|
if err != nil {
|
|
t.Fatalf("failed to derive second identity: %v", err)
|
|
}
|
|
|
|
// Verify that both private keys are identical
|
|
privateKey1 := identity1.String()
|
|
privateKey2 := identity2.String()
|
|
if privateKey1 != privateKey2 {
|
|
t.Fatalf(
|
|
"private keys should be identical:\nFirst: %s\nSecond: %s",
|
|
privateKey1,
|
|
privateKey2,
|
|
)
|
|
}
|
|
|
|
// Verify that both public keys (recipients) are identical
|
|
publicKey1 := identity1.Recipient().String()
|
|
publicKey2 := identity2.Recipient().String()
|
|
if publicKey1 != publicKey2 {
|
|
t.Fatalf(
|
|
"public keys should be identical:\nFirst: %s\nSecond: %s",
|
|
publicKey1,
|
|
publicKey2,
|
|
)
|
|
}
|
|
|
|
t.Logf("✓ Deterministic generation verified")
|
|
t.Logf("Private key: %s", privateKey1)
|
|
t.Logf("Public key: %s", publicKey1)
|
|
|
|
// Generate 1 MB of random data for encryption test
|
|
testData := make([]byte, testDataSizeMegabyte)
|
|
if _, err := rand.Read(testData); err != nil {
|
|
t.Fatalf("failed to generate random test data: %v", err)
|
|
}
|
|
|
|
t.Logf("Generated %d bytes of random test data", len(testData))
|
|
|
|
// Encrypt the data using the public key (recipient)
|
|
var ciphertext bytes.Buffer
|
|
encryptor, err := age.Encrypt(&ciphertext, identity1.Recipient())
|
|
if err != nil {
|
|
t.Fatalf("failed to create encryptor: %v", err)
|
|
}
|
|
|
|
_, err = encryptor.Write(testData)
|
|
if err != nil {
|
|
t.Fatalf("failed to write data to encryptor: %v", err)
|
|
}
|
|
|
|
err = encryptor.Close()
|
|
if err != nil {
|
|
t.Fatalf("failed to close encryptor: %v", err)
|
|
}
|
|
|
|
t.Logf(
|
|
"✓ Encrypted %d bytes into %d bytes of ciphertext",
|
|
len(testData),
|
|
ciphertext.Len(),
|
|
)
|
|
|
|
// Decrypt the data using the private key
|
|
decryptor, err := age.Decrypt(
|
|
bytes.NewReader(ciphertext.Bytes()),
|
|
identity1,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("failed to create decryptor: %v", err)
|
|
}
|
|
|
|
decryptedData, err := io.ReadAll(decryptor)
|
|
if err != nil {
|
|
t.Fatalf("failed to read decrypted data: %v", err)
|
|
}
|
|
|
|
t.Logf("✓ Decrypted %d bytes", len(decryptedData))
|
|
|
|
// Verify that the decrypted data matches the original
|
|
if len(decryptedData) != len(testData) {
|
|
t.Fatalf(
|
|
"decrypted data length mismatch: expected %d, got %d",
|
|
len(testData),
|
|
len(decryptedData),
|
|
)
|
|
}
|
|
|
|
if !bytes.Equal(testData, decryptedData) {
|
|
t.Fatalf("decrypted data does not match original data")
|
|
}
|
|
|
|
t.Logf("✓ Large data encryption/decryption test passed successfully")
|
|
|
|
// Additional verification: test with the second identity (should work identically)
|
|
var ciphertext2 bytes.Buffer
|
|
encryptor2, err := age.Encrypt(&ciphertext2, identity2.Recipient())
|
|
if err != nil {
|
|
t.Fatalf("failed to create second encryptor: %v", err)
|
|
}
|
|
|
|
_, err = encryptor2.Write(testData)
|
|
if err != nil {
|
|
t.Fatalf("failed to write data to second encryptor: %v", err)
|
|
}
|
|
|
|
err = encryptor2.Close()
|
|
if err != nil {
|
|
t.Fatalf("failed to close second encryptor: %v", err)
|
|
}
|
|
|
|
// Decrypt with the second identity
|
|
decryptor2, err := age.Decrypt(
|
|
bytes.NewReader(ciphertext2.Bytes()),
|
|
identity2,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("failed to create second decryptor: %v", err)
|
|
}
|
|
|
|
decryptedData2, err := io.ReadAll(decryptor2)
|
|
if err != nil {
|
|
t.Fatalf("failed to read second decrypted data: %v", err)
|
|
}
|
|
|
|
if !bytes.Equal(testData, decryptedData2) {
|
|
t.Fatalf("second decrypted data does not match original data")
|
|
}
|
|
|
|
t.Logf("✓ Cross-verification with second identity successful")
|
|
}
|