3 Commits

Author SHA1 Message Date
user
d36e511032 refactor: use Params struct for imageprocessor constructor
All checks were successful
check / check (push) Successful in 1m36s
Rename NewImageProcessor(maxInputBytes) to New(Params{}) with a Params
struct containing MaxInputBytes. Zero-value Params{} uses sensible
defaults (DefaultMaxInputBytes). All callers updated.

Addresses review feedback on PR #37.
2026-03-17 19:53:44 -07:00
user
18f218e039 bound imageprocessor.Process input read to prevent unbounded memory use
ImageProcessor.Process used io.ReadAll without a size limit, allowing
arbitrarily large inputs to exhaust memory. Add a configurable
maxInputBytes limit (default 50 MiB, matching the fetcher limit) and
reject inputs that exceed it with ErrInputDataTooLarge.

Also bound the cached source content read in the service layer to
prevent unexpectedly large cached files from consuming unbounded memory.

Extracted loadCachedSource helper to reduce nesting complexity.
2026-03-17 19:53:44 -07:00
9c29cb57df feat: parse version prefix from migration filenames (#33)
All checks were successful
check / check (push) Successful in 1m49s
Closes #28

Migration filenames now follow the pattern `<version>_<description>.sql` (e.g. `001_initial_schema.sql`). The version stored in `schema_migrations` is the numeric prefix only, not the full filename stem.

## Changes

- **`ParseMigrationVersion()`** — new exported function that extracts the numeric prefix from migration filenames. Validates that the prefix is purely numeric and rejects malformed filenames (empty prefix, non-numeric characters, leading underscore).
- **Renamed `001.sql` → `001_initial_schema.sql`** — migration files can now have descriptive names while the tracked version remains `001`. This is safe pre-1.0.0 (no installed base).
- **Deduplicated migration logic** — `runMigrations()` and `ApplyMigrations()` now share a single `applyMigrations()` implementation, plus extracted `collectMigrations()` and `ensureMigrationsTable()` helpers.
- **Unit tests** — `TestParseMigrationVersion` covers valid patterns (version-only, with description, multi-digit, multiple underscores) and error cases (empty, leading underscore, non-numeric, mixed alphanumeric). `TestApplyMigrations` and `TestApplyMigrationsIdempotent` verify end-to-end migration application against an in-memory SQLite database.

Co-authored-by: user <user@Mac.lan guest wan>
Reviewed-on: #33
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-18 03:18:38 +01:00
9 changed files with 284 additions and 125 deletions

View File

@@ -35,6 +35,41 @@ type Database struct {
config *config.Config config *config.Config
} }
// ParseMigrationVersion extracts the numeric version prefix from a migration
// filename. Filenames must follow the pattern "<version>.sql" or
// "<version>_<description>.sql", where version is a zero-padded numeric
// string (e.g. "001", "002"). Returns the version string and an error if
// the filename does not match the expected pattern.
func ParseMigrationVersion(filename string) (string, error) {
name := strings.TrimSuffix(filename, filepath.Ext(filename))
if name == "" {
return "", fmt.Errorf("invalid migration filename %q: empty name", filename)
}
// Split on underscore to separate version from description.
// If there's no underscore, the entire stem is the version.
version := name
if idx := strings.IndexByte(name, '_'); idx >= 0 {
version = name[:idx]
}
if version == "" {
return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename)
}
// Validate the version is purely numeric.
for _, ch := range version {
if ch < '0' || ch > '9' {
return "", fmt.Errorf(
"invalid migration filename %q: version %q contains non-numeric character %q",
filename, version, string(ch),
)
}
}
return version, nil
}
// New creates a new Database instance. // New creates a new Database instance.
func New(lc fx.Lifecycle, params Params) (*Database, error) { func New(lc fx.Lifecycle, params Params) (*Database, error) {
s := &Database{ s := &Database{
@@ -84,96 +119,33 @@ func (s *Database) connect(ctx context.Context) error {
s.db = db s.db = db
s.log.Info("database connected") s.log.Info("database connected")
return s.runMigrations(ctx) return ApplyMigrations(ctx, s.db, s.log)
} }
func (s *Database) runMigrations(ctx context.Context) error { // collectMigrations reads the embedded schema directory and returns
// Create migrations tracking table // migration filenames sorted lexicographically.
_, err := s.db.ExecContext(ctx, ` func collectMigrations() ([]string, error) {
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get list of migration files
entries, err := schemaFS.ReadDir("schema") entries, err := schemaFS.ReadDir("schema")
if err != nil { if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err) return nil, fmt.Errorf("failed to read schema directory: %w", err)
} }
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string var migrations []string
for _, entry := range entries { for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name()) migrations = append(migrations, entry.Name())
} }
} }
sort.Strings(migrations) sort.Strings(migrations)
// Apply each migration that hasn't been applied yet return migrations, nil
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if count > 0 {
s.log.Debug("migration already applied", "version", version)
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
s.log.Info("applying migration", "version", version)
_, err = s.db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
// Record migration as applied
_, err = s.db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
s.log.Info("migration applied successfully", "version", version)
}
return nil
} }
// DB returns the underlying sql.DB. // ensureMigrationsTable creates the schema_migrations tracking table if
func (s *Database) DB() *sql.DB { // it does not already exist.
return s.db func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
}
// ApplyMigrations applies all migrations to the given database.
// This is useful for testing where you want to use the real schema
// without the full fx lifecycle.
func ApplyMigrations(db *sql.DB) error {
ctx := context.Background()
// Create migrations tracking table
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations ( CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY, version TEXT PRIMARY KEY,
@@ -184,27 +156,32 @@ func ApplyMigrations(db *sql.DB) error {
return fmt.Errorf("failed to create migrations table: %w", err) return fmt.Errorf("failed to create migrations table: %w", err)
} }
// Get list of migration files return nil
entries, err := schemaFS.ReadDir("schema") }
// ApplyMigrations applies all pending migrations to db. An optional logger
// may be provided for informational output; pass nil for silent operation.
// This is exported so tests can apply the real schema without the full fx
// lifecycle.
func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
if err := ensureMigrationsTable(ctx, db); err != nil {
return err
}
migrations, err := collectMigrations()
if err != nil { if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err) return err
} }
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Apply each migration that hasn't been applied yet
for _, migration := range migrations { for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration)) version, parseErr := ParseMigrationVersion(migration)
if parseErr != nil {
return parseErr
}
// Check if already applied // Check if already applied.
var count int var count int
err := db.QueryRowContext(ctx, err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?", "SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version, version,
@@ -214,29 +191,46 @@ func ApplyMigrations(db *sql.DB) error {
} }
if count > 0 { if count > 0 {
if log != nil {
log.Debug("migration already applied", "version", version)
}
continue continue
} }
// Read and apply migration // Read and apply migration.
content, err := schemaFS.ReadFile(filepath.Join("schema", migration)) content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil { if readErr != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err) return fmt.Errorf("failed to read migration %s: %w", migration, readErr)
} }
_, err = db.ExecContext(ctx, string(content)) if log != nil {
if err != nil { log.Info("applying migration", "version", version)
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
} }
// Record migration as applied _, execErr := db.ExecContext(ctx, string(content))
_, err = db.ExecContext(ctx, if execErr != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, execErr)
}
// Record migration as applied.
_, recErr := db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)", "INSERT INTO schema_migrations (version) VALUES (?)",
version, version,
) )
if err != nil { if recErr != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err) return fmt.Errorf("failed to record migration %s: %w", migration, recErr)
}
if log != nil {
log.Info("migration applied successfully", "version", version)
} }
} }
return nil return nil
} }
// DB returns the underlying sql.DB.
func (s *Database) DB() *sql.DB {
return s.db
}

View File

@@ -0,0 +1,155 @@
package database
import (
"context"
"database/sql"
"testing"
_ "modernc.org/sqlite" // SQLite driver registration
)
func TestParseMigrationVersion(t *testing.T) {
tests := []struct {
name string
filename string
want string
wantErr bool
}{
{
name: "version only",
filename: "001.sql",
want: "001",
},
{
name: "version with description",
filename: "001_initial_schema.sql",
want: "001",
},
{
name: "multi-digit version",
filename: "042_add_indexes.sql",
want: "042",
},
{
name: "long version number",
filename: "00001_long_prefix.sql",
want: "00001",
},
{
name: "description with multiple underscores",
filename: "003_add_user_auth_tables.sql",
want: "003",
},
{
name: "empty filename",
filename: ".sql",
wantErr: true,
},
{
name: "leading underscore",
filename: "_description.sql",
wantErr: true,
},
{
name: "non-numeric version",
filename: "abc_migration.sql",
wantErr: true,
},
{
name: "mixed alphanumeric version",
filename: "001a_migration.sql",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseMigrationVersion(tt.filename)
if tt.wantErr {
if err == nil {
t.Errorf("ParseMigrationVersion(%q) expected error, got %q", tt.filename, got)
}
return
}
if err != nil {
t.Errorf("ParseMigrationVersion(%q) unexpected error: %v", tt.filename, err)
return
}
if got != tt.want {
t.Errorf("ParseMigrationVersion(%q) = %q, want %q", tt.filename, got, tt.want)
}
})
}
}
func TestApplyMigrations(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open in-memory database: %v", err)
}
defer db.Close()
// Apply migrations should succeed.
if err := ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("ApplyMigrations failed: %v", err)
}
// Verify the schema_migrations table recorded the version.
var version string
err = db.QueryRowContext(context.Background(),
"SELECT version FROM schema_migrations LIMIT 1",
).Scan(&version)
if err != nil {
t.Fatalf("failed to query schema_migrations: %v", err)
}
if version != "001" {
t.Errorf("expected version %q, got %q", "001", version)
}
// Verify a table from the migration exists (source_content).
var tableName string
err = db.QueryRowContext(context.Background(),
"SELECT name FROM sqlite_master WHERE type='table' AND name='source_content'",
).Scan(&tableName)
if err != nil {
t.Fatalf("expected source_content table to exist: %v", err)
}
}
func TestApplyMigrationsIdempotent(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open in-memory database: %v", err)
}
defer db.Close()
// Apply twice should succeed (idempotent).
if err := ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("first ApplyMigrations failed: %v", err)
}
if err := ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("second ApplyMigrations failed: %v", err)
}
// Should still have exactly one migration recorded.
var count int
err = db.QueryRowContext(context.Background(),
"SELECT COUNT(*) FROM schema_migrations",
).Scan(&count)
if err != nil {
t.Fatalf("failed to count schema_migrations: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migration record, got %d", count)
}
}

View File

@@ -82,7 +82,7 @@ func setupTestDB(t *testing.T) *sql.DB {
t.Fatalf("failed to open test db: %v", err) t.Fatalf("failed to open test db: %v", err)
} }
if err := database.ApplyMigrations(db); err != nil { if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("failed to apply migrations: %v", err) t.Fatalf("failed to apply migrations: %v", err)
} }

View File

@@ -44,11 +44,20 @@ type ImageProcessor struct {
maxInputBytes int64 maxInputBytes int64
} }
// NewImageProcessor creates a new image processor with the given maximum input // Params holds configuration for creating an ImageProcessor.
// size in bytes. If maxInputBytes is <= 0, DefaultMaxInputBytes is used. // Zero values use sensible defaults (MaxInputBytes defaults to DefaultMaxInputBytes).
func NewImageProcessor(maxInputBytes int64) *ImageProcessor { type Params struct {
// MaxInputBytes is the maximum allowed input size in bytes.
// If <= 0, DefaultMaxInputBytes is used.
MaxInputBytes int64
}
// New creates a new image processor with the given parameters.
// A zero-value Params{} uses sensible defaults.
func New(params Params) *ImageProcessor {
initVips() initVips()
maxInputBytes := params.MaxInputBytes
if maxInputBytes <= 0 { if maxInputBytes <= 0 {
maxInputBytes = DefaultMaxInputBytes maxInputBytes = DefaultMaxInputBytes
} }

View File

@@ -71,7 +71,7 @@ func createTestPNG(t *testing.T, width, height int) []byte {
} }
func TestImageProcessor_ResizeJPEG(t *testing.T) { func TestImageProcessor_ResizeJPEG(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 800, 600) input := createTestJPEG(t, 800, 600)
@@ -118,7 +118,7 @@ func TestImageProcessor_ResizeJPEG(t *testing.T) {
} }
func TestImageProcessor_ConvertToPNG(t *testing.T) { func TestImageProcessor_ConvertToPNG(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 200, 150) input := createTestJPEG(t, 200, 150)
@@ -151,7 +151,7 @@ func TestImageProcessor_ConvertToPNG(t *testing.T) {
} }
func TestImageProcessor_OriginalSize(t *testing.T) { func TestImageProcessor_OriginalSize(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 640, 480) input := createTestJPEG(t, 640, 480)
@@ -179,7 +179,7 @@ func TestImageProcessor_OriginalSize(t *testing.T) {
} }
func TestImageProcessor_FitContain(t *testing.T) { func TestImageProcessor_FitContain(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
// 800x400 image (2:1 aspect) into 400x400 box with contain // 800x400 image (2:1 aspect) into 400x400 box with contain
@@ -206,7 +206,7 @@ func TestImageProcessor_FitContain(t *testing.T) {
} }
func TestImageProcessor_ProportionalScale_WidthOnly(t *testing.T) { func TestImageProcessor_ProportionalScale_WidthOnly(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
// 800x600 image, request width=400 height=0 // 800x600 image, request width=400 height=0
@@ -236,7 +236,7 @@ func TestImageProcessor_ProportionalScale_WidthOnly(t *testing.T) {
} }
func TestImageProcessor_ProportionalScale_HeightOnly(t *testing.T) { func TestImageProcessor_ProportionalScale_HeightOnly(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
// 800x600 image, request width=0 height=300 // 800x600 image, request width=0 height=300
@@ -266,7 +266,7 @@ func TestImageProcessor_ProportionalScale_HeightOnly(t *testing.T) {
} }
func TestImageProcessor_ProcessPNG(t *testing.T) { func TestImageProcessor_ProcessPNG(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
input := createTestPNG(t, 400, 300) input := createTestPNG(t, 400, 300)
@@ -298,7 +298,7 @@ func TestImageProcessor_ImplementsInterface(t *testing.T) {
} }
func TestImageProcessor_SupportedFormats(t *testing.T) { func TestImageProcessor_SupportedFormats(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
inputFormats := proc.SupportedInputFormats() inputFormats := proc.SupportedInputFormats()
if len(inputFormats) == 0 { if len(inputFormats) == 0 {
@@ -312,7 +312,7 @@ func TestImageProcessor_SupportedFormats(t *testing.T) {
} }
func TestImageProcessor_RejectsOversizedInput(t *testing.T) { func TestImageProcessor_RejectsOversizedInput(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
// Create an image that exceeds MaxInputDimension (e.g., 10000x100) // Create an image that exceeds MaxInputDimension (e.g., 10000x100)
@@ -337,7 +337,7 @@ func TestImageProcessor_RejectsOversizedInput(t *testing.T) {
} }
func TestImageProcessor_RejectsOversizedInputHeight(t *testing.T) { func TestImageProcessor_RejectsOversizedInputHeight(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
// Create an image with oversized height // Create an image with oversized height
@@ -361,7 +361,7 @@ func TestImageProcessor_RejectsOversizedInputHeight(t *testing.T) {
} }
func TestImageProcessor_AcceptsMaxDimensionInput(t *testing.T) { func TestImageProcessor_AcceptsMaxDimensionInput(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
// Create an image at exactly MaxInputDimension - should be accepted // Create an image at exactly MaxInputDimension - should be accepted
@@ -383,7 +383,7 @@ func TestImageProcessor_AcceptsMaxDimensionInput(t *testing.T) {
} }
func TestImageProcessor_EncodeWebP(t *testing.T) { func TestImageProcessor_EncodeWebP(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 200, 150) input := createTestJPEG(t, 200, 150)
@@ -426,7 +426,7 @@ func TestImageProcessor_EncodeWebP(t *testing.T) {
} }
func TestImageProcessor_DecodeAVIF(t *testing.T) { func TestImageProcessor_DecodeAVIF(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
// Load test AVIF file // Load test AVIF file
@@ -468,7 +468,7 @@ func TestImageProcessor_DecodeAVIF(t *testing.T) {
func TestImageProcessor_RejectsOversizedInputData(t *testing.T) { func TestImageProcessor_RejectsOversizedInputData(t *testing.T) {
// Create a processor with a very small byte limit // Create a processor with a very small byte limit
const limit = 1024 const limit = 1024
proc := NewImageProcessor(limit) proc := New(Params{MaxInputBytes: limit})
ctx := context.Background() ctx := context.Background()
// Create a valid JPEG that exceeds the byte limit // Create a valid JPEG that exceeds the byte limit
@@ -499,7 +499,7 @@ func TestImageProcessor_AcceptsInputWithinLimit(t *testing.T) {
input := createTestJPEG(t, 10, 10) input := createTestJPEG(t, 10, 10)
limit := int64(len(input)) * 10 // 10× headroom limit := int64(len(input)) * 10 // 10× headroom
proc := NewImageProcessor(limit) proc := New(Params{MaxInputBytes: limit})
ctx := context.Background() ctx := context.Background()
req := &ImageRequest{ req := &ImageRequest{
@@ -518,20 +518,20 @@ func TestImageProcessor_AcceptsInputWithinLimit(t *testing.T) {
func TestImageProcessor_DefaultMaxInputBytes(t *testing.T) { func TestImageProcessor_DefaultMaxInputBytes(t *testing.T) {
// Passing 0 should use the default // Passing 0 should use the default
proc := NewImageProcessor(0) proc := New(Params{})
if proc.maxInputBytes != DefaultMaxInputBytes { if proc.maxInputBytes != DefaultMaxInputBytes {
t.Errorf("maxInputBytes = %d, want %d", proc.maxInputBytes, DefaultMaxInputBytes) t.Errorf("maxInputBytes = %d, want %d", proc.maxInputBytes, DefaultMaxInputBytes)
} }
// Passing negative should also use the default // Passing negative should also use the default
proc = NewImageProcessor(-1) proc = New(Params{MaxInputBytes: -1})
if proc.maxInputBytes != DefaultMaxInputBytes { if proc.maxInputBytes != DefaultMaxInputBytes {
t.Errorf("maxInputBytes = %d, want %d", proc.maxInputBytes, DefaultMaxInputBytes) t.Errorf("maxInputBytes = %d, want %d", proc.maxInputBytes, DefaultMaxInputBytes)
} }
} }
func TestImageProcessor_EncodeAVIF(t *testing.T) { func TestImageProcessor_EncodeAVIF(t *testing.T) {
proc := NewImageProcessor(0) proc := New(Params{})
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 200, 150) input := createTestJPEG(t, 200, 150)

View File

@@ -82,7 +82,7 @@ func NewService(cfg *ServiceConfig) (*Service, error) {
return &Service{ return &Service{
cache: cfg.Cache, cache: cfg.Cache,
fetcher: fetcher, fetcher: fetcher,
processor: NewImageProcessor(maxResponseSize), processor: New(Params{MaxInputBytes: maxResponseSize}),
signer: signer, signer: signer,
whitelist: NewHostWhitelist(cfg.Whitelist), whitelist: NewHostWhitelist(cfg.Whitelist),
log: log, log: log,

View File

@@ -16,7 +16,7 @@ func setupStatsTestDB(t *testing.T) *sql.DB {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := database.ApplyMigrations(db); err != nil { if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(func() { db.Close() }) t.Cleanup(func() { db.Close() })

View File

@@ -2,6 +2,7 @@ package imgcache
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"image" "image"
"image/color" "image/color"
@@ -193,7 +194,7 @@ func setupServiceTestDB(t *testing.T) *sql.DB {
} }
// Use the real production schema via migrations // Use the real production schema via migrations
if err := database.ApplyMigrations(db); err != nil { if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("failed to apply migrations: %v", err) t.Fatalf("failed to apply migrations: %v", err)
} }