package resolver import ( "context" "errors" "fmt" "net" "sort" "strings" "time" "github.com/miekg/dns" ) const ( queryTimeoutDuration = 5 * time.Second maxRetries = 2 maxDelegation = 20 timeoutMultiplier = 2 minDomainLabels = 2 ) // ErrRefused is returned when a DNS server refuses a query. var ErrRefused = errors.New("dns query refused") func rootServerList() []string { return []string{ "198.41.0.4", // a.root-servers.net "170.247.170.2", // b "192.33.4.12", // c "199.7.91.13", // d "192.203.230.10", // e "192.5.5.241", // f "192.112.36.4", // g "198.97.190.53", // h "192.36.148.17", // i "192.58.128.30", // j "193.0.14.129", // k "199.7.83.42", // l "202.12.27.33", // m } } func checkCtx(ctx context.Context) error { err := ctx.Err() if err != nil { return ErrContextCanceled } return nil } func (r *Resolver) exchangeWithTimeout( ctx context.Context, msg *dns.Msg, addr string, attempt int, ) (*dns.Msg, error) { _ = attempt // timeout escalation handled by client config resp, _, err := r.client.ExchangeContext(ctx, msg, addr) return resp, err } func (r *Resolver) tryExchange( ctx context.Context, msg *dns.Msg, addr string, ) (*dns.Msg, error) { var resp *dns.Msg var err error for attempt := range maxRetries { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } resp, err = r.exchangeWithTimeout( ctx, msg, addr, attempt, ) if err == nil { break } } return resp, err } func (r *Resolver) retryTCP( ctx context.Context, msg *dns.Msg, addr string, resp *dns.Msg, ) *dns.Msg { if !resp.Truncated { return resp } tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr) if tcpErr == nil { return tcpResp } return resp } // queryDNS sends a DNS query to a specific server IP. // Tries non-recursive first, falls back to recursive on // REFUSED (handles DNS interception environments). func (r *Resolver) queryDNS( ctx context.Context, serverIP string, name string, qtype uint16, ) (*dns.Msg, error) { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } name = dns.Fqdn(name) addr := net.JoinHostPort(serverIP, "53") msg := new(dns.Msg) msg.SetQuestion(name, qtype) msg.RecursionDesired = false resp, err := r.tryExchange(ctx, msg, addr) if err != nil { return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err) } if resp.Rcode == dns.RcodeRefused { msg.RecursionDesired = true resp, err = r.tryExchange(ctx, msg, addr) if err != nil { return nil, fmt.Errorf( "query %s @%s: %w", name, serverIP, err, ) } if resp.Rcode == dns.RcodeRefused { return nil, fmt.Errorf( "query %s @%s: %w", name, serverIP, ErrRefused, ) } } resp = r.retryTCP(ctx, msg, addr, resp) return resp, nil } func extractNSSet(rrs []dns.RR) []string { nsSet := make(map[string]bool) for _, rr := range rrs { if ns, ok := rr.(*dns.NS); ok { nsSet[strings.ToLower(ns.Ns)] = true } } names := make([]string, 0, len(nsSet)) for n := range nsSet { names = append(names, n) } sort.Strings(names) return names } func extractGlue(rrs []dns.RR) map[string][]net.IP { glue := make(map[string][]net.IP) for _, rr := range rrs { switch r := rr.(type) { case *dns.A: name := strings.ToLower(r.Hdr.Name) glue[name] = append(glue[name], r.A) case *dns.AAAA: name := strings.ToLower(r.Hdr.Name) glue[name] = append(glue[name], r.AAAA) } } return glue } func glueIPs(nsNames []string, glue map[string][]net.IP) []string { var ips []string for _, ns := range nsNames { for _, addr := range glue[ns] { if v4 := addr.To4(); v4 != nil { ips = append(ips, v4.String()) } } } return ips } func (r *Resolver) followDelegation( ctx context.Context, domain string, servers []string, ) ([]string, error) { for range maxDelegation { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } resp, err := r.queryServers( ctx, servers, domain, dns.TypeNS, ) if err != nil { return nil, err } ansNS := extractNSSet(resp.Answer) if len(ansNS) > 0 { return ansNS, nil } authNS := extractNSSet(resp.Ns) if len(authNS) == 0 { return r.resolveNSRecursive(ctx, domain) } glue := extractGlue(resp.Extra) nextServers := glueIPs(authNS, glue) if len(nextServers) == 0 { nextServers = r.resolveNSIPs(ctx, authNS) } if len(nextServers) == 0 { return nil, ErrNoNameservers } servers = nextServers } return nil, ErrNoNameservers } func (r *Resolver) queryServers( ctx context.Context, servers []string, name string, qtype uint16, ) (*dns.Msg, error) { var lastErr error for _, ip := range servers { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } resp, err := r.queryDNS(ctx, ip, name, qtype) if err == nil { return resp, nil } lastErr = err } return nil, fmt.Errorf("all servers failed: %w", lastErr) } func (r *Resolver) resolveNSIPs( ctx context.Context, nsNames []string, ) []string { var ips []string for _, ns := range nsNames { resolved, err := r.resolveARecord(ctx, ns) if err == nil { ips = append(ips, resolved...) } if len(ips) > 0 { break } } return ips } // resolveNSRecursive queries for NS records using recursive // resolution as a fallback for intercepted environments. func (r *Resolver) resolveNSRecursive( ctx context.Context, domain string, ) ([]string, error) { domain = dns.Fqdn(domain) msg := new(dns.Msg) msg.SetQuestion(domain, dns.TypeNS) msg.RecursionDesired = true for _, ip := range rootServerList()[:3] { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } addr := net.JoinHostPort(ip, "53") resp, _, err := r.client.ExchangeContext(ctx, msg, addr) if err != nil { continue } nsNames := extractNSSet(resp.Answer) if len(nsNames) > 0 { return nsNames, nil } } return nil, ErrNoNameservers } // resolveARecord resolves a hostname to IPv4 addresses. func (r *Resolver) resolveARecord( ctx context.Context, hostname string, ) ([]string, error) { hostname = dns.Fqdn(hostname) msg := new(dns.Msg) msg.SetQuestion(hostname, dns.TypeA) msg.RecursionDesired = true for _, ip := range rootServerList()[:3] { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } addr := net.JoinHostPort(ip, "53") resp, _, err := r.client.ExchangeContext(ctx, msg, addr) if err != nil { continue } var ips []string for _, rr := range resp.Answer { if a, ok := rr.(*dns.A); ok { ips = append(ips, a.A.String()) } } if len(ips) > 0 { return ips, nil } } return nil, fmt.Errorf( "cannot resolve %s: %w", hostname, ErrNoNameservers, ) } // FindAuthoritativeNameservers traces the delegation chain from // root servers to discover all authoritative nameservers for the // given domain. Walks up the label hierarchy for subdomains. func (r *Resolver) FindAuthoritativeNameservers( ctx context.Context, domain string, ) ([]string, error) { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } domain = dns.Fqdn(strings.ToLower(domain)) labels := dns.SplitDomainName(domain) for i := range labels { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } candidate := strings.Join(labels[i:], ".") + "." nsNames, err := r.followDelegation( ctx, candidate, rootServerList(), ) if err == nil && len(nsNames) > 0 { sort.Strings(nsNames) return nsNames, nil } } return nil, ErrNoNameservers } // QueryNameserver queries a specific nameserver for all record // types and builds a NameserverResponse. func (r *Resolver) QueryNameserver( ctx context.Context, nsHostname string, hostname string, ) (*NameserverResponse, error) { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } nsIPs, err := r.resolveARecord(ctx, nsHostname) if err != nil { return nil, fmt.Errorf("resolving NS %s: %w", nsHostname, err) } hostname = dns.Fqdn(hostname) return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname) } func (r *Resolver) queryAllTypes( ctx context.Context, nsHostname string, nsIP string, hostname string, ) (*NameserverResponse, error) { resp := &NameserverResponse{ Nameserver: nsHostname, Records: make(map[string][]string), Status: StatusOK, } qtypes := []uint16{ dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, dns.TypeMX, dns.TypeTXT, dns.TypeSRV, dns.TypeCAA, dns.TypeNS, } state := r.queryEachType(ctx, nsIP, hostname, qtypes, resp) classifyResponse(resp, state) return resp, nil } type queryState struct { gotNXDomain bool gotSERVFAIL bool hasRecords bool } func (r *Resolver) queryEachType( ctx context.Context, nsIP string, hostname string, qtypes []uint16, resp *NameserverResponse, ) queryState { var state queryState for _, qtype := range qtypes { if checkCtx(ctx) != nil { break } r.querySingleType(ctx, nsIP, hostname, qtype, resp, &state) } for k := range resp.Records { sort.Strings(resp.Records[k]) } return state } func (r *Resolver) querySingleType( ctx context.Context, nsIP string, hostname string, qtype uint16, resp *NameserverResponse, state *queryState, ) { msg, err := r.queryDNS(ctx, nsIP, hostname, qtype) if err != nil { return } if msg.Rcode == dns.RcodeNameError { state.gotNXDomain = true return } if msg.Rcode == dns.RcodeServerFailure { state.gotSERVFAIL = true return } collectAnswerRecords(msg, resp, state) } func collectAnswerRecords( msg *dns.Msg, resp *NameserverResponse, state *queryState, ) { for _, rr := range msg.Answer { val := extractRecordValue(rr) if val == "" { continue } typeName := dns.TypeToString[rr.Header().Rrtype] resp.Records[typeName] = append( resp.Records[typeName], val, ) state.hasRecords = true } } func classifyResponse(resp *NameserverResponse, state queryState) { switch { case state.gotNXDomain && !state.hasRecords: resp.Status = StatusNXDomain case state.gotSERVFAIL && !state.hasRecords: resp.Status = StatusError case !state.hasRecords && !state.gotNXDomain: resp.Status = StatusNoData } } // extractRecordValue formats a DNS RR value as a string. func extractRecordValue(rr dns.RR) string { switch r := rr.(type) { case *dns.A: return r.A.String() case *dns.AAAA: return r.AAAA.String() case *dns.CNAME: return r.Target case *dns.MX: return fmt.Sprintf("%d %s", r.Preference, r.Mx) case *dns.TXT: return strings.Join(r.Txt, "") case *dns.SRV: return fmt.Sprintf( "%d %d %d %s", r.Priority, r.Weight, r.Port, r.Target, ) case *dns.CAA: return fmt.Sprintf( "%d %s \"%s\"", r.Flag, r.Tag, r.Value, ) case *dns.NS: return r.Ns default: return "" } } // parentDomain returns the registerable parent domain. func parentDomain(hostname string) string { hostname = dns.Fqdn(strings.ToLower(hostname)) labels := dns.SplitDomainName(hostname) if len(labels) <= minDomainLabels { return strings.Join(labels, ".") + "." } return strings.Join( labels[len(labels)-minDomainLabels:], ".", ) + "." } // QueryAllNameservers discovers auth NSes for the hostname's // parent domain, then queries each one independently. func (r *Resolver) QueryAllNameservers( ctx context.Context, hostname string, ) (map[string]*NameserverResponse, error) { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } parent := parentDomain(hostname) nameservers, err := r.FindAuthoritativeNameservers(ctx, parent) if err != nil { return nil, err } return r.queryEachNS(ctx, nameservers, hostname) } func (r *Resolver) queryEachNS( ctx context.Context, nameservers []string, hostname string, ) (map[string]*NameserverResponse, error) { results := make(map[string]*NameserverResponse) for _, ns := range nameservers { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } resp, err := r.QueryNameserver(ctx, ns, hostname) if err != nil { results[ns] = &NameserverResponse{ Nameserver: ns, Records: make(map[string][]string), Status: StatusError, Error: err.Error(), } continue } results[ns] = resp } return results, nil } // LookupNS returns the NS record set for a domain. func (r *Resolver) LookupNS( ctx context.Context, domain string, ) ([]string, error) { return r.FindAuthoritativeNameservers(ctx, domain) } // LookupAllRecords performs iterative resolution to find all DNS // records for the given hostname, keyed by authoritative nameserver. func (r *Resolver) LookupAllRecords( ctx context.Context, hostname string, ) (map[string]map[string][]string, error) { results, err := r.QueryAllNameservers(ctx, hostname) if err != nil { return nil, err } out := make(map[string]map[string][]string, len(results)) for ns, resp := range results { out[ns] = resp.Records } return out, nil } // ResolveIPAddresses resolves a hostname to all IPv4 and IPv6 // addresses, following CNAME chains up to MaxCNAMEDepth. func (r *Resolver) ResolveIPAddresses( ctx context.Context, hostname string, ) ([]string, error) { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } return r.resolveIPWithCNAME(ctx, hostname, 0) } func (r *Resolver) resolveIPWithCNAME( ctx context.Context, hostname string, depth int, ) ([]string, error) { if depth > MaxCNAMEDepth { return nil, ErrCNAMEDepthExceeded } results, err := r.QueryAllNameservers(ctx, hostname) if err != nil { return nil, err } ips, cnameTarget := collectIPs(results) if len(ips) == 0 && cnameTarget != "" { return r.resolveIPWithCNAME(ctx, cnameTarget, depth+1) } sort.Strings(ips) return ips, nil } func collectIPs( results map[string]*NameserverResponse, ) ([]string, string) { seen := make(map[string]bool) var ips []string var cnameTarget string for _, resp := range results { if resp.Status == StatusNXDomain { continue } for _, ip := range resp.Records["A"] { if !seen[ip] { seen[ip] = true ips = append(ips, ip) } } for _, ip := range resp.Records["AAAA"] { if !seen[ip] { seen[ip] = true ips = append(ips, ip) } } if len(resp.Records["CNAME"]) > 0 && cnameTarget == "" { cnameTarget = resp.Records["CNAME"][0] } } return ips, cnameTarget }