Three coordinated changes drop restore wall-clock by orders of
magnitude on real-world snapshots and bring memory use back under
control:
* Streaming download into the disk cache. New
blobDiskCache.PutFromReader takes an io.Reader and io.Copy's it
straight into the cache file. The old downloadBlob path did
io.ReadAll on the decrypted plaintext stream — for a 50 GB blob
that meant 50 GB in RAM before the cache write. The whole chain
(Storage.Get → age.Decrypt → zstd.NewReader → io.Copy) is now
fully streaming; peak RAM per blob is bounded by ~64 KB of
internal age/zstd buffers plus the io.Copy buffer.
* Chunk extraction via ReadAt. After a blob is on disk, restore
reads chunks via blobDiskCache.ReadAt(hash, offset, length) so
only the chunk's bytes ever touch RAM. The previous code path
called blobCache.Get for every cache-hit chunk, which read the
entire blob (e.g. 10 GB) from disk into a []byte just to slice
out a few KB — single-handedly ~900 ms per cache hit on the
user's photo snapshot.
* Locality-aware restore ordering. New restorePlan indexes
file→blob_set and blob→file_set at restore start, then drives
the loop so that every file whose full blob set is on disk is
drained before any new blob downloads. After the drain queue
empties, the planner picks the pending file with the smallest
uncached-blob count, downloads those blobs, and drains again.
A sweep is forced right before each download so the just-
completed blob is evicted before the next one is Put, keeping
peak disk-cache occupancy at 1 for path-order-adversarial
snapshots.
The restore hot path also moves onto a restoreSession struct so
restoreFile/restoreRegularFile/etc. take only the per-call file
argument instead of threading 9+ parameters through every signature.
The new BlobRepository.GetAll method lets the session build a single
blob-id → blob-hash map at start instead of doing one DB query per
chunk.
TestRestoreLocalityAndReadAt passes: peak_len ≤ 1, get_calls = 0,
readat_calls > 0, every blob fetched exactly once.
246 lines
5.8 KiB
Go
246 lines
5.8 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"sneak.berlin/go/vaultik/internal/log"
|
|
)
|
|
|
|
type BlobRepository struct {
|
|
db *DB
|
|
}
|
|
|
|
func NewBlobRepository(db *DB) *BlobRepository {
|
|
return &BlobRepository{db: db}
|
|
}
|
|
|
|
func (r *BlobRepository) Create(ctx context.Context, tx *sql.Tx, blob *Blob) error {
|
|
query := `
|
|
INSERT INTO blobs (id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
var finishedTS, uploadedTS *int64
|
|
if blob.FinishedTS != nil {
|
|
ts := blob.FinishedTS.Unix()
|
|
finishedTS = &ts
|
|
}
|
|
if blob.UploadedTS != nil {
|
|
ts := blob.UploadedTS.Unix()
|
|
uploadedTS = &ts
|
|
}
|
|
|
|
var err error
|
|
if tx != nil {
|
|
_, err = tx.ExecContext(ctx, query, blob.ID, blob.Hash, blob.CreatedTS.Unix(),
|
|
finishedTS, blob.UncompressedSize, blob.CompressedSize, uploadedTS)
|
|
} else {
|
|
_, err = r.db.ExecWithLog(ctx, query, blob.ID, blob.Hash, blob.CreatedTS.Unix(),
|
|
finishedTS, blob.UncompressedSize, blob.CompressedSize, uploadedTS)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("inserting blob: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *BlobRepository) GetByHash(ctx context.Context, hash string) (*Blob, error) {
|
|
query := `
|
|
SELECT id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts
|
|
FROM blobs
|
|
WHERE blob_hash = ?
|
|
`
|
|
|
|
var blob Blob
|
|
var createdTSUnix int64
|
|
var finishedTSUnix, uploadedTSUnix sql.NullInt64
|
|
|
|
err := r.db.conn.QueryRowContext(ctx, query, hash).Scan(
|
|
&blob.ID,
|
|
&blob.Hash,
|
|
&createdTSUnix,
|
|
&finishedTSUnix,
|
|
&blob.UncompressedSize,
|
|
&blob.CompressedSize,
|
|
&uploadedTSUnix,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying blob: %w", err)
|
|
}
|
|
|
|
blob.CreatedTS = time.Unix(createdTSUnix, 0).UTC()
|
|
if finishedTSUnix.Valid {
|
|
ts := time.Unix(finishedTSUnix.Int64, 0).UTC()
|
|
blob.FinishedTS = &ts
|
|
}
|
|
if uploadedTSUnix.Valid {
|
|
ts := time.Unix(uploadedTSUnix.Int64, 0).UTC()
|
|
blob.UploadedTS = &ts
|
|
}
|
|
return &blob, nil
|
|
}
|
|
|
|
// GetByID retrieves a blob by its ID
|
|
func (r *BlobRepository) GetByID(ctx context.Context, id string) (*Blob, error) {
|
|
query := `
|
|
SELECT id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts
|
|
FROM blobs
|
|
WHERE id = ?
|
|
`
|
|
|
|
var blob Blob
|
|
var createdTSUnix int64
|
|
var finishedTSUnix, uploadedTSUnix sql.NullInt64
|
|
|
|
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
|
|
&blob.ID,
|
|
&blob.Hash,
|
|
&createdTSUnix,
|
|
&finishedTSUnix,
|
|
&blob.UncompressedSize,
|
|
&blob.CompressedSize,
|
|
&uploadedTSUnix,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying blob: %w", err)
|
|
}
|
|
|
|
blob.CreatedTS = time.Unix(createdTSUnix, 0).UTC()
|
|
if finishedTSUnix.Valid {
|
|
ts := time.Unix(finishedTSUnix.Int64, 0).UTC()
|
|
blob.FinishedTS = &ts
|
|
}
|
|
if uploadedTSUnix.Valid {
|
|
ts := time.Unix(uploadedTSUnix.Int64, 0).UTC()
|
|
blob.UploadedTS = &ts
|
|
}
|
|
return &blob, nil
|
|
}
|
|
|
|
// GetAll returns every blob row keyed by blob ID. Useful at restore
|
|
// start to translate the per-chunk blob_id references in chunkToBlobMap
|
|
// into blob hashes without doing one GetByID query per chunk.
|
|
func (r *BlobRepository) GetAll(ctx context.Context) (map[string]*Blob, error) {
|
|
query := `
|
|
SELECT id, blob_hash, created_ts, finished_ts, uncompressed_size, compressed_size, uploaded_ts
|
|
FROM blobs
|
|
`
|
|
|
|
rows, err := r.db.conn.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying blobs: %w", err)
|
|
}
|
|
defer CloseRows(rows)
|
|
|
|
out := make(map[string]*Blob)
|
|
for rows.Next() {
|
|
var blob Blob
|
|
var createdTSUnix int64
|
|
var finishedTSUnix, uploadedTSUnix sql.NullInt64
|
|
if err := rows.Scan(
|
|
&blob.ID,
|
|
&blob.Hash,
|
|
&createdTSUnix,
|
|
&finishedTSUnix,
|
|
&blob.UncompressedSize,
|
|
&blob.CompressedSize,
|
|
&uploadedTSUnix,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scanning blob: %w", err)
|
|
}
|
|
blob.CreatedTS = time.Unix(createdTSUnix, 0).UTC()
|
|
if finishedTSUnix.Valid {
|
|
ts := time.Unix(finishedTSUnix.Int64, 0).UTC()
|
|
blob.FinishedTS = &ts
|
|
}
|
|
if uploadedTSUnix.Valid {
|
|
ts := time.Unix(uploadedTSUnix.Int64, 0).UTC()
|
|
blob.UploadedTS = &ts
|
|
}
|
|
out[blob.ID.String()] = &blob
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// UpdateFinished updates a blob when it's finalized
|
|
func (r *BlobRepository) UpdateFinished(ctx context.Context, tx *sql.Tx, id string, hash string, uncompressedSize, compressedSize int64) error {
|
|
query := `
|
|
UPDATE blobs
|
|
SET blob_hash = ?, finished_ts = ?, uncompressed_size = ?, compressed_size = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
now := time.Now().UTC().Unix()
|
|
var err error
|
|
if tx != nil {
|
|
_, err = tx.ExecContext(ctx, query, hash, now, uncompressedSize, compressedSize, id)
|
|
} else {
|
|
_, err = r.db.ExecWithLog(ctx, query, hash, now, uncompressedSize, compressedSize, id)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("updating blob: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateUploaded marks a blob as uploaded
|
|
func (r *BlobRepository) UpdateUploaded(ctx context.Context, tx *sql.Tx, id string) error {
|
|
query := `
|
|
UPDATE blobs
|
|
SET uploaded_ts = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
now := time.Now().UTC().Unix()
|
|
var err error
|
|
if tx != nil {
|
|
_, err = tx.ExecContext(ctx, query, now, id)
|
|
} else {
|
|
_, err = r.db.ExecWithLog(ctx, query, now, id)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("marking blob as uploaded: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteOrphaned deletes blobs that are not referenced by any snapshot
|
|
func (r *BlobRepository) DeleteOrphaned(ctx context.Context) error {
|
|
query := `
|
|
DELETE FROM blobs
|
|
WHERE NOT EXISTS (
|
|
SELECT 1 FROM snapshot_blobs
|
|
WHERE snapshot_blobs.blob_id = blobs.id
|
|
)
|
|
`
|
|
|
|
result, err := r.db.ExecWithLog(ctx, query)
|
|
if err != nil {
|
|
return fmt.Errorf("deleting orphaned blobs: %w", err)
|
|
}
|
|
|
|
rowsAffected, _ := result.RowsAffected()
|
|
if rowsAffected > 0 {
|
|
log.Debug("Deleted orphaned blobs", "count", rowsAffected)
|
|
}
|
|
|
|
return nil
|
|
}
|