diff --git a/internal/resolver/iterative.go b/internal/resolver/iterative.go index 8f41b6d..97c58b2 100644 --- a/internal/resolver/iterative.go +++ b/internal/resolver/iterative.go @@ -13,7 +13,7 @@ import ( ) const ( - queryTimeoutDuration = 5 * time.Second + queryTimeoutDuration = 700 * time.Millisecond maxRetries = 2 maxDelegation = 20 timeoutMultiplier = 2 @@ -227,7 +227,7 @@ func (r *Resolver) followDelegation( authNS := extractNSSet(resp.Ns) if len(authNS) == 0 { - return r.resolveNSRecursive(ctx, domain) + return r.resolveNSIterative(ctx, domain) } glue := extractGlue(resp.Extra) @@ -291,60 +291,84 @@ func (r *Resolver) resolveNSIPs( return ips } -// resolveNSRecursive queries for NS records using recursive -// resolution as a fallback for intercepted environments. -func (r *Resolver) resolveNSRecursive( +// resolveNSIterative queries for NS records using iterative +// resolution as a fallback when followDelegation finds no +// authoritative answer in the delegation chain. +func (r *Resolver) resolveNSIterative( ctx context.Context, domain string, ) ([]string, error) { - domain = dns.Fqdn(domain) - msg := new(dns.Msg) - msg.SetQuestion(domain, dns.TypeNS) - msg.RecursionDesired = true + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } - for _, ip := range rootServerList()[:3] { + domain = dns.Fqdn(domain) + servers := rootServerList() + + for range maxDelegation { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } - addr := net.JoinHostPort(ip, "53") - - resp, _, err := r.client.ExchangeContext(ctx, msg, addr) + resp, err := r.queryServers( + ctx, servers, domain, dns.TypeNS, + ) if err != nil { - continue + return nil, err } nsNames := extractNSSet(resp.Answer) if len(nsNames) > 0 { return nsNames, nil } + + // Follow delegation. + authNS := extractNSSet(resp.Ns) + if len(authNS) == 0 { + break + } + + glue := extractGlue(resp.Extra) + nextServers := glueIPs(authNS, glue) + + if len(nextServers) == 0 { + break + } + + servers = nextServers } return nil, ErrNoNameservers } -// resolveARecord resolves a hostname to IPv4 addresses. +// resolveARecord resolves a hostname to IPv4 addresses using +// iterative resolution through the delegation chain. func (r *Resolver) resolveARecord( ctx context.Context, hostname string, ) ([]string, error) { - hostname = dns.Fqdn(hostname) - msg := new(dns.Msg) - msg.SetQuestion(hostname, dns.TypeA) - msg.RecursionDesired = true + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } - for _, ip := range rootServerList()[:3] { + hostname = dns.Fqdn(hostname) + servers := rootServerList() + + for range maxDelegation { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } - addr := net.JoinHostPort(ip, "53") - - resp, _, err := r.client.ExchangeContext(ctx, msg, addr) + resp, err := r.queryServers( + ctx, servers, hostname, dns.TypeA, + ) if err != nil { - continue + return nil, fmt.Errorf( + "resolving %s: %w", hostname, err, + ) } + // Check for A records in the answer section. var ips []string for _, rr := range resp.Answer { @@ -356,6 +380,24 @@ func (r *Resolver) resolveARecord( if len(ips) > 0 { return ips, nil } + + // Follow delegation if present. + authNS := extractNSSet(resp.Ns) + if len(authNS) == 0 { + break + } + + glue := extractGlue(resp.Extra) + nextServers := glueIPs(authNS, glue) + + if len(nextServers) == 0 { + // Resolve NS IPs iteratively — but guard + // against infinite recursion by using only + // already-resolved servers. + break + } + + servers = nextServers } return nil, fmt.Errorf(