Add per-host connection limits for upstream fetching
- Add upstream_connections_per_host config option (default: 20) - Implement per-host semaphores to limit concurrent connections - Semaphore released when response body is closed - Prevents overwhelming origin servers with parallel requests
This commit is contained in:
@@ -8,17 +8,19 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Fetcher configuration constants.
|
||||
const (
|
||||
DefaultFetchTimeout = 30 * time.Second
|
||||
DefaultMaxResponseSize = 50 << 20 // 50MB
|
||||
DefaultTLSTimeout = 10 * time.Second
|
||||
DefaultMaxIdleConns = 100
|
||||
DefaultIdleConnTimeout = 90 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
DefaultFetchTimeout = 30 * time.Second
|
||||
DefaultMaxResponseSize = 50 << 20 // 50MB
|
||||
DefaultTLSTimeout = 10 * time.Second
|
||||
DefaultMaxIdleConns = 100
|
||||
DefaultIdleConnTimeout = 90 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
DefaultMaxConnectionsPerHost = 20
|
||||
)
|
||||
|
||||
// Fetcher errors.
|
||||
@@ -44,6 +46,8 @@ type FetcherConfig struct {
|
||||
AllowedContentTypes []string
|
||||
// AllowHTTP allows non-TLS connections (for testing only)
|
||||
AllowHTTP bool
|
||||
// MaxConnectionsPerHost limits concurrent connections to each upstream host
|
||||
MaxConnectionsPerHost int
|
||||
}
|
||||
|
||||
// DefaultFetcherConfig returns sensible defaults.
|
||||
@@ -60,14 +64,17 @@ func DefaultFetcherConfig() *FetcherConfig {
|
||||
"image/avif",
|
||||
"image/svg+xml",
|
||||
},
|
||||
AllowHTTP: false,
|
||||
AllowHTTP: false,
|
||||
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPFetcher implements the Fetcher interface with SSRF protection.
|
||||
type HTTPFetcher struct {
|
||||
client *http.Client
|
||||
config *FetcherConfig
|
||||
client *http.Client
|
||||
config *FetcherConfig
|
||||
hostSems map[string]chan struct{} // per-host semaphores
|
||||
hostSemMu sync.Mutex // protects hostSems map
|
||||
}
|
||||
|
||||
// NewHTTPFetcher creates a new fetcher with SSRF protection.
|
||||
@@ -102,11 +109,26 @@ func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
|
||||
}
|
||||
|
||||
return &HTTPFetcher{
|
||||
client: client,
|
||||
config: config,
|
||||
client: client,
|
||||
config: config,
|
||||
hostSems: make(map[string]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
|
||||
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
|
||||
f.hostSemMu.Lock()
|
||||
defer f.hostSemMu.Unlock()
|
||||
|
||||
sem, ok := f.hostSems[host]
|
||||
if !ok {
|
||||
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
|
||||
f.hostSems[host] = sem
|
||||
}
|
||||
|
||||
return sem
|
||||
}
|
||||
|
||||
// Fetch retrieves content from the given URL with SSRF protection.
|
||||
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
||||
// Validate URL before making request
|
||||
@@ -114,6 +136,26 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract host for rate limiting
|
||||
host := extractHost(url)
|
||||
|
||||
// Acquire semaphore slot for this host
|
||||
sem := f.getHostSemaphore(host)
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
// Acquired slot
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// If we fail before returning a result, release the slot
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
<-sem
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
@@ -146,14 +188,17 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, erro
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
|
||||
}
|
||||
|
||||
// Wrap body with size limiter
|
||||
// Wrap body with size limiter and semaphore releaser
|
||||
limitedBody := &limitedReader{
|
||||
reader: resp.Body,
|
||||
remaining: f.config.MaxResponseSize,
|
||||
}
|
||||
|
||||
// Mark success so defer doesn't release the semaphore
|
||||
success = true
|
||||
|
||||
return &FetchResult{
|
||||
Content: &limitedReadCloser{limitedBody, resp.Body},
|
||||
Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
|
||||
ContentLength: resp.ContentLength,
|
||||
ContentType: contentType,
|
||||
Headers: resp.Header,
|
||||
@@ -340,12 +385,16 @@ func (r *limitedReader) Read(p []byte) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// limitedReadCloser combines the limited reader with the original closer.
|
||||
type limitedReadCloser struct {
|
||||
// semaphoreReleasingReadCloser releases a semaphore slot when closed.
|
||||
type semaphoreReleasingReadCloser struct {
|
||||
*limitedReader
|
||||
closer io.Closer
|
||||
sem chan struct{}
|
||||
}
|
||||
|
||||
func (r *limitedReadCloser) Close() error {
|
||||
return r.closer.Close()
|
||||
func (r *semaphoreReleasingReadCloser) Close() error {
|
||||
err := r.closer.Close()
|
||||
<-r.sem // Release semaphore slot
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user