Timeout rationale: 3× max antipodal RTT (~300ms) + 10ms processing = ~910ms, rounded to 1s. Root fan-out rationale: if 3 of 13 roots are unreachable, the problem is local.
746 lines
15 KiB
Go
746 lines
15 KiB
Go
package resolver
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"math/rand/v2"
|
||
"net"
|
||
"sort"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
const (
|
||
// queryTimeoutDuration is the per-exchange DNS timeout.
|
||
//
|
||
// Rationale: maximum RTT to antipodal root/TLD servers is
|
||
// ~300ms. We use 3× max RTT + 10ms processing ≈ 910ms,
|
||
// rounded to 1s. Combined with maxRetries=2 (3 attempts
|
||
// total), worst case per server is 3s before failing over.
|
||
queryTimeoutDuration = 1 * time.Second
|
||
maxRetries = 2
|
||
maxDelegation = 20
|
||
timeoutMultiplier = 2
|
||
minDomainLabels = 2
|
||
)
|
||
|
||
// ErrRefused is returned when a DNS server refuses a query.
|
||
var ErrRefused = errors.New("dns query refused")
|
||
|
||
func allRootServers() []string {
|
||
return []string{
|
||
"198.41.0.4", // a.root-servers.net
|
||
"170.247.170.2", // b
|
||
"192.33.4.12", // c
|
||
"199.7.91.13", // d
|
||
"192.203.230.10", // e
|
||
"192.5.5.241", // f
|
||
"192.112.36.4", // g
|
||
"198.97.190.53", // h
|
||
"192.36.148.17", // i
|
||
"192.58.128.30", // j
|
||
"193.0.14.129", // k
|
||
"199.7.83.42", // l
|
||
"202.12.27.33", // m
|
||
}
|
||
}
|
||
|
||
// rootServerList returns 3 randomly-selected root servers.
|
||
// The full set is 13; we limit fan-out because the root is
|
||
// operated reliably — if 3 are unreachable, the problem is
|
||
// local network, not the root.
|
||
func rootServerList() []string {
|
||
shuffled := allRootServers()
|
||
rand.Shuffle(len(shuffled), func(i, j int) {
|
||
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
|
||
})
|
||
|
||
return shuffled[:3]
|
||
}
|
||
|
||
func checkCtx(ctx context.Context) error {
|
||
err := ctx.Err()
|
||
if err != nil {
|
||
return ErrContextCanceled
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Resolver) exchangeWithTimeout(
|
||
ctx context.Context,
|
||
msg *dns.Msg,
|
||
addr string,
|
||
attempt int,
|
||
) (*dns.Msg, error) {
|
||
_ = attempt // timeout escalation handled by client config
|
||
|
||
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
|
||
|
||
return resp, err
|
||
}
|
||
|
||
func (r *Resolver) tryExchange(
|
||
ctx context.Context,
|
||
msg *dns.Msg,
|
||
addr string,
|
||
) (*dns.Msg, error) {
|
||
var resp *dns.Msg
|
||
|
||
var err error
|
||
|
||
for attempt := range maxRetries {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
resp, err = r.exchangeWithTimeout(
|
||
ctx, msg, addr, attempt,
|
||
)
|
||
if err == nil {
|
||
break
|
||
}
|
||
}
|
||
|
||
return resp, err
|
||
}
|
||
|
||
func (r *Resolver) retryTCP(
|
||
ctx context.Context,
|
||
msg *dns.Msg,
|
||
addr string,
|
||
resp *dns.Msg,
|
||
) *dns.Msg {
|
||
if !resp.Truncated {
|
||
return resp
|
||
}
|
||
|
||
tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr)
|
||
if tcpErr == nil {
|
||
return tcpResp
|
||
}
|
||
|
||
return resp
|
||
}
|
||
|
||
// queryDNS sends a DNS query to a specific server IP.
|
||
// Tries non-recursive first, falls back to recursive on
|
||
// REFUSED (handles DNS interception environments).
|
||
func (r *Resolver) queryDNS(
|
||
ctx context.Context,
|
||
serverIP string,
|
||
name string,
|
||
qtype uint16,
|
||
) (*dns.Msg, error) {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
name = dns.Fqdn(name)
|
||
addr := net.JoinHostPort(serverIP, "53")
|
||
|
||
msg := new(dns.Msg)
|
||
msg.SetQuestion(name, qtype)
|
||
msg.RecursionDesired = false
|
||
|
||
resp, err := r.tryExchange(ctx, msg, addr)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err)
|
||
}
|
||
|
||
if resp.Rcode == dns.RcodeRefused {
|
||
msg.RecursionDesired = true
|
||
|
||
resp, err = r.tryExchange(ctx, msg, addr)
|
||
if err != nil {
|
||
return nil, fmt.Errorf(
|
||
"query %s @%s: %w", name, serverIP, err,
|
||
)
|
||
}
|
||
|
||
if resp.Rcode == dns.RcodeRefused {
|
||
return nil, fmt.Errorf(
|
||
"query %s @%s: %w", name, serverIP, ErrRefused,
|
||
)
|
||
}
|
||
}
|
||
|
||
resp = r.retryTCP(ctx, msg, addr, resp)
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
func extractNSSet(rrs []dns.RR) []string {
|
||
nsSet := make(map[string]bool)
|
||
|
||
for _, rr := range rrs {
|
||
if ns, ok := rr.(*dns.NS); ok {
|
||
nsSet[strings.ToLower(ns.Ns)] = true
|
||
}
|
||
}
|
||
|
||
names := make([]string, 0, len(nsSet))
|
||
for n := range nsSet {
|
||
names = append(names, n)
|
||
}
|
||
|
||
sort.Strings(names)
|
||
|
||
return names
|
||
}
|
||
|
||
func extractGlue(rrs []dns.RR) map[string][]net.IP {
|
||
glue := make(map[string][]net.IP)
|
||
|
||
for _, rr := range rrs {
|
||
switch r := rr.(type) {
|
||
case *dns.A:
|
||
name := strings.ToLower(r.Hdr.Name)
|
||
glue[name] = append(glue[name], r.A)
|
||
case *dns.AAAA:
|
||
name := strings.ToLower(r.Hdr.Name)
|
||
glue[name] = append(glue[name], r.AAAA)
|
||
}
|
||
}
|
||
|
||
return glue
|
||
}
|
||
|
||
func glueIPs(nsNames []string, glue map[string][]net.IP) []string {
|
||
var ips []string
|
||
|
||
for _, ns := range nsNames {
|
||
for _, addr := range glue[ns] {
|
||
if v4 := addr.To4(); v4 != nil {
|
||
ips = append(ips, v4.String())
|
||
}
|
||
}
|
||
}
|
||
|
||
return ips
|
||
}
|
||
|
||
func (r *Resolver) followDelegation(
|
||
ctx context.Context,
|
||
domain string,
|
||
servers []string,
|
||
) ([]string, error) {
|
||
for range maxDelegation {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
resp, err := r.queryServers(
|
||
ctx, servers, domain, dns.TypeNS,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
ansNS := extractNSSet(resp.Answer)
|
||
if len(ansNS) > 0 {
|
||
return ansNS, nil
|
||
}
|
||
|
||
authNS := extractNSSet(resp.Ns)
|
||
if len(authNS) == 0 {
|
||
return r.resolveNSRecursive(ctx, domain)
|
||
}
|
||
|
||
glue := extractGlue(resp.Extra)
|
||
nextServers := glueIPs(authNS, glue)
|
||
|
||
if len(nextServers) == 0 {
|
||
nextServers = r.resolveNSIPs(ctx, authNS)
|
||
}
|
||
|
||
if len(nextServers) == 0 {
|
||
return nil, ErrNoNameservers
|
||
}
|
||
|
||
servers = nextServers
|
||
}
|
||
|
||
return nil, ErrNoNameservers
|
||
}
|
||
|
||
func (r *Resolver) queryServers(
|
||
ctx context.Context,
|
||
servers []string,
|
||
name string,
|
||
qtype uint16,
|
||
) (*dns.Msg, error) {
|
||
var lastErr error
|
||
|
||
for _, ip := range servers {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
resp, err := r.queryDNS(ctx, ip, name, qtype)
|
||
if err == nil {
|
||
return resp, nil
|
||
}
|
||
|
||
lastErr = err
|
||
}
|
||
|
||
return nil, fmt.Errorf("all servers failed: %w", lastErr)
|
||
}
|
||
|
||
func (r *Resolver) resolveNSIPs(
|
||
ctx context.Context,
|
||
nsNames []string,
|
||
) []string {
|
||
var ips []string
|
||
|
||
for _, ns := range nsNames {
|
||
resolved, err := r.resolveARecord(ctx, ns)
|
||
if err == nil {
|
||
ips = append(ips, resolved...)
|
||
}
|
||
|
||
if len(ips) > 0 {
|
||
break
|
||
}
|
||
}
|
||
|
||
return ips
|
||
}
|
||
|
||
// resolveNSRecursive queries for NS records using recursive
|
||
// resolution as a fallback for intercepted environments.
|
||
func (r *Resolver) resolveNSRecursive(
|
||
ctx context.Context,
|
||
domain string,
|
||
) ([]string, error) {
|
||
domain = dns.Fqdn(domain)
|
||
msg := new(dns.Msg)
|
||
msg.SetQuestion(domain, dns.TypeNS)
|
||
msg.RecursionDesired = true
|
||
|
||
for _, ip := range rootServerList() {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
addr := net.JoinHostPort(ip, "53")
|
||
|
||
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
nsNames := extractNSSet(resp.Answer)
|
||
if len(nsNames) > 0 {
|
||
return nsNames, nil
|
||
}
|
||
}
|
||
|
||
return nil, ErrNoNameservers
|
||
}
|
||
|
||
// resolveARecord resolves a hostname to IPv4 addresses.
|
||
func (r *Resolver) resolveARecord(
|
||
ctx context.Context,
|
||
hostname string,
|
||
) ([]string, error) {
|
||
hostname = dns.Fqdn(hostname)
|
||
msg := new(dns.Msg)
|
||
msg.SetQuestion(hostname, dns.TypeA)
|
||
msg.RecursionDesired = true
|
||
|
||
for _, ip := range rootServerList() {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
addr := net.JoinHostPort(ip, "53")
|
||
|
||
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
var ips []string
|
||
|
||
for _, rr := range resp.Answer {
|
||
if a, ok := rr.(*dns.A); ok {
|
||
ips = append(ips, a.A.String())
|
||
}
|
||
}
|
||
|
||
if len(ips) > 0 {
|
||
return ips, nil
|
||
}
|
||
}
|
||
|
||
return nil, fmt.Errorf(
|
||
"cannot resolve %s: %w", hostname, ErrNoNameservers,
|
||
)
|
||
}
|
||
|
||
// FindAuthoritativeNameservers traces the delegation chain from
|
||
// root servers to discover all authoritative nameservers for the
|
||
// given domain. Walks up the label hierarchy for subdomains.
|
||
func (r *Resolver) FindAuthoritativeNameservers(
|
||
ctx context.Context,
|
||
domain string,
|
||
) ([]string, error) {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
domain = dns.Fqdn(strings.ToLower(domain))
|
||
labels := dns.SplitDomainName(domain)
|
||
|
||
for i := range labels {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
candidate := strings.Join(labels[i:], ".") + "."
|
||
|
||
nsNames, err := r.followDelegation(
|
||
ctx, candidate, rootServerList(),
|
||
)
|
||
if err == nil && len(nsNames) > 0 {
|
||
sort.Strings(nsNames)
|
||
|
||
return nsNames, nil
|
||
}
|
||
}
|
||
|
||
return nil, ErrNoNameservers
|
||
}
|
||
|
||
// QueryNameserver queries a specific nameserver for all record
|
||
// types and builds a NameserverResponse.
|
||
func (r *Resolver) QueryNameserver(
|
||
ctx context.Context,
|
||
nsHostname string,
|
||
hostname string,
|
||
) (*NameserverResponse, error) {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
nsIPs, err := r.resolveARecord(ctx, nsHostname)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("resolving NS %s: %w", nsHostname, err)
|
||
}
|
||
|
||
hostname = dns.Fqdn(hostname)
|
||
|
||
return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname)
|
||
}
|
||
|
||
func (r *Resolver) queryAllTypes(
|
||
ctx context.Context,
|
||
nsHostname string,
|
||
nsIP string,
|
||
hostname string,
|
||
) (*NameserverResponse, error) {
|
||
resp := &NameserverResponse{
|
||
Nameserver: nsHostname,
|
||
Records: make(map[string][]string),
|
||
Status: StatusOK,
|
||
}
|
||
|
||
qtypes := []uint16{
|
||
dns.TypeA, dns.TypeAAAA, dns.TypeCNAME,
|
||
dns.TypeMX, dns.TypeTXT, dns.TypeSRV,
|
||
dns.TypeCAA, dns.TypeNS,
|
||
}
|
||
|
||
state := r.queryEachType(ctx, nsIP, hostname, qtypes, resp)
|
||
classifyResponse(resp, state)
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
type queryState struct {
|
||
gotNXDomain bool
|
||
gotSERVFAIL bool
|
||
hasRecords bool
|
||
}
|
||
|
||
func (r *Resolver) queryEachType(
|
||
ctx context.Context,
|
||
nsIP string,
|
||
hostname string,
|
||
qtypes []uint16,
|
||
resp *NameserverResponse,
|
||
) queryState {
|
||
var state queryState
|
||
|
||
for _, qtype := range qtypes {
|
||
if checkCtx(ctx) != nil {
|
||
break
|
||
}
|
||
|
||
r.querySingleType(ctx, nsIP, hostname, qtype, resp, &state)
|
||
}
|
||
|
||
for k := range resp.Records {
|
||
sort.Strings(resp.Records[k])
|
||
}
|
||
|
||
return state
|
||
}
|
||
|
||
func (r *Resolver) querySingleType(
|
||
ctx context.Context,
|
||
nsIP string,
|
||
hostname string,
|
||
qtype uint16,
|
||
resp *NameserverResponse,
|
||
state *queryState,
|
||
) {
|
||
msg, err := r.queryDNS(ctx, nsIP, hostname, qtype)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
if msg.Rcode == dns.RcodeNameError {
|
||
state.gotNXDomain = true
|
||
|
||
return
|
||
}
|
||
|
||
if msg.Rcode == dns.RcodeServerFailure {
|
||
state.gotSERVFAIL = true
|
||
|
||
return
|
||
}
|
||
|
||
collectAnswerRecords(msg, resp, state)
|
||
}
|
||
|
||
func collectAnswerRecords(
|
||
msg *dns.Msg,
|
||
resp *NameserverResponse,
|
||
state *queryState,
|
||
) {
|
||
for _, rr := range msg.Answer {
|
||
val := extractRecordValue(rr)
|
||
if val == "" {
|
||
continue
|
||
}
|
||
|
||
typeName := dns.TypeToString[rr.Header().Rrtype]
|
||
resp.Records[typeName] = append(
|
||
resp.Records[typeName], val,
|
||
)
|
||
state.hasRecords = true
|
||
}
|
||
}
|
||
|
||
func classifyResponse(resp *NameserverResponse, state queryState) {
|
||
switch {
|
||
case state.gotNXDomain && !state.hasRecords:
|
||
resp.Status = StatusNXDomain
|
||
case state.gotSERVFAIL && !state.hasRecords:
|
||
resp.Status = StatusError
|
||
case !state.hasRecords && !state.gotNXDomain:
|
||
resp.Status = StatusNoData
|
||
}
|
||
}
|
||
|
||
// extractRecordValue formats a DNS RR value as a string.
|
||
func extractRecordValue(rr dns.RR) string {
|
||
switch r := rr.(type) {
|
||
case *dns.A:
|
||
return r.A.String()
|
||
case *dns.AAAA:
|
||
return r.AAAA.String()
|
||
case *dns.CNAME:
|
||
return r.Target
|
||
case *dns.MX:
|
||
return fmt.Sprintf("%d %s", r.Preference, r.Mx)
|
||
case *dns.TXT:
|
||
return strings.Join(r.Txt, "")
|
||
case *dns.SRV:
|
||
return fmt.Sprintf(
|
||
"%d %d %d %s",
|
||
r.Priority, r.Weight, r.Port, r.Target,
|
||
)
|
||
case *dns.CAA:
|
||
return fmt.Sprintf(
|
||
"%d %s \"%s\"", r.Flag, r.Tag, r.Value,
|
||
)
|
||
case *dns.NS:
|
||
return r.Ns
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// parentDomain returns the registerable parent domain.
|
||
func parentDomain(hostname string) string {
|
||
hostname = dns.Fqdn(strings.ToLower(hostname))
|
||
labels := dns.SplitDomainName(hostname)
|
||
|
||
if len(labels) <= minDomainLabels {
|
||
return strings.Join(labels, ".") + "."
|
||
}
|
||
|
||
return strings.Join(
|
||
labels[len(labels)-minDomainLabels:], ".",
|
||
) + "."
|
||
}
|
||
|
||
// QueryAllNameservers discovers auth NSes for the hostname's
|
||
// parent domain, then queries each one independently.
|
||
func (r *Resolver) QueryAllNameservers(
|
||
ctx context.Context,
|
||
hostname string,
|
||
) (map[string]*NameserverResponse, error) {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
parent := parentDomain(hostname)
|
||
|
||
nameservers, err := r.FindAuthoritativeNameservers(ctx, parent)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return r.queryEachNS(ctx, nameservers, hostname)
|
||
}
|
||
|
||
func (r *Resolver) queryEachNS(
|
||
ctx context.Context,
|
||
nameservers []string,
|
||
hostname string,
|
||
) (map[string]*NameserverResponse, error) {
|
||
results := make(map[string]*NameserverResponse)
|
||
|
||
for _, ns := range nameservers {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
resp, err := r.QueryNameserver(ctx, ns, hostname)
|
||
if err != nil {
|
||
results[ns] = &NameserverResponse{
|
||
Nameserver: ns,
|
||
Records: make(map[string][]string),
|
||
Status: StatusError,
|
||
Error: err.Error(),
|
||
}
|
||
|
||
continue
|
||
}
|
||
|
||
results[ns] = resp
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// LookupNS returns the NS record set for a domain.
|
||
func (r *Resolver) LookupNS(
|
||
ctx context.Context,
|
||
domain string,
|
||
) ([]string, error) {
|
||
return r.FindAuthoritativeNameservers(ctx, domain)
|
||
}
|
||
|
||
// LookupAllRecords performs iterative resolution to find all DNS
|
||
// records for the given hostname, keyed by authoritative nameserver.
|
||
func (r *Resolver) LookupAllRecords(
|
||
ctx context.Context,
|
||
hostname string,
|
||
) (map[string]map[string][]string, error) {
|
||
results, err := r.QueryAllNameservers(ctx, hostname)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
out := make(map[string]map[string][]string, len(results))
|
||
for ns, resp := range results {
|
||
out[ns] = resp.Records
|
||
}
|
||
|
||
return out, nil
|
||
}
|
||
|
||
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
|
||
// addresses, following CNAME chains up to MaxCNAMEDepth.
|
||
func (r *Resolver) ResolveIPAddresses(
|
||
ctx context.Context,
|
||
hostname string,
|
||
) ([]string, error) {
|
||
if checkCtx(ctx) != nil {
|
||
return nil, ErrContextCanceled
|
||
}
|
||
|
||
return r.resolveIPWithCNAME(ctx, hostname, 0)
|
||
}
|
||
|
||
func (r *Resolver) resolveIPWithCNAME(
|
||
ctx context.Context,
|
||
hostname string,
|
||
depth int,
|
||
) ([]string, error) {
|
||
if depth > MaxCNAMEDepth {
|
||
return nil, ErrCNAMEDepthExceeded
|
||
}
|
||
|
||
results, err := r.QueryAllNameservers(ctx, hostname)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
ips, cnameTarget := collectIPs(results)
|
||
|
||
if len(ips) == 0 && cnameTarget != "" {
|
||
return r.resolveIPWithCNAME(ctx, cnameTarget, depth+1)
|
||
}
|
||
|
||
sort.Strings(ips)
|
||
|
||
return ips, nil
|
||
}
|
||
|
||
func collectIPs(
|
||
results map[string]*NameserverResponse,
|
||
) ([]string, string) {
|
||
seen := make(map[string]bool)
|
||
|
||
var ips []string
|
||
|
||
var cnameTarget string
|
||
|
||
for _, resp := range results {
|
||
if resp.Status == StatusNXDomain {
|
||
continue
|
||
}
|
||
|
||
for _, ip := range resp.Records["A"] {
|
||
if !seen[ip] {
|
||
seen[ip] = true
|
||
ips = append(ips, ip)
|
||
}
|
||
}
|
||
|
||
for _, ip := range resp.Records["AAAA"] {
|
||
if !seen[ip] {
|
||
seen[ip] = true
|
||
ips = append(ips, ip)
|
||
}
|
||
}
|
||
|
||
if len(resp.Records["CNAME"]) > 0 && cnameTarget == "" {
|
||
cnameTarget = resp.Records["CNAME"][0]
|
||
}
|
||
}
|
||
|
||
return ips, cnameTarget
|
||
}
|