3 Commits

Author SHA1 Message Date
user
55bb620de0 docs: update README and config to reflect exact-match-only whitelist
All checks were successful
check / check (push) Successful in 5s
Remove suffix match documentation and config comments since whitelist
now only supports exact host matches.
2026-03-15 11:18:44 -07:00
user
215ddb7f72 fix: remove suffix matching from host whitelist
Whitelist entries now support exact host matches only. Leading dots
in patterns are stripped for backwards compatibility (.example.com
becomes an exact match for example.com). Suffix matching that would
match arbitrary subdomains is no longer supported.

Closes #27
2026-03-15 11:18:25 -07:00
user
27739da046 test: add failing tests for removing suffix matching from whitelist
Suffix matching (.example.com matching subdomains) should not be
supported. Whitelist entries should be exact host matches only.
Leading dots should be stripped and treated as exact matches.
2026-03-15 11:18:01 -07:00
32 changed files with 475 additions and 1511 deletions

View File

@@ -67,10 +67,7 @@ hosts require an HMAC-SHA256 signature.
#### Signature Specification #### Signature Specification
Signatures use HMAC-SHA256 and include an expiration timestamp to Signatures use HMAC-SHA256 and include an expiration timestamp to
prevent replay attacks. Signatures are **exact match only**: every prevent replay attacks.
component (host, path, query, dimensions, format, expiration) must
match exactly what was signed. No suffix matching, wildcard matching,
or partial matching is supported.
**Signed data format** (colon-separated): **Signed data format** (colon-separated):
@@ -101,9 +98,7 @@ expiration 1704067200:
**Whitelist patterns:** **Whitelist patterns:**
- **Exact match**: `cdn.example.com` — matches only that host - **Exact match only**: `cdn.example.com` — matches only that host
- **Suffix match**: `.example.com` — matches `cdn.example.com`,
`images.example.com`, and `example.com`
### Configuration ### Configuration

View File

