3 Commits

Author SHA1 Message Date
user
55bb620de0 docs: update README and config to reflect exact-match-only whitelist
All checks were successful
check / check (push) Successful in 5s
Remove suffix match documentation and config comments since whitelist
now only supports exact host matches.
2026-03-15 11:18:44 -07:00
user
215ddb7f72 fix: remove suffix matching from host whitelist
Whitelist entries now support exact host matches only. Leading dots
in patterns are stripped for backwards compatibility (.example.com
becomes an exact match for example.com). Suffix matching that would
match arbitrary subdomains is no longer supported.

Closes #27
2026-03-15 11:18:25 -07:00
user
27739da046 test: add failing tests for removing suffix matching from whitelist
Suffix matching (.example.com matching subdomains) should not be
supported. Whitelist entries should be exact host matches only.
Leading dots should be stripped and treated as exact matches.
2026-03-15 11:18:01 -07:00
12 changed files with 169 additions and 308 deletions

View File

@@ -98,9 +98,7 @@ expiration 1704067200:
**Whitelist patterns:**
- **Exact match**: `cdn.example.com` — matches only that host
- **Suffix match**: `.example.com` — matches `cdn.example.com`,
`images.example.com`, and `example.com`
- **Exact match only**: `cdn.example.com` — matches only that host
### Configuration

View File

