Add per-host connection limits for upstream fetching
- Add upstream_connections_per_host config option (default: 20) - Implement per-host semaphores to limit concurrent connections - Semaphore released when response body is closed - Prevents overwhelming origin servers with parallel requests
This commit is contained in:
@@ -16,8 +16,9 @@ import (
|
||||
|
||||
// Default configuration values.
|
||||
const (
|
||||
DefaultPort = 8080
|
||||
DefaultStateDir = "./data"
|
||||
DefaultPort = 8080
|
||||
DefaultStateDir = "./data"
|
||||
DefaultUpstreamConnectionsPerHost = 20
|
||||
)
|
||||
|
||||
// Params defines dependencies for Config.
|
||||
@@ -39,9 +40,10 @@ type Config struct {
|
||||
DBURL string
|
||||
|
||||
// Image proxy settings
|
||||
SigningKey string // HMAC signing key for URL signatures
|
||||
WhitelistHosts []string // Hosts that don't require signatures
|
||||
AllowHTTP bool // Allow non-TLS upstream (testing only)
|
||||
SigningKey string // HMAC signing key for URL signatures
|
||||
WhitelistHosts []string // Hosts that don't require signatures
|
||||
AllowHTTP bool // Allow non-TLS upstream (testing only)
|
||||
UpstreamConnectionsPerHost int // Max concurrent connections per upstream host
|
||||
}
|
||||
|
||||
// New creates a new Config instance by loading configuration from file.
|
||||
@@ -59,16 +61,17 @@ func New(_ fx.Lifecycle, params Params) (*Config, error) {
|
||||
}
|
||||
|
||||
c := &Config{
|
||||
Debug: getBool(sc, "debug", false),
|
||||
MaintenanceMode: getBool(sc, "maintenance_mode", false),
|
||||
Port: getInt(sc, "port", DefaultPort),
|
||||
StateDir: getString(sc, "state_dir", DefaultStateDir),
|
||||
SentryDSN: getString(sc, "sentry_dsn", ""),
|
||||
MetricsUsername: getString(sc, "metrics.username", ""),
|
||||
MetricsPassword: getString(sc, "metrics.password", ""),
|
||||
SigningKey: getString(sc, "signing_key", ""),
|
||||
WhitelistHosts: getStringSlice(sc, "whitelist_hosts"),
|
||||
AllowHTTP: getBool(sc, "allow_http", false),
|
||||
Debug: getBool(sc, "debug", false),
|
||||
MaintenanceMode: getBool(sc, "maintenance_mode", false),
|
||||
Port: getInt(sc, "port", DefaultPort),
|
||||
StateDir: getString(sc, "state_dir", DefaultStateDir),
|
||||
SentryDSN: getString(sc, "sentry_dsn", ""),
|
||||
MetricsUsername: getString(sc, "metrics.username", ""),
|
||||
MetricsPassword: getString(sc, "metrics.password", ""),
|
||||
SigningKey: getString(sc, "signing_key", ""),
|
||||
WhitelistHosts: getStringSlice(sc, "whitelist_hosts"),
|
||||
AllowHTTP: getBool(sc, "allow_http", false),
|
||||
UpstreamConnectionsPerHost: getInt(sc, "upstream_connections_per_host", DefaultUpstreamConnectionsPerHost),
|
||||
}
|
||||
|
||||
// Build DBURL from StateDir if not explicitly set
|
||||
|
||||
@@ -72,6 +72,9 @@ func (s *Handlers) initImageService() error {
|
||||
// Create the fetcher config
|
||||
fetcherCfg := imgcache.DefaultFetcherConfig()
|
||||
fetcherCfg.AllowHTTP = s.config.AllowHTTP
|
||||
if s.config.UpstreamConnectionsPerHost > 0 {
|
||||
fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost
|
||||
}
|
||||
|
||||
// Create the service
|
||||
svc, err := imgcache.NewService(&imgcache.ServiceConfig{
|
||||
|
||||
@@ -8,17 +8,19 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Fetcher configuration constants.
|
||||
const (
|
||||
DefaultFetchTimeout = 30 * time.Second
|
||||
DefaultMaxResponseSize = 50 << 20 // 50MB
|
||||
DefaultTLSTimeout = 10 * time.Second
|
||||
DefaultMaxIdleConns = 100
|
||||
DefaultIdleConnTimeout = 90 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
DefaultFetchTimeout = 30 * time.Second
|
||||
DefaultMaxResponseSize = 50 << 20 // 50MB
|
||||
DefaultTLSTimeout = 10 * time.Second
|
||||
DefaultMaxIdleConns = 100
|
||||
DefaultIdleConnTimeout = 90 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
DefaultMaxConnectionsPerHost = 20
|
||||
)
|
||||
|
||||
// Fetcher errors.
|
||||
@@ -44,6 +46,8 @@ type FetcherConfig struct {
|
||||
AllowedContentTypes []string
|
||||
// AllowHTTP allows non-TLS connections (for testing only)
|
||||
AllowHTTP bool
|
||||
// MaxConnectionsPerHost limits concurrent connections to each upstream host
|
||||
MaxConnectionsPerHost int
|
||||
}
|
||||
|
||||
// DefaultFetcherConfig returns sensible defaults.
|
||||
@@ -60,14 +64,17 @@ func DefaultFetcherConfig() *FetcherConfig {
|
||||
"image/avif",
|
||||
"image/svg+xml",
|
||||
},
|
||||
AllowHTTP: false,
|
||||
AllowHTTP: false,
|
||||
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPFetcher implements the Fetcher interface with SSRF protection.
|
||||
type HTTPFetcher struct {
|
||||
client *http.Client
|
||||
config *FetcherConfig
|
||||
client *http.Client
|
||||
config *FetcherConfig
|
||||
hostSems map[string]chan struct{} // per-host semaphores
|
||||
hostSemMu sync.Mutex // protects hostSems map
|
||||
}
|
||||
|
||||
// NewHTTPFetcher creates a new fetcher with SSRF protection.
|
||||
@@ -102,11 +109,26 @@ func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
|
||||
}
|
||||
|
||||
return &HTTPFetcher{
|
||||
client: client,
|
||||
config: config,
|
||||
client: client,
|
||||
config: config,
|
||||
hostSems: make(map[string]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
|
||||
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
|
||||
f.hostSemMu.Lock()
|
||||
defer f.hostSemMu.Unlock()
|
||||
|
||||
sem, ok := f.hostSems[host]
|
||||
if !ok {
|
||||
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
|
||||
f.hostSems[host] = sem
|
||||
}
|
||||
|
||||
return sem
|
||||
}
|
||||
|
||||
// Fetch retrieves content from the given URL with SSRF protection.
|
||||
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
||||
// Validate URL before making request
|
||||
@@ -114,6 +136,26 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract host for rate limiting
|
||||
host := extractHost(url)
|
||||
|
||||
// Acquire semaphore slot for this host
|
||||
sem := f.getHostSemaphore(host)
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
// Acquired slot
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// If we fail before returning a result, release the slot
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
<-sem
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
@@ -146,14 +188,17 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
|
||||
}
|
||||
|
||||
// Wrap body with size limiter
|
||||
// Wrap body with size limiter and semaphore releaser
|
||||
limitedBody := &limitedReader{
|
||||
reader: resp.Body,
|
||||
remaining: f.config.MaxResponseSize,
|
||||
}
|
||||
|
||||
// Mark success so defer doesn't release the semaphore
|
||||
success = true
|
||||
|
||||
return &FetchResult{
|
||||
Content: &limitedReadCloser{limitedBody, resp.Body},
|
||||
Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
|
||||
ContentLength: resp.ContentLength,
|
||||
ContentType: contentType,
|
||||
Headers: resp.Header,
|
||||
@@ -340,12 +385,16 @@ func (r *limitedReader) Read(p []byte) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// limitedReadCloser combines the limited reader with the original closer.
|
||||
type limitedReadCloser struct {
|
||||
// semaphoreReleasingReadCloser releases a semaphore slot when closed.
|
||||
type semaphoreReleasingReadCloser struct {
|
||||
*limitedReader
|
||||
closer io.Closer
|
||||
sem chan struct{}
|
||||
}
|
||||
|
||||
func (r *limitedReadCloser) Close() error {
|
||||
return r.closer.Close()
|
||||
func (r *semaphoreReleasingReadCloser) Close() error {
|
||||
err := r.closer.Close()
|
||||
<-r.sem // Release semaphore slot
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user