Add ParseImagePath for chi wildcard and upstream fetcher with SSRF protection
This commit is contained in:
351
internal/imgcache/fetcher.go
Normal file
351
internal/imgcache/fetcher.go
Normal file
@@ -0,0 +1,351 @@
|
||||
package imgcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Fetcher configuration constants.
|
||||
const (
|
||||
DefaultFetchTimeout = 30 * time.Second
|
||||
DefaultMaxResponseSize = 50 << 20 // 50MB
|
||||
DefaultTLSTimeout = 10 * time.Second
|
||||
DefaultMaxIdleConns = 100
|
||||
DefaultIdleConnTimeout = 90 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
)
|
||||
|
||||
// Fetcher errors.
|
||||
var (
|
||||
ErrSSRFBlocked = errors.New("request blocked: private or internal IP")
|
||||
ErrInvalidHost = errors.New("invalid or unresolvable host")
|
||||
ErrUnsupportedScheme = errors.New("only HTTPS is supported")
|
||||
ErrResponseTooLarge = errors.New("response exceeds maximum size")
|
||||
ErrInvalidContentType = errors.New("invalid or unsupported content type")
|
||||
ErrUpstreamError = errors.New("upstream server error")
|
||||
ErrUpstreamTimeout = errors.New("upstream request timeout")
|
||||
)
|
||||
|
||||
// FetcherConfig holds configuration for the upstream fetcher.
|
||||
type FetcherConfig struct {
|
||||
// Timeout for upstream requests
|
||||
Timeout time.Duration
|
||||
// MaxResponseSize is the maximum allowed response body size
|
||||
MaxResponseSize int64
|
||||
// UserAgent to send to upstream servers
|
||||
UserAgent string
|
||||
// AllowedContentTypes is a whitelist of MIME types to accept
|
||||
AllowedContentTypes []string
|
||||
// AllowHTTP allows non-TLS connections (for testing only)
|
||||
AllowHTTP bool
|
||||
}
|
||||
|
||||
// DefaultFetcherConfig returns sensible defaults.
|
||||
func DefaultFetcherConfig() *FetcherConfig {
|
||||
return &FetcherConfig{
|
||||
Timeout: DefaultFetchTimeout,
|
||||
MaxResponseSize: DefaultMaxResponseSize,
|
||||
UserAgent: "pixa/1.0",
|
||||
AllowedContentTypes: []string{
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/avif",
|
||||
"image/svg+xml",
|
||||
},
|
||||
AllowHTTP: false,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPFetcher implements the Fetcher interface with SSRF protection.
|
||||
type HTTPFetcher struct {
|
||||
client *http.Client
|
||||
config *FetcherConfig
|
||||
}
|
||||
|
||||
// NewHTTPFetcher creates a new fetcher with SSRF protection.
|
||||
func NewHTTPFetcher(config *FetcherConfig) *HTTPFetcher {
|
||||
if config == nil {
|
||||
config = DefaultFetcherConfig()
|
||||
}
|
||||
|
||||
// Create transport with SSRF-safe dialer
|
||||
transport := &http.Transport{
|
||||
DialContext: ssrfSafeDialer,
|
||||
TLSHandshakeTimeout: DefaultTLSTimeout,
|
||||
MaxIdleConns: DefaultMaxIdleConns,
|
||||
IdleConnTimeout: DefaultIdleConnTimeout,
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: config.Timeout,
|
||||
// Don't follow redirects automatically - we need to validate each hop
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= DefaultMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
// Validate the redirect target
|
||||
if err := validateURL(req.URL.String(), config.AllowHTTP); err != nil {
|
||||
return fmt.Errorf("redirect blocked: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return &HTTPFetcher{
|
||||
client: client,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch retrieves content from the given URL with SSRF protection.
|
||||
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
|
||||
// Validate URL before making request
|
||||
if err := validateURL(url, f.config.AllowHTTP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", f.config.UserAgent)
|
||||
req.Header.Set("Accept", strings.Join(f.config.AllowedContentTypes, ", "))
|
||||
|
||||
resp, err := f.client.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, ErrUpstreamTimeout
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
// Check status code
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return nil, fmt.Errorf("%w: status %d", ErrUpstreamError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Validate content type
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !f.isAllowedContentType(contentType) {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
|
||||
}
|
||||
|
||||
// Wrap body with size limiter
|
||||
limitedBody := &limitedReader{
|
||||
reader: resp.Body,
|
||||
remaining: f.config.MaxResponseSize,
|
||||
}
|
||||
|
||||
return &FetchResult{
|
||||
Content: &limitedReadCloser{limitedBody, resp.Body},
|
||||
ContentLength: resp.ContentLength,
|
||||
ContentType: contentType,
|
||||
Headers: resp.Header,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isAllowedContentType checks if the content type is in the whitelist.
|
||||
func (f *HTTPFetcher) isAllowedContentType(contentType string) bool {
|
||||
// Extract the MIME type without parameters
|
||||
mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0])
|
||||
|
||||
for _, allowed := range f.config.AllowedContentTypes {
|
||||
if strings.EqualFold(mediaType, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateURL checks if a URL is safe to fetch (not internal/private).
|
||||
func validateURL(rawURL string, allowHTTP bool) error {
|
||||
if !allowHTTP && !strings.HasPrefix(rawURL, "https://") {
|
||||
return ErrUnsupportedScheme
|
||||
}
|
||||
|
||||
// Parse to extract host
|
||||
host := extractHost(rawURL)
|
||||
if host == "" {
|
||||
return ErrInvalidHost
|
||||
}
|
||||
|
||||
// Remove port if present
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
// Block obvious localhost patterns
|
||||
if isLocalhost(host) {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
|
||||
// Resolve the host to check IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrInvalidHost, host)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if isPrivateIP(ip) {
|
||||
return ErrSSRFBlocked
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractHost extracts the host from a URL string.
|
||||
func extractHost(rawURL string) string {
|
||||
// Simple extraction without full URL parsing
|
||||
url := rawURL
|
||||
if idx := strings.Index(url, "://"); idx != -1 {
|
||||
url = url[idx+3:]
|
||||
}
|
||||
if idx := strings.Index(url, "/"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
if idx := strings.Index(url, "?"); idx != -1 {
|
||||
url = url[:idx]
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
// isLocalhost checks if the host is localhost.
|
||||
func isLocalhost(host string) bool {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
return host == "localhost" ||
|
||||
host == "127.0.0.1" ||
|
||||
host == "::1" ||
|
||||
host == "[::1]" ||
|
||||
strings.HasSuffix(host, ".localhost") ||
|
||||
strings.HasSuffix(host, ".local")
|
||||
}
|
||||
|
||||
// isPrivateIP checks if an IP is private, loopback, or otherwise internal.
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for loopback
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for private ranges
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for link-local
|
||||
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unspecified (0.0.0.0 or ::)
|
||||
if ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for multicast
|
||||
if ip.IsMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Additional checks for IPv4
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// 169.254.0.0/16 - Link local
|
||||
if ip4[0] == 169 && ip4[1] == 254 {
|
||||
return true
|
||||
}
|
||||
// 0.0.0.0/8 - Current network
|
||||
if ip4[0] == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ssrfSafeDialer is a custom dialer that validates IP addresses before connecting.
|
||||
func ssrfSafeDialer(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Resolve the address
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidHost, host)
|
||||
}
|
||||
|
||||
// Check all resolved IPs
|
||||
for _, ip := range ips {
|
||||
if isPrivateIP(ip) {
|
||||
return nil, ErrSSRFBlocked
|
||||
}
|
||||
}
|
||||
|
||||
// Connect using the first valid IP
|
||||
var dialer net.Dialer
|
||||
for _, ip := range ips {
|
||||
addr := net.JoinHostPort(ip.String(), port)
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to connect to %s", host)
|
||||
}
|
||||
|
||||
// limitedReader wraps a reader and limits the number of bytes read.
|
||||
type limitedReader struct {
|
||||
reader io.Reader
|
||||
remaining int64
|
||||
}
|
||||
|
||||
func (r *limitedReader) Read(p []byte) (int, error) {
|
||||
if r.remaining <= 0 {
|
||||
return 0, ErrResponseTooLarge
|
||||
}
|
||||
|
||||
if int64(len(p)) > r.remaining {
|
||||
p = p[:r.remaining]
|
||||
}
|
||||
|
||||
n, err := r.reader.Read(p)
|
||||
r.remaining -= int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// limitedReadCloser combines the limited reader with the original closer.
|
||||
type limitedReadCloser struct {
|
||||
*limitedReader
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (r *limitedReadCloser) Close() error {
|
||||
return r.closer.Close()
|
||||
}
|
||||
@@ -38,12 +38,24 @@ type ParsedURL struct {
|
||||
Format ImageFormat
|
||||
}
|
||||
|
||||
// ParseImageURL parses a URL path like /v1/image/<host>/<path>/<size>.<format>
|
||||
// ParseImagePath parses the path captured by chi's wildcard: <host>/<path>/<size>.<format>
|
||||
// This is the primary entry point when using chi routing.
|
||||
// Examples:
|
||||
// - /v1/image/cdn.example.com/photos/cat.jpg/800x600.webp
|
||||
// - /v1/image/cdn.example.com/photos/cat.jpg/0x0.jpeg
|
||||
// - /v1/image/cdn.example.com/photos/cat.jpg/orig.png
|
||||
// - /v1/image/cdn.example.com/photos/cat.jpg?q=1/800x600.webp
|
||||
// - cdn.example.com/photos/cat.jpg/800x600.webp
|
||||
// - cdn.example.com/photos/cat.jpg/0x0.jpeg
|
||||
// - cdn.example.com/photos/cat.jpg/orig.png
|
||||
func ParseImagePath(path string) (*ParsedURL, error) {
|
||||
// Strip leading slash if present (chi may include it)
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
if path == "" {
|
||||
return nil, ErrMissingHost
|
||||
}
|
||||
|
||||
return parseImageComponents(path)
|
||||
}
|
||||
|
||||
// ParseImageURL parses a full URL path like /v1/image/<host>/<path>/<size>.<format>
|
||||
// Use ParseImagePath instead when working with chi's wildcard capture.
|
||||
func ParseImageURL(urlPath string) (*ParsedURL, error) {
|
||||
// Remove the /v1/image/ prefix
|
||||
const prefix = "/v1/image/"
|
||||
@@ -56,6 +68,11 @@ func ParseImageURL(urlPath string) (*ParsedURL, error) {
|
||||
return nil, ErrMissingHost
|
||||
}
|
||||
|
||||
return parseImageComponents(remainder)
|
||||
}
|
||||
|
||||
// parseImageComponents parses <host>/<path>/<size>.<format> structure.
|
||||
func parseImageComponents(remainder string) (*ParsedURL, error) {
|
||||
// Find the last path segment which contains size.format
|
||||
lastSlash := strings.LastIndex(remainder, "/")
|
||||
if lastSlash == -1 {
|
||||
|
||||
@@ -162,6 +162,63 @@ func TestParseImageURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImagePath(t *testing.T) {
|
||||
// ParseImagePath is for chi wildcard capture (no /v1/image/ prefix)
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want *ParsedURL
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "chi wildcard capture",
|
||||
input: "cdn.example.com/photos/cat.jpg/800x600.webp",
|
||||
want: &ParsedURL{
|
||||
Host: "cdn.example.com",
|
||||
Path: "/photos/cat.jpg",
|
||||
Size: Size{Width: 800, Height: 600},
|
||||
Format: FormatWebP,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with leading slash from chi",
|
||||
input: "/cdn.example.com/photos/cat.jpg/800x600.webp",
|
||||
want: &ParsedURL{
|
||||
Host: "cdn.example.com",
|
||||
Path: "/photos/cat.jpg",
|
||||
Size: Size{Width: 800, Height: 600},
|
||||
Format: FormatWebP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseImagePath(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseImagePath() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if got.Host != tt.want.Host {
|
||||
t.Errorf("Host = %q, want %q", got.Host, tt.want.Host)
|
||||
}
|
||||
if got.Path != tt.want.Path {
|
||||
t.Errorf("Path = %q, want %q", got.Path, tt.want.Path)
|
||||
}
|
||||
if got.Size != tt.want.Size {
|
||||
t.Errorf("Size = %v, want %v", got.Size, tt.want.Size)
|
||||
}
|
||||
if got.Format != tt.want.Format {
|
||||
t.Errorf("Format = %q, want %q", got.Format, tt.want.Format)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsedURL_ToImageRequest(t *testing.T) {
|
||||
parsed := &ParsedURL{
|
||||
Host: "cdn.example.com",
|
||||
|
||||
Reference in New Issue
Block a user