@@ -17,7 +17,10 @@ import (
"sneak.berlin/go/pixa/internal/server"
)
var Version string //nolint:gochecknoglobals // set by ldflags
var (
Appname = "pixad" //nolint:gochecknoglobals // set by ldflags
Version string //nolint:gochecknoglobals // set by ldflags
)
var configPath string //nolint:gochecknoglobals // cobra flag
@@ -37,6 +40,7 @@ func main() {
}
func run(_ *cobra.Command, _ []string) {
globals.Appname = Appname
globals.Version = Version
// Set config path in environment if specified via flag

View File

@@ -13,8 +13,7 @@ state_dir: ./data
# Generate with: openssl rand -base64 32
signing_key: "CHANGE_ME_generate_with_openssl_rand_base64_32"
# Hosts that don't require signatures
# Use "." prefix for wildcard subdomain matching (e.g., ".example.com" matches "cdn.example.com")
# Hosts that don't require signatures (exact match only)
whitelist_hosts:
- s3.sneak.cloud
- static.sneak.cloud

View File

@@ -35,41 +35,6 @@ type Database struct {
config *config.Config
}
// ParseMigrationVersion extracts the numeric version prefix from a migration
// filename. Filenames must follow the pattern "<version>.sql" or
// "<version>_<description>.sql", where version is a zero-padded numeric
// string (e.g. "001", "002"). Returns the version string and an error if
// the filename does not match the expected pattern.
func ParseMigrationVersion(filename string) (string, error) {
name := strings.TrimSuffix(filename, filepath.Ext(filename))
if name == "" {
return "", fmt.Errorf("invalid migration filename %q: empty name", filename)
}
// Split on underscore to separate version from description.
// If there's no underscore, the entire stem is the version.
version := name
if idx := strings.IndexByte(name, '_'); idx >= 0 {
version = name[:idx]
}
if version == "" {
return "", fmt.Errorf("invalid migration filename %q: empty version prefix", filename)
}
// Validate the version is purely numeric.
for _, ch := range version {
if ch < '0' || ch > '9' {
return "", fmt.Errorf(
"invalid migration filename %q: version %q contains non-numeric character %q",
filename, version, string(ch),
)
}
}
return version, nil
}
// New creates a new Database instance.
func New(lc fx.Lifecycle, params Params) (*Database, error) {
s := &Database{
@@ -119,33 +84,96 @@ func (s *Database) connect(ctx context.Context) error {
s.db = db
s.log.Info("database connected")
return ApplyMigrations(ctx, s.db, s.log)
return s.runMigrations(ctx)
}
// collectMigrations reads the embedded schema directory and returns
// migration filenames sorted lexicographically.
func collectMigrations() ([]string, error) {
entries, err := schemaFS.ReadDir("schema")
func (s *Database) runMigrations(ctx context.Context) error {
// Create migrations tracking table
_, err := s.db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return nil, fmt.Errorf("failed to read schema directory: %w", err)
return fmt.Errorf("failed to create migrations table: %w", err)
}
var migrations []string
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return fmt.Errorf("failed to read schema directory: %w", err)
}
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
return migrations, nil
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := s.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if count > 0 {
s.log.Debug("migration already applied", "version", version)
continue
}
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
s.log.Info("applying migration", "version", version)
_, err = s.db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
// Record migration as applied
_, err = s.db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
s.log.Info("migration applied successfully", "version", version)
}
return nil
}
// ensureMigrationsTable creates the schema_migrations tracking table if
// it does not already exist.
func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
// DB returns the underlying sql.DB.
func (s *Database) DB() *sql.DB {
return s.db
}
// ApplyMigrations applies all migrations to the given database.
// This is useful for testing where you want to use the real schema
// without the full fx lifecycle.
func ApplyMigrations(db *sql.DB) error {
ctx := context.Background()
// Create migrations tracking table
_, err := db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
@@ -156,32 +184,27 @@ func ensureMigrationsTable(ctx context.Context, db *sql.DB) error {
return fmt.Errorf("failed to create migrations table: %w", err)
}
return nil
}
// ApplyMigrations applies all pending migrations to db. An optional logger
// may be provided for informational output; pass nil for silent operation.
// This is exported so tests can apply the real schema without the full fx
// lifecycle.
func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
if err := ensureMigrationsTable(ctx, db); err != nil {
return err
}
migrations, err := collectMigrations()
// Get list of migration files
entries, err := schemaFS.ReadDir("schema")
if err != nil {
return err
return fmt.Errorf("failed to read schema directory: %w", err)
}
for _, migration := range migrations {
version, parseErr := ParseMigrationVersion(migration)
if parseErr != nil {
return parseErr
// Sort migration files by name (001.sql, 002.sql, etc.)
var migrations []string
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
migrations = append(migrations, entry.Name())
}
}
sort.Strings(migrations)
// Check if already applied.
// Apply each migration that hasn't been applied yet
for _, migration := range migrations {
version := strings.TrimSuffix(migration, filepath.Ext(migration))
// Check if already applied
var count int
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?",
version,
@@ -191,46 +214,29 @@ func ApplyMigrations(ctx context.Context, db *sql.DB, log *slog.Logger) error {
}
if count > 0 {
if log != nil {
log.Debug("migration already applied", "version", version)
}
continue
}
// Read and apply migration.
content, readErr := schemaFS.ReadFile(filepath.Join("schema", migration))
if readErr != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, readErr)
// Read and apply migration
content, err := schemaFS.ReadFile(filepath.Join("schema", migration))
if err != nil {
return fmt.Errorf("failed to read migration %s: %w", migration, err)
}
if log != nil {
log.Info("applying migration", "version", version)
_, err = db.ExecContext(ctx, string(content))
if err != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, err)
}
_, execErr := db.ExecContext(ctx, string(content))
if execErr != nil {
return fmt.Errorf("failed to apply migration %s: %w", migration, execErr)
}
// Record migration as applied.
_, recErr := db.ExecContext(ctx,
// Record migration as applied
_, err = db.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
version,
)
if recErr != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, recErr)
}
if log != nil {
log.Info("migration applied successfully", "version", version)
if err != nil {
return fmt.Errorf("failed to record migration %s: %w", migration, err)
}
}
return nil
}
// DB returns the underlying sql.DB.
func (s *Database) DB() *sql.DB {
return s.db
}

View File

