Implement host whitelist for source domains
This commit is contained in:
82
internal/imgcache/whitelist.go
Normal file
82
internal/imgcache/whitelist.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package imgcache
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HostWhitelist implements the Whitelist interface for checking allowed source hosts.
|
||||
type HostWhitelist struct {
|
||||
// exactHosts contains hosts that must match exactly (e.g., "cdn.example.com")
|
||||
exactHosts map[string]struct{}
|
||||
// suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com")
|
||||
suffixHosts []string
|
||||
}
|
||||
|
||||
// NewHostWhitelist creates a whitelist from a list of host patterns.
|
||||
// Patterns starting with "." are treated as suffix matches.
|
||||
// Examples:
|
||||
// - "cdn.example.com" - exact match only
|
||||
// - ".example.com" - matches cdn.example.com, images.example.com, etc.
|
||||
func NewHostWhitelist(patterns []string) *HostWhitelist {
|
||||
w := &HostWhitelist{
|
||||
exactHosts: make(map[string]struct{}),
|
||||
suffixHosts: make([]string, 0),
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
pattern = strings.ToLower(strings.TrimSpace(pattern))
|
||||
if pattern == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(pattern, ".") {
|
||||
w.suffixHosts = append(w.suffixHosts, pattern)
|
||||
} else {
|
||||
w.exactHosts[pattern] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
// IsWhitelisted checks if a URL's host is in the whitelist.
|
||||
func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
|
||||
if u == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
host := strings.ToLower(u.Hostname())
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check exact match
|
||||
if _, ok := w.exactHosts[host]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check suffix match
|
||||
for _, suffix := range w.suffixHosts {
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
return true
|
||||
}
|
||||
// Also match if host equals the suffix without the leading dot
|
||||
// e.g., pattern ".example.com" should match "example.com"
|
||||
if host == strings.TrimPrefix(suffix, ".") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the whitelist has no entries.
|
||||
func (w *HostWhitelist) IsEmpty() bool {
|
||||
return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0
|
||||
}
|
||||
|
||||
// Count returns the total number of whitelist entries.
|
||||
func (w *HostWhitelist) Count() int {
|
||||
return len(w.exactHosts) + len(w.suffixHosts)
|
||||
}
|
||||
190
internal/imgcache/whitelist_test.go
Normal file
190
internal/imgcache/whitelist_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package imgcache
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHostWhitelist_IsWhitelisted(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
patterns []string
|
||||
testURL string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
patterns: []string{"cdn.example.com"},
|
||||
testURL: "https://cdn.example.com/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact match case insensitive",
|
||||
patterns: []string{"CDN.Example.COM"},
|
||||
testURL: "https://cdn.example.com/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact match not found",
|
||||
patterns: []string{"cdn.example.com"},
|
||||
testURL: "https://other.example.com/image.jpg",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "suffix match",
|
||||
patterns: []string{".example.com"},
|
||||
testURL: "https://cdn.example.com/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "suffix match deep subdomain",
|
||||
patterns: []string{".example.com"},
|
||||
testURL: "https://cdn.images.example.com/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "suffix match apex domain",
|
||||
patterns: []string{".example.com"},
|
||||
testURL: "https://example.com/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "suffix match not found",
|
||||
patterns: []string{".example.com"},
|
||||
testURL: "https://notexample.com/image.jpg",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "suffix match partial not allowed",
|
||||
patterns: []string{".example.com"},
|
||||
testURL: "https://fakeexample.com/image.jpg",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multiple patterns",
|
||||
patterns: []string{"cdn.example.com", ".images.org", "static.test.net"},
|
||||
testURL: "https://photos.images.org/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty whitelist",
|
||||
patterns: []string{},
|
||||
testURL: "https://cdn.example.com/image.jpg",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil url",
|
||||
patterns: []string{"cdn.example.com"},
|
||||
testURL: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "url with port",
|
||||
patterns: []string{"cdn.example.com"},
|
||||
testURL: "https://cdn.example.com:443/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace in patterns",
|
||||
patterns: []string{" cdn.example.com ", " .other.com "},
|
||||
testURL: "https://cdn.example.com/image.jpg",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := NewHostWhitelist(tt.patterns)
|
||||
|
||||
var u *url.URL
|
||||
if tt.testURL != "" {
|
||||
var err error
|
||||
u, err = url.Parse(tt.testURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse test URL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
got := w.IsWhitelisted(u)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostWhitelist_IsEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
patterns []string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
patterns: []string{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
patterns: nil,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
patterns: []string{" ", ""},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "has entries",
|
||||
patterns: []string{"example.com"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := NewHostWhitelist(tt.patterns)
|
||||
if got := w.IsEmpty(); got != tt.want {
|
||||
t.Errorf("IsEmpty() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostWhitelist_Count(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
patterns []string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
patterns: []string{},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "exact hosts only",
|
||||
patterns: []string{"a.com", "b.com", "c.com"},
|
||||
want: 3,
|
||||
},
|
||||
{
|
||||
name: "suffix hosts only",
|
||||
patterns: []string{".a.com", ".b.com"},
|
||||
want: 2,
|
||||
},
|
||||
{
|
||||
name: "mixed",
|
||||
patterns: []string{"exact.com", ".suffix.com"},
|
||||
want: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := NewHostWhitelist(tt.patterns)
|
||||
if got := w.Count(); got != tt.want {
|
||||
t.Errorf("Count() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user