package agehd

import (
	"bytes"
	"crypto/rand"
	"fmt"
	"io"
	"strings"
	"testing"

	"filippo.io/age"
	"github.com/tyler-smith/go-bip39"
)

const (
	mnemonic = "abandon abandon abandon abandon abandon " +
		"abandon abandon abandon abandon abandon abandon about"

	// Test xprv from BIP85 test vectors
	testXPRV = "xprv9s21ZrQH143K2LBWUUQRFXhucrQqBpKdRRxNVq2zBqsx8HVqFk2uYo8kmbaLLHRdqtQpUm98uKfu3vca1LqdGhUtyoFnCNkfmXRyPXLjbKb"

	// Additional test mnemonics for comprehensive testing
	testMnemonic12 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
	testMnemonic15 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
	testMnemonic18 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
	testMnemonic21 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
	testMnemonic24 = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon art"

	// Test messages used throughout the tests
	testMessageHelloWorld    = "hello world"
	testMessageHelloFromXPRV = "hello from xprv"
	testMessageGeneric       = "test message"
	testMessageBoundary      = "boundary test"
	testMessageBenchmark     = "benchmark test message"
	testMessageLargePattern  = "A"

	// Error messages for validation
	errorMsgNeed32Bytes = "need 32-byte scalar, got"
	errorMsgInvalidXPRV = "invalid-xprv"

	// Test constants for various scenarios
	testSkipMessage = "Skipping consistency test - test mnemonic and xprv are from different sources"

	// Numeric constants for testing
	testNumGoroutines = 10
	testNumIterations = 100

	// Large data test constants
	testDataSizeMegabyte = 1024 * 1024 // 1 MB
)

func TestEncryptDecrypt(t *testing.T) {
	id, err := DeriveIdentity(mnemonic, 0)
	if err != nil {
		t.Fatalf("derive: %v", err)
	}

	t.Logf("secret:    %s", id.String())
	t.Logf("recipient: %s", id.Recipient().String())

	var ct bytes.Buffer
	w, err := age.Encrypt(&ct, id.Recipient())
	if err != nil {
		t.Fatalf("encrypt init: %v", err)
	}
	if _, err = io.WriteString(w, testMessageHelloWorld); err != nil {
		t.Fatalf("write: %v", err)
	}
	if err = w.Close(); err != nil {
		t.Fatalf("encrypt close: %v", err)
	}

	r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), id)
	if err != nil {
		t.Fatalf("decrypt init: %v", err)
	}
	dec, err := io.ReadAll(r)
	if err != nil {
		t.Fatalf("read: %v", err)
	}

	if got := string(dec); got != testMessageHelloWorld {
		t.Fatalf("round-trip mismatch: %q", got)
	}
}

func TestDeriveIdentityFromXPRV(t *testing.T) {
	id, err := DeriveIdentityFromXPRV(testXPRV, 0)
	if err != nil {
		t.Fatalf("derive from xprv: %v", err)
	}

	t.Logf("xprv secret:    %s", id.String())
	t.Logf("xprv recipient: %s", id.Recipient().String())

	// Test encryption/decryption with xprv-derived identity
	var ct bytes.Buffer
	w, err := age.Encrypt(&ct, id.Recipient())
	if err != nil {
		t.Fatalf("encrypt init: %v", err)
	}
	if _, err = io.WriteString(w, testMessageHelloFromXPRV); err != nil {
		t.Fatalf("write: %v", err)
	}
	if err = w.Close(); err != nil {
		t.Fatalf("encrypt close: %v", err)
	}

	r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), id)
	if err != nil {
		t.Fatalf("decrypt init: %v", err)
	}
	dec, err := io.ReadAll(r)
	if err != nil {
		t.Fatalf("read: %v", err)
	}

	if got := string(dec); got != testMessageHelloFromXPRV {
		t.Fatalf("round-trip mismatch: %q", got)
	}
}

