package delivery import ( "context" "errors" "fmt" "net" "net/http" "net/url" "time" ) const ( // dnsResolutionTimeout is the maximum time to wait for // DNS resolution during SSRF validation. dnsResolutionTimeout = 5 * time.Second ) // Sentinel errors for SSRF validation. var ( errNoHostname = errors.New("URL has no hostname") errNoIPs = errors.New( "hostname resolved to no IP addresses", ) errBlockedIP = errors.New( "blocked private/reserved IP range", ) errInvalidScheme = errors.New( "only http and https are allowed", ) ) // blockedNetworks contains all private/reserved IP ranges // that should be blocked to prevent SSRF attacks. // //nolint:gochecknoglobals // package-level network list is appropriate here var blockedNetworks []*net.IPNet //nolint:gochecknoinits // init is the idiomatic way to parse CIDRs once at startup func init() { cidrs := []string{ "127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "169.254.0.0/16", "0.0.0.0/8", "100.64.0.0/10", "192.0.0.0/24", "192.0.2.0/24", "198.18.0.0/15", "198.51.100.0/24", "203.0.113.0/24", "224.0.0.0/4", "240.0.0.0/4", "::1/128", "fc00::/7", "fe80::/10", } for _, cidr := range cidrs { _, network, err := net.ParseCIDR(cidr) if err != nil { panic(fmt.Sprintf( "ssrf: failed to parse CIDR %q: %v", cidr, err, )) } blockedNetworks = append( blockedNetworks, network, ) } } // isBlockedIP checks whether an IP address falls within // any blocked private/reserved network range. func isBlockedIP(ip net.IP) bool { for _, network := range blockedNetworks { if network.Contains(ip) { return true } } return false } // ValidateTargetURL checks that an HTTP delivery target // URL is safe from SSRF attacks. func ValidateTargetURL( ctx context.Context, targetURL string, ) error { parsed, err := url.Parse(targetURL) if err != nil { return fmt.Errorf("invalid URL: %w", err) } err = validateScheme(parsed.Scheme) if err != nil { return err } host := parsed.Hostname() if host == "" { return errNoHostname } if ip := net.ParseIP(host); ip != nil { return checkBlockedIP(ip) } return validateHostname(ctx, host) } func validateScheme(scheme string) error { if scheme != "http" && scheme != "https" { return fmt.Errorf( "unsupported URL scheme %q: %w", scheme, errInvalidScheme, ) } return nil } func checkBlockedIP(ip net.IP) error { if isBlockedIP(ip) { return fmt.Errorf( "target IP %s is in a blocked "+ "private/reserved range: %w", ip, errBlockedIP, ) } return nil } func validateHostname( ctx context.Context, host string, ) error { dnsCtx, cancel := context.WithTimeout( ctx, dnsResolutionTimeout, ) defer cancel() ips, err := net.DefaultResolver.LookupIPAddr( dnsCtx, host, ) if err != nil { return fmt.Errorf( "failed to resolve hostname %q: %w", host, err, ) } if len(ips) == 0 { return fmt.Errorf( "hostname %q: %w", host, errNoIPs, ) } for _, ipAddr := range ips { if isBlockedIP(ipAddr.IP) { return fmt.Errorf( "hostname %q resolves to blocked "+ "IP %s: %w", host, ipAddr.IP, errBlockedIP, ) } } return nil } // NewSSRFSafeTransport creates an http.Transport with a // custom DialContext that blocks connections to // private/reserved IP addresses. func NewSSRFSafeTransport() *http.Transport { return &http.Transport{ DialContext: ssrfDialContext, } } func ssrfDialContext( ctx context.Context, network, addr string, ) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf( "ssrf: invalid address %q: %w", addr, err, ) } ips, err := net.DefaultResolver.LookupIPAddr( ctx, host, ) if err != nil { return nil, fmt.Errorf( "ssrf: DNS resolution failed for %q: %w", host, err, ) } for _, ipAddr := range ips { if isBlockedIP(ipAddr.IP) { return nil, fmt.Errorf( "ssrf: connection to %s (%s) "+ "blocked: %w", host, ipAddr.IP, errBlockedIP, ) } } var dialer net.Dialer return dialer.DialContext( ctx, network, net.JoinHostPort(ips[0].IP.String(), port), ) }