refactor: extract httpfetcher package from imgcache
All checks were successful
check / check (push) Successful in 57s

Move HTTPFetcher, Config (was FetcherConfig), SSRF-safe dialer, rate
limiting, content-type validation, and related error vars from
internal/imgcache/fetcher.go into new internal/httpfetcher/ package.

The Fetcher interface and FetchResult type also move to httpfetcher
to avoid circular imports (imgcache imports httpfetcher, not the other
way around).

Renames to avoid stuttering:
  NewHTTPFetcher -> httpfetcher.New
  FetcherConfig  -> httpfetcher.Config
  NewMockFetcher -> httpfetcher.NewMock

The ServiceConfig.FetcherConfig field is retained (it describes what
kind of config it holds, not a stutter).

Pure refactor - no behavior changes. Unit tests for the httpfetcher
package are included.

refs #39
This commit is contained in:
clawbot
2026-04-17 06:47:05 +00:00
parent 6b4a1d7607
commit a853fe7ee7
12 changed files with 414 additions and 74 deletions

View File

@@ -0,0 +1,477 @@
// Package httpfetcher fetches content from upstream HTTP origins with SSRF
// protection, per-host connection limits, and content-type validation.
package httpfetcher
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
neturl "net/url"
"strings"
"sync"
"time"
)
// Fetcher configuration constants.
const (
DefaultFetchTimeout = 30 * time.Second
DefaultMaxResponseSize = 50 << 20 // 50MB
DefaultTLSTimeout = 10 * time.Second
DefaultMaxIdleConns = 100
DefaultIdleConnTimeout = 90 * time.Second
DefaultMaxRedirects = 10
DefaultMaxConnectionsPerHost = 20
)
// Fetcher errors.
var (
ErrSSRFBlocked = errors.New("request blocked: private or internal IP")
ErrInvalidHost = errors.New("invalid or unresolvable host")
ErrUnsupportedScheme = errors.New("only HTTPS is supported")
ErrResponseTooLarge = errors.New("response exceeds maximum size")
ErrInvalidContentType = errors.New("invalid or unsupported content type")
ErrUpstreamError = errors.New("upstream server error")
ErrUpstreamTimeout = errors.New("upstream request timeout")
)
// Fetcher retrieves content from upstream origins.
type Fetcher interface {
// Fetch retrieves content from the given URL.
Fetch(ctx context.Context, url string) (*FetchResult, error)
}
// FetchResult contains the result of fetching from upstream.
type FetchResult struct {
// Content is the raw image data.
Content io.ReadCloser
// ContentLength is the size in bytes (-1 if unknown).
ContentLength int64
// ContentType is the MIME type from upstream.
ContentType string
// Headers contains all response headers from upstream.
Headers map[string][]string
// StatusCode is the HTTP status code from upstream.
StatusCode int
// FetchDurationMs is how long the fetch took in milliseconds.
FetchDurationMs int64
// RemoteAddr is the IP:port of the upstream server.
RemoteAddr string
// HTTPVersion is the protocol version (e.g., "1.1", "2.0").
HTTPVersion string
// TLSVersion is the TLS protocol version (e.g., "TLS 1.3").
TLSVersion string
// TLSCipherSuite is the negotiated cipher suite name.
TLSCipherSuite string
}
// Config holds configuration for the upstream fetcher.
type Config struct {
// Timeout for upstream requests.
Timeout time.Duration
// MaxResponseSize is the maximum allowed response body size.
MaxResponseSize int64
// UserAgent to send to upstream servers.
UserAgent string
// AllowedContentTypes is an allow list of MIME types to accept.
AllowedContentTypes []string
// AllowHTTP allows non-TLS connections (for testing only).
AllowHTTP bool
// MaxConnectionsPerHost limits concurrent connections to each upstream host.
MaxConnectionsPerHost int
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Timeout: DefaultFetchTimeout,
MaxResponseSize: DefaultMaxResponseSize,
UserAgent: "pixa/1.0",
AllowedContentTypes: []string{
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/avif",
"image/svg+xml",
},
AllowHTTP: false,
MaxConnectionsPerHost: DefaultMaxConnectionsPerHost,
}
}
// HTTPFetcher implements Fetcher with SSRF protection and per-host connection limits.
type HTTPFetcher struct {
client *http.Client
config *Config
hostSems map[string]chan struct{} // per-host semaphores
hostSemMu sync.Mutex // protects hostSems map
}
// New creates a new HTTPFetcher with SSRF protection.
func New(config *Config) *HTTPFetcher {
if config == nil {
config = DefaultConfig()
}
// Create transport with SSRF-safe dialer
transport := &http.Transport{
DialContext: ssrfSafeDialer,
TLSHandshakeTimeout: DefaultTLSTimeout,
MaxIdleConns: DefaultMaxIdleConns,
IdleConnTimeout: DefaultIdleConnTimeout,
}
client := &http.Client{
Transport: transport,
Timeout: config.Timeout,
// Don't follow redirects automatically - we need to validate each hop
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= DefaultMaxRedirects {
return errors.New("too many redirects")
}
// Validate the redirect target
if err := validateURL(req.URL.String(), config.AllowHTTP); err != nil {
return fmt.Errorf("redirect blocked: %w", err)
}
return nil
},
}
return &HTTPFetcher{
client: client,
config: config,
hostSems: make(map[string]chan struct{}),
}
}
// getHostSemaphore returns the semaphore for a host, creating it if necessary.
func (f *HTTPFetcher) getHostSemaphore(host string) chan struct{} {
f.hostSemMu.Lock()
defer f.hostSemMu.Unlock()
sem, ok := f.hostSems[host]
if !ok {
sem = make(chan struct{}, f.config.MaxConnectionsPerHost)
f.hostSems[host] = sem
}
return sem
}
// Fetch retrieves content from the given URL with SSRF protection.
func (f *HTTPFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
// Validate URL before making request
if err := validateURL(url, f.config.AllowHTTP); err != nil {
return nil, err
}
// Extract host for rate limiting
host := extractHost(url)
// Acquire semaphore slot for this host
sem := f.getHostSemaphore(host)
select {
case sem <- struct{}{}:
// Acquired slot
case <-ctx.Done():
return nil, ctx.Err()
}
// If we fail before returning a result, release the slot
success := false
defer func() {
if !success {
<-sem
}
}()
parsedURL, err := neturl.Parse(url)
if err != nil {
return nil, fmt.Errorf("failed to parse URL: %w", err)
}
req := &http.Request{
Method: http.MethodGet,
URL: parsedURL,
Header: make(http.Header),
}
req = req.WithContext(ctx)
req.Header.Set("User-Agent", f.config.UserAgent)
req.Header.Set("Accept", strings.Join(f.config.AllowedContentTypes, ", "))
// Use httptrace to capture connection details
var remoteAddr string
trace := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
if info.Conn != nil {
remoteAddr = info.Conn.RemoteAddr().String()
}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
startTime := time.Now()
//nolint:gosec // G704: URL validated by validateURL() above
resp, err := f.client.Do(req)
fetchDuration := time.Since(startTime)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return nil, ErrUpstreamTimeout
}
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// Extract HTTP version (strip "HTTP/" prefix)
httpVersion := strings.TrimPrefix(resp.Proto, "HTTP/")
// Extract TLS info if available
var tlsVersion, tlsCipherSuite string
if resp.TLS != nil {
tlsVersion = tls.VersionName(resp.TLS.Version)
tlsCipherSuite = tls.CipherSuiteName(resp.TLS.CipherSuite)
}
// Check status code
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
_ = resp.Body.Close()
return nil, fmt.Errorf("%w: status %d", ErrUpstreamError, resp.StatusCode)
}
// Validate content type
contentType := resp.Header.Get("Content-Type")
if !f.isAllowedContentType(contentType) {
_ = resp.Body.Close()
return nil, fmt.Errorf("%w: %s", ErrInvalidContentType, contentType)
}
// Wrap body with size limiter and semaphore releaser
limitedBody := &limitedReader{
reader: resp.Body,
remaining: f.config.MaxResponseSize,
}
// Mark success so defer doesn't release the semaphore
success = true
return &FetchResult{
Content: &semaphoreReleasingReadCloser{limitedBody, resp.Body, sem},
ContentLength: resp.ContentLength,
ContentType: contentType,
Headers: resp.Header,
StatusCode: resp.StatusCode,
FetchDurationMs: fetchDuration.Milliseconds(),
RemoteAddr: remoteAddr,
HTTPVersion: httpVersion,
TLSVersion: tlsVersion,
TLSCipherSuite: tlsCipherSuite,
}, nil
}
// isAllowedContentType checks if the content type is in the allow list.
func (f *HTTPFetcher) isAllowedContentType(contentType string) bool {
// Extract the MIME type without parameters
mediaType := strings.TrimSpace(strings.Split(contentType, ";")[0])
for _, allowed := range f.config.AllowedContentTypes {
if strings.EqualFold(mediaType, allowed) {
return true
}
}
return false
}
// validateURL checks if a URL is safe to fetch (not internal/private).
func validateURL(rawURL string, allowHTTP bool) error {
if !allowHTTP && !strings.HasPrefix(rawURL, "https://") {
return ErrUnsupportedScheme
}
// Parse to extract host
host := extractHost(rawURL)
if host == "" {
return ErrInvalidHost
}
// Remove port if present
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// Block obvious localhost patterns
if isLocalhost(host) {
return ErrSSRFBlocked
}
// Resolve the host to check IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("%w: %s", ErrInvalidHost, host)
}
for _, ip := range ips {
if isPrivateIP(ip) {
return ErrSSRFBlocked
}
}
return nil
}
// extractHost extracts the host from a URL string.
func extractHost(rawURL string) string {
// Simple extraction without full URL parsing
url := rawURL
if idx := strings.Index(url, "://"); idx != -1 {
url = url[idx+3:]
}
if idx := strings.Index(url, "/"); idx != -1 {
url = url[:idx]
}
if idx := strings.Index(url, "?"); idx != -1 {
url = url[:idx]
}
return url
}
// isLocalhost checks if the host is localhost.
func isLocalhost(host string) bool {
host = strings.ToLower(host)
return host == "localhost" ||
host == "127.0.0.1" ||
host == "::1" ||
host == "[::1]" ||
strings.HasSuffix(host, ".localhost") ||
strings.HasSuffix(host, ".local")
}
// isPrivateIP checks if an IP is private, loopback, or otherwise internal.
func isPrivateIP(ip net.IP) bool {
if ip == nil {
return true
}
// Check for loopback
if ip.IsLoopback() {
return true
}
// Check for private ranges
if ip.IsPrivate() {
return true
}
// Check for link-local
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
// Check for unspecified (0.0.0.0 or ::)
if ip.IsUnspecified() {
return true
}
// Check for multicast
if ip.IsMulticast() {
return true
}
// Additional checks for IPv4
if ip4 := ip.To4(); ip4 != nil {
// 169.254.0.0/16 - Link local
if ip4[0] == 169 && ip4[1] == 254 {
return true
}
// 0.0.0.0/8 - Current network
if ip4[0] == 0 {
return true
}
}
return false
}
// ssrfSafeDialer is a custom dialer that validates IP addresses before connecting.
func ssrfSafeDialer(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// Resolve the address
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrInvalidHost, host)
}
// Check all resolved IPs
for _, ip := range ips {
if isPrivateIP(ip) {
return nil, ErrSSRFBlocked
}
}
// Connect using the first valid IP
var dialer net.Dialer
for _, ip := range ips {
addr := net.JoinHostPort(ip.String(), port)
conn, err := dialer.DialContext(ctx, network, addr)
if err == nil {
return conn, nil
}
}
return nil, fmt.Errorf("failed to connect to %s", host)
}
// limitedReader wraps a reader and limits the number of bytes read.
type limitedReader struct {
reader io.Reader
remaining int64
}
func (r *limitedReader) Read(p []byte) (int, error) {
if r.remaining <= 0 {
return 0, ErrResponseTooLarge
}
if int64(len(p)) > r.remaining {
p = p[:r.remaining]
}
n, err := r.reader.Read(p)
r.remaining -= int64(n)
return n, err
}
// semaphoreReleasingReadCloser releases a semaphore slot when closed.
type semaphoreReleasingReadCloser struct {
*limitedReader
closer io.Closer
sem chan struct{}
}
func (r *semaphoreReleasingReadCloser) Close() error {
err := r.closer.Close()
<-r.sem // Release semaphore slot
return err
}

