Some checks failed
check / check (push) Has been cancelled
- Update Dockerfile base image from golang:1.24-alpine to golang:1.25.4-alpine (pinned by sha256 digest) to match go.mod requirement of go >= 1.25.4 - Fix gosec G703 (path traversal) false positives by adding filepath.Clean() at call sites with nolint annotations for internally-constructed paths - Fix gosec G704 (SSRF) false positive with nolint annotation; URL is already validated by validateURL() which checks scheme, resolves DNS, and blocks private IPs - All make check passes clean (lint + tests)
446 lines
11 KiB
Go
446 lines
11 KiB
Go
package imgcache
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptrace"
|
|
neturl "net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Fetcher configuration constants.
|
|
const (
|
|
DefaultFetchTimeout = 30 * time.Second
|
|
DefaultMaxResponseSize = 50 << 20 // 50MB
|
|
DefaultTLSTimeout = 10 * time.Second
|
|
DefaultMaxIdleConns = 100
|
|
DefaultIdleConnTimeout = 90 * time.Second
|
|
DefaultMaxRedirects = 10
|
|
DefaultMaxConnectionsPerHost = 20
|
|
)
|
|
|
|
// Fetcher errors.
|
|
var (
|
|
ErrSSRFBlocked = errors.New("request blocked: private or internal IP")
|
|
ErrInvalidHost = errors.New("invalid or unresolvable host")
|
|
ErrUnsupportedScheme = errors.New("only HTTPS is supported")
|
|
ErrResponseTooLarge = errors.New("response exceeds maximum size")
|
|
ErrInvalidContentType = errors.New("invalid or unsupported content type")
|
|
ErrUpstreamError = errors.New("upstream server error")
|
|
ErrUpstreamTimeout = errors.New("upstream request timeout")
|
|
)
|
|
|
|
// FetcherConfig holds configuration for the upstream fetcher.
|
|
type FetcherConfig struct {
|
|
// Timeout for upstream requests
|
|
Timeout time.Duration
|
|
// MaxResponseSize is the maximum allowed response body size
|
|
MaxResponseSize int64
|
|
// UserAgent to send to upstream servers
|
|
UserAgent string
|
|
// AllowedContentTypes is a whitelist of MIME types to accept
|
|
AllowedContentTypes []string
|
|
// AllowHTTP allows non-TLS connections (for testing only)
|
|
AllowHTTP bool
|
|
// MaxConnectionsPerHost limits concurrent connections to each upstream host
|
|
MaxConnectionsPerHost int
|
|
}
|
|
|
|
// DefaultFetcherConfig returns sensible defaults.
|
|
func DefaultFetcherConfig() *FetcherConfig {
|
|
return &FetcherConfig{
|
|
Timeout: DefaultFetchTimeout,
|
|
MaxResponseSize: DefaultMaxResponseSize,
|
|
UserAgent: "pixa/1.0",
|
|
AllowedContentTypes: []string{
|
|
"image/jpeg",
|
|
"image/png",
|
|
"image/gif",
|
|
"image/webp",
|
|
"image/avif",
|
|
"image/svg+xml",
|
|
},
|
|
AllowHTTP: false,
|
|
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
|
|
}
|
|
}
|
|
|
|
// HTTPFetcher implements the Fetcher interface with SSRF protection.
|
|
type HTTPFetcher struct {
|
|
client *http.Client
|
|
config *FetcherConfig
|
|
hostSems map[string]chan struct{} // per-host semaphores
|
|
hostSemMu sync.Mutex // protects hostSems map
|
|
}
|
|
|
|
// NewHTTPFetcher creates a new fetcher with SSRF protection.
|
|
func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
|
|
if config == nil {
|
|
config = DefaultFetcherConfig()
|
|
}
|
|
|
|
// Create transport with SSRF-safe dialer
|
|
transport := &http.Transport{
|
|
DialContext: ssrfSafeDialer,
|
|
TLSHandshakeTimeout: DefaultTLSTimeout,
|
|
MaxIdleConns: DefaultMaxIdleConns,
|
|
IdleConnTimeout: DefaultIdleConnTimeout,
|
|
}
|
|
|
|
client := &http.Client{
|
|
Transport: transport,
|
|
Timeout: config.Timeout,
|
|
// Don't follow redirects automatically - we need to validate each hop
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
if len(via) >= DefaultMaxRedirects {
|
|
return errors.New("too many redirects")
|
|
}
|
|
// Validate the redirect target
|
|
if err := validateURL(req.URL.String(), config.AllowHTTP); err != nil {
|
|
return fmt.Errorf("redirect blocked: %w", err)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
}
|
|
|
|
return &HTTPFetcher{
|
|
client: client,
|
|
config: config,
|
|
hostSems: make(map[string]chan struct{}),
|
|
}
|
|
}
|
|
|
|
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
|
|
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
|
|
f.hostSemMu.Lock()
|
|
defer f.hostSemMu.Unlock()
|
|
|
|
sem, ok := f.hostSems[host]
|
|
if !ok {
|
|
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
|
|
f.hostSems[host] = sem
|
|
}
|
|
|
|
return sem
|
|
}
|
|
|
|
// Fetch retrieves content from the given URL with SSRF protection.
|
|
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
|
// Validate URL before making request
|
|
if err := validateURL(url, f.config.AllowHTTP); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Extract host for rate limiting
|
|
host := extractHost(url)
|
|
|
|
// Acquire semaphore slot for this host
|
|
sem := f.getHostSemaphore(host)
|
|
select {
|
|
case sem <- struct{}{}:
|
|
// Acquired slot
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
// If we fail before returning a result, release the slot
|
|
success := false
|
|
defer func() {
|
|
if !success {
|
|
<-sem
|
|
}
|
|
}()
|
|
|
|
parsedURL, err := neturl.Parse(url)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse URL: %w", err)
|
|
}
|
|
|
|
req := &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: parsedURL,
|
|
Header: make(http.Header),
|
|
}
|
|
req = req.WithContext(ctx)
|
|
|
|
req.Header.Set("User-Agent", f.config.UserAgent)
|
|
req.Header.Set("Accept", strings.Join(f.config.AllowedContentTypes, ", "))
|
|
|
|
// Use httptrace to capture connection details
|
|
var remoteAddr string
|
|
|
|
trace := &httptrace.ClientTrace{
|
|
GotConn: func(info httptrace.GotConnInfo) {
|
|
if info.Conn != nil {
|
|
remoteAddr = info.Conn.RemoteAddr().String()
|
|
}
|
|
},
|
|
}
|
|
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
|
|
|
startTime := time.Now()
|
|
|
|
//nolint:gosec // G704: URL validated by validateURL() above
|
|
resp, err := f.client.Do(req)
|
|
|
|
fetchDuration := time.Since(startTime)
|
|
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, ErrUpstreamTimeout
|
|
}
|
|
|
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
|
}
|
|
|
|
// Extract HTTP version (strip "HTTP/" prefix)
|
|
httpVersion := strings.TrimPrefix(resp.Proto, "HTTP/")
|
|
|
|
// Extract TLS info if available
|
|
var tlsVersion, tlsCipherSuite string
|
|
|
|
if resp.TLS != nil {
|
|
tlsVersion = tls.VersionName(resp.TLS.Version)
|
|
tlsCipherSuite = tls.CipherSuiteName(resp.TLS.CipherSuite)
|
|
}
|
|
|
|
// Check status code
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
_ = resp.Body.Close()
|
|
|
|
return nil, fmt.Errorf("%w: status %d", ErrUpstreamError, resp.StatusCode)
|
|
}
|
|
|
|
// Validate content type
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if !f.isAllowedContentType(contentType) {
|
|
_ = resp.Body.Close()
|
|
|
|
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
|
|
}
|
|
|
|
// Wrap body with size limiter and semaphore releaser
|
|
limitedBody := &limitedReader{
|
|
reader: resp.Body,
|
|
remaining: f.config.MaxResponseSize,
|
|
}
|
|
|
|
// Mark success so defer doesn't release the semaphore
|
|
success = true
|
|
|
|
return &FetchResult{
|
|
Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
|
|
ContentLength: resp.ContentLength,
|
|
ContentType: contentType,
|
|
Headers: resp.Header,
|
|
StatusCode: resp.StatusCode,
|
|
FetchDurationMs: fetchDuration.Milliseconds(),
|
|
RemoteAddr: remoteAddr,
|
|
HTTPVersion: httpVersion,
|
|
TLSVersion: tlsVersion,
|
|
TLSCipherSuite: tlsCipherSuite,
|
|
}, nil
|
|
}
|
|
|
|
// isAllowedContentType checks if the content type is in the whitelist.
|
|
func (f *HTTPFetcher) isAllowedContentType(contentType string) bool {
|
|
// Extract the MIME type without parameters
|
|
mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0])
|
|
|
|
for _, allowed := range f.config.AllowedContentTypes {
|
|
if strings.EqualFold(mediaType, allowed) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// validateURL checks if a URL is safe to fetch (not internal/private).
|
|
func validateURL(rawURL string, allowHTTP bool) error {
|
|
if !allowHTTP && !strings.HasPrefix(rawURL, "https://") {
|
|
return ErrUnsupportedScheme
|
|
}
|
|
|
|
// Parse to extract host
|
|
host := extractHost(rawURL)
|
|
if host == "" {
|
|
return ErrInvalidHost
|
|
}
|
|
|
|
// Remove port if present
|
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
|
host = h
|
|
}
|
|
|
|
// Block obvious localhost patterns
|
|
if isLocalhost(host) {
|
|
return ErrSSRFBlocked
|
|
}
|
|
|
|
// Resolve the host to check IP addresses
|
|
ips, err := net.LookupIP(host)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %s", ErrInvalidHost, host)
|
|
}
|
|
|
|
for _, ip := range ips {
|
|
if isPrivateIP(ip) {
|
|
return ErrSSRFBlocked
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// extractHost extracts the host from a URL string.
|
|
func extractHost(rawURL string) string {
|
|
// Simple extraction without full URL parsing
|
|
url := rawURL
|
|
if idx := strings.Index(url, "://"); idx != -1 {
|
|
url = url[idx+3:]
|
|
}
|
|
if idx := strings.Index(url, "/"); idx != -1 {
|
|
url = url[:idx]
|
|
}
|
|
if idx := strings.Index(url, "?"); idx != -1 {
|
|
url = url[:idx]
|
|
}
|
|
|
|
return url
|
|
}
|
|
|
|
// isLocalhost checks if the host is localhost.
|
|
func isLocalhost(host string) bool {
|
|
host = strings.ToLower(host)
|
|
|
|
return host == "localhost" ||
|
|
host == "127.0.0.1" ||
|
|
host == "::1" ||
|
|
host == "[::1]" ||
|
|
strings.HasSuffix(host, ".localhost") ||
|
|
strings.HasSuffix(host, ".local")
|
|
}
|
|
|
|
// isPrivateIP checks if an IP is private, loopback, or otherwise internal.
|
|
func isPrivateIP(ip net.IP) bool {
|
|
if ip == nil {
|
|
return true
|
|
}
|
|
|
|
// Check for loopback
|
|
if ip.IsLoopback() {
|
|
return true
|
|
}
|
|
|
|
// Check for private ranges
|
|
if ip.IsPrivate() {
|
|
return true
|
|
}
|
|
|
|
// Check for link-local
|
|
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
|
return true
|
|
}
|
|
|
|
// Check for unspecified (0.0.0.0 or ::)
|
|
if ip.IsUnspecified() {
|
|
return true
|
|
}
|
|
|
|
// Check for multicast
|
|
if ip.IsMulticast() {
|
|
return true
|
|
}
|
|
|
|
// Additional checks for IPv4
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
// 169.254.0.0/16 - Link local
|
|
if ip4[0] == 169 && ip4[1] == 254 {
|
|
return true
|
|
}
|
|
// 0.0.0.0/8 - Current network
|
|
if ip4[0] == 0 {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// ssrfSafeDialer is a custom dialer that validates IP addresses before connecting.
|
|
func ssrfSafeDialer(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
host, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Resolve the address
|
|
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %s", ErrInvalidHost, host)
|
|
}
|
|
|
|
// Check all resolved IPs
|
|
for _, ip := range ips {
|
|
if isPrivateIP(ip) {
|
|
return nil, ErrSSRFBlocked
|
|
}
|
|
}
|
|
|
|
// Connect using the first valid IP
|
|
var dialer net.Dialer
|
|
for _, ip := range ips {
|
|
addr := net.JoinHostPort(ip.String(), port)
|
|
conn, err := dialer.DialContext(ctx, network, addr)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to connect to %s", host)
|
|
}
|
|
|
|
// limitedReader wraps a reader and limits the number of bytes read.
|
|
type limitedReader struct {
|
|
reader io.Reader
|
|
remaining int64
|
|
}
|
|
|
|
func (r *limitedReader) Read(p []byte) (int, error) {
|
|
if r.remaining <= 0 {
|
|
return 0, ErrResponseTooLarge
|
|
}
|
|
|
|
if int64(len(p)) > r.remaining {
|
|
p = p[:r.remaining]
|
|
}
|
|
|
|
n, err := r.reader.Read(p)
|
|
r.remaining -= int64(n)
|
|
|
|
return n, err
|
|
}
|
|
|
|
// semaphoreReleasingReadCloser releases a semaphore slot when closed.
|
|
type semaphoreReleasingReadCloser struct {
|
|
*limitedReader
|
|
closer io.Closer
|
|
sem chan struct{}
|
|
}
|
|
|
|
func (r *semaphoreReleasingReadCloser) Close() error {
|
|
err := r.closer.Close()
|
|
<-r.sem // Release semaphore slot
|
|
|
|
return err
|
|
}
|