package imgcache import ( "context" "errors" "fmt" "io" "net" "net/http" "strings" "sync" "time" ) // Fetcher configuration constants. const ( DefaultFetchTimeout = 30 * time.Second DefaultMaxResponseSize = 50 << 20 // 50MB DefaultTLSTimeout = 10 * time.Second DefaultMaxIdleConns = 100 DefaultIdleConnTimeout = 90 * time.Second DefaultMaxRedirects = 10 DefaultMaxConnectionsPerHost = 20 ) // Fetcher errors. var ( ErrSSRFBlocked = errors.New("request blocked: private or internal IP") ErrInvalidHost = errors.New("invalid or unresolvable host") ErrUnsupportedScheme = errors.New("only HTTPS is supported") ErrResponseTooLarge = errors.New("response exceeds maximum size") ErrInvalidContentType = errors.New("invalid or unsupported content type") ErrUpstreamError = errors.New("upstream server error") ErrUpstreamTimeout = errors.New("upstream request timeout") ) // FetcherConfig holds configuration for the upstream fetcher. type FetcherConfig struct { // Timeout for upstream requests Timeout time.Duration // MaxResponseSize is the maximum allowed response body size MaxResponseSize int64 // UserAgent to send to upstream servers UserAgent string // AllowedContentTypes is a whitelist of MIME types to accept AllowedContentTypes []string // AllowHTTP allows non-TLS connections (for testing only) AllowHTTP bool // MaxConnectionsPerHost limits concurrent connections to each upstream host MaxConnectionsPerHost int } // DefaultFetcherConfig returns sensible defaults. func DefaultFetcherConfig() *FetcherConfig { return &FetcherConfig{ Timeout: DefaultFetchTimeout, MaxResponseSize: DefaultMaxResponseSize, UserAgent: "pixa/1.0", AllowedContentTypes: []string{ "image/jpeg", "image/png", "image/gif", "image/webp", "image/avif", "image/svg+xml", }, AllowHTTP: false, MaxConnectionsPerHost: DefaultMaxConnectionsPerHost, } } // HTTPFetcher implements the Fetcher interface with SSRF protection. type HTTPFetcher struct { client *http.Client config *FetcherConfig hostSems map[string]chan struct{} // per-host semaphores hostSemMu sync.Mutex // protects hostSems map } // NewHTTPFetcher creates a new fetcher with SSRF protection. func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher { if config == nil { config = DefaultFetcherConfig() } // Create transport with SSRF-safe dialer transport := &http.Transport{ DialContext: ssrfSafeDialer, TLSHandshakeTimeout: DefaultTLSTimeout, MaxIdleConns: DefaultMaxIdleConns, IdleConnTimeout: DefaultIdleConnTimeout, } client := &http.Client{ Transport: transport, Timeout: config.Timeout, // Don't follow redirects automatically - we need to validate each hop CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= DefaultMaxRedirects { return errors.New("too many redirects") } // Validate the redirect target if err := validateURL(req.URL.String(), config.AllowHTTP); err != nil { return fmt.Errorf("redirect blocked: %w", err) } return nil }, } return &HTTPFetcher{ client: client, config: config, hostSems: make(map[string]chan struct{}), } } // getHostSemaphore returns the semaphore for a host, creating it if necessary. func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} { f.hostSemMu.Lock() defer f.hostSemMu.Unlock() sem, ok := f.hostSems[host] if !ok { sem = make(chan struct{}, f.config.MaxConnectionsPerHost) f.hostSems[host] = sem } return sem } // Fetch retrieves content from the given URL with SSRF protection. func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) { // Validate URL before making request if err := validateURL(url, f.config.AllowHTTP); err != nil { return nil, err } // Extract host for rate limiting host := extractHost(url) // Acquire semaphore slot for this host sem := f.getHostSemaphore(host) select { case sem <- struct{}{}: // Acquired slot case <-ctx.Done(): return nil, ctx.Err() } // If we fail before returning a result, release the slot success := false defer func() { if !success { <-sem } }() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", f.config.UserAgent) req.Header.Set("Accept", strings.Join(f.config.AllowedContentTypes, ", ")) startTime := time.Now() resp, err := f.client.Do(req) fetchDuration := time.Since(startTime) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return nil, ErrUpstreamTimeout } return nil, fmt.Errorf("upstream request failed: %w", err) } // Get remote address if available var remoteAddr string if resp.Request != nil && resp.Request.URL != nil { remoteAddr = resp.Request.Host } // Check status code if resp.StatusCode < 200 || resp.StatusCode >= 300 { _ = resp.Body.Close() return nil, fmt.Errorf("%w: status %d", ErrUpstreamError, resp.StatusCode) } // Validate content type contentType := resp.Header.Get("Content-Type") if !f.isAllowedContentType(contentType) { _ = resp.Body.Close() return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType) } // Wrap body with size limiter and semaphore releaser limitedBody := &limitedReader{ reader: resp.Body, remaining: f.config.MaxResponseSize, } // Mark success so defer doesn't release the semaphore success = true return &FetchResult{ Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem}, ContentLength: resp.ContentLength, ContentType: contentType, Headers: resp.Header, StatusCode: resp.StatusCode, FetchDurationMs: fetchDuration.Milliseconds(), RemoteAddr: remoteAddr, }, nil } // isAllowedContentType checks if the content type is in the whitelist. func (f *HTTPFetcher) isAllowedContentType(contentType string) bool { // Extract the MIME type without parameters mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0]) for _, allowed := range f.config.AllowedContentTypes { if strings.EqualFold(mediaType, allowed) { return true } } return false } // validateURL checks if a URL is safe to fetch (not internal/private). func validateURL(rawURL string, allowHTTP bool) error { if !allowHTTP && !strings.HasPrefix(rawURL, "https://") { return ErrUnsupportedScheme } // Parse to extract host host := extractHost(rawURL) if host == "" { return ErrInvalidHost } // Remove port if present if h, _, err := net.SplitHostPort(host); err == nil { host = h } // Block obvious localhost patterns if isLocalhost(host) { return ErrSSRFBlocked } // Resolve the host to check IP addresses ips, err := net.LookupIP(host) if err != nil { return fmt.Errorf("%w: %s", ErrInvalidHost, host) } for _, ip := range ips { if isPrivateIP(ip) { return ErrSSRFBlocked } } return nil } // extractHost extracts the host from a URL string. func extractHost(rawURL string) string { // Simple extraction without full URL parsing url := rawURL if idx := strings.Index(url, "://"); idx != -1 { url = url[idx+3:] } if idx := strings.Index(url, "/"); idx != -1 { url = url[:idx] } if idx := strings.Index(url, "?"); idx != -1 { url = url[:idx] } return url } // isLocalhost checks if the host is localhost. func isLocalhost(host string) bool { host = strings.ToLower(host) return host == "localhost" || host == "127.0.0.1" || host == "::1" || host == "[::1]" || strings.HasSuffix(host, ".localhost") || strings.HasSuffix(host, ".local") } // isPrivateIP checks if an IP is private, loopback, or otherwise internal. func isPrivateIP(ip net.IP) bool { if ip == nil { return true } // Check for loopback if ip.IsLoopback() { return true } // Check for private ranges if ip.IsPrivate() { return true } // Check for link-local if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { return true } // Check for unspecified (0.0.0.0 or ::) if ip.IsUnspecified() { return true } // Check for multicast if ip.IsMulticast() { return true } // Additional checks for IPv4 if ip4 := ip.To4(); ip4 != nil { // 169.254.0.0/16 - Link local if ip4[0] == 169 && ip4[1] == 254 { return true } // 0.0.0.0/8 - Current network if ip4[0] == 0 { return true } } return false } // ssrfSafeDialer is a custom dialer that validates IP addresses before connecting. func ssrfSafeDialer(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err } // Resolve the address ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) if err != nil { return nil, fmt.Errorf("%w: %s", ErrInvalidHost, host) } // Check all resolved IPs for _, ip := range ips { if isPrivateIP(ip) { return nil, ErrSSRFBlocked } } // Connect using the first valid IP var dialer net.Dialer for _, ip := range ips { addr := net.JoinHostPort(ip.String(), port) conn, err := dialer.DialContext(ctx, network, addr) if err == nil { return conn, nil } } return nil, fmt.Errorf("failed to connect to %s", host) } // limitedReader wraps a reader and limits the number of bytes read. type limitedReader struct { reader io.Reader remaining int64 } func (r *limitedReader) Read(p []byte) (int, error) { if r.remaining <= 0 { return 0, ErrResponseTooLarge } if int64(len(p)) > r.remaining { p = p[:r.remaining] } n, err := r.reader.Read(p) r.remaining -= int64(n) return n, err } // semaphoreReleasingReadCloser releases a semaphore slot when closed. type semaphoreReleasingReadCloser struct { *limitedReader closer io.Closer sem chan struct{} } func (r *semaphoreReleasingReadCloser) Close() error { err := r.closer.Close() <-r.sem // Release semaphore slot return err }