3 Commits

Author SHA1 Message Date
user
55bb620de0 docs: update README and config to reflect exact-match-only whitelist
All checks were successful
check / check (push) Successful in 5s
Remove suffix match documentation and config comments since whitelist
now only supports exact host matches.
2026-03-15 11:18:44 -07:00
user
215ddb7f72 fix: remove suffix matching from host whitelist
Whitelist entries now support exact host matches only. Leading dots
in patterns are stripped for backwards compatibility (.example.com
becomes an exact match for example.com). Suffix matching that would
match arbitrary subdomains is no longer supported.

Closes #27
2026-03-15 11:18:25 -07:00
user
27739da046 test: add failing tests for removing suffix matching from whitelist
Suffix matching (.example.com matching subdomains) should not be
supported. Whitelist entries should be exact host matches only.
Leading dots should be stripped and treated as exact matches.
2026-03-15 11:18:01 -07:00
7 changed files with 141 additions and 354 deletions

View File

@@ -98,9 +98,7 @@ expiration 1704067200:
**Whitelist patterns:**
- **Exact match**: `cdn.example.com` — matches only that host
- **Suffix match**: `.example.com` — matches `cdn.example.com`,
`images.example.com`, and `example.com`
- **Exact match only**: `cdn.example.com` — matches only that host
### Configuration

View File

@@ -13,8 +13,7 @@ state_dir: ./data
# Generate with: openssl rand -base64 32
signing_key: "CHANGE_ME_generate_with_openssl_rand_base64_32"
# Hosts that don't require signatures
# Use "." prefix for wildcard subdomain matching (e.g., ".example.com" matches "cdn.example.com")
# Hosts that don't require signatures (exact match only)
whitelist_hosts:
- s3.sneak.cloud
- static.sneak.cloud

View File