View File

@@ -0,0 +1,329 @@
package httpfetcher
import (
"context"
"errors"
"io"
"net"
"testing"
"testing/fstest"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.Timeout != DefaultFetchTimeout {
t.Errorf("Timeout = %v, want %v", cfg.Timeout, DefaultFetchTimeout)
}
if cfg.MaxResponseSize != DefaultMaxResponseSize {
t.Errorf("MaxResponseSize = %d, want %d", cfg.MaxResponseSize, DefaultMaxResponseSize)
}
if cfg.MaxConnectionsPerHost != DefaultMaxConnectionsPerHost {
t.Errorf("MaxConnectionsPerHost = %d, want %d",
cfg.MaxConnectionsPerHost, DefaultMaxConnectionsPerHost)
}
if cfg.AllowHTTP {
t.Error("AllowHTTP should default to false")
}
if len(cfg.AllowedContentTypes) == 0 {
t.Error("AllowedContentTypes should not be empty")
}
}
func TestNewWithNilConfigUsesDefaults(t *testing.T) {
f := New(nil)
if f == nil {
t.Fatal("New(nil) returned nil")
}
if f.config == nil {
t.Fatal("config should be populated from DefaultConfig")
}
if f.config.Timeout != DefaultFetchTimeout {
t.Errorf("Timeout = %v, want %v", f.config.Timeout, DefaultFetchTimeout)
}
}
func TestIsAllowedContentType(t *testing.T) {
f := New(DefaultConfig())
tests := []struct {
contentType string
want bool
}{
{"image/jpeg", true},
{"image/png", true},
{"image/webp", true},
{"image/jpeg; charset=utf-8", true},
{"IMAGE/JPEG", true},
{"text/html", false},
{"application/octet-stream", false},
{"", false},
}
for _, tc := range tests {
t.Run(tc.contentType, func(t *testing.T) {
got := f.isAllowedContentType(tc.contentType)
if got != tc.want {
t.Errorf("isAllowedContentType(%q) = %v, want %v", tc.contentType, got, tc.want)
}
})
}
}
func TestExtractHost(t *testing.T) {
tests := []struct {
url string
want string
}{
{"https://example.com/path", "example.com"},
{"http://example.com:8080/path", "example.com:8080"},
{"https://example.com", "example.com"},
{"https://example.com?q=1", "example.com"},
{"example.com/path", "example.com"},
{"", ""},
}
for _, tc := range tests {
t.Run(tc.url, func(t *testing.T) {
got := extractHost(tc.url)
if got != tc.want {
t.Errorf("extractHost(%q) = %q, want %q", tc.url, got, tc.want)
}
})
}
}
func TestIsLocalhost(t *testing.T) {
tests := []struct {
host string
want bool
}{
{"localhost", true},
{"LOCALHOST", true},
{"127.0.0.1", true},
{"::1", true},
{"[::1]", true},
{"foo.localhost", true},
{"foo.local", true},
{"example.com", false},
{"127.0.0.2", false}, // Handled by isPrivateIP, not isLocalhost string match
}
for _, tc := range tests {
t.Run(tc.host, func(t *testing.T) {
got := isLocalhost(tc.host)
if got != tc.want {
t.Errorf("isLocalhost(%q) = %v, want %v", tc.host, got, tc.want)
}
})
}
}
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
ip string
want bool
}{
{"127.0.0.1", true}, // loopback
{"10.0.0.1", true}, // private
{"192.168.1.1", true}, // private
{"172.16.0.1", true}, // private
{"169.254.1.1", true}, // link-local
{"0.0.0.0", true}, // unspecified
{"224.0.0.1", true}, // multicast
{"::1", true}, // IPv6 loopback
{"fe80::1", true}, // IPv6 link-local
{"8.8.8.8", false}, // public
{"2001:4860:4860::8888", false}, // public IPv6
}
for _, tc := range tests {
t.Run(tc.ip, func(t *testing.T) {
ip := net.ParseIP(tc.ip)
if ip == nil {
t.Fatalf("failed to parse IP %q", tc.ip)
}
got := isPrivateIP(ip)
if got != tc.want {
t.Errorf("isPrivateIP(%q) = %v, want %v", tc.ip, got, tc.want)
}
})
}
if !isPrivateIP(nil) {
t.Error("isPrivateIP(nil) should return true")
}
}
func TestValidateURL_RejectsNonHTTPS(t *testing.T) {
err := validateURL("http://example.com/path", false)
if !errors.Is(err, ErrUnsupportedScheme) {
t.Errorf("validateURL http = %v, want ErrUnsupportedScheme", err)
}
}
func TestValidateURL_AllowsHTTPWhenConfigured(t *testing.T) {
// Use a host that won't resolve (explicit .invalid TLD) so we don't hit DNS.
err := validateURL("http://nonexistent.invalid/path", true)
// We expect a host resolution error, not ErrUnsupportedScheme.
if errors.Is(err, ErrUnsupportedScheme) {
t.Error("validateURL with AllowHTTP should not return ErrUnsupportedScheme")
}
}
func TestValidateURL_RejectsLocalhost(t *testing.T) {
err := validateURL("https://localhost/path", false)
if !errors.Is(err, ErrSSRFBlocked) {
t.Errorf("validateURL localhost = %v, want ErrSSRFBlocked", err)
}
}
func TestValidateURL_EmptyHost(t *testing.T) {
err := validateURL("https:///path", false)
if !errors.Is(err, ErrInvalidHost) {
t.Errorf("validateURL empty host = %v, want ErrInvalidHost", err)
}
}
func TestMockFetcher_FetchesFile(t *testing.T) {
mockFS := fstest.MapFS{
"example.com/images/photo.jpg": &fstest.MapFile{Data: []byte("fake-jpeg-data")},
}
m := NewMock(mockFS)
result, err := m.Fetch(context.Background(), "https://example.com/images/photo.jpg")
if err != nil {
t.Fatalf("Fetch() error = %v", err)
}
defer func() { _ = result.Content.Close() }()
if result.ContentType != "image/jpeg" {
t.Errorf("ContentType = %q, want image/jpeg", result.ContentType)
}
data, err := io.ReadAll(result.Content)
if err != nil {
t.Fatalf("read content: %v", err)
}
if string(data) != "fake-jpeg-data" {
t.Errorf("Content = %q, want %q", string(data), "fake-jpeg-data")
}
if result.ContentLength != int64(len("fake-jpeg-data")) {
t.Errorf("ContentLength = %d, want %d", result.ContentLength, len("fake-jpeg-data"))
}
}
func TestMockFetcher_MissingFileReturnsUpstreamError(t *testing.T) {
mockFS := fstest.MapFS{}
m := NewMock(mockFS)
_, err := m.Fetch(context.Background(), "https://example.com/missing.jpg")
if !errors.Is(err, ErrUpstreamError) {
t.Errorf("Fetch() error = %v, want ErrUpstreamError", err)
}
}
func TestMockFetcher_RespectsContextCancellation(t *testing.T) {
mockFS := fstest.MapFS{
"example.com/photo.jpg": &fstest.MapFile{Data: []byte("data")},
}
m := NewMock(mockFS)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := m.Fetch(ctx, "https://example.com/photo.jpg")
if !errors.Is(err, context.Canceled) {
t.Errorf("Fetch() error = %v, want context.Canceled", err)
}
}
func TestDetectContentTypeFromPath(t *testing.T) {
tests := []struct {
path string
want string
}{
{"foo/bar.jpg", "image/jpeg"},
{"foo/bar.JPG", "image/jpeg"},
{"foo/bar.jpeg", "image/jpeg"},
{"foo/bar.png", "image/png"},
{"foo/bar.gif", "image/gif"},
{"foo/bar.webp", "image/webp"},
{"foo/bar.avif", "image/avif"},
{"foo/bar.svg", "image/svg+xml"},
{"foo/bar.bin", "application/octet-stream"},
{"foo/bar", "application/octet-stream"},
}
for _, tc := range tests {
t.Run(tc.path, func(t *testing.T) {
got := detectContentTypeFromPath(tc.path)
if got != tc.want {
t.Errorf("detectContentTypeFromPath(%q) = %q, want %q", tc.path, got, tc.want)
}
})
}
}
func TestLimitedReader_EnforcesLimit(t *testing.T) {
src := make([]byte, 100)
r := &limitedReader{
reader: &byteReader{data: src},
remaining: 50,
}
buf := make([]byte, 100)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("first Read error = %v", err)
}
if n > 50 {
t.Errorf("read %d bytes, should be capped at 50", n)
}
// Drain until limit is exhausted.
total := n
for total < 50 {
nn, err := r.Read(buf)
total += nn
if err != nil {
t.Fatalf("during drain: %v", err)
}
}
// Now the limit is exhausted — next read should error.
_, err = r.Read(buf)
if !errors.Is(err, ErrResponseTooLarge) {
t.Errorf("exhausted Read error = %v, want ErrResponseTooLarge", err)
}
}
// byteReader is a minimal io.Reader over a byte slice for testing.
type byteReader struct {
data []byte
pos int
}
func (r *byteReader) Read(p []byte) (int, error) {
if r.pos >= len(r.data) {
return 0, io.EOF
}
n := copy(p, r.data[r.pos:])
r.pos += n
return n, nil
}