func TestDeterministicDerivation(t *testing.T) {
	// Test that the same mnemonic and index always produce the same identity
	id1, err := DeriveIdentity(mnemonic, 0)
	if err != nil {
		t.Fatalf("derive 1: %v", err)
	}

	id2, err := DeriveIdentity(mnemonic, 0)
	if err != nil {
		t.Fatalf("derive 2: %v", err)
	}

	if id1.String() != id2.String() {
		t.Fatalf("identities should be deterministic: %s != %s", id1.String(), id2.String())
	}

	// Test that different indices produce different identities
	id3, err := DeriveIdentity(mnemonic, 1)
	if err != nil {
		t.Fatalf("derive 3: %v", err)
	}

	if id1.String() == id3.String() {
		t.Fatalf("different indices should produce different identities")
	}

	t.Logf("Index 0: %s", id1.String())
	t.Logf("Index 1: %s", id3.String())
}

func TestDeterministicXPRVDerivation(t *testing.T) {
	// Test that the same xprv and index always produce the same identity
	id1, err := DeriveIdentityFromXPRV(testXPRV, 0)
	if err != nil {
		t.Fatalf("derive 1: %v", err)
	}

	id2, err := DeriveIdentityFromXPRV(testXPRV, 0)
	if err != nil {
		t.Fatalf("derive 2: %v", err)
	}

	if id1.String() != id2.String() {
		t.Fatalf("xprv identities should be deterministic: %s != %s", id1.String(), id2.String())
	}

	// Test that different indices with same xprv produce different identities
	id3, err := DeriveIdentityFromXPRV(testXPRV, 1)
	if err != nil {
		t.Fatalf("derive 3: %v", err)
	}

	if id1.String() == id3.String() {
		t.Fatalf("different indices should produce different identities")
	}

	t.Logf("XPRV Index 0: %s", id1.String())
	t.Logf("XPRV Index 1: %s", id3.String())
}

func TestMnemonicVsXPRVConsistency(t *testing.T) {
	// Test that deriving from mnemonic and from the corresponding xprv produces the same result
	// Note: This test is removed because the test mnemonic and test xprv are from different sources
	// and are not expected to produce the same results.
	t.Skip(testSkipMessage)
}

func TestEntropyLength(t *testing.T) {
	// Test that DeriveEntropy returns exactly 32 bytes
	entropy, err := DeriveEntropy(mnemonic, 0)
	if err != nil {
		t.Fatalf("derive entropy: %v", err)
	}

	if len(entropy) != 32 {
		t.Fatalf("expected 32 bytes of entropy, got %d", len(entropy))
	}

	t.Logf("Entropy (32 bytes): %x", entropy)

	// Test that DeriveEntropyFromXPRV returns exactly 32 bytes
	entropyXPRV, err := DeriveEntropyFromXPRV(testXPRV, 0)
	if err != nil {
		t.Fatalf("derive entropy from xprv: %v", err)
	}

	if len(entropyXPRV) != 32 {
		t.Fatalf("expected 32 bytes of entropy from xprv, got %d", len(entropyXPRV))
	}

	t.Logf("XPRV Entropy (32 bytes): %x", entropyXPRV)

	// Note: We don't compare the entropy values since the test mnemonic and test xprv
	// are from different sources and should produce different entropy values.
}

func TestIdentityFromEntropy(t *testing.T) {
	// Test that IdentityFromEntropy works with custom entropy
	entropy := make([]byte, 32)
	for i := range entropy {
		entropy[i] = byte(i)
	}

	id, err := IdentityFromEntropy(entropy)
	if err != nil {
		t.Fatalf("identity from entropy: %v", err)
	}

	t.Logf("Custom entropy identity: %s", id.String())

	// Test that it rejects wrong-sized entropy
	_, err = IdentityFromEntropy(entropy[:31])
	if err == nil {
		t.Fatalf("expected error for 31-byte entropy")
	}

	// Create a 33-byte slice to test rejection
	entropy33 := make([]byte, 33)
	copy(entropy33, entropy)
	_, err = IdentityFromEntropy(entropy33)
	if err == nil {
		t.Fatalf("expected error for 33-byte entropy")
	}
}

func TestInvalidXPRV(t *testing.T) {
	// Test with invalid xprv
	_, err := DeriveIdentityFromXPRV(errorMsgInvalidXPRV, 0)
	if err == nil {
		t.Fatalf("expected error for invalid xprv")
	}

	t.Logf("Got expected error for invalid xprv: %v", err)
}