@@ -1,155 +0,0 @@
package database
import (
"context"
"database/sql"
"testing"
_ "modernc.org/sqlite" // SQLite driver registration
)
func TestParseMigrationVersion(t *testing.T) {
tests := []struct {
name string
filename string
want string
wantErr bool
}{
{
name: "version only",
filename: "001.sql",
want: "001",
},
{
name: "version with description",
filename: "001_initial_schema.sql",
want: "001",
},
{
name: "multi-digit version",
filename: "042_add_indexes.sql",
want: "042",
},
{
name: "long version number",
filename: "00001_long_prefix.sql",
want: "00001",
},
{
name: "description with multiple underscores",
filename: "003_add_user_auth_tables.sql",
want: "003",
},
{
name: "empty filename",
filename: ".sql",
wantErr: true,
},
{
name: "leading underscore",
filename: "_description.sql",
wantErr: true,
},
{
name: "non-numeric version",
filename: "abc_migration.sql",
wantErr: true,
},
{
name: "mixed alphanumeric version",
filename: "001a_migration.sql",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseMigrationVersion(tt.filename)
if tt.wantErr {
if err == nil {
t.Errorf("ParseMigrationVersion(%q) expected error, got %q", tt.filename, got)
}
return
}
if err != nil {
t.Errorf("ParseMigrationVersion(%q) unexpected error: %v", tt.filename, err)
return
}
if got != tt.want {
t.Errorf("ParseMigrationVersion(%q) = %q, want %q", tt.filename, got, tt.want)
}
})
}
}
func TestApplyMigrations(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open in-memory database: %v", err)
}
defer db.Close()
// Apply migrations should succeed.
if err := ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("ApplyMigrations failed: %v", err)
}
// Verify the schema_migrations table recorded the version.
var version string
err = db.QueryRowContext(context.Background(),
"SELECT version FROM schema_migrations LIMIT 1",
).Scan(&version)
if err != nil {
t.Fatalf("failed to query schema_migrations: %v", err)
}
if version != "001" {
t.Errorf("expected version %q, got %q", "001", version)
}
// Verify a table from the migration exists (source_content).
var tableName string
err = db.QueryRowContext(context.Background(),
"SELECT name FROM sqlite_master WHERE type='table' AND name='source_content'",
).Scan(&tableName)
if err != nil {
t.Fatalf("expected source_content table to exist: %v", err)
}
}
func TestApplyMigrationsIdempotent(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("failed to open in-memory database: %v", err)
}
defer db.Close()
// Apply twice should succeed (idempotent).
if err := ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("first ApplyMigrations failed: %v", err)
}
if err := ApplyMigrations(context.Background(), db, nil); err != nil {
t.Fatalf("second ApplyMigrations failed: %v", err)
}
// Should still have exactly one migration recorded.
var count int
err = db.QueryRowContext(context.Background(),
"SELECT COUNT(*) FROM schema_migrations",
).Scan(&count)
if err != nil {
t.Fatalf("failed to count schema_migrations: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migration record, got %d", count)
}
}

View File

