package delivery import ( "context" "fmt" "net" "net/http" "net/url" "time" ) const ( // dnsResolutionTimeout is the maximum time to wait for DNS resolution // during SSRF validation. dnsResolutionTimeout = 5 * time.Second ) // blockedNetworks contains all private/reserved IP ranges that should be // blocked to prevent SSRF attacks. This includes RFC 1918 private // addresses, loopback, link-local, and IPv6 equivalents. // //nolint:gochecknoglobals // package-level network list is appropriate here var blockedNetworks []*net.IPNet //nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup func init() { cidrs := []string{ // IPv4 private/reserved ranges "127.0.0.0/8", // Loopback "10.0.0.0/8", // RFC 1918 Class A private "172.16.0.0/12", // RFC 1918 Class B private "192.168.0.0/16", // RFC 1918 Class C private "169.254.0.0/16", // Link-local (cloud metadata) "0.0.0.0/8", // "This" network "100.64.0.0/10", // Shared address space (CGN) "192.0.0.0/24", // IETF protocol assignments "192.0.2.0/24", // TEST-NET-1 "198.18.0.0/15", // Benchmarking "198.51.100.0/24", // TEST-NET-2 "203.0.113.0/24", // TEST-NET-3 "224.0.0.0/4", // Multicast "240.0.0.0/4", // Reserved for future use // IPv6 private/reserved ranges "::1/128", // Loopback "fc00::/7", // Unique local addresses "fe80::/10", // Link-local } for _, cidr := range cidrs { _, network, err := net.ParseCIDR(cidr) if err != nil { panic(fmt.Sprintf("ssrf: failed to parse CIDR %q: %v", cidr, err)) } blockedNetworks = append(blockedNetworks, network) } } // isBlockedIP checks whether an IP address falls within any blocked // private/reserved network range. func isBlockedIP(ip net.IP) bool { for _, network := range blockedNetworks { if network.Contains(ip) { return true } } return false } // ValidateTargetURL checks that an HTTP delivery target URL is safe // from SSRF attacks. It validates the URL format, resolves the hostname // to IP addresses, and verifies that none of the resolved IPs are in // blocked private/reserved ranges. // // Returns nil if the URL is safe, or an error describing the issue. func ValidateTargetURL(targetURL string) error { parsed, err := url.Parse(targetURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) } // Only allow http and https schemes if parsed.Scheme != "http" && parsed.Scheme != "https" { return fmt.Errorf("unsupported URL scheme %q: only http and https are allowed", parsed.Scheme) } host := parsed.Hostname() if host == "" { return fmt.Errorf("URL has no hostname") } // Check if the host is a raw IP address first if ip := net.ParseIP(host); ip != nil { if isBlockedIP(ip) { return fmt.Errorf("target IP %s is in a blocked private/reserved range", ip) } return nil } // Resolve hostname to IPs and check each one ctx, cancel := context.WithTimeout(context.Background(), dnsResolutionTimeout) defer cancel() ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) if err != nil { return fmt.Errorf("failed to resolve hostname %q: %w", host, err) } if len(ips) == 0 { return fmt.Errorf("hostname %q resolved to no IP addresses", host) } for _, ipAddr := range ips { if isBlockedIP(ipAddr.IP) { return fmt.Errorf("hostname %q resolves to blocked IP %s (private/reserved range)", host, ipAddr.IP) } } return nil } // NewSSRFSafeTransport creates an http.Transport with a custom DialContext // that blocks connections to private/reserved IP addresses. This provides // defense-in-depth SSRF protection at the network layer, catching cases // where DNS records change between target creation and delivery time // (DNS rebinding attacks). func NewSSRFSafeTransport() *http.Transport { return &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("ssrf: invalid address %q: %w", addr, err) } // Resolve hostname to IPs ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) if err != nil { return nil, fmt.Errorf("ssrf: DNS resolution failed for %q: %w", host, err) } // Check all resolved IPs for _, ipAddr := range ips { if isBlockedIP(ipAddr.IP) { return nil, fmt.Errorf("ssrf: connection to %s (%s) blocked — private/reserved IP range", host, ipAddr.IP) } } // Connect to the first allowed IP var dialer net.Dialer return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) }, } }