// TestClampFunction tests the RFC-7748 clamping function
func TestClampFunction(t *testing.T) {
	tests := []struct {
		name     string
		input    []byte
		expected []byte
	}{
		{
			name:     "all zeros",
			input:    make([]byte, 32),
			expected: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64},
		},
		{
			name:     "all ones",
			input:    bytes.Repeat([]byte{255}, 32),
			expected: append([]byte{248}, append(bytes.Repeat([]byte{255}, 30), 127)...),
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			input := make([]byte, 32)
			copy(input, tt.input)
			clamp(input)

			// Check specific bits that should be clamped
			if input[0]&7 != 0 {
				t.Errorf("first byte should have bottom 3 bits cleared, got %08b", input[0])
			}
			if input[31]&128 != 0 {
				t.Errorf("last byte should have top bit cleared, got %08b", input[31])
			}
			if input[31]&64 == 0 {
				t.Errorf("last byte should have second-to-top bit set, got %08b", input[31])
			}
		})
	}
}

// TestIdentityFromEntropyEdgeCases tests edge cases for IdentityFromEntropy
func TestIdentityFromEntropyEdgeCases(t *testing.T) {
	tests := []struct {
		name        string
		entropy     []byte
		expectError bool
		errorMsg    string
	}{
		{
			name:        "nil entropy",
			entropy:     nil,
			expectError: true,
			errorMsg:    errorMsgNeed32Bytes + " 0",
		},
		{
			name:        "empty entropy",
			entropy:     []byte{},
			expectError: true,
			errorMsg:    errorMsgNeed32Bytes + " 0",
		},
		{
			name:        "too short entropy",
			entropy:     make([]byte, 31),
			expectError: true,
			errorMsg:    errorMsgNeed32Bytes + " 31",
		},
		{
			name:        "too long entropy",
			entropy:     make([]byte, 33),
			expectError: true,
			errorMsg:    errorMsgNeed32Bytes + " 33",
		},
		{
			name:        "valid 32-byte entropy",
			entropy:     make([]byte, 32),
			expectError: false,
		},
		{
			name: "random valid entropy",
			entropy: func() []byte {
				b := make([]byte, 32)
				if _, err := rand.Read(b); err != nil {
					panic(err) // In test context, panic is acceptable for setup failures
				}
				return b
			}(),
			expectError: false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			identity, err := IdentityFromEntropy(tt.entropy)

			if tt.expectError {
				if err == nil {
					t.Errorf("expected error but got none")
				} else if !strings.Contains(err.Error(), tt.errorMsg) {
					t.Errorf("expected error containing %q, got %q", tt.errorMsg, err.Error())
				}
				if identity != nil {
					t.Errorf("expected nil identity on error, got %v", identity)
				}
			} else {
				if err != nil {
					t.Errorf("unexpected error: %v", err)
				}
				if identity == nil {
					t.Errorf("expected valid identity, got nil")
				}
			}
		})
	}
}

// TestDeriveEntropyInvalidMnemonic tests error handling for invalid mnemonics
func TestDeriveEntropyInvalidMnemonic(t *testing.T) {
	tests := []struct {
		name     string
		mnemonic string
	}{
		{
			name:     "empty mnemonic",
			mnemonic: "",
		},
		{
			name:     "single word",
			mnemonic: "abandon",
		},
		{
			name:     "invalid word",
			mnemonic: "invalid word sequence that does not exist in bip39",
		},
		{
			name:     "wrong word count",
			mnemonic: "abandon abandon abandon abandon abandon",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			// Note: BIP39 library is quite permissive and doesn't validate
			// mnemonic words strictly, so we mainly test that the function
			// doesn't panic and produces some result
			entropy, err := DeriveEntropy(tt.mnemonic, 0)
			if err != nil {
				t.Logf("Got error for invalid mnemonic %q: %v", tt.name, err)
			} else {
				if len(entropy) != 32 {
					t.Errorf("expected 32 bytes even for invalid mnemonic, got %d", len(entropy))
				}
				t.Logf("Invalid mnemonic %q produced entropy: %x", tt.name, entropy)
			}
		})
	}
}

