Add ParseImagePath for chi wildcard and upstream fetcher with SSRF protection

This commit is contained in:
2026-01-08 02:59:48 -08:00
parent c69ddf6f61
commit 018c280267
3 changed files with 430 additions and 5 deletions

View File

@@ -0,0 +1,351 @@
package imgcache
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
)
// Fetcher configuration constants.
const (
DefaultFetchTimeout = 30 * time.Second
DefaultMaxResponseSize = 50 << 20 // 50MB
DefaultTLSTimeout = 10 * time.Second
DefaultMaxIdleConns = 100
DefaultIdleConnTimeout = 90 * time.Second
DefaultMaxRedirects = 10
)
// Fetcher errors.
var (
ErrSSRFBlocked = errors.New("request blocked: private or internal IP")
ErrInvalidHost = errors.New("invalid or unresolvable host")
ErrUnsupportedScheme = errors.New("only HTTPS is supported")
ErrResponseTooLarge = errors.New("response exceeds maximum size")
ErrInvalidContentType = errors.New("invalid or unsupported content type")
ErrUpstreamError = errors.New("upstream server error")
ErrUpstreamTimeout = errors.New("upstream request timeout")
)
// FetcherConfig holds configuration for the upstream fetcher.
type FetcherConfig struct {
// Timeout for upstream requests
Timeout time.Duration
// MaxResponseSize is the maximum allowed response body size
MaxResponseSize int64
// UserAgent to send to upstream servers
UserAgent string
// AllowedContentTypes is a whitelist of MIME types to accept
AllowedContentTypes []string
// AllowHTTP allows non-TLS connections (for testing only)
AllowHTTP bool
}
// DefaultFetcherConfig returns sensible defaults.
func DefaultFetcherConfig() *FetcherConfig {
return &FetcherConfig{
Timeout: DefaultFetchTimeout,
MaxResponseSize: DefaultMaxResponseSize,
UserAgent: "pixa/1.0",
AllowedContentTypes: []string{
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/avif",
"image/svg+xml",
},
AllowHTTP: false,
}
}
// HTTPFetcher implements the Fetcher interface with SSRF protection.
type HTTPFetcher struct {
client *http.Client
config *FetcherConfig
}
// NewHTTPFetcher creates a new fetcher with SSRF protection.
func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
if config == nil {
config = DefaultFetcherConfig()
}
// Create transport with SSRF-safe dialer
transport := &http.Transport{
DialContext: ssrfSafeDialer,
TLSHandshakeTimeout: DefaultTLSTimeout,
MaxIdleConns: DefaultMaxIdleConns,
IdleConnTimeout: DefaultIdleConnTimeout,
}
client := &http.Client{
Transport: transport,
Timeout: config.Timeout,
// Don't follow redirects automatically - we need to validate each hop
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= DefaultMaxRedirects {
return errors.New("too many redirects")
}
// Validate the redirect target
if err := validateURL(req.URL.String(), config.AllowHTTP); err != nil {
return fmt.Errorf("redirect blocked: %w", err)
}
return nil
},
}
return &HTTPFetcher{
client: client,
config: config,
}
}
// Fetch retrieves content from the given URL with SSRF protection.
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
// Validate URL before making request
if err := validateURL(url, f.config.AllowHTTP); err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", f.config.UserAgent)
req.Header.Set("Accept", strings.Join(f.config.AllowedContentTypes, ", "))
resp, err := f.client.Do(req)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, ErrUpstreamTimeout
}
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// Check status code
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
_ = resp.Body.Close()
return nil, fmt.Errorf("%w: status %d", ErrUpstreamError, resp.StatusCode)
}
// Validate content type
contentType := resp.Header.Get("Content-Type")
if !f.isAllowedContentType(contentType) {
_ = resp.Body.Close()
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
}
// Wrap body with size limiter
limitedBody := &limitedReader{
reader: resp.Body,
remaining: f.config.MaxResponseSize,
}
return &FetchResult{
Content: &limitedReadCloser{limitedBody, resp.Body},
ContentLength: resp.ContentLength,
ContentType: contentType,
Headers: resp.Header,
}, nil
}
// isAllowedContentType checks if the content type is in the whitelist.
func (f *HTTPFetcher) isAllowedContentType(contentType string) bool {
// Extract the MIME type without parameters
mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0])
for _, allowed := range f.config.AllowedContentTypes {
if strings.EqualFold(mediaType, allowed) {
return true
}
}
return false
}
// validateURL checks if a URL is safe to fetch (not internal/private).
func validateURL(rawURL string, allowHTTP bool) error {
if !allowHTTP && !strings.HasPrefix(rawURL, "https://") {
return ErrUnsupportedScheme
}
// Parse to extract host
host := extractHost(rawURL)
if host == "" {
return ErrInvalidHost
}
// Remove port if present
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// Block obvious localhost patterns
if isLocalhost(host) {
return ErrSSRFBlocked
}
// Resolve the host to check IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("%w: %s", ErrInvalidHost, host)
}
for _, ip := range ips {
if isPrivateIP(ip) {
return ErrSSRFBlocked
}
}
return nil
}
// extractHost extracts the host from a URL string.
func extractHost(rawURL string) string {
// Simple extraction without full URL parsing
url := rawURL
if idx := strings.Index(url, "://"); idx != -1 {
url = url[idx+3:]
}
if idx := strings.Index(url, "/"); idx != -1 {
url = url[:idx]
}
if idx := strings.Index(url, "?"); idx != -1 {
url = url[:idx]
}
return url
}
// isLocalhost checks if the host is localhost.
func isLocalhost(host string) bool {
host = strings.ToLower(host)
return host == "localhost" ||
host == "127.0.0.1" ||
host == "::1" ||
host == "[::1]" ||
strings.HasSuffix(host, ".localhost") ||
strings.HasSuffix(host, ".local")
}
// isPrivateIP checks if an IP is private, loopback, or otherwise internal.
func isPrivateIP(ip net.IP) bool {
if ip == nil {
return true
}
// Check for loopback
if ip.IsLoopback() {
return true
}
// Check for private ranges
if ip.IsPrivate() {
return true
}
// Check for link-local
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
// Check for unspecified (0.0.0.0 or ::)
if ip.IsUnspecified() {
return true
}
// Check for multicast
if ip.IsMulticast() {
return true
}
// Additional checks for IPv4
if ip4 := ip.To4(); ip4 != nil {
// 169.254.0.0/16 - Link local
if ip4[0] == 169 && ip4[1] == 254 {
return true
}
// 0.0.0.0/8 - Current network
if ip4[0] == 0 {
return true
}
}
return false
}
// ssrfSafeDialer is a custom dialer that validates IP addresses before connecting.
func ssrfSafeDialer(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// Resolve the address
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrInvalidHost, host)
}
// Check all resolved IPs
for _, ip := range ips {
if isPrivateIP(ip) {
return nil, ErrSSRFBlocked
}
}
// Connect using the first valid IP
var dialer net.Dialer
for _, ip := range ips {
addr := net.JoinHostPort(ip.String(), port)
conn, err := dialer.DialContext(ctx, network, addr)
if err == nil {
return conn, nil
}
}
return nil, fmt.Errorf("failed to connect to %s", host)
}
// limitedReader wraps a reader and limits the number of bytes read.
type limitedReader struct {
reader io.Reader
remaining int64
}
func (r *limitedReader) Read(p []byte) (int, error) {
if r.remaining <= 0 {
return 0, ErrResponseTooLarge
}
if int64(len(p)) > r.remaining {
p = p[:r.remaining]
}
n, err := r.reader.Read(p)
r.remaining -= int64(n)
return n, err
}
// limitedReadCloser combines the limited reader with the original closer.
type limitedReadCloser struct {
*limitedReader
closer io.Closer
}
func (r *limitedReadCloser) Close() error {
return r.closer.Close()
}

