1 Commits

Author SHA1 Message Date
user
e241b99d22 remove suffix matching from host whitelist
All checks were successful
check / check (push) Successful in 1m50s
Signatures are per-URL, so the whitelist should only support exact host
matches. Remove the suffix/wildcard matching that allowed patterns like
'.example.com' to bypass signature requirements for entire domain trees.

Leading dots in existing config entries are now stripped, so '.example.com'
becomes 'example.com' as an exact match (backwards-compatible normalisation).
2026-03-17 01:55:19 -07:00
7 changed files with 135 additions and 367 deletions

View File

@@ -96,11 +96,10 @@ expiration 1704067200:
4. URL:
`/v1/image/cdn.example.com/photos/cat.jpg/800x600.webp?sig=<base64url>&exp=1704067200`
**Whitelist patterns:**
- **Exact match**: `cdn.example.com` — matches only that host
- **Suffix match**: `.example.com` — matches `cdn.example.com`,
`images.example.com`, and `example.com`
**Whitelist entries** are exact host matches only (e.g.
`cdn.example.com`). Suffix/wildcard matching is not supported —
signatures are per-URL, so each allowed host must be listed
explicitly.
### Configuration

View File

@@ -13,8 +13,7 @@ state_dir: ./data
# Generate with: openssl rand -base64 32
signing_key: "CHANGE_ME_generate_with_openssl_rand_base64_32"
# Hosts that don't require signatures
# Use "." prefix for wildcard subdomain matching (e.g., ".example.com" matches "cdn.example.com")
# Hosts that don't require signatures (exact match only)
whitelist_hosts:
- s3.sneak.cloud
- static.sneak.cloud

View File