@@ -21,10 +21,6 @@ import (
//go:embed schema/*.sql
var schemaFS embed.FS
// bootstrapVersion is the migration that creates the schema_migrations
// table itself. It is applied before the normal migration loop.
const bootstrapVersion = "000"
// Params defines dependencies for Database.
type Params struct {
fx.In
@@ -88,36 +84,43 @@ func (s *Database) connect(ctx context.Context) error {
s.db = db
s.log.Info("database connected")
return applyMigrations(ctx, s.db, s.log)
return s.runMigrations(ctx)
}
// applyMigrations bootstraps the migrations table from 000.sql and then
// applies every remaining migration that has not been recorded yet.
func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
if err := bootstrapMigrationsTable(ctx, db, log); err != nil {
return err
func (s *Database) runMigrations(ctx context.Context) error {
// Create migrations tracking table
_, err := s.db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
@@ -126,24 +129,26 @@ func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
}
if count > 0 {
logDebug(log, "migration already applied", "version", version)
s.log.Debug("migration already applied", "version", version)
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
logInfo(log, "applying migration", "version", version)
s.log.Info("applying migration", "version", version)
_, err = db.ExecContext(ctx, string(content))
_, err = s.db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
_, err = db.ExecContext(ctx,
// Record migration as applied
_, err = s.db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
@@ -151,81 +156,12 @@ func applyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
logInfo(log, "migration applied successfully", "version", version)
s.log.Info("migration applied successfully", "version", version)
}
return nil
}
// bootstrapMigrationsTable ensures the schema_migrations table exists
// by applying 000.sql directly. For databases that already have the
// table (created by older code), it records version "000" for
// consistency.
func bootstrapMigrationsTable(ctx context.Context, db *sql.DB, log *slog.Logger) error {
var tableExists int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableExists)
if err != nil {
return fmt.Errorf("failed to check for migrations table: %w", err)
}
if tableExists > 0 {
// Table already exists (from older inline-SQL code or a
// previous run). Make sure version "000" is recorded so the
// normal loop skips the bootstrap file.
var recorded int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
bootstrapVersion,
).Scan(&recorded)
if err != nil {
return fmt.Errorf("failed to check bootstrap migration status: %w", err)
}
if recorded == 0 {
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
bootstrapVersion,
)
if err != nil {
return fmt.Errorf("failed to record bootstrap migration: %w", err)
}
logInfo(log, "recorded bootstrap migration for existing table", "version", bootstrapVersion)
}
return nil
}
// Table does not exist — apply 000.sql to create it.
content, err := schemaFS.ReadFile("schema/000.sql")
if err != nil {
return fmt.Errorf("failed to read bootstrap migration 000.sql: %w", err)
}
logInfo(log, "applying bootstrap migration", "version", bootstrapVersion)
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply bootstrap migration: %w", err)
}
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
bootstrapVersion,
)
if err != nil {
return fmt.Errorf("failed to record bootstrap migration: %w", err)
}
logInfo(log, "bootstrap migration applied successfully", "version", bootstrapVersion)
return nil
}
// DB returns the underlying sql.DB.
func (s *Database) DB() *sql.DB {
return s.db
@@ -235,19 +171,72 @@ func (s *Database) DB() *sql.DB {
// This is useful for testing where you want to use the real schema
// without the full fx lifecycle.
func ApplyMigrations(db *sql.DB) error {
return applyMigrations(context.Background(), db, nil)
}
ctx := context.Background()
// logInfo logs at info level when a logger is available.
func logInfo(log *slog.Logger, msg string, args ...any) {
if log != nil {
log.Info(msg, args...)
// Create migrations tracking table
_, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
}
}
// logDebug logs at debug level when a logger is available.
func logDebug(log *slog.Logger, msg string, args ...any) {
if log != nil {
log.Debug(msg, args...)
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if count > 0 {
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
// Record migration as applied
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
}
return nil
}

View File

@@ -1,199 +0,0 @@
package database
import (
"context"
"database/sql"
"testing"
_ "modernc.org/sqlite" // SQLite driver registration
)
// openTestDB returns a fresh in-memory SQLite database.
func openTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open test db: %v", err)
}
t.Cleanup(func() { db.Close() })
return db
}
func TestApplyMigrations_CreatesSchemaAndTables(t *testing.T) {
db := openTestDB(t)
if err := ApplyMigrations(db); err != nil {
t.Fatalf("ApplyMigrations failed: %v", err)
}
// The schema_migrations table must exist and contain at least
// version "000" (the bootstrap) and "001" (the initial schema).
rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version")
if err != nil {
t.Fatalf("failed to query schema_migrations: %v", err)
}
defer rows.Close()
var versions []string
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
t.Fatalf("failed to scan version: %v", err)
}
versions = append(versions, v)
}
if err := rows.Err(); err != nil {
t.Fatalf("row iteration error: %v", err)
}
if len(versions) < 2 {
t.Fatalf("expected at least 2 migrations recorded, got %d: %v", len(versions), versions)
}
if versions[0] != "000" {
t.Errorf("first recorded migration = %q, want %q", versions[0], "000")
}
if versions[1] != "001" {
t.Errorf("second recorded migration = %q, want %q", versions[1], "001")
}
// Verify that the application tables created by 001.sql exist.
for _, table := range []string{"source_content", "source_metadata", "output_content", "request_cache", "negative_cache", "cache_stats"} {
var count int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?",
table,
).Scan(&count)
if err != nil {
t.Fatalf("failed to check for table %s: %v", table, err)
}
if count != 1 {
t.Errorf("table %s does not exist after migrations", table)
}
}
}
func TestApplyMigrations_Idempotent(t *testing.T) {
db := openTestDB(t)
if err := ApplyMigrations(db); err != nil {
t.Fatalf("first ApplyMigrations failed: %v", err)
}
// Running a second time must succeed without errors.
if err := ApplyMigrations(db); err != nil {
t.Fatalf("second ApplyMigrations failed: %v", err)
}
// Verify no duplicate rows in schema_migrations.
var count int
err := db.QueryRow("SELECT COUNT(*) FROM schema_migrations WHERE version = '000'").Scan(&count)
if err != nil {
t.Fatalf("failed to count 000 rows: %v", err)
}
if count != 1 {
t.Errorf("expected exactly 1 row for version 000, got %d", count)
}
}
func TestBootstrapMigrationsTable_FreshDatabase(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
if err := bootstrapMigrationsTable(ctx, db, nil); err != nil {
t.Fatalf("bootstrapMigrationsTable failed: %v", err)
}
// schema_migrations table must exist.
var tableCount int
err := db.QueryRow(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
).Scan(&tableCount)
if err != nil {
t.Fatalf("failed to check for table: %v", err)
}
if tableCount != 1 {
t.Fatalf("schema_migrations table not created")
}
// Version "000" must be recorded.
var recorded int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '000'",
).Scan(&recorded)
if err != nil {
t.Fatalf("failed to check version: %v", err)
}
if recorded != 1 {
t.Errorf("expected version 000 to be recorded, got count %d", recorded)
}
}
func TestBootstrapMigrationsTable_ExistingTableBackwardsCompat(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
// Simulate an older database that created the table via inline SQL
// (without recording version "000").
_, err := db.Exec(`
CREATE TABLE schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
t.Fatalf("failed to create legacy table: %v", err)
}
// Insert a fake migration to prove the table already existed.
_, err = db.Exec("INSERT INTO schema_migrations (version) VALUES ('001')")
if err != nil {
t.Fatalf("failed to insert legacy version: %v", err)
}
if err := bootstrapMigrationsTable(ctx, db, nil); err != nil {
t.Fatalf("bootstrapMigrationsTable failed: %v", err)
}
// Version "000" must now be recorded.
var recorded int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '000'",
).Scan(&recorded)
if err != nil {
t.Fatalf("failed to check version: %v", err)
}
if recorded != 1 {
t.Errorf("expected version 000 to be recorded for legacy DB, got count %d", recorded)
}
// The existing "001" row must still be there.
var legacyCount int
err = db.QueryRow(
"SELECT COUNT(*) FROM schema_migrations WHERE version = '001'",
).Scan(&legacyCount)
if err != nil {
t.Fatalf("failed to check legacy version: %v", err)
}
if legacyCount != 1 {
t.Errorf("legacy version 001 row missing after bootstrap")
}
}

View File

@@ -1,9 +0,0 @@
-- Migration 000: Schema migrations tracking table
-- This must be the first migration applied. The bootstrap logic in
-- database.go applies it directly (bypassing the normal migration
-- loop) when the schema_migrations table does not yet exist.
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@@ -6,22 +6,19 @@ import (
)
// HostWhitelist implements the Whitelist interface for checking allowed source hosts.
// Only exact host matches are supported. Leading dots in patterns are stripped
// (e.g. ".example.com" becomes an exact match for "example.com").
type HostWhitelist struct {
// exactHosts contains hosts that must match exactly (e.g., "cdn.example.com")
exactHosts map[string]struct{}
// suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com")
suffixHosts []string
// hosts contains hosts that must match exactly (e.g., "cdn.example.com")
hosts map[string]struct{}
}
// NewHostWhitelist creates a whitelist from a list of host patterns.
// Patterns starting with "." are treated as suffix matches.
// Examples:
// - "cdn.example.com" - exact match only
// - ".example.com" - matches cdn.example.com, images.example.com, etc.
// All patterns are treated as exact matches. Leading dots are stripped
// for backwards compatibility (e.g. ".example.com" matches "example.com" only).
func NewHostWhitelist(patterns []string) *HostWhitelist {
w := &HostWhitelist{
exactHosts: make(map[string]struct{}),
suffixHosts: make([]string, 0),
hosts: make(map[string]struct{}),
}
for _, pattern := range patterns {
@@ -30,17 +27,22 @@ func NewHostWhitelist(patterns []string) *HostWhitelist {
continue
}
if strings.HasPrefix(pattern, ".") {
w.suffixHosts = append(w.suffixHosts, pattern)
} else {
w.exactHosts[pattern] = struct{}{}
// Strip leading dot — suffix matching is not supported.
// ".example.com" is treated as exact match for "example.com".
pattern = strings.TrimPrefix(pattern, ".")
if pattern == "" {
continue
}
w.hosts[pattern] = struct{}{}
}
return w
}
// IsWhitelisted checks if a URL's host is in the whitelist.
// Only exact host matches are supported.
func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
if u == nil {
return false
@@ -51,32 +53,17 @@ func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
return false
}
// Check exact match
if _, ok := w.exactHosts[host]; ok {
return true
}
_, ok := w.hosts[host]
// Check suffix match
for _, suffix := range w.suffixHosts {
if strings.HasSuffix(host, suffix) {
return true
}
// Also match if host equals the suffix without the leading dot
// e.g., pattern ".example.com" should match "example.com"
if host == strings.TrimPrefix(suffix, ".") {
return true
}
}
return false
return ok
}
// IsEmpty returns true if the whitelist has no entries.
func (w *HostWhitelist) IsEmpty() bool {
return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0
return len(w.hosts) == 0
}
// Count returns the total number of whitelist entries.
func (w *HostWhitelist) Count() int {
return len(w.exactHosts) + len(w.suffixHosts)
return len(w.hosts)
}

View File

@@ -31,41 +31,47 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
want: false,
},
{
name: "suffix match",
name: "dot prefix does not enable suffix matching",
patterns: []string{".example.com"},
testURL: "https://cdn.example.com/image.jpg",
want: true,
want: false,
},
{
name: "suffix match deep subdomain",
name: "dot prefix does not match deep subdomain",
patterns: []string{".example.com"},
testURL: "https://cdn.images.example.com/image.jpg",
want: true,
want: false,
},
{
name: "suffix match apex domain",
name: "dot prefix stripped matches apex domain exactly",
patterns: []string{".example.com"},
testURL: "https://example.com/image.jpg",
want: true,
},
{
name: "suffix match not found",
name: "dot prefix does not match unrelated domain",
patterns: []string{".example.com"},
testURL: "https://notexample.com/image.jpg",
want: false,
},
{
name: "suffix match partial not allowed",
name: "dot prefix does not match partial domain",
patterns: []string{".example.com"},
testURL: "https://fakeexample.com/image.jpg",
want: false,
},
{
name: "multiple patterns",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
name: "multiple patterns exact only",
patterns: []string{"cdn.example.com", "photos.images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg",
want: true,
},
{
name: "multiple patterns no suffix match",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg",
want: false,
},
{
name: "empty whitelist",
patterns: []string{},
@@ -90,6 +96,12 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
testURL: "https://cdn.example.com/image.jpg",
want: true,
},
{
name: "whitespace dot prefix stripped matches exactly",
patterns: []string{" .other.com "},
testURL: "https://other.com/image.jpg",
want: true,
},
}
for _, tt := range tests {
@@ -139,6 +151,11 @@ func TestHostWhitelist_IsEmpty(t *testing.T) {
patterns: []string{"example.com"},
want: false,
},
{
name: "dot prefix entry still counts",
patterns: []string{".example.com"},
want: false,
},
}
for _, tt := range tests {
@@ -168,7 +185,7 @@ func TestHostWhitelist_Count(t *testing.T) {
want: 3,
},
{
name: "suffix hosts only",
name: "dot prefix hosts treated as exact",
patterns: []string{".a.com", ".b.com"},
want: 2,
},
@@ -177,6 +194,11 @@ func TestHostWhitelist_Count(t *testing.T) {
patterns: []string{"exact.com", ".suffix.com"},
want: 2,
},
{
name: "dot prefix deduplicates with exact",
patterns: []string{"example.com", ".example.com"},
want: 1,
},
}
for _, tt := range tests {