diff --git a/README.md b/README.md index 972a8ac..0a9d555 100644 --- a/README.md +++ b/README.md @@ -376,6 +376,13 @@ docker run -d \ --- +## Planned Future Features (Post-1.0) + +- **DNSSEC validation**: Validate the DNSSEC chain of trust during + iterative resolution and report DNSSEC failures as notifications. + +--- + ## Project Structure Follows the conventions defined in `CONVENTIONS.md`, adapted from the diff --git a/go.mod b/go.mod index 32ad532..58794b3 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/cors v1.2.2 github.com/joho/godotenv v1.5.1 + github.com/miekg/dns v1.1.72 github.com/prometheus/client_golang v1.23.2 github.com/spf13/viper v1.21.0 + github.com/stretchr/testify v1.11.1 go.uber.org/fx v1.24.0 golang.org/x/net v0.50.0 ) @@ -16,10 +18,12 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect @@ -34,7 +38,11 @@ require ( go.uber.org/zap v1.26.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/mod v0.32.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect + golang.org/x/tools v0.41.0 // indirect google.golang.org/protobuf v1.36.8 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 66cc528..720b18f 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= @@ -74,12 +76,18 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/resolver/dns_client.go b/internal/resolver/dns_client.go new file mode 100644 index 0000000..589c657 --- /dev/null +++ b/internal/resolver/dns_client.go @@ -0,0 +1,48 @@ +package resolver + +import ( + "context" + "time" + + "github.com/miekg/dns" +) + +// DNSClient abstracts DNS wire-protocol exchanges so the resolver +// can be tested without hitting real nameservers. +type DNSClient interface { + ExchangeContext( + ctx context.Context, + msg *dns.Msg, + addr string, + ) (*dns.Msg, time.Duration, error) +} + +// udpClient wraps a real dns.Client for production use. +type udpClient struct { + timeout time.Duration +} + +func (c *udpClient) ExchangeContext( + ctx context.Context, + msg *dns.Msg, + addr string, +) (*dns.Msg, time.Duration, error) { + cl := &dns.Client{Timeout: c.timeout} + + return cl.ExchangeContext(ctx, msg, addr) +} + +// tcpClient wraps a real dns.Client using TCP. +type tcpClient struct { + timeout time.Duration +} + +func (c *tcpClient) ExchangeContext( + ctx context.Context, + msg *dns.Msg, + addr string, +) (*dns.Msg, time.Duration, error) { + cl := &dns.Client{Net: "tcp", Timeout: c.timeout} + + return cl.ExchangeContext(ctx, msg, addr) +} diff --git a/internal/resolver/errors.go b/internal/resolver/errors.go new file mode 100644 index 0000000..94bc313 --- /dev/null +++ b/internal/resolver/errors.go @@ -0,0 +1,27 @@ +package resolver + +import "errors" + +// Sentinel errors returned by the resolver. +var ( + // ErrNotImplemented indicates a method is stubbed out. + ErrNotImplemented = errors.New( + "resolver not yet implemented", + ) + + // ErrNoNameservers is returned when no authoritative NS + // could be discovered for a domain. + ErrNoNameservers = errors.New( + "no authoritative nameservers found", + ) + + // ErrCNAMEDepthExceeded is returned when a CNAME chain + // exceeds MaxCNAMEDepth. + ErrCNAMEDepthExceeded = errors.New( + "CNAME chain depth exceeded", + ) + + // ErrContextCanceled wraps context cancellation for the + // resolver's iterative queries. + ErrContextCanceled = errors.New("context canceled") +) diff --git a/internal/resolver/iterative.go b/internal/resolver/iterative.go new file mode 100644 index 0000000..8f41b6d --- /dev/null +++ b/internal/resolver/iterative.go @@ -0,0 +1,725 @@ +package resolver + +import ( + "context" + "errors" + "fmt" + "net" + "sort" + "strings" + "time" + + "github.com/miekg/dns" +) + +const ( + queryTimeoutDuration = 5 * time.Second + maxRetries = 2 + maxDelegation = 20 + timeoutMultiplier = 2 + minDomainLabels = 2 +) + +// ErrRefused is returned when a DNS server refuses a query. +var ErrRefused = errors.New("dns query refused") + +func rootServerList() []string { + return []string{ + "198.41.0.4", // a.root-servers.net + "170.247.170.2", // b + "192.33.4.12", // c + "199.7.91.13", // d + "192.203.230.10", // e + "192.5.5.241", // f + "192.112.36.4", // g + "198.97.190.53", // h + "192.36.148.17", // i + "192.58.128.30", // j + "193.0.14.129", // k + "199.7.83.42", // l + "202.12.27.33", // m + } +} + +func checkCtx(ctx context.Context) error { + err := ctx.Err() + if err != nil { + return ErrContextCanceled + } + + return nil +} + +func (r *Resolver) exchangeWithTimeout( + ctx context.Context, + msg *dns.Msg, + addr string, + attempt int, +) (*dns.Msg, error) { + _ = attempt // timeout escalation handled by client config + + resp, _, err := r.client.ExchangeContext(ctx, msg, addr) + + return resp, err +} + +func (r *Resolver) tryExchange( + ctx context.Context, + msg *dns.Msg, + addr string, +) (*dns.Msg, error) { + var resp *dns.Msg + + var err error + + for attempt := range maxRetries { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + resp, err = r.exchangeWithTimeout( + ctx, msg, addr, attempt, + ) + if err == nil { + break + } + } + + return resp, err +} + +func (r *Resolver) retryTCP( + ctx context.Context, + msg *dns.Msg, + addr string, + resp *dns.Msg, +) *dns.Msg { + if !resp.Truncated { + return resp + } + + tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr) + if tcpErr == nil { + return tcpResp + } + + return resp +} + +// queryDNS sends a DNS query to a specific server IP. +// Tries non-recursive first, falls back to recursive on +// REFUSED (handles DNS interception environments). +func (r *Resolver) queryDNS( + ctx context.Context, + serverIP string, + name string, + qtype uint16, +) (*dns.Msg, error) { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + name = dns.Fqdn(name) + addr := net.JoinHostPort(serverIP, "53") + + msg := new(dns.Msg) + msg.SetQuestion(name, qtype) + msg.RecursionDesired = false + + resp, err := r.tryExchange(ctx, msg, addr) + if err != nil { + return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err) + } + + if resp.Rcode == dns.RcodeRefused { + msg.RecursionDesired = true + + resp, err = r.tryExchange(ctx, msg, addr) + if err != nil { + return nil, fmt.Errorf( + "query %s @%s: %w", name, serverIP, err, + ) + } + + if resp.Rcode == dns.RcodeRefused { + return nil, fmt.Errorf( + "query %s @%s: %w", name, serverIP, ErrRefused, + ) + } + } + + resp = r.retryTCP(ctx, msg, addr, resp) + + return resp, nil +} + +func extractNSSet(rrs []dns.RR) []string { + nsSet := make(map[string]bool) + + for _, rr := range rrs { + if ns, ok := rr.(*dns.NS); ok { + nsSet[strings.ToLower(ns.Ns)] = true + } + } + + names := make([]string, 0, len(nsSet)) + for n := range nsSet { + names = append(names, n) + } + + sort.Strings(names) + + return names +} + +func extractGlue(rrs []dns.RR) map[string][]net.IP { + glue := make(map[string][]net.IP) + + for _, rr := range rrs { + switch r := rr.(type) { + case *dns.A: + name := strings.ToLower(r.Hdr.Name) + glue[name] = append(glue[name], r.A) + case *dns.AAAA: + name := strings.ToLower(r.Hdr.Name) + glue[name] = append(glue[name], r.AAAA) + } + } + + return glue +} + +func glueIPs(nsNames []string, glue map[string][]net.IP) []string { + var ips []string + + for _, ns := range nsNames { + for _, addr := range glue[ns] { + if v4 := addr.To4(); v4 != nil { + ips = append(ips, v4.String()) + } + } + } + + return ips +} + +func (r *Resolver) followDelegation( + ctx context.Context, + domain string, + servers []string, +) ([]string, error) { + for range maxDelegation { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + resp, err := r.queryServers( + ctx, servers, domain, dns.TypeNS, + ) + if err != nil { + return nil, err + } + + ansNS := extractNSSet(resp.Answer) + if len(ansNS) > 0 { + return ansNS, nil + } + + authNS := extractNSSet(resp.Ns) + if len(authNS) == 0 { + return r.resolveNSRecursive(ctx, domain) + } + + glue := extractGlue(resp.Extra) + nextServers := glueIPs(authNS, glue) + + if len(nextServers) == 0 { + nextServers = r.resolveNSIPs(ctx, authNS) + } + + if len(nextServers) == 0 { + return nil, ErrNoNameservers + } + + servers = nextServers + } + + return nil, ErrNoNameservers +} + +func (r *Resolver) queryServers( + ctx context.Context, + servers []string, + name string, + qtype uint16, +) (*dns.Msg, error) { + var lastErr error + + for _, ip := range servers { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + resp, err := r.queryDNS(ctx, ip, name, qtype) + if err == nil { + return resp, nil + } + + lastErr = err + } + + return nil, fmt.Errorf("all servers failed: %w", lastErr) +} + +func (r *Resolver) resolveNSIPs( + ctx context.Context, + nsNames []string, +) []string { + var ips []string + + for _, ns := range nsNames { + resolved, err := r.resolveARecord(ctx, ns) + if err == nil { + ips = append(ips, resolved...) + } + + if len(ips) > 0 { + break + } + } + + return ips +} + +// resolveNSRecursive queries for NS records using recursive +// resolution as a fallback for intercepted environments. +func (r *Resolver) resolveNSRecursive( + ctx context.Context, + domain string, +) ([]string, error) { + domain = dns.Fqdn(domain) + msg := new(dns.Msg) + msg.SetQuestion(domain, dns.TypeNS) + msg.RecursionDesired = true + + for _, ip := range rootServerList()[:3] { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + addr := net.JoinHostPort(ip, "53") + + resp, _, err := r.client.ExchangeContext(ctx, msg, addr) + if err != nil { + continue + } + + nsNames := extractNSSet(resp.Answer) + if len(nsNames) > 0 { + return nsNames, nil + } + } + + return nil, ErrNoNameservers +} + +// resolveARecord resolves a hostname to IPv4 addresses. +func (r *Resolver) resolveARecord( + ctx context.Context, + hostname string, +) ([]string, error) { + hostname = dns.Fqdn(hostname) + msg := new(dns.Msg) + msg.SetQuestion(hostname, dns.TypeA) + msg.RecursionDesired = true + + for _, ip := range rootServerList()[:3] { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + addr := net.JoinHostPort(ip, "53") + + resp, _, err := r.client.ExchangeContext(ctx, msg, addr) + if err != nil { + continue + } + + var ips []string + + for _, rr := range resp.Answer { + if a, ok := rr.(*dns.A); ok { + ips = append(ips, a.A.String()) + } + } + + if len(ips) > 0 { + return ips, nil + } + } + + return nil, fmt.Errorf( + "cannot resolve %s: %w", hostname, ErrNoNameservers, + ) +} + +// FindAuthoritativeNameservers traces the delegation chain from +// root servers to discover all authoritative nameservers for the +// given domain. Walks up the label hierarchy for subdomains. +func (r *Resolver) FindAuthoritativeNameservers( + ctx context.Context, + domain string, +) ([]string, error) { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + domain = dns.Fqdn(strings.ToLower(domain)) + labels := dns.SplitDomainName(domain) + + for i := range labels { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + candidate := strings.Join(labels[i:], ".") + "." + + nsNames, err := r.followDelegation( + ctx, candidate, rootServerList(), + ) + if err == nil && len(nsNames) > 0 { + sort.Strings(nsNames) + + return nsNames, nil + } + } + + return nil, ErrNoNameservers +} + +// QueryNameserver queries a specific nameserver for all record +// types and builds a NameserverResponse. +func (r *Resolver) QueryNameserver( + ctx context.Context, + nsHostname string, + hostname string, +) (*NameserverResponse, error) { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + nsIPs, err := r.resolveARecord(ctx, nsHostname) + if err != nil { + return nil, fmt.Errorf("resolving NS %s: %w", nsHostname, err) + } + + hostname = dns.Fqdn(hostname) + + return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname) +} + +func (r *Resolver) queryAllTypes( + ctx context.Context, + nsHostname string, + nsIP string, + hostname string, +) (*NameserverResponse, error) { + resp := &NameserverResponse{ + Nameserver: nsHostname, + Records: make(map[string][]string), + Status: StatusOK, + } + + qtypes := []uint16{ + dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, + dns.TypeMX, dns.TypeTXT, dns.TypeSRV, + dns.TypeCAA, dns.TypeNS, + } + + state := r.queryEachType(ctx, nsIP, hostname, qtypes, resp) + classifyResponse(resp, state) + + return resp, nil +} + +type queryState struct { + gotNXDomain bool + gotSERVFAIL bool + hasRecords bool +} + +func (r *Resolver) queryEachType( + ctx context.Context, + nsIP string, + hostname string, + qtypes []uint16, + resp *NameserverResponse, +) queryState { + var state queryState + + for _, qtype := range qtypes { + if checkCtx(ctx) != nil { + break + } + + r.querySingleType(ctx, nsIP, hostname, qtype, resp, &state) + } + + for k := range resp.Records { + sort.Strings(resp.Records[k]) + } + + return state +} + +func (r *Resolver) querySingleType( + ctx context.Context, + nsIP string, + hostname string, + qtype uint16, + resp *NameserverResponse, + state *queryState, +) { + msg, err := r.queryDNS(ctx, nsIP, hostname, qtype) + if err != nil { + return + } + + if msg.Rcode == dns.RcodeNameError { + state.gotNXDomain = true + + return + } + + if msg.Rcode == dns.RcodeServerFailure { + state.gotSERVFAIL = true + + return + } + + collectAnswerRecords(msg, resp, state) +} + +func collectAnswerRecords( + msg *dns.Msg, + resp *NameserverResponse, + state *queryState, +) { + for _, rr := range msg.Answer { + val := extractRecordValue(rr) + if val == "" { + continue + } + + typeName := dns.TypeToString[rr.Header().Rrtype] + resp.Records[typeName] = append( + resp.Records[typeName], val, + ) + state.hasRecords = true + } +} + +func classifyResponse(resp *NameserverResponse, state queryState) { + switch { + case state.gotNXDomain && !state.hasRecords: + resp.Status = StatusNXDomain + case state.gotSERVFAIL && !state.hasRecords: + resp.Status = StatusError + case !state.hasRecords && !state.gotNXDomain: + resp.Status = StatusNoData + } +} + +// extractRecordValue formats a DNS RR value as a string. +func extractRecordValue(rr dns.RR) string { + switch r := rr.(type) { + case *dns.A: + return r.A.String() + case *dns.AAAA: + return r.AAAA.String() + case *dns.CNAME: + return r.Target + case *dns.MX: + return fmt.Sprintf("%d %s", r.Preference, r.Mx) + case *dns.TXT: + return strings.Join(r.Txt, "") + case *dns.SRV: + return fmt.Sprintf( + "%d %d %d %s", + r.Priority, r.Weight, r.Port, r.Target, + ) + case *dns.CAA: + return fmt.Sprintf( + "%d %s \"%s\"", r.Flag, r.Tag, r.Value, + ) + case *dns.NS: + return r.Ns + default: + return "" + } +} + +// parentDomain returns the registerable parent domain. +func parentDomain(hostname string) string { + hostname = dns.Fqdn(strings.ToLower(hostname)) + labels := dns.SplitDomainName(hostname) + + if len(labels) <= minDomainLabels { + return strings.Join(labels, ".") + "." + } + + return strings.Join( + labels[len(labels)-minDomainLabels:], ".", + ) + "." +} + +// QueryAllNameservers discovers auth NSes for the hostname's +// parent domain, then queries each one independently. +func (r *Resolver) QueryAllNameservers( + ctx context.Context, + hostname string, +) (map[string]*NameserverResponse, error) { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + parent := parentDomain(hostname) + + nameservers, err := r.FindAuthoritativeNameservers(ctx, parent) + if err != nil { + return nil, err + } + + return r.queryEachNS(ctx, nameservers, hostname) +} + +func (r *Resolver) queryEachNS( + ctx context.Context, + nameservers []string, + hostname string, +) (map[string]*NameserverResponse, error) { + results := make(map[string]*NameserverResponse) + + for _, ns := range nameservers { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + resp, err := r.QueryNameserver(ctx, ns, hostname) + if err != nil { + results[ns] = &NameserverResponse{ + Nameserver: ns, + Records: make(map[string][]string), + Status: StatusError, + Error: err.Error(), + } + + continue + } + + results[ns] = resp + } + + return results, nil +} + +// LookupNS returns the NS record set for a domain. +func (r *Resolver) LookupNS( + ctx context.Context, + domain string, +) ([]string, error) { + return r.FindAuthoritativeNameservers(ctx, domain) +} + +// LookupAllRecords performs iterative resolution to find all DNS +// records for the given hostname, keyed by authoritative nameserver. +func (r *Resolver) LookupAllRecords( + ctx context.Context, + hostname string, +) (map[string]map[string][]string, error) { + results, err := r.QueryAllNameservers(ctx, hostname) + if err != nil { + return nil, err + } + + out := make(map[string]map[string][]string, len(results)) + for ns, resp := range results { + out[ns] = resp.Records + } + + return out, nil +} + +// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6 +// addresses, following CNAME chains up to MaxCNAMEDepth. +func (r *Resolver) ResolveIPAddresses( + ctx context.Context, + hostname string, +) ([]string, error) { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + return r.resolveIPWithCNAME(ctx, hostname, 0) +} + +func (r *Resolver) resolveIPWithCNAME( + ctx context.Context, + hostname string, + depth int, +) ([]string, error) { + if depth > MaxCNAMEDepth { + return nil, ErrCNAMEDepthExceeded + } + + results, err := r.QueryAllNameservers(ctx, hostname) + if err != nil { + return nil, err + } + + ips, cnameTarget := collectIPs(results) + + if len(ips) == 0 && cnameTarget != "" { + return r.resolveIPWithCNAME(ctx, cnameTarget, depth+1) + } + + sort.Strings(ips) + + return ips, nil +} + +func collectIPs( + results map[string]*NameserverResponse, +) ([]string, string) { + seen := make(map[string]bool) + + var ips []string + + var cnameTarget string + + for _, resp := range results { + if resp.Status == StatusNXDomain { + continue + } + + for _, ip := range resp.Records["A"] { + if !seen[ip] { + seen[ip] = true + ips = append(ips, ip) + } + } + + for _, ip := range resp.Records["AAAA"] { + if !seen[ip] { + seen[ip] = true + ips = append(ips, ip) + } + } + + if len(resp.Records["CNAME"]) > 0 && cnameTarget == "" { + cnameTarget = resp.Records["CNAME"][0] + } + } + + return ips, cnameTarget +} diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 76432d9..889cdeb 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -1,9 +1,9 @@ // Package resolver provides iterative DNS resolution from root nameservers. +// It traces the full delegation chain from IANA root servers through TLD +// and domain nameservers, never relying on upstream recursive resolvers. package resolver import ( - "context" - "errors" "log/slog" "go.uber.org/fx" @@ -11,8 +11,16 @@ import ( "sneak.berlin/go/dnswatcher/internal/logger" ) -// ErrNotImplemented indicates the resolver is not yet implemented. -var ErrNotImplemented = errors.New("resolver not yet implemented") +// Query status constants matching the state model. +const ( + StatusOK = "ok" + StatusError = "error" + StatusNXDomain = "nxdomain" + StatusNoData = "nodata" +) + +// MaxCNAMEDepth is the maximum CNAME chain depth to follow. +const MaxCNAMEDepth = 10 // Params contains dependencies for Resolver. type Params struct { @@ -21,44 +29,54 @@ type Params struct { Logger *logger.Logger } -// Resolver performs iterative DNS resolution from root servers. -type Resolver struct { - log *slog.Logger +// NameserverResponse holds one nameserver's response for a query. +type NameserverResponse struct { + Nameserver string + Records map[string][]string + Status string + Error string } -// New creates a new Resolver instance. +// Resolver performs iterative DNS resolution from root servers. +type Resolver struct { + log *slog.Logger + client DNSClient + tcp DNSClient +} + +// New creates a new Resolver instance for use with uber/fx. func New( _ fx.Lifecycle, params Params, ) (*Resolver, error) { return &Resolver{ - log: params.Logger.Get(), + log: params.Logger.Get(), + client: &udpClient{timeout: queryTimeoutDuration}, + tcp: &tcpClient{timeout: queryTimeoutDuration}, }, nil } -// LookupNS performs iterative resolution to find authoritative -// nameservers for the given domain. -func (r *Resolver) LookupNS( - _ context.Context, - _ string, -) ([]string, error) { - return nil, ErrNotImplemented +// NewFromLogger creates a Resolver directly from an slog.Logger, +// useful for testing without the fx lifecycle. +func NewFromLogger(log *slog.Logger) *Resolver { + return &Resolver{ + log: log, + client: &udpClient{timeout: queryTimeoutDuration}, + tcp: &tcpClient{timeout: queryTimeoutDuration}, + } } -// LookupAllRecords performs iterative resolution to find all DNS -// records for the given hostname, keyed by authoritative nameserver. -func (r *Resolver) LookupAllRecords( - _ context.Context, - _ string, -) (map[string]map[string][]string, error) { - return nil, ErrNotImplemented +// NewFromLoggerWithClient creates a Resolver with a custom DNS +// client, useful for testing with mock DNS responses. +func NewFromLoggerWithClient( + log *slog.Logger, + client DNSClient, +) *Resolver { + return &Resolver{ + log: log, + client: client, + tcp: client, + } } -// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6 -// addresses, following CNAME chains. -func (r *Resolver) ResolveIPAddresses( - _ context.Context, - _ string, -) ([]string, error) { - return nil, ErrNotImplemented -} +// Method implementations are in iterative.go. diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go new file mode 100644 index 0000000..3b9d936 --- /dev/null +++ b/internal/resolver/resolver_test.go @@ -0,0 +1,634 @@ +package resolver_test + +import ( + "context" + "log/slog" + "net" + "os" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sneak.berlin/go/dnswatcher/internal/resolver" +) + +// ---------------------------------------------------------------- +// Test helpers +// ---------------------------------------------------------------- + +func newTestResolver(t *testing.T) *resolver.Resolver { + t.Helper() + + log := slog.New(slog.NewTextHandler( + os.Stderr, + &slog.HandlerOptions{Level: slog.LevelDebug}, + )) + + return resolver.NewFromLogger(log) +} + +func testContext(t *testing.T) context.Context { + t.Helper() + + ctx, cancel := context.WithTimeout( + context.Background(), 60*time.Second, + ) + t.Cleanup(cancel) + + return ctx +} + +func findOneNSForDomain( + t *testing.T, + r *resolver.Resolver, + ctx context.Context, //nolint:revive // test helper + domain string, +) string { + t.Helper() + + nameservers, err := r.FindAuthoritativeNameservers( + ctx, domain, + ) + require.NoError(t, err) + require.NotEmpty(t, nameservers) + + return nameservers[0] +} + +// ---------------------------------------------------------------- +// FindAuthoritativeNameservers tests +// ---------------------------------------------------------------- + +func TestFindAuthoritativeNameservers_ValidDomain( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + nameservers, err := r.FindAuthoritativeNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + require.NotEmpty(t, nameservers) + + hasGoogleNS := false + + for _, ns := range nameservers { + if strings.Contains(ns, "google") { + hasGoogleNS = true + + break + } + } + + assert.True(t, hasGoogleNS, + "expected google nameservers, got: %v", nameservers, + ) +} + +func TestFindAuthoritativeNameservers_Subdomain( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + nameservers, err := r.FindAuthoritativeNameservers( + ctx, "www.google.com", + ) + require.NoError(t, err) + require.NotEmpty(t, nameservers) +} + +func TestFindAuthoritativeNameservers_ReturnsSorted( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + nameservers, err := r.FindAuthoritativeNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + + assert.True( + t, + sort.StringsAreSorted(nameservers), + "nameservers should be sorted, got: %v", nameservers, + ) +} + +func TestFindAuthoritativeNameservers_Deterministic( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + first, err := r.FindAuthoritativeNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + + second, err := r.FindAuthoritativeNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + + assert.Equal(t, first, second) +} + +func TestFindAuthoritativeNameservers_TrailingDot( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + ns1, err := r.FindAuthoritativeNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + + ns2, err := r.FindAuthoritativeNameservers( + ctx, "google.com.", + ) + require.NoError(t, err) + + assert.Equal(t, ns1, ns2) +} + +func TestFindAuthoritativeNameservers_CloudflareDomain( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + nameservers, err := r.FindAuthoritativeNameservers( + ctx, "cloudflare.com", + ) + require.NoError(t, err) + require.NotEmpty(t, nameservers) + + for _, ns := range nameservers { + assert.True(t, strings.HasSuffix(ns, "."), + "NS should be FQDN with trailing dot: %s", ns, + ) + } +} + +// ---------------------------------------------------------------- +// QueryNameserver tests +// ---------------------------------------------------------------- + +func TestQueryNameserver_BasicA(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "google.com") + + resp, err := r.QueryNameserver( + ctx, ns, "www.google.com", + ) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Equal(t, resolver.StatusOK, resp.Status) + assert.Equal(t, ns, resp.Nameserver) + + hasRecords := len(resp.Records["A"]) > 0 || + len(resp.Records["CNAME"]) > 0 + assert.True(t, hasRecords, + "expected A or CNAME records for www.google.com", + ) +} + +func TestQueryNameserver_AAAA(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") + + resp, err := r.QueryNameserver( + ctx, ns, "cloudflare.com", + ) + require.NoError(t, err) + + aaaaRecords := resp.Records["AAAA"] + require.NotEmpty(t, aaaaRecords, + "cloudflare.com should have AAAA records", + ) + + for _, ip := range aaaaRecords { + parsed := net.ParseIP(ip) + require.NotNil(t, parsed, + "should be valid IP: %s", ip, + ) + } +} + +func TestQueryNameserver_MX(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "google.com") + + resp, err := r.QueryNameserver( + ctx, ns, "google.com", + ) + require.NoError(t, err) + + mxRecords := resp.Records["MX"] + require.NotEmpty(t, mxRecords, + "google.com should have MX records", + ) +} + +func TestQueryNameserver_TXT(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "google.com") + + resp, err := r.QueryNameserver( + ctx, ns, "google.com", + ) + require.NoError(t, err) + + txtRecords := resp.Records["TXT"] + require.NotEmpty(t, txtRecords, + "google.com should have TXT records", + ) + + hasSPF := false + + for _, txt := range txtRecords { + if strings.Contains(txt, "v=spf1") { + hasSPF = true + + break + } + } + + assert.True(t, hasSPF, + "google.com should have SPF TXT record", + ) +} + +func TestQueryNameserver_NXDomain(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "google.com") + + resp, err := r.QueryNameserver( + ctx, ns, + "this-surely-does-not-exist-xyz.google.com", + ) + require.NoError(t, err) + + assert.Equal(t, resolver.StatusNXDomain, resp.Status) +} + +func TestQueryNameserver_RecordsSorted(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "google.com") + + resp, err := r.QueryNameserver( + ctx, ns, "google.com", + ) + require.NoError(t, err) + + for recordType, values := range resp.Records { + assert.True( + t, + sort.StringsAreSorted(values), + "%s records should be sorted", recordType, + ) + } +} + +func TestQueryNameserver_ResponseIncludesNameserver( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") + + resp, err := r.QueryNameserver( + ctx, ns, "cloudflare.com", + ) + require.NoError(t, err) + + assert.Equal(t, ns, resp.Nameserver) +} + +func TestQueryNameserver_EmptyRecordsOnNXDomain( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "google.com") + + resp, err := r.QueryNameserver( + ctx, ns, + "this-surely-does-not-exist-xyz.google.com", + ) + require.NoError(t, err) + + totalRecords := 0 + for _, values := range resp.Records { + totalRecords += len(values) + } + + assert.Zero(t, totalRecords) +} + +func TestQueryNameserver_TrailingDotHandling(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + ns := findOneNSForDomain(t, r, ctx, "google.com") + + resp1, err := r.QueryNameserver( + ctx, ns, "google.com", + ) + require.NoError(t, err) + + resp2, err := r.QueryNameserver( + ctx, ns, "google.com.", + ) + require.NoError(t, err) + + assert.Equal(t, resp1.Status, resp2.Status) +} + +// ---------------------------------------------------------------- +// QueryAllNameservers tests +// ---------------------------------------------------------------- + +func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + results, err := r.QueryAllNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + require.NotEmpty(t, results) + + assert.GreaterOrEqual(t, len(results), 2) + + for ns, resp := range results { + assert.Equal(t, ns, resp.Nameserver) + } +} + +func TestQueryAllNameservers_AllReturnOK(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + results, err := r.QueryAllNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + + for ns, resp := range results { + assert.Equal( + t, resolver.StatusOK, resp.Status, + "NS %s should return OK", ns, + ) + } +} + +func TestQueryAllNameservers_NXDomainFromAllNS( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + results, err := r.QueryAllNameservers( + ctx, + "this-surely-does-not-exist-xyz.google.com", + ) + require.NoError(t, err) + + for ns, resp := range results { + assert.Equal( + t, resolver.StatusNXDomain, resp.Status, + "NS %s should return nxdomain", ns, + ) + } +} + +// ---------------------------------------------------------------- +// LookupNS tests +// ---------------------------------------------------------------- + +func TestLookupNS_ValidDomain(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + nameservers, err := r.LookupNS(ctx, "google.com") + require.NoError(t, err) + require.NotEmpty(t, nameservers) + + for _, ns := range nameservers { + assert.True(t, strings.HasSuffix(ns, "."), + "NS should have trailing dot: %s", ns, + ) + } +} + +func TestLookupNS_Sorted(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + nameservers, err := r.LookupNS(ctx, "google.com") + require.NoError(t, err) + + assert.True(t, sort.StringsAreSorted(nameservers)) +} + +func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + fromLookup, err := r.LookupNS(ctx, "google.com") + require.NoError(t, err) + + fromFind, err := r.FindAuthoritativeNameservers( + ctx, "google.com", + ) + require.NoError(t, err) + + assert.Equal(t, fromFind, fromLookup) +} + +// ---------------------------------------------------------------- +// ResolveIPAddresses tests +// ---------------------------------------------------------------- + +func TestResolveIPAddresses_ReturnsIPs(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + ips, err := r.ResolveIPAddresses(ctx, "google.com") + require.NoError(t, err) + require.NotEmpty(t, ips) + + for _, ip := range ips { + parsed := net.ParseIP(ip) + assert.NotNil(t, parsed, + "should be valid IP: %s", ip, + ) + } +} + +func TestResolveIPAddresses_Deduplicated(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + ips, err := r.ResolveIPAddresses(ctx, "google.com") + require.NoError(t, err) + + seen := make(map[string]bool) + + for _, ip := range ips { + assert.False(t, seen[ip], "duplicate IP: %s", ip) + seen[ip] = true + } +} + +func TestResolveIPAddresses_Sorted(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + ips, err := r.ResolveIPAddresses(ctx, "google.com") + require.NoError(t, err) + + assert.True(t, sort.StringsAreSorted(ips)) +} + +func TestResolveIPAddresses_NXDomainReturnsEmpty( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + ips, err := r.ResolveIPAddresses( + ctx, + "this-surely-does-not-exist-xyz.google.com", + ) + require.NoError(t, err) + assert.Empty(t, ips) +} + +func TestResolveIPAddresses_CloudflareDomain(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com") + require.NoError(t, err) + require.NotEmpty(t, ips) +} + +// ---------------------------------------------------------------- +// Context cancellation tests +// ---------------------------------------------------------------- + +func TestFindAuthoritativeNameservers_ContextCanceled( + t *testing.T, +) { + t.Parallel() + + r := newTestResolver(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := r.FindAuthoritativeNameservers(ctx, "google.com") + assert.Error(t, err) +} + +func TestQueryNameserver_ContextCanceled(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := r.QueryNameserver( + ctx, "ns1.google.com.", "google.com", + ) + assert.Error(t, err) +} + +func TestQueryAllNameservers_ContextCanceled(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := r.QueryAllNameservers(ctx, "google.com") + assert.Error(t, err) +} + +func TestResolveIPAddresses_ContextCanceled(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := r.ResolveIPAddresses(ctx, "google.com") + assert.Error(t, err) +}