1 Commits

Author SHA1 Message Date
user
e241b99d22 remove suffix matching from host whitelist
All checks were successful
check / check (push) Successful in 1m50s
Signatures are per-URL, so the whitelist should only support exact host
matches. Remove the suffix/wildcard matching that allowed patterns like
'.example.com' to bypass signature requirements for entire domain trees.

Leading dots in existing config entries are now stripped, so '.example.com'
becomes 'example.com' as an exact match (backwards-compatible normalisation).
2026-03-17 01:55:19 -07:00
11 changed files with 161 additions and 530 deletions

View File

@@ -96,11 +96,10 @@ expiration 1704067200:
4. URL:
`/v1/image/cdn.example.com/photos/cat.jpg/800x600.webp?sig=<base64url>&exp=1704067200`
**Whitelist patterns:**
- **Exact match**: `cdn.example.com` — matches only that host
- **Suffix match**: `.example.com` — matches `cdn.example.com`,
`images.example.com`, and `example.com`
**Whitelist entries** are exact host matches only (e.g.
`cdn.example.com`). Suffix/wildcard matching is not supported —
signatures are per-URL, so each allowed host must be listed
explicitly.
### Configuration

View File

@@ -13,8 +13,7 @@ state_dir: ./data
# Generate with: openssl rand -base64 32
signing_key: "CHANGE_ME_generate_with_openssl_rand_base64_32"
# Hosts that don't require signatures
# Use "." prefix for wildcard subdomain matching (e.g., ".example.com" matches "cdn.example.com")
# Hosts that don't require signatures (exact match only)
whitelist_hosts:
- s3.sneak.cloud
- static.sneak.cloud

View File

@@ -21,10 +21,6 @@ import (
//go:embed schema/*.sql
var schemaFS embed.FS
// bootstrapVersion is the migration that creates the schema_migrations
// table itself. It is applied before the normal migration loop.
const bootstrapVersion = "000"
// Params defines dependencies for Database.
type Params struct {
fx.In
@@ -39,41 +35,6 @@ type Database struct {
config *config.Config
}
// ParseMigrationVersion extracts the numeric version prefix from a migration
// filename. Filenames must follow the pattern "<version>.sql" or
// "<version>_<description>.sql", where version is a zero-padded numeric
// string (e.g. "001", "002"). Returns the version string and an error if
// the filename does not match the expected pattern.
func ParseMigrationVersion(filename string) (string, error) {
name := strings.TrimSuffix(filename, filepath.Ext(filename))
if name == "" {
return "", fmt.Errorf("invalid migration filename %q: empty name", filename)
}
// Split on underscore to separate version from description.
// If there's no underscore, the entire stem is the version.
version := name
if idx := strings.IndexByte(name, '_'); idx >= 0 {
version = name[:idx]
}
if version == "" {
return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename)
}
// Validate the version is purely numeric.
for _, ch := range version {
if ch < '0' || ch > '9' {
return "", fmt.Errorf(
"invalid migration filename %q: version %q contains non-numeric character %q",
filename, version, string(ch),
)
}
}
return version, nil
}
// New creates a new Database instance.
func New(lc fx.Lifecycle, params Params) (*Database, error) {
s := &Database{
@@ -123,139 +84,43 @@ func (s *Database) connect(ctx context.Context) error {
s.db = db
s.log.Info("database connected")
return ApplyMigrations(ctx, s.db, s.log)
return s.runMigrations(ctx)
}
// collectMigrations reads the embedded schema directory and returns
// migration filenames sorted lexicographically.
func collectMigrations() ([]string, error) {
entries, err := schemaFS.ReadDir("schema")
func (s *Database) runMigrations(ctx context.Context) error {
// Create migrations tracking table
_, err := s.db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return nil, fmt.Errorf("failed to read schema directory: %w", err)
return fmt.Errorf("failed to create migrations table: %w", err)
}
var migrations []string
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
return migrations, nil
}
// bootstrapMigrationsTable ensures the schema_migrations table exists
// by applying 000.sql directly. For databases that already have the
// table (created by older code), it records version "000" for
// consistency.
func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error {
var tableExists int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableExists)
if err != nil {
return fmt.Errorf("failed to check for migrations table: %w", err)
}
if tableExists > 0 {
return ensureBootstrapVersionRecorded(ctx, db, log)
}
return applyBootstrapMigration(ctx, db, log)
}
// ensureBootstrapVersionRecorded checks whether version "000" is already
// recorded in an existing schema_migrations table and inserts it if not.
func ensureBootstrapVersionRecorded(ctx context.Context, db *sql.DB, log *slog.Logger) error {
var recorded int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
bootstrapVersion,
).Scan(&recorded)
if err != nil {
return fmt.Errorf("failed to check bootstrap migration status: %w", err)
}
if recorded > 0 {
return nil
}
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
bootstrapVersion,
)
if err != nil {
return fmt.Errorf("failed to record bootstrap migration: %w", err)
}
if log != nil {
log.Info("recorded bootstrap migration for existing table", "version", bootstrapVersion)
}
return nil
}
// applyBootstrapMigration reads and executes 000.sql to create the
// schema_migrations table on a fresh database.
func applyBootstrapMigration(ctx context.Context, db *sql.DB, log *slog.Logger) error {
content, err := schemaFS.ReadFile("schema/000.sql")
if err != nil {
return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err)
}
if log != nil {
log.Info("applying bootstrap migration", "version", bootstrapVersion)
}
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply bootstrap migration: %w", err)
}
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
bootstrapVersion,
)
if err != nil {
return fmt.Errorf("failed to record bootstrap migration: %w", err)
}
if log != nil {
log.Info("bootstrap migration applied successfully", "version", bootstrapVersion)
}
return nil
}
// ApplyMigrations applies all pending migrations to db. An optional logger
// may be provided for informational output; pass nil for silent operation.
// This is exported so tests can apply the real schema without the full fx
// lifecycle.
func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
if err := bootstrapMigrationsTable(ctx, db, log); err != nil {
return err
}
migrations, err := collectMigrations()
if err != nil {
return err
}
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version, parseErr := ParseMigrationVersion(migration)
if parseErr != nil {
return parseErr
}
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied.
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
@@ -264,40 +129,34 @@ func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
}
if count > 0 {
if log != nil {
log.Debug("migration already applied", "version", version)
}
s.log.Debug("migration already applied", "version", version)
continue
}
// Read and apply migration.
content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration))
if readErr != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, readErr)
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
if log != nil {
log.Info("applying migration", "version", version)
s.log.Info("applying migration", "version", version)
_, err = s.db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
_, execErr := db.ExecContext(ctx, string(content))
if execErr != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, execErr)
}
// Record migration as applied.
_, recErr := db.ExecContext(ctx,
// Record migration as applied
_, err = s.db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if recErr != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, recErr)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
if log != nil {
log.Info("migration applied successfully", "version", version)
}
s.log.Info("migration applied successfully", "version", version)
}
return nil
@@ -307,3 +166,77 @@ func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
func (s *Database) DB() *sql.DB {
return s.db
}
// ApplyMigrations applies all migrations to the given database.
// This is useful for testing where you want to use the real schema
// without the full fx lifecycle.
func ApplyMigrations(db *sql.DB) error {
ctx := context.Background()
// Create migrations tracking table
_, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if count > 0 {
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
// Record migration as applied
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
}
return nil
}