// TestDeriveEntropyFromXPRVInvalidInputs tests error handling for invalid XPRVs
func TestDeriveEntropyFromXPRVInvalidInputs(t *testing.T) {
	tests := []struct {
		name        string
		xprv        string
		expectError bool
	}{
		{
			name:        "empty xprv",
			xprv:        "",
			expectError: true,
		},
		{
			name:        "invalid base58",
			xprv:        "invalid-base58-string-!@#$%",
			expectError: true,
		},
		{
			name:        "wrong prefix",
			xprv:        "xpub661MyMwAqRbcFtXgS5sYJABqqG9YLmC4Q1Rdap9gSE8NqtwybGhePY2gZ29ESFjqJoCu1Rupje8YtGqsefD265TMg7usUDFdp6W1EGMcet8",
			expectError: true,
		},
		{
			name:        "truncated xprv",
			xprv:        "xprv9s21ZrQH143K2LBWUUQRFXhucrQqBpKdRRxNVq2zBqsx8HVqFk2uYo8kmbaLLHRdqtQpUm98uKfu3vca1LqdGhUtyoFnCNkfmXRyPXLj",
			expectError: true,
		},
		{
			name:        "valid xprv",
			xprv:        testXPRV,
			expectError: false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			entropy, err := DeriveEntropyFromXPRV(tt.xprv, 0)

			if tt.expectError {
				if err == nil {
					t.Errorf("expected error for invalid xprv %q", tt.name)
				} else {
					t.Logf("Got expected error for %q: %v", tt.name, err)
				}
			} else {
				if err != nil {
					t.Errorf("unexpected error for valid xprv: %v", err)
				}
				if len(entropy) != 32 {
					t.Errorf("expected 32 bytes of entropy, got %d", len(entropy))
				}
			}
		})
	}
}

// TestDifferentMnemonicLengths tests derivation with different mnemonic lengths
func TestDifferentMnemonicLengths(t *testing.T) {
	mnemonics := map[string]string{
		"12 words": testMnemonic12,
		"15 words": testMnemonic15,
		"18 words": testMnemonic18,
		"21 words": testMnemonic21,
		"24 words": testMnemonic24,
	}

	for name, mnemonic := range mnemonics {
		t.Run(name, func(t *testing.T) {
			identity, err := DeriveIdentity(mnemonic, 0)
			if err != nil {
				t.Fatalf("failed to derive identity from %s: %v", name, err)
			}

			// Test that we can encrypt/decrypt
			var ct bytes.Buffer
			w, err := age.Encrypt(&ct, identity.Recipient())
			if err != nil {
				t.Fatalf("encrypt init: %v", err)
			}
			if _, err = io.WriteString(w, testMessageGeneric); err != nil {
				t.Fatalf("write: %v", err)
			}
			if err = w.Close(); err != nil {
				t.Fatalf("encrypt close: %v", err)
			}

			r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
			if err != nil {
				t.Fatalf("decrypt init: %v", err)
			}
			dec, err := io.ReadAll(r)
			if err != nil {
				t.Fatalf("read: %v", err)
			}

			if string(dec) != testMessageGeneric {
				t.Fatalf("round-trip failed for %s", name)
			}

			t.Logf("%s identity: %s", name, identity.String())
		})
	}
}

// TestIndexBoundaries tests derivation with various index values
func TestIndexBoundaries(t *testing.T) {
	indices := []uint32{
		0,          // minimum
		1,          // basic
		100,        // moderate
		1000,       // larger
		0x7FFFFFFF, // maximum hardened index
		0xFFFFFFFF, // maximum uint32
	}

	for _, index := range indices {
		t.Run(fmt.Sprintf("index_%d", index), func(t *testing.T) {
			identity, err := DeriveIdentity(mnemonic, index)
			if err != nil {
				t.Fatalf("failed to derive identity at index %d: %v", index, err)
			}

			// Verify the identity is valid by testing encryption/decryption
			var ct bytes.Buffer
			w, err := age.Encrypt(&ct, identity.Recipient())
			if err != nil {
				t.Fatalf("encrypt init at index %d: %v", index, err)
			}
			if _, err = io.WriteString(w, testMessageBoundary); err != nil {
				t.Fatalf("write at index %d: %v", index, err)
			}
			if err = w.Close(); err != nil {
				t.Fatalf("encrypt close at index %d: %v", index, err)
			}

			r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
			if err != nil {
				t.Fatalf("decrypt init at index %d: %v", index, err)
			}
			dec, err := io.ReadAll(r)
			if err != nil {
				t.Fatalf("read at index %d: %v", index, err)
			}

			if string(dec) != testMessageBoundary {
				t.Fatalf("round-trip failed at index %d", index)
			}

			t.Logf("Index %d identity: %s", index, identity.String())
		})
	}
}