@@ -17,7 +17,10 @@ import (
"sneak.berlin/go/pixa/internal/server" "sneak.berlin/go/pixa/internal/server"
) )
var Version string //nolint:gochecknoglobals // set by ldflags var (
Appname = "pixad" //nolint:gochecknoglobals // set by ldflags
Version string //nolint:gochecknoglobals // set by ldflags
)
var configPath string //nolint:gochecknoglobals // cobra flag var configPath string //nolint:gochecknoglobals // cobra flag
@@ -37,6 +40,7 @@ func main() {
} }
func run(_ *cobra.Command, _ []string) { func run(_ *cobra.Command, _ []string) {
globals.Appname = Appname
globals.Version = Version globals.Version = Version
// Set config path in environment if specified via flag // Set config path in environment if specified via flag

View File

@@ -13,8 +13,7 @@ state_dir: ./data
# Generate with: openssl rand -base64 32 # Generate with: openssl rand -base64 32
signing_key: "CHANGE_ME_generate_with_openssl_rand_base64_32" signing_key: "CHANGE_ME_generate_with_openssl_rand_base64_32"
# Hosts that don't require signatures # Hosts that don't require signatures (exact match only)
# Use "." prefix for wildcard subdomain matching (e.g., ".example.com" matches "cdn.example.com")
whitelist_hosts: whitelist_hosts:
- s3.sneak.cloud - s3.sneak.cloud
- static.sneak.cloud - static.sneak.cloud

View File

@@ -1,83 +0,0 @@
// Package allowlist provides host-based URL allow-listing for the image proxy.
package allowlist
import (
"net/url"
"strings"
)
// HostAllowList checks whether source hosts are permitted.
type HostAllowList struct {
// exactHosts contains hosts that must match exactly (e.g., "cdn.example.com")
exactHosts map[string]struct{}
// suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com")
suffixHosts []string
}
// New creates a HostAllowList from a list of host patterns.
// Patterns starting with "." are treated as suffix matches.
// Examples:
// - "cdn.example.com" - exact match only
// - ".example.com" - matches cdn.example.com, images.example.com, etc.
func New(patterns []string) *HostAllowList {
w := &HostAllowList{
exactHosts: make(map[string]struct{}),
suffixHosts: make([]string, 0),
}
for _, pattern := range patterns {
pattern = strings.ToLower(strings.TrimSpace(pattern))
if pattern == "" {
continue
}
if strings.HasPrefix(pattern, ".") {
w.suffixHosts = append(w.suffixHosts, pattern)
} else {
w.exactHosts[pattern] = struct{}{}
}
}
return w
}
// IsAllowed checks if a URL's host is in the allow list.
func (w *HostAllowList) IsAllowed(u *url.URL) bool {
if u == nil {
return false
}
host := strings.ToLower(u.Hostname())
if host == "" {
return false
}
// Check exact match
if _, ok := w.exactHosts[host]; ok {
return true
}
// Check suffix match
for _, suffix := range w.suffixHosts {
if strings.HasSuffix(host, suffix) {
return true
}
// Also match if host equals the suffix without the leading dot
// e.g., pattern ".example.com" should match "example.com"
if host == strings.TrimPrefix(suffix, ".") {
return true
}
}
return false
}
// IsEmpty returns true if the allow list has no entries.
func (w *HostAllowList) IsEmpty() bool {
return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0
}
// Count returns the total number of allow list entries.
func (w *HostAllowList) Count() int {
return len(w.exactHosts) + len(w.suffixHosts)
}

View File

@@ -9,7 +9,6 @@ import (
"log/slog" "log/slog"
"path/filepath" "path/filepath"
"sort" "sort"
"strconv"
"strings" "strings"
"go.uber.org/fx" "go.uber.org/fx"
@@ -22,10 +21,6 @@ import (
//go:embed schema/*.sql //go:embed schema/*.sql
var schemaFS embed.FS var schemaFS embed.FS
// bootstrapVersion is the migration that creates the schema_migrations
// table itself. It is applied before the normal migration loop.
const bootstrapVersion = 0
// Params defines dependencies for Database. // Params defines dependencies for Database.
type Params struct { type Params struct {
fx.In fx.In
@@ -40,46 +35,6 @@ type Database struct {
config *config.Config config *config.Config
} }
// ParseMigrationVersion extracts the numeric version prefix from a migration
// filename. Filenames must follow the pattern "<version>.sql" or
// "<version>_<description>.sql", where version is a zero-padded numeric
// string (e.g. "001", "002"). Returns the version as an integer and an
// error if the filename does not match the expected pattern.
func ParseMigrationVersion(filename string) (int, error) {
name := strings.TrimSuffix(filename, filepath.Ext(filename))
if name == "" {
return 0, fmt.Errorf("invalid migration filename %q: empty name", filename)
}
// Split on underscore to separate version from description.
// If there's no underscore, the entire stem is the version.
versionStr := name
if idx := strings.IndexByte(name, '_'); idx >= 0 {
versionStr = name[:idx]
}
if versionStr == "" {
return 0, fmt.Errorf("invalid migration filename %q: empty version prefix", filename)
}
// Validate the version is purely numeric.
for _, ch := range versionStr {
if ch < '0' || ch > '9' {
return 0, fmt.Errorf(
"invalid migration filename %q: version %q contains non-numeric character %q",
filename, versionStr, string(ch),
)
}
}
version, err := strconv.Atoi(versionStr)
if err != nil {
return 0, fmt.Errorf("invalid migration filename %q: %w", filename, err)
}
return version, nil
}
// New creates a new Database instance. // New creates a new Database instance.
func New(lc fx.Lifecycle, params Params) (*Database, error) { func New(lc fx.Lifecycle, params Params) (*Database, error) {
s := &Database{ s := &Database{
@@ -129,87 +84,43 @@ func (s *Database) connect(ctx context.Context) error {
s.db = db s.db = db
s.log.Info("database connected") s.log.Info("database connected")
return ApplyMigrations(ctx, s.db, s.log) return s.runMigrations(ctx)
} }
// collectMigrations reads the embedded schema directory and returns func (s *Database) runMigrations(ctx context.Context) error {
// migration filenames sorted lexicographically. // Create migrations tracking table
func collectMigrations() ([]string, error) { _, err := s.db.ExecContext(ctx, `
entries, err := schemaFS.ReadDir("schema") CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read schema directory: %w", err) return fmt.Errorf("failed to create migrations table: %w", err)
} }
var migrations []string // Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries { for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name()) migrations = append(migrations, entry.Name())
} }
} }
sort.Strings(migrations) sort.Strings(migrations)
return migrations, nil // Apply each migration that hasn't been applied yet
}
// bootstrapMigrationsTable ensures the schema_migrations table exists
// by applying 000.sql if the table is missing.
func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error {
var tableExists int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableExists)
if err != nil {
return fmt.Errorf("failed to check for migrations table: %w", err)
}
if tableExists > 0 {
return nil
}
content, err := schemaFS.ReadFile("schema/000.sql")
if err != nil {
return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err)
}
if log != nil {
log.Info("applying bootstrap migration", "version", bootstrapVersion)
}
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply bootstrap migration: %w", err)
}
return nil
}
// ApplyMigrations applies all pending migrations to db. An optional logger
// may be provided for informational output; pass nil for silent operation.
// This is exported so tests can apply the real schema without the full fx
// lifecycle.
func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
if err := bootstrapMigrationsTable(ctx, db, log); err != nil {
return err
}
migrations, err := collectMigrations()
if err != nil {
return err
}
for _, migration := range migrations { for _, migration := range migrations {
version, parseErr := ParseMigrationVersion(migration) version := strings.TrimSuffix(migration, filepath.Ext(migration))
if parseErr != nil {
return parseErr
}
// Check if already applied. // Check if already applied
var count int var count int
err := s.db.QueryRowContext(ctx,
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?", "SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version, version,
).Scan(&count) ).Scan(&count)
@@ -218,40 +129,34 @@ func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
} }
if count > 0 { if count > 0 {
if log != nil { s.log.Debug("migration already applied", "version", version)
log.Debug("migration already applied", "version", version)
}
continue continue
} }
// Read and apply migration. // Read and apply migration
content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration)) content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if readErr != nil { if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, readErr) return fmt.Errorf("failed to read migration %s: %w", migration, err)
} }
if log != nil { s.log.Info("applying migration", "version", version)
log.Info("applying migration", "version", version)
_, err = s.db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
} }
_, execErr := db.ExecContext(ctx, string(content)) // Record migration as applied
if execErr != nil { _, err = s.db.ExecContext(ctx,
return fmt.Errorf("failed to apply migration %s: %w", migration, execErr)
}
// Record migration as applied.
_, recErr := db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)", "INSERT INTO schema_migrations (version) VALUES (?)",
version, version,
) )
if recErr != nil { if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, recErr) return fmt.Errorf("failed to record migration %s: %w", migration, err)
} }
if log != nil { s.log.Info("migration applied successfully", "version", version)
log.Info("migration applied successfully", "version", version)
}
} }
return nil return nil
@@ -261,3 +166,77 @@ func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
func (s *Database) DB() *sql.DB { func (s *Database) DB() *sql.DB {
return s.db return s.db
} }
// ApplyMigrations applies all migrations to the given database.
// This is useful for testing where you want to use the real schema
// without the full fx lifecycle.
func ApplyMigrations(db *sql.DB) error {
ctx := context.Background()
// Create migrations tracking table
_, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if count > 0 {
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
// Record migration as applied
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
}
return nil
}

View File

@@ -1,224 +0,0 @@
package database
import (
"context"
"database/sql"
"testing"
_ "modernc.org/sqlite" // SQLite driver registration
)
// openTestDB returns a fresh in-memory SQLite database.
func openTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open test db: %v", err)
}
t.Cleanup(func() { db.Close() })
return db
}
func TestParseMigrationVersion(t *testing.T) {
tests := []struct {
name string
filename string
want int
wantErr bool
}{
{
name: "version only",
filename: "001.sql",
want: 1,
},
{
name: "version with description",
filename: "001_initial_schema.sql",
want: 1,
},
{
name: "multi-digit version",
filename: "042_add_indexes.sql",
want: 42,
},
{
name: "long version number",
filename: "00001_long_prefix.sql",
want: 1,
},
{
name: "description with multiple underscores",
filename: "003_add_user_auth_tables.sql",
want: 3,
},
{
name: "empty filename",
filename: ".sql",
wantErr: true,
},
{
name: "leading underscore",
filename: "_description.sql",
wantErr: true,
},
{
name: "non-numeric version",
filename: "abc_migration.sql",
wantErr: true,
},
{
name: "mixed alphanumeric version",
filename: "001a_migration.sql",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseMigrationVersion(tt.filename)
if tt.wantErr {
if err == nil {
t.Errorf("ParseMigrationVersion(%q) expected error, got %d", tt.filename, got)
}
return
}
if err != nil {
t.Errorf("ParseMigrationVersion(%q) unexpected error: %v", tt.filename, err)
return
}
if got != tt.want {
t.Errorf("ParseMigrationVersion(%q) = %d, want %d", tt.filename, got, tt.want)
}
})
}
}
func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := ApplyMigrations(ctx, db, nil); err != nil {
t.Fatalf("ApplyMigrations failed: %v", err)
}
// The schema_migrations table must exist and contain at least
// version 0 (the bootstrap) and 1 (the initial schema).
rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version")
if err != nil {
t.Fatalf("failed to query schema_migrations: %v", err)
}
defer rows.Close()
var versions []int
for rows.Next() {
var v int
if err := rows.Scan(&v); err != nil {
t.Fatalf("failed to scan version: %v", err)
}
versions = append(versions, v)
}
if err := rows.Err(); err != nil {
t.Fatalf("row iteration error: %v", err)
}
if len(versions) < 2 {
t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions)
}
if versions[0] != 0 {
t.Errorf("first recorded migration = %d, want %d", versions[0], 0)
}
if versions[1] != 1 {
t.Errorf("second recorded migration = %d, want %d", versions[1], 1)
}
// Verify that the application tables created by 001.sql exist.
for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} {
var count int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
table,
).Scan(&count)
if err != nil {
t.Fatalf("failed to check for table %s: %v", table, err)
}
if count != 1 {
t.Errorf("table %s does not exist after migrations", table)
}
}
}
func TestApplyMigrations_Idempotent(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := ApplyMigrations(ctx, db, nil); err != nil {
t.Fatalf("first ApplyMigrations failed: %v", err)
}
// Running a second time must succeed without errors.
if err := ApplyMigrations(ctx, db, nil); err != nil {
t.Fatalf("second ApplyMigrations failed: %v", err)
}
// Verify no duplicate rows in schema_migrations.
var count int
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = 0").Scan(&count)
if err != nil {
t.Fatalf("failed to count version 0 rows: %v", err)
}
if count != 1 {
t.Errorf("expected exactly 1 row for version 0, got %d", count)
}
}
func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := bootstrapMigrationsTable(ctx, db, nil); err != nil {
t.Fatalf("bootstrapMigrationsTable failed: %v", err)
}
// schema_migrations table must exist.
var tableCount int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableCount)
if err != nil {
t.Fatalf("failed to check for table: %v", err)
}
if tableCount != 1 {
t.Fatalf("schema_migrations table not created")
}
// Version 0 must be recorded.
var recorded int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = 0",
).Scan(&recorded)
if err != nil {
t.Fatalf("failed to check version: %v", err)
}
if recorded != 1 {
t.Errorf("expected version 0 to be recorded, got count %d", recorded)
}
}

View File

@@ -1,9 +0,0 @@
-- Migration 000: Schema migrations tracking table
-- Applied as a bootstrap step before the normal migration loop.
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
INSERT OR IGNORE INTO schema_migrations (version) VALUES (0);

View File

@@ -5,10 +5,11 @@ import (
"go.uber.org/fx" "go.uber.org/fx"
) )
const appname = "pixad" // Build-time variables populated from main() via ldflags.
var (
// Version is populated from main() via ldflags. Appname string //nolint:gochecknoglobals // set from main
var Version string //nolint:gochecknoglobals // set from main Version string //nolint:gochecknoglobals // set from main
)
// Globals holds application-wide constants. // Globals holds application-wide constants.
type Globals struct { type Globals struct {
@@ -19,7 +20,7 @@ type Globals struct {
// New creates a new Globals instance from build-time variables. // New creates a new Globals instance from build-time variables.
func New(_ fx.Lifecycle) (*Globals, error) { func New(_ fx.Lifecycle) (*Globals, error) {
return &Globals{ return &Globals{
Appname: appname, Appname: Appname,
Version: Version, Version: Version,
}, nil }, nil
} }

View File

@@ -13,7 +13,6 @@ import (
"sneak.berlin/go/pixa/internal/database" "sneak.berlin/go/pixa/internal/database"
"sneak.berlin/go/pixa/internal/encurl" "sneak.berlin/go/pixa/internal/encurl"
"sneak.berlin/go/pixa/internal/healthcheck" "sneak.berlin/go/pixa/internal/healthcheck"
"sneak.berlin/go/pixa/internal/httpfetcher"
"sneak.berlin/go/pixa/internal/imgcache" "sneak.berlin/go/pixa/internal/imgcache"
"sneak.berlin/go/pixa/internal/logger" "sneak.berlin/go/pixa/internal/logger"
"sneak.berlin/go/pixa/internal/session" "sneak.berlin/go/pixa/internal/session"
@@ -73,7 +72,7 @@ func (s *Handlers) initImageService() error {
s.imgCache = cache s.imgCache = cache
// Create the fetcher config // Create the fetcher config
fetcherCfg := httpfetcher.DefaultConfig() fetcherCfg := imgcache.DefaultFetcherConfig()
fetcherCfg.AllowHTTP = s.config.AllowHTTP fetcherCfg.AllowHTTP = s.config.AllowHTTP
if s.config.UpstreamConnectionsPerHost > 0 { if s.config.UpstreamConnectionsPerHost > 0 {
fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost

View File

@@ -18,7 +18,6 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"sneak.berlin/go/pixa/internal/database" "sneak.berlin/go/pixa/internal/database"
"sneak.berlin/go/pixa/internal/httpfetcher"
"sneak.berlin/go/pixa/internal/imgcache" "sneak.berlin/go/pixa/internal/imgcache"
) )
@@ -83,7 +82,7 @@ func setupTestDB(t *testing.T) *sql.DB {
t.Fatalf("failed to open test db: %v", err) t.Fatalf("failed to open test db: %v", err)
} }
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil { if err := database.ApplyMigrations(db); err != nil {
t.Fatalf("failed to apply migrations: %v", err) t.Fatalf("failed to apply migrations: %v", err)
} }
@@ -117,16 +116,16 @@ func newMockFetcher(fs fs.FS) *mockFetcher {
return &mockFetcher{fs: fs} return &mockFetcher{fs: fs}
} }
func (f *mockFetcher) Fetch(ctx context.Context, url string) (*httpfetcher.FetchResult, error) { func (f *mockFetcher) Fetch(ctx context.Context, url string) (*imgcache.FetchResult, error) {
// Remove https:// prefix // Remove https:// prefix
path := url[8:] // Remove "https://" path := url[8:] // Remove "https://"
data, err := fs.ReadFile(f.fs, path) data, err := fs.ReadFile(f.fs, path)
if err != nil { if err != nil {
return nil, httpfetcher.ErrUpstreamError return nil, imgcache.ErrUpstreamError
} }
return &httpfetcher.FetchResult{ return &imgcache.FetchResult{
Content: io.NopCloser(bytes.NewReader(data)), Content: io.NopCloser(bytes.NewReader(data)),
ContentLength: int64(len(data)), ContentLength: int64(len(data)),
ContentType: "image/jpeg", ContentType: "image/jpeg",

View File

@@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"sneak.berlin/go/pixa/internal/httpfetcher"
"sneak.berlin/go/pixa/internal/imgcache" "sneak.berlin/go/pixa/internal/imgcache"
) )
@@ -98,13 +97,13 @@ func (s *Handlers) HandleImage() http.HandlerFunc {
) )
// Check for specific error types // Check for specific error types
if errors.Is(err, httpfetcher.ErrSSRFBlocked) { if errors.Is(err, imgcache.ErrSSRFBlocked) {
s.respondError(w, "forbidden", http.StatusForbidden) s.respondError(w, "forbidden", http.StatusForbidden)
return return
} }
if errors.Is(err, httpfetcher.ErrUpstreamError) { if errors.Is(err, imgcache.ErrUpstreamError) {
s.respondError(w, "upstream error", http.StatusBadGateway) s.respondError(w, "upstream error", http.StatusBadGateway)
return return

View File

@@ -11,7 +11,6 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"sneak.berlin/go/pixa/internal/encurl" "sneak.berlin/go/pixa/internal/encurl"
"sneak.berlin/go/pixa/internal/httpfetcher"
"sneak.berlin/go/pixa/internal/imgcache" "sneak.berlin/go/pixa/internal/imgcache"
) )
@@ -101,11 +100,11 @@ func (s *Handlers) HandleImageEnc() http.HandlerFunc {
// handleImageError converts image service errors to HTTP responses. // handleImageError converts image service errors to HTTP responses.
func (s *Handlers) handleImageError(w http.ResponseWriter, err error) { func (s *Handlers) handleImageError(w http.ResponseWriter, err error) {
switch { switch {
case errors.Is(err, httpfetcher.ErrSSRFBlocked): case errors.Is(err, imgcache.ErrSSRFBlocked):
s.respondError(w, "forbidden", http.StatusForbidden) s.respondError(w, "forbidden", http.StatusForbidden)
case errors.Is(err, httpfetcher.ErrUpstreamError): case errors.Is(err, imgcache.ErrUpstreamError):
s.respondError(w, "upstream error", http.StatusBadGateway) s.respondError(w, "upstream error", http.StatusBadGateway)
case errors.Is(err, httpfetcher.ErrUpstreamTimeout): case errors.Is(err, imgcache.ErrUpstreamTimeout):
s.respondError(w, "upstream timeout", http.StatusGatewayTimeout) s.respondError(w, "upstream timeout", http.StatusGatewayTimeout)
default: default:
s.log.Error("image request failed", "error", err) s.log.Error("image request failed", "error", err)

View File

@@ -1,329 +0,0 @@
package httpfetcher
import (
"context"
"errors"
"io"
"net"
"testing"
"testing/fstest"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.Timeout != DefaultFetchTimeout {
t.Errorf("Timeout = %v, want %v", cfg.Timeout, DefaultFetchTimeout)
}
if cfg.MaxResponseSize != DefaultMaxResponseSize {
t.Errorf("MaxResponseSize = %d, want %d", cfg.MaxResponseSize, DefaultMaxResponseSize)
}
if cfg.MaxConnectionsPerHost != DefaultMaxConnectionsPerHost {
t.Errorf("MaxConnectionsPerHost = %d, want %d",
cfg.MaxConnectionsPerHost, DefaultMaxConnectionsPerHost)
}
if cfg.AllowHTTP {
t.Error("AllowHTTP should default to false")
}
if len(cfg.AllowedContentTypes) == 0 {
t.Error("AllowedContentTypes should not be empty")
}
}
func TestNewWithNilConfigUsesDefaults(t *testing.T) {
f := New(nil)
if f == nil {
t.Fatal("New(nil) returned nil")
}
if f.config == nil {
t.Fatal("config should be populated from DefaultConfig")
}
if f.config.Timeout != DefaultFetchTimeout {
t.Errorf("Timeout = %v, want %v", f.config.Timeout, DefaultFetchTimeout)
}
}
func TestIsAllowedContentType(t *testing.T) {
f := New(DefaultConfig())
tests := []struct {
contentType string
want bool
}{
{"image/jpeg", true},
{"image/png", true},
{"image/webp", true},
{"image/jpeg; charset=utf-8", true},
{"IMAGE/JPEG", true},
{"text/html", false},
{"application/octet-stream", false},
{"", false},
}
for _, tc := range tests {
t.Run(tc.contentType, func(t *testing.T) {
got := f.isAllowedContentType(tc.contentType)
if got != tc.want {
t.Errorf("isAllowedContentType(%q) = %v, want %v", tc.contentType, got, tc.want)
}
})
}
}
func TestExtractHost(t *testing.T) {
tests := []struct {
url string
want string
}{
{"https://example.com/path", "example.com"},
{"http://example.com:8080/path", "example.com:8080"},
{"https://example.com", "example.com"},
{"https://example.com?q=1", "example.com"},
{"example.com/path", "example.com"},
{"", ""},
}
for _, tc := range tests {
t.Run(tc.url, func(t *testing.T) {
got := extractHost(tc.url)
if got != tc.want {
t.Errorf("extractHost(%q) = %q, want %q", tc.url, got, tc.want)
}
})
}
}
func TestIsLocalhost(t *testing.T) {
tests := []struct {
host string
want bool
}{
{"localhost", true},
{"LOCALHOST", true},
{"127.0.0.1", true},
{"::1", true},
{"[::1]", true},
{"foo.localhost", true},
{"foo.local", true},
{"example.com", false},
{"127.0.0.2", false}, // Handled by isPrivateIP, not isLocalhost string match
}
for _, tc := range tests {
t.Run(tc.host, func(t *testing.T) {
got := isLocalhost(tc.host)
if got != tc.want {
t.Errorf("isLocalhost(%q) = %v, want %v", tc.host, got, tc.want)
}
})
}
}
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
ip string
want bool
}{
{"127.0.0.1", true}, // loopback
{"10.0.0.1", true}, // private
{"192.168.1.1", true}, // private
{"172.16.0.1", true}, // private
{"169.254.1.1", true}, // link-local
{"0.0.0.0", true}, // unspecified
{"224.0.0.1", true}, // multicast
{"::1", true}, // IPv6 loopback
{"fe80::1", true}, // IPv6 link-local
{"8.8.8.8", false}, // public
{"2001:4860:4860::8888", false}, // public IPv6
}
for _, tc := range tests {
t.Run(tc.ip, func(t *testing.T) {
ip := net.ParseIP(tc.ip)
if ip == nil {
t.Fatalf("failed to parse IP %q", tc.ip)
}
got := isPrivateIP(ip)
if got != tc.want {
t.Errorf("isPrivateIP(%q) = %v, want %v", tc.ip, got, tc.want)
}
})
}
if !isPrivateIP(nil) {
t.Error("isPrivateIP(nil) should return true")
}
}
func TestValidateURL_RejectsNonHTTPS(t *testing.T) {
err := validateURL("http://example.com/path", false)
if !errors.Is(err, ErrUnsupportedScheme) {
t.Errorf("validateURL http = %v, want ErrUnsupportedScheme", err)
}
}
func TestValidateURL_AllowsHTTPWhenConfigured(t *testing.T) {
// Use a host that won't resolve (explicit .invalid TLD) so we don't hit DNS.
err := validateURL("http://nonexistent.invalid/path", true)
// We expect a host resolution error, not ErrUnsupportedScheme.
if errors.Is(err, ErrUnsupportedScheme) {
t.Error("validateURL with AllowHTTP should not return ErrUnsupportedScheme")
}
}
func TestValidateURL_RejectsLocalhost(t *testing.T) {
err := validateURL("https://localhost/path", false)
if !errors.Is(err, ErrSSRFBlocked) {
t.Errorf("validateURL localhost = %v, want ErrSSRFBlocked", err)
}
}
func TestValidateURL_EmptyHost(t *testing.T) {
err := validateURL("https:///path", false)
if !errors.Is(err, ErrInvalidHost) {
t.Errorf("validateURL empty host = %v, want ErrInvalidHost", err)
}
}
func TestMockFetcher_FetchesFile(t *testing.T) {
mockFS := fstest.MapFS{
"example.com/images/photo.jpg": &fstest.MapFile{Data: []byte("fake-jpeg-data")},
}
m := NewMock(mockFS)
result, err := m.Fetch(context.Background(), "https://example.com/images/photo.jpg")
if err != nil {
t.Fatalf("Fetch() error = %v", err)
}
defer func() { _ = result.Content.Close() }()
if result.ContentType != "image/jpeg" {
t.Errorf("ContentType = %q, want image/jpeg", result.ContentType)
}
data, err := io.ReadAll(result.Content)
if err != nil {
t.Fatalf("read content: %v", err)
}
if string(data) != "fake-jpeg-data" {
t.Errorf("Content = %q, want %q", string(data), "fake-jpeg-data")
}
if result.ContentLength != int64(len("fake-jpeg-data")) {
t.Errorf("ContentLength = %d, want %d", result.ContentLength, len("fake-jpeg-data"))
}
}
func TestMockFetcher_MissingFileReturnsUpstreamError(t *testing.T) {
mockFS := fstest.MapFS{}
m := NewMock(mockFS)
_, err := m.Fetch(context.Background(), "https://example.com/missing.jpg")
if !errors.Is(err, ErrUpstreamError) {
t.Errorf("Fetch() error = %v, want ErrUpstreamError", err)
}
}
func TestMockFetcher_RespectsContextCancellation(t *testing.T) {
mockFS := fstest.MapFS{
"example.com/photo.jpg": &fstest.MapFile{Data: []byte("data")},
}
m := NewMock(mockFS)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := m.Fetch(ctx, "https://example.com/photo.jpg")
if !errors.Is(err, context.Canceled) {
t.Errorf("Fetch() error = %v, want context.Canceled", err)
}
}
func TestDetectContentTypeFromPath(t *testing.T) {
tests := []struct {
path string
want string
}{
{"foo/bar.jpg", "image/jpeg"},
{"foo/bar.JPG", "image/jpeg"},
{"foo/bar.jpeg", "image/jpeg"},
{"foo/bar.png", "image/png"},
{"foo/bar.gif", "image/gif"},
{"foo/bar.webp", "image/webp"},
{"foo/bar.avif", "image/avif"},
{"foo/bar.svg", "image/svg+xml"},
{"foo/bar.bin", "application/octet-stream"},
{"foo/bar", "application/octet-stream"},
}
for _, tc := range tests {
t.Run(tc.path, func(t *testing.T) {
got := detectContentTypeFromPath(tc.path)
if got != tc.want {
t.Errorf("detectContentTypeFromPath(%q) = %q, want %q", tc.path, got, tc.want)
}
})
}
}
func TestLimitedReader_EnforcesLimit(t *testing.T) {
src := make([]byte, 100)
r := &limitedReader{
reader: &byteReader{data: src},
remaining: 50,
}
buf := make([]byte, 100)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("first Read error = %v", err)
}
if n > 50 {
t.Errorf("read %d bytes, should be capped at 50", n)
}
// Drain until limit is exhausted.
total := n
for total < 50 {
nn, err := r.Read(buf)
total += nn
if err != nil {
t.Fatalf("during drain: %v", err)
}
}
// Now the limit is exhausted — next read should error.
_, err = r.Read(buf)
if !errors.Is(err, ErrResponseTooLarge) {
t.Errorf("exhausted Read error = %v, want ErrResponseTooLarge", err)
}
}
// byteReader is a minimal io.Reader over a byte slice for testing.
type byteReader struct {
data []byte
pos int
}
func (r *byteReader) Read(p []byte) (int, error) {
if r.pos >= len(r.data) {
return 0, io.EOF
}
n := copy(p, r.data[r.pos:])
r.pos += n
return n, nil
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 281 B

View File

@@ -9,8 +9,6 @@ import (
"io" "io"
"path/filepath" "path/filepath"
"time" "time"
"sneak.berlin/go/pixa/internal/httpfetcher"
) )
// Cache errors. // Cache errors.
@@ -113,7 +111,7 @@ func (c *Cache) StoreSource(
ctx context.Context, ctx context.Context,
req *ImageRequest, req *ImageRequest,
content io.Reader, content io.Reader,
result *httpfetcher.FetchResult, result *FetchResult,
) (ContentHash, error) { ) (ContentHash, error) {
// Store content // Store content
contentHash, size, err := c.srcContent.Store(content) contentHash, size, err := c.srcContent.Store(content)

View File

@@ -9,7 +9,6 @@ import (
"time" "time"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
"sneak.berlin/go/pixa/internal/httpfetcher"
) )
func setupTestDB(t *testing.T) *sql.DB { func setupTestDB(t *testing.T) *sql.DB {
@@ -153,7 +152,7 @@ func TestCache_StoreAndLookup(t *testing.T) {
// Store source content // Store source content
sourceContent := []byte("fake jpeg data") sourceContent := []byte("fake jpeg data")
fetchResult := &httpfetcher.FetchResult{ fetchResult := &FetchResult{
ContentType: "image/jpeg", ContentType: "image/jpeg",
Headers: map[string][]string{"Content-Type": {"image/jpeg"}}, Headers: map[string][]string{"Content-Type": {"image/jpeg"}},
} }

View File

@@ -1,6 +1,4 @@
// Package httpfetcher fetches content from upstream HTTP origins with SSRF package imgcache
// protection, per-host connection limits, and content-type validation.
package httpfetcher
import ( import (
"context" "context"
@@ -39,55 +37,25 @@ var (
ErrUpstreamTimeout = errors.New("upstream request timeout") ErrUpstreamTimeout = errors.New("upstream request timeout")
) )
// Fetcher retrieves content from upstream origins. // FetcherConfig holds configuration for the upstream fetcher.
type Fetcher interface { type FetcherConfig struct {
// Fetch retrieves content from the given URL. // Timeout for upstream requests
Fetch(ctx context.Context, url string) (*FetchResult, error)
}
// FetchResult contains the result of fetching from upstream.
type FetchResult struct {
// Content is the raw image data.
Content io.ReadCloser
// ContentLength is the size in bytes (-1 if unknown).
ContentLength int64
// ContentType is the MIME type from upstream.
ContentType string
// Headers contains all response headers from upstream.
Headers map[string][]string
// StatusCode is the HTTP status code from upstream.
StatusCode int
// FetchDurationMs is how long the fetch took in milliseconds.
FetchDurationMs int64
// RemoteAddr is the IP:port of the upstream server.
RemoteAddr string
// HTTPVersion is the protocol version (e.g., "1.1", "2.0").
HTTPVersion string
// TLSVersion is the TLS protocol version (e.g., "TLS 1.3").
TLSVersion string
// TLSCipherSuite is the negotiated cipher suite name.
TLSCipherSuite string
}
// Config holds configuration for the upstream fetcher.
type Config struct {
// Timeout for upstream requests.
Timeout time.Duration Timeout time.Duration
// MaxResponseSize is the maximum allowed response body size. // MaxResponseSize is the maximum allowed response body size
MaxResponseSize int64 MaxResponseSize int64
// UserAgent to send to upstream servers. // UserAgent to send to upstream servers
UserAgent string UserAgent string
// AllowedContentTypes is an allow list of MIME types to accept. // AllowedContentTypes is a whitelist of MIME types to accept
AllowedContentTypes []string AllowedContentTypes []string
// AllowHTTP allows non-TLS connections (for testing only). // AllowHTTP allows non-TLS connections (for testing only)
AllowHTTP bool AllowHTTP bool
// MaxConnectionsPerHost limits concurrent connections to each upstream host. // MaxConnectionsPerHost limits concurrent connections to each upstream host
MaxConnectionsPerHost int MaxConnectionsPerHost int
} }
// DefaultConfig returns a Config with sensible defaults. // DefaultFetcherConfig returns sensible defaults.
func DefaultConfig() *Config { func DefaultFetcherConfig() *FetcherConfig {
return &Config{ return &FetcherConfig{
Timeout: DefaultFetchTimeout, Timeout: DefaultFetchTimeout,
MaxResponseSize: DefaultMaxResponseSize, MaxResponseSize: DefaultMaxResponseSize,
UserAgent: "pixa/1.0", UserAgent: "pixa/1.0",
@@ -104,18 +72,18 @@ func DefaultConfig() *Config {
} }
} }
// HTTPFetcher implements Fetcher with SSRF protection and per-host connection limits. // HTTPFetcher implements the Fetcher interface with SSRF protection.
type HTTPFetcher struct { type HTTPFetcher struct {
client *http.Client client *http.Client
config *Config config *FetcherConfig
hostSems map[string]chan struct{} // per-host semaphores hostSems map[string]chan struct{} // per-host semaphores
hostSemMu sync.Mutex // protects hostSems map hostSemMu sync.Mutex // protects hostSems map
} }
// New creates a new HTTPFetcher with SSRF protection. // NewHTTPFetcher creates a new fetcher with SSRF protection.
func New(config *Config) *HTTPFetcher { func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
if config == nil { if config == nil {
config = DefaultConfig() config = DefaultFetcherConfig()
} }
// Create transport with SSRF-safe dialer // Create transport with SSRF-safe dialer
@@ -282,7 +250,7 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
}, nil }, nil
} }
// isAllowedContentType checks if the content type is in the allow list. // isAllowedContentType checks if the content type is in the whitelist.
func (f *HTTPFetcher) isAllowedContentType(contentType string) bool { func (f *HTTPFetcher) isAllowedContentType(contentType string) bool {
// Extract the MIME type without parameters // Extract the MIME type without parameters
mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0]) mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0])

View File

@@ -169,6 +169,66 @@ type Whitelist interface {
IsWhitelisted(u *url.URL) bool IsWhitelisted(u *url.URL) bool
} }
// Fetcher fetches images from upstream origins
type Fetcher interface {
// Fetch retrieves an image from the origin
Fetch(ctx context.Context, url string) (*FetchResult, error)
}
// FetchResult contains the result of fetching from upstream
type FetchResult struct {
// Content is the raw image data
Content io.ReadCloser
// ContentLength is the size in bytes (-1 if unknown)
ContentLength int64
// ContentType is the MIME type from upstream
ContentType string
// Headers contains all response headers from upstream
Headers map[string][]string
// StatusCode is the HTTP status code from upstream
StatusCode int
// FetchDurationMs is how long the fetch took in milliseconds
FetchDurationMs int64
// RemoteAddr is the IP:port of the upstream server
RemoteAddr string
// HTTPVersion is the protocol version (e.g., "1.1", "2.0")
HTTPVersion string
// TLSVersion is the TLS protocol version (e.g., "TLS 1.3")
TLSVersion string
// TLSCipherSuite is the negotiated cipher suite name
TLSCipherSuite string
}
// Processor handles image transformation (resize, format conversion)
type Processor interface {
// Process transforms an image according to the request
Process(ctx context.Context, input io.Reader, req *ImageRequest) (*ProcessResult, error)
// SupportedInputFormats returns MIME types this processor can read
SupportedInputFormats() []string
// SupportedOutputFormats returns formats this processor can write
SupportedOutputFormats() []ImageFormat
}
// ProcessResult contains the result of image processing
type ProcessResult struct {
// Content is the processed image data
Content io.ReadCloser
// ContentLength is the size in bytes
ContentLength int64
// ContentType is the MIME type of the output
ContentType string
// Width is the output image width
Width int
// Height is the output image height
Height int
// InputWidth is the original image width before processing
InputWidth int
// InputHeight is the original image height before processing
InputHeight int
// InputFormat is the detected input format (e.g., "jpeg", "png")
InputFormat string
}
// Storage handles persistent storage of cached content // Storage handles persistent storage of cached content
type Storage interface { type Storage interface {
// Store saves content and returns its hash // Store saves content and returns its hash

View File

@@ -1,6 +1,4 @@
// Package magic detects image formats from magic bytes and validates package imgcache
// content against declared MIME types.
package magic
import ( import (
"bytes" "bytes"
@@ -29,20 +27,6 @@ const (
MIMETypeSVG = MIMEType("image/svg+xml") MIMETypeSVG = MIMEType("image/svg+xml")
) )
// ImageFormat represents supported output image formats.
// This mirrors the type in imgcache to avoid circular imports.
type ImageFormat string
// Supported image output formats.
const (
FormatOriginal ImageFormat = "orig"
FormatJPEG ImageFormat = "jpeg"
FormatPNG ImageFormat = "png"
FormatWebP ImageFormat = "webp"
FormatAVIF ImageFormat = "avif"
FormatGIF ImageFormat = "gif"
)
// MinMagicBytes is the minimum number of bytes needed to detect format. // MinMagicBytes is the minimum number of bytes needed to detect format.
const MinMagicBytes = 12 const MinMagicBytes = 12
@@ -205,7 +189,7 @@ func PeekAndValidate(r io.Reader, declaredType string) (io.Reader, error) {
return io.MultiReader(bytes.NewReader(buf), r), nil return io.MultiReader(bytes.NewReader(buf), r), nil
} }
// MIMEToImageFormat converts a MIME type to an ImageFormat. // MIMEToImageFormat converts a MIME type to our ImageFormat type.
func MIMEToImageFormat(mimeType string) (ImageFormat, bool) { func MIMEToImageFormat(mimeType string) (ImageFormat, bool) {
normalized := normalizeMIMEType(mimeType) normalized := normalizeMIMEType(mimeType)
switch MIMEType(normalized) { switch MIMEType(normalized) {
@@ -224,7 +208,7 @@ func MIMEToImageFormat(mimeType string) (ImageFormat, bool) {
} }
} }
// ImageFormatToMIME converts an ImageFormat to a MIME type string. // ImageFormatToMIME converts our ImageFormat to a MIME type string.
func ImageFormatToMIME(format ImageFormat) string { func ImageFormatToMIME(format ImageFormat) string {
switch format { switch format {
case FormatJPEG: case FormatJPEG:

View File

@@ -1,4 +1,4 @@
package magic package imgcache
import ( import (
"bytes" "bytes"

View File

@@ -1,4 +1,4 @@
package httpfetcher package imgcache
import ( import (
"context" "context"
@@ -10,15 +10,15 @@ import (
"strings" "strings"
) )
// MockFetcher implements Fetcher using an embedded filesystem. // MockFetcher implements the Fetcher interface using an embedded filesystem.
// Files are organized as: hostname/path/to/file.ext // Files are organized as: hostname/path/to/file.ext
// URLs like https://example.com/images/photo.jpg map to example.com/images/photo.jpg. // URLs like https://example.com/images/photo.jpg map to example.com/images/photo.jpg
type MockFetcher struct { type MockFetcher struct {
fs fs.FS fs fs.FS
} }
// NewMock creates a new mock fetcher backed by the given filesystem. // NewMockFetcher creates a new mock fetcher backed by the given filesystem.
func NewMock(fsys fs.FS) *MockFetcher { func NewMockFetcher(fsys fs.FS) *MockFetcher {
return &MockFetcher{fs: fsys} return &MockFetcher{fs: fsys}
} }

View File

@@ -1,5 +1,4 @@
// Package imageprocessor provides image format conversion and resizing using libvips. package imgcache
package imageprocessor
import ( import (
"bytes" "bytes"
@@ -23,133 +22,38 @@ func initVips() {
}) })
} }
// Format represents supported output image formats.
type Format string
// Supported image output formats.
const (
FormatOriginal Format = "orig"
FormatJPEG Format = "jpeg"
FormatPNG Format = "png"
FormatWebP Format = "webp"
FormatAVIF Format = "avif"
FormatGIF Format = "gif"
)
// FitMode represents how to fit an image into requested dimensions.
type FitMode string
// Supported image fit modes.
const (
FitCover FitMode = "cover"
FitContain FitMode = "contain"
FitFill FitMode = "fill"
FitInside FitMode = "inside"
FitOutside FitMode = "outside"
)
// ErrInvalidFitMode is returned when an invalid fit mode is provided.
var ErrInvalidFitMode = errors.New("invalid fit mode")
// Size represents requested image dimensions.
type Size struct {
Width int
Height int
}
// Request holds the parameters for image processing.
type Request struct {
Size Size
Format Format
Quality int
FitMode FitMode
}
// Result contains the output of image processing.
type Result struct {
// Content is the processed image data.
Content io.ReadCloser
// ContentLength is the size in bytes.
ContentLength int64
// ContentType is the MIME type of the output.
ContentType string
// Width is the output image width.
Width int
// Height is the output image height.
Height int
// InputWidth is the original image width before processing.
InputWidth int
// InputHeight is the original image height before processing.
InputHeight int
// InputFormat is the detected input format (e.g., "jpeg", "png").
InputFormat string
}
// MaxInputDimension is the maximum allowed width or height for input images. // MaxInputDimension is the maximum allowed width or height for input images.
// Images larger than this are rejected to prevent DoS via decompression bombs. // Images larger than this are rejected to prevent DoS via decompression bombs.
const MaxInputDimension = 8192 const MaxInputDimension = 8192
// DefaultMaxInputBytes is the default maximum input size in bytes (50 MiB).
// This matches the default upstream fetcher limit.
const DefaultMaxInputBytes = 50 << 20
// ErrInputTooLarge is returned when input image dimensions exceed MaxInputDimension. // ErrInputTooLarge is returned when input image dimensions exceed MaxInputDimension.
var ErrInputTooLarge = errors.New("input image dimensions exceed maximum") var ErrInputTooLarge = errors.New("input image dimensions exceed maximum")
// ErrInputDataTooLarge is returned when the raw input data exceeds the configured byte limit.
var ErrInputDataTooLarge = errors.New("input data exceeds maximum allowed size")
// ErrUnsupportedOutputFormat is returned when the requested output format is not supported. // ErrUnsupportedOutputFormat is returned when the requested output format is not supported.
var ErrUnsupportedOutputFormat = errors.New("unsupported output format") var ErrUnsupportedOutputFormat = errors.New("unsupported output format")
// ImageProcessor implements image transformation using libvips via govips. // ImageProcessor implements the Processor interface using libvips via govips.
type ImageProcessor struct { type ImageProcessor struct{}
maxInputBytes int64
}
// Params holds configuration for creating an ImageProcessor. // NewImageProcessor creates a new image processor.
// Zero values use sensible defaults (MaxInputBytes defaults to DefaultMaxInputBytes). func NewImageProcessor() *ImageProcessor {
type Params struct {
// MaxInputBytes is the maximum allowed input size in bytes.
// If <= 0, DefaultMaxInputBytes is used.
MaxInputBytes int64
}
// New creates a new image processor with the given parameters.
// A zero-value Params{} uses sensible defaults.
func New(params Params) *ImageProcessor {
initVips() initVips()
maxInputBytes := params.MaxInputBytes return &ImageProcessor{}
if maxInputBytes <= 0 {
maxInputBytes = DefaultMaxInputBytes
}
return &ImageProcessor{
maxInputBytes: maxInputBytes,
}
} }
// Process transforms an image according to the request. // Process transforms an image according to the request.
func (p *ImageProcessor) Process( func (p *ImageProcessor) Process(
_ context.Context, _ context.Context,
input io.Reader, input io.Reader,
req *Request, req *ImageRequest,
) (*Result, error) { ) (*ProcessResult, error) {
// Read input with a size limit to prevent unbounded memory consumption. // Read input
// We read at most maxInputBytes+1 so we can detect if the input exceeds data, err := io.ReadAll(input)
// the limit without consuming additional memory.
limited := io.LimitReader(input, p.maxInputBytes+1)
data, err := io.ReadAll(limited)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read input: %w", err) return nil, fmt.Errorf("failed to read input: %w", err)
} }
if int64(len(data)) > p.maxInputBytes {
return nil, ErrInputDataTooLarge
}
// Decode image // Decode image
img, err := vips.NewImageFromBuffer(data) img, err := vips.NewImageFromBuffer(data)
if err != nil { if err != nil {
@@ -205,10 +109,10 @@ func (p *ImageProcessor) Process(
return nil, fmt.Errorf("failed to encode: %w", err) return nil, fmt.Errorf("failed to encode: %w", err)
} }
return &Result{ return &ProcessResult{
Content: io.NopCloser(bytes.NewReader(output)), Content: io.NopCloser(bytes.NewReader(output)),
ContentLength: int64(len(output)), ContentLength: int64(len(output)),
ContentType: FormatToMIME(outputFormat), ContentType: ImageFormatToMIME(outputFormat),
Width: img.Width(), Width: img.Width(),
Height: img.Height(), Height: img.Height(),
InputWidth: origWidth, InputWidth: origWidth,
@@ -220,17 +124,17 @@ func (p *ImageProcessor) Process(
// SupportedInputFormats returns MIME types this processor can read. // SupportedInputFormats returns MIME types this processor can read.
func (p *ImageProcessor) SupportedInputFormats() []string { func (p *ImageProcessor) SupportedInputFormats() []string {
return []string{ return []string{
"image/jpeg", string(MIMETypeJPEG),
"image/png", string(MIMETypePNG),
"image/gif", string(MIMETypeGIF),
"image/webp", string(MIMETypeWebP),
"image/avif", string(MIMETypeAVIF),
} }
} }
// SupportedOutputFormats returns formats this processor can write. // SupportedOutputFormats returns formats this processor can write.
func (p *ImageProcessor) SupportedOutputFormats() []Format { func (p *ImageProcessor) SupportedOutputFormats() []ImageFormat {
return []Format{ return []ImageFormat{
FormatJPEG, FormatJPEG,
FormatPNG, FormatPNG,
FormatGIF, FormatGIF,
@@ -239,24 +143,6 @@ func (p *ImageProcessor) SupportedOutputFormats() []Format {
} }
} }
// FormatToMIME converts a Format to its MIME type string.
func FormatToMIME(format Format) string {
switch format {
case FormatJPEG:
return "image/jpeg"
case FormatPNG:
return "image/png"
case FormatWebP:
return "image/webp"
case FormatGIF:
return "image/gif"
case FormatAVIF:
return "image/avif"
default:
return "application/octet-stream"
}
}
// detectFormat returns the format string from a vips image. // detectFormat returns the format string from a vips image.
func (p *ImageProcessor) detectFormat(img *vips.ImageRef) string { func (p *ImageProcessor) detectFormat(img *vips.ImageRef) string {
format := img.Format() format := img.Format()
@@ -285,6 +171,7 @@ func (p *ImageProcessor) resize(img *vips.ImageRef, width, height int, fit FitMo
case FitContain: case FitContain:
// Resize to fit within dimensions, maintaining aspect ratio // Resize to fit within dimensions, maintaining aspect ratio
// Calculate target dimensions maintaining aspect ratio
imgW, imgH := img.Width(), img.Height() imgW, imgH := img.Width(), img.Height()
scaleW := float64(width) / float64(imgW) scaleW := float64(width) / float64(imgW)
scaleH := float64(height) / float64(imgH) scaleH := float64(height) / float64(imgH)
@@ -295,7 +182,7 @@ func (p *ImageProcessor) resize(img *vips.ImageRef, width, height int, fit FitMo
return img.Thumbnail(newW, newH, vips.InterestingNone) return img.Thumbnail(newW, newH, vips.InterestingNone)
case FitFill: case FitFill:
// Resize to exact dimensions (may distort) // Resize to exact dimensions (may distort) - use ThumbnailWithSize with Force
return img.ThumbnailWithSize(width, height, vips.InterestingNone, vips.SizeForce) return img.ThumbnailWithSize(width, height, vips.InterestingNone, vips.SizeForce)
case FitInside: case FitInside:
@@ -331,7 +218,7 @@ func (p *ImageProcessor) resize(img *vips.ImageRef, width, height int, fit FitMo
const defaultQuality = 85 const defaultQuality = 85
// encode encodes an image to the specified format. // encode encodes an image to the specified format.
func (p *ImageProcessor) encode(img *vips.ImageRef, format Format, quality int) ([]byte, error) { func (p *ImageProcessor) encode(img *vips.ImageRef, format ImageFormat, quality int) ([]byte, error) {
if quality <= 0 { if quality <= 0 {
quality = defaultQuality quality = defaultQuality
} }
@@ -379,8 +266,8 @@ func (p *ImageProcessor) encode(img *vips.ImageRef, format Format, quality int)
return output, nil return output, nil
} }
// formatFromString converts a format string to Format. // formatFromString converts a format string to ImageFormat.
func (p *ImageProcessor) formatFromString(format string) Format { func (p *ImageProcessor) formatFromString(format string) ImageFormat {
switch format { switch format {
case "jpeg": case "jpeg":
return FormatJPEG return FormatJPEG

View File

@@ -1,4 +1,4 @@
package imageprocessor package imgcache
import ( import (
"bytes" "bytes"
@@ -70,36 +70,13 @@ func createTestPNG(t *testing.T, width, height int) []byte {
return buf.Bytes() return buf.Bytes()
} }
// detectMIME is a minimal magic-byte detector for test assertions.
func detectMIME(data []byte) string {
if len(data) >= 3 && data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF {
return "image/jpeg"
}
if len(data) >= 8 && string(data[:8]) == "\x89PNG\r\n\x1a\n" {
return "image/png"
}
if len(data) >= 4 && string(data[:4]) == "GIF8" {
return "image/gif"
}
if len(data) >= 12 && string(data[:4]) == "RIFF" && string(data[8:12]) == "WEBP" {
return "image/webp"
}
if len(data) >= 12 && string(data[4:8]) == "ftyp" {
brand := string(data[8:12])
if brand == "avif" || brand == "avis" {
return "image/avif"
}
}
return ""
}
func TestImageProcessor_ResizeJPEG(t *testing.T) { func TestImageProcessor_ResizeJPEG(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 800, 600) input := createTestJPEG(t, 800, 600)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 400, Height: 300}, Size: Size{Width: 400, Height: 300},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -130,19 +107,23 @@ func TestImageProcessor_ResizeJPEG(t *testing.T) {
t.Fatalf("failed to read result: %v", err) t.Fatalf("failed to read result: %v", err)
} }
mime := detectMIME(data) mime, err := DetectFormat(data)
if mime != "image/jpeg" { if err != nil {
t.Errorf("Output format = %v, want image/jpeg", mime) t.Fatalf("DetectFormat() error = %v", err)
}
if mime != MIMETypeJPEG {
t.Errorf("Output format = %v, want %v", mime, MIMETypeJPEG)
} }
} }
func TestImageProcessor_ConvertToPNG(t *testing.T) { func TestImageProcessor_ConvertToPNG(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 200, 150) input := createTestJPEG(t, 200, 150)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 200, Height: 150}, Size: Size{Width: 200, Height: 150},
Format: FormatPNG, Format: FormatPNG,
FitMode: FitCover, FitMode: FitCover,
@@ -159,19 +140,23 @@ func TestImageProcessor_ConvertToPNG(t *testing.T) {
t.Fatalf("failed to read result: %v", err) t.Fatalf("failed to read result: %v", err)
} }
mime := detectMIME(data) mime, err := DetectFormat(data)
if mime != "image/png" { if err != nil {
t.Errorf("Output format = %v, want image/png", mime) t.Fatalf("DetectFormat() error = %v", err)
}
if mime != MIMETypePNG {
t.Errorf("Output format = %v, want %v", mime, MIMETypePNG)
} }
} }
func TestImageProcessor_OriginalSize(t *testing.T) { func TestImageProcessor_OriginalSize(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 640, 480) input := createTestJPEG(t, 640, 480)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 0, Height: 0}, // Original size Size: Size{Width: 0, Height: 0}, // Original size
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -194,14 +179,14 @@ func TestImageProcessor_OriginalSize(t *testing.T) {
} }
func TestImageProcessor_FitContain(t *testing.T) { func TestImageProcessor_FitContain(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
// 800x400 image (2:1 aspect) into 400x400 box with contain // 800x400 image (2:1 aspect) into 400x400 box with contain
// Should result in 400x200 (maintaining aspect ratio) // Should result in 400x200 (maintaining aspect ratio)
input := createTestJPEG(t, 800, 400) input := createTestJPEG(t, 800, 400)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 400, Height: 400}, Size: Size{Width: 400, Height: 400},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -221,14 +206,14 @@ func TestImageProcessor_FitContain(t *testing.T) {
} }
func TestImageProcessor_ProportionalScale_WidthOnly(t *testing.T) { func TestImageProcessor_ProportionalScale_WidthOnly(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
// 800x600 image, request width=400 height=0 // 800x600 image, request width=400 height=0
// Should scale proportionally to 400x300 // Should scale proportionally to 400x300
input := createTestJPEG(t, 800, 600) input := createTestJPEG(t, 800, 600)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 400, Height: 0}, Size: Size{Width: 400, Height: 0},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -251,14 +236,14 @@ func TestImageProcessor_ProportionalScale_WidthOnly(t *testing.T) {
} }
func TestImageProcessor_ProportionalScale_HeightOnly(t *testing.T) { func TestImageProcessor_ProportionalScale_HeightOnly(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
// 800x600 image, request width=0 height=300 // 800x600 image, request width=0 height=300
// Should scale proportionally to 400x300 // Should scale proportionally to 400x300
input := createTestJPEG(t, 800, 600) input := createTestJPEG(t, 800, 600)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 0, Height: 300}, Size: Size{Width: 0, Height: 300},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -281,12 +266,12 @@ func TestImageProcessor_ProportionalScale_HeightOnly(t *testing.T) {
} }
func TestImageProcessor_ProcessPNG(t *testing.T) { func TestImageProcessor_ProcessPNG(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
input := createTestPNG(t, 400, 300) input := createTestPNG(t, 400, 300)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 200, Height: 150}, Size: Size{Width: 200, Height: 150},
Format: FormatPNG, Format: FormatPNG,
FitMode: FitCover, FitMode: FitCover,
@@ -307,8 +292,13 @@ func TestImageProcessor_ProcessPNG(t *testing.T) {
} }
} }
func TestImageProcessor_ImplementsInterface(t *testing.T) {
// Verify ImageProcessor implements Processor interface
var _ Processor = (*ImageProcessor)(nil)
}
func TestImageProcessor_SupportedFormats(t *testing.T) { func TestImageProcessor_SupportedFormats(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
inputFormats := proc.SupportedInputFormats() inputFormats := proc.SupportedInputFormats()
if len(inputFormats) == 0 { if len(inputFormats) == 0 {
@@ -322,14 +312,14 @@ func TestImageProcessor_SupportedFormats(t *testing.T) {
} }
func TestImageProcessor_RejectsOversizedInput(t *testing.T) { func TestImageProcessor_RejectsOversizedInput(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
// Create an image that exceeds MaxInputDimension (e.g., 10000x100) // Create an image that exceeds MaxInputDimension (e.g., 10000x100)
// This should be rejected before processing to prevent DoS // This should be rejected before processing to prevent DoS
input := createTestJPEG(t, 10000, 100) input := createTestJPEG(t, 10000, 100)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 100, Height: 100}, Size: Size{Width: 100, Height: 100},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -347,13 +337,13 @@ func TestImageProcessor_RejectsOversizedInput(t *testing.T) {
} }
func TestImageProcessor_RejectsOversizedInputHeight(t *testing.T) { func TestImageProcessor_RejectsOversizedInputHeight(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
// Create an image with oversized height // Create an image with oversized height
input := createTestJPEG(t, 100, 10000) input := createTestJPEG(t, 100, 10000)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 100, Height: 100}, Size: Size{Width: 100, Height: 100},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -371,13 +361,14 @@ func TestImageProcessor_RejectsOversizedInputHeight(t *testing.T) {
} }
func TestImageProcessor_AcceptsMaxDimensionInput(t *testing.T) { func TestImageProcessor_AcceptsMaxDimensionInput(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
// Create an image at exactly MaxInputDimension - should be accepted // Create an image at exactly MaxInputDimension - should be accepted
// Using smaller dimensions to keep test fast
input := createTestJPEG(t, MaxInputDimension, 100) input := createTestJPEG(t, MaxInputDimension, 100)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 100, Height: 100}, Size: Size{Width: 100, Height: 100},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -392,12 +383,12 @@ func TestImageProcessor_AcceptsMaxDimensionInput(t *testing.T) {
} }
func TestImageProcessor_EncodeWebP(t *testing.T) { func TestImageProcessor_EncodeWebP(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 200, 150) input := createTestJPEG(t, 200, 150)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 100, Height: 75}, Size: Size{Width: 100, Height: 75},
Format: FormatWebP, Format: FormatWebP,
Quality: 80, Quality: 80,
@@ -416,9 +407,13 @@ func TestImageProcessor_EncodeWebP(t *testing.T) {
t.Fatalf("failed to read result: %v", err) t.Fatalf("failed to read result: %v", err)
} }
mime := detectMIME(data) mime, err := DetectFormat(data)
if mime != "image/webp" { if err != nil {
t.Errorf("Output format = %v, want image/webp", mime) t.Fatalf("DetectFormat() error = %v", err)
}
if mime != MIMETypeWebP {
t.Errorf("Output format = %v, want %v", mime, MIMETypeWebP)
} }
// Verify dimensions // Verify dimensions
@@ -431,7 +426,7 @@ func TestImageProcessor_EncodeWebP(t *testing.T) {
} }
func TestImageProcessor_DecodeAVIF(t *testing.T) { func TestImageProcessor_DecodeAVIF(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
// Load test AVIF file // Load test AVIF file
@@ -441,7 +436,7 @@ func TestImageProcessor_DecodeAVIF(t *testing.T) {
} }
// Request resize and convert to JPEG // Request resize and convert to JPEG
req := &Request{ req := &ImageRequest{
Size: Size{Width: 2, Height: 2}, Size: Size{Width: 2, Height: 2},
Format: FormatJPEG, Format: FormatJPEG,
Quality: 85, Quality: 85,
@@ -460,84 +455,23 @@ func TestImageProcessor_DecodeAVIF(t *testing.T) {
t.Fatalf("failed to read result: %v", err) t.Fatalf("failed to read result: %v", err)
} }
mime := detectMIME(data) mime, err := DetectFormat(data)
if mime != "image/jpeg" {
t.Errorf("Output format = %v, want image/jpeg", mime)
}
}
func TestImageProcessor_RejectsOversizedInputData(t *testing.T) {
// Create a processor with a very small byte limit
const limit = 1024
proc := New(Params{MaxInputBytes: limit})
ctx := context.Background()
// Create a valid JPEG that exceeds the byte limit
input := createTestJPEG(t, 800, 600) // will be well over 1 KiB
if int64(len(input)) <= limit {
t.Fatalf("test JPEG must exceed %d bytes, got %d", limit, len(input))
}
req := &Request{
Size: Size{Width: 100, Height: 75},
Format: FormatJPEG,
Quality: 85,
FitMode: FitCover,
}
_, err := proc.Process(ctx, bytes.NewReader(input), req)
if err == nil {
t.Fatal("Process() should reject input exceeding maxInputBytes")
}
if err != ErrInputDataTooLarge {
t.Errorf("Process() error = %v, want ErrInputDataTooLarge", err)
}
}
func TestImageProcessor_AcceptsInputWithinLimit(t *testing.T) {
// Create a small image and set limit well above its size
input := createTestJPEG(t, 10, 10)
limit := int64(len(input)) * 10 // 10× headroom
proc := New(Params{MaxInputBytes: limit})
ctx := context.Background()
req := &Request{
Size: Size{Width: 10, Height: 10},
Format: FormatJPEG,
Quality: 85,
FitMode: FitCover,
}
result, err := proc.Process(ctx, bytes.NewReader(input), req)
if err != nil { if err != nil {
t.Fatalf("Process() error = %v, want nil", err) t.Fatalf("DetectFormat() error = %v", err)
}
defer result.Content.Close()
}
func TestImageProcessor_DefaultMaxInputBytes(t *testing.T) {
// Passing 0 should use the default
proc := New(Params{})
if proc.maxInputBytes != DefaultMaxInputBytes {
t.Errorf("maxInputBytes = %d, want %d", proc.maxInputBytes, DefaultMaxInputBytes)
} }
// Passing negative should also use the default if mime != MIMETypeJPEG {
proc = New(Params{MaxInputBytes: -1}) t.Errorf("Output format = %v, want %v", mime, MIMETypeJPEG)
if proc.maxInputBytes != DefaultMaxInputBytes {
t.Errorf("maxInputBytes = %d, want %d", proc.maxInputBytes, DefaultMaxInputBytes)
} }
} }
func TestImageProcessor_EncodeAVIF(t *testing.T) { func TestImageProcessor_EncodeAVIF(t *testing.T) {
proc := New(Params{}) proc := NewImageProcessor()
ctx := context.Background() ctx := context.Background()
input := createTestJPEG(t, 200, 150) input := createTestJPEG(t, 200, 150)
req := &Request{ req := &ImageRequest{
Size: Size{Width: 100, Height: 75}, Size: Size{Width: 100, Height: 75},
Format: FormatAVIF, Format: FormatAVIF,
Quality: 85, Quality: 85,
@@ -556,9 +490,13 @@ func TestImageProcessor_EncodeAVIF(t *testing.T) {
t.Fatalf("failed to read result: %v", err) t.Fatalf("failed to read result: %v", err)
} }
mime := detectMIME(data) mime, err := DetectFormat(data)
if mime != "image/avif" { if err != nil {
t.Errorf("Output format = %v, want image/avif", mime) t.Fatalf("DetectFormat() error = %v", err)
}
if mime != MIMETypeAVIF {
t.Errorf("Output format = %v, want %v", mime, MIMETypeAVIF)
} }
// Verify dimensions // Verify dimensions

View File

@@ -11,22 +11,17 @@ import (
"time" "time"
"github.com/dustin/go-humanize" "github.com/dustin/go-humanize"
"sneak.berlin/go/pixa/internal/allowlist"
"sneak.berlin/go/pixa/internal/httpfetcher"
"sneak.berlin/go/pixa/internal/imageprocessor"
"sneak.berlin/go/pixa/internal/magic"
) )
// Service implements the ImageCache interface, orchestrating cache, fetcher, and processor. // Service implements the ImageCache interface, orchestrating cache, fetcher, and processor.
type Service struct { type Service struct {
cache *Cache cache *Cache
fetcher httpfetcher.Fetcher fetcher Fetcher
processor *imageprocessor.ImageProcessor processor Processor
signer *Signer signer *Signer
allowlist *allowlist.HostAllowList whitelist *HostWhitelist
log *slog.Logger log *slog.Logger
allowHTTP bool allowHTTP bool
maxResponseSize int64
} }
// ServiceConfig holds configuration for the image service. // ServiceConfig holds configuration for the image service.
@@ -34,9 +29,9 @@ type ServiceConfig struct {
// Cache is the cache instance // Cache is the cache instance
Cache *Cache Cache *Cache
// FetcherConfig configures the upstream fetcher (ignored if Fetcher is set) // FetcherConfig configures the upstream fetcher (ignored if Fetcher is set)
FetcherConfig *httpfetcher.Config FetcherConfig *FetcherConfig
// Fetcher is an optional custom fetcher (for testing) // Fetcher is an optional custom fetcher (for testing)
Fetcher httpfetcher.Fetcher Fetcher Fetcher
// SigningKey is the HMAC signing key (empty disables signing) // SigningKey is the HMAC signing key (empty disables signing)
SigningKey string SigningKey string
// Whitelist is the list of hosts that don't require signatures // Whitelist is the list of hosts that don't require signatures
@@ -55,18 +50,16 @@ func NewService(cfg *ServiceConfig) (*Service, error) {
return nil, errors.New("signing key is required") return nil, errors.New("signing key is required")
} }
// Resolve fetcher config for defaults
fetcherCfg := cfg.FetcherConfig
if fetcherCfg == nil {
fetcherCfg = httpfetcher.DefaultConfig()
}
// Use custom fetcher if provided, otherwise create HTTP fetcher // Use custom fetcher if provided, otherwise create HTTP fetcher
var fetcher httpfetcher.Fetcher var fetcher Fetcher
if cfg.Fetcher != nil { if cfg.Fetcher != nil {
fetcher = cfg.Fetcher fetcher = cfg.Fetcher
} else { } else {
fetcher = httpfetcher.New(fetcherCfg) fetcherCfg := cfg.FetcherConfig
if fetcherCfg == nil {
fetcherCfg = DefaultFetcherConfig()
}
fetcher = NewHTTPFetcher(fetcherCfg)
} }
signer := NewSigner(cfg.SigningKey) signer := NewSigner(cfg.SigningKey)
@@ -81,17 +74,14 @@ func NewService(cfg *ServiceConfig) (*Service, error) {
allowHTTP = cfg.FetcherConfig.AllowHTTP allowHTTP = cfg.FetcherConfig.AllowHTTP
} }
maxResponseSize := fetcherCfg.MaxResponseSize
return &Service{ return &Service{
cache: cfg.Cache, cache: cfg.Cache,
fetcher: fetcher, fetcher: fetcher,
processor: imageprocessor.New(imageprocessor.Params{MaxInputBytes: maxResponseSize}), processor: NewImageProcessor(),
signer: signer, signer: signer,
allowlist: allowlist.New(cfg.Whitelist), whitelist: NewHostWhitelist(cfg.Whitelist),
log: log, log: log,
allowHTTP: allowHTTP, allowHTTP: allowHTTP,
maxResponseSize: maxResponseSize,
}, nil }, nil
} }
@@ -114,7 +104,7 @@ func (s *Service) Get(ctx context.Context, req *ImageRequest) (*ImageResponse, e
"path", req.SourcePath, "path", req.SourcePath,
) )
return nil, fmt.Errorf("%w: %w", httpfetcher.ErrUpstreamError, ErrNegativeCached) return nil, fmt.Errorf("%w: %w", ErrUpstreamError, ErrNegativeCached)
} }
// Check variant cache first (disk only, no DB) // Check variant cache first (disk only, no DB)
@@ -156,40 +146,6 @@ func (s *Service) Get(ctx context.Context, req *ImageRequest) (*ImageResponse, e
return response, nil return response, nil
} }
// loadCachedSource attempts to load source content from cache, returning nil
// if the cached data is unavailable or exceeds maxResponseSize.
func (s *Service) loadCachedSource(contentHash ContentHash) []byte {
reader, err := s.cache.GetSourceContent(contentHash)
if err != nil {
s.log.Warn("failed to load cached source, fetching", "error", err)
return nil
}
// Bound the read to maxResponseSize to prevent unbounded memory use
// from unexpectedly large cached files.
limited := io.LimitReader(reader, s.maxResponseSize+1)
data, err := io.ReadAll(limited)
_ = reader.Close()
if err != nil {
s.log.Warn("failed to read cached source, fetching", "error", err)
return nil
}
if int64(len(data)) > s.maxResponseSize {
s.log.Warn("cached source exceeds max response size, discarding",
"hash", contentHash,
"max_bytes", s.maxResponseSize,
)
return nil
}
return data
}
// processFromSourceOrFetch processes an image, using cached source content if available. // processFromSourceOrFetch processes an image, using cached source content if available.
func (s *Service) processFromSourceOrFetch( func (s *Service) processFromSourceOrFetch(
ctx context.Context, ctx context.Context,
@@ -206,8 +162,22 @@ func (s *Service) processFromSourceOrFetch(
var fetchBytes int64 var fetchBytes int64
if contentHash != "" { if contentHash != "" {
// We have cached source - load it
s.log.Debug("using cached source", "hash", contentHash) s.log.Debug("using cached source", "hash", contentHash)
sourceData = s.loadCachedSource(contentHash)
reader, err := s.cache.GetSourceContent(contentHash)
if err != nil {
s.log.Warn("failed to load cached source, fetching", "error", err)
// Fall through to fetch
} else {
sourceData, err = io.ReadAll(reader)
_ = reader.Close()
if err != nil {
s.log.Warn("failed to read cached source, fetching", "error", err)
// Fall through to fetch
}
}
} }
// Fetch from upstream if we don't have source data or it's empty // Fetch from upstream if we don't have source data or it's empty
@@ -279,7 +249,7 @@ func (s *Service) fetchAndProcess(
) )
// Validate magic bytes match content type // Validate magic bytes match content type
if err := magic.ValidateMagicBytes(sourceData, fetchResult.ContentType); err != nil { if err := ValidateMagicBytes(sourceData, fetchResult.ContentType); err != nil {
return nil, fmt.Errorf("content validation failed: %w", err) return nil, fmt.Errorf("content validation failed: %w", err)
} }
@@ -304,14 +274,7 @@ func (s *Service) processAndStore(
// Process the image // Process the image
processStart := time.Now() processStart := time.Now()
processReq := &imageprocessor.Request{ processResult, err := s.processor.Process(ctx, bytes.NewReader(sourceData), req)
Size: imageprocessor.Size{Width: req.Size.Width, Height: req.Size.Height},
Format: imageprocessor.Format(req.Format),
Quality: req.Quality,
FitMode: imageprocessor.FitMode(req.FitMode),
}
processResult, err := s.processor.Process(ctx, bytes.NewReader(sourceData), processReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("image processing failed: %w", err) return nil, fmt.Errorf("image processing failed: %w", err)
} }
@@ -384,7 +347,7 @@ func (s *Service) Stats(ctx context.Context) (*CacheStats, error) {
// ValidateRequest validates the request signature if required. // ValidateRequest validates the request signature if required.
func (s *Service) ValidateRequest(req *ImageRequest) error { func (s *Service) ValidateRequest(req *ImageRequest) error {
// Check if host is allowed (no signature required) // Check if host is whitelisted (no signature required)
sourceURL := req.SourceURL() sourceURL := req.SourceURL()
parsedURL, err := url.Parse(sourceURL) parsedURL, err := url.Parse(sourceURL)
@@ -392,11 +355,11 @@ func (s *Service) ValidateRequest(req *ImageRequest) error {
return fmt.Errorf("invalid source URL: %w", err) return fmt.Errorf("invalid source URL: %w", err)
} }
if s.allowlist.IsAllowed(parsedURL) { if s.whitelist.IsWhitelisted(parsedURL) {
return nil return nil
} }
// Signature required for non-allowed hosts // Signature required for non-whitelisted hosts
return s.signer.Verify(req) return s.signer.Verify(req)
} }
@@ -419,13 +382,13 @@ const (
// isNegativeCacheable returns true if the error should be cached. // isNegativeCacheable returns true if the error should be cached.
func isNegativeCacheable(err error) bool { func isNegativeCacheable(err error) bool {
return errors.Is(err, httpfetcher.ErrUpstreamError) return errors.Is(err, ErrUpstreamError)
} }
// extractStatusCode extracts HTTP status code from error message. // extractStatusCode extracts HTTP status code from error message.
func extractStatusCode(err error) int { func extractStatusCode(err error) int {
// Default to 502 Bad Gateway for upstream errors // Default to 502 Bad Gateway for upstream errors
if errors.Is(err, httpfetcher.ErrUpstreamError) { if errors.Is(err, ErrUpstreamError) {
return httpStatusBadGateway return httpStatusBadGateway
} }

View File

@@ -5,8 +5,6 @@ import (
"io" "io"
"testing" "testing"
"time" "time"
"sneak.berlin/go/pixa/internal/magic"
) )
func TestService_Get_WhitelistedHost(t *testing.T) { func TestService_Get_WhitelistedHost(t *testing.T) {
@@ -153,74 +151,6 @@ func TestService_Get_NonWhitelistedHost_InvalidSignature(t *testing.T) {
} }
} }
// TestService_ValidateRequest_SignatureExactHostMatch verifies that
// ValidateRequest enforces exact host matching for signatures. A
// signature for one host must not verify for a different host, even
// if they share a domain suffix.
func TestService_ValidateRequest_SignatureExactHostMatch(t *testing.T) {
signingKey := "test-signing-key-must-be-32-chars"
svc, _ := SetupTestService(t,
WithSigningKey(signingKey),
WithNoWhitelist(),
)
signer := NewSigner(signingKey)
// Sign a request for "cdn.example.com"
signedReq := &ImageRequest{
SourceHost: "cdn.example.com",
SourcePath: "/photos/cat.jpg",
Size: Size{Width: 50, Height: 50},
Format: FormatJPEG,
Quality: 85,
FitMode: FitCover,
Expires: time.Now().Add(time.Hour),
}
signedReq.Signature = signer.Sign(signedReq)
// The original request should pass validation
t.Run("exact host passes", func(t *testing.T) {
err := svc.ValidateRequest(signedReq)
if err != nil {
t.Errorf("ValidateRequest() exact host failed: %v", err)
}
})
// Try to reuse the signature with different hosts
tests := []struct {
name string
host string
}{
{"parent domain", "example.com"},
{"sibling subdomain", "images.example.com"},
{"deeper subdomain", "a.cdn.example.com"},
{"evil suffix domain", "cdn.example.com.evil.com"},
{"prefixed host", "evilcdn.example.com"},
}
for _, tt := range tests {
t.Run(tt.name+" rejected", func(t *testing.T) {
req := &ImageRequest{
SourceHost: tt.host,
SourcePath: signedReq.SourcePath,
SourceQuery: signedReq.SourceQuery,
Size: signedReq.Size,
Format: signedReq.Format,
Quality: signedReq.Quality,
FitMode: signedReq.FitMode,
Expires: signedReq.Expires,
Signature: signedReq.Signature,
}
err := svc.ValidateRequest(req)
if err == nil {
t.Errorf("ValidateRequest() should reject signature for host %q (signed for %q)",
tt.host, signedReq.SourceHost)
}
})
}
}
func TestService_Get_InvalidFile(t *testing.T) { func TestService_Get_InvalidFile(t *testing.T) {
svc, fixtures := SetupTestService(t) svc, fixtures := SetupTestService(t)
ctx := context.Background() ctx := context.Background()
@@ -317,17 +247,17 @@ func TestService_Get_FormatConversion(t *testing.T) {
t.Fatalf("failed to read response: %v", err) t.Fatalf("failed to read response: %v", err)
} }
detectedMIME, err := magic.DetectFormat(data) detectedMIME, err := DetectFormat(data)
if err != nil { if err != nil {
t.Fatalf("failed to detect format: %v", err) t.Fatalf("failed to detect format: %v", err)
} }
expectedFormat, ok := magic.MIMEToImageFormat(tt.wantMIME) expectedFormat, ok := MIMEToImageFormat(tt.wantMIME)
if !ok { if !ok {
t.Fatalf("unknown format for MIME type: %s", tt.wantMIME) t.Fatalf("unknown format for MIME type: %s", tt.wantMIME)
} }
detectedFormat, ok := magic.MIMEToImageFormat(string(detectedMIME)) detectedFormat, ok := MIMEToImageFormat(string(detectedMIME))
if !ok { if !ok {
t.Fatalf("unknown format for detected MIME type: %s", detectedMIME) t.Fatalf("unknown format for detected MIME type: %s", detectedMIME)
} }

View File

@@ -43,11 +43,6 @@ func (s *Signer) Sign(req *ImageRequest) string {
} }
// Verify checks if the signature on the request is valid and not expired. // Verify checks if the signature on the request is valid and not expired.
// Signatures are exact-match only: every component of the signed data
// (host, path, query, dimensions, format, expiration) must match exactly.
// No suffix matching, wildcard matching, or partial matching is supported.
// A signature for "cdn.example.com" will NOT verify for "example.com" or
// "other.cdn.example.com", and vice versa.
func (s *Signer) Verify(req *ImageRequest) error { func (s *Signer) Verify(req *ImageRequest) error {
// Check expiration first // Check expiration first
if req.Expires.IsZero() { if req.Expires.IsZero() {
@@ -71,8 +66,6 @@ func (s *Signer) Verify(req *ImageRequest) error {
// buildSignatureData creates the string to be signed. // buildSignatureData creates the string to be signed.
// Format: "host:path:query:width:height:format:expiration" // Format: "host:path:query:width:height:format:expiration"
// All components are used verbatim (exact match). No normalization,
// suffix matching, or wildcard expansion is performed.
func (s *Signer) buildSignatureData(req *ImageRequest) string { func (s *Signer) buildSignatureData(req *ImageRequest) string {
return fmt.Sprintf("%s:%s:%s:%d:%d:%s:%d", return fmt.Sprintf("%s:%s:%s:%d:%d:%s:%d",
req.SourceHost, req.SourceHost,

View File

@@ -152,178 +152,6 @@ func TestSigner_Verify(t *testing.T) {
} }
} }
// TestSigner_Verify_ExactMatchOnly verifies that signatures enforce exact
// matching on every URL component. No suffix matching, wildcard matching,
// or partial matching is supported.
func TestSigner_Verify_ExactMatchOnly(t *testing.T) {
signer := NewSigner("test-secret-key")
// Base request that we'll sign, then tamper with individual fields.
baseReq := func() *ImageRequest {
req := &ImageRequest{
SourceHost: "cdn.example.com",
SourcePath: "/photos/cat.jpg",
SourceQuery: "token=abc",
Size: Size{Width: 800, Height: 600},
Format: FormatWebP,
Expires: time.Now().Add(1 * time.Hour),
}
req.Signature = signer.Sign(req)
return req
}
tests := []struct {
name string
tamper func(req *ImageRequest)
}{
{
name: "parent domain does not match subdomain",
tamper: func(req *ImageRequest) {
// Signed for cdn.example.com, try example.com
req.SourceHost = "example.com"
},
},
{
name: "subdomain does not match parent domain",
tamper: func(req *ImageRequest) {
// Signed for cdn.example.com, try images.cdn.example.com
req.SourceHost = "images.cdn.example.com"
},
},
{
name: "sibling subdomain does not match",
tamper: func(req *ImageRequest) {
// Signed for cdn.example.com, try images.example.com
req.SourceHost = "images.example.com"
},
},
{
name: "host with suffix appended does not match",
tamper: func(req *ImageRequest) {
// Signed for cdn.example.com, try cdn.example.com.evil.com
req.SourceHost = "cdn.example.com.evil.com"
},
},
{
name: "host with prefix does not match",
tamper: func(req *ImageRequest) {
// Signed for cdn.example.com, try evilcdn.example.com
req.SourceHost = "evilcdn.example.com"
},
},
{
name: "different path does not match",
tamper: func(req *ImageRequest) {
req.SourcePath = "/photos/dog.jpg"
},
},
{
name: "path suffix does not match",
tamper: func(req *ImageRequest) {
req.SourcePath = "/photos/cat.jpg/extra"
},
},
{
name: "path prefix does not match",
tamper: func(req *ImageRequest) {
req.SourcePath = "/other/photos/cat.jpg"
},
},
{
name: "different query does not match",
tamper: func(req *ImageRequest) {
req.SourceQuery = "token=xyz"
},
},
{
name: "added query does not match empty query",
tamper: func(req *ImageRequest) {
req.SourceQuery = "extra=1"
},
},
{
name: "removed query does not match",
tamper: func(req *ImageRequest) {
req.SourceQuery = ""
},
},
{
name: "different width does not match",
tamper: func(req *ImageRequest) {
req.Size.Width = 801
},
},
{
name: "different height does not match",
tamper: func(req *ImageRequest) {
req.Size.Height = 601
},
},
{
name: "different format does not match",
tamper: func(req *ImageRequest) {
req.Format = FormatPNG
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := baseReq()
tt.tamper(req)
err := signer.Verify(req)
if err != ErrSignatureInvalid {
t.Errorf("Verify() = %v, want %v", err, ErrSignatureInvalid)
}
})
}
// Verify the unmodified base request still passes
t.Run("unmodified request passes", func(t *testing.T) {
req := baseReq()
if err := signer.Verify(req); err != nil {
t.Errorf("Verify() unmodified request failed: %v", err)
}
})
}
// TestSigner_Sign_ExactHostInData verifies that Sign uses the exact host
// string in the signature data, producing different signatures for
// suffix-related hosts.
func TestSigner_Sign_ExactHostInData(t *testing.T) {
signer := NewSigner("test-secret-key")
hosts := []string{
"cdn.example.com",
"example.com",
"images.example.com",
"images.cdn.example.com",
"cdn.example.com.evil.com",
}
sigs := make(map[string]string)
for _, host := range hosts {
req := &ImageRequest{
SourceHost: host,
SourcePath: "/photos/cat.jpg",
SourceQuery: "",
Size: Size{Width: 800, Height: 600},
Format: FormatWebP,
Expires: time.Unix(1704067200, 0),
}
sig := signer.Sign(req)
if existing, ok := sigs[sig]; ok {
t.Errorf("hosts %q and %q produced the same signature", existing, host)
}
sigs[sig] = host
}
}
func TestSigner_DifferentKeys(t *testing.T) { func TestSigner_DifferentKeys(t *testing.T) {
signer1 := NewSigner("secret-key-1") signer1 := NewSigner("secret-key-1")
signer2 := NewSigner("secret-key-2") signer2 := NewSigner("secret-key-2")

View File

@@ -16,7 +16,7 @@ func setupStatsTestDB(t *testing.T) *sql.DB {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil { if err := database.ApplyMigrations(db); err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(func() { db.Close() }) t.Cleanup(func() { db.Close() })

View File

@@ -2,7 +2,6 @@ package imgcache
import ( import (
"bytes" "bytes"
"context"
"database/sql" "database/sql"
"image" "image"
"image/color" "image/color"
@@ -15,7 +14,6 @@ import (
"time" "time"
"sneak.berlin/go/pixa/internal/database" "sneak.berlin/go/pixa/internal/database"
"sneak.berlin/go/pixa/internal/httpfetcher"
) )
// TestFixtures contains paths to test files in the mock filesystem. // TestFixtures contains paths to test files in the mock filesystem.
@@ -173,7 +171,7 @@ func SetupTestService(t *testing.T, opts ...TestServiceOption) (*Service, *TestF
svc, err := NewService(&ServiceConfig{ svc, err := NewService(&ServiceConfig{
Cache: cache, Cache: cache,
Fetcher: httpfetcher.NewMock(mockFS), Fetcher: NewMockFetcher(mockFS),
SigningKey: cfg.signingKey, SigningKey: cfg.signingKey,
Whitelist: cfg.whitelist, Whitelist: cfg.whitelist,
}) })
@@ -195,7 +193,7 @@ func setupServiceTestDB(t *testing.T) *sql.DB {
} }
// Use the real production schema via migrations // Use the real production schema via migrations
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil { if err := database.ApplyMigrations(db); err != nil {
t.Fatalf("failed to apply migrations: %v", err) t.Fatalf("failed to apply migrations: %v", err)
} }

View File

@@ -0,0 +1,69 @@
package imgcache
import (
"net/url"
"strings"
)
// HostWhitelist implements the Whitelist interface for checking allowed source hosts.
// Only exact host matches are supported. Leading dots in patterns are stripped
// (e.g. ".example.com" becomes an exact match for "example.com").
type HostWhitelist struct {
// hosts contains hosts that must match exactly (e.g., "cdn.example.com")
hosts map[string]struct{}
}
// NewHostWhitelist creates a whitelist from a list of host patterns.
// All patterns are treated as exact matches. Leading dots are stripped
// for backwards compatibility (e.g. ".example.com" matches "example.com" only).
func NewHostWhitelist(patterns []string) *HostWhitelist {
w := &HostWhitelist{
hosts: make(map[string]struct{}),
}
for _, pattern := range patterns {
pattern = strings.ToLower(strings.TrimSpace(pattern))
if pattern == "" {
continue
}
// Strip leading dot — suffix matching is not supported.
// ".example.com" is treated as exact match for "example.com".
pattern = strings.TrimPrefix(pattern, ".")
if pattern == "" {
continue
}
w.hosts[pattern] = struct{}{}
}
return w
}
// IsWhitelisted checks if a URL's host is in the whitelist.
// Only exact host matches are supported.
func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
if u == nil {
return false
}
host := strings.ToLower(u.Hostname())
if host == "" {
return false
}
_, ok := w.hosts[host]
return ok
}
// IsEmpty returns true if the whitelist has no entries.
func (w *HostWhitelist) IsEmpty() bool {
return len(w.hosts) == 0
}
// Count returns the total number of whitelist entries.
func (w *HostWhitelist) Count() int {
return len(w.hosts)
}

View File

@@ -1,13 +1,11 @@
package allowlist_test package imgcache
import ( import (
"net/url" "net/url"
"testing" "testing"
"sneak.berlin/go/pixa/internal/allowlist"
) )
func TestHostAllowList_IsAllowed(t *testing.T) { func TestHostWhitelist_IsWhitelisted(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
patterns []string patterns []string
@@ -33,43 +31,49 @@ func TestHostAllowList_IsAllowed(t *testing.T) {
want: false, want: false,
}, },
{ {
name: "suffix match", name: "dot prefix does not enable suffix matching",
patterns: []string{".example.com"}, patterns: []string{".example.com"},
testURL: "https://cdn.example.com/image.jpg", testURL: "https://cdn.example.com/image.jpg",
want: true, want: false,
}, },
{ {
name: "suffix match deep subdomain", name: "dot prefix does not match deep subdomain",
patterns: []string{".example.com"}, patterns: []string{".example.com"},
testURL: "https://cdn.images.example.com/image.jpg", testURL: "https://cdn.images.example.com/image.jpg",
want: true, want: false,
}, },
{ {
name: "suffix match apex domain", name: "dot prefix stripped matches apex domain exactly",
patterns: []string{".example.com"}, patterns: []string{".example.com"},
testURL: "https://example.com/image.jpg", testURL: "https://example.com/image.jpg",
want: true, want: true,
}, },
{ {
name: "suffix match not found", name: "dot prefix does not match unrelated domain",
patterns: []string{".example.com"}, patterns: []string{".example.com"},
testURL: "https://notexample.com/image.jpg", testURL: "https://notexample.com/image.jpg",
want: false, want: false,
}, },
{ {
name: "suffix match partial not allowed", name: "dot prefix does not match partial domain",
patterns: []string{".example.com"}, patterns: []string{".example.com"},
testURL: "https://fakeexample.com/image.jpg", testURL: "https://fakeexample.com/image.jpg",
want: false, want: false,
}, },
{ {
name: "multiple patterns", name: "multiple patterns exact only",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"}, patterns: []string{"cdn.example.com", "photos.images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg", testURL: "https://photos.images.org/image.jpg",
want: true, want: true,
}, },
{ {
name: "empty allow list", name: "multiple patterns no suffix match",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg",
want: false,
},
{
name: "empty whitelist",
patterns: []string{}, patterns: []string{},
testURL: "https://cdn.example.com/image.jpg", testURL: "https://cdn.example.com/image.jpg",
want: false, want: false,
@@ -92,11 +96,17 @@ func TestHostAllowList_IsAllowed(t *testing.T) {
testURL: "https://cdn.example.com/image.jpg", testURL: "https://cdn.example.com/image.jpg",
want: true, want: true,
}, },
{
name: "whitespace dot prefix stripped matches exactly",
patterns: []string{" .other.com "},
testURL: "https://other.com/image.jpg",
want: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
w := allowlist.New(tt.patterns) w := NewHostWhitelist(tt.patterns)
var u *url.URL var u *url.URL
if tt.testURL != "" { if tt.testURL != "" {
@@ -107,15 +117,15 @@ func TestHostAllowList_IsAllowed(t *testing.T) {
} }
} }
got := w.IsAllowed(u) got := w.IsWhitelisted(u)
if got != tt.want { if got != tt.want {
t.Errorf("IsAllowed() = %v, want %v", got, tt.want) t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want)
} }
}) })
} }
} }
func TestHostAllowList_IsEmpty(t *testing.T) { func TestHostWhitelist_IsEmpty(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
patterns []string patterns []string
@@ -141,11 +151,16 @@ func TestHostAllowList_IsEmpty(t *testing.T) {
patterns: []string{"example.com"}, patterns: []string{"example.com"},
want: false, want: false,
}, },
{
name: "dot prefix entry still counts",
patterns: []string{".example.com"},
want: false,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
w := allowlist.New(tt.patterns) w := NewHostWhitelist(tt.patterns)
if got := w.IsEmpty(); got != tt.want { if got := w.IsEmpty(); got != tt.want {
t.Errorf("IsEmpty() = %v, want %v", got, tt.want) t.Errorf("IsEmpty() = %v, want %v", got, tt.want)
} }
@@ -153,7 +168,7 @@ func TestHostAllowList_IsEmpty(t *testing.T) {
} }
} }
func TestHostAllowList_Count(t *testing.T) { func TestHostWhitelist_Count(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
patterns []string patterns []string
@@ -170,7 +185,7 @@ func TestHostAllowList_Count(t *testing.T) {
want: 3, want: 3,
}, },
{ {
name: "suffix hosts only", name: "dot prefix hosts treated as exact",
patterns: []string{".a.com", ".b.com"}, patterns: []string{".a.com", ".b.com"},
want: 2, want: 2,
}, },
@@ -179,11 +194,16 @@ func TestHostAllowList_Count(t *testing.T) {
patterns: []string{"exact.com", ".suffix.com"}, patterns: []string{"exact.com", ".suffix.com"},
want: 2, want: 2,
}, },
{
name: "dot prefix deduplicates with exact",
patterns: []string{"example.com", ".example.com"},
want: 1,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
w := allowlist.New(tt.patterns) w := NewHostWhitelist(tt.patterns)
if got := w.Count(); got != tt.want { if got := w.Count(); got != tt.want {
t.Errorf("Count() = %v, want %v", got, tt.want) t.Errorf("Count() = %v, want %v", got, tt.want)
} }