Add per-host connection limits for upstream fetching
- Add upstream_connections_per_host config option (default: 20) - Implement per-host semaphores to limit concurrent connections - Semaphore released when response body is closed - Prevents overwhelming origin servers with parallel requests
This commit is contained in:
@@ -22,6 +22,9 @@ whitelist_hosts:
|
|||||||
# Allow HTTP upstream (only for testing, always use HTTPS in production)
|
# Allow HTTP upstream (only for testing, always use HTTPS in production)
|
||||||
allow_http: false
|
allow_http: false
|
||||||
|
|
||||||
|
# Maximum concurrent connections per upstream host (default: 20)
|
||||||
|
upstream_connections_per_host: 20
|
||||||
|
|
||||||
# Sentry error reporting (optional)
|
# Sentry error reporting (optional)
|
||||||
sentry_dsn: ""
|
sentry_dsn: ""
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
DefaultPort = 8080
|
DefaultPort = 8080
|
||||||
DefaultStateDir = "./data"
|
DefaultStateDir = "./data"
|
||||||
|
DefaultUpstreamConnectionsPerHost = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
// Params defines dependencies for Config.
|
// Params defines dependencies for Config.
|
||||||
@@ -42,6 +43,7 @@ type Config struct {
|
|||||||
SigningKey string // HMAC signing key for URL signatures
|
SigningKey string // HMAC signing key for URL signatures
|
||||||
WhitelistHosts []string // Hosts that don't require signatures
|
WhitelistHosts []string // Hosts that don't require signatures
|
||||||
AllowHTTP bool // Allow non-TLS upstream (testing only)
|
AllowHTTP bool // Allow non-TLS upstream (testing only)
|
||||||
|
UpstreamConnectionsPerHost int // Max concurrent connections per upstream host
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Config instance by loading configuration from file.
|
// New creates a new Config instance by loading configuration from file.
|
||||||
@@ -69,6 +71,7 @@ func New(_ fx.Lifecycle, params Params) (*Config, error) {
|
|||||||
SigningKey: getString(sc, "signing_key", ""),
|
SigningKey: getString(sc, "signing_key", ""),
|
||||||
WhitelistHosts: getStringSlice(sc, "whitelist_hosts"),
|
WhitelistHosts: getStringSlice(sc, "whitelist_hosts"),
|
||||||
AllowHTTP: getBool(sc, "allow_http", false),
|
AllowHTTP: getBool(sc, "allow_http", false),
|
||||||
|
UpstreamConnectionsPerHost: getInt(sc, "upstream_connections_per_host", DefaultUpstreamConnectionsPerHost),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build DBURL from StateDir if not explicitly set
|
// Build DBURL from StateDir if not explicitly set
|
||||||
|
|||||||
@@ -72,6 +72,9 @@ func (s *Handlers) initImageService() error {
|
|||||||
// Create the fetcher config
|
// Create the fetcher config
|
||||||
fetcherCfg := imgcache.DefaultFetcherConfig()
|
fetcherCfg := imgcache.DefaultFetcherConfig()
|
||||||
fetcherCfg.AllowHTTP = s.config.AllowHTTP
|
fetcherCfg.AllowHTTP = s.config.AllowHTTP
|
||||||
|
if s.config.UpstreamConnectionsPerHost > 0 {
|
||||||
|
fetcherCfg.MaxConnectionsPerHost = s.config.UpstreamConnectionsPerHost
|
||||||
|
}
|
||||||
|
|
||||||
// Create the service
|
// Create the service
|
||||||
svc, err := imgcache.NewService(&imgcache.ServiceConfig{
|
svc, err := imgcache.NewService(&imgcache.ServiceConfig{
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ const (
|
|||||||
DefaultMaxIdleConns = 100
|
DefaultMaxIdleConns = 100
|
||||||
DefaultIdleConnTimeout = 90 * time.Second
|
DefaultIdleConnTimeout = 90 * time.Second
|
||||||
DefaultMaxRedirects = 10
|
DefaultMaxRedirects = 10
|
||||||
|
DefaultMaxConnectionsPerHost = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
// Fetcher errors.
|
// Fetcher errors.
|
||||||
@@ -44,6 +46,8 @@ type FetcherConfig struct {
|
|||||||
AllowedContentTypes []string
|
AllowedContentTypes []string
|
||||||
// AllowHTTP allows non-TLS connections (for testing only)
|
// AllowHTTP allows non-TLS connections (for testing only)
|
||||||
AllowHTTP bool
|
AllowHTTP bool
|
||||||
|
// MaxConnectionsPerHost limits concurrent connections to each upstream host
|
||||||
|
MaxConnectionsPerHost int
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultFetcherConfig returns sensible defaults.
|
// DefaultFetcherConfig returns sensible defaults.
|
||||||
@@ -61,6 +65,7 @@ func DefaultFetcherConfig() *FetcherConfig {
|
|||||||
"image/svg+xml",
|
"image/svg+xml",
|
||||||
},
|
},
|
||||||
AllowHTTP: false,
|
AllowHTTP: false,
|
||||||
|
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,6 +73,8 @@ func DefaultFetcherConfig() *FetcherConfig {
|
|||||||
type HTTPFetcher struct {
|
type HTTPFetcher struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
config *FetcherConfig
|
config *FetcherConfig
|
||||||
|
hostSems map[string]chan struct{} // per-host semaphores
|
||||||
|
hostSemMu sync.Mutex // protects hostSems map
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPFetcher creates a new fetcher with SSRF protection.
|
// NewHTTPFetcher creates a new fetcher with SSRF protection.
|
||||||
@@ -104,9 +111,24 @@ func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
|
|||||||
return &HTTPFetcher{
|
return &HTTPFetcher{
|
||||||
client: client,
|
client: client,
|
||||||
config: config,
|
config: config,
|
||||||
|
hostSems: make(map[string]chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
|
||||||
|
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
|
||||||
|
f.hostSemMu.Lock()
|
||||||
|
defer f.hostSemMu.Unlock()
|
||||||
|
|
||||||
|
sem, ok := f.hostSems[host]
|
||||||
|
if !ok {
|
||||||
|
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
|
||||||
|
f.hostSems[host] = sem
|
||||||
|
}
|
||||||
|
|
||||||
|
return sem
|
||||||
|
}
|
||||||
|
|
||||||
// Fetch retrieves content from the given URL with SSRF protection.
|
// Fetch retrieves content from the given URL with SSRF protection.
|
||||||
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
||||||
// Validate URL before making request
|
// Validate URL before making request
|
||||||
@@ -114,6 +136,26 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract host for rate limiting
|
||||||
|
host := extractHost(url)
|
||||||
|
|
||||||
|
// Acquire semaphore slot for this host
|
||||||
|
sem := f.getHostSemaphore(host)
|
||||||
|
select {
|
||||||
|
case sem <- struct{}{}:
|
||||||
|
// Acquired slot
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we fail before returning a result, release the slot
|
||||||
|
success := false
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
<-sem
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
@@ -146,14 +188,17 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
|
|||||||
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
|
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap body with size limiter
|
// Wrap body with size limiter and semaphore releaser
|
||||||
limitedBody := &limitedReader{
|
limitedBody := &limitedReader{
|
||||||
reader: resp.Body,
|
reader: resp.Body,
|
||||||
remaining: f.config.MaxResponseSize,
|
remaining: f.config.MaxResponseSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark success so defer doesn't release the semaphore
|
||||||
|
success = true
|
||||||
|
|
||||||
return &FetchResult{
|
return &FetchResult{
|
||||||
Content: &limitedReadCloser{limitedBody, resp.Body},
|
Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
|
||||||
ContentLength: resp.ContentLength,
|
ContentLength: resp.ContentLength,
|
||||||
ContentType: contentType,
|
ContentType: contentType,
|
||||||
Headers: resp.Header,
|
Headers: resp.Header,
|
||||||
@@ -340,12 +385,16 @@ func (r *limitedReader) Read(p []byte) (int, error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// limitedReadCloser combines the limited reader with the original closer.
|
// semaphoreReleasingReadCloser releases a semaphore slot when closed.
|
||||||
type limitedReadCloser struct {
|
type semaphoreReleasingReadCloser struct {
|
||||||
*limitedReader
|
*limitedReader
|
||||||
closer io.Closer
|
closer io.Closer
|
||||||
|
sem chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *limitedReadCloser) Close() error {
|
func (r *semaphoreReleasingReadCloser) Close() error {
|
||||||
return r.closer.Close()
|
err := r.closer.Close()
|
||||||
|
<-r.sem // Release semaphore slot
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user