From f244d9c7e0fa58970e28d5b0e9d73bf48b10f6f0 Mon Sep 17 00:00:00 2001 From: sneak Date: Thu, 8 Jan 2026 05:19:20 -0800 Subject: [PATCH] Add per-host connection limits for upstream fetching - Add upstream_connections_per_host config option (default: 20) - Implement per-host semaphores to limit concurrent connections - Semaphore released when response body is closed - Prevents overwhelming origin servers with parallel requests --- example-config.yml | 3 ++ internal/config/config.go | 33 +++++++------- internal/handlers/handlers.go | 3 ++ internal/imgcache/fetcher.go | 83 ++++++++++++++++++++++++++++------- 4 files changed, 90 insertions(+), 32 deletions(-) diff --git a/example-config.yml b/example-config.yml index cbfa387..0209cee 100644 --- a/example-config.yml +++ b/example-config.yml @@ -22,6 +22,9 @@ whitelist_hosts: # Allow HTTP upstream (only for testing, always use HTTPS in production) allow_http: false +# Maximum concurrent connections per upstream host (default: 20) +upstream_connections_per_host: 20 + # Sentry error reporting (optional) sentry_dsn: "" diff --git a/internal/config/config.go b/internal/config/config.go index 93367f9..8230d5a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,8 +16,9 @@ import ( // Default configuration values. const ( - DefaultPort = 8080 - DefaultStateDir = "./data" + DefaultPort = 8080 + DefaultStateDir = "./data" + DefaultUpstreamConnectionsPerHost = 20 ) // Params defines dependencies for Config. @@ -39,9 +40,10 @@ type Config struct { DBURL string // Image proxy settings - SigningKey string // HMAC signing key for URL signatures - WhitelistHosts []string // Hosts that don't require signatures - AllowHTTP bool // Allow non-TLS upstream (testing only) + SigningKey string // HMAC signing key for URL signatures + WhitelistHosts []string // Hosts that don't require signatures + AllowHTTP bool // Allow non-TLS upstream (testing only) + UpstreamConnectionsPerHost int // Max concurrent connections per upstream host } // New creates a new Config instance by loading configuration from file. @@ -59,16 +61,17 @@ func New(_ fx.Lifecycle, params Params) (*Config, error) { } c := &Config{ - Debug: getBool(sc, "debug", false), - MaintenanceMode: getBool(sc, "maintenance_mode", false), - Port: getInt(sc, "port", DefaultPort), - StateDir: getString(sc, "state_dir", DefaultStateDir), - SentryDSN: getString(sc, "sentry_dsn", ""), - MetricsUsername: getString(sc, "metrics.username", ""), - MetricsPassword: getString(sc, "metrics.password", ""), - SigningKey: getString(sc, "signing_key", ""), - WhitelistHosts: getStringSlice(sc, "whitelist_hosts"), - AllowHTTP: getBool(sc, "allow_http", false), + Debug: getBool(sc, "debug", false), + MaintenanceMode: getBool(sc, "maintenance_mode", false), + Port: getInt(sc, "port", DefaultPort), + StateDir: getString(sc, "state_dir", DefaultStateDir), + SentryDSN: getString(sc, "sentry_dsn", ""), + MetricsUsername: getString(sc, "metrics.username", ""), + MetricsPassword: getString(sc, "metrics.password", ""), + SigningKey: getString(sc, "signing_key", ""), + WhitelistHosts: getStringSlice(sc, "whitelist_hosts"), + AllowHTTP: getBool(sc, "allow_http", false), + UpstreamConnectionsPerHost: getInt(sc, "upstream_connections_per_host", DefaultUpstreamConnectionsPerHost), } // Build DBURL from StateDir if not explicitly set diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 97eecbb..48935a2 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -72,6 +72,9 @@ func (s *Handlers) initImageService() error { // Create the fetcher config fetcherCfg := imgcache.DefaultFetcherConfig() fetcherCfg.AllowHTTP = s.config.AllowHTTP + if s.config.UpstreamConnectionsPerHost > 0 { + fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost + } // Create the service svc, err := imgcache.NewService(&imgcache.ServiceConfig{ diff --git a/internal/imgcache/fetcher.go b/internal/imgcache/fetcher.go index dd9f75e..3ab6acc 100644 --- a/internal/imgcache/fetcher.go +++ b/internal/imgcache/fetcher.go @@ -8,17 +8,19 @@ import ( "net" "net/http" "strings" + "sync" "time" ) // Fetcher configuration constants. const ( - DefaultFetchTimeout = 30 * time.Second - DefaultMaxResponseSize = 50 << 20 // 50MB - DefaultTLSTimeout = 10 * time.Second - DefaultMaxIdleConns = 100 - DefaultIdleConnTimeout = 90 * time.Second - DefaultMaxRedirects = 10 + DefaultFetchTimeout = 30 * time.Second + DefaultMaxResponseSize = 50 << 20 // 50MB + DefaultTLSTimeout = 10 * time.Second + DefaultMaxIdleConns = 100 + DefaultIdleConnTimeout = 90 * time.Second + DefaultMaxRedirects = 10 + DefaultMaxConnectionsPerHost = 20 ) // Fetcher errors. @@ -44,6 +46,8 @@ type FetcherConfig struct { AllowedContentTypes []string // AllowHTTP allows non-TLS connections (for testing only) AllowHTTP bool + // MaxConnectionsPerHost limits concurrent connections to each upstream host + MaxConnectionsPerHost int } // DefaultFetcherConfig returns sensible defaults. @@ -60,14 +64,17 @@ func DefaultFetcherConfig() *FetcherConfig { "image/avif", "image/svg+xml", }, - AllowHTTP: false, + AllowHTTP: false, + MaxConnectionsPerHost: DefaultMaxConnectionsPerHost, } } // HTTPFetcher implements the Fetcher interface with SSRF protection. type HTTPFetcher struct { - client *http.Client - config *FetcherConfig + client *http.Client + config *FetcherConfig + hostSems map[string]chan struct{} // per-host semaphores + hostSemMu sync.Mutex // protects hostSems map } // NewHTTPFetcher creates a new fetcher with SSRF protection. @@ -102,11 +109,26 @@ func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher { } return &HTTPFetcher{ - client: client, - config: config, + client: client, + config: config, + hostSems: make(map[string]chan struct{}), } } +// getHostSemaphore returns the semaphore for a host, creating it if necessary. +func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} { + f.hostSemMu.Lock() + defer f.hostSemMu.Unlock() + + sem, ok := f.hostSems[host] + if !ok { + sem = make(chan struct{}, f.config.MaxConnectionsPerHost) + f.hostSems[host] = sem + } + + return sem +} + // Fetch retrieves content from the given URL with SSRF protection. func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) { // Validate URL before making request @@ -114,6 +136,26 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro return nil, err } + // Extract host for rate limiting + host := extractHost(url) + + // Acquire semaphore slot for this host + sem := f.getHostSemaphore(host) + select { + case sem <- struct{}{}: + // Acquired slot + case <-ctx.Done(): + return nil, ctx.Err() + } + + // If we fail before returning a result, release the slot + success := false + defer func() { + if !success { + <-sem + } + }() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -146,14 +188,17 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType) } - // Wrap body with size limiter + // Wrap body with size limiter and semaphore releaser limitedBody := &limitedReader{ reader: resp.Body, remaining: f.config.MaxResponseSize, } + // Mark success so defer doesn't release the semaphore + success = true + return &FetchResult{ - Content: &limitedReadCloser{limitedBody, resp.Body}, + Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem}, ContentLength: resp.ContentLength, ContentType: contentType, Headers: resp.Header, @@ -340,12 +385,16 @@ func (r *limitedReader) Read(p []byte) (int, error) { return n, err } -// limitedReadCloser combines the limited reader with the original closer. -type limitedReadCloser struct { +// semaphoreReleasingReadCloser releases a semaphore slot when closed. +type semaphoreReleasingReadCloser struct { *limitedReader closer io.Closer + sem chan struct{} } -func (r *limitedReadCloser) Close() error { - return r.closer.Close() +func (r *semaphoreReleasingReadCloser) Close() error { + err := r.closer.Close() + <-r.sem // Release semaphore slot + + return err }