View File

@@ -1,279 +0,0 @@
package database
import (
"context"
"database/sql"
"testing"
_ "modernc.org/sqlite" // SQLite driver registration
)
// openTestDB returns a fresh in-memory SQLite database.
func openTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open test db: %v", err)
}
t.Cleanup(func() { db.Close() })
return db
}
func TestParseMigrationVersion(t *testing.T) {
tests := []struct {
name string
filename string
want string
wantErr bool
}{
{
name: "version only",
filename: "001.sql",
want: "001",
},
{
name: "version with description",
filename: "001_initial_schema.sql",
want: "001",
},
{
name: "multi-digit version",
filename: "042_add_indexes.sql",
want: "042",
},
{
name: "long version number",
filename: "00001_long_prefix.sql",
want: "00001",
},
{
name: "description with multiple underscores",
filename: "003_add_user_auth_tables.sql",
want: "003",
},
{
name: "empty filename",
filename: ".sql",
wantErr: true,
},
{
name: "leading underscore",
filename: "_description.sql",
wantErr: true,
},
{
name: "non-numeric version",
filename: "abc_migration.sql",
wantErr: true,
},
{
name: "mixed alphanumeric version",
filename: "001a_migration.sql",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseMigrationVersion(tt.filename)
if tt.wantErr {
if err == nil {
t.Errorf("ParseMigrationVersion(%q) expected error, got %q", tt.filename, got)
}
return
}
if err != nil {
t.Errorf("ParseMigrationVersion(%q) unexpected error: %v", tt.filename, err)
return
}
if got != tt.want {
t.Errorf("ParseMigrationVersion(%q) = %q, want %q", tt.filename, got, tt.want)
}
})
}
}
func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := ApplyMigrations(ctx, db, nil); err != nil {
t.Fatalf("ApplyMigrations failed: %v", err)
}
// The schema_migrations table must exist and contain at least
// version "000" (the bootstrap) and "001" (the initial schema).
rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version")
if err != nil {
t.Fatalf("failed to query schema_migrations: %v", err)
}
defer rows.Close()
var versions []string
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
t.Fatalf("failed to scan version: %v", err)
}
versions = append(versions, v)
}
if err := rows.Err(); err != nil {
t.Fatalf("row iteration error: %v", err)
}
if len(versions) < 2 {
t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions)
}
if versions[0] != "000" {
t.Errorf("first recorded migration = %q, want %q", versions[0], "000")
}
if versions[1] != "001" {
t.Errorf("second recorded migration = %q, want %q", versions[1], "001")
}
// Verify that the application tables created by 001.sql exist.
for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} {
var count int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
table,
).Scan(&count)
if err != nil {
t.Fatalf("failed to check for table %s: %v", table, err)
}
if count != 1 {
t.Errorf("table %s does not exist after migrations", table)
}
}
}
func TestApplyMigrations_Idempotent(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := ApplyMigrations(ctx, db, nil); err != nil {
t.Fatalf("first ApplyMigrations failed: %v", err)
}
// Running a second time must succeed without errors.
if err := ApplyMigrations(ctx, db, nil); err != nil {
t.Fatalf("second ApplyMigrations failed: %v", err)
}
// Verify no duplicate rows in schema_migrations.
var count int
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = '000'").Scan(&count)
if err != nil {
t.Fatalf("failed to count 000 rows: %v", err)
}
if count != 1 {
t.Errorf("expected exactly 1 row for version 000, got %d", count)
}
}
func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := bootstrapMigrationsTable(ctx, db, nil); err != nil {
t.Fatalf("bootstrapMigrationsTable failed: %v", err)
}
// schema_migrations table must exist.
var tableCount int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableCount)
if err != nil {
t.Fatalf("failed to check for table: %v", err)
}
if tableCount != 1 {
t.Fatalf("schema_migrations table not created")
}
// Version "000" must be recorded.
var recorded int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '000'",
).Scan(&recorded)
if err != nil {
t.Fatalf("failed to check version: %v", err)
}
if recorded != 1 {
t.Errorf("expected version 000 to be recorded, got count %d", recorded)
}
}
func TestBootstrapMigrationsTable_ExistingTableBackwardsCompat(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
// Simulate an older database that created the table via inline SQL
// (without recording version "000").
_, err := db.Exec(`
CREATE TABLE schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
t.Fatalf("failed to create legacy table: %v", err)
}
// Insert a fake migration to prove the table already existed.
_, err = db.Exec("INSERT INTO schema_migrations (version) VALUES ('001')")
if err != nil {
t.Fatalf("failed to insert legacy version: %v", err)
}
if err := bootstrapMigrationsTable(ctx, db, nil); err != nil {
t.Fatalf("bootstrapMigrationsTable failed: %v", err)
}
// Version "000" must now be recorded.
var recorded int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '000'",
).Scan(&recorded)
if err != nil {
t.Fatalf("failed to check version: %v", err)
}
if recorded != 1 {
t.Errorf("expected version 000 to be recorded for legacy DB, got count %d", recorded)
}
// The existing "001" row must still be there.
var legacyCount int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '001'",
).Scan(&legacyCount)
if err != nil {
t.Fatalf("failed to check legacy version: %v", err)
}
if legacyCount != 1 {
t.Errorf("legacy version 001 row missing after bootstrap")
}
}

View File

@@ -1,9 +0,0 @@
-- Migration 000: Schema migrations tracking table
-- This must be the first migration applied. The bootstrap logic in
-- database.go applies it directly (bypassing the normal migration
-- loop) when the schema_migrations table does not yet exist.
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@@ -82,7 +82,7 @@ func setupTestDB(t *testing.T) *sql.DB {
t.Fatalf("failed to open test db: %v", err)
}
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
if err := database.ApplyMigrations(db); err != nil {
t.Fatalf("failed to apply migrations: %v", err)
}

View File

@@ -16,7 +16,7 @@ func setupStatsTestDB(t *testing.T) *sql.DB {
if err != nil {
t.Fatal(err)
}
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
if err := database.ApplyMigrations(db); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })

View File

@@ -2,7 +2,6 @@ package imgcache
import (
"bytes"
"context"
"database/sql"
"image"
"image/color"
@@ -194,7 +193,7 @@ func setupServiceTestDB(t *testing.T) *sql.DB {
}
// Use the real production schema via migrations
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
if err := database.ApplyMigrations(db); err != nil {
t.Fatalf("failed to apply migrations: %v", err)
}

View File

@@ -5,23 +5,20 @@ import (
"strings"
)
// HostWhitelist implements the Whitelist interface for checking allowed source hosts.
// HostWhitelist checks whether a source host is allowed without a signature.
// Only exact host matches are supported. Signatures are per-URL, so
// wildcard/suffix matching is intentionally not provided.
type HostWhitelist struct {
// exactHosts contains hosts that must match exactly (e.g., "cdn.example.com")
exactHosts map[string]struct{}
// suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com")
suffixHosts []string
}
// NewHostWhitelist creates a whitelist from a list of host patterns.
// Patterns starting with "." are treated as suffix matches.
// Examples:
// - "cdn.example.com" - exact match only
// - ".example.com" - matches cdn.example.com, images.example.com, etc.
// NewHostWhitelist creates a whitelist from a list of hostnames.
// Each entry is treated as an exact host match. Leading dots are
// stripped so that legacy ".example.com" entries become "example.com".
func NewHostWhitelist(patterns []string) *HostWhitelist {
w := &HostWhitelist{
exactHosts: make(map[string]struct{}),
suffixHosts: make([]string, 0),
}
for _, pattern := range patterns {
@@ -30,9 +27,11 @@ func NewHostWhitelist(patterns []string) *HostWhitelist {
continue
}
if strings.HasPrefix(pattern, ".") {
w.suffixHosts = append(w.suffixHosts, pattern)
} else {
// Strip leading dot — suffix matching is no longer supported;
// ".example.com" is normalised to "example.com" as an exact entry.
pattern = strings.TrimPrefix(pattern, ".")
if pattern != "" {
w.exactHosts[pattern] = struct{}{}
}
}
@@ -40,7 +39,7 @@ func NewHostWhitelist(patterns []string) *HostWhitelist {
return w
}
// IsWhitelisted checks if a URL's host is in the whitelist.
// IsWhitelisted checks if a URL's host is in the whitelist (exact match only).
func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
if u == nil {
return false
@@ -51,32 +50,17 @@ func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
return false
}
// Check exact match
if _, ok := w.exactHosts[host]; ok {
return true
}
_, ok := w.exactHosts[host]
// Check suffix match
for _, suffix := range w.suffixHosts {
if strings.HasSuffix(host, suffix) {
return true
}
// Also match if host equals the suffix without the leading dot
// e.g., pattern ".example.com" should match "example.com"
if host == strings.TrimPrefix(suffix, ".") {
return true
}
}
return false
return ok
}
// IsEmpty returns true if the whitelist has no entries.
func (w *HostWhitelist) IsEmpty() bool {
return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0
return len(w.exactHosts) == 0
}
// Count returns the total number of whitelist entries.
func (w *HostWhitelist) Count() int {
return len(w.exactHosts) + len(w.suffixHosts)
return len(w.exactHosts)
}

View File

@@ -31,41 +31,41 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
want: false,
},
{
name: "suffix match",
patterns: []string{".example.com"},
name: "no suffix matching for subdomains",
patterns: []string{"example.com"},
testURL: "https://cdn.example.com/image.jpg",
want: true,
want: false,
},
{
name: "suffix match deep subdomain",
patterns: []string{".example.com"},
testURL: "https://cdn.images.example.com/image.jpg",
want: true,
},
{
name: "suffix match apex domain",
name: "leading dot stripped to exact match",
patterns: []string{".example.com"},
testURL: "https://example.com/image.jpg",
want: true,
},
{
name: "suffix match not found",
name: "leading dot does not enable suffix matching",
patterns: []string{".example.com"},
testURL: "https://notexample.com/image.jpg",
testURL: "https://cdn.example.com/image.jpg",
want: false,
},
{
name: "suffix match partial not allowed",
name: "leading dot does not match deep subdomain",
patterns: []string{".example.com"},
testURL: "https://fakeexample.com/image.jpg",
testURL: "https://cdn.images.example.com/image.jpg",
want: false,
},
{
name: "multiple patterns",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
name: "multiple patterns exact only",
patterns: []string{"cdn.example.com", "photos.images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg",
want: true,
},
{
name: "multiple patterns no suffix leak",
patterns: []string{"cdn.example.com", "images.org"},
testURL: "https://photos.images.org/image.jpg",
want: false,
},
{
name: "empty whitelist",
patterns: []string{},
@@ -86,7 +86,7 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
},
{
name: "whitespace in patterns",
patterns: []string{" cdn.example.com ", " .other.com "},
patterns: []string{" cdn.example.com ", " other.com "},
testURL: "https://cdn.example.com/image.jpg",
want: true,
},
@@ -139,6 +139,11 @@ func TestHostWhitelist_IsEmpty(t *testing.T) {
patterns: []string{"example.com"},
want: false,
},
{
name: "leading dot normalised to entry",
patterns: []string{".example.com"},
want: false,
},
}
for _, tt := range tests {
@@ -168,14 +173,14 @@ func TestHostWhitelist_Count(t *testing.T) {
want: 3,
},
{
name: "suffix hosts only",
name: "leading dots normalised to exact",
patterns: []string{".a.com", ".b.com"},
want: 2,
},
{
name: "mixed",
patterns: []string{"exact.com", ".suffix.com"},
want: 2,
name: "mixed deduplication",
patterns: []string{"example.com", ".example.com"},
want: 1,
},
}