diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index d80ee34..f47d165 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -13,6 +13,7 @@ import ( "sneak.berlin/go/pixa/internal/database" "sneak.berlin/go/pixa/internal/encurl" "sneak.berlin/go/pixa/internal/healthcheck" + "sneak.berlin/go/pixa/internal/httpfetcher" "sneak.berlin/go/pixa/internal/imgcache" "sneak.berlin/go/pixa/internal/logger" "sneak.berlin/go/pixa/internal/session" @@ -72,7 +73,7 @@ func (s *Handlers) initImageService() error { s.imgCache = cache // Create the fetcher config - fetcherCfg := imgcache.DefaultFetcherConfig() + fetcherCfg := httpfetcher.DefaultConfig() fetcherCfg.AllowHTTP = s.config.AllowHTTP if s.config.UpstreamConnectionsPerHost > 0 { fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go index 63861de..4994646 100644 --- a/internal/handlers/handlers_test.go +++ b/internal/handlers/handlers_test.go @@ -18,6 +18,7 @@ import ( "github.com/go-chi/chi/v5" "sneak.berlin/go/pixa/internal/database" + "sneak.berlin/go/pixa/internal/httpfetcher" "sneak.berlin/go/pixa/internal/imgcache" ) @@ -116,16 +117,16 @@ func newMockFetcher(fs fs.FS) *mockFetcher { return &mockFetcher{fs: fs} } -func (f *mockFetcher) Fetch(ctx context.Context, url string) (*imgcache.FetchResult, error) { +func (f *mockFetcher) Fetch(ctx context.Context, url string) (*httpfetcher.FetchResult, error) { // Remove https:// prefix path := url[8:] // Remove "https://" data, err := fs.ReadFile(f.fs, path) if err != nil { - return nil, imgcache.ErrUpstreamError + return nil, httpfetcher.ErrUpstreamError } - return &imgcache.FetchResult{ + return &httpfetcher.FetchResult{ Content: io.NopCloser(bytes.NewReader(data)), ContentLength: int64(len(data)), ContentType: "image/jpeg", diff --git a/internal/handlers/image.go b/internal/handlers/image.go index d82c34f..d7248bf 100644 --- a/internal/handlers/image.go +++ b/internal/handlers/image.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "sneak.berlin/go/pixa/internal/httpfetcher" "sneak.berlin/go/pixa/internal/imgcache" ) @@ -97,13 +98,13 @@ func (s *Handlers) HandleImage() http.HandlerFunc { ) // Check for specific error types - if errors.Is(err, imgcache.ErrSSRFBlocked) { + if errors.Is(err, httpfetcher.ErrSSRFBlocked) { s.respondError(w, "forbidden", http.StatusForbidden) return } - if errors.Is(err, imgcache.ErrUpstreamError) { + if errors.Is(err, httpfetcher.ErrUpstreamError) { s.respondError(w, "upstream error", http.StatusBadGateway) return diff --git a/internal/handlers/imageenc.go b/internal/handlers/imageenc.go index 60294a0..3a95cab 100644 --- a/internal/handlers/imageenc.go +++ b/internal/handlers/imageenc.go @@ -11,6 +11,7 @@ import ( "github.com/go-chi/chi/v5" "sneak.berlin/go/pixa/internal/encurl" + "sneak.berlin/go/pixa/internal/httpfetcher" "sneak.berlin/go/pixa/internal/imgcache" ) @@ -100,11 +101,11 @@ func (s *Handlers) HandleImageEnc() http.HandlerFunc { // handleImageError converts image service errors to HTTP responses. func (s *Handlers) handleImageError(w http.ResponseWriter, err error) { switch { - case errors.Is(err, imgcache.ErrSSRFBlocked): + case errors.Is(err, httpfetcher.ErrSSRFBlocked): s.respondError(w, "forbidden", http.StatusForbidden) - case errors.Is(err, imgcache.ErrUpstreamError): + case errors.Is(err, httpfetcher.ErrUpstreamError): s.respondError(w, "upstream error", http.StatusBadGateway) - case errors.Is(err, imgcache.ErrUpstreamTimeout): + case errors.Is(err, httpfetcher.ErrUpstreamTimeout): s.respondError(w, "upstream timeout", http.StatusGatewayTimeout) default: s.log.Error("image request failed", "error", err) diff --git a/internal/imgcache/fetcher.go b/internal/httpfetcher/httpfetcher.go similarity index 83% rename from internal/imgcache/fetcher.go rename to internal/httpfetcher/httpfetcher.go index dc70b32..d199e8d 100644 --- a/internal/imgcache/fetcher.go +++ b/internal/httpfetcher/httpfetcher.go @@ -1,4 +1,6 @@ -package imgcache +// Package httpfetcher fetches content from upstream HTTP origins with SSRF +// protection, per-host connection limits, and content-type validation. +package httpfetcher import ( "context" @@ -37,25 +39,55 @@ var ( ErrUpstreamTimeout = errors.New("upstream request timeout") ) -// FetcherConfig holds configuration for the upstream fetcher. -type FetcherConfig struct { - // Timeout for upstream requests +// Fetcher retrieves content from upstream origins. +type Fetcher interface { + // Fetch retrieves content from the given URL. + Fetch(ctx context.Context, url string) (*FetchResult, error) +} + +// FetchResult contains the result of fetching from upstream. +type FetchResult struct { + // Content is the raw image data. + Content io.ReadCloser + // ContentLength is the size in bytes (-1 if unknown). + ContentLength int64 + // ContentType is the MIME type from upstream. + ContentType string + // Headers contains all response headers from upstream. + Headers map[string][]string + // StatusCode is the HTTP status code from upstream. + StatusCode int + // FetchDurationMs is how long the fetch took in milliseconds. + FetchDurationMs int64 + // RemoteAddr is the IP:port of the upstream server. + RemoteAddr string + // HTTPVersion is the protocol version (e.g., "1.1", "2.0"). + HTTPVersion string + // TLSVersion is the TLS protocol version (e.g., "TLS 1.3"). + TLSVersion string + // TLSCipherSuite is the negotiated cipher suite name. + TLSCipherSuite string +} + +// Config holds configuration for the upstream fetcher. +type Config struct { + // Timeout for upstream requests. Timeout time.Duration - // MaxResponseSize is the maximum allowed response body size + // MaxResponseSize is the maximum allowed response body size. MaxResponseSize int64 - // UserAgent to send to upstream servers + // UserAgent to send to upstream servers. UserAgent string - // AllowedContentTypes is a whitelist of MIME types to accept + // AllowedContentTypes is an allow list of MIME types to accept. AllowedContentTypes []string - // AllowHTTP allows non-TLS connections (for testing only) + // AllowHTTP allows non-TLS connections (for testing only). AllowHTTP bool - // MaxConnectionsPerHost limits concurrent connections to each upstream host + // MaxConnectionsPerHost limits concurrent connections to each upstream host. MaxConnectionsPerHost int } -// DefaultFetcherConfig returns sensible defaults. -func DefaultFetcherConfig() *FetcherConfig { - return &FetcherConfig{ +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() *Config { + return &Config{ Timeout: DefaultFetchTimeout, MaxResponseSize: DefaultMaxResponseSize, UserAgent: "pixa/1.0", @@ -72,18 +104,18 @@ func DefaultFetcherConfig() *FetcherConfig { } } -// HTTPFetcher implements the Fetcher interface with SSRF protection. +// HTTPFetcher implements Fetcher with SSRF protection and per-host connection limits. type HTTPFetcher struct { client *http.Client - config *FetcherConfig + config *Config hostSems map[string]chan struct{} // per-host semaphores hostSemMu sync.Mutex // protects hostSems map } -// NewHTTPFetcher creates a new fetcher with SSRF protection. -func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher { +// New creates a new HTTPFetcher with SSRF protection. +func New(config *Config) *HTTPFetcher { if config == nil { - config = DefaultFetcherConfig() + config = DefaultConfig() } // Create transport with SSRF-safe dialer @@ -250,7 +282,7 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro }, nil } -// isAllowedContentType checks if the content type is in the whitelist. +// isAllowedContentType checks if the content type is in the allow list. func (f *HTTPFetcher) isAllowedContentType(contentType string) bool { // Extract the MIME type without parameters mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0]) diff --git a/internal/httpfetcher/httpfetcher_test.go b/internal/httpfetcher/httpfetcher_test.go new file mode 100644 index 0000000..6a0cee2 --- /dev/null +++ b/internal/httpfetcher/httpfetcher_test.go @@ -0,0 +1,329 @@ +package httpfetcher + +import ( + "context" + "errors" + "io" + "net" + "testing" + "testing/fstest" +) + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Timeout != DefaultFetchTimeout { + t.Errorf("Timeout = %v, want %v", cfg.Timeout, DefaultFetchTimeout) + } + + if cfg.MaxResponseSize != DefaultMaxResponseSize { + t.Errorf("MaxResponseSize = %d, want %d", cfg.MaxResponseSize, DefaultMaxResponseSize) + } + + if cfg.MaxConnectionsPerHost != DefaultMaxConnectionsPerHost { + t.Errorf("MaxConnectionsPerHost = %d, want %d", + cfg.MaxConnectionsPerHost, DefaultMaxConnectionsPerHost) + } + + if cfg.AllowHTTP { + t.Error("AllowHTTP should default to false") + } + + if len(cfg.AllowedContentTypes) == 0 { + t.Error("AllowedContentTypes should not be empty") + } +} + +func TestNewWithNilConfigUsesDefaults(t *testing.T) { + f := New(nil) + + if f == nil { + t.Fatal("New(nil) returned nil") + } + + if f.config == nil { + t.Fatal("config should be populated from DefaultConfig") + } + + if f.config.Timeout != DefaultFetchTimeout { + t.Errorf("Timeout = %v, want %v", f.config.Timeout, DefaultFetchTimeout) + } +} + +func TestIsAllowedContentType(t *testing.T) { + f := New(DefaultConfig()) + + tests := []struct { + contentType string + want bool + }{ + {"image/jpeg", true}, + {"image/png", true}, + {"image/webp", true}, + {"image/jpeg; charset=utf-8", true}, + {"IMAGE/JPEG", true}, + {"text/html", false}, + {"application/octet-stream", false}, + {"", false}, + } + + for _, tc := range tests { + t.Run(tc.contentType, func(t *testing.T) { + got := f.isAllowedContentType(tc.contentType) + if got != tc.want { + t.Errorf("isAllowedContentType(%q) = %v, want %v", tc.contentType, got, tc.want) + } + }) + } +} + +func TestExtractHost(t *testing.T) { + tests := []struct { + url string + want string + }{ + {"https://example.com/path", "example.com"}, + {"http://example.com:8080/path", "example.com:8080"}, + {"https://example.com", "example.com"}, + {"https://example.com?q=1", "example.com"}, + {"example.com/path", "example.com"}, + {"", ""}, + } + + for _, tc := range tests { + t.Run(tc.url, func(t *testing.T) { + got := extractHost(tc.url) + if got != tc.want { + t.Errorf("extractHost(%q) = %q, want %q", tc.url, got, tc.want) + } + }) + } +} + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + host string + want bool + }{ + {"localhost", true}, + {"LOCALHOST", true}, + {"127.0.0.1", true}, + {"::1", true}, + {"[::1]", true}, + {"foo.localhost", true}, + {"foo.local", true}, + {"example.com", false}, + {"127.0.0.2", false}, // Handled by isPrivateIP, not isLocalhost string match + } + + for _, tc := range tests { + t.Run(tc.host, func(t *testing.T) { + got := isLocalhost(tc.host) + if got != tc.want { + t.Errorf("isLocalhost(%q) = %v, want %v", tc.host, got, tc.want) + } + }) + } +} + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + ip string + want bool + }{ + {"127.0.0.1", true}, // loopback + {"10.0.0.1", true}, // private + {"192.168.1.1", true}, // private + {"172.16.0.1", true}, // private + {"169.254.1.1", true}, // link-local + {"0.0.0.0", true}, // unspecified + {"224.0.0.1", true}, // multicast + {"::1", true}, // IPv6 loopback + {"fe80::1", true}, // IPv6 link-local + {"8.8.8.8", false}, // public + {"2001:4860:4860::8888", false}, // public IPv6 + } + + for _, tc := range tests { + t.Run(tc.ip, func(t *testing.T) { + ip := net.ParseIP(tc.ip) + if ip == nil { + t.Fatalf("failed to parse IP %q", tc.ip) + } + + got := isPrivateIP(ip) + if got != tc.want { + t.Errorf("isPrivateIP(%q) = %v, want %v", tc.ip, got, tc.want) + } + }) + } + + if !isPrivateIP(nil) { + t.Error("isPrivateIP(nil) should return true") + } +} + +func TestValidateURL_RejectsNonHTTPS(t *testing.T) { + err := validateURL("http://example.com/path", false) + if !errors.Is(err, ErrUnsupportedScheme) { + t.Errorf("validateURL http = %v, want ErrUnsupportedScheme", err) + } +} + +func TestValidateURL_AllowsHTTPWhenConfigured(t *testing.T) { + // Use a host that won't resolve (explicit .invalid TLD) so we don't hit DNS. + err := validateURL("http://nonexistent.invalid/path", true) + // We expect a host resolution error, not ErrUnsupportedScheme. + if errors.Is(err, ErrUnsupportedScheme) { + t.Error("validateURL with AllowHTTP should not return ErrUnsupportedScheme") + } +} + +func TestValidateURL_RejectsLocalhost(t *testing.T) { + err := validateURL("https://localhost/path", false) + if !errors.Is(err, ErrSSRFBlocked) { + t.Errorf("validateURL localhost = %v, want ErrSSRFBlocked", err) + } +} + +func TestValidateURL_EmptyHost(t *testing.T) { + err := validateURL("https:///path", false) + if !errors.Is(err, ErrInvalidHost) { + t.Errorf("validateURL empty host = %v, want ErrInvalidHost", err) + } +} + +func TestMockFetcher_FetchesFile(t *testing.T) { + mockFS := fstest.MapFS{ + "example.com/images/photo.jpg": &fstest.MapFile{Data: []byte("fake-jpeg-data")}, + } + + m := NewMock(mockFS) + + result, err := m.Fetch(context.Background(), "https://example.com/images/photo.jpg") + if err != nil { + t.Fatalf("Fetch() error = %v", err) + } + defer func() { _ = result.Content.Close() }() + + if result.ContentType != "image/jpeg" { + t.Errorf("ContentType = %q, want image/jpeg", result.ContentType) + } + + data, err := io.ReadAll(result.Content) + if err != nil { + t.Fatalf("read content: %v", err) + } + + if string(data) != "fake-jpeg-data" { + t.Errorf("Content = %q, want %q", string(data), "fake-jpeg-data") + } + + if result.ContentLength != int64(len("fake-jpeg-data")) { + t.Errorf("ContentLength = %d, want %d", result.ContentLength, len("fake-jpeg-data")) + } +} + +func TestMockFetcher_MissingFileReturnsUpstreamError(t *testing.T) { + mockFS := fstest.MapFS{} + m := NewMock(mockFS) + + _, err := m.Fetch(context.Background(), "https://example.com/missing.jpg") + if !errors.Is(err, ErrUpstreamError) { + t.Errorf("Fetch() error = %v, want ErrUpstreamError", err) + } +} + +func TestMockFetcher_RespectsContextCancellation(t *testing.T) { + mockFS := fstest.MapFS{ + "example.com/photo.jpg": &fstest.MapFile{Data: []byte("data")}, + } + m := NewMock(mockFS) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := m.Fetch(ctx, "https://example.com/photo.jpg") + if !errors.Is(err, context.Canceled) { + t.Errorf("Fetch() error = %v, want context.Canceled", err) + } +} + +func TestDetectContentTypeFromPath(t *testing.T) { + tests := []struct { + path string + want string + }{ + {"foo/bar.jpg", "image/jpeg"}, + {"foo/bar.JPG", "image/jpeg"}, + {"foo/bar.jpeg", "image/jpeg"}, + {"foo/bar.png", "image/png"}, + {"foo/bar.gif", "image/gif"}, + {"foo/bar.webp", "image/webp"}, + {"foo/bar.avif", "image/avif"}, + {"foo/bar.svg", "image/svg+xml"}, + {"foo/bar.bin", "application/octet-stream"}, + {"foo/bar", "application/octet-stream"}, + } + + for _, tc := range tests { + t.Run(tc.path, func(t *testing.T) { + got := detectContentTypeFromPath(tc.path) + if got != tc.want { + t.Errorf("detectContentTypeFromPath(%q) = %q, want %q", tc.path, got, tc.want) + } + }) + } +} + +func TestLimitedReader_EnforcesLimit(t *testing.T) { + src := make([]byte, 100) + r := &limitedReader{ + reader: &byteReader{data: src}, + remaining: 50, + } + + buf := make([]byte, 100) + + n, err := r.Read(buf) + if err != nil { + t.Fatalf("first Read error = %v", err) + } + + if n > 50 { + t.Errorf("read %d bytes, should be capped at 50", n) + } + + // Drain until limit is exhausted. + total := n + for total < 50 { + nn, err := r.Read(buf) + total += nn + if err != nil { + t.Fatalf("during drain: %v", err) + } + } + + // Now the limit is exhausted — next read should error. + _, err = r.Read(buf) + if !errors.Is(err, ErrResponseTooLarge) { + t.Errorf("exhausted Read error = %v, want ErrResponseTooLarge", err) + } +} + +// byteReader is a minimal io.Reader over a byte slice for testing. +type byteReader struct { + data []byte + pos int +} + +func (r *byteReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + + n := copy(p, r.data[r.pos:]) + r.pos += n + + return n, nil +} diff --git a/internal/imgcache/mock_fetcher.go b/internal/httpfetcher/mock.go similarity index 90% rename from internal/imgcache/mock_fetcher.go rename to internal/httpfetcher/mock.go index a309c4a..c54b944 100644 --- a/internal/imgcache/mock_fetcher.go +++ b/internal/httpfetcher/mock.go @@ -1,4 +1,4 @@ -package imgcache +package httpfetcher import ( "context" @@ -10,15 +10,15 @@ import ( "strings" ) -// MockFetcher implements the Fetcher interface using an embedded filesystem. +// MockFetcher implements Fetcher using an embedded filesystem. // Files are organized as: hostname/path/to/file.ext -// URLs like https://example.com/images/photo.jpg map to example.com/images/photo.jpg +// URLs like https://example.com/images/photo.jpg map to example.com/images/photo.jpg. type MockFetcher struct { fs fs.FS } -// NewMockFetcher creates a new mock fetcher backed by the given filesystem. -func NewMockFetcher(fsys fs.FS) *MockFetcher { +// NewMock creates a new mock fetcher backed by the given filesystem. +func NewMock(fsys fs.FS) *MockFetcher { return &MockFetcher{fs: fsys} } diff --git a/internal/imgcache/cache.go b/internal/imgcache/cache.go index 7bd578b..5e15873 100644 --- a/internal/imgcache/cache.go +++ b/internal/imgcache/cache.go @@ -9,6 +9,8 @@ import ( "io" "path/filepath" "time" + + "sneak.berlin/go/pixa/internal/httpfetcher" ) // Cache errors. @@ -111,7 +113,7 @@ func (c *Cache) StoreSource( ctx context.Context, req *ImageRequest, content io.Reader, - result *FetchResult, + result *httpfetcher.FetchResult, ) (ContentHash, error) { // Store content contentHash, size, err := c.srcContent.Store(content) diff --git a/internal/imgcache/cache_test.go b/internal/imgcache/cache_test.go index dda75ef..8645b82 100644 --- a/internal/imgcache/cache_test.go +++ b/internal/imgcache/cache_test.go @@ -9,6 +9,7 @@ import ( "time" _ "modernc.org/sqlite" + "sneak.berlin/go/pixa/internal/httpfetcher" ) func setupTestDB(t *testing.T) *sql.DB { @@ -152,7 +153,7 @@ func TestCache_StoreAndLookup(t *testing.T) { // Store source content sourceContent := []byte("fake jpeg data") - fetchResult := &FetchResult{ + fetchResult := &httpfetcher.FetchResult{ ContentType: "image/jpeg", Headers: map[string][]string{"Content-Type": {"image/jpeg"}}, } diff --git a/internal/imgcache/imgcache.go b/internal/imgcache/imgcache.go index 605db01..105afb4 100644 --- a/internal/imgcache/imgcache.go +++ b/internal/imgcache/imgcache.go @@ -169,36 +169,6 @@ type Whitelist interface { IsWhitelisted(u *url.URL) bool } -// Fetcher fetches images from upstream origins -type Fetcher interface { - // Fetch retrieves an image from the origin - Fetch(ctx context.Context, url string) (*FetchResult, error) -} - -// FetchResult contains the result of fetching from upstream -type FetchResult struct { - // Content is the raw image data - Content io.ReadCloser - // ContentLength is the size in bytes (-1 if unknown) - ContentLength int64 - // ContentType is the MIME type from upstream - ContentType string - // Headers contains all response headers from upstream - Headers map[string][]string - // StatusCode is the HTTP status code from upstream - StatusCode int - // FetchDurationMs is how long the fetch took in milliseconds - FetchDurationMs int64 - // RemoteAddr is the IP:port of the upstream server - RemoteAddr string - // HTTPVersion is the protocol version (e.g., "1.1", "2.0") - HTTPVersion string - // TLSVersion is the TLS protocol version (e.g., "TLS 1.3") - TLSVersion string - // TLSCipherSuite is the negotiated cipher suite name - TLSCipherSuite string -} - // Storage handles persistent storage of cached content type Storage interface { // Store saves content and returns its hash diff --git a/internal/imgcache/service.go b/internal/imgcache/service.go index 849a90d..4229911 100644 --- a/internal/imgcache/service.go +++ b/internal/imgcache/service.go @@ -12,6 +12,7 @@ import ( "github.com/dustin/go-humanize" "sneak.berlin/go/pixa/internal/allowlist" + "sneak.berlin/go/pixa/internal/httpfetcher" "sneak.berlin/go/pixa/internal/imageprocessor" "sneak.berlin/go/pixa/internal/magic" ) @@ -19,7 +20,7 @@ import ( // Service implements the ImageCache interface, orchestrating cache, fetcher, and processor. type Service struct { cache *Cache - fetcher Fetcher + fetcher httpfetcher.Fetcher processor *imageprocessor.ImageProcessor signer *Signer allowlist *allowlist.HostAllowList @@ -33,9 +34,9 @@ type ServiceConfig struct { // Cache is the cache instance Cache *Cache // FetcherConfig configures the upstream fetcher (ignored if Fetcher is set) - FetcherConfig *FetcherConfig + FetcherConfig *httpfetcher.Config // Fetcher is an optional custom fetcher (for testing) - Fetcher Fetcher + Fetcher httpfetcher.Fetcher // SigningKey is the HMAC signing key (empty disables signing) SigningKey string // Whitelist is the list of hosts that don't require signatures @@ -57,15 +58,15 @@ func NewService(cfg *ServiceConfig) (*Service, error) { // Resolve fetcher config for defaults fetcherCfg := cfg.FetcherConfig if fetcherCfg == nil { - fetcherCfg = DefaultFetcherConfig() + fetcherCfg = httpfetcher.DefaultConfig() } // Use custom fetcher if provided, otherwise create HTTP fetcher - var fetcher Fetcher + var fetcher httpfetcher.Fetcher if cfg.Fetcher != nil { fetcher = cfg.Fetcher } else { - fetcher = NewHTTPFetcher(fetcherCfg) + fetcher = httpfetcher.New(fetcherCfg) } signer := NewSigner(cfg.SigningKey) @@ -113,7 +114,7 @@ func (s *Service) Get(ctx context.Context, req *ImageRequest) (*ImageResponse, e "path", req.SourcePath, ) - return nil, fmt.Errorf("%w: %w", ErrUpstreamError, ErrNegativeCached) + return nil, fmt.Errorf("%w: %w", httpfetcher.ErrUpstreamError, ErrNegativeCached) } // Check variant cache first (disk only, no DB) @@ -418,13 +419,13 @@ const ( // isNegativeCacheable returns true if the error should be cached. func isNegativeCacheable(err error) bool { - return errors.Is(err, ErrUpstreamError) + return errors.Is(err, httpfetcher.ErrUpstreamError) } // extractStatusCode extracts HTTP status code from error message. func extractStatusCode(err error) int { // Default to 502 Bad Gateway for upstream errors - if errors.Is(err, ErrUpstreamError) { + if errors.Is(err, httpfetcher.ErrUpstreamError) { return httpStatusBadGateway } diff --git a/internal/imgcache/testutil_test.go b/internal/imgcache/testutil_test.go index 3b85ece..2e23df5 100644 --- a/internal/imgcache/testutil_test.go +++ b/internal/imgcache/testutil_test.go @@ -15,6 +15,7 @@ import ( "time" "sneak.berlin/go/pixa/internal/database" + "sneak.berlin/go/pixa/internal/httpfetcher" ) // TestFixtures contains paths to test files in the mock filesystem. @@ -172,7 +173,7 @@ func SetupTestService(t *testing.T, opts ...TestServiceOption) (*Service, *TestF svc, err := NewService(&ServiceConfig{ Cache: cache, - Fetcher: NewMockFetcher(mockFS), + Fetcher: httpfetcher.NewMock(mockFS), SigningKey: cfg.signingKey, Whitelist: cfg.whitelist, })