Compare commits

..

No commits in common. "899448e1da942cdb622ffa84752667e7be434e9d" and "cda0cf865ae69e0cffa9ed67567f491fcc3e12aa" have entirely different histories.

24 changed files with 451 additions and 1498 deletions

1
go.mod
View File

@ -37,7 +37,6 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect
github.com/adrg/xdg v0.5.3 // indirect
github.com/armon/go-metrics v0.4.1 // indirect
github.com/aws/aws-sdk-go v1.44.256 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.11 // indirect

2
go.sum
View File

@ -33,8 +33,6 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mo
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78=
github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=

View File

@ -2,11 +2,9 @@ package cli
import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
@ -14,11 +12,9 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/pidlock"
"git.eeqj.de/sneak/vaultik/internal/s3"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/vaultik"
"github.com/adrg/xdg"
"go.uber.org/fx"
)
@ -55,7 +51,7 @@ func NewApp(opts AppOptions) *fx.App {
config.Module,
database.Module,
log.Module,
storage.Module,
s3.Module,
snapshot.Module,
fx.Provide(vaultik.New),
fx.Invoke(setupGlobals),
@ -122,23 +118,7 @@ func RunApp(ctx context.Context, app *fx.App) error {
// RunWithApp is a helper that creates and runs an fx app with the given options.
// It combines NewApp and RunApp into a single convenient function. This is the
// preferred way to run CLI commands that need the full application context.
// It acquires a PID lock before starting to prevent concurrent instances.
func RunWithApp(ctx context.Context, opts AppOptions) error {
// Acquire PID lock to prevent concurrent instances
lockDir := filepath.Join(xdg.DataHome, "berlin.sneak.app.vaultik")
lock, err := pidlock.Acquire(lockDir)
if err != nil {
if errors.Is(err, pidlock.ErrAlreadyRunning) {
return fmt.Errorf("cannot start: %w", err)
}
return fmt.Errorf("failed to acquire lock: %w", err)
}
defer func() {
if err := lock.Release(); err != nil {
log.Warn("Failed to release PID lock", "error", err)
}
}()
app := NewApp(opts)
return RunApp(ctx, app)
}

View File

@ -8,8 +8,8 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/s3"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/storage"
"github.com/spf13/cobra"
"go.uber.org/fx"
)
@ -23,7 +23,7 @@ type FetchApp struct {
Globals *globals.Globals
Config *config.Config
Repositories *database.Repositories
Storage storage.Storer
S3Client *s3.Client
DB *database.DB
Shutdowner fx.Shutdowner
}
@ -61,14 +61,15 @@ The age_secret_key must be configured in the config file for decryption.`,
},
Modules: []fx.Option{
snapshot.Module,
s3.Module,
fx.Provide(fx.Annotate(
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
storer storage.Storer, db *database.DB, shutdowner fx.Shutdowner) *FetchApp {
s3Client *s3.Client, db *database.DB, shutdowner fx.Shutdowner) *FetchApp {
return &FetchApp{
Globals: g,
Config: cfg,
Repositories: repos,
Storage: storer,
S3Client: s3Client,
DB: db,
Shutdowner: shutdowner,
}

View File

@ -8,8 +8,8 @@ import (
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/s3"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/storage"
"github.com/spf13/cobra"
"go.uber.org/fx"
)
@ -24,7 +24,7 @@ type RestoreApp struct {
Globals *globals.Globals
Config *config.Config
Repositories *database.Repositories
Storage storage.Storer
S3Client *s3.Client
DB *database.DB
Shutdowner fx.Shutdowner
}
@ -61,14 +61,15 @@ The age_secret_key must be configured in the config file for decryption.`,
},
Modules: []fx.Option{
snapshot.Module,
s3.Module,
fx.Provide(fx.Annotate(
func(g *globals.Globals, cfg *config.Config, repos *database.Repositories,
storer storage.Storer, db *database.DB, shutdowner fx.Shutdowner) *RestoreApp {
s3Client *s3.Client, db *database.DB, shutdowner fx.Shutdowner) *RestoreApp {
return &RestoreApp{
Globals: g,
Config: cfg,
Repositories: repos,
Storage: storer,
S3Client: s3Client,
DB: db,
Shutdowner: shutdowner,
}

View File

@ -7,14 +7,14 @@ import (
"time"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/s3"
"github.com/spf13/cobra"
"go.uber.org/fx"
)
// StoreApp contains dependencies for store commands
type StoreApp struct {
Storage storage.Storer
S3Client *s3.Client
Shutdowner fx.Shutdowner
}
@ -48,18 +48,19 @@ func newStoreInfoCommand() *cobra.Command {
// Info displays storage information
func (app *StoreApp) Info(ctx context.Context) error {
// Get storage info
storageInfo := app.Storage.Info()
// Get bucket info
bucketName := app.S3Client.BucketName()
endpoint := app.S3Client.Endpoint()
fmt.Printf("Storage Information\n")
fmt.Printf("==================\n\n")
fmt.Printf("Storage Configuration:\n")
fmt.Printf(" Type: %s\n", storageInfo.Type)
fmt.Printf(" Location: %s\n\n", storageInfo.Location)
fmt.Printf("S3 Configuration:\n")
fmt.Printf(" Endpoint: %s\n", endpoint)
fmt.Printf(" Bucket: %s\n\n", bucketName)
// Count snapshots by listing metadata/ prefix
snapshotCount := 0
snapshotCh := app.Storage.ListStream(ctx, "metadata/")
snapshotCh := app.S3Client.ListObjectsStream(ctx, "metadata/", true)
snapshotDirs := make(map[string]bool)
for object := range snapshotCh {
@ -78,7 +79,7 @@ func (app *StoreApp) Info(ctx context.Context) error {
blobCount := 0
var totalSize int64
blobCh := app.Storage.ListStream(ctx, "blobs/")
blobCh := app.S3Client.ListObjectsStream(ctx, "blobs/", false)
for object := range blobCh {
if object.Err != nil {
return fmt.Errorf("listing blobs: %w", object.Err)
@ -129,9 +130,10 @@ func runWithApp(ctx context.Context, fn func(*StoreApp) error) error {
Debug: rootFlags.Debug,
},
Modules: []fx.Option{
fx.Provide(func(storer storage.Storer, shutdowner fx.Shutdowner) *StoreApp {
s3.Module,
fx.Provide(func(s3Client *s3.Client, shutdowner fx.Shutdowner) *StoreApp {
return &StoreApp{
Storage: storer,
S3Client: s3Client,
Shutdowner: shutdowner,
}
}),

View File

@ -3,43 +3,16 @@ package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"git.eeqj.de/sneak/smartconfig"
"github.com/adrg/xdg"
"go.uber.org/fx"
"gopkg.in/yaml.v3"
)
const appName = "berlin.sneak.app.vaultik"
// expandTilde expands ~ at the start of a path to the user's home directory.
func expandTilde(path string) string {
if path == "~" {
home, _ := os.UserHomeDir()
return home
}
if strings.HasPrefix(path, "~/") {
home, _ := os.UserHomeDir()
return filepath.Join(home, path[2:])
}
return path
}
// expandTildeInURL expands ~ in file:// URLs.
func expandTildeInURL(url string) string {
if strings.HasPrefix(url, "file://~/") {
home, _ := os.UserHomeDir()
return "file://" + filepath.Join(home, url[9:])
}
return url
}
// Config represents the application configuration for Vaultik.
// It defines all settings for backup operations, including source directories,
// encryption recipients, storage configuration, and performance tuning parameters.
// encryption recipients, S3 storage configuration, and performance tuning parameters.
// Configuration is typically loaded from a YAML file.
type Config struct {
AgeRecipients []string `yaml:"age_recipients"`
@ -55,14 +28,6 @@ type Config struct {
S3 S3Config `yaml:"s3"`
SourceDirs []string `yaml:"source_dirs"`
CompressionLevel int `yaml:"compression_level"`
// StorageURL specifies the storage backend using a URL format.
// Takes precedence over S3Config if set.
// Supported formats:
// - s3://bucket/prefix?endpoint=host&region=us-east-1
// - file:///path/to/backup
// For S3 URLs, credentials are still read from s3.access_key_id and s3.secret_access_key.
StorageURL string `yaml:"storage_url"`
}
// S3Config represents S3 storage configuration for backup storage.
@ -119,7 +84,7 @@ func Load(path string) (*Config, error) {
BackupInterval: 1 * time.Hour,
FullScanInterval: 24 * time.Hour,
MinTimeBetweenRun: 15 * time.Minute,
IndexPath: filepath.Join(xdg.DataHome, appName, "index.sqlite"),
IndexPath: "/var/lib/vaultik/index.sqlite",
CompressionLevel: 3,
}
@ -134,16 +99,9 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
// Expand tilde in all path fields
cfg.IndexPath = expandTilde(cfg.IndexPath)
cfg.StorageURL = expandTildeInURL(cfg.StorageURL)
for i, dir := range cfg.SourceDirs {
cfg.SourceDirs[i] = expandTilde(dir)
}
// Check for environment variable override for IndexPath
if envIndexPath := os.Getenv("VAULTIK_INDEX_PATH"); envIndexPath != "" {
cfg.IndexPath = expandTilde(envIndexPath)
cfg.IndexPath = envIndexPath
}
// Get hostname if not set
@ -174,7 +132,7 @@ func Load(path string) (*Config, error) {
// It ensures all required fields are present and have valid values:
// - At least one age recipient must be specified
// - At least one source directory must be configured
// - Storage must be configured (either storage_url or s3.* fields)
// - S3 credentials and endpoint must be provided
// - Chunk size must be at least 1MB
// - Blob size limit must be at least the chunk size
// - Compression level must be between 1 and 19
@ -188,9 +146,20 @@ func (c *Config) Validate() error {
return fmt.Errorf("at least one source directory is required")
}
// Validate storage configuration
if err := c.validateStorage(); err != nil {
return err
if c.S3.Endpoint == "" {
return fmt.Errorf("s3.endpoint is required")
}
if c.S3.Bucket == "" {
return fmt.Errorf("s3.bucket is required")
}
if c.S3.AccessKeyID == "" {
return fmt.Errorf("s3.access_key_id is required")
}
if c.S3.SecretAccessKey == "" {
return fmt.Errorf("s3.secret_access_key is required")
}
if c.ChunkSize.Int64() < 1024*1024 { // 1MB minimum
@ -208,50 +177,6 @@ func (c *Config) Validate() error {
return nil
}
// validateStorage validates storage configuration.
// If StorageURL is set, it takes precedence. S3 URLs require credentials.
// File URLs don't require any S3 configuration.
// If StorageURL is not set, legacy S3 configuration is required.
func (c *Config) validateStorage() error {
if c.StorageURL != "" {
// URL-based configuration
if strings.HasPrefix(c.StorageURL, "file://") {
// File storage doesn't need S3 credentials
return nil
}
if strings.HasPrefix(c.StorageURL, "s3://") {
// S3 storage needs credentials
if c.S3.AccessKeyID == "" {
return fmt.Errorf("s3.access_key_id is required for s3:// URLs")
}
if c.S3.SecretAccessKey == "" {
return fmt.Errorf("s3.secret_access_key is required for s3:// URLs")
}
return nil
}
return fmt.Errorf("storage_url must start with s3:// or file://")
}
// Legacy S3 configuration
if c.S3.Endpoint == "" {
return fmt.Errorf("s3.endpoint is required (or set storage_url)")
}
if c.S3.Bucket == "" {
return fmt.Errorf("s3.bucket is required (or set storage_url)")
}
if c.S3.AccessKeyID == "" {
return fmt.Errorf("s3.access_key_id is required")
}
if c.S3.SecretAccessKey == "" {
return fmt.Errorf("s3.secret_access_key is required")
}
return nil
}
// Module exports the config module for fx dependency injection.
// It provides the Config type to other modules in the application.
var Module = fx.Module("config",

View File

@ -1,108 +0,0 @@
// Package pidlock provides process-level locking using PID files.
// It prevents multiple instances of vaultik from running simultaneously,
// which would cause database locking conflicts.
package pidlock
import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
)
// ErrAlreadyRunning indicates another vaultik instance is running.
var ErrAlreadyRunning = errors.New("another vaultik instance is already running")
// Lock represents an acquired PID lock.
type Lock struct {
path string
}
// Acquire attempts to acquire a PID lock in the specified directory.
// If the lock file exists and the process is still running, it returns
// ErrAlreadyRunning with details about the existing process.
// On success, it writes the current PID to the lock file and returns
// a Lock that must be released with Release().
func Acquire(lockDir string) (*Lock, error) {
// Ensure lock directory exists
if err := os.MkdirAll(lockDir, 0700); err != nil {
return nil, fmt.Errorf("creating lock directory: %w", err)
}
lockPath := filepath.Join(lockDir, "vaultik.pid")
// Check for existing lock
existingPID, err := readPIDFile(lockPath)
if err == nil {
// Lock file exists, check if process is running
if isProcessRunning(existingPID) {
return nil, fmt.Errorf("%w (PID %d)", ErrAlreadyRunning, existingPID)
}
// Process is not running, stale lock file - we can take over
}
// Write our PID
pid := os.Getpid()
if err := os.WriteFile(lockPath, []byte(strconv.Itoa(pid)), 0600); err != nil {
return nil, fmt.Errorf("writing PID file: %w", err)
}
return &Lock{path: lockPath}, nil
}
// Release removes the PID lock file.
// It is safe to call Release multiple times.
func (l *Lock) Release() error {
if l == nil || l.path == "" {
return nil
}
// Verify we still own the lock (our PID is in the file)
existingPID, err := readPIDFile(l.path)
if err != nil {
// File already gone or unreadable - that's fine
return nil
}
if existingPID != os.Getpid() {
// Someone else wrote to our lock file - don't remove it
return nil
}
if err := os.Remove(l.path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("removing PID file: %w", err)
}
l.path = "" // Prevent double-release
return nil
}
// readPIDFile reads and parses the PID from a lock file.
func readPIDFile(path string) (int, error) {
data, err := os.ReadFile(path)
if err != nil {
return 0, err
}
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil {
return 0, fmt.Errorf("parsing PID: %w", err)
}
return pid, nil
}
// isProcessRunning checks if a process with the given PID is running.
func isProcessRunning(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
return false
}
// On Unix, FindProcess always succeeds. We need to send signal 0 to check.
err = process.Signal(syscall.Signal(0))
return err == nil
}

View File

@ -1,108 +0,0 @@
package pidlock
import (
"os"
"path/filepath"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAcquireAndRelease(t *testing.T) {
tmpDir := t.TempDir()
// Acquire lock
lock, err := Acquire(tmpDir)
require.NoError(t, err)
require.NotNil(t, lock)
// Verify PID file exists with our PID
data, err := os.ReadFile(filepath.Join(tmpDir, "vaultik.pid"))
require.NoError(t, err)
pid, err := strconv.Atoi(string(data))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
// Release lock
err = lock.Release()
require.NoError(t, err)
// Verify PID file is gone
_, err = os.Stat(filepath.Join(tmpDir, "vaultik.pid"))
assert.True(t, os.IsNotExist(err))
}
func TestAcquireBlocksSecondInstance(t *testing.T) {
tmpDir := t.TempDir()
// Acquire first lock
lock1, err := Acquire(tmpDir)
require.NoError(t, err)
require.NotNil(t, lock1)
defer func() { _ = lock1.Release() }()
// Try to acquire second lock - should fail
lock2, err := Acquire(tmpDir)
assert.ErrorIs(t, err, ErrAlreadyRunning)
assert.Nil(t, lock2)
}
func TestAcquireWithStaleLock(t *testing.T) {
tmpDir := t.TempDir()
// Write a stale PID file (PID that doesn't exist)
stalePID := 999999999 // Unlikely to be a real process
pidPath := filepath.Join(tmpDir, "vaultik.pid")
err := os.WriteFile(pidPath, []byte(strconv.Itoa(stalePID)), 0600)
require.NoError(t, err)
// Should be able to acquire lock (stale lock is cleaned up)
lock, err := Acquire(tmpDir)
require.NoError(t, err)
require.NotNil(t, lock)
defer func() { _ = lock.Release() }()
// Verify our PID is now in the file
data, err := os.ReadFile(pidPath)
require.NoError(t, err)
pid, err := strconv.Atoi(string(data))
require.NoError(t, err)
assert.Equal(t, os.Getpid(), pid)
}
func TestReleaseIsIdempotent(t *testing.T) {
tmpDir := t.TempDir()
lock, err := Acquire(tmpDir)
require.NoError(t, err)
// Release multiple times - should not error
err = lock.Release()
require.NoError(t, err)
err = lock.Release()
require.NoError(t, err)
}
func TestReleaseNilLock(t *testing.T) {
var lock *Lock
err := lock.Release()
assert.NoError(t, err)
}
func TestAcquireCreatesDirectory(t *testing.T) {
tmpDir := t.TempDir()
nestedDir := filepath.Join(tmpDir, "nested", "dir")
lock, err := Acquire(nestedDir)
require.NoError(t, err)
require.NotNil(t, lock)
defer func() { _ = lock.Release() }()
// Verify directory was created
info, err := os.Stat(nestedDir)
require.NoError(t, err)
assert.True(t, info.IsDir())
}

View File

@ -194,8 +194,8 @@ func TestMultipleFileChanges(t *testing.T) {
// First scan
result1, err := scanner.Scan(ctx, "/", snapshotID1)
require.NoError(t, err)
// Only regular files are counted, not directories
assert.Equal(t, 3, result1.FilesScanned)
// 4 files because root directory is also counted
assert.Equal(t, 4, result1.FilesScanned)
// Modify two files
time.Sleep(10 * time.Millisecond) // Ensure mtime changes
@ -221,8 +221,9 @@ func TestMultipleFileChanges(t *testing.T) {
result2, err := scanner.Scan(ctx, "/", snapshotID2)
require.NoError(t, err)
// Only regular files are counted, not directories
assert.Equal(t, 3, result2.FilesScanned)
// The scanner might examine more items than just our files (includes directories, etc)
// We should verify that at least our expected files were scanned
assert.GreaterOrEqual(t, result2.FilesScanned, 4, "Should scan at least 4 files (3 files + root dir)")
// Verify each file has exactly one set of chunks
for path := range files {

View File

@ -3,7 +3,7 @@ package snapshot
import (
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/s3"
"github.com/spf13/afero"
"go.uber.org/fx"
)
@ -27,13 +27,13 @@ var Module = fx.Module("backup",
// ScannerFactory creates scanners with custom parameters
type ScannerFactory func(params ScannerParams) *Scanner
func provideScannerFactory(cfg *config.Config, repos *database.Repositories, storer storage.Storer) ScannerFactory {
func provideScannerFactory(cfg *config.Config, repos *database.Repositories, s3Client *s3.Client) ScannerFactory {
return func(params ScannerParams) *Scanner {
return NewScanner(ScannerConfig{
FS: params.Fs,
ChunkSize: cfg.ChunkSize.Int64(),
Repositories: repos,
Storage: storer,
S3Client: s3Client,
MaxBlobSize: cfg.BlobSizeLimit.Int64(),
CompressionLevel: cfg.CompressionLevel,
AgeRecipients: cfg.AgeRecipients,

View File

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"io"
"os"
"strings"
"sync"
@ -13,7 +14,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/chunker"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/s3"
"github.com/dustin/go-humanize"
"github.com/spf13/afero"
)
@ -31,17 +32,13 @@ type Scanner struct {
chunker *chunker.Chunker
packer *blob.Packer
repos *database.Repositories
storage storage.Storer
s3Client S3Client
maxBlobSize int64
compressionLevel int
ageRecipient string
snapshotID string // Current snapshot being processed
progress *ProgressReporter
// In-memory cache of known chunk hashes for fast existence checks
knownChunks map[string]struct{}
knownChunksMu sync.RWMutex
// Mutex for coordinating blob creation
packerMu sync.Mutex // Blocks chunk production during blob creation
@ -49,12 +46,19 @@ type Scanner struct {
scanCtx context.Context
}
// S3Client interface for blob storage operations
type S3Client interface {
PutObject(ctx context.Context, key string, data io.Reader) error
PutObjectWithProgress(ctx context.Context, key string, data io.Reader, size int64, progress s3.ProgressCallback) error
StatObject(ctx context.Context, key string) (*s3.ObjectInfo, error)
}
// ScannerConfig contains configuration for the scanner
type ScannerConfig struct {
FS afero.Fs
ChunkSize int64
Repositories *database.Repositories
Storage storage.Storer
S3Client S3Client
MaxBlobSize int64
CompressionLevel int
AgeRecipients []string // Optional, empty means no encryption
@ -107,7 +111,7 @@ func NewScanner(cfg ScannerConfig) *Scanner {
chunker: chunker.NewChunker(cfg.ChunkSize),
packer: packer,
repos: cfg.Repositories,
storage: cfg.Storage,
s3Client: cfg.S3Client,
maxBlobSize: cfg.MaxBlobSize,
compressionLevel: cfg.CompressionLevel,
ageRecipient: strings.Join(cfg.AgeRecipients, ","),
@ -124,11 +128,11 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
}
// Set blob handler for concurrent upload
if s.storage != nil {
log.Debug("Setting blob handler for storage uploads")
if s.s3Client != nil {
log.Debug("Setting blob handler for S3 uploads")
s.packer.SetBlobHandler(s.handleBlobReady)
} else {
log.Debug("No storage configured, blobs will not be uploaded")
log.Debug("No S3 client configured, blobs will not be uploaded")
}
// Start progress reporting if enabled
@ -137,41 +141,16 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
defer s.progress.Stop()
}
// Phase 0: Load known files and chunks from database into memory for fast lookup
fmt.Println("Loading known files from database...")
knownFiles, err := s.loadKnownFiles(ctx, path)
if err != nil {
return nil, fmt.Errorf("loading known files: %w", err)
}
fmt.Printf("Loaded %s known files from database\n", formatNumber(len(knownFiles)))
fmt.Println("Loading known chunks from database...")
if err := s.loadKnownChunks(ctx); err != nil {
return nil, fmt.Errorf("loading known chunks: %w", err)
}
fmt.Printf("Loaded %s known chunks from database\n", formatNumber(len(s.knownChunks)))
// Phase 1: Scan directory, collect files to process, and track existing files
// (builds existingFiles map during walk to avoid double traversal)
log.Info("Phase 1/3: Scanning directory structure")
existingFiles := make(map[string]struct{})
scanResult, err := s.scanPhase(ctx, path, result, existingFiles, knownFiles)
if err != nil {
return nil, fmt.Errorf("scan phase failed: %w", err)
}
filesToProcess := scanResult.FilesToProcess
// Phase 1b: Detect deleted files by comparing DB against scanned files
if err := s.detectDeletedFilesFromMap(ctx, knownFiles, existingFiles, result); err != nil {
// Phase 0: Check for deleted files from previous snapshots
if err := s.detectDeletedFiles(ctx, path, result); err != nil {
return nil, fmt.Errorf("detecting deleted files: %w", err)
}
// Phase 1c: Associate unchanged files with this snapshot (no new records needed)
if len(scanResult.UnchangedFileIDs) > 0 {
fmt.Printf("Associating %s unchanged files with snapshot...\n", formatNumber(len(scanResult.UnchangedFileIDs)))
if err := s.batchAddFilesToSnapshot(ctx, scanResult.UnchangedFileIDs); err != nil {
return nil, fmt.Errorf("associating unchanged files: %w", err)
}
// Phase 1: Scan directory and collect files to process
log.Info("Phase 1/3: Scanning directory structure")
filesToProcess, err := s.scanPhase(ctx, path, result)
if err != nil {
return nil, fmt.Errorf("scan phase failed: %w", err)
}
// Calculate total size to process
@ -237,83 +216,22 @@ func (s *Scanner) Scan(ctx context.Context, path string, snapshotID string) (*Sc
return result, nil
}
// loadKnownFiles loads all known files from the database into a map for fast lookup
// This avoids per-file database queries during the scan phase
func (s *Scanner) loadKnownFiles(ctx context.Context, path string) (map[string]*database.File, error) {
files, err := s.repos.Files.ListByPrefix(ctx, path)
if err != nil {
return nil, fmt.Errorf("listing files by prefix: %w", err)
}
result := make(map[string]*database.File, len(files))
for _, f := range files {
result[f.Path] = f
}
return result, nil
}
// loadKnownChunks loads all known chunk hashes from the database into a map for fast lookup
// This avoids per-chunk database queries during file processing
func (s *Scanner) loadKnownChunks(ctx context.Context) error {
chunks, err := s.repos.Chunks.List(ctx)
if err != nil {
return fmt.Errorf("listing chunks: %w", err)
}
s.knownChunksMu.Lock()
s.knownChunks = make(map[string]struct{}, len(chunks))
for _, c := range chunks {
s.knownChunks[c.ChunkHash] = struct{}{}
}
s.knownChunksMu.Unlock()
return nil
}
// chunkExists checks if a chunk hash exists in the in-memory cache
func (s *Scanner) chunkExists(hash string) bool {
s.knownChunksMu.RLock()
_, exists := s.knownChunks[hash]
s.knownChunksMu.RUnlock()
return exists
}
// addKnownChunk adds a chunk hash to the in-memory cache
func (s *Scanner) addKnownChunk(hash string) {
s.knownChunksMu.Lock()
s.knownChunks[hash] = struct{}{}
s.knownChunksMu.Unlock()
}
// ScanPhaseResult contains the results of the scan phase
type ScanPhaseResult struct {
FilesToProcess []*FileToProcess
UnchangedFileIDs []string // IDs of unchanged files to associate with snapshot
}
// scanPhase performs the initial directory scan to identify files to process
// It uses the pre-loaded knownFiles map for fast change detection without DB queries
// It also populates existingFiles map for deletion detection
// Returns files needing processing and IDs of unchanged files for snapshot association
func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult, existingFiles map[string]struct{}, knownFiles map[string]*database.File) (*ScanPhaseResult, error) {
// Use known file count as estimate for progress (accurate for subsequent backups)
estimatedTotal := int64(len(knownFiles))
func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult) ([]*FileToProcess, error) {
var filesToProcess []*FileToProcess
var unchangedFileIDs []string // Just IDs - no new records needed
var mu sync.Mutex
// Set up periodic status output
startTime := time.Now()
lastStatusTime := time.Now()
statusInterval := 15 * time.Second
var filesScanned int64
var bytesScanned int64
log.Debug("Starting directory walk", "path", path)
err := afero.Walk(s.fs, path, func(filePath string, info os.FileInfo, err error) error {
err := afero.Walk(s.fs, path, func(path string, info os.FileInfo, err error) error {
log.Debug("Scanning filesystem entry", "path", path)
if err != nil {
log.Debug("Error accessing filesystem entry", "path", filePath, "error", err)
log.Debug("Error accessing filesystem entry", "path", path, "error", err)
return err
}
@ -324,80 +242,42 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult
default:
}
// Skip non-regular files for processing (but still count them)
if !info.Mode().IsRegular() {
return nil
// Check file and update metadata
file, needsProcessing, err := s.checkFileAndUpdateMetadata(ctx, path, info, result)
if err != nil {
// Don't log context cancellation as an error
if err == context.Canceled {
return err
}
return fmt.Errorf("failed to check %s: %w", path, err)
}
// Track this file as existing (for deletion detection)
existingFiles[filePath] = struct{}{}
// Check file against in-memory map (no DB query!)
file, needsProcessing := s.checkFileInMemory(filePath, info, knownFiles)
mu.Lock()
if needsProcessing {
// New or changed file - will create record after processing
// If file needs processing, add to list
if needsProcessing && info.Mode().IsRegular() && info.Size() > 0 {
mu.Lock()
filesToProcess = append(filesToProcess, &FileToProcess{
Path: filePath,
Path: path,
FileInfo: info,
File: file,
})
} else if file.ID != "" {
// Unchanged file with existing ID - just need snapshot association
unchangedFileIDs = append(unchangedFileIDs, file.ID)
mu.Unlock()
}
filesScanned++
changedCount := len(filesToProcess)
mu.Unlock()
// Update result stats
if needsProcessing {
result.BytesScanned += info.Size()
} else {
result.FilesSkipped++
result.BytesSkipped += info.Size()
// Update scan statistics
if info.Mode().IsRegular() {
filesScanned++
bytesScanned += info.Size()
}
result.FilesScanned++
// Output periodic status
if time.Since(lastStatusTime) >= statusInterval {
elapsed := time.Since(startTime)
rate := float64(filesScanned) / elapsed.Seconds()
mu.Lock()
changedCount := len(filesToProcess)
mu.Unlock()
// Build status line - use estimate if available (not first backup)
if estimatedTotal > 0 {
// Show actual scanned vs estimate (may exceed estimate if files were added)
pct := float64(filesScanned) / float64(estimatedTotal) * 100
if pct > 100 {
pct = 100 // Cap at 100% for display
}
remaining := estimatedTotal - filesScanned
if remaining < 0 {
remaining = 0
}
var eta time.Duration
if rate > 0 && remaining > 0 {
eta = time.Duration(float64(remaining)/rate) * time.Second
}
fmt.Printf("Scan: %s files (~%.0f%%), %s changed/new, %.0f files/sec, %s elapsed",
formatNumber(int(filesScanned)),
pct,
formatNumber(changedCount),
rate,
elapsed.Round(time.Second))
if eta > 0 {
fmt.Printf(", ETA %s", eta.Round(time.Second))
}
fmt.Println()
} else {
// First backup - no estimate available
fmt.Printf("Scan: %s files, %s changed/new, %.0f files/sec, %s elapsed\n",
formatNumber(int(filesScanned)),
formatNumber(changedCount),
rate,
elapsed.Round(time.Second))
}
fmt.Printf("Scan progress: %s files examined, %s changed\n",
formatNumber(int(filesScanned)),
formatNumber(changedCount))
lastStatusTime = time.Now()
}
@ -408,129 +288,16 @@ func (s *Scanner) scanPhase(ctx context.Context, path string, result *ScanResult
return nil, err
}
return &ScanPhaseResult{
FilesToProcess: filesToProcess,
UnchangedFileIDs: unchangedFileIDs,
}, nil
}
// checkFileInMemory checks if a file needs processing using the in-memory map
// No database access is performed - this is purely CPU/memory work
func (s *Scanner) checkFileInMemory(path string, info os.FileInfo, knownFiles map[string]*database.File) (*database.File, bool) {
// Get file stats
stat, ok := info.Sys().(interface {
Uid() uint32
Gid() uint32
})
var uid, gid uint32
if ok {
uid = stat.Uid()
gid = stat.Gid()
}
// Create file record
file := &database.File{
Path: path,
MTime: info.ModTime(),
CTime: info.ModTime(), // afero doesn't provide ctime
Size: info.Size(),
Mode: uint32(info.Mode()),
UID: uid,
GID: gid,
}
// Check against in-memory map
existingFile, exists := knownFiles[path]
if !exists {
// New file
return file, true
}
// Reuse existing ID
file.ID = existingFile.ID
// Check if file has changed
if existingFile.Size != file.Size ||
existingFile.MTime.Unix() != file.MTime.Unix() ||
existingFile.Mode != file.Mode ||
existingFile.UID != file.UID ||
existingFile.GID != file.GID {
return file, true
}
// File unchanged
return file, false
}
// batchAddFilesToSnapshot adds existing file IDs to the snapshot association table
// This is used for unchanged files that already have records in the database
func (s *Scanner) batchAddFilesToSnapshot(ctx context.Context, fileIDs []string) error {
const batchSize = 1000
startTime := time.Now()
lastStatusTime := time.Now()
statusInterval := 5 * time.Second
for i := 0; i < len(fileIDs); i += batchSize {
// Check context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
end := i + batchSize
if end > len(fileIDs) {
end = len(fileIDs)
}
batch := fileIDs[i:end]
err := s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
for _, fileID := range batch {
if err := s.repos.Snapshots.AddFileByID(ctx, tx, s.snapshotID, fileID); err != nil {
return fmt.Errorf("adding file to snapshot: %w", err)
}
}
return nil
})
if err != nil {
return err
}
// Periodic status
if time.Since(lastStatusTime) >= statusInterval {
elapsed := time.Since(startTime)
rate := float64(end) / elapsed.Seconds()
pct := float64(end) / float64(len(fileIDs)) * 100
fmt.Printf("Associating files: %s/%s (%.1f%%), %.0f files/sec\n",
formatNumber(end), formatNumber(len(fileIDs)), pct, rate)
lastStatusTime = time.Now()
}
}
elapsed := time.Since(startTime)
rate := float64(len(fileIDs)) / elapsed.Seconds()
fmt.Printf("Associated %s unchanged files in %s (%.0f files/sec)\n",
formatNumber(len(fileIDs)), elapsed.Round(time.Second), rate)
return nil
return filesToProcess, nil
}
// processPhase processes the files that need backing up
func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProcess, result *ScanResult) error {
// Calculate total bytes to process
var totalBytes int64
for _, f := range filesToProcess {
totalBytes += f.FileInfo.Size()
}
// Set up periodic status output
lastStatusTime := time.Now()
statusInterval := 15 * time.Second
startTime := time.Now()
filesProcessed := 0
var bytesProcessed int64
totalFiles := len(filesToProcess)
// Process each file
@ -551,33 +318,18 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc
}
filesProcessed++
bytesProcessed += fileToProcess.FileInfo.Size()
// Output periodic status
if time.Since(lastStatusTime) >= statusInterval {
elapsed := time.Since(startTime)
pct := float64(bytesProcessed) / float64(totalBytes) * 100
byteRate := float64(bytesProcessed) / elapsed.Seconds()
fileRate := float64(filesProcessed) / elapsed.Seconds()
// Calculate ETA based on bytes (more accurate than files)
remainingBytes := totalBytes - bytesProcessed
remaining := totalFiles - filesProcessed
var eta time.Duration
if byteRate > 0 {
eta = time.Duration(float64(remainingBytes)/byteRate) * time.Second
if filesProcessed > 0 {
eta = elapsed / time.Duration(filesProcessed) * time.Duration(remaining)
}
// Format: Progress [5.7k/610k] 6.7 GB/44 GB (15.4%), 106MB/sec, 500 files/sec, running for 1m30s, ETA: 5m49s
fmt.Printf("Progress [%s/%s] %s/%s (%.1f%%), %s/sec, %.0f files/sec, running for %s",
formatCompact(filesProcessed),
formatCompact(totalFiles),
humanize.Bytes(uint64(bytesProcessed)),
humanize.Bytes(uint64(totalBytes)),
pct,
humanize.Bytes(uint64(byteRate)),
fileRate,
elapsed.Round(time.Second))
if eta > 0 {
fmt.Printf("Progress: %s/%s files", formatNumber(filesProcessed), formatNumber(totalFiles))
if remaining > 0 && eta > 0 {
fmt.Printf(", ETA: %s", eta.Round(time.Second))
}
fmt.Println()
@ -593,8 +345,8 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc
}
s.packerMu.Unlock()
// If no storage configured, store any remaining blobs locally
if s.storage == nil {
// If no S3 client, store any remaining blobs
if s.s3Client == nil {
blobs := s.packer.GetFinishedBlobs()
for _, b := range blobs {
// Blob metadata is already stored incrementally during packing
@ -612,6 +364,205 @@ func (s *Scanner) processPhase(ctx context.Context, filesToProcess []*FileToProc
return nil
}
// checkFileAndUpdateMetadata checks if a file needs processing and updates metadata
func (s *Scanner) checkFileAndUpdateMetadata(ctx context.Context, path string, info os.FileInfo, result *ScanResult) (*database.File, bool, error) {
// Check context cancellation
select {
case <-ctx.Done():
return nil, false, ctx.Err()
default:
}
// Process file without holding a long transaction
return s.checkFile(ctx, path, info, result)
}
// checkFile checks if a file needs processing and updates metadata
func (s *Scanner) checkFile(ctx context.Context, path string, info os.FileInfo, result *ScanResult) (*database.File, bool, error) {
// Get file stats
stat, ok := info.Sys().(interface {
Uid() uint32
Gid() uint32
})
var uid, gid uint32
if ok {
uid = stat.Uid()
gid = stat.Gid()
}
// Check if it's a symlink
var linkTarget string
if info.Mode()&os.ModeSymlink != 0 {
// Read the symlink target
if linker, ok := s.fs.(afero.LinkReader); ok {
linkTarget, _ = linker.ReadlinkIfPossible(path)
}
}
// Create file record
file := &database.File{
Path: path,
MTime: info.ModTime(),
CTime: info.ModTime(), // afero doesn't provide ctime
Size: info.Size(),
Mode: uint32(info.Mode()),
UID: uid,
GID: gid,
LinkTarget: linkTarget,
}
// Check if file has changed since last backup (no transaction needed for read)
log.Debug("Querying database for existing file record", "path", path)
existingFile, err := s.repos.Files.GetByPath(ctx, path)
if err != nil {
return nil, false, fmt.Errorf("checking existing file: %w", err)
}
fileChanged := existingFile == nil || s.hasFileChanged(existingFile, file)
// Update file metadata and add to snapshot in a single transaction
log.Debug("Updating file record in database and adding to snapshot", "path", path, "changed", fileChanged, "snapshot", s.snapshotID)
err = s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
// First create/update the file
if err := s.repos.Files.Create(ctx, tx, file); err != nil {
return fmt.Errorf("creating file: %w", err)
}
// Then add it to the snapshot using the file ID
if err := s.repos.Snapshots.AddFileByID(ctx, tx, s.snapshotID, file.ID); err != nil {
return fmt.Errorf("adding file to snapshot: %w", err)
}
return nil
})
if err != nil {
return nil, false, err
}
log.Debug("File record added to snapshot association", "path", path)
result.FilesScanned++
// Update progress
if s.progress != nil {
stats := s.progress.GetStats()
stats.FilesScanned.Add(1)
stats.CurrentFile.Store(path)
}
// Track skipped files
if info.Mode().IsRegular() && info.Size() > 0 && !fileChanged {
result.FilesSkipped++
result.BytesSkipped += info.Size()
if s.progress != nil {
stats := s.progress.GetStats()
stats.FilesSkipped.Add(1)
stats.BytesSkipped.Add(info.Size())
}
// File hasn't changed, but we still need to associate existing chunks with this snapshot
log.Debug("File content unchanged, reusing existing chunks and blobs", "path", path)
if err := s.associateExistingChunks(ctx, path); err != nil {
return nil, false, fmt.Errorf("associating existing chunks: %w", err)
}
log.Debug("Existing chunks and blobs associated with snapshot", "path", path)
} else {
// File changed or is not a regular file
result.BytesScanned += info.Size()
if s.progress != nil {
s.progress.GetStats().BytesScanned.Add(info.Size())
}
}
return file, fileChanged, nil
}
// hasFileChanged determines if a file has changed since last backup
func (s *Scanner) hasFileChanged(existingFile, newFile *database.File) bool {
// Check if any metadata has changed
if existingFile.Size != newFile.Size {
return true
}
if existingFile.MTime.Unix() != newFile.MTime.Unix() {
return true
}
if existingFile.Mode != newFile.Mode {
return true
}
if existingFile.UID != newFile.UID {
return true
}
if existingFile.GID != newFile.GID {
return true
}
if existingFile.LinkTarget != newFile.LinkTarget {
return true
}
return false
}
// associateExistingChunks links existing chunks from an unchanged file to the current snapshot
func (s *Scanner) associateExistingChunks(ctx context.Context, path string) error {
log.Debug("associateExistingChunks start", "path", path)
// Get existing file chunks (no transaction needed for read)
log.Debug("Querying database for file's chunk associations", "path", path)
fileChunks, err := s.repos.FileChunks.GetByFile(ctx, path)
if err != nil {
return fmt.Errorf("getting existing file chunks: %w", err)
}
log.Debug("Retrieved file chunk associations from database", "path", path, "count", len(fileChunks))
// Collect unique blob IDs that need to be added to snapshot
blobsToAdd := make(map[string]string) // blob ID -> blob hash
for i, fc := range fileChunks {
log.Debug("Looking up blob containing chunk", "path", path, "chunk_index", i, "chunk_hash", fc.ChunkHash)
// Find which blob contains this chunk (no transaction needed for read)
log.Debug("Querying database for blob containing chunk", "chunk_hash", fc.ChunkHash)
blobChunk, err := s.repos.BlobChunks.GetByChunkHash(ctx, fc.ChunkHash)
if err != nil {
return fmt.Errorf("finding blob for chunk %s: %w", fc.ChunkHash, err)
}
if blobChunk == nil {
log.Warn("Chunk record exists in database but not associated with any blob", "chunk", fc.ChunkHash, "file", path)
continue
}
log.Debug("Found blob record containing chunk", "chunk_hash", fc.ChunkHash, "blob_id", blobChunk.BlobID)
// Track blob ID for later processing
if _, exists := blobsToAdd[blobChunk.BlobID]; !exists {
blobsToAdd[blobChunk.BlobID] = "" // We'll get the hash later
}
}
// Now get blob hashes outside of transaction operations
for blobID := range blobsToAdd {
blob, err := s.repos.Blobs.GetByID(ctx, blobID)
if err != nil {
return fmt.Errorf("getting blob %s: %w", blobID, err)
}
if blob == nil {
log.Warn("Blob record missing from database", "blob_id", blobID)
delete(blobsToAdd, blobID)
continue
}
blobsToAdd[blobID] = blob.Hash
}
// Add blobs to snapshot using short transactions
for blobID, blobHash := range blobsToAdd {
log.Debug("Adding blob reference to snapshot association", "blob_id", blobID, "blob_hash", blobHash, "snapshot", s.snapshotID)
err := s.repos.WithTx(ctx, func(ctx context.Context, tx *sql.Tx) error {
return s.repos.Snapshots.AddBlob(ctx, tx, s.snapshotID, blobID, blobHash)
})
if err != nil {
return fmt.Errorf("adding existing blob to snapshot: %w", err)
}
log.Debug("Created snapshot-blob association in database", "blob_id", blobID)
}
log.Debug("associateExistingChunks complete", "path", path, "blobs_processed", len(blobsToAdd))
return nil
}
// handleBlobReady is called by the packer when a blob is finalized
func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
startTime := time.Now().UTC()
@ -622,7 +573,7 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
s.progress.ReportUploadStart(finishedBlob.Hash, finishedBlob.Compressed)
}
// Upload to storage first (without holding any locks)
// Upload to S3 first (without holding any locks)
// Use scan context for cancellation support
ctx := s.scanCtx
if ctx == nil {
@ -634,6 +585,7 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
lastProgressBytes := int64(0)
progressCallback := func(uploaded int64) error {
// Calculate instantaneous speed
now := time.Now()
elapsed := now.Sub(lastProgressTime).Seconds()
@ -660,29 +612,19 @@ func (s *Scanner) handleBlobReady(blobWithReader *blob.BlobWithReader) error {
// Create sharded path: blobs/ca/fe/cafebabe...
blobPath := fmt.Sprintf("blobs/%s/%s/%s", finishedBlob.Hash[:2], finishedBlob.Hash[2:4], finishedBlob.Hash)
if err := s.storage.PutWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil {
return fmt.Errorf("uploading blob %s to storage: %w", finishedBlob.Hash, err)
if err := s.s3Client.PutObjectWithProgress(ctx, blobPath, blobWithReader.Reader, finishedBlob.Compressed, progressCallback); err != nil {
return fmt.Errorf("uploading blob %s to S3: %w", finishedBlob.Hash, err)
}
uploadDuration := time.Since(startTime)
// Calculate upload speed
uploadSpeedBps := float64(finishedBlob.Compressed) / uploadDuration.Seconds()
// Print blob stored message
fmt.Printf("Blob stored: %s (%s, %s/sec, %s)\n",
finishedBlob.Hash[:12]+"...",
humanize.Bytes(uint64(finishedBlob.Compressed)),
humanize.Bytes(uint64(uploadSpeedBps)),
uploadDuration.Round(time.Millisecond))
// Log upload stats
uploadSpeedBits := uploadSpeedBps * 8 // bits per second
log.Info("Successfully uploaded blob to storage",
uploadSpeed := float64(finishedBlob.Compressed) * 8 / uploadDuration.Seconds() // bits per second
log.Info("Successfully uploaded blob to S3 storage",
"path", blobPath,
"size", humanize.Bytes(uint64(finishedBlob.Compressed)),
"duration", uploadDuration,
"speed", humanize.SI(uploadSpeedBits, "bps"))
"speed", humanize.SI(uploadSpeed, "bps"))
// Report upload complete
if s.progress != nil {
@ -776,8 +718,12 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
"hash", chunk.Hash,
"size", chunk.Size)
// Check if chunk already exists (fast in-memory lookup)
chunkExists := s.chunkExists(chunk.Hash)
// Check if chunk already exists (outside of transaction)
existing, err := s.repos.Chunks.GetByHash(ctx, chunk.Hash)
if err != nil {
return fmt.Errorf("checking chunk existence: %w", err)
}
chunkExists := (existing != nil)
// Store chunk if new
if !chunkExists {
@ -794,8 +740,6 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
if err != nil {
return fmt.Errorf("storing chunk: %w", err)
}
// Add to in-memory cache for fast duplicate detection
s.addKnownChunk(chunk.Hash)
}
// Track file chunk association for later storage
@ -871,16 +815,9 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
"file_hash", fileHash,
"chunks", len(chunks))
// Store file record, chunk associations, and snapshot association in database
// This happens AFTER successful chunking to avoid orphaned records on interruption
// Store file-chunk associations and chunk-file mappings in database
err = s.repos.WithTx(ctx, func(txCtx context.Context, tx *sql.Tx) error {
// Create or update the file record
// Files.Create uses INSERT OR REPLACE, so it handles both new and changed files
if err := s.repos.Files.Create(txCtx, tx, fileToProcess.File); err != nil {
return fmt.Errorf("creating file record: %w", err)
}
// Delete any existing file_chunks and chunk_files for this file
// First, delete all existing file_chunks and chunk_files for this file
// This ensures old chunks are no longer associated when file content changes
if err := s.repos.FileChunks.DeleteByFileID(txCtx, tx, fileToProcess.File.ID); err != nil {
return fmt.Errorf("deleting old file chunks: %w", err)
@ -889,11 +826,6 @@ func (s *Scanner) processFileStreaming(ctx context.Context, fileToProcess *FileT
return fmt.Errorf("deleting old chunk files: %w", err)
}
// Update chunk associations with the file ID (now that we have it)
for i := range chunks {
chunks[i].fileChunk.FileID = fileToProcess.File.ID
}
for _, ci := range chunks {
// Create file-chunk mapping
if err := s.repos.FileChunks.Create(txCtx, tx, &ci.fileChunk); err != nil {
@ -928,35 +860,25 @@ func (s *Scanner) GetProgress() *ProgressReporter {
return s.progress
}
// detectDeletedFilesFromMap finds files that existed in previous snapshots but no longer exist
// Uses pre-loaded maps to avoid any filesystem or database access
func (s *Scanner) detectDeletedFilesFromMap(ctx context.Context, knownFiles map[string]*database.File, existingFiles map[string]struct{}, result *ScanResult) error {
if len(knownFiles) == 0 {
return nil
// detectDeletedFiles finds files that existed in previous snapshots but no longer exist
func (s *Scanner) detectDeletedFiles(ctx context.Context, path string, result *ScanResult) error {
// Get all files with this path prefix from the database
files, err := s.repos.Files.ListByPrefix(ctx, path)
if err != nil {
return fmt.Errorf("listing files by prefix: %w", err)
}
// Check each known file against the enumerated set (no filesystem access needed)
for path, file := range knownFiles {
// Check context cancellation periodically
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Check if the file exists in our enumerated set
if _, exists := existingFiles[path]; !exists {
for _, file := range files {
// Check if the file still exists on disk
_, err := s.fs.Stat(file.Path)
if os.IsNotExist(err) {
// File has been deleted
result.FilesDeleted++
result.BytesDeleted += file.Size
log.Debug("Detected deleted file", "path", path, "size", file.Size)
log.Debug("Detected deleted file", "path", file.Path, "size", file.Size)
}
}
if result.FilesDeleted > 0 {
fmt.Printf("Found %s deleted files\n", formatNumber(result.FilesDeleted))
}
return nil
}
@ -967,17 +889,3 @@ func formatNumber(n int) string {
}
return humanize.Comma(int64(n))
}
// formatCompact formats a number compactly with k/M suffixes (e.g., 5.7k, 1.2M)
func formatCompact(n int) string {
if n < 1000 {
return fmt.Sprintf("%d", n)
}
if n < 10000 {
return fmt.Sprintf("%.1fk", float64(n)/1000)
}
if n < 1000000 {
return fmt.Sprintf("%.0fk", float64(n)/1000)
}
return fmt.Sprintf("%.1fM", float64(n)/1000000)
}

View File

@ -99,25 +99,26 @@ func TestScannerSimpleDirectory(t *testing.T) {
t.Fatalf("scan failed: %v", err)
}
// Verify results - we only scan regular files, not directories
if result.FilesScanned != 6 {
t.Errorf("expected 6 files scanned, got %d", result.FilesScanned)
// Verify results
// We now scan 6 files + 3 directories (source, subdir, subdir2) = 9 entries
if result.FilesScanned != 9 {
t.Errorf("expected 9 entries scanned, got %d", result.FilesScanned)
}
// Total bytes should be the sum of all file contents
// Directories have their own sizes, so the total will be more than just file content
if result.BytesScanned < 97 { // At minimum we have 97 bytes of file content
t.Errorf("expected at least 97 bytes scanned, got %d", result.BytesScanned)
}
// Verify files in database - only regular files are stored
// Verify files in database
files, err := repos.Files.ListByPrefix(ctx, "/source")
if err != nil {
t.Fatalf("failed to list files: %v", err)
}
// We should have 6 files (directories are not stored)
if len(files) != 6 {
t.Errorf("expected 6 files in database, got %d", len(files))
// We should have 6 files + 3 directories = 9 entries
if len(files) != 9 {
t.Errorf("expected 9 entries in database, got %d", len(files))
}
// Verify specific file
@ -234,9 +235,9 @@ func TestScannerLargeFile(t *testing.T) {
t.Fatalf("scan failed: %v", err)
}
// We scan only regular files, not directories
if result.FilesScanned != 1 {
t.Errorf("expected 1 file scanned, got %d", result.FilesScanned)
// We scan 1 file + 1 directory = 2 entries
if result.FilesScanned != 2 {
t.Errorf("expected 2 entries scanned, got %d", result.FilesScanned)
}
// The file size should be at least 1MB

View File

@ -52,7 +52,7 @@ import (
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/storage"
"git.eeqj.de/sneak/vaultik/internal/s3"
"github.com/dustin/go-humanize"
"github.com/spf13/afero"
"go.uber.org/fx"
@ -60,27 +60,27 @@ import (
// SnapshotManager handles snapshot creation and metadata export
type SnapshotManager struct {
repos *database.Repositories
storage storage.Storer
config *config.Config
fs afero.Fs
repos *database.Repositories
s3Client S3Client
config *config.Config
fs afero.Fs
}
// SnapshotManagerParams holds dependencies for NewSnapshotManager
type SnapshotManagerParams struct {
fx.In
Repos *database.Repositories
Storage storage.Storer
Config *config.Config
Repos *database.Repositories
S3Client *s3.Client
Config *config.Config
}
// NewSnapshotManager creates a new snapshot manager for dependency injection
func NewSnapshotManager(params SnapshotManagerParams) *SnapshotManager {
return &SnapshotManager{
repos: params.Repos,
storage: params.Storage,
config: params.Config,
repos: params.Repos,
s3Client: params.S3Client,
config: params.Config,
}
}
@ -268,7 +268,7 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
dbKey := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
dbUploadStart := time.Now()
if err := sm.storage.Put(ctx, dbKey, bytes.NewReader(finalData)); err != nil {
if err := sm.s3Client.PutObject(ctx, dbKey, bytes.NewReader(finalData)); err != nil {
return fmt.Errorf("uploading snapshot database: %w", err)
}
dbUploadDuration := time.Since(dbUploadStart)
@ -282,7 +282,7 @@ func (sm *SnapshotManager) ExportSnapshotMetadata(ctx context.Context, dbPath st
// Upload blob manifest (compressed only, not encrypted)
manifestKey := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
manifestUploadStart := time.Now()
if err := sm.storage.Put(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil {
if err := sm.s3Client.PutObject(ctx, manifestKey, bytes.NewReader(blobManifest)); err != nil {
return fmt.Errorf("uploading blob manifest: %w", err)
}
manifestUploadDuration := time.Since(manifestUploadStart)
@ -635,11 +635,11 @@ func (sm *SnapshotManager) CleanupIncompleteSnapshots(ctx context.Context, hostn
log.Info("Found incomplete snapshots", "count", len(incompleteSnapshots))
// Check each incomplete snapshot for metadata in storage
// Check each incomplete snapshot for metadata in S3
for _, snapshot := range incompleteSnapshots {
// Check if metadata exists in storage
// Check if metadata exists in S3
metadataKey := fmt.Sprintf("metadata/%s/db.zst", snapshot.ID)
_, err := sm.storage.Stat(ctx, metadataKey)
_, err := sm.s3Client.StatObject(ctx, metadataKey)
if err != nil {
// Metadata doesn't exist in S3 - this is an incomplete snapshot
@ -676,11 +676,6 @@ func (sm *SnapshotManager) deleteSnapshot(ctx context.Context, snapshotID string
return fmt.Errorf("deleting snapshot blobs: %w", err)
}
// Delete uploads entries (has foreign key to snapshots without CASCADE)
if err := sm.repos.Snapshots.DeleteSnapshotUploads(ctx, snapshotID); err != nil {
return fmt.Errorf("deleting snapshot uploads: %w", err)
}
// Delete the snapshot itself
if err := sm.repos.Snapshots.Delete(ctx, snapshotID); err != nil {
return fmt.Errorf("deleting snapshot: %w", err)

View File

@ -1,262 +0,0 @@
package storage
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/spf13/afero"
)
// FileStorer implements Storer using the local filesystem.
// It mirrors the S3 path structure for consistency.
type FileStorer struct {
fs afero.Fs
basePath string
}
// NewFileStorer creates a new filesystem storage backend.
// The basePath directory will be created if it doesn't exist.
// Uses the real OS filesystem by default; call SetFilesystem to override for testing.
func NewFileStorer(basePath string) (*FileStorer, error) {
fs := afero.NewOsFs()
// Ensure base path exists
if err := fs.MkdirAll(basePath, 0755); err != nil {
return nil, fmt.Errorf("creating base path: %w", err)
}
return &FileStorer{
fs: fs,
basePath: basePath,
}, nil
}
// SetFilesystem overrides the filesystem for testing.
func (f *FileStorer) SetFilesystem(fs afero.Fs) {
f.fs = fs
}
// fullPath returns the full filesystem path for a key.
func (f *FileStorer) fullPath(key string) string {
return filepath.Join(f.basePath, key)
}
// Put stores data at the specified key.
func (f *FileStorer) Put(ctx context.Context, key string, data io.Reader) error {
path := f.fullPath(key)
// Create parent directories
dir := filepath.Dir(path)
if err := f.fs.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("creating directories: %w", err)
}
file, err := f.fs.Create(path)
if err != nil {
return fmt.Errorf("creating file: %w", err)
}
defer func() { _ = file.Close() }()
if _, err := io.Copy(file, data); err != nil {
return fmt.Errorf("writing file: %w", err)
}
return nil
}
// PutWithProgress stores data with progress reporting.
func (f *FileStorer) PutWithProgress(ctx context.Context, key string, data io.Reader, size int64, progress ProgressCallback) error {
path := f.fullPath(key)
// Create parent directories
dir := filepath.Dir(path)
if err := f.fs.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("creating directories: %w", err)
}
file, err := f.fs.Create(path)
if err != nil {
return fmt.Errorf("creating file: %w", err)
}
defer func() { _ = file.Close() }()
// Wrap with progress tracking
pw := &progressWriter{
writer: file,
callback: progress,
}
if _, err := io.Copy(pw, data); err != nil {
return fmt.Errorf("writing file: %w", err)
}
return nil
}
// Get retrieves data from the specified key.
func (f *FileStorer) Get(ctx context.Context, key string) (io.ReadCloser, error) {
path := f.fullPath(key)
file, err := f.fs.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("opening file: %w", err)
}
return file, nil
}
// Stat returns metadata about an object without retrieving its contents.
func (f *FileStorer) Stat(ctx context.Context, key string) (*ObjectInfo, error) {
path := f.fullPath(key)
info, err := f.fs.Stat(path)
if err != nil {
if os.IsNotExist(err) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("stat file: %w", err)
}
return &ObjectInfo{
Key: key,
Size: info.Size(),
}, nil
}
// Delete removes an object.
func (f *FileStorer) Delete(ctx context.Context, key string) error {
path := f.fullPath(key)
err := f.fs.Remove(path)
if os.IsNotExist(err) {
return nil // Match S3 behavior: no error if doesn't exist
}
if err != nil {
return fmt.Errorf("removing file: %w", err)
}
return nil
}
// List returns all keys with the given prefix.
func (f *FileStorer) List(ctx context.Context, prefix string) ([]string, error) {
var keys []string
basePath := f.fullPath(prefix)
// Check if base path exists
exists, err := afero.Exists(f.fs, basePath)
if err != nil {
return nil, fmt.Errorf("checking path: %w", err)
}
if !exists {
return keys, nil // Empty list for non-existent prefix
}
err = afero.Walk(f.fs, basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Check context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if !info.IsDir() {
// Convert back to key (relative path from basePath)
relPath, err := filepath.Rel(f.basePath, path)
if err != nil {
return fmt.Errorf("computing relative path: %w", err)
}
// Normalize path separators to forward slashes for consistency
relPath = strings.ReplaceAll(relPath, string(filepath.Separator), "/")
keys = append(keys, relPath)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("walking directory: %w", err)
}
return keys, nil
}
// ListStream returns a channel of ObjectInfo for large result sets.
func (f *FileStorer) ListStream(ctx context.Context, prefix string) <-chan ObjectInfo {
ch := make(chan ObjectInfo)
go func() {
defer close(ch)
basePath := f.fullPath(prefix)
// Check if base path exists
exists, err := afero.Exists(f.fs, basePath)
if err != nil {
ch <- ObjectInfo{Err: fmt.Errorf("checking path: %w", err)}
return
}
if !exists {
return // Empty channel for non-existent prefix
}
_ = afero.Walk(f.fs, basePath, func(path string, info os.FileInfo, err error) error {
// Check context cancellation
select {
case <-ctx.Done():
ch <- ObjectInfo{Err: ctx.Err()}
return ctx.Err()
default:
}
if err != nil {
ch <- ObjectInfo{Err: err}
return nil // Continue walking despite errors
}
if !info.IsDir() {
relPath, err := filepath.Rel(f.basePath, path)
if err != nil {
ch <- ObjectInfo{Err: fmt.Errorf("computing relative path: %w", err)}
return nil
}
// Normalize path separators
relPath = strings.ReplaceAll(relPath, string(filepath.Separator), "/")
ch <- ObjectInfo{
Key: relPath,
Size: info.Size(),
}
}
return nil
})
}()
return ch
}
// Info returns human-readable storage location information.
func (f *FileStorer) Info() StorageInfo {
return StorageInfo{
Type: "file",
Location: f.basePath,
}
}
// progressWriter wraps an io.Writer to track write progress.
type progressWriter struct {
writer io.Writer
written int64
callback ProgressCallback
}
func (pw *progressWriter) Write(p []byte) (int, error) {
n, err := pw.writer.Write(p)
if n > 0 {
pw.written += int64(n)
if pw.callback != nil {
if callbackErr := pw.callback(pw.written); callbackErr != nil {
return n, callbackErr
}
}
}
return n, err
}

View File

@ -1,110 +0,0 @@
package storage
import (
"context"
"fmt"
"strings"
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/s3"
"go.uber.org/fx"
)
// Module exports storage functionality as an fx module.
// It provides a Storer implementation based on the configured storage URL
// or falls back to legacy S3 configuration.
var Module = fx.Module("storage",
fx.Provide(NewStorer),
)
// NewStorer creates a Storer based on configuration.
// If StorageURL is set, it uses URL-based configuration.
// Otherwise, it falls back to legacy S3 configuration.
func NewStorer(cfg *config.Config) (Storer, error) {
if cfg.StorageURL != "" {
return storerFromURL(cfg.StorageURL, cfg)
}
return storerFromLegacyS3Config(cfg)
}
func storerFromURL(rawURL string, cfg *config.Config) (Storer, error) {
parsed, err := ParseStorageURL(rawURL)
if err != nil {
return nil, fmt.Errorf("parsing storage URL: %w", err)
}
switch parsed.Scheme {
case "file":
return NewFileStorer(parsed.Prefix)
case "s3":
// Build endpoint URL
endpoint := parsed.Endpoint
if endpoint == "" {
endpoint = "s3.amazonaws.com"
}
// Add protocol if not present
if parsed.UseSSL && !strings.HasPrefix(endpoint, "https://") && !strings.HasPrefix(endpoint, "http://") {
endpoint = "https://" + endpoint
} else if !parsed.UseSSL && !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
endpoint = "http://" + endpoint
}
region := parsed.Region
if region == "" {
region = cfg.S3.Region
if region == "" {
region = "us-east-1"
}
}
// Credentials come from config (not URL for security)
client, err := s3.NewClient(context.Background(), s3.Config{
Endpoint: endpoint,
Bucket: parsed.Bucket,
Prefix: parsed.Prefix,
AccessKeyID: cfg.S3.AccessKeyID,
SecretAccessKey: cfg.S3.SecretAccessKey,
Region: region,
})
if err != nil {
return nil, fmt.Errorf("creating S3 client: %w", err)
}
return NewS3Storer(client), nil
default:
return nil, fmt.Errorf("unsupported storage scheme: %s", parsed.Scheme)
}
}
func storerFromLegacyS3Config(cfg *config.Config) (Storer, error) {
endpoint := cfg.S3.Endpoint
// Ensure protocol is present
if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
if cfg.S3.UseSSL {
endpoint = "https://" + endpoint
} else {
endpoint = "http://" + endpoint
}
}
region := cfg.S3.Region
if region == "" {
region = "us-east-1"
}
client, err := s3.NewClient(context.Background(), s3.Config{
Endpoint: endpoint,
Bucket: cfg.S3.Bucket,
Prefix: cfg.S3.Prefix,
AccessKeyID: cfg.S3.AccessKeyID,
SecretAccessKey: cfg.S3.SecretAccessKey,
Region: region,
})
if err != nil {
return nil, fmt.Errorf("creating S3 client: %w", err)
}
return NewS3Storer(client), nil
}

View File

@ -1,85 +0,0 @@
package storage
import (
"context"
"fmt"
"io"
"git.eeqj.de/sneak/vaultik/internal/s3"
)
// S3Storer wraps the existing s3.Client to implement Storer.
type S3Storer struct {
client *s3.Client
}
// NewS3Storer creates a new S3 storage backend.
func NewS3Storer(client *s3.Client) *S3Storer {
return &S3Storer{client: client}
}
// Put stores data at the specified key.
func (s *S3Storer) Put(ctx context.Context, key string, data io.Reader) error {
return s.client.PutObject(ctx, key, data)
}
// PutWithProgress stores data with progress reporting.
func (s *S3Storer) PutWithProgress(ctx context.Context, key string, data io.Reader, size int64, progress ProgressCallback) error {
// Convert storage.ProgressCallback to s3.ProgressCallback
var s3Progress s3.ProgressCallback
if progress != nil {
s3Progress = s3.ProgressCallback(progress)
}
return s.client.PutObjectWithProgress(ctx, key, data, size, s3Progress)
}
// Get retrieves data from the specified key.
func (s *S3Storer) Get(ctx context.Context, key string) (io.ReadCloser, error) {
return s.client.GetObject(ctx, key)
}
// Stat returns metadata about an object without retrieving its contents.
func (s *S3Storer) Stat(ctx context.Context, key string) (*ObjectInfo, error) {
info, err := s.client.StatObject(ctx, key)
if err != nil {
return nil, err
}
return &ObjectInfo{
Key: info.Key,
Size: info.Size,
}, nil
}
// Delete removes an object.
func (s *S3Storer) Delete(ctx context.Context, key string) error {
return s.client.DeleteObject(ctx, key)
}
// List returns all keys with the given prefix.
func (s *S3Storer) List(ctx context.Context, prefix string) ([]string, error) {
return s.client.ListObjects(ctx, prefix)
}
// ListStream returns a channel of ObjectInfo for large result sets.
func (s *S3Storer) ListStream(ctx context.Context, prefix string) <-chan ObjectInfo {
ch := make(chan ObjectInfo)
go func() {
defer close(ch)
for info := range s.client.ListObjectsStream(ctx, prefix, false) {
ch <- ObjectInfo{
Key: info.Key,
Size: info.Size,
Err: info.Err,
}
}
}()
return ch
}
// Info returns human-readable storage location information.
func (s *S3Storer) Info() StorageInfo {
return StorageInfo{
Type: "s3",
Location: fmt.Sprintf("%s/%s", s.client.Endpoint(), s.client.BucketName()),
}
}

View File

@ -1,74 +0,0 @@
// Package storage provides a unified interface for storage backends.
// It supports both S3-compatible object storage and local filesystem storage,
// allowing Vaultik to store backups in either location with the same API.
//
// Storage backends are selected via URL:
// - s3://bucket/prefix?endpoint=host&region=r - S3-compatible storage
// - file:///path/to/backup - Local filesystem storage
//
// Both backends implement the Storer interface and support progress reporting
// during upload/write operations.
package storage
import (
"context"
"errors"
"io"
)
// ErrNotFound is returned when an object does not exist.
var ErrNotFound = errors.New("object not found")
// ProgressCallback is called during storage operations with bytes transferred so far.
// Return an error to cancel the operation.
type ProgressCallback func(bytesTransferred int64) error
// ObjectInfo contains metadata about a stored object.
type ObjectInfo struct {
Key string // Object key/path
Size int64 // Size in bytes
Err error // Error for streaming results (nil on success)
}
// StorageInfo provides human-readable storage configuration.
type StorageInfo struct {
Type string // "s3" or "file"
Location string // endpoint/bucket for S3, base path for filesystem
}
// Storer defines the interface for storage backends.
// All paths are relative to the storage root (bucket/prefix for S3, base directory for filesystem).
type Storer interface {
// Put stores data at the specified key.
// Parent directories are created automatically for filesystem backends.
Put(ctx context.Context, key string, data io.Reader) error
// PutWithProgress stores data with progress reporting.
// Size must be the exact size of the data to store.
// The progress callback is called periodically with bytes transferred.
PutWithProgress(ctx context.Context, key string, data io.Reader, size int64, progress ProgressCallback) error
// Get retrieves data from the specified key.
// The caller must close the returned ReadCloser.
// Returns ErrNotFound if the object does not exist.
Get(ctx context.Context, key string) (io.ReadCloser, error)
// Stat returns metadata about an object without retrieving its contents.
// Returns ErrNotFound if the object does not exist.
Stat(ctx context.Context, key string) (*ObjectInfo, error)
// Delete removes an object. No error is returned if the object doesn't exist.
Delete(ctx context.Context, key string) error
// List returns all keys with the given prefix.
// For large result sets, prefer ListStream.
List(ctx context.Context, prefix string) ([]string, error)
// ListStream returns a channel of ObjectInfo for large result sets.
// The channel is closed when listing completes.
// If an error occurs during listing, the final item will have Err set.
ListStream(ctx context.Context, prefix string) <-chan ObjectInfo
// Info returns human-readable storage location information.
Info() StorageInfo
}

View File

@ -1,90 +0,0 @@
package storage
import (
"fmt"
"net/url"
"strings"
)
// StorageURL represents a parsed storage URL.
type StorageURL struct {
Scheme string // "s3" or "file"
Bucket string // S3 bucket name (empty for file)
Prefix string // Path within bucket or filesystem base path
Endpoint string // S3 endpoint (optional, default AWS)
Region string // S3 region (optional)
UseSSL bool // Use HTTPS for S3 (default true)
}
// ParseStorageURL parses a storage URL string.
// Supported formats:
// - s3://bucket/prefix?endpoint=host&region=us-east-1&ssl=true
// - file:///absolute/path/to/backup
func ParseStorageURL(rawURL string) (*StorageURL, error) {
if rawURL == "" {
return nil, fmt.Errorf("storage URL is empty")
}
// Handle file:// URLs
if strings.HasPrefix(rawURL, "file://") {
path := strings.TrimPrefix(rawURL, "file://")
if path == "" {
return nil, fmt.Errorf("file URL path is empty")
}
return &StorageURL{
Scheme: "file",
Prefix: path,
}, nil
}
// Handle s3:// URLs
if strings.HasPrefix(rawURL, "s3://") {
u, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
bucket := u.Host
if bucket == "" {
return nil, fmt.Errorf("s3 URL missing bucket name")
}
prefix := strings.TrimPrefix(u.Path, "/")
query := u.Query()
useSSL := true
if query.Get("ssl") == "false" {
useSSL = false
}
return &StorageURL{
Scheme: "s3",
Bucket: bucket,
Prefix: prefix,
Endpoint: query.Get("endpoint"),
Region: query.Get("region"),
UseSSL: useSSL,
}, nil
}
return nil, fmt.Errorf("unsupported URL scheme: must start with s3:// or file://")
}
// String returns a human-readable representation of the storage URL.
func (u *StorageURL) String() string {
switch u.Scheme {
case "file":
return fmt.Sprintf("file://%s", u.Prefix)
case "s3":
endpoint := u.Endpoint
if endpoint == "" {
endpoint = "s3.amazonaws.com"
}
if u.Prefix != "" {
return fmt.Sprintf("s3://%s/%s (endpoint: %s)", u.Bucket, u.Prefix, endpoint)
}
return fmt.Sprintf("s3://%s (endpoint: %s)", u.Bucket, endpoint)
default:
return fmt.Sprintf("%s://?", u.Scheme)
}
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
"sync"
"testing"
@ -12,122 +13,100 @@ import (
"git.eeqj.de/sneak/vaultik/internal/config"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/log"
"git.eeqj.de/sneak/vaultik/internal/s3"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/storage"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// MockStorer implements storage.Storer for testing
type MockStorer struct {
mu sync.Mutex
data map[string][]byte
calls []string
// MockS3Client implements a mock S3 client for testing
type MockS3Client struct {
mu sync.Mutex
storage map[string][]byte
calls []string
}
func NewMockStorer() *MockStorer {
return &MockStorer{
data: make(map[string][]byte),
calls: make([]string, 0),
func NewMockS3Client() *MockS3Client {
return &MockS3Client{
storage: make(map[string][]byte),
calls: make([]string, 0),
}
}
func (m *MockStorer) Put(ctx context.Context, key string, reader io.Reader) error {
func (m *MockS3Client) PutObject(ctx context.Context, key string, reader io.Reader) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "Put:"+key)
m.calls = append(m.calls, "PutObject:"+key)
data, err := io.ReadAll(reader)
if err != nil {
return err
}
m.data[key] = data
m.storage[key] = data
return nil
}
func (m *MockStorer) PutWithProgress(ctx context.Context, key string, reader io.Reader, size int64, progress storage.ProgressCallback) error {
return m.Put(ctx, key, reader)
func (m *MockS3Client) PutObjectWithProgress(ctx context.Context, key string, reader io.Reader, size int64, progress s3.ProgressCallback) error {
// For testing, just call PutObject
return m.PutObject(ctx, key, reader)
}
func (m *MockStorer) Get(ctx context.Context, key string) (io.ReadCloser, error) {
func (m *MockS3Client) GetObject(ctx context.Context, key string) (io.ReadCloser, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "Get:"+key)
data, exists := m.data[key]
m.calls = append(m.calls, "GetObject:"+key)
data, exists := m.storage[key]
if !exists {
return nil, storage.ErrNotFound
return nil, fmt.Errorf("key not found: %s", key)
}
return io.NopCloser(bytes.NewReader(data)), nil
}
func (m *MockStorer) Stat(ctx context.Context, key string) (*storage.ObjectInfo, error) {
func (m *MockS3Client) StatObject(ctx context.Context, key string) (*s3.ObjectInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "Stat:"+key)
data, exists := m.data[key]
m.calls = append(m.calls, "StatObject:"+key)
data, exists := m.storage[key]
if !exists {
return nil, storage.ErrNotFound
return nil, fmt.Errorf("key not found: %s", key)
}
return &storage.ObjectInfo{
return &s3.ObjectInfo{
Key: key,
Size: int64(len(data)),
}, nil
}
func (m *MockStorer) Delete(ctx context.Context, key string) error {
func (m *MockS3Client) DeleteObject(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "Delete:"+key)
delete(m.data, key)
m.calls = append(m.calls, "DeleteObject:"+key)
delete(m.storage, key)
return nil
}
func (m *MockStorer) List(ctx context.Context, prefix string) ([]string, error) {
func (m *MockS3Client) ListObjects(ctx context.Context, prefix string) ([]*s3.ObjectInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, "List:"+prefix)
var keys []string
for key := range m.data {
m.calls = append(m.calls, "ListObjects:"+prefix)
var objects []*s3.ObjectInfo
for key, data := range m.storage {
if len(prefix) == 0 || (len(key) >= len(prefix) && key[:len(prefix)] == prefix) {
keys = append(keys, key)
objects = append(objects, &s3.ObjectInfo{
Key: key,
Size: int64(len(data)),
})
}
}
return keys, nil
return objects, nil
}
func (m *MockStorer) ListStream(ctx context.Context, prefix string) <-chan storage.ObjectInfo {
ch := make(chan storage.ObjectInfo)
go func() {
defer close(ch)
m.mu.Lock()
defer m.mu.Unlock()
for key, data := range m.data {
if len(prefix) == 0 || (len(key) >= len(prefix) && key[:len(prefix)] == prefix) {
ch <- storage.ObjectInfo{
Key: key,
Size: int64(len(data)),
}
}
}
}()
return ch
}
func (m *MockStorer) Info() storage.StorageInfo {
return storage.StorageInfo{
Type: "mock",
Location: "memory",
}
}
// GetCalls returns the list of operations that were called
func (m *MockStorer) GetCalls() []string {
// GetCalls returns the list of S3 operations that were called
func (m *MockS3Client) GetCalls() []string {
m.mu.Lock()
defer m.mu.Unlock()
@ -137,11 +116,11 @@ func (m *MockStorer) GetCalls() []string {
}
// GetStorageSize returns the number of objects in storage
func (m *MockStorer) GetStorageSize() int {
func (m *MockS3Client) GetStorageSize() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.data)
return len(m.storage)
}
// TestEndToEndBackup tests the full backup workflow with mocked dependencies
@ -179,8 +158,8 @@ func TestEndToEndBackup(t *testing.T) {
}
}
// Create mock storage
mockStorage := NewMockStorer()
// Create mock S3 client
mockS3 := NewMockS3Client()
// Create test configuration
cfg := &config.Config{
@ -202,7 +181,7 @@ func TestEndToEndBackup(t *testing.T) {
}
// For a true end-to-end test, we'll create a simpler test that focuses on
// the core backup logic using the scanner directly with our mock storage
// the core backup logic using the scanner directly with our mock S3 client
ctx := context.Background()
// Create in-memory database
@ -216,12 +195,12 @@ func TestEndToEndBackup(t *testing.T) {
repos := database.NewRepositories(db)
// Create scanner with mock storage
// Create scanner with mock S3 client
scanner := snapshot.NewScanner(snapshot.ScannerConfig{
FS: fs,
ChunkSize: cfg.ChunkSize.Int64(),
Repositories: repos,
Storage: mockStorage,
S3Client: mockS3,
MaxBlobSize: cfg.BlobSizeLimit.Int64(),
CompressionLevel: cfg.CompressionLevel,
AgeRecipients: cfg.AgeRecipients,
@ -253,15 +232,15 @@ func TestEndToEndBackup(t *testing.T) {
assert.Greater(t, result.ChunksCreated, 0, "Should create chunks")
assert.Greater(t, result.BlobsCreated, 0, "Should create blobs")
// Verify storage operations
calls := mockStorage.GetCalls()
t.Logf("Storage operations performed: %v", calls)
// Verify S3 operations
calls := mockS3.GetCalls()
t.Logf("S3 operations performed: %v", calls)
// Should have uploaded at least one blob
blobUploads := 0
for _, call := range calls {
if len(call) > 4 && call[:4] == "Put:" {
if len(call) > 10 && call[4:10] == "blobs/" {
if len(call) > 10 && call[:10] == "PutObject:" {
if len(call) > 16 && call[10:16] == "blobs/" {
blobUploads++
}
}
@ -285,8 +264,8 @@ func TestEndToEndBackup(t *testing.T) {
require.NoError(t, err)
assert.Greater(t, len(fileChunks), 0, "Should have chunks for file1.txt")
// Verify blobs were uploaded to storage
assert.Greater(t, mockStorage.GetStorageSize(), 0, "Should have blobs in storage")
// Verify blobs were uploaded to S3
assert.Greater(t, mockS3.GetStorageSize(), 0, "Should have blobs in S3 storage")
// Complete the snapshot - just verify we got results
// In a real integration test, we'd update the snapshot record
@ -304,7 +283,7 @@ func TestEndToEndBackup(t *testing.T) {
t.Logf(" Bytes scanned: %d", result.BytesScanned)
t.Logf(" Chunks created: %d", result.ChunksCreated)
t.Logf(" Blobs created: %d", result.BlobsCreated)
t.Logf(" Storage size: %d objects", mockStorage.GetStorageSize())
t.Logf(" S3 storage size: %d objects", mockS3.GetStorageSize())
}
// TestBackupAndVerify tests backing up files and verifying the blobs
@ -322,8 +301,8 @@ func TestBackupAndVerify(t *testing.T) {
err = afero.WriteFile(fs, "/data/test.txt", []byte(testContent), 0644)
require.NoError(t, err)
// Create mock storage
mockStorage := NewMockStorer()
// Create mock S3 client
mockS3 := NewMockS3Client()
// Create test database
ctx := context.Background()
@ -342,7 +321,7 @@ func TestBackupAndVerify(t *testing.T) {
FS: fs,
ChunkSize: int64(1024 * 16), // 16KB chunks
Repositories: repos,
Storage: mockStorage,
S3Client: mockS3,
MaxBlobSize: int64(1024 * 1024), // 1MB blobs
CompressionLevel: 3,
AgeRecipients: []string{"age1ezrjmfpwsc95svdg0y54mums3zevgzu0x0ecq2f7tp8a05gl0sjq9q9wjg"}, // Test public key
@ -367,25 +346,25 @@ func TestBackupAndVerify(t *testing.T) {
// Verify backup created blobs
assert.Greater(t, result.BlobsCreated, 0, "Should create at least one blob")
assert.Equal(t, mockStorage.GetStorageSize(), result.BlobsCreated, "Storage should have the blobs")
assert.Equal(t, mockS3.GetStorageSize(), result.BlobsCreated, "S3 should have the blobs")
// Verify we can retrieve the blob from storage
objects, err := mockStorage.List(ctx, "blobs/")
// Verify we can retrieve the blob from S3
objects, err := mockS3.ListObjects(ctx, "blobs/")
require.NoError(t, err)
assert.Len(t, objects, result.BlobsCreated, "Should have correct number of blobs in storage")
assert.Len(t, objects, result.BlobsCreated, "Should have correct number of blobs in S3")
// Get the first blob and verify it exists
if len(objects) > 0 {
blobKey := objects[0]
blobKey := objects[0].Key
t.Logf("Verifying blob: %s", blobKey)
// Get blob info
blobInfo, err := mockStorage.Stat(ctx, blobKey)
blobInfo, err := mockS3.StatObject(ctx, blobKey)
require.NoError(t, err)
assert.Greater(t, blobInfo.Size, int64(0), "Blob should have content")
// Get blob content
reader, err := mockStorage.Get(ctx, blobKey)
reader, err := mockS3.GetObject(ctx, blobKey)
require.NoError(t, err)
defer func() { _ = reader.Close() }()

View File

@ -21,9 +21,9 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
allBlobsReferenced := make(map[string]bool)
manifestCount := 0
// List all snapshots in storage
// List all snapshots in S3
log.Info("Listing remote snapshots")
objectCh := v.Storage.ListStream(v.ctx, "metadata/")
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
var snapshotIDs []string
for object := range objectCh {
@ -73,10 +73,10 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
log.Info("Processed manifests", "count", manifestCount, "unique_blobs_referenced", len(allBlobsReferenced))
// List all blobs in storage
// List all blobs in S3
log.Info("Listing all blobs in storage")
allBlobs := make(map[string]int64) // hash -> size
blobObjectCh := v.Storage.ListStream(v.ctx, "blobs/")
blobObjectCh := v.S3Client.ListObjectsStream(v.ctx, "blobs/", true)
for object := range blobObjectCh {
if object.Err != nil {
@ -136,7 +136,7 @@ func (v *Vaultik) PruneBlobs(opts *PruneOptions) error {
for i, hash := range unreferencedBlobs {
blobPath := fmt.Sprintf("blobs/%s/%s/%s", hash[:2], hash[2:4], hash)
if err := v.Storage.Delete(v.ctx, blobPath); err != nil {
if err := v.S3Client.RemoveObject(v.ctx, blobPath); err != nil {
log.Error("Failed to delete blob", "hash", hash, "error", err)
continue
}

View File

@ -265,7 +265,7 @@ func (v *Vaultik) CreateSnapshot(opts *SnapshotCreateOptions) error {
func (v *Vaultik) ListSnapshots(jsonOutput bool) error {
// Get all remote snapshots
remoteSnapshots := make(map[string]bool)
objectCh := v.Storage.ListStream(v.ctx, "metadata/")
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
for object := range objectCh {
if object.Err != nil {
@ -546,7 +546,7 @@ func (v *Vaultik) VerifySnapshot(snapshotID string, deep bool) error {
return nil
} else {
// Just check existence
_, err := v.Storage.Stat(v.ctx, blobPath)
_, err := v.S3Client.StatObject(v.ctx, blobPath)
if err != nil {
fmt.Printf(" Missing: %s (%s)\n", blob.Hash, humanize.Bytes(uint64(blob.CompressedSize)))
missing++
@ -581,7 +581,7 @@ func (v *Vaultik) VerifySnapshot(snapshotID string, deep bool) error {
func (v *Vaultik) getManifestSize(snapshotID string) (int64, error) {
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
reader, err := v.Storage.Get(v.ctx, manifestPath)
reader, err := v.S3Client.GetObject(v.ctx, manifestPath)
if err != nil {
return 0, fmt.Errorf("downloading manifest: %w", err)
}
@ -598,7 +598,7 @@ func (v *Vaultik) getManifestSize(snapshotID string) (int64, error) {
func (v *Vaultik) downloadManifest(snapshotID string) (*snapshot.Manifest, error) {
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
reader, err := v.Storage.Get(v.ctx, manifestPath)
reader, err := v.S3Client.GetObject(v.ctx, manifestPath)
if err != nil {
return nil, err
}
@ -613,10 +613,10 @@ func (v *Vaultik) downloadManifest(snapshotID string) (*snapshot.Manifest, error
}
func (v *Vaultik) deleteSnapshot(snapshotID string) error {
// First, delete from storage
// First, delete from S3
// List all objects under metadata/{snapshotID}/
prefix := fmt.Sprintf("metadata/%s/", snapshotID)
objectCh := v.Storage.ListStream(v.ctx, prefix)
objectCh := v.S3Client.ListObjectsStream(v.ctx, prefix, true)
var objectsToDelete []string
for object := range objectCh {
@ -628,7 +628,7 @@ func (v *Vaultik) deleteSnapshot(snapshotID string) error {
// Delete all objects
for _, key := range objectsToDelete {
if err := v.Storage.Delete(v.ctx, key); err != nil {
if err := v.S3Client.RemoveObject(v.ctx, key); err != nil {
return fmt.Errorf("removing %s: %w", key, err)
}
}
@ -658,7 +658,7 @@ func (v *Vaultik) syncWithRemote() error {
// Get all remote snapshot IDs
remoteSnapshots := make(map[string]bool)
objectCh := v.Storage.ListStream(v.ctx, "metadata/")
objectCh := v.S3Client.ListObjectsStream(v.ctx, "metadata/", false)
for object := range objectCh {
if object.Err != nil {

View File

@ -10,8 +10,8 @@ import (
"git.eeqj.de/sneak/vaultik/internal/crypto"
"git.eeqj.de/sneak/vaultik/internal/database"
"git.eeqj.de/sneak/vaultik/internal/globals"
"git.eeqj.de/sneak/vaultik/internal/s3"
"git.eeqj.de/sneak/vaultik/internal/snapshot"
"git.eeqj.de/sneak/vaultik/internal/storage"
"github.com/spf13/afero"
"go.uber.org/fx"
)
@ -22,7 +22,7 @@ type Vaultik struct {
Config *config.Config
DB *database.DB
Repositories *database.Repositories
Storage storage.Storer
S3Client *s3.Client
ScannerFactory snapshot.ScannerFactory
SnapshotManager *snapshot.SnapshotManager
Shutdowner fx.Shutdowner
@ -46,7 +46,7 @@ type VaultikParams struct {
Config *config.Config
DB *database.DB
Repositories *database.Repositories
Storage storage.Storer
S3Client *s3.Client
ScannerFactory snapshot.ScannerFactory
SnapshotManager *snapshot.SnapshotManager
Shutdowner fx.Shutdowner
@ -72,7 +72,7 @@ func New(params VaultikParams) *Vaultik {
Config: params.Config,
DB: params.DB,
Repositories: params.Repositories,
Storage: params.Storage,
S3Client: params.S3Client,
ScannerFactory: params.ScannerFactory,
SnapshotManager: params.SnapshotManager,
Shutdowner: params.Shutdowner,

View File

@ -36,7 +36,7 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
manifestPath := fmt.Sprintf("metadata/%s/manifest.json.zst", snapshotID)
log.Info("Downloading manifest", "path", manifestPath)
manifestReader, err := v.Storage.Get(v.ctx, manifestPath)
manifestReader, err := v.S3Client.GetObject(v.ctx, manifestPath)
if err != nil {
return fmt.Errorf("failed to download manifest: %w", err)
}
@ -57,7 +57,7 @@ func (v *Vaultik) RunDeepVerify(snapshotID string, opts *VerifyOptions) error {
dbPath := fmt.Sprintf("metadata/%s/db.zst.age", snapshotID)
log.Info("Downloading encrypted database", "path", dbPath)
dbReader, err := v.Storage.Get(v.ctx, dbPath)
dbReader, err := v.S3Client.GetObject(v.ctx, dbPath)
if err != nil {
return fmt.Errorf("failed to download database: %w", err)
}
@ -236,10 +236,10 @@ func (v *Vaultik) verifyBlobExistence(manifest *snapshot.Manifest) error {
// Construct blob path
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blob.Hash[:2], blob.Hash[2:4], blob.Hash)
// Check blob exists
stat, err := v.Storage.Stat(v.ctx, blobPath)
// Check blob exists with HeadObject
stat, err := v.S3Client.StatObject(v.ctx, blobPath)
if err != nil {
return fmt.Errorf("blob %s missing from storage: %w", blob.Hash, err)
return fmt.Errorf("blob %s missing from S3: %w", blob.Hash, err)
}
// Verify size matches
@ -258,7 +258,7 @@ func (v *Vaultik) verifyBlobExistence(manifest *snapshot.Manifest) error {
}
}
log.Info("✓ All blobs exist in storage")
log.Info("✓ All blobs exist in S3")
return nil
}
@ -295,7 +295,7 @@ func (v *Vaultik) performDeepVerification(manifest *snapshot.Manifest, db *sql.D
func (v *Vaultik) verifyBlob(blobInfo snapshot.BlobInfo, db *sql.DB) error {
// Download blob
blobPath := fmt.Sprintf("blobs/%s/%s/%s", blobInfo.Hash[:2], blobInfo.Hash[2:4], blobInfo.Hash)
reader, err := v.Storage.Get(v.ctx, blobPath)
reader, err := v.S3Client.GetObject(v.ctx, blobPath)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}