Files
dnswatcher/internal/resolver/iterative.go
clawbot 63c79c0bad resolver: reduce query timeout to 1s and limit root fan-out to 3 (closes #29)
Timeout rationale: 3× max antipodal RTT (~300ms) + 10ms processing = ~910ms, rounded to 1s.
Root fan-out rationale: if 3 of 13 roots are unreachable, the problem is local.
2026-02-22 03:44:10 -08:00

746 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package resolver
import (
"context"
"errors"
"fmt"
"math/rand/v2"
"net"
"sort"
"strings"
"time"
"github.com/miekg/dns"
)
const (
// queryTimeoutDuration is the per-exchange DNS timeout.
//
// Rationale: maximum RTT to antipodal root/TLD servers is
// ~300ms. We use 3× max RTT + 10ms processing ≈ 910ms,
// rounded to 1s. Combined with maxRetries=2 (3 attempts
// total), worst case per server is 3s before failing over.
queryTimeoutDuration = 1 * time.Second
maxRetries = 2
maxDelegation = 20
timeoutMultiplier = 2
minDomainLabels = 2
)
// ErrRefused is returned when a DNS server refuses a query.
var ErrRefused = errors.New("dns query refused")
func allRootServers() []string {
return []string{
"198.41.0.4", // a.root-servers.net
"170.247.170.2", // b
"192.33.4.12", // c
"199.7.91.13", // d
"192.203.230.10", // e
"192.5.5.241", // f
"192.112.36.4", // g
"198.97.190.53", // h
"192.36.148.17", // i
"192.58.128.30", // j
"193.0.14.129", // k
"199.7.83.42", // l
"202.12.27.33", // m
}
}
// rootServerList returns 3 randomly-selected root servers.
// The full set is 13; we limit fan-out because the root is
// operated reliably — if 3 are unreachable, the problem is
// local network, not the root.
func rootServerList() []string {
shuffled := allRootServers()
rand.Shuffle(len(shuffled), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})
return shuffled[:3]
}
func checkCtx(ctx context.Context) error {
err := ctx.Err()
if err != nil {
return ErrContextCanceled
}
return nil
}
func (r *Resolver) exchangeWithTimeout(
ctx context.Context,
msg *dns.Msg,
addr string,
attempt int,
) (*dns.Msg, error) {
_ = attempt // timeout escalation handled by client config
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
return resp, err
}
func (r *Resolver) tryExchange(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, error) {
var resp *dns.Msg
var err error
for attempt := range maxRetries {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err = r.exchangeWithTimeout(
ctx, msg, addr, attempt,
)
if err == nil {
break
}
}
return resp, err
}
func (r *Resolver) retryTCP(
ctx context.Context,
msg *dns.Msg,
addr string,
resp *dns.Msg,
) *dns.Msg {
if !resp.Truncated {
return resp
}
tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr)
if tcpErr == nil {
return tcpResp
}
return resp
}
// queryDNS sends a DNS query to a specific server IP.
// Tries non-recursive first, falls back to recursive on
// REFUSED (handles DNS interception environments).
func (r *Resolver) queryDNS(
ctx context.Context,
serverIP string,
name string,
qtype uint16,
) (*dns.Msg, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
name = dns.Fqdn(name)
addr := net.JoinHostPort(serverIP, "53")
msg := new(dns.Msg)
msg.SetQuestion(name, qtype)
msg.RecursionDesired = false
resp, err := r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err)
}
if resp.Rcode == dns.RcodeRefused {
msg.RecursionDesired = true
resp, err = r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, err,
)
}
if resp.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, ErrRefused,
)
}
}
resp = r.retryTCP(ctx, msg, addr, resp)
return resp, nil
}
func extractNSSet(rrs []dns.RR) []string {
nsSet := make(map[string]bool)
for _, rr := range rrs {
if ns, ok := rr.(*dns.NS); ok {
nsSet[strings.ToLower(ns.Ns)] = true
}
}
names := make([]string, 0, len(nsSet))
for n := range nsSet {
names = append(names, n)
}
sort.Strings(names)
return names
}
func extractGlue(rrs []dns.RR) map[string][]net.IP {
glue := make(map[string][]net.IP)
for _, rr := range rrs {
switch r := rr.(type) {
case *dns.A:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.A)
case *dns.AAAA:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.AAAA)
}
}
return glue
}
func glueIPs(nsNames []string, glue map[string][]net.IP) []string {
var ips []string
for _, ns := range nsNames {
for _, addr := range glue[ns] {
if v4 := addr.To4(); v4 != nil {
ips = append(ips, v4.String())
}
}
}
return ips
}
func (r *Resolver) followDelegation(
ctx context.Context,
domain string,
servers []string,
) ([]string, error) {
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryServers(
ctx, servers, domain, dns.TypeNS,
)
if err != nil {
return nil, err
}
ansNS := extractNSSet(resp.Answer)
if len(ansNS) > 0 {
return ansNS, nil
}
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
return r.resolveNSRecursive(ctx, domain)
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
nextServers = r.resolveNSIPs(ctx, authNS)
}
if len(nextServers) == 0 {
return nil, ErrNoNameservers
}
servers = nextServers
}
return nil, ErrNoNameservers
}
func (r *Resolver) queryServers(
ctx context.Context,
servers []string,
name string,
qtype uint16,
) (*dns.Msg, error) {
var lastErr error
for _, ip := range servers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryDNS(ctx, ip, name, qtype)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, fmt.Errorf("all servers failed: %w", lastErr)
}
func (r *Resolver) resolveNSIPs(
ctx context.Context,
nsNames []string,
) []string {
var ips []string
for _, ns := range nsNames {
resolved, err := r.resolveARecord(ctx, ns)
if err == nil {
ips = append(ips, resolved...)
}
if len(ips) > 0 {
break
}
}
return ips
}
// resolveNSRecursive queries for NS records using recursive
// resolution as a fallback for intercepted environments.
func (r *Resolver) resolveNSRecursive(
ctx context.Context,
domain string,
) ([]string, error) {
domain = dns.Fqdn(domain)
msg := new(dns.Msg)
msg.SetQuestion(domain, dns.TypeNS)
msg.RecursionDesired = true
for _, ip := range rootServerList() {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
addr := net.JoinHostPort(ip, "53")
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil {
continue
}
nsNames := extractNSSet(resp.Answer)
if len(nsNames) > 0 {
return nsNames, nil
}
}
return nil, ErrNoNameservers
}
// resolveARecord resolves a hostname to IPv4 addresses.
func (r *Resolver) resolveARecord(
ctx context.Context,
hostname string,
) ([]string, error) {
hostname = dns.Fqdn(hostname)
msg := new(dns.Msg)
msg.SetQuestion(hostname, dns.TypeA)
msg.RecursionDesired = true
for _, ip := range rootServerList() {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
addr := net.JoinHostPort(ip, "53")
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil {
continue
}
var ips []string
for _, rr := range resp.Answer {
if a, ok := rr.(*dns.A); ok {
ips = append(ips, a.A.String())
}
}
if len(ips) > 0 {
return ips, nil
}
}
return nil, fmt.Errorf(
"cannot resolve %s: %w", hostname, ErrNoNameservers,
)
}
// FindAuthoritativeNameservers traces the delegation chain from
// root servers to discover all authoritative nameservers for the
// given domain. Walks up the label hierarchy for subdomains.
func (r *Resolver) FindAuthoritativeNameservers(
ctx context.Context,
domain string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
domain = dns.Fqdn(strings.ToLower(domain))
labels := dns.SplitDomainName(domain)
for i := range labels {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
candidate := strings.Join(labels[i:], ".") + "."
nsNames, err := r.followDelegation(
ctx, candidate, rootServerList(),
)
if err == nil && len(nsNames) > 0 {
sort.Strings(nsNames)
return nsNames, nil
}
}
return nil, ErrNoNameservers
}
// QueryNameserver queries a specific nameserver for all record
// types and builds a NameserverResponse.
func (r *Resolver) QueryNameserver(
ctx context.Context,
nsHostname string,
hostname string,
) (*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
nsIPs, err := r.resolveARecord(ctx, nsHostname)
if err != nil {
return nil, fmt.Errorf("resolving NS %s: %w", nsHostname, err)
}
hostname = dns.Fqdn(hostname)
return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname)
}
func (r *Resolver) queryAllTypes(
ctx context.Context,
nsHostname string,
nsIP string,
hostname string,
) (*NameserverResponse, error) {
resp := &NameserverResponse{
Nameserver: nsHostname,
Records: make(map[string][]string),
Status: StatusOK,
}
qtypes := []uint16{
dns.TypeA, dns.TypeAAAA, dns.TypeCNAME,
dns.TypeMX, dns.TypeTXT, dns.TypeSRV,
dns.TypeCAA, dns.TypeNS,
}
state := r.queryEachType(ctx, nsIP, hostname, qtypes, resp)
classifyResponse(resp, state)
return resp, nil
}
type queryState struct {
gotNXDomain bool
gotSERVFAIL bool
hasRecords bool
}
func (r *Resolver) queryEachType(
ctx context.Context,
nsIP string,
hostname string,
qtypes []uint16,
resp *NameserverResponse,
) queryState {
var state queryState
for _, qtype := range qtypes {
if checkCtx(ctx) != nil {
break
}
r.querySingleType(ctx, nsIP, hostname, qtype, resp, &state)
}
for k := range resp.Records {
sort.Strings(resp.Records[k])
}
return state
}
func (r *Resolver) querySingleType(
ctx context.Context,
nsIP string,
hostname string,
qtype uint16,
resp *NameserverResponse,
state *queryState,
) {
msg, err := r.queryDNS(ctx, nsIP, hostname, qtype)
if err != nil {
return
}
if msg.Rcode == dns.RcodeNameError {
state.gotNXDomain = true
return
}
if msg.Rcode == dns.RcodeServerFailure {
state.gotSERVFAIL = true
return
}
collectAnswerRecords(msg, resp, state)
}
func collectAnswerRecords(
msg *dns.Msg,
resp *NameserverResponse,
state *queryState,
) {
for _, rr := range msg.Answer {
val := extractRecordValue(rr)
if val == "" {
continue
}
typeName := dns.TypeToString[rr.Header().Rrtype]
resp.Records[typeName] = append(
resp.Records[typeName], val,
)
state.hasRecords = true
}
}
func classifyResponse(resp *NameserverResponse, state queryState) {
switch {
case state.gotNXDomain && !state.hasRecords:
resp.Status = StatusNXDomain
case state.gotSERVFAIL && !state.hasRecords:
resp.Status = StatusError
case !state.hasRecords && !state.gotNXDomain:
resp.Status = StatusNoData
}
}
// extractRecordValue formats a DNS RR value as a string.
func extractRecordValue(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
return r.A.String()
case *dns.AAAA:
return r.AAAA.String()
case *dns.CNAME:
return r.Target
case *dns.MX:
return fmt.Sprintf("%d %s", r.Preference, r.Mx)
case *dns.TXT:
return strings.Join(r.Txt, "")
case *dns.SRV:
return fmt.Sprintf(
"%d %d %d %s",
r.Priority, r.Weight, r.Port, r.Target,
)
case *dns.CAA:
return fmt.Sprintf(
"%d %s \"%s\"", r.Flag, r.Tag, r.Value,
)
case *dns.NS:
return r.Ns
default:
return ""
}
}
// parentDomain returns the registerable parent domain.
func parentDomain(hostname string) string {
hostname = dns.Fqdn(strings.ToLower(hostname))
labels := dns.SplitDomainName(hostname)
if len(labels) <= minDomainLabels {
return strings.Join(labels, ".") + "."
}
return strings.Join(
labels[len(labels)-minDomainLabels:], ".",
) + "."
}
// QueryAllNameservers discovers auth NSes for the hostname's
// parent domain, then queries each one independently.
func (r *Resolver) QueryAllNameservers(
ctx context.Context,
hostname string,
) (map[string]*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
parent := parentDomain(hostname)
nameservers, err := r.FindAuthoritativeNameservers(ctx, parent)
if err != nil {
return nil, err
}
return r.queryEachNS(ctx, nameservers, hostname)
}
func (r *Resolver) queryEachNS(
ctx context.Context,
nameservers []string,
hostname string,
) (map[string]*NameserverResponse, error) {
results := make(map[string]*NameserverResponse)
for _, ns := range nameservers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.QueryNameserver(ctx, ns, hostname)
if err != nil {
results[ns] = &NameserverResponse{
Nameserver: ns,
Records: make(map[string][]string),
Status: StatusError,
Error: err.Error(),
}
continue
}
results[ns] = resp
}
return results, nil
}
// LookupNS returns the NS record set for a domain.
func (r *Resolver) LookupNS(
ctx context.Context,
domain string,
) ([]string, error) {
return r.FindAuthoritativeNameservers(ctx, domain)
}
// LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname, keyed by authoritative nameserver.
func (r *Resolver) LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error) {
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
out := make(map[string]map[string][]string, len(results))
for ns, resp := range results {
out[ns] = resp.Records
}
return out, nil
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains up to MaxCNAMEDepth.
func (r *Resolver) ResolveIPAddresses(
ctx context.Context,
hostname string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
return r.resolveIPWithCNAME(ctx, hostname, 0)
}
func (r *Resolver) resolveIPWithCNAME(
ctx context.Context,
hostname string,
depth int,
) ([]string, error) {
if depth > MaxCNAMEDepth {
return nil, ErrCNAMEDepthExceeded
}
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
ips, cnameTarget := collectIPs(results)
if len(ips) == 0 && cnameTarget != "" {
return r.resolveIPWithCNAME(ctx, cnameTarget, depth+1)
}
sort.Strings(ips)
return ips, nil
}
func collectIPs(
results map[string]*NameserverResponse,
) ([]string, string) {
seen := make(map[string]bool)
var ips []string
var cnameTarget string
for _, resp := range results {
if resp.Status == StatusNXDomain {
continue
}
for _, ip := range resp.Records["A"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
for _, ip := range resp.Records["AAAA"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
if len(resp.Records["CNAME"]) > 0 && cnameTarget == "" {
cnameTarget = resp.Records["CNAME"][0]
}
}
return ips, cnameTarget
}