diff --git a/internal/imgcache/whitelist.go b/internal/imgcache/whitelist.go new file mode 100644 index 0000000..df24be2 --- /dev/null +++ b/internal/imgcache/whitelist.go @@ -0,0 +1,82 @@ +package imgcache + +import ( + "net/url" + "strings" +) + +// HostWhitelist implements the Whitelist interface for checking allowed source hosts. +type HostWhitelist struct { + // exactHosts contains hosts that must match exactly (e.g., "cdn.example.com") + exactHosts map[string]struct{} + // suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com") + suffixHosts []string +} + +// NewHostWhitelist creates a whitelist from a list of host patterns. +// Patterns starting with "." are treated as suffix matches. +// Examples: +// - "cdn.example.com" - exact match only +// - ".example.com" - matches cdn.example.com, images.example.com, etc. +func NewHostWhitelist(patterns []string) *HostWhitelist { + w := &HostWhitelist{ + exactHosts: make(map[string]struct{}), + suffixHosts: make([]string, 0), + } + + for _, pattern := range patterns { + pattern = strings.ToLower(strings.TrimSpace(pattern)) + if pattern == "" { + continue + } + + if strings.HasPrefix(pattern, ".") { + w.suffixHosts = append(w.suffixHosts, pattern) + } else { + w.exactHosts[pattern] = struct{}{} + } + } + + return w +} + +// IsWhitelisted checks if a URL's host is in the whitelist. +func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool { + if u == nil { + return false + } + + host := strings.ToLower(u.Hostname()) + if host == "" { + return false + } + + // Check exact match + if _, ok := w.exactHosts[host]; ok { + return true + } + + // Check suffix match + for _, suffix := range w.suffixHosts { + if strings.HasSuffix(host, suffix) { + return true + } + // Also match if host equals the suffix without the leading dot + // e.g., pattern ".example.com" should match "example.com" + if host == strings.TrimPrefix(suffix, ".") { + return true + } + } + + return false +} + +// IsEmpty returns true if the whitelist has no entries. +func (w *HostWhitelist) IsEmpty() bool { + return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0 +} + +// Count returns the total number of whitelist entries. +func (w *HostWhitelist) Count() int { + return len(w.exactHosts) + len(w.suffixHosts) +} diff --git a/internal/imgcache/whitelist_test.go b/internal/imgcache/whitelist_test.go new file mode 100644 index 0000000..3e33b66 --- /dev/null +++ b/internal/imgcache/whitelist_test.go @@ -0,0 +1,190 @@ +package imgcache + +import ( + "net/url" + "testing" +) + +func TestHostWhitelist_IsWhitelisted(t *testing.T) { + tests := []struct { + name string + patterns []string + testURL string + want bool + }{ + { + name: "exact match", + patterns: []string{"cdn.example.com"}, + testURL: "https://cdn.example.com/image.jpg", + want: true, + }, + { + name: "exact match case insensitive", + patterns: []string{"CDN.Example.COM"}, + testURL: "https://cdn.example.com/image.jpg", + want: true, + }, + { + name: "exact match not found", + patterns: []string{"cdn.example.com"}, + testURL: "https://other.example.com/image.jpg", + want: false, + }, + { + name: "suffix match", + patterns: []string{".example.com"}, + testURL: "https://cdn.example.com/image.jpg", + want: true, + }, + { + name: "suffix match deep subdomain", + patterns: []string{".example.com"}, + testURL: "https://cdn.images.example.com/image.jpg", + want: true, + }, + { + name: "suffix match apex domain", + patterns: []string{".example.com"}, + testURL: "https://example.com/image.jpg", + want: true, + }, + { + name: "suffix match not found", + patterns: []string{".example.com"}, + testURL: "https://notexample.com/image.jpg", + want: false, + }, + { + name: "suffix match partial not allowed", + patterns: []string{".example.com"}, + testURL: "https://fakeexample.com/image.jpg", + want: false, + }, + { + name: "multiple patterns", + patterns: []string{"cdn.example.com", ".images.org", "static.test.net"}, + testURL: "https://photos.images.org/image.jpg", + want: true, + }, + { + name: "empty whitelist", + patterns: []string{}, + testURL: "https://cdn.example.com/image.jpg", + want: false, + }, + { + name: "nil url", + patterns: []string{"cdn.example.com"}, + testURL: "", + want: false, + }, + { + name: "url with port", + patterns: []string{"cdn.example.com"}, + testURL: "https://cdn.example.com:443/image.jpg", + want: true, + }, + { + name: "whitespace in patterns", + patterns: []string{" cdn.example.com ", " .other.com "}, + testURL: "https://cdn.example.com/image.jpg", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := NewHostWhitelist(tt.patterns) + + var u *url.URL + if tt.testURL != "" { + var err error + u, err = url.Parse(tt.testURL) + if err != nil { + t.Fatalf("failed to parse test URL: %v", err) + } + } + + got := w.IsWhitelisted(u) + if got != tt.want { + t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHostWhitelist_IsEmpty(t *testing.T) { + tests := []struct { + name string + patterns []string + want bool + }{ + { + name: "empty", + patterns: []string{}, + want: true, + }, + { + name: "nil", + patterns: nil, + want: true, + }, + { + name: "whitespace only", + patterns: []string{" ", ""}, + want: true, + }, + { + name: "has entries", + patterns: []string{"example.com"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := NewHostWhitelist(tt.patterns) + if got := w.IsEmpty(); got != tt.want { + t.Errorf("IsEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHostWhitelist_Count(t *testing.T) { + tests := []struct { + name string + patterns []string + want int + }{ + { + name: "empty", + patterns: []string{}, + want: 0, + }, + { + name: "exact hosts only", + patterns: []string{"a.com", "b.com", "c.com"}, + want: 3, + }, + { + name: "suffix hosts only", + patterns: []string{".a.com", ".b.com"}, + want: 2, + }, + { + name: "mixed", + patterns: []string{"exact.com", ".suffix.com"}, + want: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := NewHostWhitelist(tt.patterns) + if got := w.Count(); got != tt.want { + t.Errorf("Count() = %v, want %v", got, tt.want) + } + }) + } +}