@@ -21,10 +21,6 @@ import (
//go:embed schema/*.sql
var schemaFS embed.FS
// bootstrapVersion is the migration that creates the schema_migrations
// table itself. It is applied before the normal migration loop.
const bootstrapVersion = "000"
// Params defines dependencies for Database.
type Params struct {
fx.In
@@ -88,36 +84,43 @@ func (s *Database) connect(ctx context.Context) error {
s.db = db
s.log.Info("database connected")
return applyMigrations(ctx, s.db, s.log)
return s.runMigrations(ctx)
}
// applyMigrations bootstraps the migrations table from 000.sql and then
// applies every remaining migration that has not been recorded yet.
func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
if err := bootstrapMigrationsTable(ctx, db, log); err != nil {
return err
func (s *Database) runMigrations(ctx context.Context) error {
// Create migrations tracking table
_, err := s.db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
@@ -126,24 +129,26 @@ func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
}
if count > 0 {
logDebug(log, "migration already applied", "version", version)
s.log.Debug("migration already applied", "version", version)
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
logInfo(log, "applying migration", "version", version)
s.log.Info("applying migration", "version", version)
_, err = db.ExecContext(ctx, string(content))
_, err = s.db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
_, err = db.ExecContext(ctx,
// Record migration as applied
_, err = s.db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
@@ -151,81 +156,12 @@ func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
logInfo(log, "migration applied successfully", "version", version)
s.log.Info("migration applied successfully", "version", version)
}
return nil
}
// bootstrapMigrationsTable ensures the schema_migrations table exists
// by applying 000.sql directly. For databases that already have the
// table (created by older code), it records version "000" for
// consistency.
func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error {
var tableExists int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableExists)
if err != nil {
return fmt.Errorf("failed to check for migrations table: %w", err)
}
if tableExists > 0 {
// Table already exists (from older inline-SQL code or a
// previous run). Make sure version "000" is recorded so the
// normal loop skips the bootstrap file.
var recorded int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
bootstrapVersion,
).Scan(&recorded)
if err != nil {
return fmt.Errorf("failed to check bootstrap migration status: %w", err)
}
if recorded == 0 {
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
bootstrapVersion,
)
if err != nil {
return fmt.Errorf("failed to record bootstrap migration: %w", err)
}
logInfo(log, "recorded bootstrap migration for existing table", "version", bootstrapVersion)
}
return nil
}
// Table does not exist — apply 000.sql to create it.
content, err := schemaFS.ReadFile("schema/000.sql")
if err != nil {
return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err)
}
logInfo(log, "applying bootstrap migration", "version", bootstrapVersion)
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply bootstrap migration: %w", err)
}
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
bootstrapVersion,
)
if err != nil {
return fmt.Errorf("failed to record bootstrap migration: %w", err)
}
logInfo(log, "bootstrap migration applied successfully", "version", bootstrapVersion)
return nil
}
// DB returns the underlying sql.DB.
func (s *Database) DB() *sql.DB {
return s.db
@@ -235,19 +171,72 @@ func (s *Database) DB() *sql.DB {
// This is useful for testing where you want to use the real schema
// without the full fx lifecycle.
func ApplyMigrations(db *sql.DB) error {
return applyMigrations(context.Background(), db, nil)
ctx := context.Background()
// Create migrations tracking table
_, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// logInfo logs at info level when a logger is available.
func logInfo(log *slog.Logger, msg string, args ...any) {
if log != nil {
log.Info(msg, args...)
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if count > 0 {
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
// Record migration as applied
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
}
// logDebug logs at debug level when a logger is available.
func logDebug(log *slog.Logger, msg string, args ...any) {
if log != nil {
log.Debug(msg, args...)
}
return nil
}

View File

@@ -1,199 +0,0 @@
package database
import (
"context"
"database/sql"
"testing"
_ "modernc.org/sqlite" // SQLite driver registration
)
// openTestDB returns a fresh in-memory SQLite database.
func openTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open test db: %v", err)
}
t.Cleanup(func() { db.Close() })
return db
}
func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) {
db := openTestDB(t)
if err := ApplyMigrations(db); err != nil {
t.Fatalf("ApplyMigrations failed: %v", err)
}
// The schema_migrations table must exist and contain at least
// version "000" (the bootstrap) and "001" (the initial schema).
rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version")
if err != nil {
t.Fatalf("failed to query schema_migrations: %v", err)
}
defer rows.Close()
var versions []string
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
t.Fatalf("failed to scan version: %v", err)
}
versions = append(versions, v)
}
if err := rows.Err(); err != nil {
t.Fatalf("row iteration error: %v", err)
}
if len(versions) < 2 {
t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions)
}
if versions[0] != "000" {
t.Errorf("first recorded migration = %q, want %q", versions[0], "000")
}
if versions[1] != "001" {
t.Errorf("second recorded migration = %q, want %q", versions[1], "001")
}
// Verify that the application tables created by 001.sql exist.
for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} {
var count int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
table,
).Scan(&count)
if err != nil {
t.Fatalf("failed to check for table %s: %v", table, err)
}
if count != 1 {
t.Errorf("table %s does not exist after migrations", table)
}
}
}
func TestApplyMigrations_Idempotent(t *testing.T) {
db := openTestDB(t)
if err := ApplyMigrations(db); err != nil {
t.Fatalf("first ApplyMigrations failed: %v", err)
}
// Running a second time must succeed without errors.
if err := ApplyMigrations(db); err != nil {
t.Fatalf("second ApplyMigrations failed: %v", err)
}
// Verify no duplicate rows in schema_migrations.
var count int
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = '000'").Scan(&count)
if err != nil {
t.Fatalf("failed to count 000 rows: %v", err)
}
if count != 1 {
t.Errorf("expected exactly 1 row for version 000, got %d", count)
}
}
func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := bootstrapMigrationsTable(ctx, db, nil); err != nil {
t.Fatalf("bootstrapMigrationsTable failed: %v", err)
}
// schema_migrations table must exist.
var tableCount int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableCount)
if err != nil {
t.Fatalf("failed to check for table: %v", err)
}
if tableCount != 1 {
t.Fatalf("schema_migrations table not created")
}
// Version "000" must be recorded.
var recorded int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '000'",
).Scan(&recorded)
if err != nil {
t.Fatalf("failed to check version: %v", err)
}
if recorded != 1 {
t.Errorf("expected version 000 to be recorded, got count %d", recorded)
}
}
func TestBootstrapMigrationsTable_ExistingTableBackwardsCompat(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
// Simulate an older database that created the table via inline SQL
// (without recording version "000").
_, err := db.Exec(`
CREATE TABLE schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
t.Fatalf("failed to create legacy table: %v", err)
}
// Insert a fake migration to prove the table already existed.
_, err = db.Exec("INSERT INTO schema_migrations (version) VALUES ('001')")
if err != nil {
t.Fatalf("failed to insert legacy version: %v", err)
}
if err := bootstrapMigrationsTable(ctx, db, nil); err != nil {
t.Fatalf("bootstrapMigrationsTable failed: %v", err)
}
// Version "000" must now be recorded.
var recorded int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '000'",
).Scan(&recorded)
if err != nil {
t.Fatalf("failed to check version: %v", err)
}
if recorded != 1 {
t.Errorf("expected version 000 to be recorded for legacy DB, got count %d", recorded)
}
// The existing "001" row must still be there.
var legacyCount int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '001'",
).Scan(&legacyCount)
if err != nil {
t.Fatalf("failed to check legacy version: %v", err)
}
if legacyCount != 1 {
t.Errorf("legacy version 001 row missing after bootstrap")
}
}

View File

@@ -1,9 +0,0 @@
-- Migration 000: Schema migrations tracking table
-- This must be the first migration applied. The bootstrap logic in
-- database.go applies it directly (bypassing the normal migration
-- loop) when the schema_migrations table does not yet exist.
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@@ -5,23 +5,20 @@ import (
"strings"
)
// HostWhitelist implements the Whitelist interface for checking allowed source hosts.
// HostWhitelist checks whether a source host is allowed without a signature.
// Only exact host matches are supported. Signatures are per-URL, so
// wildcard/suffix matching is intentionally not provided.
type HostWhitelist struct {
// exactHosts contains hosts that must match exactly (e.g., "cdn.example.com")
exactHosts map[string]struct{}
// suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com")
suffixHosts []string
}
// NewHostWhitelist creates a whitelist from a list of host patterns.
// Patterns starting with "." are treated as suffix matches.
// Examples:
// - "cdn.example.com" - exact match only
// - ".example.com" - matches cdn.example.com, images.example.com, etc.
// NewHostWhitelist creates a whitelist from a list of hostnames.
// Each entry is treated as an exact host match. Leading dots are
// stripped so that legacy ".example.com" entries become "example.com".
func NewHostWhitelist(patterns []string) *HostWhitelist {
w := &HostWhitelist{
exactHosts: make(map[string]struct{}),
suffixHosts: make([]string, 0),
}
for _, pattern := range patterns {
@@ -30,9 +27,11 @@ func NewHostWhitelist(patterns []string) *HostWhitelist {
continue
}
if strings.HasPrefix(pattern, ".") {
w.suffixHosts = append(w.suffixHosts, pattern)
} else {
// Strip leading dot — suffix matching is no longer supported;
// ".example.com" is normalised to "example.com" as an exact entry.
pattern = strings.TrimPrefix(pattern, ".")
if pattern != "" {
w.exactHosts[pattern] = struct{}{}
}
}
@@ -40,7 +39,7 @@ func NewHostWhitelist(patterns []string) *HostWhitelist {
return w
}
// IsWhitelisted checks if a URL's host is in the whitelist.
// IsWhitelisted checks if a URL's host is in the whitelist (exact match only).
func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
if u == nil {
return false
@@ -51,32 +50,17 @@ func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
return false
}
// Check exact match
if _, ok := w.exactHosts[host]; ok {
return true
}
_, ok := w.exactHosts[host]
// Check suffix match
for _, suffix := range w.suffixHosts {
if strings.HasSuffix(host, suffix) {
return true
}
// Also match if host equals the suffix without the leading dot
// e.g., pattern ".example.com" should match "example.com"
if host == strings.TrimPrefix(suffix, ".") {
return true
}
}
return false
return ok
}
// IsEmpty returns true if the whitelist has no entries.
func (w *HostWhitelist) IsEmpty() bool {
return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0
return len(w.exactHosts) == 0
}
// Count returns the total number of whitelist entries.
func (w *HostWhitelist) Count() int {
return len(w.exactHosts) + len(w.suffixHosts)
return len(w.exactHosts)
}

View File

@@ -31,41 +31,41 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
want: false,
},
{
name: "suffix match",
patterns: []string{".example.com"},
name: "no suffix matching for subdomains",
patterns: []string{"example.com"},
testURL: "https://cdn.example.com/image.jpg",
want: true,
want: false,
},
{
name: "suffix match deep subdomain",
patterns: []string{".example.com"},
testURL: "https://cdn.images.example.com/image.jpg",
want: true,
},
{
name: "suffix match apex domain",
name: "leading dot stripped to exact match",
patterns: []string{".example.com"},
testURL: "https://example.com/image.jpg",
want: true,
},
{
name: "suffix match not found",
name: "leading dot does not enable suffix matching",
patterns: []string{".example.com"},
testURL: "https://notexample.com/image.jpg",
testURL: "https://cdn.example.com/image.jpg",
want: false,
},
{
name: "suffix match partial not allowed",
name: "leading dot does not match deep subdomain",
patterns: []string{".example.com"},
testURL: "https://fakeexample.com/image.jpg",
testURL: "https://cdn.images.example.com/image.jpg",
want: false,
},
{
name: "multiple patterns",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
name: "multiple patterns exact only",
patterns: []string{"cdn.example.com", "photos.images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg",
want: true,
},
{
name: "multiple patterns no suffix leak",
patterns: []string{"cdn.example.com", "images.org"},
testURL: "https://photos.images.org/image.jpg",
want: false,
},
{
name: "empty whitelist",
patterns: []string{},
@@ -86,7 +86,7 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
},
{
name: "whitespace in patterns",
patterns: []string{" cdn.example.com ", " .other.com "},
patterns: []string{" cdn.example.com ", " other.com "},
testURL: "https://cdn.example.com/image.jpg",
want: true,
},
@@ -139,6 +139,11 @@ func TestHostWhitelist_IsEmpty(t *testing.T) {
patterns: []string{"example.com"},
want: false,
},
{
name: "leading dot normalised to entry",
patterns: []string{".example.com"},
want: false,
},
}
for _, tt := range tests {
@@ -168,14 +173,14 @@ func TestHostWhitelist_Count(t *testing.T) {
want: 3,
},
{
name: "suffix hosts only",
name: "leading dots normalised to exact",
patterns: []string{".a.com", ".b.com"},
want: 2,
},
{
name: "mixed",
patterns: []string{"exact.com", ".suffix.com"},
want: 2,
name: "mixed deduplication",
patterns: []string{"example.com", ".example.com"},
want: 1,
},
}