View File

@@ -0,0 +1,115 @@
package httpfetcher
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"strings"
)
// MockFetcher implements Fetcher using an embedded filesystem.
// Files are organized as: hostname/path/to/file.ext
// URLs like https://example.com/images/photo.jpg map to example.com/images/photo.jpg.
type MockFetcher struct {
fs fs.FS
}
// NewMock creates a new mock fetcher backed by the given filesystem.
func NewMock(fsys fs.FS) *MockFetcher {
return &MockFetcher{fs: fsys}
}
// Fetch retrieves content from the mock filesystem.
func (m *MockFetcher) Fetch(ctx context.Context, url string) (*FetchResult, error) {
// Check context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// Parse URL to get filesystem path
path, err := urlToFSPath(url)
if err != nil {
return nil, err
}
// Open the file
f, err := m.fs.Open(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: status 404", ErrUpstreamError)
}
return nil, fmt.Errorf("failed to open mock file: %w", err)
}
// Get file info for content length
stat, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("failed to stat mock file: %w", err)
}
// Detect content type from extension
contentType := detectContentTypeFromPath(path)
return &FetchResult{
Content: f.(io.ReadCloser),
ContentLength: stat.Size(),
ContentType: contentType,
Headers: make(http.Header),
}, nil
}
// urlToFSPath converts a URL to a filesystem path.
// https://example.com/images/photo.jpg -> example.com/images/photo.jpg
func urlToFSPath(rawURL string) (string, error) {
// Strip scheme
url := rawURL
if idx := strings.Index(url, "://"); idx != -1 {
url = url[idx+3:]
}
// Remove query string
if idx := strings.Index(url, "?"); idx != -1 {
url = url[:idx]
}
// Remove fragment
if idx := strings.Index(url, "#"); idx != -1 {
url = url[:idx]
}
if url == "" {
return "", errors.New("empty URL path")
}
return url, nil
}
// detectContentTypeFromPath returns the MIME type based on file extension.
func detectContentTypeFromPath(path string) string {
path = strings.ToLower(path)
switch {
case strings.HasSuffix(path, ".jpg"), strings.HasSuffix(path, ".jpeg"):
return "image/jpeg"
case strings.HasSuffix(path, ".png"):
return "image/png"
case strings.HasSuffix(path, ".gif"):
return "image/gif"
case strings.HasSuffix(path, ".webp"):
return "image/webp"
case strings.HasSuffix(path, ".avif"):
return "image/avif"
case strings.HasSuffix(path, ".svg"):
return "image/svg+xml"
default:
return "application/octet-stream"
}
}