diff --git a/internal/seal/crypto.go b/internal/seal/crypto.go new file mode 100644 index 0000000..f405022 --- /dev/null +++ b/internal/seal/crypto.go @@ -0,0 +1,84 @@ +// Package seal provides authenticated encryption utilities for pixa. +// It uses NaCl secretbox (XSalsa20-Poly1305) for sealing data +// and HKDF-SHA256 for key derivation. +package seal + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + + "golang.org/x/crypto/hkdf" + "golang.org/x/crypto/nacl/secretbox" +) + +// Key sizes for NaCl secretbox. +const ( + KeySize = 32 // NaCl secretbox key size + NonceSize = 24 // NaCl secretbox nonce size +) + +// Errors returned by crypto operations. +var ( + ErrDecryptionFailed = errors.New("decryption failed: invalid ciphertext or key") + ErrInvalidPayload = errors.New("invalid encrypted payload") + ErrKeyDerivation = errors.New("key derivation failed") +) + +// DeriveKey uses HKDF-SHA256 to derive a key from master key material. +// The salt parameter provides domain separation between different key usages. +func DeriveKey(masterKey []byte, salt string) ([KeySize]byte, error) { + var key [KeySize]byte + + hkdfReader := hkdf.New(sha256.New, masterKey, []byte(salt), nil) + + if _, err := io.ReadFull(hkdfReader, key[:]); err != nil { + return key, ErrKeyDerivation + } + + return key, nil +} + +// Encrypt encrypts plaintext using NaCl secretbox (XSalsa20-Poly1305). +// Returns base64url-encoded ciphertext with the nonce prepended. +func Encrypt(key [KeySize]byte, plaintext []byte) (string, error) { + // Generate random nonce + var nonce [NonceSize]byte + if _, err := rand.Read(nonce[:]); err != nil { + return "", err + } + + // Encrypt: nonce + secretbox.Seal(plaintext) + encrypted := secretbox.Seal(nonce[:], plaintext, &nonce, &key) + + return base64.RawURLEncoding.EncodeToString(encrypted), nil +} + +// Decrypt decrypts base64url-encoded ciphertext using NaCl secretbox. +// Expects the nonce to be prepended to the ciphertext. +func Decrypt(key [KeySize]byte, ciphertext string) ([]byte, error) { + // Decode base64url + data, err := base64.RawURLEncoding.DecodeString(ciphertext) + if err != nil { + return nil, ErrInvalidPayload + } + + // Check minimum length (nonce + at least some ciphertext + auth tag) + if len(data) < NonceSize+secretbox.Overhead { + return nil, ErrInvalidPayload + } + + // Extract nonce + var nonce [NonceSize]byte + copy(nonce[:], data[:NonceSize]) + + // Decrypt + plaintext, ok := secretbox.Open(nil, data[NonceSize:], &nonce, &key) + if !ok { + return nil, ErrDecryptionFailed + } + + return plaintext, nil +} diff --git a/internal/seal/crypto_test.go b/internal/seal/crypto_test.go new file mode 100644 index 0000000..4bf38f1 --- /dev/null +++ b/internal/seal/crypto_test.go @@ -0,0 +1,184 @@ +package seal + +import ( + "bytes" + "testing" +) + +func TestDeriveKey_Consistent(t *testing.T) { + masterKey := []byte("test-master-key-12345") + salt := "test-salt-v1" + + key1, err := DeriveKey(masterKey, salt) + if err != nil { + t.Fatalf("DeriveKey() error = %v", err) + } + + key2, err := DeriveKey(masterKey, salt) + if err != nil { + t.Fatalf("DeriveKey() error = %v", err) + } + + if key1 != key2 { + t.Error("DeriveKey() should produce consistent keys for same input") + } +} + +func TestDeriveKey_DifferentSalts(t *testing.T) { + masterKey := []byte("test-master-key-12345") + + key1, err := DeriveKey(masterKey, "salt-1") + if err != nil { + t.Fatalf("DeriveKey() error = %v", err) + } + + key2, err := DeriveKey(masterKey, "salt-2") + if err != nil { + t.Fatalf("DeriveKey() error = %v", err) + } + + if key1 == key2 { + t.Error("DeriveKey() should produce different keys for different salts") + } +} + +func TestDeriveKey_DifferentMasterKeys(t *testing.T) { + salt := "test-salt" + + key1, err := DeriveKey([]byte("master-key-1"), salt) + if err != nil { + t.Fatalf("DeriveKey() error = %v", err) + } + + key2, err := DeriveKey([]byte("master-key-2"), salt) + if err != nil { + t.Fatalf("DeriveKey() error = %v", err) + } + + if key1 == key2 { + t.Error("DeriveKey() should produce different keys for different master keys") + } +} + +func TestEncryptDecrypt_RoundTrip(t *testing.T) { + key, err := DeriveKey([]byte("test-key"), "test-salt") + if err != nil { + t.Fatalf("DeriveKey() error = %v", err) + } + + plaintext := []byte("hello, world! this is a test message.") + + ciphertext, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + decrypted, err := Decrypt(key, ciphertext) + if err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Errorf("Decrypt() = %q, want %q", decrypted, plaintext) + } +} + +func TestEncryptDecrypt_EmptyPlaintext(t *testing.T) { + key, _ := DeriveKey([]byte("test-key"), "test-salt") + plaintext := []byte{} + + ciphertext, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + decrypted, err := Decrypt(key, ciphertext) + if err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + + if !bytes.Equal(plaintext, decrypted) { + t.Errorf("Decrypt() = %q, want %q", decrypted, plaintext) + } +} + +func TestDecrypt_WrongKey(t *testing.T) { + key1, _ := DeriveKey([]byte("key-1"), "salt") + key2, _ := DeriveKey([]byte("key-2"), "salt") + + plaintext := []byte("secret message") + + ciphertext, err := Encrypt(key1, plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + _, err = Decrypt(key2, ciphertext) + if err == nil { + t.Error("Decrypt() should fail with wrong key") + } + + if err != ErrDecryptionFailed { + t.Errorf("Decrypt() error = %v, want %v", err, ErrDecryptionFailed) + } +} + +func TestDecrypt_TamperedCiphertext(t *testing.T) { + key, _ := DeriveKey([]byte("test-key"), "test-salt") + plaintext := []byte("secret message") + + ciphertext, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + // Tamper with the ciphertext (flip a bit in the middle) + tampered := []byte(ciphertext) + if len(tampered) > 10 { + tampered[10] ^= 0x01 + } + + _, err = Decrypt(key, string(tampered)) + if err == nil { + t.Error("Decrypt() should fail with tampered ciphertext") + } +} + +func TestDecrypt_InvalidBase64(t *testing.T) { + key, _ := DeriveKey([]byte("test-key"), "test-salt") + + _, err := Decrypt(key, "not-valid-base64!!!") + if err == nil { + t.Error("Decrypt() should fail with invalid base64") + } + + if err != ErrInvalidPayload { + t.Errorf("Decrypt() error = %v, want %v", err, ErrInvalidPayload) + } +} + +func TestDecrypt_TooShort(t *testing.T) { + key, _ := DeriveKey([]byte("test-key"), "test-salt") + + // Create a base64 string that's too short to contain nonce + auth tag + _, err := Decrypt(key, "dG9vLXNob3J0") + if err == nil { + t.Error("Decrypt() should fail with too-short ciphertext") + } + + if err != ErrInvalidPayload { + t.Errorf("Decrypt() error = %v, want %v", err, ErrInvalidPayload) + } +} + +func TestEncrypt_ProducesDifferentCiphertexts(t *testing.T) { + key, _ := DeriveKey([]byte("test-key"), "test-salt") + plaintext := []byte("same message") + + ciphertext1, _ := Encrypt(key, plaintext) + ciphertext2, _ := Encrypt(key, plaintext) + + if ciphertext1 == ciphertext2 { + t.Error("Encrypt() should produce different ciphertexts due to random nonce") + } +}