// TestEntropyUniqueness tests that different inputs produce different entropy
func TestEntropyUniqueness(t *testing.T) {
	// Test different indices with same mnemonic
	entropy1, err := DeriveEntropy(mnemonic, 0)
	if err != nil {
		t.Fatalf("derive entropy 1: %v", err)
	}

	entropy2, err := DeriveEntropy(mnemonic, 1)
	if err != nil {
		t.Fatalf("derive entropy 2: %v", err)
	}

	if bytes.Equal(entropy1, entropy2) {
		t.Fatalf("different indices should produce different entropy")
	}

	// Test different mnemonics with same index
	entropy3, err := DeriveEntropy(testMnemonic24, 0)
	if err != nil {
		t.Fatalf("derive entropy 3: %v", err)
	}

	if bytes.Equal(entropy1, entropy3) {
		t.Fatalf("different mnemonics should produce different entropy")
	}

	t.Logf("Entropy uniqueness verified across indices and mnemonics")
}

// TestConcurrentDerivation tests that derivation is safe for concurrent use
func TestConcurrentDerivation(t *testing.T) {
	results := make(chan string, testNumGoroutines*testNumIterations)
	errors := make(chan error, testNumGoroutines*testNumIterations)

	for i := 0; i < testNumGoroutines; i++ {
		go func(goroutineID int) {
			for j := 0; j < testNumIterations; j++ {
				identity, err := DeriveIdentity(mnemonic, uint32(j))
				if err != nil {
					errors <- err
					return
				}
				results <- identity.String()
			}
		}(i)
	}

	// Collect results
	resultMap := make(map[string]int)
	for i := 0; i < testNumGoroutines*testNumIterations; i++ {
		select {
		case result := <-results:
			resultMap[result]++
		case err := <-errors:
			t.Fatalf("concurrent derivation error: %v", err)
		}
	}

	// Verify that each index produced the same result across all goroutines
	expectedResults := testNumGoroutines
	for result, count := range resultMap {
		if count != expectedResults {
			t.Errorf("result %s appeared %d times, expected %d", result, count, expectedResults)
		}
	}

	t.Logf("Concurrent derivation test passed with %d unique results", len(resultMap))
}

// Benchmark tests
func BenchmarkDeriveIdentity(b *testing.B) {
	for i := 0; i < b.N; i++ {
		_, err := DeriveIdentity(mnemonic, uint32(i%1000))
		if err != nil {
			b.Fatalf("derive identity: %v", err)
		}
	}
}

func BenchmarkDeriveIdentityFromXPRV(b *testing.B) {
	for i := 0; i < b.N; i++ {
		_, err := DeriveIdentityFromXPRV(testXPRV, uint32(i%1000))
		if err != nil {
			b.Fatalf("derive identity from xprv: %v", err)
		}
	}
}

func BenchmarkDeriveEntropy(b *testing.B) {
	for i := 0; i < b.N; i++ {
		_, err := DeriveEntropy(mnemonic, uint32(i%1000))
		if err != nil {
			b.Fatalf("derive entropy: %v", err)
		}
	}
}

func BenchmarkIdentityFromEntropy(b *testing.B) {
	entropy := make([]byte, 32)
	if _, err := rand.Read(entropy); err != nil {
		b.Fatalf("failed to generate random entropy: %v", err)
	}

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_, err := IdentityFromEntropy(entropy)
		if err != nil {
			b.Fatalf("identity from entropy: %v", err)
		}
	}
}