@@ -5,10 +5,11 @@ import (
"go.uber.org/fx"
)
const appname = "pixad"
// Version is populated from main() via ldflags.
var Version string //nolint:gochecknoglobals // set from main
// Build-time variables populated from main() via ldflags.
var (
Appname string //nolint:gochecknoglobals // set from main
Version string //nolint:gochecknoglobals // set from main
)
// Globals holds application-wide constants.
type Globals struct {
@@ -19,7 +20,7 @@ type Globals struct {
// New creates a new Globals instance from build-time variables.
func New(_ fx.Lifecycle) (*Globals, error) {
return &Globals{
Appname: appname,
Appname: Appname,
Version: Version,
}, nil
}

View File

@@ -82,7 +82,7 @@ func setupTestDB(t *testing.T) *sql.DB {
t.Fatalf("failed to open test db: %v", err)
}
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
if err := database.ApplyMigrations(db); err != nil {
t.Fatalf("failed to apply migrations: %v", err)
}

View File

@@ -16,7 +16,7 @@ func setupStatsTestDB(t *testing.T) *sql.DB {
if err != nil {
t.Fatal(err)
}
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
if err := database.ApplyMigrations(db); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })

View File

@@ -2,7 +2,6 @@ package imgcache
import (
"bytes"
"context"
"database/sql"
"image"
"image/color"
@@ -194,7 +193,7 @@ func setupServiceTestDB(t *testing.T) *sql.DB {
}
// Use the real production schema via migrations
if err := database.ApplyMigrations(context.Background(), db, nil); err != nil {
if err := database.ApplyMigrations(db); err != nil {
t.Fatalf("failed to apply migrations: %v", err)
}

View File

@@ -6,22 +6,19 @@ import (
)
// HostWhitelist implements the Whitelist interface for checking allowed source hosts.
// Only exact host matches are supported. Leading dots in patterns are stripped
// (e.g. ".example.com" becomes an exact match for "example.com").
type HostWhitelist struct {
// exactHosts contains hosts that must match exactly (e.g., "cdn.example.com")
exactHosts map[string]struct{}
// suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com")
suffixHosts []string
// hosts contains hosts that must match exactly (e.g., "cdn.example.com")
hosts map[string]struct{}
}
// NewHostWhitelist creates a whitelist from a list of host patterns.
// Patterns starting with "." are treated as suffix matches.
// Examples:
// - "cdn.example.com" - exact match only
// - ".example.com" - matches cdn.example.com, images.example.com, etc.
// All patterns are treated as exact matches. Leading dots are stripped
// for backwards compatibility (e.g. ".example.com" matches "example.com" only).
func NewHostWhitelist(patterns []string) *HostWhitelist {
w := &HostWhitelist{
exactHosts: make(map[string]struct{}),
suffixHosts: make([]string, 0),
hosts: make(map[string]struct{}),
}
for _, pattern := range patterns {
@@ -30,17 +27,22 @@ func NewHostWhitelist(patterns []string) *HostWhitelist {
continue
}
if strings.HasPrefix(pattern, ".") {
w.suffixHosts = append(w.suffixHosts, pattern)
} else {
w.exactHosts[pattern] = struct{}{}
// Strip leading dot — suffix matching is not supported.
// ".example.com" is treated as exact match for "example.com".
pattern = strings.TrimPrefix(pattern, ".")
if pattern == "" {
continue
}
w.hosts[pattern] = struct{}{}
}
return w
}
// IsWhitelisted checks if a URL's host is in the whitelist.
// Only exact host matches are supported.
func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
if u == nil {
return false
@@ -51,32 +53,17 @@ func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
return false
}
// Check exact match
if _, ok := w.exactHosts[host]; ok {
return true
}
_, ok := w.hosts[host]
// Check suffix match
for _, suffix := range w.suffixHosts {
if strings.HasSuffix(host, suffix) {
return true
}
// Also match if host equals the suffix without the leading dot
// e.g., pattern ".example.com" should match "example.com"
if host == strings.TrimPrefix(suffix, ".") {
return true
}
}
return false
return ok
}
// IsEmpty returns true if the whitelist has no entries.
func (w *HostWhitelist) IsEmpty() bool {
return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0
return len(w.hosts) == 0
}
// Count returns the total number of whitelist entries.
func (w *HostWhitelist) Count() int {
return len(w.exactHosts) + len(w.suffixHosts)
return len(w.hosts)
}

View File

@@ -31,41 +31,47 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
want: false,
},
{
name: "suffix match",
name: "dot prefix does not enable suffix matching",
patterns: []string{".example.com"},
testURL: "https://cdn.example.com/image.jpg",
want: true,
want: false,
},
{
name: "suffix match deep subdomain",
name: "dot prefix does not match deep subdomain",
patterns: []string{".example.com"},
testURL: "https://cdn.images.example.com/image.jpg",
want: true,
want: false,
},
{
name: "suffix match apex domain",
name: "dot prefix stripped matches apex domain exactly",
patterns: []string{".example.com"},
testURL: "https://example.com/image.jpg",
want: true,
},
{
name: "suffix match not found",
name: "dot prefix does not match unrelated domain",
patterns: []string{".example.com"},
testURL: "https://notexample.com/image.jpg",
want: false,
},
{
name: "suffix match partial not allowed",
name: "dot prefix does not match partial domain",
patterns: []string{".example.com"},
testURL: "https://fakeexample.com/image.jpg",
want: false,
},
{
name: "multiple patterns",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
name: "multiple patterns exact only",
patterns: []string{"cdn.example.com", "photos.images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg",
want: true,
},
{
name: "multiple patterns no suffix match",
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
testURL: "https://photos.images.org/image.jpg",
want: false,
},
{
name: "empty whitelist",
patterns: []string{},
@@ -90,6 +96,12 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
testURL: "https://cdn.example.com/image.jpg",
want: true,
},
{
name: "whitespace dot prefix stripped matches exactly",
patterns: []string{" .other.com "},
testURL: "https://other.com/image.jpg",
want: true,
},
}
for _, tt := range tests {
@@ -139,6 +151,11 @@ func TestHostWhitelist_IsEmpty(t *testing.T) {
patterns: []string{"example.com"},
want: false,
},
{
name: "dot prefix entry still counts",
patterns: []string{".example.com"},
want: false,
},
}
for _, tt := range tests {
@@ -168,7 +185,7 @@ func TestHostWhitelist_Count(t *testing.T) {
want: 3,
},
{
name: "suffix hosts only",
name: "dot prefix hosts treated as exact",
patterns: []string{".a.com", ".b.com"},
want: 2,
},
@@ -177,6 +194,11 @@ func TestHostWhitelist_Count(t *testing.T) {
patterns: []string{"exact.com", ".suffix.com"},
want: 2,
},
{
name: "dot prefix deduplicates with exact",
patterns: []string{"example.com", ".example.com"},
want: 1,
},
}
for _, tt := range tests {