Add per-host connection limits for upstream fetching

- Add upstream_connections_per_host config option (default: 20)
- Implement per-host semaphores to limit concurrent connections
- Semaphore released when response body is closed
- Prevents overwhelming origin servers with parallel requests
This commit is contained in:
2026-01-08 05:19:20 -08:00
parent 49ff72dfa8
commit f244d9c7e0
4 changed files with 90 additions and 32 deletions

View File

@@ -22,6 +22,9 @@ whitelist_hosts:
# Allow HTTP upstream (only for testing, always use HTTPS in production) # Allow HTTP upstream (only for testing, always use HTTPS in production)
allow_http: false allow_http: false
# Maximum concurrent connections per upstream host (default: 20)
upstream_connections_per_host: 20
# Sentry error reporting (optional) # Sentry error reporting (optional)
sentry_dsn: "" sentry_dsn: ""

View File

@@ -16,8 +16,9 @@ import (
// Default configuration values. // Default configuration values.
const ( const (
DefaultPort = 8080 DefaultPort = 8080
DefaultStateDir = "./data" DefaultStateDir = "./data"
DefaultUpstreamConnectionsPerHost = 20
) )
// Params defines dependencies for Config. // Params defines dependencies for Config.
@@ -39,9 +40,10 @@ type Config struct {
DBURL string DBURL string
// Image proxy settings // Image proxy settings
SigningKey string // HMAC signing key for URL signatures SigningKey string // HMAC signing key for URL signatures
WhitelistHosts []string // Hosts that don't require signatures WhitelistHosts []string // Hosts that don't require signatures
AllowHTTP bool // Allow non-TLS upstream (testing only) AllowHTTP bool // Allow non-TLS upstream (testing only)
UpstreamConnectionsPerHost int // Max concurrent connections per upstream host
} }
// New creates a new Config instance by loading configuration from file. // New creates a new Config instance by loading configuration from file.
@@ -59,16 +61,17 @@ func New(_ fx.Lifecycle, params Params) (*Config, error) {
} }
c := &Config{ c := &Config{
Debug: getBool(sc, "debug", false), Debug: getBool(sc, "debug", false),
MaintenanceMode: getBool(sc, "maintenance_mode", false), MaintenanceMode: getBool(sc, "maintenance_mode", false),
Port: getInt(sc, "port", DefaultPort), Port: getInt(sc, "port", DefaultPort),
StateDir: getString(sc, "state_dir", DefaultStateDir), StateDir: getString(sc, "state_dir", DefaultStateDir),
SentryDSN: getString(sc, "sentry_dsn", ""), SentryDSN: getString(sc, "sentry_dsn", ""),
MetricsUsername: getString(sc, "metrics.username", ""), MetricsUsername: getString(sc, "metrics.username", ""),
MetricsPassword: getString(sc, "metrics.password", ""), MetricsPassword: getString(sc, "metrics.password", ""),
SigningKey: getString(sc, "signing_key", ""), SigningKey: getString(sc, "signing_key", ""),
WhitelistHosts: getStringSlice(sc, "whitelist_hosts"), WhitelistHosts: getStringSlice(sc, "whitelist_hosts"),
AllowHTTP: getBool(sc, "allow_http", false), AllowHTTP: getBool(sc, "allow_http", false),
UpstreamConnectionsPerHost: getInt(sc, "upstream_connections_per_host", DefaultUpstreamConnectionsPerHost),
} }
// Build DBURL from StateDir if not explicitly set // Build DBURL from StateDir if not explicitly set

View File

