refactor: extract httpfetcher package from imgcache
All checks were successful
check / check (push) Successful in 57s
All checks were successful
check / check (push) Successful in 57s
Move HTTPFetcher, Config (was FetcherConfig), SSRF-safe dialer, rate limiting, content-type validation, and related error vars from internal/imgcache/fetcher.go into new internal/httpfetcher/ package. The Fetcher interface and FetchResult type also move to httpfetcher to avoid circular imports (imgcache imports httpfetcher, not the other way around). Renames to avoid stuttering: NewHTTPFetcher -> httpfetcher.New FetcherConfig -> httpfetcher.Config NewMockFetcher -> httpfetcher.NewMock The ServiceConfig.FetcherConfig field is retained (it describes what kind of config it holds, not a stutter). Pure refactor - no behavior changes. Unit tests for the httpfetcher package are included. refs #39
This commit is contained in:
477
internal/httpfetcher/httpfetcher.go
Normal file
477
internal/httpfetcher/httpfetcher.go
Normal file
@@ -0,0 +1,477 @@
|
||||
// Package httpfetcher fetches content from upstream HTTP origins with SSRF
|
||||
// protection, per-host connection limits, and content-type validation.
|
||||
package httpfetcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
neturl "net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Fetcher configuration constants.
|
||||
const (
|
||||
DefaultFetchTimeout = 30 * time.Second
|
||||
DefaultMaxResponseSize = 50 << 20 // 50MB
|
||||
DefaultTLSTimeout = 10 * time.Second
|
||||
DefaultMaxIdleConns = 100
|
||||
DefaultIdleConnTimeout = 90 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
DefaultMaxConnectionsPerHost = 20
|
||||
)
|
||||
|
||||
// Fetcher errors.
|
||||
var (
|
||||
ErrSSRFBlocked = errors.New("request blocked: private or internal IP")
|
||||
ErrInvalidHost = errors.New("invalid or unresolvable host")
|
||||
ErrUnsupportedScheme = errors.New("only HTTPS is supported")
|
||||
ErrResponseTooLarge = errors.New("response exceeds maximum size")
|
||||
ErrInvalidContentType = errors.New("invalid or unsupported content type")
|
||||
ErrUpstreamError = errors.New("upstream server error")
|
||||
ErrUpstreamTimeout = errors.New("upstream request timeout")
|
||||
)
|
||||
|
||||
// Fetcher retrieves content from upstream origins.
|
||||
type Fetcher interface {
|
||||
// Fetch retrieves content from the given URL.
|
||||
Fetch(ctx context.Context, url string) (*FetchResult, error)
|
||||
}
|
||||
|
||||
// FetchResult contains the result of fetching from upstream.
|
||||
type FetchResult struct {
|
||||
// Content is the raw image data.
|
||||
Content io.ReadCloser
|
||||
// ContentLength is the size in bytes (-1 if unknown).
|
||||
ContentLength int64
|
||||
// ContentType is the MIME type from upstream.
|
||||
ContentType string
|
||||
// Headers contains all response headers from upstream.
|
||||
Headers map[string][]string
|
||||
// StatusCode is the HTTP status code from upstream.
|
||||
StatusCode int
|
||||
// FetchDurationMs is how long the fetch took in milliseconds.
|
||||
FetchDurationMs int64
|
||||
// RemoteAddr is the IP:port of the upstream server.
|
||||
RemoteAddr string
|
||||
// HTTPVersion is the protocol version (e.g., "1.1", "2.0").
|
||||
HTTPVersion string
|
||||
// TLSVersion is the TLS protocol version (e.g., "TLS 1.3").
|
||||
TLSVersion string
|
||||
// TLSCipherSuite is the negotiated cipher suite name.
|
||||
TLSCipherSuite string
|
||||
}
|
||||
|
||||
// Config holds configuration for the upstream fetcher.
|
||||
type Config struct {
|
||||
// Timeout for upstream requests.
|
||||
Timeout time.Duration
|
||||
// MaxResponseSize is the maximum allowed response body size.
|
||||
MaxResponseSize int64
|
||||
// UserAgent to send to upstream servers.
|
||||
UserAgent string
|
||||
// AllowedContentTypes is an allow list of MIME types to accept.
|
||||
AllowedContentTypes []string
|
||||
// AllowHTTP allows non-TLS connections (for testing only).
|
||||
AllowHTTP bool
|
||||
// MaxConnectionsPerHost limits concurrent connections to each upstream host.
|
||||
MaxConnectionsPerHost int
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Timeout: DefaultFetchTimeout,
|
||||
MaxResponseSize: DefaultMaxResponseSize,
|
||||
UserAgent: "pixa/1.0",
|
||||
AllowedContentTypes: []string{
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/avif",
|
||||
"image/svg+xml",
|
||||
},
|
||||
AllowHTTP: false,
|
||||
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPFetcher implements Fetcher with SSRF protection and per-host connection limits.
|
||||
type HTTPFetcher struct {
|
||||
client *http.Client
|
||||
config *Config
|
||||
hostSems map[string]chan struct{} // per-host semaphores
|
||||
hostSemMu sync.Mutex // protects hostSems map
|
||||
}
|
||||
|
||||
// New creates a new HTTPFetcher with SSRF protection.
|
||||
func New(config *Config) *HTTPFetcher {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
|
||||
// Create transport with SSRF-safe dialer
|
||||
transport := &http.Transport{
|
||||
DialContext: ssrfSafeDialer,
|
||||
TLSHandshakeTimeout: DefaultTLSTimeout,
|
||||
MaxIdleConns: DefaultMaxIdleConns,
|
||||
IdleConnTimeout: DefaultIdleConnTimeout,
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: config.Timeout,
|
||||
// Don't follow redirects automatically - we need to validate each hop
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= DefaultMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
// Validate the redirect target
|
||||
if err := validateURL(req.URL.String(), config.AllowHTTP); err != nil {
|
||||
return fmt.Errorf("redirect blocked: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return &HTTPFetcher{
|
||||
client: client,
|
||||
config: config,
|
||||
hostSems: make(map[string]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
|
||||
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
|
||||
f.hostSemMu.Lock()
|
||||
defer f.hostSemMu.Unlock()
|
||||
|
||||
sem, ok := f.hostSems[host]
|
||||
if !ok {
|
||||
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
|
||||
f.hostSems[host] = sem
|
||||
}
|
||||
|
||||
return sem
|
||||
}
|
||||
|
||||
// Fetch retrieves content from the given URL with SSRF protection.
|
||||
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
||||
// Validate URL before making request
|
||||
if err := validateURL(url, f.config.AllowHTTP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract host for rate limiting
|
||||
host := extractHost(url)
|
||||
|
||||
// Acquire semaphore slot for this host
|
||||
sem := f.getHostSemaphore(host)
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
// Acquired slot
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// If we fail before returning a result, release the slot
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
<-sem
|
||||
}
|
||||
}()
|
||||
|
||||
parsedURL, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse URL: %w", err)
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: parsedURL,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
req.Header.Set("User-Agent", f.config.UserAgent)
|
||||
req.Header.Set("Accept", strings.Join(f.config.AllowedContentTypes, ", "))
|
||||
|
||||
// Use httptrace to capture connection details
|
||||
var remoteAddr string
|
||||
|
||||
trace := &httptrace.ClientTrace{
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
if info.Conn != nil {
|
||||
remoteAddr = info.Conn.RemoteAddr().String()
|
||||
}
|
||||
},
|
||||
}
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
//nolint:gosec // G704: URL validated by validateURL() above
|
||||
resp, err := f.client.Do(req)
|
||||
|
||||
fetchDuration := time.Since(startTime)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, ErrUpstreamTimeout
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
// Extract HTTP version (strip "HTTP/" prefix)
|
||||
httpVersion := strings.TrimPrefix(resp.Proto, "HTTP/")
|
||||
|
||||
// Extract TLS info if available
|
||||
var tlsVersion, tlsCipherSuite string
|
||||
|
||||
if resp.TLS != nil {
|
||||
tlsVersion = tls.VersionName(resp.TLS.Version)
|
||||
tlsCipherSuite = tls.CipherSuiteName(resp.TLS.CipherSuite)
|
||||
}
|
||||
|
||||
// Check status code
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return nil, fmt.Errorf("%w: status %d", ErrUpstreamError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Validate content type
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !f.isAllowedContentType(contentType) {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
|
||||
}
|
||||
|
||||
// Wrap body with size limiter and semaphore releaser
|
||||
limitedBody := &limitedReader{
|
||||
reader: resp.Body,
|
||||
remaining: f.config.MaxResponseSize,
|
||||
}
|
||||
|
||||
// Mark success so defer doesn't release the semaphore
|
||||
success = true
|
||||
|
||||
return &FetchResult{
|
||||
Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
|
||||
ContentLength: resp.ContentLength,
|
||||
ContentType: contentType,
|
||||
Headers: resp.Header,
|
||||
StatusCode: resp.StatusCode,
|
||||
FetchDurationMs: fetchDuration.Milliseconds(),
|
||||
RemoteAddr: remoteAddr,
|
||||
HTTPVersion: httpVersion,
|
||||
TLSVersion: tlsVersion,
|
||||
TLSCipherSuite: tlsCipherSuite,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isAllowedContentType checks if the content type is in the allow list.
|
||||
func (f *HTTPFetcher) isAllowedContentType(contentType string) bool {
|
||||
// Extract the MIME type without parameters
|
||||
mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0])
|
||||
|
||||
for _, allowed := range f.config.AllowedContentTypes {
|
||||
if strings.EqualFold(mediaType, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateURL checks if a URL is safe to fetch (not internal/private).
|
||||
func validateURL(rawURL string, allowHTTP bool) error {
|
||||
if !allowHTTP && !strings.HasPrefix(rawURL, "https://") {
|
||||
return ErrUnsupportedScheme
|
||||
}
|
||||
|
||||
// Parse to extract host
|
||||
host := extractHost(rawURL)
|
||||
if host == "" {
|
||||
return ErrInvalidHost
|
||||
}
|
||||
|
||||
// Remove port if present
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
// Block obvious localhost patterns
|
||||
if isLocalhost(host) {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
// Resolve the host to check IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrInvalidHost, host)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if isPrivateIP(ip) {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractHost extracts the host from a URL string.
|
||||
func extractHost(rawURL string) string {
|
||||
// Simple extraction without full URL parsing
|
||||
url := rawURL
|
||||
if idx := strings.Index(url, "://"); idx != -1 {
|
||||
url = url[idx+3:]
|
||||
}
|
||||
if idx := strings.Index(url, "/"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
if idx := strings.Index(url, "?"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
// isLocalhost checks if the host is localhost.
|
||||
func isLocalhost(host string) bool {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
return host == "localhost" ||
|
||||
host == "127.0.0.1" ||
|
||||
host == "::1" ||
|
||||
host == "[::1]" ||
|
||||
strings.HasSuffix(host, ".localhost") ||
|
||||
strings.HasSuffix(host, ".local")
|
||||
}
|
||||
|
||||
// isPrivateIP checks if an IP is private, loopback, or otherwise internal.
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for loopback
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for private ranges
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for link-local
|
||||
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unspecified (0.0.0.0 or ::)
|
||||
if ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for multicast
|
||||
if ip.IsMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Additional checks for IPv4
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// 169.254.0.0/16 - Link local
|
||||
if ip4[0] == 169 && ip4[1] == 254 {
|
||||
return true
|
||||
}
|
||||
// 0.0.0.0/8 - Current network
|
||||
if ip4[0] == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ssrfSafeDialer is a custom dialer that validates IP addresses before connecting.
|
||||
func ssrfSafeDialer(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Resolve the address
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidHost, host)
|
||||
}
|
||||
|
||||
// Check all resolved IPs
|
||||
for _, ip := range ips {
|
||||
if isPrivateIP(ip) {
|
||||
return nil, ErrSSRFBlocked
|
||||
}
|
||||
}
|
||||
|
||||
// Connect using the first valid IP
|
||||
var dialer net.Dialer
|
||||
for _, ip := range ips {
|
||||
addr := net.JoinHostPort(ip.String(), port)
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to connect to %s", host)
|
||||
}
|
||||
|
||||
// limitedReader wraps a reader and limits the number of bytes read.
|
||||
type limitedReader struct {
|
||||
reader io.Reader
|
||||
remaining int64
|
||||
}
|
||||
|
||||
func (r *limitedReader) Read(p []byte) (int, error) {
|
||||
if r.remaining <= 0 {
|
||||
return 0, ErrResponseTooLarge
|
||||
}
|
||||
|
||||
if int64(len(p)) > r.remaining {
|
||||
p = p[:r.remaining]
|
||||
}
|
||||
|
||||
n, err := r.reader.Read(p)
|
||||
r.remaining -= int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// semaphoreReleasingReadCloser releases a semaphore slot when closed.
|
||||
type semaphoreReleasingReadCloser struct {
|
||||
*limitedReader
|
||||
closer io.Closer
|
||||
sem chan struct{}
|
||||
}
|
||||
|
||||
func (r *semaphoreReleasingReadCloser) Close() error {
|
||||
err := r.closer.Close()
|
||||
<-r.sem // Release semaphore slot
|
||||
|
||||
return err
|
||||
}
|
||||
329
internal/httpfetcher/httpfetcher_test.go
Normal file
329
internal/httpfetcher/httpfetcher_test.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package httpfetcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
)
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Timeout != DefaultFetchTimeout {
|
||||
t.Errorf("Timeout = %v, want %v", cfg.Timeout, DefaultFetchTimeout)
|
||||
}
|
||||
|
||||
if cfg.MaxResponseSize != DefaultMaxResponseSize {
|
||||
t.Errorf("MaxResponseSize = %d, want %d", cfg.MaxResponseSize, DefaultMaxResponseSize)
|
||||
}
|
||||
|
||||
if cfg.MaxConnectionsPerHost != DefaultMaxConnectionsPerHost {
|
||||
t.Errorf("MaxConnectionsPerHost = %d, want %d",
|
||||
cfg.MaxConnectionsPerHost, DefaultMaxConnectionsPerHost)
|
||||
}
|
||||
|
||||
if cfg.AllowHTTP {
|
||||
t.Error("AllowHTTP should default to false")
|
||||
}
|
||||
|
||||
if len(cfg.AllowedContentTypes) == 0 {
|
||||
t.Error("AllowedContentTypes should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithNilConfigUsesDefaults(t *testing.T) {
|
||||
f := New(nil)
|
||||
|
||||
if f == nil {
|
||||
t.Fatal("New(nil) returned nil")
|
||||
}
|
||||
|
||||
if f.config == nil {
|
||||
t.Fatal("config should be populated from DefaultConfig")
|
||||
}
|
||||
|
||||
if f.config.Timeout != DefaultFetchTimeout {
|
||||
t.Errorf("Timeout = %v, want %v", f.config.Timeout, DefaultFetchTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedContentType(t *testing.T) {
|
||||
f := New(DefaultConfig())
|
||||
|
||||
tests := []struct {
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{"image/jpeg", true},
|
||||
{"image/png", true},
|
||||
{"image/webp", true},
|
||||
{"image/jpeg; charset=utf-8", true},
|
||||
{"IMAGE/JPEG", true},
|
||||
{"text/html", false},
|
||||
{"application/octet-stream", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.contentType, func(t *testing.T) {
|
||||
got := f.isAllowedContentType(tc.contentType)
|
||||
if got != tc.want {
|
||||
t.Errorf("isAllowedContentType(%q) = %v, want %v", tc.contentType, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
url string
|
||||
want string
|
||||
}{
|
||||
{"https://example.com/path", "example.com"},
|
||||
{"http://example.com:8080/path", "example.com:8080"},
|
||||
{"https://example.com", "example.com"},
|
||||
{"https://example.com?q=1", "example.com"},
|
||||
{"example.com/path", "example.com"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.url, func(t *testing.T) {
|
||||
got := extractHost(tc.url)
|
||||
if got != tc.want {
|
||||
t.Errorf("extractHost(%q) = %q, want %q", tc.url, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalhost(t *testing.T) {
|
||||
tests := []struct {
|
||||
host string
|
||||
want bool
|
||||
}{
|
||||
{"localhost", true},
|
||||
{"LOCALHOST", true},
|
||||
{"127.0.0.1", true},
|
||||
{"::1", true},
|
||||
{"[::1]", true},
|
||||
{"foo.localhost", true},
|
||||
{"foo.local", true},
|
||||
{"example.com", false},
|
||||
{"127.0.0.2", false}, // Handled by isPrivateIP, not isLocalhost string match
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.host, func(t *testing.T) {
|
||||
got := isLocalhost(tc.host)
|
||||
if got != tc.want {
|
||||
t.Errorf("isLocalhost(%q) = %v, want %v", tc.host, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
want bool
|
||||
}{
|
||||
{"127.0.0.1", true}, // loopback
|
||||
{"10.0.0.1", true}, // private
|
||||
{"192.168.1.1", true}, // private
|
||||
{"172.16.0.1", true}, // private
|
||||
{"169.254.1.1", true}, // link-local
|
||||
{"0.0.0.0", true}, // unspecified
|
||||
{"224.0.0.1", true}, // multicast
|
||||
{"::1", true}, // IPv6 loopback
|
||||
{"fe80::1", true}, // IPv6 link-local
|
||||
{"8.8.8.8", false}, // public
|
||||
{"2001:4860:4860::8888", false}, // public IPv6
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.ip, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", tc.ip)
|
||||
}
|
||||
|
||||
got := isPrivateIP(ip)
|
||||
if got != tc.want {
|
||||
t.Errorf("isPrivateIP(%q) = %v, want %v", tc.ip, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if !isPrivateIP(nil) {
|
||||
t.Error("isPrivateIP(nil) should return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_RejectsNonHTTPS(t *testing.T) {
|
||||
err := validateURL("http://example.com/path", false)
|
||||
if !errors.Is(err, ErrUnsupportedScheme) {
|
||||
t.Errorf("validateURL http = %v, want ErrUnsupportedScheme", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_AllowsHTTPWhenConfigured(t *testing.T) {
|
||||
// Use a host that won't resolve (explicit .invalid TLD) so we don't hit DNS.
|
||||
err := validateURL("http://nonexistent.invalid/path", true)
|
||||
// We expect a host resolution error, not ErrUnsupportedScheme.
|
||||
if errors.Is(err, ErrUnsupportedScheme) {
|
||||
t.Error("validateURL with AllowHTTP should not return ErrUnsupportedScheme")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_RejectsLocalhost(t *testing.T) {
|
||||
err := validateURL("https://localhost/path", false)
|
||||
if !errors.Is(err, ErrSSRFBlocked) {
|
||||
t.Errorf("validateURL localhost = %v, want ErrSSRFBlocked", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_EmptyHost(t *testing.T) {
|
||||
err := validateURL("https:///path", false)
|
||||
if !errors.Is(err, ErrInvalidHost) {
|
||||
t.Errorf("validateURL empty host = %v, want ErrInvalidHost", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockFetcher_FetchesFile(t *testing.T) {
|
||||
mockFS := fstest.MapFS{
|
||||
"example.com/images/photo.jpg": &fstest.MapFile{Data: []byte("fake-jpeg-data")},
|
||||
}
|
||||
|
||||
m := NewMock(mockFS)
|
||||
|
||||
result, err := m.Fetch(context.Background(), "https://example.com/images/photo.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("Fetch() error = %v", err)
|
||||
}
|
||||
defer func() { _ = result.Content.Close() }()
|
||||
|
||||
if result.ContentType != "image/jpeg" {
|
||||
t.Errorf("ContentType = %q, want image/jpeg", result.ContentType)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(result.Content)
|
||||
if err != nil {
|
||||
t.Fatalf("read content: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != "fake-jpeg-data" {
|
||||
t.Errorf("Content = %q, want %q", string(data), "fake-jpeg-data")
|
||||
}
|
||||
|
||||
if result.ContentLength != int64(len("fake-jpeg-data")) {
|
||||
t.Errorf("ContentLength = %d, want %d", result.ContentLength, len("fake-jpeg-data"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockFetcher_MissingFileReturnsUpstreamError(t *testing.T) {
|
||||
mockFS := fstest.MapFS{}
|
||||
m := NewMock(mockFS)
|
||||
|
||||
_, err := m.Fetch(context.Background(), "https://example.com/missing.jpg")
|
||||
if !errors.Is(err, ErrUpstreamError) {
|
||||
t.Errorf("Fetch() error = %v, want ErrUpstreamError", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockFetcher_RespectsContextCancellation(t *testing.T) {
|
||||
mockFS := fstest.MapFS{
|
||||
"example.com/photo.jpg": &fstest.MapFile{Data: []byte("data")},
|
||||
}
|
||||
m := NewMock(mockFS)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := m.Fetch(ctx, "https://example.com/photo.jpg")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("Fetch() error = %v, want context.Canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectContentTypeFromPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{"foo/bar.jpg", "image/jpeg"},
|
||||
{"foo/bar.JPG", "image/jpeg"},
|
||||
{"foo/bar.jpeg", "image/jpeg"},
|
||||
{"foo/bar.png", "image/png"},
|
||||
{"foo/bar.gif", "image/gif"},
|
||||
{"foo/bar.webp", "image/webp"},
|
||||
{"foo/bar.avif", "image/avif"},
|
||||
{"foo/bar.svg", "image/svg+xml"},
|
||||
{"foo/bar.bin", "application/octet-stream"},
|
||||
{"foo/bar", "application/octet-stream"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
got := detectContentTypeFromPath(tc.path)
|
||||
if got != tc.want {
|
||||
t.Errorf("detectContentTypeFromPath(%q) = %q, want %q", tc.path, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReader_EnforcesLimit(t *testing.T) {
|
||||
src := make([]byte, 100)
|
||||
r := &limitedReader{
|
||||
reader: &byteReader{data: src},
|
||||
remaining: 50,
|
||||
}
|
||||
|
||||
buf := make([]byte, 100)
|
||||
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("first Read error = %v", err)
|
||||
}
|
||||
|
||||
if n > 50 {
|
||||
t.Errorf("read %d bytes, should be capped at 50", n)
|
||||
}
|
||||
|
||||
// Drain until limit is exhausted.
|
||||
total := n
|
||||
for total < 50 {
|
||||
nn, err := r.Read(buf)
|
||||
total += nn
|
||||
if err != nil {
|
||||
t.Fatalf("during drain: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now the limit is exhausted — next read should error.
|
||||
_, err = r.Read(buf)
|
||||
if !errors.Is(err, ErrResponseTooLarge) {
|
||||
t.Errorf("exhausted Read error = %v, want ErrResponseTooLarge", err)
|
||||
}
|
||||
}
|
||||
|
||||
// byteReader is a minimal io.Reader over a byte slice for testing.
|
||||
type byteReader struct {
|
||||
data []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (r *byteReader) Read(p []byte) (int, error) {
|
||||
if r.pos >= len(r.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n := copy(p, r.data[r.pos:])
|
||||
r.pos += n
|
||||
|
||||
return n, nil
|
||||
}
|
||||
115
internal/httpfetcher/mock.go
Normal file
115
internal/httpfetcher/mock.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package httpfetcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MockFetcher implements Fetcher using an embedded filesystem.
|
||||
// Files are organized as: hostname/path/to/file.ext
|
||||
// URLs like https://example.com/images/photo.jpg map to example.com/images/photo.jpg.
|
||||
type MockFetcher struct {
|
||||
fs fs.FS
|
||||
}
|
||||
|
||||
// NewMock creates a new mock fetcher backed by the given filesystem.
|
||||
func NewMock(fsys fs.FS) *MockFetcher {
|
||||
return &MockFetcher{fs: fsys}
|
||||
}
|
||||
|
||||
// Fetch retrieves content from the mock filesystem.
|
||||
func (m *MockFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
||||
// Check context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Parse URL to get filesystem path
|
||||
path, err := urlToFSPath(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Open the file
|
||||
f, err := m.fs.Open(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%w: status 404", ErrUpstreamError)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to open mock file: %w", err)
|
||||
}
|
||||
|
||||
// Get file info for content length
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
|
||||
return nil, fmt.Errorf("failed to stat mock file: %w", err)
|
||||
}
|
||||
|
||||
// Detect content type from extension
|
||||
contentType := detectContentTypeFromPath(path)
|
||||
|
||||
return &FetchResult{
|
||||
Content: f.(io.ReadCloser),
|
||||
ContentLength: stat.Size(),
|
||||
ContentType: contentType,
|
||||
Headers: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// urlToFSPath converts a URL to a filesystem path.
|
||||
// https://example.com/images/photo.jpg -> example.com/images/photo.jpg
|
||||
func urlToFSPath(rawURL string) (string, error) {
|
||||
// Strip scheme
|
||||
url := rawURL
|
||||
if idx := strings.Index(url, "://"); idx != -1 {
|
||||
url = url[idx+3:]
|
||||
}
|
||||
|
||||
// Remove query string
|
||||
if idx := strings.Index(url, "?"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
|
||||
// Remove fragment
|
||||
if idx := strings.Index(url, "#"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
|
||||
if url == "" {
|
||||
return "", errors.New("empty URL path")
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// detectContentTypeFromPath returns the MIME type based on file extension.
|
||||
func detectContentTypeFromPath(path string) string {
|
||||
path = strings.ToLower(path)
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(path, ".jpg"), strings.HasSuffix(path, ".jpeg"):
|
||||
return "image/jpeg"
|
||||
case strings.HasSuffix(path, ".png"):
|
||||
return "image/png"
|
||||
case strings.HasSuffix(path, ".gif"):
|
||||
return "image/gif"
|
||||
case strings.HasSuffix(path, ".webp"):
|
||||
return "image/webp"
|
||||
case strings.HasSuffix(path, ".avif"):
|
||||
return "image/avif"
|
||||
case strings.HasSuffix(path, ".svg"):
|
||||
return "image/svg+xml"
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user