Add per-host connection limits for upstream fetching

- Add upstream_connections_per_host config option (default: 20)
- Implement per-host semaphores to limit concurrent connections
- Semaphore released when response body is closed
- Prevents overwhelming origin servers with parallel requests
This commit is contained in:
2026-01-08 05:19:20 -08:00
parent 49ff72dfa8
commit f244d9c7e0
4 changed files with 90 additions and 32 deletions

View File

@@ -22,6 +22,9 @@ whitelist_hosts:
# Allow HTTP upstream (only for testing, always use HTTPS in production)
allow_http: false
# Maximum concurrent connections per upstream host (default: 20)
upstream_connections_per_host: 20
# Sentry error reporting (optional)
sentry_dsn: ""

View File

@@ -16,8 +16,9 @@ import (
// Default configuration values.
const (
DefaultPort = 8080
DefaultStateDir = "./data"
DefaultPort = 8080
DefaultStateDir = "./data"
DefaultUpstreamConnectionsPerHost = 20
)
// Params defines dependencies for Config.
@@ -39,9 +40,10 @@ type Config struct {
DBURL string
// Image proxy settings
SigningKey string // HMAC signing key for URL signatures
WhitelistHosts []string // Hosts that don't require signatures
AllowHTTP bool // Allow non-TLS upstream (testing only)
SigningKey string // HMAC signing key for URL signatures
WhitelistHosts []string // Hosts that don't require signatures
AllowHTTP bool // Allow non-TLS upstream (testing only)
UpstreamConnectionsPerHost int // Max concurrent connections per upstream host
}
// New creates a new Config instance by loading configuration from file.
@@ -59,16 +61,17 @@ func New(_ fx.Lifecycle, params Params) (*Config, error) {
}
c := &Config{
Debug: getBool(sc, "debug", false),
MaintenanceMode: getBool(sc, "maintenance_mode", false),
Port: getInt(sc, "port", DefaultPort),
StateDir: getString(sc, "state_dir", DefaultStateDir),
SentryDSN: getString(sc, "sentry_dsn", ""),
MetricsUsername: getString(sc, "metrics.username", ""),
MetricsPassword: getString(sc, "metrics.password", ""),
SigningKey: getString(sc, "signing_key", ""),
WhitelistHosts: getStringSlice(sc, "whitelist_hosts"),
AllowHTTP: getBool(sc, "allow_http", false),
Debug: getBool(sc, "debug", false),
MaintenanceMode: getBool(sc, "maintenance_mode", false),
Port: getInt(sc, "port", DefaultPort),
StateDir: getString(sc, "state_dir", DefaultStateDir),
SentryDSN: getString(sc, "sentry_dsn", ""),
MetricsUsername: getString(sc, "metrics.username", ""),
MetricsPassword: getString(sc, "metrics.password", ""),
SigningKey: getString(sc, "signing_key", ""),
WhitelistHosts: getStringSlice(sc, "whitelist_hosts"),
AllowHTTP: getBool(sc, "allow_http", false),
UpstreamConnectionsPerHost: getInt(sc, "upstream_connections_per_host", DefaultUpstreamConnectionsPerHost),
}
// Build DBURL from StateDir if not explicitly set

View File

@@ -72,6 +72,9 @@ func (s *Handlers) initImageService() error {
// Create the fetcher config
fetcherCfg := imgcache.DefaultFetcherConfig()
fetcherCfg.AllowHTTP = s.config.AllowHTTP
if s.config.UpstreamConnectionsPerHost > 0 {
fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost
}
// Create the service
svc, err := imgcache.NewService(&imgcache.ServiceConfig{

View File

@@ -8,17 +8,19 @@ import (
"net"
"net/http"
"strings"
"sync"
"time"
)
// Fetcher configuration constants.
const (
DefaultFetchTimeout = 30 * time.Second
DefaultMaxResponseSize = 50 << 20 // 50MB
DefaultTLSTimeout = 10 * time.Second
DefaultMaxIdleConns = 100
DefaultIdleConnTimeout = 90 * time.Second
DefaultMaxRedirects = 10
DefaultFetchTimeout = 30 * time.Second
DefaultMaxResponseSize = 50 << 20 // 50MB
DefaultTLSTimeout = 10 * time.Second
DefaultMaxIdleConns = 100
DefaultIdleConnTimeout = 90 * time.Second
DefaultMaxRedirects = 10
DefaultMaxConnectionsPerHost = 20
)
// Fetcher errors.
@@ -44,6 +46,8 @@ type FetcherConfig struct {
AllowedContentTypes []string
// AllowHTTP allows non-TLS connections (for testing only)
AllowHTTP bool
// MaxConnectionsPerHost limits concurrent connections to each upstream host
MaxConnectionsPerHost int
}
// DefaultFetcherConfig returns sensible defaults.
@@ -60,14 +64,17 @@ func DefaultFetcherConfig() *FetcherConfig {
"image/avif",
"image/svg+xml",
},
AllowHTTP: false,
AllowHTTP: false,
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
}
}
// HTTPFetcher implements the Fetcher interface with SSRF protection.
type HTTPFetcher struct {
client *http.Client
config *FetcherConfig
client *http.Client
config *FetcherConfig
hostSems map[string]chan struct{} // per-host semaphores
hostSemMu sync.Mutex // protects hostSems map
}
// NewHTTPFetcher creates a new fetcher with SSRF protection.
@@ -102,11 +109,26 @@ func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
}
return &HTTPFetcher{
client: client,
config: config,
client: client,
config: config,
hostSems: make(map[string]chan struct{}),
}
}
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
f.hostSemMu.Lock()
defer f.hostSemMu.Unlock()
sem, ok := f.hostSems[host]
if !ok {
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
f.hostSems[host] = sem
}
return sem
}
// Fetch retrieves content from the given URL with SSRF protection.
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
// Validate URL before making request
@@ -114,6 +136,26 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
return nil, err
}
// Extract host for rate limiting
host := extractHost(url)
// Acquire semaphore slot for this host
sem := f.getHostSemaphore(host)
select {
case sem <- struct{}{}:
// Acquired slot
case <-ctx.Done():
return nil, ctx.Err()
}
// If we fail before returning a result, release the slot
success := false
defer func() {
if !success {
<-sem
}
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
@@ -146,14 +188,17 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
}
// Wrap body with size limiter
// Wrap body with size limiter and semaphore releaser
limitedBody := &limitedReader{
reader: resp.Body,
remaining: f.config.MaxResponseSize,
}
// Mark success so defer doesn't release the semaphore
success = true
return &FetchResult{
Content: &limitedReadCloser{limitedBody, resp.Body},
Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
ContentLength: resp.ContentLength,
ContentType: contentType,
Headers: resp.Header,
@@ -340,12 +385,16 @@ func (r *limitedReader) Read(p []byte) (int, error) {
return n, err
}
// limitedReadCloser combines the limited reader with the original closer.
type limitedReadCloser struct {
// semaphoreReleasingReadCloser releases a semaphore slot when closed.
type semaphoreReleasingReadCloser struct {
*limitedReader
closer io.Closer
sem chan struct{}
}
func (r *limitedReadCloser) Close() error {
return r.closer.Close()
func (r *semaphoreReleasingReadCloser) Close() error {
err := r.closer.Close()
<-r.sem // Release semaphore slot
return err
}