refactor: extract whitelist package from internal/imgcache #41

Merged
sneak merged 2 commits from refactor/extract-whitelist-package into main 2026-03-25 20:44:57 +01:00
3 changed files with 31 additions and 27 deletions

View File

@@ -1,25 +1,26 @@
package imgcache
// Package allowlist provides host-based URL allow-listing for the image proxy.
package allowlist
import (
"net/url"
"strings"
)
// HostWhitelist implements the Whitelist interface for checking allowed source hosts.
type HostWhitelist struct {
// HostAllowList checks whether source hosts are permitted.
type HostAllowList struct {
// exactHosts contains hosts that must match exactly (e.g., "cdn.example.com")
exactHosts map[string]struct{}
// suffixHosts contains domain suffixes to match (e.g., ".example.com" matches "cdn.example.com")
suffixHosts []string
}
// NewHostWhitelist creates a whitelist from a list of host patterns.
// New creates a HostAllowList from a list of host patterns.
// Patterns starting with "." are treated as suffix matches.
// Examples:
// - "cdn.example.com" - exact match only
// - ".example.com" - matches cdn.example.com, images.example.com, etc.
func NewHostWhitelist(patterns []string) *HostWhitelist {
w := &HostWhitelist{
func New(patterns []string) *HostAllowList {
w := &HostAllowList{
exactHosts: make(map[string]struct{}),
suffixHosts: make([]string, 0),
}
@@ -40,8 +41,8 @@ func NewHostWhitelist(patterns []string) *HostWhitelist {
return w
}
// IsWhitelisted checks if a URL's host is in the whitelist.
func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
// IsAllowed checks if a URL's host is in the allow list.
func (w *HostAllowList) IsAllowed(u *url.URL) bool {
if u == nil {
return false
}
@@ -71,12 +72,12 @@ func (w *HostWhitelist) IsWhitelisted(u *url.URL) bool {
return false
}
// IsEmpty returns true if the whitelist has no entries.
func (w *HostWhitelist) IsEmpty() bool {
// IsEmpty returns true if the allow list has no entries.
func (w *HostAllowList) IsEmpty() bool {
return len(w.exactHosts) == 0 && len(w.suffixHosts) == 0
}
// Count returns the total number of whitelist entries.
func (w *HostWhitelist) Count() int {
// Count returns the total number of allow list entries.
func (w *HostAllowList) Count() int {
return len(w.exactHosts) + len(w.suffixHosts)
}

View File

@@ -1,11 +1,13 @@
package imgcache
package allowlist_test
import (
"net/url"
"testing"
"sneak.berlin/go/pixa/internal/allowlist"
)
func TestHostWhitelist_IsWhitelisted(t *testing.T) {
func TestHostAllowList_IsAllowed(t *testing.T) {
tests := []struct {
name string
patterns []string
@@ -67,7 +69,7 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
want: true,
},
{
name: "empty whitelist",
name: "empty allow list",
patterns: []string{},
testURL: "https://cdn.example.com/image.jpg",
want: false,
@@ -94,7 +96,7 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := NewHostWhitelist(tt.patterns)
w := allowlist.New(tt.patterns)
var u *url.URL
if tt.testURL != "" {
@@ -105,15 +107,15 @@ func TestHostWhitelist_IsWhitelisted(t *testing.T) {
}
}
got := w.IsWhitelisted(u)
got := w.IsAllowed(u)
if got != tt.want {
t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want)
t.Errorf("IsAllowed() = %v, want %v", got, tt.want)
}
})
}
}
func TestHostWhitelist_IsEmpty(t *testing.T) {
func TestHostAllowList_IsEmpty(t *testing.T) {
tests := []struct {
name string
patterns []string
@@ -143,7 +145,7 @@ func TestHostWhitelist_IsEmpty(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := NewHostWhitelist(tt.patterns)
w := allowlist.New(tt.patterns)
if got := w.IsEmpty(); got != tt.want {
t.Errorf("IsEmpty() = %v, want %v", got, tt.want)
}
@@ -151,7 +153,7 @@ func TestHostWhitelist_IsEmpty(t *testing.T) {
}
}
func TestHostWhitelist_Count(t *testing.T) {
func TestHostAllowList_Count(t *testing.T) {
tests := []struct {
name string
patterns []string
@@ -181,7 +183,7 @@ func TestHostWhitelist_Count(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := NewHostWhitelist(tt.patterns)
w := allowlist.New(tt.patterns)
if got := w.Count(); got != tt.want {
t.Errorf("Count() = %v, want %v", got, tt.want)
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/dustin/go-humanize"
"sneak.berlin/go/pixa/internal/allowlist"
"sneak.berlin/go/pixa/internal/imageprocessor"
)
@@ -20,7 +21,7 @@ type Service struct {
fetcher Fetcher
processor *imageprocessor.ImageProcessor
signer *Signer
whitelist *HostWhitelist
allowlist *allowlist.HostAllowList
log *slog.Logger
allowHTTP bool
maxResponseSize int64
@@ -85,7 +86,7 @@ func NewService(cfg *ServiceConfig) (*Service, error) {
fetcher: fetcher,
processor: imageprocessor.New(imageprocessor.Params{MaxInputBytes: maxResponseSize}),
signer: signer,
whitelist: NewHostWhitelist(cfg.Whitelist),
allowlist: allowlist.New(cfg.Whitelist),
log: log,
allowHTTP: allowHTTP,
maxResponseSize: maxResponseSize,
@@ -381,7 +382,7 @@ func (s *Service) Stats(ctx context.Context) (*CacheStats, error) {
// ValidateRequest validates the request signature if required.
func (s *Service) ValidateRequest(req *ImageRequest) error {
// Check if host is whitelisted (no signature required)
// Check if host is allowed (no signature required)
sourceURL := req.SourceURL()
parsedURL, err := url.Parse(sourceURL)
@@ -389,11 +390,11 @@ func (s *Service) ValidateRequest(req *ImageRequest) error {
return fmt.Errorf("invalid source URL: %w", err)
}
if s.whitelist.IsWhitelisted(parsedURL) {
if s.allowlist.IsAllowed(parsedURL) {
return nil
}
// Signature required for non-whitelisted hosts
// Signature required for non-allowed hosts
return s.signer.Verify(req)
}