@@ -72,6 +72,9 @@ func (s *Handlers) initImageService() error {
// Create the fetcher config // Create the fetcher config
fetcherCfg := imgcache.DefaultFetcherConfig() fetcherCfg := imgcache.DefaultFetcherConfig()
fetcherCfg.AllowHTTP = s.config.AllowHTTP fetcherCfg.AllowHTTP = s.config.AllowHTTP
if s.config.UpstreamConnectionsPerHost > 0 {
fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost
}
// Create the service // Create the service
svc, err := imgcache.NewService(&imgcache.ServiceConfig{ svc, err := imgcache.NewService(&imgcache.ServiceConfig{

View File

@@ -8,17 +8,19 @@ import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
"sync"
"time" "time"
) )
// Fetcher configuration constants. // Fetcher configuration constants.
const ( const (
DefaultFetchTimeout = 30 * time.Second DefaultFetchTimeout = 30 * time.Second
DefaultMaxResponseSize = 50 << 20 // 50MB DefaultMaxResponseSize = 50 << 20 // 50MB
DefaultTLSTimeout = 10 * time.Second DefaultTLSTimeout = 10 * time.Second
DefaultMaxIdleConns = 100 DefaultMaxIdleConns = 100
DefaultIdleConnTimeout = 90 * time.Second DefaultIdleConnTimeout = 90 * time.Second
DefaultMaxRedirects = 10 DefaultMaxRedirects = 10
DefaultMaxConnectionsPerHost = 20
) )
// Fetcher errors. // Fetcher errors.
@@ -44,6 +46,8 @@ type FetcherConfig struct {
AllowedContentTypes []string AllowedContentTypes []string
// AllowHTTP allows non-TLS connections (for testing only) // AllowHTTP allows non-TLS connections (for testing only)
AllowHTTP bool AllowHTTP bool
// MaxConnectionsPerHost limits concurrent connections to each upstream host
MaxConnectionsPerHost int
} }
// DefaultFetcherConfig returns sensible defaults. // DefaultFetcherConfig returns sensible defaults.
@@ -60,14 +64,17 @@ func DefaultFetcherConfig() *FetcherConfig {
"image/avif", "image/avif",
"image/svg+xml", "image/svg+xml",
}, },
AllowHTTP: false, AllowHTTP: false,
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
} }
} }
// HTTPFetcher implements the Fetcher interface with SSRF protection. // HTTPFetcher implements the Fetcher interface with SSRF protection.
type HTTPFetcher struct { type HTTPFetcher struct {
client *http.Client client *http.Client
config *FetcherConfig config *FetcherConfig
hostSems map[string]chan struct{} // per-host semaphores
hostSemMu sync.Mutex // protects hostSems map
} }
// NewHTTPFetcher creates a new fetcher with SSRF protection. // NewHTTPFetcher creates a new fetcher with SSRF protection.
@@ -102,11 +109,26 @@ func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
} }
return &HTTPFetcher{ return &HTTPFetcher{
client: client, client: client,
config: config, config: config,
hostSems: make(map[string]chan struct{}),
} }
} }
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
f.hostSemMu.Lock()
defer f.hostSemMu.Unlock()
sem, ok := f.hostSems[host]
if !ok {
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
f.hostSems[host] = sem
}
return sem
}
// Fetch retrieves content from the given URL with SSRF protection. // Fetch retrieves content from the given URL with SSRF protection.
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) { func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
// Validate URL before making request // Validate URL before making request
@@ -114,6 +136,26 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
return nil, err return nil, err
} }
// Extract host for rate limiting
host := extractHost(url)
// Acquire semaphore slot for this host
sem := f.getHostSemaphore(host)
select {
case sem <- struct{}{}:
// Acquired slot
case <-ctx.Done():
return nil, ctx.Err()
}
// If we fail before returning a result, release the slot
success := false
defer func() {
if !success {
<-sem
}
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
@@ -146,14 +188,17 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType) return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
} }
// Wrap body with size limiter // Wrap body with size limiter and semaphore releaser
limitedBody := &limitedReader{ limitedBody := &limitedReader{
reader: resp.Body, reader: resp.Body,
remaining: f.config.MaxResponseSize, remaining: f.config.MaxResponseSize,
} }
// Mark success so defer doesn't release the semaphore
success = true
return &FetchResult{ return &FetchResult{
Content: &limitedReadCloser{limitedBody, resp.Body}, Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
ContentLength: resp.ContentLength, ContentLength: resp.ContentLength,
ContentType: contentType, ContentType: contentType,
Headers: resp.Header, Headers: resp.Header,
@@ -340,12 +385,16 @@ func (r *limitedReader) Read(p []byte) (int, error) {
return n, err return n, err
} }
// limitedReadCloser combines the limited reader with the original closer. // semaphoreReleasingReadCloser releases a semaphore slot when closed.
type limitedReadCloser struct { type semaphoreReleasingReadCloser struct {
*limitedReader *limitedReader
closer io.Closer closer io.Closer
sem chan struct{}
} }
func (r *limitedReadCloser) Close() error { func (r *semaphoreReleasingReadCloser) Close() error {
return r.closer.Close() err := r.closer.Close()
<-r.sem // Release semaphore slot
return err
} }