func BenchmarkEncryptDecrypt(b *testing.B) {
	identity, err := DeriveIdentity(mnemonic, 0)
	if err != nil {
		b.Fatalf("derive identity: %v", err)
	}

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		var ct bytes.Buffer
		w, err := age.Encrypt(&ct, identity.Recipient())
		if err != nil {
			b.Fatalf("encrypt init: %v", err)
		}
		if _, err = io.WriteString(w, testMessageBenchmark); err != nil {
			b.Fatalf("write: %v", err)
		}
		if err = w.Close(); err != nil {
			b.Fatalf("encrypt close: %v", err)
		}

		r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
		if err != nil {
			b.Fatalf("decrypt init: %v", err)
		}
		_, err = io.ReadAll(r)
		if err != nil {
			b.Fatalf("read: %v", err)
		}
	}
}

// TestConstants verifies the hardcoded constants
func TestConstants(t *testing.T) {
	if purpose != 83696968 {
		t.Errorf("purpose constant mismatch: expected 83696968, got %d", purpose)
	}
	if vendorID != 592366788 {
		t.Errorf("vendorID constant mismatch: expected 592366788, got %d", vendorID)
	}
	if appID != 733482323 {
		t.Errorf("appID constant mismatch: expected 733482323, got %d", appID)
	}
	if hrp != "age-secret-key-" {
		t.Errorf("hrp constant mismatch: expected 'age-secret-key-', got %q", hrp)
	}
}

// TestIdentityStringFormat tests that generated identities have the correct format
func TestIdentityStringFormat(t *testing.T) {
	identity, err := DeriveIdentity(mnemonic, 0)
	if err != nil {
		t.Fatalf("derive identity: %v", err)
	}

	secretKey := identity.String()
	recipient := identity.Recipient().String()

	// Check secret key format
	if !strings.HasPrefix(secretKey, "AGE-SECRET-KEY-") {
		t.Errorf("secret key should start with 'AGE-SECRET-KEY-', got: %s", secretKey)
	}

	// Check recipient format
	if !strings.HasPrefix(recipient, "age1") {
		t.Errorf("recipient should start with 'age1', got: %s", recipient)
	}

	// Check that they're different
	if secretKey == recipient {
		t.Errorf("secret key and recipient should be different")
	}

	t.Logf("Secret key format: %s", secretKey)
	t.Logf("Recipient format: %s", recipient)
}

// TestLargeMessageEncryption tests encryption/decryption of larger messages
func TestLargeMessageEncryption(t *testing.T) {
	identity, err := DeriveIdentity(mnemonic, 0)
	if err != nil {
		t.Fatalf("derive identity: %v", err)
	}

	// Test with different message sizes
	sizes := []int{1, 100, 1024, 10240, 100000}

	for _, size := range sizes {
		t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
			message := strings.Repeat(testMessageLargePattern, size)

			var ct bytes.Buffer
			w, err := age.Encrypt(&ct, identity.Recipient())
			if err != nil {
				t.Fatalf("encrypt init: %v", err)
			}
			if _, err = io.WriteString(w, message); err != nil {
				t.Fatalf("write: %v", err)
			}
			if err = w.Close(); err != nil {
				t.Fatalf("encrypt close: %v", err)
			}

			r, err := age.Decrypt(bytes.NewReader(ct.Bytes()), identity)
			if err != nil {
				t.Fatalf("decrypt init: %v", err)
			}
			dec, err := io.ReadAll(r)
			if err != nil {
				t.Fatalf("read: %v", err)
			}

			if string(dec) != message {
				t.Fatalf("message size %d: round-trip failed", size)
			}

			t.Logf("Successfully encrypted/decrypted %d byte message", size)
		})
	}
}

