Add custom types, version command, and restore --verify flag

- Add internal/types package with type-safe wrappers for IDs, hashes,
  paths, and credentials (FileID, BlobID, ChunkHash, etc.)
- Implement driver.Valuer and sql.Scanner for UUID-based types
- Add `vaultik version` command showing version, commit, go version
- Add `--verify` flag to restore command that checksums all restored
  files against expected chunk hashes with progress bar
- Remove fetch.go (dead code, functionality in restore)
- Clean up TODO.md, remove completed items
- Update all database and snapshot code to use new custom types
This commit is contained in:
2026-01-14 17:11:52 -08:00
parent 2afd54d693
commit 417b25a5f5
53 changed files with 2330 additions and 1581 deletions

View File

@@ -26,6 +26,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/blobgen"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/google/uuid"
"github.com/spf13/afero"
)
@@ -262,19 +263,22 @@ func (p *Packer) startNewBlob() error {
// Create blob record in database
if p.repos != nil {
blobIDTyped, err := types.ParseBlobID(blobID)
if err != nil {
return fmt.Errorf("parsing blob ID: %w", err)
}
blob := &database.Blob{
ID: blobID,
Hash: "temp-placeholder-" + blobID, // Temporary placeholder until finalized
ID: blobIDTyped,
Hash: types.BlobHash("temp-placeholder-" + blobID), // Temporary placeholder until finalized
CreatedTS: time.Now().UTC(),
FinishedTS: nil,
UncompressedSize: 0,
CompressedSize: 0,
UploadedTS: nil,
}
err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
if err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
return p.repos.Blobs.Create(ctx, tx, blob)
})
if err != nil {
}); err != nil {
return fmt.Errorf("creating blob record: %w", err)
}
}
@@ -403,11 +407,16 @@ func (p *Packer) finalizeCurrentBlob() error {
// Insert pending chunks, blob_chunks, and update blob in a single transaction
if p.repos != nil {
blobIDTyped, parseErr := types.ParseBlobID(p.currentBlob.id)
if parseErr != nil {
p.cleanupTempFile()
return fmt.Errorf("parsing blob ID: %w", parseErr)
}
err := p.repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
// First insert all pending chunks (required for blob_chunks FK)
for _, chunk := range chunksToInsert {
dbChunk := &database.Chunk{
ChunkHash: chunk.Hash,
ChunkHash: types.ChunkHash(chunk.Hash),
Size: chunk.Size,
}
if err := p.repos.Chunks.Create(ctx, tx, dbChunk); err != nil {
@@ -418,8 +427,8 @@ func (p *Packer) finalizeCurrentBlob() error {
// Insert all blob_chunk records in batch
for _, chunk := range p.currentBlob.chunks {
blobChunk := &database.BlobChunk{
BlobID: p.currentBlob.id,
ChunkHash: chunk.Hash,
BlobID: blobIDTyped,
ChunkHash: types.ChunkHash(chunk.Hash),
Offset: chunk.Offset,
Length: chunk.Size,
}

View File

@@ -12,6 +12,7 @@ import (
"filippo.io/age"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/klauspost/compress/zstd"
"github.com/spf13/afero"
)
@@ -60,7 +61,7 @@ func TestPacker(t *testing.T) {
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
ChunkHash: types.ChunkHash(hashStr),
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
@@ -152,7 +153,7 @@ func TestPacker(t *testing.T) {
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
ChunkHash: types.ChunkHash(hashStr),
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
@@ -235,7 +236,7 @@ func TestPacker(t *testing.T) {
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
ChunkHash: types.ChunkHash(hashStr),
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {
@@ -322,7 +323,7 @@ func TestPacker(t *testing.T) {
// Create chunk in database first
dbChunk := &database.Chunk{
ChunkHash: hashStr,
ChunkHash: types.ChunkHash(hashStr),
Size: int64(len(data)),
}
err = repos.WithTx(context.Background(), func(ctx context.Context, tx *sql.Tx) error {

View File

@@ -18,7 +18,7 @@ func TestCLIEntry(t *testing.T) {
}
// Verify all subcommands are registered
expectedCommands := []string{"snapshot", "store", "restore", "prune", "verify", "fetch"}
expectedCommands := []string{"snapshot", "store", "restore", "prune", "verify", "info", "version"}
for _, expected := range expectedCommands {
found := false
for _, cmd := range cmd.Commands() {

View File

@@ -1,138 +0,0 @@
package cli
import (
"context"
"fmt"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/storage"
"github.com/spf13/cobra"
"go.uber.org/fx"
)
// FetchOptions contains options for the fetch command
type FetchOptions struct {
}
// FetchApp contains all dependencies needed for fetch
type FetchApp struct {
Globals *globals.Globals
Config *config.Config
Repositories *database.Repositories
Storage storage.Storer
DB *database.DB
Shutdowner fx.Shutdowner
}
// NewFetchCommand creates the fetch command
func NewFetchCommand() *cobra.Command {
opts := &FetchOptions{}
cmd := &cobra.Command{
Use: "fetch <snapshot-id> <file-path> <target-path>",
Short: "Extract single file from backup",
Long: `Download and decrypt a single file from a backup snapshot.
This command extracts a specific file from the snapshot and saves it to the target path.
The age_secret_key must be configured in the config file for decryption.`,
Args: cobra.ExactArgs(3),
RunE: func(cmd *cobra.Command, args []string) error {
snapshotID := args[0]
filePath := args[1]
targetPath := args[2]
// Use unified config resolution
configPath, err := ResolveConfigPath()
if err != nil {
return err
}
// Use the app framework like other commands
rootFlags := GetRootFlags()
return RunWithApp(cmd.Context(), AppOptions{
ConfigPath: configPath,
LogOptions: log.LogOptions{
Verbose: rootFlags.Verbose,
Debug: rootFlags.Debug,
},
Modules: []fx.Option{
snapshot.Module,
fx.Provide(fx.Annotate(
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
storer storage.Storer, db *database.DB, shutdowner fx.Shutdowner) *FetchApp {
return &FetchApp{
Globals: g,
Config: cfg,
Repositories: repos,
Storage: storer,
DB: db,
Shutdowner: shutdowner,
}
},
)),
},
Invokes: []fx.Option{
fx.Invoke(func(app *FetchApp, lc fx.Lifecycle) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
// Start the fetch operation in a goroutine
go func() {
// Run the fetch operation
if err := app.runFetch(ctx, snapshotID, filePath, targetPath, opts); err != nil {
if err != context.Canceled {
log.Error("Fetch operation failed", "error", err)
}
}
// Shutdown the app when fetch completes
if err := app.Shutdowner.Shutdown(); err != nil {
log.Error("Failed to shutdown", "error", err)
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
log.Debug("Stopping fetch operation")
return nil
},
})
}),
},
})
},
}
return cmd
}
// runFetch executes the fetch operation
func (app *FetchApp) runFetch(ctx context.Context, snapshotID, filePath, targetPath string, opts *FetchOptions) error {
// Check for age_secret_key
if app.Config.AgeSecretKey == "" {
return fmt.Errorf("age_secret_key missing from config - required for fetch")
}
log.Info("Starting fetch operation",
"snapshot_id", snapshotID,
"file_path", filePath,
"target_path", targetPath,
"bucket", app.Config.S3.Bucket,
"prefix", app.Config.S3.Prefix,
)
// TODO: Implement fetch logic
// 1. Download and decrypt database from S3
// 2. Find the file metadata and chunk list
// 3. Download and decrypt only the necessary blobs
// 4. Reconstruct the file from chunks
// 5. Write file to target path with proper metadata
fmt.Printf("Fetching %s from snapshot %s to %s\n", filePath, snapshotID, targetPath)
fmt.Println("TODO: Implement fetch logic")
return nil
}

View File

@@ -2,13 +2,12 @@ package cli
import (
"context"
"fmt"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/vaultik"
"github.com/spf13/cobra"
"go.uber.org/fx"
)
@@ -16,16 +15,17 @@ import (
// RestoreOptions contains options for the restore command
type RestoreOptions struct {
TargetDir string
Paths []string // Optional paths to restore (empty = all)
Verify bool // Verify restored files after restore
}
// RestoreApp contains all dependencies needed for restore
type RestoreApp struct {
Globals *globals.Globals
Config *config.Config
Repositories *database.Repositories
Storage storage.Storer
DB *database.DB
Shutdowner fx.Shutdowner
Globals *globals.Globals
Config *config.Config
Storage storage.Storer
Vaultik *vaultik.Vaultik
Shutdowner fx.Shutdowner
}
// NewRestoreCommand creates the restore command
@@ -33,16 +33,35 @@ func NewRestoreCommand() *cobra.Command {
opts := &RestoreOptions{}
cmd := &cobra.Command{
Use: "restore <snapshot-id> <target-dir>",
Use: "restore <snapshot-id> <target-dir> [paths...]",
Short: "Restore files from backup",
Long: `Download and decrypt files from a backup snapshot.
This command will restore all files from the specified snapshot to the target directory.
The age_secret_key must be configured in the config file for decryption.`,
Args: cobra.ExactArgs(2),
This command will restore files from the specified snapshot to the target directory.
If no paths are specified, all files are restored.
If paths are specified, only matching files/directories are restored.
Requires the VAULTIK_AGE_SECRET_KEY environment variable to be set with the age private key.
Examples:
# Restore entire snapshot
vaultik restore myhost_docs_2025-01-01T12:00:00Z /restore
# Restore specific file
vaultik restore myhost_docs_2025-01-01T12:00:00Z /restore /home/user/important.txt
# Restore specific directory
vaultik restore myhost_docs_2025-01-01T12:00:00Z /restore /home/user/documents/
# Restore and verify all files
vaultik restore --verify myhost_docs_2025-01-01T12:00:00Z /restore`,
Args: cobra.MinimumNArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
snapshotID := args[0]
opts.TargetDir = args[1]
if len(args) > 2 {
opts.Paths = args[2:]
}
// Use unified config resolution
configPath, err := ResolveConfigPath()
@@ -60,15 +79,14 @@ The age_secret_key must be configured in the config file for decryption.`,
},
Modules: []fx.Option{
fx.Provide(fx.Annotate(
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
storer storage.Storer, db *database.DB, shutdowner fx.Shutdowner) *RestoreApp {
func(g *globals.Globals, cfg *config.Config,
storer storage.Storer, v *vaultik.Vaultik, shutdowner fx.Shutdowner) *RestoreApp {
return &RestoreApp{
Globals: g,
Config: cfg,
Repositories: repos,
Storage: storer,
DB: db,
Shutdowner: shutdowner,
Globals: g,
Config: cfg,
Storage: storer,
Vaultik: v,
Shutdowner: shutdowner,
}
},
)),
@@ -80,7 +98,13 @@ The age_secret_key must be configured in the config file for decryption.`,
// Start the restore operation in a goroutine
go func() {
// Run the restore operation
if err := app.runRestore(ctx, snapshotID, opts); err != nil {
restoreOpts := &vaultik.RestoreOptions{
SnapshotID: snapshotID,
TargetDir: opts.TargetDir,
Paths: opts.Paths,
Verify: opts.Verify,
}
if err := app.Vaultik.Restore(restoreOpts); err != nil {
if err != context.Canceled {
log.Error("Restore operation failed", "error", err)
}
@@ -95,6 +119,7 @@ The age_secret_key must be configured in the config file for decryption.`,
},
OnStop: func(ctx context.Context) error {
log.Debug("Stopping restore operation")
app.Vaultik.Cancel()
return nil
},
})
@@ -104,31 +129,7 @@ The age_secret_key must be configured in the config file for decryption.`,
},
}
cmd.Flags().BoolVar(&opts.Verify, "verify", false, "Verify restored files by checking chunk hashes")
return cmd
}
// runRestore executes the restore operation
func (app *RestoreApp) runRestore(ctx context.Context, snapshotID string, opts *RestoreOptions) error {
// Check for age_secret_key
if app.Config.AgeSecretKey == "" {
return fmt.Errorf("age_secret_key required for restore - set in config file or VAULTIK_AGE_SECRET_KEY environment variable")
}
log.Info("Starting restore operation",
"snapshot_id", snapshotID,
"target_dir", opts.TargetDir,
"bucket", app.Config.S3.Bucket,
"prefix", app.Config.S3.Prefix,
)
// TODO: Implement restore logic
// 1. Download and decrypt database from S3
// 2. Download and decrypt blobs
// 3. Reconstruct files from chunks
// 4. Write files to target directory with proper metadata
fmt.Printf("Restoring snapshot %s to %s\n", snapshotID, opts.TargetDir)
fmt.Println("TODO: Implement restore logic")
return nil
}

View File

@@ -40,10 +40,10 @@ on the source system.`,
NewRestoreCommand(),
NewPruneCommand(),
NewVerifyCommand(),
NewFetchCommand(),
NewStoreCommand(),
NewSnapshotCommand(),
NewInfoCommand(),
NewVersionCommand(),
)
return cmd

View File

@@ -35,14 +35,19 @@ func newSnapshotCreateCommand() *cobra.Command {
opts := &vaultik.SnapshotCreateOptions{}
cmd := &cobra.Command{
Use: "create",
Short: "Create a new snapshot",
Long: `Creates a new snapshot of the configured directories.
Use: "create [snapshot-names...]",
Short: "Create new snapshots",
Long: `Creates new snapshots of the configured directories.
Config is located at /etc/vaultik/config.yml by default, but can be overridden by
If snapshot names are provided, only those snapshots are created.
If no names are provided, all configured snapshots are created.
Config is located at /etc/vaultik/config.yml by default, but can be overridden by
specifying a path using --config or by setting VAULTIK_CONFIG to a path.`,
Args: cobra.NoArgs,
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
// Pass snapshot names from args
opts.Snapshots = args
// Use unified config resolution
configPath, err := ResolveConfigPath()
if err != nil {
@@ -95,6 +100,7 @@ specifying a path using --config or by setting VAULTIK_CONFIG to a path.`,
cmd.Flags().BoolVar(&opts.Daemon, "daemon", false, "Run in daemon mode with inotify monitoring")
cmd.Flags().BoolVar(&opts.Cron, "cron", false, "Run in cron mode (silent unless error)")
cmd.Flags().BoolVar(&opts.Prune, "prune", false, "Delete all previous snapshots and unreferenced blobs after backup")
cmd.Flags().BoolVar(&opts.SkipErrors, "skip-errors", false, "Skip file read errors (log them loudly but continue)")
return cmd
}

27
internal/cli/version.go Normal file
View File

@@ -0,0 +1,27 @@
package cli
import (
"fmt"
"runtime"
"git.eeqj.de/sneak/vaultik/internal/globals"
"github.com/spf13/cobra"
)
// NewVersionCommand creates the version command
func NewVersionCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "version",
Short: "Print version information",
Long: `Print version, git commit, and build information for vaultik.`,
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("vaultik %s\n", globals.Version)
fmt.Printf(" commit: %s\n", globals.Commit)
fmt.Printf(" go: %s\n", runtime.Version())
fmt.Printf(" os/arch: %s/%s\n", runtime.GOOS, runtime.GOARCH)
},
}
return cmd
}

View File

@@ -4,9 +4,11 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
"filippo.io/age"
"git.eeqj.de/sneak/smartconfig"
"github.com/adrg/xdg"
"go.uber.org/fx"
@@ -37,24 +39,62 @@ func expandTildeInURL(url string) string {
return url
}
// SnapshotConfig represents configuration for a named snapshot.
// Each snapshot backs up one or more paths and can have its own exclude patterns
// in addition to the global excludes.
type SnapshotConfig struct {
Paths []string `yaml:"paths"`
Exclude []string `yaml:"exclude"` // Additional excludes for this snapshot
}
// GetExcludes returns the combined exclude patterns for a named snapshot.
// It merges global excludes with the snapshot-specific excludes.
func (c *Config) GetExcludes(snapshotName string) []string {
snap, ok := c.Snapshots[snapshotName]
if !ok {
return c.Exclude
}
if len(snap.Exclude) == 0 {
return c.Exclude
}
// Combine global and snapshot-specific excludes
combined := make([]string, 0, len(c.Exclude)+len(snap.Exclude))
combined = append(combined, c.Exclude...)
combined = append(combined, snap.Exclude...)
return combined
}
// SnapshotNames returns the names of all configured snapshots in sorted order.
func (c *Config) SnapshotNames() []string {
names := make([]string, 0, len(c.Snapshots))
for name := range c.Snapshots {
names = append(names, name)
}
// Sort for deterministic order
sort.Strings(names)
return names
}
// Config represents the application configuration for Vaultik.
// It defines all settings for backup operations, including source directories,
// encryption recipients, storage configuration, and performance tuning parameters.
// Configuration is typically loaded from a YAML file.
type Config struct {
AgeRecipients []string `yaml:"age_recipients"`
AgeSecretKey string `yaml:"age_secret_key"`
BackupInterval time.Duration `yaml:"backup_interval"`
BlobSizeLimit Size `yaml:"blob_size_limit"`
ChunkSize Size `yaml:"chunk_size"`
Exclude []string `yaml:"exclude"`
FullScanInterval time.Duration `yaml:"full_scan_interval"`
Hostname string `yaml:"hostname"`
IndexPath string `yaml:"index_path"`
MinTimeBetweenRun time.Duration `yaml:"min_time_between_run"`
S3 S3Config `yaml:"s3"`
SourceDirs []string `yaml:"source_dirs"`
CompressionLevel int `yaml:"compression_level"`
AgeRecipients []string `yaml:"age_recipients"`
AgeSecretKey string `yaml:"age_secret_key"`
BackupInterval time.Duration `yaml:"backup_interval"`
BlobSizeLimit Size `yaml:"blob_size_limit"`
ChunkSize Size `yaml:"chunk_size"`
Exclude []string `yaml:"exclude"` // Global excludes applied to all snapshots
FullScanInterval time.Duration `yaml:"full_scan_interval"`
Hostname string `yaml:"hostname"`
IndexPath string `yaml:"index_path"`
MinTimeBetweenRun time.Duration `yaml:"min_time_between_run"`
S3 S3Config `yaml:"s3"`
Snapshots map[string]SnapshotConfig `yaml:"snapshots"`
CompressionLevel int `yaml:"compression_level"`
// StorageURL specifies the storage backend using a URL format.
// Takes precedence over S3Config if set.
@@ -137,8 +177,13 @@ func Load(path string) (*Config, error) {
// Expand tilde in all path fields
cfg.IndexPath = expandTilde(cfg.IndexPath)
cfg.StorageURL = expandTildeInURL(cfg.StorageURL)
for i, dir := range cfg.SourceDirs {
cfg.SourceDirs[i] = expandTilde(dir)
// Expand tildes in snapshot paths
for name, snap := range cfg.Snapshots {
for i, path := range snap.Paths {
snap.Paths[i] = expandTilde(path)
}
cfg.Snapshots[name] = snap
}
// Check for environment variable override for IndexPath
@@ -148,7 +193,7 @@ func Load(path string) (*Config, error) {
// Check for environment variable override for AgeSecretKey
if envAgeSecretKey := os.Getenv("VAULTIK_AGE_SECRET_KEY"); envAgeSecretKey != "" {
cfg.AgeSecretKey = envAgeSecretKey
cfg.AgeSecretKey = extractAgeSecretKey(envAgeSecretKey)
}
// Get hostname if not set
@@ -178,7 +223,7 @@ func Load(path string) (*Config, error) {
// Validate checks if the configuration is valid and complete.
// It ensures all required fields are present and have valid values:
// - At least one age recipient must be specified
// - At least one source directory must be configured
// - At least one snapshot must be configured with at least one path
// - Storage must be configured (either storage_url or s3.* fields)
// - Chunk size must be at least 1MB
// - Blob size limit must be at least the chunk size
@@ -189,8 +234,14 @@ func (c *Config) Validate() error {
return fmt.Errorf("at least one age_recipient is required")
}
if len(c.SourceDirs) == 0 {
return fmt.Errorf("at least one source directory is required")
if len(c.Snapshots) == 0 {
return fmt.Errorf("at least one snapshot must be configured")
}
for name, snap := range c.Snapshots {
if len(snap.Paths) == 0 {
return fmt.Errorf("snapshot %q must have at least one path", name)
}
}
// Validate storage configuration
@@ -257,6 +308,21 @@ func (c *Config) validateStorage() error {
return nil
}
// extractAgeSecretKey extracts the AGE-SECRET-KEY from the input using
// the age library's parser, which handles comments and whitespace.
func extractAgeSecretKey(input string) string {
identities, err := age.ParseIdentities(strings.NewReader(input))
if err != nil || len(identities) == 0 {
// Fall back to trimmed input if parsing fails
return strings.TrimSpace(input)
}
// Return the string representation of the first identity
if id, ok := identities[0].(*age.X25519Identity); ok {
return id.String()
}
return strings.TrimSpace(input)
}
// Module exports the config module for fx dependency injection.
// It provides the Config type to other modules in the application.
var Module = fx.Module("config",

View File

@@ -45,12 +45,21 @@ func TestConfigLoad(t *testing.T) {
t.Errorf("Expected first age recipient to be %s, got '%s'", TEST_SNEAK_AGE_PUBLIC_KEY, cfg.AgeRecipients[0])
}
if len(cfg.SourceDirs) != 2 {
t.Errorf("Expected 2 source dirs, got %d", len(cfg.SourceDirs))
if len(cfg.Snapshots) != 1 {
t.Errorf("Expected 1 snapshot, got %d", len(cfg.Snapshots))
}
if cfg.SourceDirs[0] != "/tmp/vaultik-test-source" {
t.Errorf("Expected first source dir to be '/tmp/vaultik-test-source', got '%s'", cfg.SourceDirs[0])
testSnap, ok := cfg.Snapshots["test"]
if !ok {
t.Fatal("Expected 'test' snapshot to exist")
}
if len(testSnap.Paths) != 2 {
t.Errorf("Expected 2 paths in test snapshot, got %d", len(testSnap.Paths))
}
if testSnap.Paths[0] != "/tmp/vaultik-test-source" {
t.Errorf("Expected first path to be '/tmp/vaultik-test-source', got '%s'", testSnap.Paths[0])
}
if cfg.S3.Bucket != "vaultik-test-bucket" {
@@ -74,3 +83,65 @@ func TestConfigFromEnv(t *testing.T) {
t.Errorf("Config file does not exist at path from VAULTIK_CONFIG: %s", configPath)
}
}
// TestExtractAgeSecretKey tests extraction of AGE-SECRET-KEY from various inputs
func TestExtractAgeSecretKey(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "plain key",
input: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5",
expected: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5",
},
{
name: "key with trailing newline",
input: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5\n",
expected: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5",
},
{
name: "full age-keygen output",
input: `# created: 2025-01-14T12:00:00Z
# public key: age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg
AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5
`,
expected: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5",
},
{
name: "age-keygen output with extra blank lines",
input: `# created: 2025-01-14T12:00:00Z
# public key: age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg
AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5
`,
expected: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5",
},
{
name: "key with leading whitespace",
input: " AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5 ",
expected: "AGE-SECRET-KEY-19CR5YSFW59HM4TLD6GXVEDMZFTVVF7PPHKUT68TXSFPK7APHXA2QS2NJA5",
},
{
name: "empty input",
input: "",
expected: "",
},
{
name: "only comments",
input: "# this is a comment\n# another comment",
expected: "# this is a comment\n# another comment",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractAgeSecretKey(tt.input)
if result != tt.expected {
t.Errorf("extractAgeSecretKey(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}

View File

@@ -5,6 +5,8 @@ import (
"strings"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
func TestBlobChunkRepository(t *testing.T) {
@@ -16,8 +18,8 @@ func TestBlobChunkRepository(t *testing.T) {
// Create blob first
blob := &Blob{
ID: "blob1-uuid",
Hash: "blob1-hash",
ID: types.NewBlobID(),
Hash: types.BlobHash("blob1-hash"),
CreatedTS: time.Now(),
}
err := repos.Blobs.Create(ctx, nil, blob)
@@ -26,7 +28,7 @@ func TestBlobChunkRepository(t *testing.T) {
}
// Create chunks
chunks := []string{"chunk1", "chunk2", "chunk3"}
chunks := []types.ChunkHash{"chunk1", "chunk2", "chunk3"}
for _, chunkHash := range chunks {
chunk := &Chunk{
ChunkHash: chunkHash,
@@ -41,7 +43,7 @@ func TestBlobChunkRepository(t *testing.T) {
// Test Create
bc1 := &BlobChunk{
BlobID: blob.ID,
ChunkHash: "chunk1",
ChunkHash: types.ChunkHash("chunk1"),
Offset: 0,
Length: 1024,
}
@@ -54,7 +56,7 @@ func TestBlobChunkRepository(t *testing.T) {
// Add more chunks to the same blob
bc2 := &BlobChunk{
BlobID: blob.ID,
ChunkHash: "chunk2",
ChunkHash: types.ChunkHash("chunk2"),
Offset: 1024,
Length: 2048,
}
@@ -65,7 +67,7 @@ func TestBlobChunkRepository(t *testing.T) {
bc3 := &BlobChunk{
BlobID: blob.ID,
ChunkHash: "chunk3",
ChunkHash: types.ChunkHash("chunk3"),
Offset: 3072,
Length: 512,
}
@@ -75,7 +77,7 @@ func TestBlobChunkRepository(t *testing.T) {
}
// Test GetByBlobID
blobChunks, err := repos.BlobChunks.GetByBlobID(ctx, blob.ID)
blobChunks, err := repos.BlobChunks.GetByBlobID(ctx, blob.ID.String())
if err != nil {
t.Fatalf("failed to get blob chunks: %v", err)
}
@@ -134,13 +136,13 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
// Create blobs
blob1 := &Blob{
ID: "blob1-uuid",
Hash: "blob1-hash",
ID: types.NewBlobID(),
Hash: types.BlobHash("blob1-hash"),
CreatedTS: time.Now(),
}
blob2 := &Blob{
ID: "blob2-uuid",
Hash: "blob2-hash",
ID: types.NewBlobID(),
Hash: types.BlobHash("blob2-hash"),
CreatedTS: time.Now(),
}
@@ -154,7 +156,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
}
// Create chunks
chunkHashes := []string{"chunk1", "chunk2", "chunk3"}
chunkHashes := []types.ChunkHash{"chunk1", "chunk2", "chunk3"}
for _, chunkHash := range chunkHashes {
chunk := &Chunk{
ChunkHash: chunkHash,
@@ -169,10 +171,10 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
// Create chunks across multiple blobs
// Some chunks are shared between blobs (deduplication scenario)
blobChunks := []BlobChunk{
{BlobID: blob1.ID, ChunkHash: "chunk1", Offset: 0, Length: 1024},
{BlobID: blob1.ID, ChunkHash: "chunk2", Offset: 1024, Length: 1024},
{BlobID: blob2.ID, ChunkHash: "chunk2", Offset: 0, Length: 1024}, // chunk2 is shared
{BlobID: blob2.ID, ChunkHash: "chunk3", Offset: 1024, Length: 1024},
{BlobID: blob1.ID, ChunkHash: types.ChunkHash("chunk1"), Offset: 0, Length: 1024},
{BlobID: blob1.ID, ChunkHash: types.ChunkHash("chunk2"), Offset: 1024, Length: 1024},
{BlobID: blob2.ID, ChunkHash: types.ChunkHash("chunk2"), Offset: 0, Length: 1024}, // chunk2 is shared
{BlobID: blob2.ID, ChunkHash: types.ChunkHash("chunk3"), Offset: 1024, Length: 1024},
}
for _, bc := range blobChunks {
@@ -183,7 +185,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
}
// Verify blob1 chunks
chunks, err := repos.BlobChunks.GetByBlobID(ctx, blob1.ID)
chunks, err := repos.BlobChunks.GetByBlobID(ctx, blob1.ID.String())
if err != nil {
t.Fatalf("failed to get blob1 chunks: %v", err)
}
@@ -192,7 +194,7 @@ func TestBlobChunkRepositoryMultipleBlobs(t *testing.T) {
}
// Verify blob2 chunks
chunks, err = repos.BlobChunks.GetByBlobID(ctx, blob2.ID)
chunks, err = repos.BlobChunks.GetByBlobID(ctx, blob2.ID.String())
if err != nil {
t.Fatalf("failed to get blob2 chunks: %v", err)
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
func TestBlobRepository(t *testing.T) {
@@ -15,8 +17,8 @@ func TestBlobRepository(t *testing.T) {
// Test Create
blob := &Blob{
ID: "test-blob-id-123",
Hash: "blobhash123",
ID: types.NewBlobID(),
Hash: types.BlobHash("blobhash123"),
CreatedTS: time.Now().Truncate(time.Second),
}
@@ -26,7 +28,7 @@ func TestBlobRepository(t *testing.T) {
}
// Test GetByHash
retrieved, err := repo.GetByHash(ctx, blob.Hash)
retrieved, err := repo.GetByHash(ctx, blob.Hash.String())
if err != nil {
t.Fatalf("failed to get blob: %v", err)
}
@@ -41,7 +43,7 @@ func TestBlobRepository(t *testing.T) {
}
// Test GetByID
retrievedByID, err := repo.GetByID(ctx, blob.ID)
retrievedByID, err := repo.GetByID(ctx, blob.ID.String())
if err != nil {
t.Fatalf("failed to get blob by ID: %v", err)
}
@@ -54,8 +56,8 @@ func TestBlobRepository(t *testing.T) {
// Test with second blob
blob2 := &Blob{
ID: "test-blob-id-456",
Hash: "blobhash456",
ID: types.NewBlobID(),
Hash: types.BlobHash("blobhash456"),
CreatedTS: time.Now().Truncate(time.Second),
}
err = repo.Create(ctx, nil, blob2)
@@ -65,13 +67,13 @@ func TestBlobRepository(t *testing.T) {
// Test UpdateFinished
now := time.Now()
err = repo.UpdateFinished(ctx, nil, blob.ID, blob.Hash, 1000, 500)
err = repo.UpdateFinished(ctx, nil, blob.ID.String(), blob.Hash.String(), 1000, 500)
if err != nil {
t.Fatalf("failed to update blob as finished: %v", err)
}
// Verify update
updated, err := repo.GetByID(ctx, blob.ID)
updated, err := repo.GetByID(ctx, blob.ID.String())
if err != nil {
t.Fatalf("failed to get updated blob: %v", err)
}
@@ -86,13 +88,13 @@ func TestBlobRepository(t *testing.T) {
}
// Test UpdateUploaded
err = repo.UpdateUploaded(ctx, nil, blob.ID)
err = repo.UpdateUploaded(ctx, nil, blob.ID.String())
if err != nil {
t.Fatalf("failed to update blob as uploaded: %v", err)
}
// Verify upload update
uploaded, err := repo.GetByID(ctx, blob.ID)
uploaded, err := repo.GetByID(ctx, blob.ID.String())
if err != nil {
t.Fatalf("failed to get uploaded blob: %v", err)
}
@@ -113,8 +115,8 @@ func TestBlobRepositoryDuplicate(t *testing.T) {
repo := NewBlobRepository(db)
blob := &Blob{
ID: "duplicate-test-id",
Hash: "duplicate_blob",
ID: types.NewBlobID(),
Hash: types.BlobHash("duplicate_blob"),
CreatedTS: time.Now().Truncate(time.Second),
}

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
// TestCascadeDeleteDebug tests cascade delete with debug output
@@ -42,7 +44,7 @@ func TestCascadeDeleteDebug(t *testing.T) {
// Create chunks and file-chunk mappings
for i := 0; i < 3; i++ {
chunk := &Chunk{
ChunkHash: fmt.Sprintf("cascade-chunk-%d", i),
ChunkHash: types.ChunkHash(fmt.Sprintf("cascade-chunk-%d", i)),
Size: 1024,
}
err = repos.Chunks.Create(ctx, nil, chunk)

View File

@@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"fmt"
"git.eeqj.de/sneak/vaultik/internal/types"
)
type ChunkFileRepository struct {
@@ -23,9 +25,9 @@ func (r *ChunkFileRepository) Create(ctx context.Context, tx *sql.Tx, cf *ChunkF
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, cf.ChunkHash, cf.FileID, cf.FileOffset, cf.Length)
_, err = tx.ExecContext(ctx, query, cf.ChunkHash.String(), cf.FileID.String(), cf.FileOffset, cf.Length)
} else {
_, err = r.db.ExecWithLog(ctx, query, cf.ChunkHash, cf.FileID, cf.FileOffset, cf.Length)
_, err = r.db.ExecWithLog(ctx, query, cf.ChunkHash.String(), cf.FileID.String(), cf.FileOffset, cf.Length)
}
if err != nil {
@@ -35,30 +37,20 @@ func (r *ChunkFileRepository) Create(ctx context.Context, tx *sql.Tx, cf *ChunkF
return nil
}
func (r *ChunkFileRepository) GetByChunkHash(ctx context.Context, chunkHash string) ([]*ChunkFile, error) {
func (r *ChunkFileRepository) GetByChunkHash(ctx context.Context, chunkHash types.ChunkHash) ([]*ChunkFile, error) {
query := `
SELECT chunk_hash, file_id, file_offset, length
FROM chunk_files
WHERE chunk_hash = ?
`
rows, err := r.db.conn.QueryContext(ctx, query, chunkHash)
rows, err := r.db.conn.QueryContext(ctx, query, chunkHash.String())
if err != nil {
return nil, fmt.Errorf("querying chunk files: %w", err)
}
defer CloseRows(rows)
var chunkFiles []*ChunkFile
for rows.Next() {
var cf ChunkFile
err := rows.Scan(&cf.ChunkHash, &cf.FileID, &cf.FileOffset, &cf.Length)
if err != nil {
return nil, fmt.Errorf("scanning chunk file: %w", err)
}
chunkFiles = append(chunkFiles, &cf)
}
return chunkFiles, rows.Err()
return r.scanChunkFiles(rows)
}
func (r *ChunkFileRepository) GetByFilePath(ctx context.Context, filePath string) ([]*ChunkFile, error) {
@@ -75,40 +67,41 @@ func (r *ChunkFileRepository) GetByFilePath(ctx context.Context, filePath string
}
defer CloseRows(rows)
var chunkFiles []*ChunkFile
for rows.Next() {
var cf ChunkFile
err := rows.Scan(&cf.ChunkHash, &cf.FileID, &cf.FileOffset, &cf.Length)
if err != nil {
return nil, fmt.Errorf("scanning chunk file: %w", err)
}
chunkFiles = append(chunkFiles, &cf)
}
return chunkFiles, rows.Err()
return r.scanChunkFiles(rows)
}
// GetByFileID retrieves chunk files by file ID
func (r *ChunkFileRepository) GetByFileID(ctx context.Context, fileID string) ([]*ChunkFile, error) {
func (r *ChunkFileRepository) GetByFileID(ctx context.Context, fileID types.FileID) ([]*ChunkFile, error) {
query := `
SELECT chunk_hash, file_id, file_offset, length
FROM chunk_files
WHERE file_id = ?
`
rows, err := r.db.conn.QueryContext(ctx, query, fileID)
rows, err := r.db.conn.QueryContext(ctx, query, fileID.String())
if err != nil {
return nil, fmt.Errorf("querying chunk files: %w", err)
}
defer CloseRows(rows)
return r.scanChunkFiles(rows)
}
// scanChunkFiles is a helper that scans chunk file rows
func (r *ChunkFileRepository) scanChunkFiles(rows *sql.Rows) ([]*ChunkFile, error) {
var chunkFiles []*ChunkFile
for rows.Next() {
var cf ChunkFile
err := rows.Scan(&cf.ChunkHash, &cf.FileID, &cf.FileOffset, &cf.Length)
var chunkHashStr, fileIDStr string
err := rows.Scan(&chunkHashStr, &fileIDStr, &cf.FileOffset, &cf.Length)
if err != nil {
return nil, fmt.Errorf("scanning chunk file: %w", err)
}
cf.ChunkHash = types.ChunkHash(chunkHashStr)
cf.FileID, err = types.ParseFileID(fileIDStr)
if err != nil {
return nil, fmt.Errorf("parsing file ID: %w", err)
}
chunkFiles = append(chunkFiles, &cf)
}
@@ -116,14 +109,14 @@ func (r *ChunkFileRepository) GetByFileID(ctx context.Context, fileID string) ([
}
// DeleteByFileID deletes all chunk_files entries for a given file ID
func (r *ChunkFileRepository) DeleteByFileID(ctx context.Context, tx *sql.Tx, fileID string) error {
func (r *ChunkFileRepository) DeleteByFileID(ctx context.Context, tx *sql.Tx, fileID types.FileID) error {
query := `DELETE FROM chunk_files WHERE file_id = ?`
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, fileID)
_, err = tx.ExecContext(ctx, query, fileID.String())
} else {
_, err = r.db.ExecWithLog(ctx, query, fileID)
_, err = r.db.ExecWithLog(ctx, query, fileID.String())
}
if err != nil {
@@ -134,7 +127,7 @@ func (r *ChunkFileRepository) DeleteByFileID(ctx context.Context, tx *sql.Tx, fi
}
// DeleteByFileIDs deletes all chunk_files for multiple files in a single statement.
func (r *ChunkFileRepository) DeleteByFileIDs(ctx context.Context, tx *sql.Tx, fileIDs []string) error {
func (r *ChunkFileRepository) DeleteByFileIDs(ctx context.Context, tx *sql.Tx, fileIDs []types.FileID) error {
if len(fileIDs) == 0 {
return nil
}
@@ -152,7 +145,7 @@ func (r *ChunkFileRepository) DeleteByFileIDs(ctx context.Context, tx *sql.Tx, f
query := "DELETE FROM chunk_files WHERE file_id IN (?" + repeatPlaceholder(len(batch)-1) + ")"
args := make([]interface{}, len(batch))
for j, id := range batch {
args[j] = id
args[j] = id.String()
}
var err error
@@ -192,7 +185,7 @@ func (r *ChunkFileRepository) CreateBatch(ctx context.Context, tx *sql.Tx, cfs [
query += ", "
}
query += "(?, ?, ?, ?)"
args = append(args, cf.ChunkHash, cf.FileID, cf.FileOffset, cf.Length)
args = append(args, cf.ChunkHash.String(), cf.FileID.String(), cf.FileOffset, cf.Length)
}
query += " ON CONFLICT(chunk_hash, file_id) DO NOTHING"

View File

@@ -4,6 +4,8 @@ import (
"context"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
func TestChunkFileRepository(t *testing.T) {
@@ -49,7 +51,7 @@ func TestChunkFileRepository(t *testing.T) {
// Create chunk first
chunk := &Chunk{
ChunkHash: "chunk1",
ChunkHash: types.ChunkHash("chunk1"),
Size: 1024,
}
err = chunksRepo.Create(ctx, nil, chunk)
@@ -59,7 +61,7 @@ func TestChunkFileRepository(t *testing.T) {
// Test Create
cf1 := &ChunkFile{
ChunkHash: "chunk1",
ChunkHash: types.ChunkHash("chunk1"),
FileID: file1.ID,
FileOffset: 0,
Length: 1024,
@@ -72,7 +74,7 @@ func TestChunkFileRepository(t *testing.T) {
// Add same chunk in different file (deduplication scenario)
cf2 := &ChunkFile{
ChunkHash: "chunk1",
ChunkHash: types.ChunkHash("chunk1"),
FileID: file2.ID,
FileOffset: 2048,
Length: 1024,
@@ -114,7 +116,7 @@ func TestChunkFileRepository(t *testing.T) {
if len(chunkFiles) != 1 {
t.Errorf("expected 1 chunk for file, got %d", len(chunkFiles))
}
if chunkFiles[0].ChunkHash != "chunk1" {
if chunkFiles[0].ChunkHash != types.ChunkHash("chunk1") {
t.Errorf("wrong chunk hash: expected chunk1, got %s", chunkFiles[0].ChunkHash)
}
@@ -151,7 +153,7 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
}
// Create chunks first
chunks := []string{"chunk1", "chunk2", "chunk3", "chunk4"}
chunks := []types.ChunkHash{"chunk1", "chunk2", "chunk3", "chunk4"}
for _, chunkHash := range chunks {
chunk := &Chunk{
ChunkHash: chunkHash,
@@ -170,16 +172,16 @@ func TestChunkFileRepositoryComplexDeduplication(t *testing.T) {
chunkFiles := []ChunkFile{
// File1
{ChunkHash: "chunk1", FileID: file1.ID, FileOffset: 0, Length: 1024},
{ChunkHash: "chunk2", FileID: file1.ID, FileOffset: 1024, Length: 1024},
{ChunkHash: "chunk3", FileID: file1.ID, FileOffset: 2048, Length: 1024},
{ChunkHash: types.ChunkHash("chunk1"), FileID: file1.ID, FileOffset: 0, Length: 1024},
{ChunkHash: types.ChunkHash("chunk2"), FileID: file1.ID, FileOffset: 1024, Length: 1024},
{ChunkHash: types.ChunkHash("chunk3"), FileID: file1.ID, FileOffset: 2048, Length: 1024},
// File2
{ChunkHash: "chunk2", FileID: file2.ID, FileOffset: 0, Length: 1024},
{ChunkHash: "chunk3", FileID: file2.ID, FileOffset: 1024, Length: 1024},
{ChunkHash: "chunk4", FileID: file2.ID, FileOffset: 2048, Length: 1024},
{ChunkHash: types.ChunkHash("chunk2"), FileID: file2.ID, FileOffset: 0, Length: 1024},
{ChunkHash: types.ChunkHash("chunk3"), FileID: file2.ID, FileOffset: 1024, Length: 1024},
{ChunkHash: types.ChunkHash("chunk4"), FileID: file2.ID, FileOffset: 2048, Length: 1024},
// File3
{ChunkHash: "chunk1", FileID: file3.ID, FileOffset: 0, Length: 1024},
{ChunkHash: "chunk4", FileID: file3.ID, FileOffset: 1024, Length: 1024},
{ChunkHash: types.ChunkHash("chunk1"), FileID: file3.ID, FileOffset: 0, Length: 1024},
{ChunkHash: types.ChunkHash("chunk4"), FileID: file3.ID, FileOffset: 1024, Length: 1024},
}
for _, cf := range chunkFiles {

View File

@@ -3,6 +3,8 @@ package database
import (
"context"
"testing"
"git.eeqj.de/sneak/vaultik/internal/types"
)
func TestChunkRepository(t *testing.T) {
@@ -14,7 +16,7 @@ func TestChunkRepository(t *testing.T) {
// Test Create
chunk := &Chunk{
ChunkHash: "chunkhash123",
ChunkHash: types.ChunkHash("chunkhash123"),
Size: 4096,
}
@@ -24,7 +26,7 @@ func TestChunkRepository(t *testing.T) {
}
// Test GetByHash
retrieved, err := repo.GetByHash(ctx, chunk.ChunkHash)
retrieved, err := repo.GetByHash(ctx, chunk.ChunkHash.String())
if err != nil {
t.Fatalf("failed to get chunk: %v", err)
}
@@ -46,7 +48,7 @@ func TestChunkRepository(t *testing.T) {
// Test GetByHashes
chunk2 := &Chunk{
ChunkHash: "chunkhash456",
ChunkHash: types.ChunkHash("chunkhash456"),
Size: 8192,
}
err = repo.Create(ctx, nil, chunk2)
@@ -54,7 +56,7 @@ func TestChunkRepository(t *testing.T) {
t.Fatalf("failed to create second chunk: %v", err)
}
chunks, err := repo.GetByHashes(ctx, []string{chunk.ChunkHash, chunk2.ChunkHash})
chunks, err := repo.GetByHashes(ctx, []string{chunk.ChunkHash.String(), chunk2.ChunkHash.String()})
if err != nil {
t.Fatalf("failed to get chunks by hashes: %v", err)
}

View File

@@ -154,6 +154,11 @@ func (db *DB) Conn() *sql.DB {
return db.conn
}
// Path returns the path to the database file.
func (db *DB) Path() string {
return db.path
}
// BeginTx starts a new database transaction with the given options.
// The caller is responsible for committing or rolling back the transaction.
// For write transactions, consider using the Repositories.WithTx method instead,

View File

@@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"fmt"
"git.eeqj.de/sneak/vaultik/internal/types"
)
type FileChunkRepository struct {
@@ -23,9 +25,9 @@ func (r *FileChunkRepository) Create(ctx context.Context, tx *sql.Tx, fc *FileCh
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, fc.FileID, fc.Idx, fc.ChunkHash)
_, err = tx.ExecContext(ctx, query, fc.FileID.String(), fc.Idx, fc.ChunkHash.String())
} else {
_, err = r.db.ExecWithLog(ctx, query, fc.FileID, fc.Idx, fc.ChunkHash)
_, err = r.db.ExecWithLog(ctx, query, fc.FileID.String(), fc.Idx, fc.ChunkHash.String())
}
if err != nil {
@@ -50,21 +52,11 @@ func (r *FileChunkRepository) GetByPath(ctx context.Context, path string) ([]*Fi
}
defer CloseRows(rows)
var fileChunks []*FileChunk
for rows.Next() {
var fc FileChunk
err := rows.Scan(&fc.FileID, &fc.Idx, &fc.ChunkHash)
if err != nil {
return nil, fmt.Errorf("scanning file chunk: %w", err)
}
fileChunks = append(fileChunks, &fc)
}
return fileChunks, rows.Err()
return r.scanFileChunks(rows)
}
// GetByFileID retrieves file chunks by file ID
func (r *FileChunkRepository) GetByFileID(ctx context.Context, fileID string) ([]*FileChunk, error) {
func (r *FileChunkRepository) GetByFileID(ctx context.Context, fileID types.FileID) ([]*FileChunk, error) {
query := `
SELECT file_id, idx, chunk_hash
FROM file_chunks
@@ -72,23 +64,13 @@ func (r *FileChunkRepository) GetByFileID(ctx context.Context, fileID string) ([
ORDER BY idx
`
rows, err := r.db.conn.QueryContext(ctx, query, fileID)
rows, err := r.db.conn.QueryContext(ctx, query, fileID.String())
if err != nil {
return nil, fmt.Errorf("querying file chunks: %w", err)
}
defer CloseRows(rows)
var fileChunks []*FileChunk
for rows.Next() {
var fc FileChunk
err := rows.Scan(&fc.FileID, &fc.Idx, &fc.ChunkHash)
if err != nil {
return nil, fmt.Errorf("scanning file chunk: %w", err)
}
fileChunks = append(fileChunks, &fc)
}
return fileChunks, rows.Err()
return r.scanFileChunks(rows)
}
// GetByPathTx retrieves file chunks within a transaction
@@ -108,16 +90,28 @@ func (r *FileChunkRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path
}
defer CloseRows(rows)
fileChunks, err := r.scanFileChunks(rows)
LogSQL("GetByPathTx", "Complete", path, "count", len(fileChunks))
return fileChunks, err
}
// scanFileChunks is a helper that scans file chunk rows
func (r *FileChunkRepository) scanFileChunks(rows *sql.Rows) ([]*FileChunk, error) {
var fileChunks []*FileChunk
for rows.Next() {
var fc FileChunk
err := rows.Scan(&fc.FileID, &fc.Idx, &fc.ChunkHash)
var fileIDStr, chunkHashStr string
err := rows.Scan(&fileIDStr, &fc.Idx, &chunkHashStr)
if err != nil {
return nil, fmt.Errorf("scanning file chunk: %w", err)
}
fc.FileID, err = types.ParseFileID(fileIDStr)
if err != nil {
return nil, fmt.Errorf("parsing file ID: %w", err)
}
fc.ChunkHash = types.ChunkHash(chunkHashStr)
fileChunks = append(fileChunks, &fc)
}
LogSQL("GetByPathTx", "Complete", path, "count", len(fileChunks))
return fileChunks, rows.Err()
}
@@ -140,14 +134,14 @@ func (r *FileChunkRepository) DeleteByPath(ctx context.Context, tx *sql.Tx, path
}
// DeleteByFileID deletes all chunks for a file by its UUID
func (r *FileChunkRepository) DeleteByFileID(ctx context.Context, tx *sql.Tx, fileID string) error {
func (r *FileChunkRepository) DeleteByFileID(ctx context.Context, tx *sql.Tx, fileID types.FileID) error {
query := `DELETE FROM file_chunks WHERE file_id = ?`
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, fileID)
_, err = tx.ExecContext(ctx, query, fileID.String())
} else {
_, err = r.db.ExecWithLog(ctx, query, fileID)
_, err = r.db.ExecWithLog(ctx, query, fileID.String())
}
if err != nil {
@@ -158,7 +152,7 @@ func (r *FileChunkRepository) DeleteByFileID(ctx context.Context, tx *sql.Tx, fi
}
// DeleteByFileIDs deletes all chunks for multiple files in a single statement.
func (r *FileChunkRepository) DeleteByFileIDs(ctx context.Context, tx *sql.Tx, fileIDs []string) error {
func (r *FileChunkRepository) DeleteByFileIDs(ctx context.Context, tx *sql.Tx, fileIDs []types.FileID) error {
if len(fileIDs) == 0 {
return nil
}
@@ -176,7 +170,7 @@ func (r *FileChunkRepository) DeleteByFileIDs(ctx context.Context, tx *sql.Tx, f
query := "DELETE FROM file_chunks WHERE file_id IN (?" + repeatPlaceholder(len(batch)-1) + ")"
args := make([]interface{}, len(batch))
for j, id := range batch {
args[j] = id
args[j] = id.String()
}
var err error
@@ -219,7 +213,7 @@ func (r *FileChunkRepository) CreateBatch(ctx context.Context, tx *sql.Tx, fcs [
query += ", "
}
query += "(?, ?, ?)"
args = append(args, fc.FileID, fc.Idx, fc.ChunkHash)
args = append(args, fc.FileID.String(), fc.Idx, fc.ChunkHash.String())
}
query += " ON CONFLICT(file_id, idx) DO NOTHING"

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
func TestFileChunkRepository(t *testing.T) {
@@ -33,7 +35,7 @@ func TestFileChunkRepository(t *testing.T) {
}
// Create chunks first
chunks := []string{"chunk1", "chunk2", "chunk3"}
chunks := []types.ChunkHash{"chunk1", "chunk2", "chunk3"}
chunkRepo := NewChunkRepository(db)
for _, chunkHash := range chunks {
chunk := &Chunk{
@@ -50,7 +52,7 @@ func TestFileChunkRepository(t *testing.T) {
fc1 := &FileChunk{
FileID: file.ID,
Idx: 0,
ChunkHash: "chunk1",
ChunkHash: types.ChunkHash("chunk1"),
}
err = repo.Create(ctx, nil, fc1)
@@ -62,7 +64,7 @@ func TestFileChunkRepository(t *testing.T) {
fc2 := &FileChunk{
FileID: file.ID,
Idx: 1,
ChunkHash: "chunk2",
ChunkHash: types.ChunkHash("chunk2"),
}
err = repo.Create(ctx, nil, fc2)
if err != nil {
@@ -72,7 +74,7 @@ func TestFileChunkRepository(t *testing.T) {
fc3 := &FileChunk{
FileID: file.ID,
Idx: 2,
ChunkHash: "chunk3",
ChunkHash: types.ChunkHash("chunk3"),
}
err = repo.Create(ctx, nil, fc3)
if err != nil {
@@ -131,7 +133,7 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) {
for i, path := range filePaths {
file := &File{
Path: path,
Path: types.FilePath(path),
MTime: testTime,
CTime: testTime,
Size: 2048,
@@ -151,7 +153,7 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) {
chunkRepo := NewChunkRepository(db)
for i := range files {
for j := 0; j < 2; j++ {
chunkHash := fmt.Sprintf("file%d_chunk%d", i, j)
chunkHash := types.ChunkHash(fmt.Sprintf("file%d_chunk%d", i, j))
chunk := &Chunk{
ChunkHash: chunkHash,
Size: 1024,
@@ -169,7 +171,7 @@ func TestFileChunkRepositoryMultipleFiles(t *testing.T) {
fc := &FileChunk{
FileID: file.ID,
Idx: j,
ChunkHash: fmt.Sprintf("file%d_chunk%d", i, j),
ChunkHash: types.ChunkHash(fmt.Sprintf("file%d_chunk%d", i, j)),
}
err := repo.Create(ctx, nil, fc)
if err != nil {

View File

@@ -7,7 +7,7 @@ import (
"time"
"git.eeqj.de/sneak/vaultik/internal/log"
"github.com/google/uuid"
"git.eeqj.de/sneak/vaultik/internal/types"
)
type FileRepository struct {
@@ -20,14 +20,15 @@ func NewFileRepository(db *DB) *FileRepository {
func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) error {
// Generate UUID if not provided
if file.ID == "" {
file.ID = uuid.New().String()
if file.ID.IsZero() {
file.ID = types.NewFileID()
}
query := `
INSERT INTO files (id, path, mtime, ctime, size, mode, uid, gid, link_target)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO files (id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(path) DO UPDATE SET
source_path = excluded.source_path,
mtime = excluded.mtime,
ctime = excluded.ctime,
size = excluded.size,
@@ -38,44 +39,36 @@ func (r *FileRepository) Create(ctx context.Context, tx *sql.Tx, file *File) err
RETURNING id
`
var idStr string
var err error
if tx != nil {
LogSQL("Execute", query, file.ID, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget)
err = tx.QueryRowContext(ctx, query, file.ID, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget).Scan(&file.ID)
LogSQL("Execute", query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String())
err = tx.QueryRowContext(ctx, query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()).Scan(&idStr)
} else {
err = r.db.QueryRowWithLog(ctx, query, file.ID, file.Path, file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget).Scan(&file.ID)
err = r.db.QueryRowWithLog(ctx, query, file.ID.String(), file.Path.String(), file.SourcePath.String(), file.MTime.Unix(), file.CTime.Unix(), file.Size, file.Mode, file.UID, file.GID, file.LinkTarget.String()).Scan(&idStr)
}
if err != nil {
return fmt.Errorf("inserting file: %w", err)
}
// Parse the returned ID
file.ID, err = types.ParseFileID(idStr)
if err != nil {
return fmt.Errorf("parsing file ID: %w", err)
}
return nil
}
func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, error) {
query := `
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target
FROM files
WHERE path = ?
`
var file File
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
err := r.db.conn.QueryRowContext(ctx, query, path).Scan(
&file.ID,
&file.Path,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
file, err := r.scanFile(r.db.conn.QueryRowContext(ctx, query, path))
if err == sql.ErrNoRows {
return nil, nil
}
@@ -83,39 +76,18 @@ func (r *FileRepository) GetByPath(ctx context.Context, path string) (*File, err
return nil, fmt.Errorf("querying file: %w", err)
}
file.MTime = time.Unix(mtimeUnix, 0).UTC()
file.CTime = time.Unix(ctimeUnix, 0).UTC()
if linkTarget.Valid {
file.LinkTarget = linkTarget.String
}
return &file, nil
return file, nil
}
// GetByID retrieves a file by its UUID
func (r *FileRepository) GetByID(ctx context.Context, id string) (*File, error) {
func (r *FileRepository) GetByID(ctx context.Context, id types.FileID) (*File, error) {
query := `
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target
FROM files
WHERE id = ?
`
var file File
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&file.ID,
&file.Path,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
file, err := r.scanFile(r.db.conn.QueryRowContext(ctx, query, id.String()))
if err == sql.ErrNoRows {
return nil, nil
}
@@ -123,38 +95,18 @@ func (r *FileRepository) GetByID(ctx context.Context, id string) (*File, error)
return nil, fmt.Errorf("querying file: %w", err)
}
file.MTime = time.Unix(mtimeUnix, 0).UTC()
file.CTime = time.Unix(ctimeUnix, 0).UTC()
if linkTarget.Valid {
file.LinkTarget = linkTarget.String
}
return &file, nil
return file, nil
}
func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path string) (*File, error) {
query := `
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target
FROM files
WHERE path = ?
`
var file File
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
LogSQL("GetByPathTx QueryRowContext", query, path)
err := tx.QueryRowContext(ctx, query, path).Scan(
&file.ID,
&file.Path,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
file, err := r.scanFile(tx.QueryRowContext(ctx, query, path))
LogSQL("GetByPathTx Scan complete", query, path)
if err == sql.ErrNoRows {
@@ -164,10 +116,80 @@ func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path strin
return nil, fmt.Errorf("querying file: %w", err)
}
return file, nil
}
// scanFile is a helper that scans a single file row
func (r *FileRepository) scanFile(row *sql.Row) (*File, error) {
var file File
var idStr, pathStr, sourcePathStr string
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
err := row.Scan(
&idStr,
&pathStr,
&sourcePathStr,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
if err != nil {
return nil, err
}
file.ID, err = types.ParseFileID(idStr)
if err != nil {
return nil, fmt.Errorf("parsing file ID: %w", err)
}
file.Path = types.FilePath(pathStr)
file.SourcePath = types.SourcePath(sourcePathStr)
file.MTime = time.Unix(mtimeUnix, 0).UTC()
file.CTime = time.Unix(ctimeUnix, 0).UTC()
if linkTarget.Valid {
file.LinkTarget = linkTarget.String
file.LinkTarget = types.FilePath(linkTarget.String)
}
return &file, nil
}
// scanFileRows is a helper that scans a file row from rows iterator
func (r *FileRepository) scanFileRows(rows *sql.Rows) (*File, error) {
var file File
var idStr, pathStr, sourcePathStr string
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
err := rows.Scan(
&idStr,
&pathStr,
&sourcePathStr,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
if err != nil {
return nil, err
}
file.ID, err = types.ParseFileID(idStr)
if err != nil {
return nil, fmt.Errorf("parsing file ID: %w", err)
}
file.Path = types.FilePath(pathStr)
file.SourcePath = types.SourcePath(sourcePathStr)
file.MTime = time.Unix(mtimeUnix, 0).UTC()
file.CTime = time.Unix(ctimeUnix, 0).UTC()
if linkTarget.Valid {
file.LinkTarget = types.FilePath(linkTarget.String)
}
return &file, nil
@@ -175,7 +197,7 @@ func (r *FileRepository) GetByPathTx(ctx context.Context, tx *sql.Tx, path strin
func (r *FileRepository) ListModifiedSince(ctx context.Context, since time.Time) ([]*File, error) {
query := `
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target
FROM files
WHERE mtime >= ?
ORDER BY path
@@ -189,32 +211,11 @@ func (r *FileRepository) ListModifiedSince(ctx context.Context, since time.Time)
var files []*File
for rows.Next() {
var file File
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
err := rows.Scan(
&file.ID,
&file.Path,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
file, err := r.scanFileRows(rows)
if err != nil {
return nil, fmt.Errorf("scanning file: %w", err)
}
file.MTime = time.Unix(mtimeUnix, 0)
file.CTime = time.Unix(ctimeUnix, 0)
if linkTarget.Valid {
file.LinkTarget = linkTarget.String
}
files = append(files, &file)
files = append(files, file)
}
return files, rows.Err()
@@ -238,14 +239,14 @@ func (r *FileRepository) Delete(ctx context.Context, tx *sql.Tx, path string) er
}
// DeleteByID deletes a file by its UUID
func (r *FileRepository) DeleteByID(ctx context.Context, tx *sql.Tx, id string) error {
func (r *FileRepository) DeleteByID(ctx context.Context, tx *sql.Tx, id types.FileID) error {
query := `DELETE FROM files WHERE id = ?`
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, id)
_, err = tx.ExecContext(ctx, query, id.String())
} else {
_, err = r.db.ExecWithLog(ctx, query, id)
_, err = r.db.ExecWithLog(ctx, query, id.String())
}
if err != nil {
@@ -257,7 +258,7 @@ func (r *FileRepository) DeleteByID(ctx context.Context, tx *sql.Tx, id string)
func (r *FileRepository) ListByPrefix(ctx context.Context, prefix string) ([]*File, error) {
query := `
SELECT id, path, mtime, ctime, size, mode, uid, gid, link_target
SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target
FROM files
WHERE path LIKE ? || '%'
ORDER BY path
@@ -271,32 +272,37 @@ func (r *FileRepository) ListByPrefix(ctx context.Context, prefix string) ([]*Fi
var files []*File
for rows.Next() {
var file File
var mtimeUnix, ctimeUnix int64
var linkTarget sql.NullString
err := rows.Scan(
&file.ID,
&file.Path,
&mtimeUnix,
&ctimeUnix,
&file.Size,
&file.Mode,
&file.UID,
&file.GID,
&linkTarget,
)
file, err := r.scanFileRows(rows)
if err != nil {
return nil, fmt.Errorf("scanning file: %w", err)
}
files = append(files, file)
}
file.MTime = time.Unix(mtimeUnix, 0)
file.CTime = time.Unix(ctimeUnix, 0)
if linkTarget.Valid {
file.LinkTarget = linkTarget.String
return files, rows.Err()
}
// ListAll returns all files in the database
func (r *FileRepository) ListAll(ctx context.Context) ([]*File, error) {
query := `
SELECT id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target
FROM files
ORDER BY path
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("querying files: %w", err)
}
defer CloseRows(rows)
var files []*File
for rows.Next() {
file, err := r.scanFileRows(rows)
if err != nil {
return nil, fmt.Errorf("scanning file: %w", err)
}
files = append(files, &file)
files = append(files, file)
}
return files, rows.Err()
@@ -309,7 +315,7 @@ func (r *FileRepository) CreateBatch(ctx context.Context, tx *sql.Tx, files []*F
return nil
}
// Each File has 9 values, so batch at 100 to be safe with SQLite's variable limit
// Each File has 10 values, so batch at 100 to be safe with SQLite's variable limit
const batchSize = 100
for i := 0; i < len(files); i += batchSize {
@@ -319,16 +325,17 @@ func (r *FileRepository) CreateBatch(ctx context.Context, tx *sql.Tx, files []*F
}
batch := files[i:end]
query := `INSERT INTO files (id, path, mtime, ctime, size, mode, uid, gid, link_target) VALUES `
args := make([]interface{}, 0, len(batch)*9)
query := `INSERT INTO files (id, path, source_path, mtime, ctime, size, mode, uid, gid, link_target) VALUES `
args := make([]interface{}, 0, len(batch)*10)
for j, f := range batch {
if j > 0 {
query += ", "
}
query += "(?, ?, ?, ?, ?, ?, ?, ?, ?)"
args = append(args, f.ID, f.Path, f.MTime.Unix(), f.CTime.Unix(), f.Size, f.Mode, f.UID, f.GID, f.LinkTarget)
query += "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
args = append(args, f.ID.String(), f.Path.String(), f.SourcePath.String(), f.MTime.Unix(), f.CTime.Unix(), f.Size, f.Mode, f.UID, f.GID, f.LinkTarget.String())
}
query += ` ON CONFLICT(path) DO UPDATE SET
source_path = excluded.source_path,
mtime = excluded.mtime,
ctime = excluded.ctime,
size = excluded.size,
@@ -354,9 +361,9 @@ func (r *FileRepository) CreateBatch(ctx context.Context, tx *sql.Tx, files []*F
// DeleteOrphaned deletes files that are not referenced by any snapshot
func (r *FileRepository) DeleteOrphaned(ctx context.Context) error {
query := `
DELETE FROM files
DELETE FROM files
WHERE NOT EXISTS (
SELECT 1 FROM snapshot_files
SELECT 1 FROM snapshot_files
WHERE snapshot_files.file_id = files.id
)
`

View File

@@ -53,7 +53,7 @@ func TestFileRepository(t *testing.T) {
}
// Test GetByPath
retrieved, err := repo.GetByPath(ctx, file.Path)
retrieved, err := repo.GetByPath(ctx, file.Path.String())
if err != nil {
t.Fatalf("failed to get file: %v", err)
}
@@ -81,7 +81,7 @@ func TestFileRepository(t *testing.T) {
t.Fatalf("failed to update file: %v", err)
}
retrieved, err = repo.GetByPath(ctx, file.Path)
retrieved, err = repo.GetByPath(ctx, file.Path.String())
if err != nil {
t.Fatalf("failed to get updated file: %v", err)
}
@@ -99,12 +99,12 @@ func TestFileRepository(t *testing.T) {
}
// Test Delete
err = repo.Delete(ctx, nil, file.Path)
err = repo.Delete(ctx, nil, file.Path.String())
if err != nil {
t.Fatalf("failed to delete file: %v", err)
}
retrieved, err = repo.GetByPath(ctx, file.Path)
retrieved, err = repo.GetByPath(ctx, file.Path.String())
if err != nil {
t.Fatalf("error getting deleted file: %v", err)
}
@@ -137,7 +137,7 @@ func TestFileRepositorySymlink(t *testing.T) {
t.Fatalf("failed to create symlink: %v", err)
}
retrieved, err := repo.GetByPath(ctx, symlink.Path)
retrieved, err := repo.GetByPath(ctx, symlink.Path.String())
if err != nil {
t.Fatalf("failed to get symlink: %v", err)
}

View File

@@ -2,22 +2,27 @@
// It includes types for files, chunks, blobs, snapshots, and their relationships.
package database
import "time"
import (
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
// File represents a file or directory in the backup system.
// It stores metadata about files including timestamps, permissions, ownership,
// and symlink targets. This information is used to restore files with their
// original attributes.
type File struct {
ID string // UUID primary key
Path string
ID types.FileID // UUID primary key
Path types.FilePath // Absolute path of the file
SourcePath types.SourcePath // The source directory this file came from (for restore path stripping)
MTime time.Time
CTime time.Time
Size int64
Mode uint32
UID uint32
GID uint32
LinkTarget string // empty for regular files, target path for symlinks
LinkTarget types.FilePath // empty for regular files, target path for symlinks
}
// IsSymlink returns true if this file is a symbolic link.
@@ -30,16 +35,16 @@ func (f *File) IsSymlink() bool {
// Large files are split into multiple chunks for efficient deduplication and storage.
// The Idx field maintains the order of chunks within a file.
type FileChunk struct {
FileID string
FileID types.FileID
Idx int
ChunkHash string
ChunkHash types.ChunkHash
}
// Chunk represents a data chunk in the deduplication system.
// Files are split into chunks which are content-addressed by their hash.
// The ChunkHash is the SHA256 hash of the chunk content, used for deduplication.
type Chunk struct {
ChunkHash string
ChunkHash types.ChunkHash
Size int64
}
@@ -51,13 +56,13 @@ type Chunk struct {
// The blob creation process is: chunks are accumulated -> compressed with zstd
// -> encrypted with age -> hashed -> uploaded to S3 with the hash as filename.
type Blob struct {
ID string // UUID assigned when blob creation starts
Hash string // SHA256 of final compressed+encrypted content (empty until finalized)
CreatedTS time.Time // When blob creation started
FinishedTS *time.Time // When blob was finalized (nil if still packing)
UncompressedSize int64 // Total size of raw chunks before compression
CompressedSize int64 // Size after compression and encryption
UploadedTS *time.Time // When blob was uploaded to S3 (nil if not uploaded)
ID types.BlobID // UUID assigned when blob creation starts
Hash types.BlobHash // SHA256 of final compressed+encrypted content (empty until finalized)
CreatedTS time.Time // When blob creation started
FinishedTS *time.Time // When blob was finalized (nil if still packing)
UncompressedSize int64 // Total size of raw chunks before compression
CompressedSize int64 // Size after compression and encryption
UploadedTS *time.Time // When blob was uploaded to S3 (nil if not uploaded)
}
// BlobChunk represents the mapping between blobs and the chunks they contain.
@@ -65,8 +70,8 @@ type Blob struct {
// their position and size within the blob. The offset and length fields
// enable extracting specific chunks from a blob without processing the entire blob.
type BlobChunk struct {
BlobID string
ChunkHash string
BlobID types.BlobID
ChunkHash types.ChunkHash
Offset int64
Length int64
}
@@ -75,18 +80,18 @@ type BlobChunk struct {
// This is used during deduplication to identify all files that share a chunk,
// which is important for garbage collection and integrity verification.
type ChunkFile struct {
ChunkHash string
FileID string
ChunkHash types.ChunkHash
FileID types.FileID
FileOffset int64
Length int64
}
// Snapshot represents a snapshot record in the database
type Snapshot struct {
ID string
Hostname string
VaultikVersion string
VaultikGitRevision string
ID types.SnapshotID
Hostname types.Hostname
VaultikVersion types.Version
VaultikGitRevision types.GitRevision
StartedAt time.Time
CompletedAt *time.Time // nil if still in progress
FileCount int64
@@ -108,13 +113,13 @@ func (s *Snapshot) IsComplete() bool {
// SnapshotFile represents the mapping between snapshots and files
type SnapshotFile struct {
SnapshotID string
FileID string
SnapshotID types.SnapshotID
FileID types.FileID
}
// SnapshotBlob represents the mapping between snapshots and blobs
type SnapshotBlob struct {
SnapshotID string
BlobID string
BlobHash string // Denormalized for easier manifest generation
SnapshotID types.SnapshotID
BlobID types.BlobID
BlobHash types.BlobHash // Denormalized for easier manifest generation
}

View File

@@ -75,6 +75,11 @@ func (r *Repositories) WithTx(ctx context.Context, fn TxFunc) error {
return tx.Commit()
}
// DB returns the underlying database for direct queries
func (r *Repositories) DB() *DB {
return r.db
}
// WithReadTx executes a function within a read-only transaction.
// Read transactions can run concurrently with other read transactions
// but will be blocked by write transactions. The transaction is

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
func TestRepositoriesTransaction(t *testing.T) {
@@ -33,7 +35,7 @@ func TestRepositoriesTransaction(t *testing.T) {
// Create chunks
chunk1 := &Chunk{
ChunkHash: "tx_chunk1",
ChunkHash: types.ChunkHash("tx_chunk1"),
Size: 512,
}
if err := repos.Chunks.Create(ctx, tx, chunk1); err != nil {
@@ -41,7 +43,7 @@ func TestRepositoriesTransaction(t *testing.T) {
}
chunk2 := &Chunk{
ChunkHash: "tx_chunk2",
ChunkHash: types.ChunkHash("tx_chunk2"),
Size: 512,
}
if err := repos.Chunks.Create(ctx, tx, chunk2); err != nil {
@@ -69,8 +71,8 @@ func TestRepositoriesTransaction(t *testing.T) {
// Create blob
blob := &Blob{
ID: "tx-blob-id-1",
Hash: "tx_blob1",
ID: types.NewBlobID(),
Hash: types.BlobHash("tx_blob1"),
CreatedTS: time.Now().Truncate(time.Second),
}
if err := repos.Blobs.Create(ctx, tx, blob); err != nil {
@@ -156,7 +158,7 @@ func TestRepositoriesTransactionRollback(t *testing.T) {
// Create a chunk
chunk := &Chunk{
ChunkHash: "rollback_chunk",
ChunkHash: types.ChunkHash("rollback_chunk"),
Size: 1024,
}
if err := repos.Chunks.Create(ctx, tx, chunk); err != nil {

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
// TestFileRepositoryUUIDGeneration tests that files get unique UUIDs
@@ -46,15 +48,15 @@ func TestFileRepositoryUUIDGeneration(t *testing.T) {
}
// Check UUID was generated
if file.ID == "" {
if file.ID.IsZero() {
t.Error("file ID was not generated")
}
// Check UUID is unique
if uuids[file.ID] {
if uuids[file.ID.String()] {
t.Errorf("duplicate UUID generated: %s", file.ID)
}
uuids[file.ID] = true
uuids[file.ID.String()] = true
}
}
@@ -96,7 +98,8 @@ func TestFileRepositoryGetByID(t *testing.T) {
}
// Test non-existent ID
nonExistent, err := repo.GetByID(ctx, "non-existent-uuid")
nonExistentID := types.NewFileID() // Generate a new UUID that won't exist in the database
nonExistent, err := repo.GetByID(ctx, nonExistentID)
if err != nil {
t.Fatalf("GetByID should not return error for non-existent ID: %v", err)
}
@@ -154,7 +157,7 @@ func TestOrphanedFileCleanup(t *testing.T) {
}
// Add file2 to snapshot
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file2.ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID.String(), file2.ID)
if err != nil {
t.Fatalf("failed to add file to snapshot: %v", err)
}
@@ -194,11 +197,11 @@ func TestOrphanedChunkCleanup(t *testing.T) {
// Create chunks
chunk1 := &Chunk{
ChunkHash: "orphaned-chunk",
ChunkHash: types.ChunkHash("orphaned-chunk"),
Size: 1024,
}
chunk2 := &Chunk{
ChunkHash: "referenced-chunk",
ChunkHash: types.ChunkHash("referenced-chunk"),
Size: 1024,
}
@@ -244,7 +247,7 @@ func TestOrphanedChunkCleanup(t *testing.T) {
}
// Check that orphaned chunk is gone
orphanedChunk, err := repos.Chunks.GetByHash(ctx, chunk1.ChunkHash)
orphanedChunk, err := repos.Chunks.GetByHash(ctx, chunk1.ChunkHash.String())
if err != nil {
t.Fatalf("error getting chunk: %v", err)
}
@@ -253,7 +256,7 @@ func TestOrphanedChunkCleanup(t *testing.T) {
}
// Check that referenced chunk still exists
referencedChunk, err := repos.Chunks.GetByHash(ctx, chunk2.ChunkHash)
referencedChunk, err := repos.Chunks.GetByHash(ctx, chunk2.ChunkHash.String())
if err != nil {
t.Fatalf("error getting chunk: %v", err)
}
@@ -272,13 +275,13 @@ func TestOrphanedBlobCleanup(t *testing.T) {
// Create blobs
blob1 := &Blob{
ID: "orphaned-blob-id",
Hash: "orphaned-blob",
ID: types.NewBlobID(),
Hash: types.BlobHash("orphaned-blob"),
CreatedTS: time.Now().Truncate(time.Second),
}
blob2 := &Blob{
ID: "referenced-blob-id",
Hash: "referenced-blob",
ID: types.NewBlobID(),
Hash: types.BlobHash("referenced-blob"),
CreatedTS: time.Now().Truncate(time.Second),
}
@@ -303,7 +306,7 @@ func TestOrphanedBlobCleanup(t *testing.T) {
}
// Add blob2 to snapshot
err = repos.Snapshots.AddBlob(ctx, nil, snapshot.ID, blob2.ID, blob2.Hash)
err = repos.Snapshots.AddBlob(ctx, nil, snapshot.ID.String(), blob2.ID, blob2.Hash)
if err != nil {
t.Fatalf("failed to add blob to snapshot: %v", err)
}
@@ -315,7 +318,7 @@ func TestOrphanedBlobCleanup(t *testing.T) {
}
// Check that orphaned blob is gone
orphanedBlob, err := repos.Blobs.GetByID(ctx, blob1.ID)
orphanedBlob, err := repos.Blobs.GetByID(ctx, blob1.ID.String())
if err != nil {
t.Fatalf("error getting blob: %v", err)
}
@@ -324,7 +327,7 @@ func TestOrphanedBlobCleanup(t *testing.T) {
}
// Check that referenced blob still exists
referencedBlob, err := repos.Blobs.GetByID(ctx, blob2.ID)
referencedBlob, err := repos.Blobs.GetByID(ctx, blob2.ID.String())
if err != nil {
t.Fatalf("error getting blob: %v", err)
}
@@ -357,7 +360,7 @@ func TestFileChunkRepositoryWithUUIDs(t *testing.T) {
}
// Create chunks
chunks := []string{"chunk1", "chunk2", "chunk3"}
chunks := []types.ChunkHash{"chunk1", "chunk2", "chunk3"}
for i, chunkHash := range chunks {
chunk := &Chunk{
ChunkHash: chunkHash,
@@ -443,7 +446,7 @@ func TestChunkFileRepositoryWithUUIDs(t *testing.T) {
// Create a chunk that appears in both files (deduplication)
chunk := &Chunk{
ChunkHash: "shared-chunk",
ChunkHash: types.ChunkHash("shared-chunk"),
Size: 1024,
}
err = repos.Chunks.Create(ctx, nil, chunk)
@@ -526,7 +529,7 @@ func TestSnapshotRepositoryExtendedFields(t *testing.T) {
}
// Retrieve and verify
retrieved, err := repo.GetByID(ctx, snapshot.ID)
retrieved, err := repo.GetByID(ctx, snapshot.ID.String())
if err != nil {
t.Fatalf("failed to get snapshot: %v", err)
}
@@ -581,7 +584,7 @@ func TestComplexOrphanedDataScenario(t *testing.T) {
files := make([]*File, 3)
for i := range files {
files[i] = &File{
Path: fmt.Sprintf("/file%d.txt", i),
Path: types.FilePath(fmt.Sprintf("/file%d.txt", i)),
MTime: time.Now().Truncate(time.Second),
CTime: time.Now().Truncate(time.Second),
Size: 1024,
@@ -601,29 +604,29 @@ func TestComplexOrphanedDataScenario(t *testing.T) {
// file0: only in snapshot1
// file1: in both snapshots
// file2: only in snapshot2
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot1.ID, files[0].ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot1.ID.String(), files[0].ID)
if err != nil {
t.Fatal(err)
}
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot1.ID, files[1].ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot1.ID.String(), files[1].ID)
if err != nil {
t.Fatal(err)
}
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot2.ID, files[1].ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot2.ID.String(), files[1].ID)
if err != nil {
t.Fatal(err)
}
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot2.ID, files[2].ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot2.ID.String(), files[2].ID)
if err != nil {
t.Fatal(err)
}
// Delete snapshot1
err = repos.Snapshots.DeleteSnapshotFiles(ctx, snapshot1.ID)
err = repos.Snapshots.DeleteSnapshotFiles(ctx, snapshot1.ID.String())
if err != nil {
t.Fatal(err)
}
err = repos.Snapshots.Delete(ctx, snapshot1.ID)
err = repos.Snapshots.Delete(ctx, snapshot1.ID.String())
if err != nil {
t.Fatal(err)
}
@@ -689,7 +692,7 @@ func TestCascadeDelete(t *testing.T) {
// Create chunks and file-chunk mappings
for i := 0; i < 3; i++ {
chunk := &Chunk{
ChunkHash: fmt.Sprintf("cascade-chunk-%d", i),
ChunkHash: types.ChunkHash(fmt.Sprintf("cascade-chunk-%d", i)),
Size: 1024,
}
err = repos.Chunks.Create(ctx, nil, chunk)
@@ -807,7 +810,7 @@ func TestConcurrentOrphanedCleanup(t *testing.T) {
// Create many files, some orphaned
for i := 0; i < 20; i++ {
file := &File{
Path: fmt.Sprintf("/concurrent-%d.txt", i),
Path: types.FilePath(fmt.Sprintf("/concurrent-%d.txt", i)),
MTime: time.Now().Truncate(time.Second),
CTime: time.Now().Truncate(time.Second),
Size: 1024,
@@ -822,7 +825,7 @@ func TestConcurrentOrphanedCleanup(t *testing.T) {
// Add even-numbered files to snapshot
if i%2 == 0 {
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file.ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID.String(), file.ID)
if err != nil {
t.Fatal(err)
}
@@ -860,7 +863,7 @@ func TestConcurrentOrphanedCleanup(t *testing.T) {
// Verify all remaining files are even-numbered
for _, file := range files {
var num int
_, err := fmt.Sscanf(file.Path, "/concurrent-%d.txt", &num)
_, err := fmt.Sscanf(file.Path.String(), "/concurrent-%d.txt", &num)
if err != nil {
t.Logf("failed to parse file number from %s: %v", file.Path, err)
}

View File

@@ -67,7 +67,7 @@ func TestOrphanedFileCleanupDebug(t *testing.T) {
t.Logf("snapshot_files count before add: %d", count)
// Add file2 to snapshot
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file2.ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID.String(), file2.ID)
if err != nil {
t.Fatalf("failed to add file to snapshot: %v", err)
}

View File

@@ -6,6 +6,8 @@ import (
"strings"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
// TestFileRepositoryEdgeCases tests edge cases for file repository
@@ -38,7 +40,7 @@ func TestFileRepositoryEdgeCases(t *testing.T) {
{
name: "very long path",
file: &File{
Path: "/" + strings.Repeat("a", 4096),
Path: types.FilePath("/" + strings.Repeat("a", 4096)),
MTime: time.Now(),
CTime: time.Now(),
Size: 1024,
@@ -94,7 +96,7 @@ func TestFileRepositoryEdgeCases(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
// Add a unique suffix to paths to avoid UNIQUE constraint violations
if tt.file.Path != "" {
tt.file.Path = fmt.Sprintf("%s_%d_%d", tt.file.Path, i, time.Now().UnixNano())
tt.file.Path = types.FilePath(fmt.Sprintf("%s_%d_%d", tt.file.Path, i, time.Now().UnixNano()))
}
err := repo.Create(ctx, nil, tt.file)
@@ -169,7 +171,7 @@ func TestDuplicateHandling(t *testing.T) {
// Test duplicate chunk hashes
t.Run("duplicate chunk hashes", func(t *testing.T) {
chunk := &Chunk{
ChunkHash: "duplicate-chunk",
ChunkHash: types.ChunkHash("duplicate-chunk"),
Size: 1024,
}
@@ -202,7 +204,7 @@ func TestDuplicateHandling(t *testing.T) {
}
chunk := &Chunk{
ChunkHash: "test-chunk-dup",
ChunkHash: types.ChunkHash("test-chunk-dup"),
Size: 1024,
}
err = repos.Chunks.Create(ctx, nil, chunk)
@@ -279,7 +281,7 @@ func TestNullHandling(t *testing.T) {
t.Fatal(err)
}
retrieved, err := repos.Snapshots.GetByID(ctx, snapshot.ID)
retrieved, err := repos.Snapshots.GetByID(ctx, snapshot.ID.String())
if err != nil {
t.Fatal(err)
}
@@ -292,8 +294,8 @@ func TestNullHandling(t *testing.T) {
// Test blob with NULL uploaded_ts
t.Run("blob not uploaded", func(t *testing.T) {
blob := &Blob{
ID: "not-uploaded",
Hash: "test-hash",
ID: types.NewBlobID(),
Hash: types.BlobHash("test-hash"),
CreatedTS: time.Now(),
UploadedTS: nil, // Not uploaded yet
}
@@ -303,7 +305,7 @@ func TestNullHandling(t *testing.T) {
t.Fatal(err)
}
retrieved, err := repos.Blobs.GetByID(ctx, blob.ID)
retrieved, err := repos.Blobs.GetByID(ctx, blob.ID.String())
if err != nil {
t.Fatal(err)
}
@@ -339,13 +341,13 @@ func TestLargeDatasets(t *testing.T) {
// Create many files
const fileCount = 1000
fileIDs := make([]string, fileCount)
fileIDs := make([]types.FileID, fileCount)
t.Run("create many files", func(t *testing.T) {
start := time.Now()
for i := 0; i < fileCount; i++ {
file := &File{
Path: fmt.Sprintf("/large/file%05d.txt", i),
Path: types.FilePath(fmt.Sprintf("/large/file%05d.txt", i)),
MTime: time.Now(),
CTime: time.Now(),
Size: int64(i * 1024),
@@ -361,7 +363,7 @@ func TestLargeDatasets(t *testing.T) {
// Add half to snapshot
if i%2 == 0 {
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID, file.ID)
err = repos.Snapshots.AddFileByID(ctx, nil, snapshot.ID.String(), file.ID)
if err != nil {
t.Fatal(err)
}
@@ -413,7 +415,7 @@ func TestErrorPropagation(t *testing.T) {
// Test GetByID with non-existent ID
t.Run("GetByID non-existent", func(t *testing.T) {
file, err := repos.Files.GetByID(ctx, "non-existent-uuid")
file, err := repos.Files.GetByID(ctx, types.NewFileID())
if err != nil {
t.Errorf("GetByID should not return error for non-existent ID, got: %v", err)
}
@@ -436,9 +438,9 @@ func TestErrorPropagation(t *testing.T) {
// Test invalid foreign key reference
t.Run("invalid foreign key", func(t *testing.T) {
fc := &FileChunk{
FileID: "non-existent-file-id",
FileID: types.NewFileID(),
Idx: 0,
ChunkHash: "some-chunk",
ChunkHash: types.ChunkHash("some-chunk"),
}
err := repos.FileChunks.Create(ctx, nil, fc)
if err == nil {
@@ -470,7 +472,7 @@ func TestQueryInjection(t *testing.T) {
t.Run("injection attempt", func(t *testing.T) {
// Try injection in file path
file := &File{
Path: injection,
Path: types.FilePath(injection),
MTime: time.Now(),
CTime: time.Now(),
Size: 1024,

View File

@@ -6,6 +6,7 @@
CREATE TABLE IF NOT EXISTS files (
id TEXT PRIMARY KEY, -- UUID
path TEXT NOT NULL UNIQUE,
source_path TEXT NOT NULL DEFAULT '', -- The source directory this file came from (for restore path stripping)
mtime INTEGER NOT NULL,
ctime INTEGER NOT NULL,
size INTEGER NOT NULL,

View File

@@ -5,6 +5,8 @@ import (
"database/sql"
"fmt"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
type SnapshotRepository struct {
@@ -269,7 +271,7 @@ func (r *SnapshotRepository) AddFile(ctx context.Context, tx *sql.Tx, snapshotID
}
// AddFileByID adds a file to a snapshot by file ID
func (r *SnapshotRepository) AddFileByID(ctx context.Context, tx *sql.Tx, snapshotID string, fileID string) error {
func (r *SnapshotRepository) AddFileByID(ctx context.Context, tx *sql.Tx, snapshotID string, fileID types.FileID) error {
query := `
INSERT OR IGNORE INTO snapshot_files (snapshot_id, file_id)
VALUES (?, ?)
@@ -277,9 +279,9 @@ func (r *SnapshotRepository) AddFileByID(ctx context.Context, tx *sql.Tx, snapsh
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, snapshotID, fileID)
_, err = tx.ExecContext(ctx, query, snapshotID, fileID.String())
} else {
_, err = r.db.ExecWithLog(ctx, query, snapshotID, fileID)
_, err = r.db.ExecWithLog(ctx, query, snapshotID, fileID.String())
}
if err != nil {
@@ -290,7 +292,7 @@ func (r *SnapshotRepository) AddFileByID(ctx context.Context, tx *sql.Tx, snapsh
}
// AddFilesByIDBatch adds multiple files to a snapshot in batched inserts
func (r *SnapshotRepository) AddFilesByIDBatch(ctx context.Context, tx *sql.Tx, snapshotID string, fileIDs []string) error {
func (r *SnapshotRepository) AddFilesByIDBatch(ctx context.Context, tx *sql.Tx, snapshotID string, fileIDs []types.FileID) error {
if len(fileIDs) == 0 {
return nil
}
@@ -312,7 +314,7 @@ func (r *SnapshotRepository) AddFilesByIDBatch(ctx context.Context, tx *sql.Tx,
query += ", "
}
query += "(?, ?)"
args = append(args, snapshotID, fileID)
args = append(args, snapshotID, fileID.String())
}
var err error
@@ -330,7 +332,7 @@ func (r *SnapshotRepository) AddFilesByIDBatch(ctx context.Context, tx *sql.Tx,
}
// AddBlob adds a blob to a snapshot
func (r *SnapshotRepository) AddBlob(ctx context.Context, tx *sql.Tx, snapshotID string, blobID string, blobHash string) error {
func (r *SnapshotRepository) AddBlob(ctx context.Context, tx *sql.Tx, snapshotID string, blobID types.BlobID, blobHash types.BlobHash) error {
query := `
INSERT OR IGNORE INTO snapshot_blobs (snapshot_id, blob_id, blob_hash)
VALUES (?, ?, ?)
@@ -338,9 +340,9 @@ func (r *SnapshotRepository) AddBlob(ctx context.Context, tx *sql.Tx, snapshotID
var err error
if tx != nil {
_, err = tx.ExecContext(ctx, query, snapshotID, blobID, blobHash)
_, err = tx.ExecContext(ctx, query, snapshotID, blobID.String(), blobHash.String())
} else {
_, err = r.db.ExecWithLog(ctx, query, snapshotID, blobID, blobHash)
_, err = r.db.ExecWithLog(ctx, query, snapshotID, blobID.String(), blobHash.String())
}
if err != nil {

View File

@@ -6,6 +6,8 @@ import (
"math"
"testing"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
const (
@@ -46,7 +48,7 @@ func TestSnapshotRepository(t *testing.T) {
}
// Test GetByID
retrieved, err := repo.GetByID(ctx, snapshot.ID)
retrieved, err := repo.GetByID(ctx, snapshot.ID.String())
if err != nil {
t.Fatalf("failed to get snapshot: %v", err)
}
@@ -64,12 +66,12 @@ func TestSnapshotRepository(t *testing.T) {
}
// Test UpdateCounts
err = repo.UpdateCounts(ctx, nil, snapshot.ID, 200, 1000, 20, twoHundredMebibytes, sixtyMebibytes)
err = repo.UpdateCounts(ctx, nil, snapshot.ID.String(), 200, 1000, 20, twoHundredMebibytes, sixtyMebibytes)
if err != nil {
t.Fatalf("failed to update counts: %v", err)
}
retrieved, err = repo.GetByID(ctx, snapshot.ID)
retrieved, err = repo.GetByID(ctx, snapshot.ID.String())
if err != nil {
t.Fatalf("failed to get updated snapshot: %v", err)
}
@@ -97,7 +99,7 @@ func TestSnapshotRepository(t *testing.T) {
// Add more snapshots
for i := 2; i <= 5; i++ {
s := &Snapshot{
ID: fmt.Sprintf("2024-01-0%dT12:00:00Z", i),
ID: types.SnapshotID(fmt.Sprintf("2024-01-0%dT12:00:00Z", i)),
Hostname: "test-host",
VaultikVersion: "1.0.0",
StartedAt: time.Now().Add(time.Duration(i) * time.Hour).Truncate(time.Second),

View File

@@ -14,6 +14,7 @@ import (
"time"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/types"
)
// MockS3Client is a mock implementation of S3 operations for testing
@@ -138,13 +139,13 @@ func TestBackupWithInMemoryFS(t *testing.T) {
}
for _, file := range files {
if !expectedFiles[file.Path] {
if !expectedFiles[file.Path.String()] {
t.Errorf("Unexpected file in database: %s", file.Path)
}
delete(expectedFiles, file.Path)
delete(expectedFiles, file.Path.String())
// Verify file metadata
fsFile := testFS[file.Path]
fsFile := testFS[file.Path.String()]
if fsFile == nil {
t.Errorf("File %s not found in test filesystem", file.Path)
continue
@@ -294,8 +295,8 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
hostname, _ := os.Hostname()
snapshotID := time.Now().Format(time.RFC3339)
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: hostname,
ID: types.SnapshotID(snapshotID),
Hostname: types.Hostname(hostname),
VaultikVersion: "test",
StartedAt: time.Now(),
CompletedAt: nil,
@@ -340,7 +341,7 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
// Create file record in a short transaction
file := &database.File{
Path: path,
Path: types.FilePath(path),
Size: info.Size(),
Mode: uint32(info.Mode()),
MTime: info.ModTime(),
@@ -392,7 +393,7 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
// Create new chunk in a short transaction
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
chunk := &database.Chunk{
ChunkHash: chunkHash,
ChunkHash: types.ChunkHash(chunkHash),
Size: int64(n),
}
return b.repos.Chunks.Create(ctx, tx, chunk)
@@ -408,7 +409,7 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
fileChunk := &database.FileChunk{
FileID: file.ID,
Idx: chunkIndex,
ChunkHash: chunkHash,
ChunkHash: types.ChunkHash(chunkHash),
}
return b.repos.FileChunks.Create(ctx, tx, fileChunk)
})
@@ -419,7 +420,7 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
// Create chunk-file mapping in a short transaction
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
chunkFile := &database.ChunkFile{
ChunkHash: chunkHash,
ChunkHash: types.ChunkHash(chunkHash),
FileID: file.ID,
FileOffset: int64(chunkIndex * defaultChunkSize),
Length: int64(n),
@@ -463,10 +464,11 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
}
// Create blob entry in a short transaction
blobID := types.NewBlobID()
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
blob := &database.Blob{
ID: "test-blob-" + blobHash[:8],
Hash: blobHash,
ID: blobID,
Hash: types.BlobHash(blobHash),
CreatedTS: time.Now(),
}
return b.repos.Blobs.Create(ctx, tx, blob)
@@ -481,8 +483,8 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
// Create blob-chunk mapping in a short transaction
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
blobChunk := &database.BlobChunk{
BlobID: "test-blob-" + blobHash[:8],
ChunkHash: chunkHash,
BlobID: blobID,
ChunkHash: types.ChunkHash(chunkHash),
Offset: 0,
Length: chunk.Size,
}
@@ -494,7 +496,7 @@ func (b *BackupEngine) Backup(ctx context.Context, fsys fs.FS, root string) (str
// Add blob to snapshot in a short transaction
err = b.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return b.repos.Snapshots.AddBlob(ctx, tx, snapshotID, "test-blob-"+blobHash[:8], blobHash)
return b.repos.Snapshots.AddBlob(ctx, tx, snapshotID, blobID, types.BlobHash(blobHash))
})
if err != nil {
return "", err

View File

@@ -10,6 +10,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
)
@@ -108,7 +109,7 @@ func createSnapshotRecord(t *testing.T, ctx context.Context, repos *database.Rep
t.Helper()
err := repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snap := &database.Snapshot{
ID: snapshotID,
ID: types.SnapshotID(snapshotID),
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),

View File

@@ -9,6 +9,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -53,7 +54,7 @@ func TestFileContentChange(t *testing.T) {
snapshotID1 := "snapshot1"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID1,
ID: types.SnapshotID(snapshotID1),
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),
@@ -87,7 +88,7 @@ func TestFileContentChange(t *testing.T) {
snapshotID2 := "snapshot2"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID2,
ID: types.SnapshotID(snapshotID2),
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),
@@ -117,12 +118,12 @@ func TestFileContentChange(t *testing.T) {
assert.Equal(t, newChunkHash, chunkFiles2[0].ChunkHash)
// Verify old chunk still exists (it's still valid data)
oldChunk, err := repos.Chunks.GetByHash(ctx, oldChunkHash)
oldChunk, err := repos.Chunks.GetByHash(ctx, oldChunkHash.String())
require.NoError(t, err)
assert.NotNil(t, oldChunk)
// Verify new chunk exists
newChunk, err := repos.Chunks.GetByHash(ctx, newChunkHash)
newChunk, err := repos.Chunks.GetByHash(ctx, newChunkHash.String())
require.NoError(t, err)
assert.NotNil(t, newChunk)
@@ -182,7 +183,7 @@ func TestMultipleFileChanges(t *testing.T) {
snapshotID1 := "snapshot1"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID1,
ID: types.SnapshotID(snapshotID1),
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),
@@ -208,7 +209,7 @@ func TestMultipleFileChanges(t *testing.T) {
snapshotID2 := "snapshot2"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID2,
ID: types.SnapshotID(snapshotID2),
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),

View File

@@ -12,6 +12,8 @@ import (
type ScannerParams struct {
EnableProgress bool
Fs afero.Fs
Exclude []string // Exclude patterns (combined global + snapshot-specific)
SkipErrors bool // Skip file read errors (log loudly but continue)
}
// Module exports backup functionality as an fx module.
@@ -29,6 +31,12 @@ type ScannerFactory func(params ScannerParams) *Scanner
func provideScannerFactory(cfg *config.Config, repos *database.Repositories, storer storage.Storer) ScannerFactory {
return func(params ScannerParams) *Scanner {
// Use provided excludes, or fall back to global config excludes
excludes := params.Exclude
if len(excludes) == 0 {
excludes = cfg.Exclude
}
return NewScanner(ScannerConfig{
FS: params.Fs,
ChunkSize: cfg.ChunkSize.Int64(),
@@ -38,7 +46,8 @@ func provideScannerFactory(cfg *config.Config, repos *database.Repositories, sto
CompressionLevel: cfg.CompressionLevel,
AgeRecipients: cfg.AgeRecipients,
EnableProgress: params.EnableProgress,
Exclude: cfg.Exclude,
Exclude: excludes,
SkipErrors: params.SkipErrors,
})
}
}

View File

@@ -16,9 +16,9 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/dustin/go-humanize"
"github.com/gobwas/glob"
"github.com/google/uuid"
"github.com/spf13/afero"
)
@@ -45,18 +45,20 @@ type compiledPattern struct {
// Scanner scans directories and populates the database with file and chunk information
type Scanner struct {
fs afero.Fs
chunker *chunker.Chunker
packer *blob.Packer
repos *database.Repositories
storage storage.Storer
maxBlobSize int64
compressionLevel int
ageRecipient string
snapshotID string // Current snapshot being processed
exclude []string // Glob patterns for files/directories to exclude
compiledExclude []compiledPattern // Compiled glob patterns
progress *ProgressReporter
fs afero.Fs
chunker *chunker.Chunker
packer *blob.Packer
repos *database.Repositories
storage storage.Storer
maxBlobSize int64
compressionLevel int
ageRecipient string
snapshotID string // Current snapshot being processed
currentSourcePath string // Current source directory being scanned (for restore path stripping)
exclude []string // Glob patterns for files/directories to exclude
compiledExclude []compiledPattern // Compiled glob patterns
progress *ProgressReporter
skipErrors bool // Skip file read errors (log loudly but continue)
// In-memory cache of known chunk hashes for fast existence checks
knownChunks map[string]struct{}
@@ -90,6 +92,7 @@ type ScannerConfig struct {
AgeRecipients []string // Optional, empty means no encryption
EnableProgress bool // Enable progress reporting
Exclude []string // Glob patterns for files/directories to exclude
SkipErrors bool // Skip file read errors (log loudly but continue)
}
// ScanResult contains the results of a scan operation
@@ -148,6 +151,7 @@ func NewScanner(cfg ScannerConfig) *Scanner {
exclude: cfg.Exclude,
compiledExclude: compiledExclude,
progress: progress,
skipErrors: cfg.SkipErrors,
pendingChunkHashes: make(map[string]struct{}),
}
}
@@ -155,6 +159,7 @@ func NewScanner(cfg ScannerConfig) *Scanner {
// Scan scans a directory and populates the database
func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*ScanResult, error) {
s.snapshotID = snapshotID
s.currentSourcePath = path // Store source path for file records (used during restore)
s.scanCtx = ctx
result := &ScanResult{
StartTime: time.Now().UTC(),
@@ -284,7 +289,7 @@ func (s *Scanner) loadKnownFiles(ctx context.Context, path string) (map[string]*
result := make(map[string]*database.File, len(files))
for _, f := range files {
result[f.Path] = f
result[f.Path.String()] = f
}
return result, nil
@@ -301,7 +306,7 @@ func (s *Scanner) loadKnownChunks(ctx context.Context) error {
s.knownChunksMu.Lock()
s.knownChunks = make(map[string]struct{}, len(chunks))
for _, c := range chunks {
s.knownChunks[c.ChunkHash] = struct{}{}
s.knownChunks[c.ChunkHash.String()] = struct{}{}
}
s.knownChunksMu.Unlock()
@@ -432,7 +437,7 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error {
for _, data := range s.pendingFiles {
allChunksCommitted := true
for _, fc := range data.fileChunks {
if s.isChunkPending(fc.ChunkHash) {
if s.isChunkPending(fc.ChunkHash.String()) {
allChunksCommitted = false
break
}
@@ -463,7 +468,7 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error {
collectStart := time.Now()
var allFileChunks []database.FileChunk
var allChunkFiles []database.ChunkFile
var allFileIDs []string
var allFileIDs []types.FileID
var allFiles []*database.File
for _, data := range canFlush {
@@ -542,7 +547,7 @@ func (s *Scanner) flushCompletedPendingFiles(ctx context.Context) error {
// ScanPhaseResult contains the results of the scan phase
type ScanPhaseResult struct {
FilesToProcess []*FileToProcess
UnchangedFileIDs []string // IDs of unchanged files to associate with snapshot
UnchangedFileIDs []types.FileID // IDs of unchanged files to associate with snapshot
}
// scanPhase performs the initial directory scan to identify files to process
@@ -554,7 +559,7 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult
estimatedTotal := int64(len(knownFiles))
var filesToProcess []*FileToProcess
var unchangedFileIDs []string // Just IDs - no new records needed
var unchangedFileIDs []types.FileID // Just IDs - no new records needed
var mu sync.Mutex
// Set up periodic status output
@@ -566,6 +571,11 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult
log.Debug("Starting directory walk", "path", path)
err := afero.Walk(s.fs, path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
if s.skipErrors {
log.Error("ERROR: Failed to access file (skipping due to --skip-errors)", "path", filePath, "error", err)
fmt.Printf("ERROR: Failed to access %s: %v (skipping)\n", filePath, err)
return nil // Continue scanning
}
log.Debug("Error accessing filesystem entry", "path", filePath, "error", err)
return err
}
@@ -604,7 +614,7 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult
FileInfo: info,
File: file,
})
} else if file.ID != "" {
} else if !file.ID.IsZero() {
// Unchanged file with existing ID - just need snapshot association
unchangedFileIDs = append(unchangedFileIDs, file.ID)
}
@@ -696,22 +706,23 @@ func (s *Scanner) checkFileInMemory(path string, info os.FileInfo, knownFiles ma
// Create file record with ID set upfront
// For new files, generate UUID immediately so it's available for chunk associations
// For existing files, reuse the existing ID
var fileID string
var fileID types.FileID
if exists {
fileID = existingFile.ID
} else {
fileID = uuid.New().String()
fileID = types.NewFileID()
}
file := &database.File{
ID: fileID,
Path: path,
MTime: info.ModTime(),
CTime: info.ModTime(), // afero doesn't provide ctime
Size: info.Size(),
Mode: uint32(info.Mode()),
UID: uid,
GID: gid,
ID: fileID,
Path: types.FilePath(path),
SourcePath: types.SourcePath(s.currentSourcePath), // Store source directory for restore path stripping
MTime: info.ModTime(),
CTime: info.ModTime(), // afero doesn't provide ctime
Size: info.Size(),
Mode: uint32(info.Mode()),
UID: uid,
GID: gid,
}
// New file - needs processing
@@ -734,7 +745,7 @@ func (s *Scanner) checkFileInMemory(path string, info os.FileInfo, knownFiles ma
// batchAddFilesToSnapshot adds existing file IDs to the snapshot association table
// This is used for unchanged files that already have records in the database
func (s *Scanner) batchAddFilesToSnapshot(ctx context.Context, fileIDs []string) error {
func (s *Scanner) batchAddFilesToSnapshot(ctx context.Context, fileIDs []types.FileID) error {
const batchSize = 1000
startTime := time.Now()
@@ -817,6 +828,13 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc
result.FilesSkipped++
continue
}
// Skip file read errors if --skip-errors is enabled
if s.skipErrors {
log.Error("ERROR: Failed to process file (skipping due to --skip-errors)", "path", fileToProcess.Path, "error", err)
fmt.Printf("ERROR: Failed to process %s: %v (skipping)\n", fileToProcess.Path, err)
result.FilesSkipped++
continue
}
return fmt.Errorf("processing file %s: %w", fileToProcess.Path, err)
}
@@ -881,8 +899,12 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc
for _, b := range blobs {
// Blob metadata is already stored incrementally during packing
// Just add the blob to the snapshot
err := s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, b.ID, b.Hash)
blobID, err := types.ParseBlobID(b.ID)
if err != nil {
return fmt.Errorf("parsing blob ID: %w", err)
}
err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, blobID, types.BlobHash(b.Hash))
})
if err != nil {
return fmt.Errorf("storing blob metadata: %w", err)
@@ -984,14 +1006,21 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
if dbCtx == nil {
dbCtx = context.Background()
}
err := s.repos.WithTx(dbCtx, func(ctx context.Context, tx *sql.Tx) error {
// Parse blob ID for typed operations
finishedBlobID, err := types.ParseBlobID(finishedBlob.ID)
if err != nil {
return fmt.Errorf("parsing finished blob ID: %w", err)
}
err = s.repos.WithTx(dbCtx, func(ctx context.Context, tx *sql.Tx) error {
// Update blob upload timestamp
if err := s.repos.Blobs.UpdateUploaded(ctx, tx, finishedBlob.ID); err != nil {
return fmt.Errorf("updating blob upload timestamp: %w", err)
}
// Add the blob to the snapshot
if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, finishedBlob.ID, finishedBlob.Hash); err != nil {
if err := s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, finishedBlobID, types.BlobHash(finishedBlob.Hash)); err != nil {
return fmt.Errorf("adding blob to snapshot: %w", err)
}
@@ -1094,7 +1123,7 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
fileChunk: database.FileChunk{
FileID: fileToProcess.File.ID,
Idx: chunkIndex,
ChunkHash: chunk.Hash,
ChunkHash: types.ChunkHash(chunk.Hash),
},
offset: chunk.Offset,
size: chunk.Size,

View File

@@ -10,6 +10,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/spf13/afero"
)
@@ -74,7 +75,7 @@ func TestScannerSimpleDirectory(t *testing.T) {
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
ID: types.SnapshotID(snapshotID),
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),
@@ -209,7 +210,7 @@ func TestScannerLargeFile(t *testing.T) {
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
ID: types.SnapshotID(snapshotID),
Hostname: "test-host",
VaultikVersion: "test",
StartedAt: time.Now(),

View File

@@ -54,6 +54,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/dustin/go-humanize"
"github.com/spf13/afero"
"go.uber.org/fx"
@@ -90,20 +91,35 @@ func (sm *SnapshotManager) SetFilesystem(fs afero.Fs) {
sm.fs = fs
}
// CreateSnapshot creates a new snapshot record in the database at the start of a backup
// CreateSnapshot creates a new snapshot record in the database at the start of a backup.
// Deprecated: Use CreateSnapshotWithName instead for multi-snapshot support.
func (sm *SnapshotManager) CreateSnapshot(ctx context.Context, hostname, version, gitRevision string) (string, error) {
return sm.CreateSnapshotWithName(ctx, hostname, "", version, gitRevision)
}
// CreateSnapshotWithName creates a new snapshot record with an optional snapshot name.
// The snapshot ID format is: hostname_name_timestamp or hostname_timestamp if name is empty.
func (sm *SnapshotManager) CreateSnapshotWithName(ctx context.Context, hostname, name, version, gitRevision string) (string, error) {
// Use short hostname (strip domain if present)
shortHostname := hostname
if idx := strings.Index(hostname, "."); idx != -1 {
shortHostname = hostname[:idx]
}
snapshotID := fmt.Sprintf("%s_%s", shortHostname, time.Now().UTC().Format("2006-01-02T15:04:05Z"))
// Build snapshot ID with optional name
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
var snapshotID string
if name != "" {
snapshotID = fmt.Sprintf("%s_%s_%s", shortHostname, name, timestamp)
} else {
snapshotID = fmt.Sprintf("%s_%s", shortHostname, timestamp)
}
snapshot := &database.Snapshot{
ID: snapshotID,
Hostname: hostname,
VaultikVersion: version,
VaultikGitRevision: gitRevision,
ID: types.SnapshotID(snapshotID),
Hostname: types.Hostname(hostname),
VaultikVersion: types.Version(version),
VaultikGitRevision: types.GitRevision(gitRevision),
StartedAt: time.Now().UTC(),
CompletedAt: nil, // Not completed yet
FileCount: 0,
@@ -652,7 +668,7 @@ func (sm *SnapshotManager) CleanupIncompleteSnapshots(ctx context.Context, hostn
log.Info("Cleaning up incomplete snapshot record", "snapshot_id", snapshot.ID, "started_at", snapshot.StartedAt)
// Delete the snapshot and all its associations
if err := sm.deleteSnapshot(ctx, snapshot.ID); err != nil {
if err := sm.deleteSnapshot(ctx, snapshot.ID.String()); err != nil {
return fmt.Errorf("deleting incomplete snapshot %s: %w", snapshot.ID, err)
}
@@ -661,7 +677,7 @@ func (sm *SnapshotManager) CleanupIncompleteSnapshots(ctx context.Context, hostn
// Metadata exists - this snapshot was completed but database wasn't updated
// This shouldn't happen in normal operation, but mark it complete
log.Warn("Found snapshot with S3 metadata but incomplete in database", "snapshot_id", snapshot.ID)
if err := sm.repos.Snapshots.MarkComplete(ctx, nil, snapshot.ID); err != nil {
if err := sm.repos.Snapshots.MarkComplete(ctx, nil, snapshot.ID.String()); err != nil {
log.Error("Failed to mark snapshot as complete in database", "snapshot_id", snapshot.ID, "error", err)
}
}

View File

@@ -101,7 +101,7 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
config: cfg,
fs: fs,
}
if _, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshot.ID); err != nil {
if _, err := sm.cleanSnapshotDB(ctx, tempDBPath, snapshot.ID.String()); err != nil {
t.Fatalf("failed to clean snapshot database: %v", err)
}
@@ -119,7 +119,7 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
cleanedRepos := database.NewRepositories(cleanedDB)
// Verify snapshot exists
verifySnapshot, err := cleanedRepos.Snapshots.GetByID(ctx, snapshot.ID)
verifySnapshot, err := cleanedRepos.Snapshots.GetByID(ctx, snapshot.ID.String())
if err != nil {
t.Fatalf("failed to get snapshot: %v", err)
}
@@ -128,7 +128,7 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
}
// Verify orphan file is gone
f, err := cleanedRepos.Files.GetByPath(ctx, file.Path)
f, err := cleanedRepos.Files.GetByPath(ctx, file.Path.String())
if err != nil {
t.Fatalf("failed to check file: %v", err)
}
@@ -137,7 +137,7 @@ func TestCleanSnapshotDBEmptySnapshot(t *testing.T) {
}
// Verify orphan chunk is gone
c, err := cleanedRepos.Chunks.GetByHash(ctx, chunk.ChunkHash)
c, err := cleanedRepos.Chunks.GetByHash(ctx, chunk.ChunkHash.String())
if err != nil {
t.Fatalf("failed to check chunk: %v", err)
}

203
internal/types/types.go Normal file
View File

@@ -0,0 +1,203 @@
// Package types provides custom types for better type safety across the vaultik codebase.
// Using distinct types for IDs, hashes, paths, and credentials prevents accidental
// mixing of semantically different values that happen to share the same underlying type.
package types
import (
"database/sql/driver"
"fmt"
"github.com/google/uuid"
)
// FileID is a UUID identifying a file record in the database.
type FileID uuid.UUID
// NewFileID generates a new random FileID.
func NewFileID() FileID {
return FileID(uuid.New())
}
// ParseFileID parses a string into a FileID.
func ParseFileID(s string) (FileID, error) {
id, err := uuid.Parse(s)
if err != nil {
return FileID{}, err
}
return FileID(id), nil
}
// IsZero returns true if the FileID is the zero value.
func (id FileID) IsZero() bool {
return uuid.UUID(id) == uuid.Nil
}
// Value implements driver.Valuer for database serialization.
func (id FileID) Value() (driver.Value, error) {
return uuid.UUID(id).String(), nil
}
// Scan implements sql.Scanner for database deserialization.
func (id *FileID) Scan(src interface{}) error {
if src == nil {
*id = FileID{}
return nil
}
var s string
switch v := src.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("cannot scan %T into FileID", src)
}
parsed, err := uuid.Parse(s)
if err != nil {
return fmt.Errorf("invalid FileID: %w", err)
}
*id = FileID(parsed)
return nil
}
// BlobID is a UUID identifying a blob record in the database.
// This is distinct from BlobHash which is the content-addressed hash of the blob.
type BlobID uuid.UUID
// NewBlobID generates a new random BlobID.
func NewBlobID() BlobID {
return BlobID(uuid.New())
}
// ParseBlobID parses a string into a BlobID.
func ParseBlobID(s string) (BlobID, error) {
id, err := uuid.Parse(s)
if err != nil {
return BlobID{}, err
}
return BlobID(id), nil
}
// IsZero returns true if the BlobID is the zero value.
func (id BlobID) IsZero() bool {
return uuid.UUID(id) == uuid.Nil
}
// Value implements driver.Valuer for database serialization.
func (id BlobID) Value() (driver.Value, error) {
return uuid.UUID(id).String(), nil
}
// Scan implements sql.Scanner for database deserialization.
func (id *BlobID) Scan(src interface{}) error {
if src == nil {
*id = BlobID{}
return nil
}
var s string
switch v := src.(type) {
case string:
s = v
case []byte:
s = string(v)
default:
return fmt.Errorf("cannot scan %T into BlobID", src)
}
parsed, err := uuid.Parse(s)
if err != nil {
return fmt.Errorf("invalid BlobID: %w", err)
}
*id = BlobID(parsed)
return nil
}
// SnapshotID identifies a snapshot, typically in format "hostname_name_timestamp".
type SnapshotID string
// ChunkHash is the SHA256 hash of a chunk's content.
// Used for content-addressing and deduplication of file chunks.
type ChunkHash string
// BlobHash is the SHA256 hash of a blob's compressed and encrypted content.
// This is used as the filename in S3 storage for content-addressed retrieval.
type BlobHash string
// FilePath represents an absolute path to a file or directory.
type FilePath string
// SourcePath represents the root directory from which files are backed up.
// Used during restore to strip the source prefix from paths.
type SourcePath string
// AgeRecipient is an age public key used for encryption.
// Format: age1... (Bech32-encoded X25519 public key)
type AgeRecipient string
// AgeSecretKey is an age private key used for decryption.
// Format: AGE-SECRET-KEY-... (Bech32-encoded X25519 private key)
// This type should never be logged or serialized in plaintext.
type AgeSecretKey string
// S3Endpoint is the URL of an S3-compatible storage endpoint.
type S3Endpoint string
// BucketName is the name of an S3 bucket.
type BucketName string
// S3Prefix is the path prefix within an S3 bucket.
type S3Prefix string
// AWSRegion is an AWS region identifier (e.g., "us-east-1").
type AWSRegion string
// AWSAccessKeyID is an AWS access key ID for authentication.
type AWSAccessKeyID string
// AWSSecretAccessKey is an AWS secret access key for authentication.
// This type should never be logged or serialized in plaintext.
type AWSSecretAccessKey string
// Hostname identifies a host machine.
type Hostname string
// Version is a semantic version string.
type Version string
// GitRevision is a git commit SHA.
type GitRevision string
// GlobPattern is a glob pattern for file matching (e.g., "*.log", "node_modules").
type GlobPattern string
// String methods for Stringer interface
func (id FileID) String() string { return uuid.UUID(id).String() }
func (id BlobID) String() string { return uuid.UUID(id).String() }
func (id SnapshotID) String() string { return string(id) }
func (h ChunkHash) String() string { return string(h) }
func (h BlobHash) String() string { return string(h) }
func (p FilePath) String() string { return string(p) }
func (p SourcePath) String() string { return string(p) }
func (r AgeRecipient) String() string { return string(r) }
func (e S3Endpoint) String() string { return string(e) }
func (b BucketName) String() string { return string(b) }
func (p S3Prefix) String() string { return string(p) }
func (r AWSRegion) String() string { return string(r) }
func (k AWSAccessKeyID) String() string { return string(k) }
func (h Hostname) String() string { return string(h) }
func (v Version) String() string { return string(v) }
func (r GitRevision) String() string { return string(r) }
func (p GlobPattern) String() string { return string(p) }
// Redacted String methods for sensitive types - prevents accidental logging
func (k AgeSecretKey) String() string { return "[REDACTED]" }
func (k AWSSecretAccessKey) String() string { return "[REDACTED]" }
// Raw returns the actual value for sensitive types when explicitly needed
func (k AgeSecretKey) Raw() string { return string(k) }
func (k AWSSecretAccessKey) Raw() string { return string(k) }

View File

@@ -5,13 +5,15 @@ import (
"strconv"
"strings"
"time"
"git.eeqj.de/sneak/vaultik/internal/types"
)
// SnapshotInfo contains information about a snapshot
type SnapshotInfo struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
CompressedSize int64 `json:"compressed_size"`
ID types.SnapshotID `json:"id"`
Timestamp time.Time `json:"timestamp"`
CompressedSize int64 `json:"compressed_size"`
}
// formatNumber formats a number with commas
@@ -60,27 +62,18 @@ func formatBytes(bytes int64) string {
}
// parseSnapshotTimestamp extracts the timestamp from a snapshot ID
// Format: hostname_snapshotname_2026-01-12T14:41:15Z
func parseSnapshotTimestamp(snapshotID string) (time.Time, error) {
// Format: hostname-YYYYMMDD-HHMMSSZ
parts := strings.Split(snapshotID, "-")
if len(parts) < 3 {
return time.Time{}, fmt.Errorf("invalid snapshot ID format")
parts := strings.Split(snapshotID, "_")
if len(parts) < 2 {
return time.Time{}, fmt.Errorf("invalid snapshot ID format: expected hostname_snapshotname_timestamp")
}
dateStr := parts[len(parts)-2]
timeStr := parts[len(parts)-1]
if len(dateStr) != 8 || len(timeStr) != 7 || !strings.HasSuffix(timeStr, "Z") {
return time.Time{}, fmt.Errorf("invalid timestamp format")
}
// Remove Z suffix
timeStr = timeStr[:6]
// Parse the timestamp
timestamp, err := time.Parse("20060102150405", dateStr+timeStr)
// Last part is the RFC3339 timestamp
timestampStr := parts[len(parts)-1]
timestamp, err := time.Parse(time.RFC3339, timestampStr)
if err != nil {
return time.Time{}, fmt.Errorf("failed to parse timestamp: %w", err)
return time.Time{}, fmt.Errorf("invalid timestamp: %w", err)
}
return timestamp.UTC(), nil

View File

@@ -30,14 +30,23 @@ func (v *Vaultik) ShowInfo() error {
// Backup Settings
fmt.Printf("=== Backup Settings ===\n")
fmt.Printf("Source Directories:\n")
for _, dir := range v.Config.SourceDirs {
fmt.Printf(" - %s\n", dir)
// Show configured snapshots
fmt.Printf("Snapshots:\n")
for _, name := range v.Config.SnapshotNames() {
snap := v.Config.Snapshots[name]
fmt.Printf(" %s:\n", name)
for _, path := range snap.Paths {
fmt.Printf(" - %s\n", path)
}
if len(snap.Exclude) > 0 {
fmt.Printf(" exclude: %s\n", strings.Join(snap.Exclude, ", "))
}
}
// Global exclude patterns
if len(v.Config.Exclude) > 0 {
fmt.Printf("Exclude Patterns: %s\n", strings.Join(v.Config.Exclude, ", "))
fmt.Printf("Global Exclude: %s\n", strings.Join(v.Config.Exclude, ", "))
}
fmt.Printf("Compression: zstd level %d\n", v.Config.CompressionLevel)

View File

@@ -14,6 +14,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -184,7 +185,11 @@ func TestEndToEndBackup(t *testing.T) {
// Create test configuration
cfg := &config.Config{
SourceDirs: []string{"/home/user"},
Snapshots: map[string]config.SnapshotConfig{
"test": {
Paths: []string{"/home/user"},
},
},
Exclude: []string{"*.tmp", "*.log"},
ChunkSize: config.Size(16 * 1024), // 16KB chunks
BlobSizeLimit: config.Size(100 * 1024), // 100KB blobs
@@ -232,7 +237,7 @@ func TestEndToEndBackup(t *testing.T) {
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
ID: types.SnapshotID(snapshotID),
Hostname: "test-host",
VaultikVersion: "test-version",
StartedAt: time.Now(),
@@ -352,7 +357,7 @@ func TestBackupAndVerify(t *testing.T) {
snapshotID := "test-snapshot-001"
err = repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
snapshot := &database.Snapshot{
ID: snapshotID,
ID: types.SnapshotID(snapshotID),
Hostname: "test-host",
VaultikVersion: "test-version",
StartedAt: time.Now(),

675
internal/vaultik/restore.go Normal file
View File

@@ -0,0 +1,675 @@
package vaultik
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"time"
"filippo.io/age"
"git.eeqj.de/sneak/vaultik/internal/blobgen"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/dustin/go-humanize"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
"golang.org/x/term"
)
// RestoreOptions contains options for the restore operation
type RestoreOptions struct {
SnapshotID string
TargetDir string
Paths []string // Optional paths to restore (empty = all)
Verify bool // Verify restored files by checking chunk hashes
}
// RestoreResult contains statistics from a restore operation
type RestoreResult struct {
FilesRestored int
BytesRestored int64
BlobsDownloaded int
BytesDownloaded int64
Duration time.Duration
// Verification results (only populated if Verify option is set)
FilesVerified int
BytesVerified int64
FilesFailed int
FailedFiles []string // Paths of files that failed verification
}
// Restore restores files from a snapshot to the target directory
func (v *Vaultik) Restore(opts *RestoreOptions) error {
startTime := time.Now()
// Check for age_secret_key
if v.Config.AgeSecretKey == "" {
return fmt.Errorf("decryption key required for restore\n\nSet the VAULTIK_AGE_SECRET_KEY environment variable to your age private key:\n export VAULTIK_AGE_SECRET_KEY='AGE-SECRET-KEY-...'")
}
// Parse the age identity
identity, err := age.ParseX25519Identity(v.Config.AgeSecretKey)
if err != nil {
return fmt.Errorf("parsing age secret key: %w", err)
}
log.Info("Starting restore operation",
"snapshot_id", opts.SnapshotID,
"target_dir", opts.TargetDir,
"paths", opts.Paths,
)
// Step 1: Download and decrypt the snapshot metadata database
log.Info("Downloading snapshot metadata...")
tempDB, err := v.downloadSnapshotDB(opts.SnapshotID, identity)
if err != nil {
return fmt.Errorf("downloading snapshot database: %w", err)
}
defer func() {
if err := tempDB.Close(); err != nil {
log.Debug("Failed to close temp database", "error", err)
}
// Clean up temp file
if err := v.Fs.Remove(tempDB.Path()); err != nil {
log.Debug("Failed to remove temp database", "error", err)
}
}()
repos := database.NewRepositories(tempDB)
// Step 2: Get list of files to restore
files, err := v.getFilesToRestore(v.ctx, repos, opts.Paths)
if err != nil {
return fmt.Errorf("getting files to restore: %w", err)
}
if len(files) == 0 {
log.Warn("No files found to restore")
return nil
}
log.Info("Found files to restore", "count", len(files))
// Step 3: Create target directory
if err := v.Fs.MkdirAll(opts.TargetDir, 0755); err != nil {
return fmt.Errorf("creating target directory: %w", err)
}
// Step 4: Build a map of chunks to blobs for efficient restoration
chunkToBlobMap, err := v.buildChunkToBlobMap(v.ctx, repos)
if err != nil {
return fmt.Errorf("building chunk-to-blob map: %w", err)
}
// Step 5: Restore files
result := &RestoreResult{}
blobCache := make(map[string][]byte) // Cache downloaded and decrypted blobs
for i, file := range files {
if v.ctx.Err() != nil {
return v.ctx.Err()
}
if err := v.restoreFile(v.ctx, repos, file, opts.TargetDir, identity, chunkToBlobMap, blobCache, result); err != nil {
log.Error("Failed to restore file", "path", file.Path, "error", err)
// Continue with other files
continue
}
// Progress logging
if (i+1)%100 == 0 || i+1 == len(files) {
log.Info("Restore progress",
"files", fmt.Sprintf("%d/%d", i+1, len(files)),
"bytes", humanize.Bytes(uint64(result.BytesRestored)),
)
}
}
result.Duration = time.Since(startTime)
log.Info("Restore complete",
"files_restored", result.FilesRestored,
"bytes_restored", humanize.Bytes(uint64(result.BytesRestored)),
"blobs_downloaded", result.BlobsDownloaded,
"bytes_downloaded", humanize.Bytes(uint64(result.BytesDownloaded)),
"duration", result.Duration,
)
_, _ = fmt.Fprintf(v.Stdout, "Restored %d files (%s) in %s\n",
result.FilesRestored,
humanize.Bytes(uint64(result.BytesRestored)),
result.Duration.Round(time.Second),
)
// Run verification if requested
if opts.Verify {
if err := v.verifyRestoredFiles(v.ctx, repos, files, opts.TargetDir, result); err != nil {
return fmt.Errorf("verification failed: %w", err)
}
if result.FilesFailed > 0 {
_, _ = fmt.Fprintf(v.Stdout, "\nVerification FAILED: %d files did not match expected checksums\n", result.FilesFailed)
for _, path := range result.FailedFiles {
_, _ = fmt.Fprintf(v.Stdout, " - %s\n", path)
}
return fmt.Errorf("%d files failed verification", result.FilesFailed)
}
_, _ = fmt.Fprintf(v.Stdout, "Verified %d files (%s)\n",
result.FilesVerified,
humanize.Bytes(uint64(result.BytesVerified)),
)
}
return nil
}
// downloadSnapshotDB downloads and decrypts the snapshot metadata database
func (v *Vaultik) downloadSnapshotDB(snapshotID string, identity age.Identity) (*database.DB, error) {
// Download encrypted database from S3
dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
reader, err := v.Storage.Get(v.ctx, dbKey)
if err != nil {
return nil, fmt.Errorf("downloading %s: %w", dbKey, err)
}
defer func() { _ = reader.Close() }()
// Read all data
encryptedData, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("reading encrypted data: %w", err)
}
log.Debug("Downloaded encrypted database", "size", humanize.Bytes(uint64(len(encryptedData))))
// Decrypt and decompress using blobgen.Reader
blobReader, err := blobgen.NewReader(bytes.NewReader(encryptedData), identity)
if err != nil {
return nil, fmt.Errorf("creating decryption reader: %w", err)
}
defer func() { _ = blobReader.Close() }()
// Read the SQL dump
sqlDump, err := io.ReadAll(blobReader)
if err != nil {
return nil, fmt.Errorf("decrypting and decompressing: %w", err)
}
log.Debug("Decrypted database SQL dump", "size", humanize.Bytes(uint64(len(sqlDump))))
// Create a temporary database file
tempFile, err := afero.TempFile(v.Fs, "", "vaultik-restore-*.db")
if err != nil {
return nil, fmt.Errorf("creating temp file: %w", err)
}
tempPath := tempFile.Name()
if err := tempFile.Close(); err != nil {
return nil, fmt.Errorf("closing temp file: %w", err)
}
// Write SQL to a temp file for sqlite3 to read
sqlTempFile, err := afero.TempFile(v.Fs, "", "vaultik-restore-*.sql")
if err != nil {
return nil, fmt.Errorf("creating SQL temp file: %w", err)
}
sqlTempPath := sqlTempFile.Name()
if _, err := sqlTempFile.Write(sqlDump); err != nil {
_ = sqlTempFile.Close()
return nil, fmt.Errorf("writing SQL dump: %w", err)
}
if err := sqlTempFile.Close(); err != nil {
return nil, fmt.Errorf("closing SQL temp file: %w", err)
}
defer func() { _ = v.Fs.Remove(sqlTempPath) }()
// Execute the SQL dump to create the database
cmd := exec.Command("sqlite3", tempPath, ".read "+sqlTempPath)
if output, err := cmd.CombinedOutput(); err != nil {
return nil, fmt.Errorf("executing SQL dump: %w\nOutput: %s", err, output)
}
log.Debug("Created restore database", "path", tempPath)
// Open the database
db, err := database.New(v.ctx, tempPath)
if err != nil {
return nil, fmt.Errorf("opening restore database: %w", err)
}
return db, nil
}
// getFilesToRestore returns the list of files to restore based on path filters
func (v *Vaultik) getFilesToRestore(ctx context.Context, repos *database.Repositories, pathFilters []string) ([]*database.File, error) {
// If no filters, get all files
if len(pathFilters) == 0 {
return repos.Files.ListAll(ctx)
}
// Get files matching the path filters
var result []*database.File
seen := make(map[string]bool)
for _, filter := range pathFilters {
// Normalize the filter path
filter = filepath.Clean(filter)
// Get files with this prefix
files, err := repos.Files.ListByPrefix(ctx, filter)
if err != nil {
return nil, fmt.Errorf("listing files with prefix %s: %w", filter, err)
}
for _, file := range files {
if !seen[file.ID.String()] {
seen[file.ID.String()] = true
result = append(result, file)
}
}
}
return result, nil
}
// buildChunkToBlobMap creates a mapping from chunk hash to blob information
func (v *Vaultik) buildChunkToBlobMap(ctx context.Context, repos *database.Repositories) (map[string]*database.BlobChunk, error) {
// Query all blob_chunks
query := `SELECT blob_id, chunk_hash, offset, length FROM blob_chunks`
rows, err := repos.DB().Conn().QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("querying blob_chunks: %w", err)
}
defer func() { _ = rows.Close() }()
result := make(map[string]*database.BlobChunk)
for rows.Next() {
var bc database.BlobChunk
var blobIDStr, chunkHashStr string
if err := rows.Scan(&blobIDStr, &chunkHashStr, &bc.Offset, &bc.Length); err != nil {
return nil, fmt.Errorf("scanning blob_chunk: %w", err)
}
blobID, err := types.ParseBlobID(blobIDStr)
if err != nil {
return nil, fmt.Errorf("parsing blob ID: %w", err)
}
bc.BlobID = blobID
bc.ChunkHash = types.ChunkHash(chunkHashStr)
result[chunkHashStr] = &bc
}
return result, rows.Err()
}
// restoreFile restores a single file
func (v *Vaultik) restoreFile(
ctx context.Context,
repos *database.Repositories,
file *database.File,
targetDir string,
identity age.Identity,
chunkToBlobMap map[string]*database.BlobChunk,
blobCache map[string][]byte,
result *RestoreResult,
) error {
// Calculate target path - use full original path under target directory
targetPath := filepath.Join(targetDir, file.Path.String())
// Create parent directories
parentDir := filepath.Dir(targetPath)
if err := v.Fs.MkdirAll(parentDir, 0755); err != nil {
return fmt.Errorf("creating parent directory: %w", err)
}
// Handle symlinks
if file.IsSymlink() {
return v.restoreSymlink(file, targetPath, result)
}
// Handle directories
if file.Mode&uint32(os.ModeDir) != 0 {
return v.restoreDirectory(file, targetPath, result)
}
// Handle regular files
return v.restoreRegularFile(ctx, repos, file, targetPath, identity, chunkToBlobMap, blobCache, result)
}
// restoreSymlink restores a symbolic link
func (v *Vaultik) restoreSymlink(file *database.File, targetPath string, result *RestoreResult) error {
// Remove existing file if it exists
_ = v.Fs.Remove(targetPath)
// Create symlink
// Note: afero.MemMapFs doesn't support symlinks, so we use os for real filesystems
if osFs, ok := v.Fs.(*afero.OsFs); ok {
_ = osFs // silence unused variable warning
if err := os.Symlink(file.LinkTarget.String(), targetPath); err != nil {
return fmt.Errorf("creating symlink: %w", err)
}
} else {
log.Debug("Symlink creation not supported on this filesystem", "path", file.Path, "target", file.LinkTarget)
}
result.FilesRestored++
log.Debug("Restored symlink", "path", file.Path, "target", file.LinkTarget)
return nil
}
// restoreDirectory restores a directory with proper permissions
func (v *Vaultik) restoreDirectory(file *database.File, targetPath string, result *RestoreResult) error {
// Create directory
if err := v.Fs.MkdirAll(targetPath, os.FileMode(file.Mode)); err != nil {
return fmt.Errorf("creating directory: %w", err)
}
// Set permissions
if err := v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil {
log.Debug("Failed to set directory permissions", "path", targetPath, "error", err)
}
// Set ownership (requires root)
if osFs, ok := v.Fs.(*afero.OsFs); ok {
_ = osFs
if err := os.Chown(targetPath, int(file.UID), int(file.GID)); err != nil {
log.Debug("Failed to set directory ownership", "path", targetPath, "error", err)
}
}
// Set mtime
if err := v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil {
log.Debug("Failed to set directory mtime", "path", targetPath, "error", err)
}
result.FilesRestored++
return nil
}
// restoreRegularFile restores a regular file by reconstructing it from chunks
func (v *Vaultik) restoreRegularFile(
ctx context.Context,
repos *database.Repositories,
file *database.File,
targetPath string,
identity age.Identity,
chunkToBlobMap map[string]*database.BlobChunk,
blobCache map[string][]byte,
result *RestoreResult,
) error {
// Get file chunks in order
fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID)
if err != nil {
return fmt.Errorf("getting file chunks: %w", err)
}
// Create output file
outFile, err := v.Fs.Create(targetPath)
if err != nil {
return fmt.Errorf("creating output file: %w", err)
}
defer func() { _ = outFile.Close() }()
// Write chunks in order
var bytesWritten int64
for _, fc := range fileChunks {
// Find which blob contains this chunk
chunkHashStr := fc.ChunkHash.String()
blobChunk, ok := chunkToBlobMap[chunkHashStr]
if !ok {
return fmt.Errorf("chunk %s not found in any blob", chunkHashStr[:16])
}
// Get the blob's hash from the database
blob, err := repos.Blobs.GetByID(ctx, blobChunk.BlobID.String())
if err != nil {
return fmt.Errorf("getting blob %s: %w", blobChunk.BlobID, err)
}
// Download and decrypt blob if not cached
blobHashStr := blob.Hash.String()
blobData, ok := blobCache[blobHashStr]
if !ok {
blobData, err = v.downloadBlob(ctx, blobHashStr, identity)
if err != nil {
return fmt.Errorf("downloading blob %s: %w", blobHashStr[:16], err)
}
blobCache[blobHashStr] = blobData
result.BlobsDownloaded++
result.BytesDownloaded += int64(len(blobData))
}
// Extract chunk from blob
if blobChunk.Offset+blobChunk.Length > int64(len(blobData)) {
return fmt.Errorf("chunk %s extends beyond blob data (offset=%d, length=%d, blob_size=%d)",
fc.ChunkHash[:16], blobChunk.Offset, blobChunk.Length, len(blobData))
}
chunkData := blobData[blobChunk.Offset : blobChunk.Offset+blobChunk.Length]
// Write chunk to output file
n, err := outFile.Write(chunkData)
if err != nil {
return fmt.Errorf("writing chunk: %w", err)
}
bytesWritten += int64(n)
}
// Close file before setting metadata
if err := outFile.Close(); err != nil {
return fmt.Errorf("closing output file: %w", err)
}
// Set permissions
if err := v.Fs.Chmod(targetPath, os.FileMode(file.Mode)); err != nil {
log.Debug("Failed to set file permissions", "path", targetPath, "error", err)
}
// Set ownership (requires root)
if osFs, ok := v.Fs.(*afero.OsFs); ok {
_ = osFs
if err := os.Chown(targetPath, int(file.UID), int(file.GID)); err != nil {
log.Debug("Failed to set file ownership", "path", targetPath, "error", err)
}
}
// Set mtime
if err := v.Fs.Chtimes(targetPath, file.MTime, file.MTime); err != nil {
log.Debug("Failed to set file mtime", "path", targetPath, "error", err)
}
result.FilesRestored++
result.BytesRestored += bytesWritten
log.Debug("Restored file", "path", file.Path, "size", humanize.Bytes(uint64(bytesWritten)))
return nil
}
// downloadBlob downloads and decrypts a blob
func (v *Vaultik) downloadBlob(ctx context.Context, blobHash string, identity age.Identity) ([]byte, error) {
// Construct blob path with sharding
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobHash[:2], blobHash[2:4], blobHash)
reader, err := v.Storage.Get(ctx, blobPath)
if err != nil {
return nil, fmt.Errorf("downloading blob: %w", err)
}
defer func() { _ = reader.Close() }()
// Read encrypted data
encryptedData, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("reading blob data: %w", err)
}
// Decrypt and decompress
blobReader, err := blobgen.NewReader(bytes.NewReader(encryptedData), identity)
if err != nil {
return nil, fmt.Errorf("creating decryption reader: %w", err)
}
defer func() { _ = blobReader.Close() }()
data, err := io.ReadAll(blobReader)
if err != nil {
return nil, fmt.Errorf("decrypting blob: %w", err)
}
log.Debug("Downloaded and decrypted blob",
"hash", blobHash[:16],
"encrypted_size", humanize.Bytes(uint64(len(encryptedData))),
"decrypted_size", humanize.Bytes(uint64(len(data))),
)
return data, nil
}
// verifyRestoredFiles verifies that all restored files match their expected chunk hashes
func (v *Vaultik) verifyRestoredFiles(
ctx context.Context,
repos *database.Repositories,
files []*database.File,
targetDir string,
result *RestoreResult,
) error {
// Calculate total bytes to verify for progress bar
var totalBytes int64
regularFiles := make([]*database.File, 0, len(files))
for _, file := range files {
// Skip symlinks and directories - only verify regular files
if file.IsSymlink() || file.Mode&uint32(os.ModeDir) != 0 {
continue
}
regularFiles = append(regularFiles, file)
totalBytes += file.Size
}
if len(regularFiles) == 0 {
log.Info("No regular files to verify")
return nil
}
log.Info("Verifying restored files",
"files", len(regularFiles),
"bytes", humanize.Bytes(uint64(totalBytes)),
)
_, _ = fmt.Fprintf(v.Stdout, "\nVerifying %d files (%s)...\n",
len(regularFiles),
humanize.Bytes(uint64(totalBytes)),
)
// Create progress bar if output is a terminal
var bar *progressbar.ProgressBar
if isTerminal() {
bar = progressbar.NewOptions64(
totalBytes,
progressbar.OptionSetDescription("Verifying"),
progressbar.OptionSetWriter(os.Stderr),
progressbar.OptionShowBytes(true),
progressbar.OptionShowCount(),
progressbar.OptionSetWidth(40),
progressbar.OptionThrottle(100*time.Millisecond),
progressbar.OptionOnCompletion(func() {
fmt.Fprint(os.Stderr, "\n")
}),
progressbar.OptionSetRenderBlankState(true),
)
}
// Verify each file
for _, file := range regularFiles {
if ctx.Err() != nil {
return ctx.Err()
}
targetPath := filepath.Join(targetDir, file.Path.String())
bytesVerified, err := v.verifyFile(ctx, repos, file, targetPath)
if err != nil {
log.Error("File verification failed", "path", file.Path, "error", err)
result.FilesFailed++
result.FailedFiles = append(result.FailedFiles, file.Path.String())
} else {
result.FilesVerified++
result.BytesVerified += bytesVerified
}
// Update progress bar
if bar != nil {
_ = bar.Add64(file.Size)
}
}
if bar != nil {
_ = bar.Finish()
}
log.Info("Verification complete",
"files_verified", result.FilesVerified,
"bytes_verified", humanize.Bytes(uint64(result.BytesVerified)),
"files_failed", result.FilesFailed,
)
return nil
}
// verifyFile verifies a single restored file by checking its chunk hashes
func (v *Vaultik) verifyFile(
ctx context.Context,
repos *database.Repositories,
file *database.File,
targetPath string,
) (int64, error) {
// Get file chunks in order
fileChunks, err := repos.FileChunks.GetByFileID(ctx, file.ID)
if err != nil {
return 0, fmt.Errorf("getting file chunks: %w", err)
}
// Open the restored file
f, err := v.Fs.Open(targetPath)
if err != nil {
return 0, fmt.Errorf("opening file: %w", err)
}
defer func() { _ = f.Close() }()
// Verify each chunk
var bytesVerified int64
for _, fc := range fileChunks {
// Get chunk size from database
chunk, err := repos.Chunks.GetByHash(ctx, fc.ChunkHash.String())
if err != nil {
return bytesVerified, fmt.Errorf("getting chunk %s: %w", fc.ChunkHash.String()[:16], err)
}
// Read chunk data from file
chunkData := make([]byte, chunk.Size)
n, err := io.ReadFull(f, chunkData)
if err != nil {
return bytesVerified, fmt.Errorf("reading chunk data: %w", err)
}
if int64(n) != chunk.Size {
return bytesVerified, fmt.Errorf("short read: expected %d bytes, got %d", chunk.Size, n)
}
// Calculate hash and compare
hash := sha256.Sum256(chunkData)
actualHash := hex.EncodeToString(hash[:])
expectedHash := fc.ChunkHash.String()
if actualHash != expectedHash {
return bytesVerified, fmt.Errorf("chunk %d hash mismatch: expected %s, got %s",
fc.Idx, expectedHash[:16], actualHash[:16])
}
bytesVerified += int64(n)
}
log.Debug("File verified", "path", file.Path, "bytes", bytesVerified, "chunks", len(fileChunks))
return bytesVerified, nil
}
// isTerminal returns true if stdout is a terminal
func isTerminal() bool {
return term.IsTerminal(int(os.Stdout.Fd()))
}

View File

@@ -13,19 +13,22 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/types"
"github.com/dustin/go-humanize"
)
// SnapshotCreateOptions contains options for the snapshot create command
type SnapshotCreateOptions struct {
Daemon bool
Cron bool
Prune bool
Daemon bool
Cron bool
Prune bool
SkipErrors bool // Skip file read errors (log them loudly but continue)
Snapshots []string // Optional list of snapshot names to process (empty = all)
}
// CreateSnapshot executes the snapshot creation operation
func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
snapshotStartTime := time.Now()
overallStartTime := time.Now()
log.Info("Starting snapshot creation",
"version", v.Globals.Version,
@@ -57,9 +60,51 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
return fmt.Errorf("daemon mode not yet implemented")
}
// Determine which snapshots to process
snapshotNames := opts.Snapshots
if len(snapshotNames) == 0 {
snapshotNames = v.Config.SnapshotNames()
} else {
// Validate requested snapshot names exist
for _, name := range snapshotNames {
if _, ok := v.Config.Snapshots[name]; !ok {
return fmt.Errorf("snapshot %q not found in config", name)
}
}
}
if len(snapshotNames) == 0 {
return fmt.Errorf("no snapshots configured")
}
// Process each named snapshot
for snapIdx, snapName := range snapshotNames {
if err := v.createNamedSnapshot(opts, hostname, snapName, snapIdx+1, len(snapshotNames)); err != nil {
return err
}
}
// Print overall summary if multiple snapshots
if len(snapshotNames) > 1 {
_, _ = fmt.Fprintf(v.Stdout, "\nAll %d snapshots completed in %s\n", len(snapshotNames), time.Since(overallStartTime).Round(time.Second))
}
return nil
}
// createNamedSnapshot creates a single named snapshot
func (v *Vaultik) createNamedSnapshot(opts *SnapshotCreateOptions, hostname, snapName string, idx, total int) error {
snapshotStartTime := time.Now()
snapConfig := v.Config.Snapshots[snapName]
if total > 1 {
_, _ = fmt.Fprintf(v.Stdout, "\n=== Snapshot %d/%d: %s ===\n", idx, total, snapName)
}
// Resolve source directories to absolute paths
resolvedDirs := make([]string, 0, len(v.Config.SourceDirs))
for _, dir := range v.Config.SourceDirs {
resolvedDirs := make([]string, 0, len(snapConfig.Paths))
for _, dir := range snapConfig.Paths {
absPath, err := filepath.Abs(dir)
if err != nil {
return fmt.Errorf("failed to resolve absolute path for %s: %w", dir, err)
@@ -80,9 +125,12 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
}
// Create scanner with progress enabled (unless in cron mode)
// Pass the combined excludes for this snapshot
scanner := v.ScannerFactory(snapshot.ScannerParams{
EnableProgress: !opts.Cron,
Fs: v.Fs,
Exclude: v.Config.GetExcludes(snapName),
SkipErrors: opts.SkipErrors,
})
// Statistics tracking
@@ -98,12 +146,12 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
totalBlobsUploaded := 0
uploadDuration := time.Duration(0)
// Create a new snapshot at the beginning
snapshotID, err := v.SnapshotManager.CreateSnapshot(v.ctx, hostname, v.Globals.Version, v.Globals.Commit)
// Create a new snapshot at the beginning (with snapshot name in ID)
snapshotID, err := v.SnapshotManager.CreateSnapshotWithName(v.ctx, hostname, snapName, v.Globals.Version, v.Globals.Commit)
if err != nil {
return fmt.Errorf("creating snapshot: %w", err)
}
log.Info("Beginning snapshot", "snapshot_id", snapshotID)
log.Info("Beginning snapshot", "snapshot_id", snapshotID, "name", snapName)
_, _ = fmt.Fprintf(v.Stdout, "Beginning snapshot: %s\n", snapshotID)
for i, dir := range resolvedDirs {
@@ -292,31 +340,32 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
// Build a map of local snapshots for quick lookup
localSnapshotMap := make(map[string]*database.Snapshot)
for _, s := range localSnapshots {
localSnapshotMap[s.ID] = s
localSnapshotMap[s.ID.String()] = s
}
// Remove local snapshots that don't exist remotely
for _, snapshot := range localSnapshots {
if !remoteSnapshots[snapshot.ID] {
snapshotIDStr := snapshot.ID.String()
if !remoteSnapshots[snapshotIDStr] {
log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID)
// Delete related records first to avoid foreign key constraints
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err)
}
// Now delete the snapshot itself
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err)
} else {
log.Info("Deleted local snapshot not found in remote", "snapshot_id", snapshot.ID)
delete(localSnapshotMap, snapshot.ID)
delete(localSnapshotMap, snapshotIDStr)
}
}
}
@@ -355,7 +404,7 @@ func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
}
snapshots = append(snapshots, SnapshotInfo{
ID: snapshotID,
ID: types.SnapshotID(snapshotID),
Timestamp: timestamp,
CompressedSize: totalSize,
})
@@ -481,7 +530,7 @@ func (v *Vaultik) PurgeSnapshots(keepLatest bool, olderThan string, force bool)
// Delete snapshots
for _, snap := range toDelete {
log.Info("Deleting snapshot", "id", snap.ID)
if err := v.deleteSnapshot(snap.ID); err != nil {
if err := v.deleteSnapshot(snap.ID.String()); err != nil {
return fmt.Errorf("deleting snapshot %s: %w", snap.ID, err)
}
}
@@ -689,9 +738,10 @@ func (v *Vaultik) syncWithRemote() error {
// Remove local snapshots that don't exist remotely
removedCount := 0
for _, snapshot := range localSnapshots {
if !remoteSnapshots[snapshot.ID] {
snapshotIDStr := snapshot.ID.String()
if !remoteSnapshots[snapshotIDStr] {
log.Info("Removing local snapshot not found in remote", "snapshot_id", snapshot.ID)
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete local snapshot", "snapshot_id", snapshot.ID, "error", err)
} else {
removedCount++
@@ -911,18 +961,19 @@ func (v *Vaultik) PruneDatabase() (*PruneResult, error) {
}
for _, snapshot := range incompleteSnapshots {
snapshotIDStr := snapshot.ID.String()
log.Info("Deleting incomplete snapshot", "snapshot_id", snapshot.ID)
// Delete related records first
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.DeleteSnapshotFiles(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot files", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.DeleteSnapshotBlobs(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot blobs", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.DeleteSnapshotUploads(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot uploads", "snapshot_id", snapshot.ID, "error", err)
}
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshot.ID); err != nil {
if err := v.Repositories.Snapshots.Delete(v.ctx, snapshotIDStr); err != nil {
log.Error("Failed to delete snapshot", "snapshot_id", snapshot.ID, "error", err)
} else {
result.SnapshotsDeleted++