View File

@@ -38,12 +38,24 @@ type ParsedURL struct {
Format ImageFormat
}
// ParseImageURL parses a URL path like /v1/image/<host>/<path>/<size>.<format>
// ParseImagePath parses the path captured by chi's wildcard: <host>/<path>/<size>.<format>
// This is the primary entry point when using chi routing.
// Examples:
// - /v1/image/cdn.example.com/photos/cat.jpg/800x600.webp
// - /v1/image/cdn.example.com/photos/cat.jpg/0x0.jpeg
// - /v1/image/cdn.example.com/photos/cat.jpg/orig.png
// - /v1/image/cdn.example.com/photos/cat.jpg?q=1/800x600.webp
// - cdn.example.com/photos/cat.jpg/800x600.webp
// - cdn.example.com/photos/cat.jpg/0x0.jpeg
// - cdn.example.com/photos/cat.jpg/orig.png
func ParseImagePath(path string) (*ParsedURL, error) {
// Strip leading slash if present (chi may include it)
path = strings.TrimPrefix(path, "/")
if path == "" {
return nil, ErrMissingHost
}
return parseImageComponents(path)
}
// ParseImageURL parses a full URL path like /v1/image/<host>/<path>/<size>.<format>
// Use ParseImagePath instead when working with chi's wildcard capture.
func ParseImageURL(urlPath string) (*ParsedURL, error) {
// Remove the /v1/image/ prefix
const prefix = "/v1/image/"
@@ -56,6 +68,11 @@ func ParseImageURL(urlPath string) (*ParsedURL, error) {
return nil, ErrMissingHost
}
return parseImageComponents(remainder)
}
// parseImageComponents parses <host>/<path>/<size>.<format> structure.
func parseImageComponents(remainder string) (*ParsedURL, error) {
// Find the last path segment which contains size.format
lastSlash := strings.LastIndex(remainder, "/")
if lastSlash == -1 {

View File

@@ -162,6 +162,63 @@ func TestParseImageURL(t *testing.T) {
}
}
func TestParseImagePath(t *testing.T) {
// ParseImagePath is for chi wildcard capture (no /v1/image/ prefix)
tests := []struct {
name string
input string
want *ParsedURL
wantErr bool
}{
{
name: "chi wildcard capture",
input: "cdn.example.com/photos/cat.jpg/800x600.webp",
want: &ParsedURL{
Host: "cdn.example.com",
Path: "/photos/cat.jpg",
Size: Size{Width: 800, Height: 600},
Format: FormatWebP,
},
},
{
name: "with leading slash from chi",
input: "/cdn.example.com/photos/cat.jpg/800x600.webp",
want: &ParsedURL{
Host: "cdn.example.com",
Path: "/photos/cat.jpg",
Size: Size{Width: 800, Height: 600},
Format: FormatWebP,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseImagePath(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseImagePath() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
return
}
if got.Host != tt.want.Host {
t.Errorf("Host = %q, want %q", got.Host, tt.want.Host)
}
if got.Path != tt.want.Path {
t.Errorf("Path = %q, want %q", got.Path, tt.want.Path)
}
if got.Size != tt.want.Size {
t.Errorf("Size = %v, want %v", got.Size, tt.want.Size)
}
if got.Format != tt.want.Format {
t.Errorf("Format = %q, want %q", got.Format, tt.want.Format)
}
})
}
}
func TestParsedURL_ToImageRequest(t *testing.T) {
parsed := &ParsedURL{
Host: "cdn.example.com",