// TestRandomMnemonicDeterministicGeneration tests that:
// 1. A random mnemonic generates the same keys deterministically
// 2. Large data (1MB) can be encrypted and decrypted successfully
func TestRandomMnemonicDeterministicGeneration(t *testing.T) {
	// Generate a random mnemonic using the BIP39 library
	entropy := make([]byte, 32) // 256 bits for 24-word mnemonic
	if _, err := rand.Read(entropy); err != nil {
		t.Fatalf("failed to generate random entropy: %v", err)
	}

	randomMnemonic, err := bip39.NewMnemonic(entropy)
	if err != nil {
		t.Fatalf("failed to generate random mnemonic: %v", err)
	}

	t.Logf("Generated random mnemonic: %s", randomMnemonic)

	// Test index for key derivation
	testIndex := uint32(42)

	// Generate the first identity
	identity1, err := DeriveIdentity(randomMnemonic, testIndex)
	if err != nil {
		t.Fatalf("failed to derive first identity: %v", err)
	}

	// Generate the second identity with the same mnemonic and index
	identity2, err := DeriveIdentity(randomMnemonic, testIndex)
	if err != nil {
		t.Fatalf("failed to derive second identity: %v", err)
	}

	// Verify that both private keys are identical
	privateKey1 := identity1.String()
	privateKey2 := identity2.String()
	if privateKey1 != privateKey2 {
		t.Fatalf("private keys should be identical:\nFirst:  %s\nSecond: %s", privateKey1, privateKey2)
	}

	// Verify that both public keys (recipients) are identical
	publicKey1 := identity1.Recipient().String()
	publicKey2 := identity2.Recipient().String()
	if publicKey1 != publicKey2 {
		t.Fatalf("public keys should be identical:\nFirst:  %s\nSecond: %s", publicKey1, publicKey2)
	}

	t.Logf("✓ Deterministic generation verified")
	t.Logf("Private key: %s", privateKey1)
	t.Logf("Public key:  %s", publicKey1)

	// Generate 1 MB of random data for encryption test
	testData := make([]byte, testDataSizeMegabyte)
	if _, err := rand.Read(testData); err != nil {
		t.Fatalf("failed to generate random test data: %v", err)
	}

	t.Logf("Generated %d bytes of random test data", len(testData))

	// Encrypt the data using the public key (recipient)
	var ciphertext bytes.Buffer
	encryptor, err := age.Encrypt(&ciphertext, identity1.Recipient())
	if err != nil {
		t.Fatalf("failed to create encryptor: %v", err)
	}

	_, err = encryptor.Write(testData)
	if err != nil {
		t.Fatalf("failed to write data to encryptor: %v", err)
	}

	err = encryptor.Close()
	if err != nil {
		t.Fatalf("failed to close encryptor: %v", err)
	}

	t.Logf("✓ Encrypted %d bytes into %d bytes of ciphertext", len(testData), ciphertext.Len())

	// Decrypt the data using the private key
	decryptor, err := age.Decrypt(bytes.NewReader(ciphertext.Bytes()), identity1)
	if err != nil {
		t.Fatalf("failed to create decryptor: %v", err)
	}

	decryptedData, err := io.ReadAll(decryptor)
	if err != nil {
		t.Fatalf("failed to read decrypted data: %v", err)
	}

	t.Logf("✓ Decrypted %d bytes", len(decryptedData))

	// Verify that the decrypted data matches the original
	if len(decryptedData) != len(testData) {
		t.Fatalf("decrypted data length mismatch: expected %d, got %d", len(testData), len(decryptedData))
	}

	if !bytes.Equal(testData, decryptedData) {
		t.Fatalf("decrypted data does not match original data")
	}

	t.Logf("✓ Large data encryption/decryption test passed successfully")

	// Additional verification: test with the second identity (should work identically)
	var ciphertext2 bytes.Buffer
	encryptor2, err := age.Encrypt(&ciphertext2, identity2.Recipient())
	if err != nil {
		t.Fatalf("failed to create second encryptor: %v", err)
	}

	_, err = encryptor2.Write(testData)
	if err != nil {
		t.Fatalf("failed to write data to second encryptor: %v", err)
	}

	err = encryptor2.Close()
	if err != nil {
		t.Fatalf("failed to close second encryptor: %v", err)
	}

	// Decrypt with the second identity
	decryptor2, err := age.Decrypt(bytes.NewReader(ciphertext2.Bytes()), identity2)
	if err != nil {
		t.Fatalf("failed to create second decryptor: %v", err)
	}

	decryptedData2, err := io.ReadAll(decryptor2)
	if err != nil {
		t.Fatalf("failed to read second decrypted data: %v", err)
	}

	if !bytes.Equal(testData, decryptedData2) {
		t.Fatalf("second decrypted data does not match original data")
	}

	t.Logf("✓ Cross-verification with second identity successful")
}