From 018c280267235f8b6da7fb2fb1234136d0e55446 Mon Sep 17 00:00:00 2001 From: sneak Date: Thu, 8 Jan 2026 02:59:48 -0800 Subject: [PATCH] Add ParseImagePath for chi wildcard and upstream fetcher with SSRF protection --- internal/imgcache/fetcher.go | 351 ++++++++++++++++++++++++++++ internal/imgcache/urlparser.go | 27 ++- internal/imgcache/urlparser_test.go | 57 +++++ 3 files changed, 430 insertions(+), 5 deletions(-) create mode 100644 internal/imgcache/fetcher.go diff --git a/internal/imgcache/fetcher.go b/internal/imgcache/fetcher.go new file mode 100644 index 0000000..cfc1e7d --- /dev/null +++ b/internal/imgcache/fetcher.go @@ -0,0 +1,351 @@ +package imgcache + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" +) + +// Fetcher configuration constants. +const ( + DefaultFetchTimeout = 30 * time.Second + DefaultMaxResponseSize = 50 << 20 // 50MB + DefaultTLSTimeout = 10 * time.Second + DefaultMaxIdleConns = 100 + DefaultIdleConnTimeout = 90 * time.Second + DefaultMaxRedirects = 10 +) + +// Fetcher errors. +var ( + ErrSSRFBlocked = errors.New("request blocked: private or internal IP") + ErrInvalidHost = errors.New("invalid or unresolvable host") + ErrUnsupportedScheme = errors.New("only HTTPS is supported") + ErrResponseTooLarge = errors.New("response exceeds maximum size") + ErrInvalidContentType = errors.New("invalid or unsupported content type") + ErrUpstreamError = errors.New("upstream server error") + ErrUpstreamTimeout = errors.New("upstream request timeout") +) + +// FetcherConfig holds configuration for the upstream fetcher. +type FetcherConfig struct { + // Timeout for upstream requests + Timeout time.Duration + // MaxResponseSize is the maximum allowed response body size + MaxResponseSize int64 + // UserAgent to send to upstream servers + UserAgent string + // AllowedContentTypes is a whitelist of MIME types to accept + AllowedContentTypes []string + // AllowHTTP allows non-TLS connections (for testing only) + AllowHTTP bool +} + +// DefaultFetcherConfig returns sensible defaults. +func DefaultFetcherConfig() *FetcherConfig { + return &FetcherConfig{ + Timeout: DefaultFetchTimeout, + MaxResponseSize: DefaultMaxResponseSize, + UserAgent: "pixa/1.0", + AllowedContentTypes: []string{ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + "image/avif", + "image/svg+xml", + }, + AllowHTTP: false, + } +} + +// HTTPFetcher implements the Fetcher interface with SSRF protection. +type HTTPFetcher struct { + client *http.Client + config *FetcherConfig +} + +// NewHTTPFetcher creates a new fetcher with SSRF protection. +func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher { + if config == nil { + config = DefaultFetcherConfig() + } + + // Create transport with SSRF-safe dialer + transport := &http.Transport{ + DialContext: ssrfSafeDialer, + TLSHandshakeTimeout: DefaultTLSTimeout, + MaxIdleConns: DefaultMaxIdleConns, + IdleConnTimeout: DefaultIdleConnTimeout, + } + + client := &http.Client{ + Transport: transport, + Timeout: config.Timeout, + // Don't follow redirects automatically - we need to validate each hop + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= DefaultMaxRedirects { + return errors.New("too many redirects") + } + // Validate the redirect target + if err := validateURL(req.URL.String(), config.AllowHTTP); err != nil { + return fmt.Errorf("redirect blocked: %w", err) + } + + return nil + }, + } + + return &HTTPFetcher{ + client: client, + config: config, + } +} + +// Fetch retrieves content from the given URL with SSRF protection. +func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) { + // Validate URL before making request + if err := validateURL(url, f.config.AllowHTTP); err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("User-Agent", f.config.UserAgent) + req.Header.Set("Accept", strings.Join(f.config.AllowedContentTypes, ", ")) + + resp, err := f.client.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil, ErrUpstreamTimeout + } + + return nil, fmt.Errorf("upstream request failed: %w", err) + } + + // Check status code + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + _ = resp.Body.Close() + + return nil, fmt.Errorf("%w: status %d", ErrUpstreamError, resp.StatusCode) + } + + // Validate content type + contentType := resp.Header.Get("Content-Type") + if !f.isAllowedContentType(contentType) { + _ = resp.Body.Close() + + return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType) + } + + // Wrap body with size limiter + limitedBody := &limitedReader{ + reader: resp.Body, + remaining: f.config.MaxResponseSize, + } + + return &FetchResult{ + Content: &limitedReadCloser{limitedBody, resp.Body}, + ContentLength: resp.ContentLength, + ContentType: contentType, + Headers: resp.Header, + }, nil +} + +// isAllowedContentType checks if the content type is in the whitelist. +func (f *HTTPFetcher) isAllowedContentType(contentType string) bool { + // Extract the MIME type without parameters + mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0]) + + for _, allowed := range f.config.AllowedContentTypes { + if strings.EqualFold(mediaType, allowed) { + return true + } + } + + return false +} + +// validateURL checks if a URL is safe to fetch (not internal/private). +func validateURL(rawURL string, allowHTTP bool) error { + if !allowHTTP && !strings.HasPrefix(rawURL, "https://") { + return ErrUnsupportedScheme + } + + // Parse to extract host + host := extractHost(rawURL) + if host == "" { + return ErrInvalidHost + } + + // Remove port if present + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + + // Block obvious localhost patterns + if isLocalhost(host) { + return ErrSSRFBlocked + } + + // Resolve the host to check IP addresses + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("%w: %s", ErrInvalidHost, host) + } + + for _, ip := range ips { + if isPrivateIP(ip) { + return ErrSSRFBlocked + } + } + + return nil +} + +// extractHost extracts the host from a URL string. +func extractHost(rawURL string) string { + // Simple extraction without full URL parsing + url := rawURL + if idx := strings.Index(url, "://"); idx != -1 { + url = url[idx+3:] + } + if idx := strings.Index(url, "/"); idx != -1 { + url = url[:idx] + } + if idx := strings.Index(url, "?"); idx != -1 { + url = url[:idx] + } + + return url +} + +// isLocalhost checks if the host is localhost. +func isLocalhost(host string) bool { + host = strings.ToLower(host) + + return host == "localhost" || + host == "127.0.0.1" || + host == "::1" || + host == "[::1]" || + strings.HasSuffix(host, ".localhost") || + strings.HasSuffix(host, ".local") +} + +// isPrivateIP checks if an IP is private, loopback, or otherwise internal. +func isPrivateIP(ip net.IP) bool { + if ip == nil { + return true + } + + // Check for loopback + if ip.IsLoopback() { + return true + } + + // Check for private ranges + if ip.IsPrivate() { + return true + } + + // Check for link-local + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + // Check for unspecified (0.0.0.0 or ::) + if ip.IsUnspecified() { + return true + } + + // Check for multicast + if ip.IsMulticast() { + return true + } + + // Additional checks for IPv4 + if ip4 := ip.To4(); ip4 != nil { + // 169.254.0.0/16 - Link local + if ip4[0] == 169 && ip4[1] == 254 { + return true + } + // 0.0.0.0/8 - Current network + if ip4[0] == 0 { + return true + } + } + + return false +} + +// ssrfSafeDialer is a custom dialer that validates IP addresses before connecting. +func ssrfSafeDialer(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + // Resolve the address + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidHost, host) + } + + // Check all resolved IPs + for _, ip := range ips { + if isPrivateIP(ip) { + return nil, ErrSSRFBlocked + } + } + + // Connect using the first valid IP + var dialer net.Dialer + for _, ip := range ips { + addr := net.JoinHostPort(ip.String(), port) + conn, err := dialer.DialContext(ctx, network, addr) + if err == nil { + return conn, nil + } + } + + return nil, fmt.Errorf("failed to connect to %s", host) +} + +// limitedReader wraps a reader and limits the number of bytes read. +type limitedReader struct { + reader io.Reader + remaining int64 +} + +func (r *limitedReader) Read(p []byte) (int, error) { + if r.remaining <= 0 { + return 0, ErrResponseTooLarge + } + + if int64(len(p)) > r.remaining { + p = p[:r.remaining] + } + + n, err := r.reader.Read(p) + r.remaining -= int64(n) + + return n, err +} + +// limitedReadCloser combines the limited reader with the original closer. +type limitedReadCloser struct { + *limitedReader + closer io.Closer +} + +func (r *limitedReadCloser) Close() error { + return r.closer.Close() +} diff --git a/internal/imgcache/urlparser.go b/internal/imgcache/urlparser.go index 7bf70b9..30c9356 100644 --- a/internal/imgcache/urlparser.go +++ b/internal/imgcache/urlparser.go @@ -38,12 +38,24 @@ type ParsedURL struct { Format ImageFormat } -// ParseImageURL parses a URL path like /v1/image///. +// ParseImagePath parses the path captured by chi's wildcard: //. +// This is the primary entry point when using chi routing. // Examples: -// - /v1/image/cdn.example.com/photos/cat.jpg/800x600.webp -// - /v1/image/cdn.example.com/photos/cat.jpg/0x0.jpeg -// - /v1/image/cdn.example.com/photos/cat.jpg/orig.png -// - /v1/image/cdn.example.com/photos/cat.jpg?q=1/800x600.webp +// - cdn.example.com/photos/cat.jpg/800x600.webp +// - cdn.example.com/photos/cat.jpg/0x0.jpeg +// - cdn.example.com/photos/cat.jpg/orig.png +func ParseImagePath(path string) (*ParsedURL, error) { + // Strip leading slash if present (chi may include it) + path = strings.TrimPrefix(path, "/") + if path == "" { + return nil, ErrMissingHost + } + + return parseImageComponents(path) +} + +// ParseImageURL parses a full URL path like /v1/image///. +// Use ParseImagePath instead when working with chi's wildcard capture. func ParseImageURL(urlPath string) (*ParsedURL, error) { // Remove the /v1/image/ prefix const prefix = "/v1/image/" @@ -56,6 +68,11 @@ func ParseImageURL(urlPath string) (*ParsedURL, error) { return nil, ErrMissingHost } + return parseImageComponents(remainder) +} + +// parseImageComponents parses //. structure. +func parseImageComponents(remainder string) (*ParsedURL, error) { // Find the last path segment which contains size.format lastSlash := strings.LastIndex(remainder, "/") if lastSlash == -1 { diff --git a/internal/imgcache/urlparser_test.go b/internal/imgcache/urlparser_test.go index 136d277..a7cf141 100644 --- a/internal/imgcache/urlparser_test.go +++ b/internal/imgcache/urlparser_test.go @@ -162,6 +162,63 @@ func TestParseImageURL(t *testing.T) { } } +func TestParseImagePath(t *testing.T) { + // ParseImagePath is for chi wildcard capture (no /v1/image/ prefix) + tests := []struct { + name string + input string + want *ParsedURL + wantErr bool + }{ + { + name: "chi wildcard capture", + input: "cdn.example.com/photos/cat.jpg/800x600.webp", + want: &ParsedURL{ + Host: "cdn.example.com", + Path: "/photos/cat.jpg", + Size: Size{Width: 800, Height: 600}, + Format: FormatWebP, + }, + }, + { + name: "with leading slash from chi", + input: "/cdn.example.com/photos/cat.jpg/800x600.webp", + want: &ParsedURL{ + Host: "cdn.example.com", + Path: "/photos/cat.jpg", + Size: Size{Width: 800, Height: 600}, + Format: FormatWebP, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseImagePath(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseImagePath() error = %v, wantErr %v", err, tt.wantErr) + + return + } + if err != nil { + return + } + if got.Host != tt.want.Host { + t.Errorf("Host = %q, want %q", got.Host, tt.want.Host) + } + if got.Path != tt.want.Path { + t.Errorf("Path = %q, want %q", got.Path, tt.want.Path) + } + if got.Size != tt.want.Size { + t.Errorf("Size = %v, want %v", got.Size, tt.want.Size) + } + if got.Format != tt.want.Format { + t.Errorf("Format = %q, want %q", got.Format, tt.want.Format) + } + }) + } +} + func TestParsedURL_ToImageRequest(t *testing.T) { parsed := &ParsedURL{ Host: "cdn.example.com",