package resolver_test import ( "context" "fmt" "log/slog" "net" "os" "sort" "strings" "testing" "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/dnswatcher/internal/resolver" ) // ---------------------------------------------------------------- // Mock DNS client // ---------------------------------------------------------------- // mockDNSClient implements resolver.DNSClient with canned responses. type mockDNSClient struct { handlers map[string]func(msg *dns.Msg) *dns.Msg } func newMockClient() *mockDNSClient { return &mockDNSClient{ handlers: make(map[string]func(msg *dns.Msg) *dns.Msg), } } func (m *mockDNSClient) ExchangeContext( ctx context.Context, msg *dns.Msg, addr string, ) (*dns.Msg, time.Duration, error) { err := ctx.Err() if err != nil { return nil, 0, err } host, _, _ := net.SplitHostPort(addr) if host == "" { host = addr } qname := msg.Question[0].Name qtype := dns.TypeToString[msg.Question[0].Qtype] resp := m.findHandler(host, qname, qtype, msg) return resp, time.Millisecond, nil } func (m *mockDNSClient) findHandler( host, qname, qtype string, msg *dns.Msg, ) *dns.Msg { key := fmt.Sprintf( "%s|%s|%s", host, strings.ToLower(qname), qtype, ) if h, ok := m.handlers[key]; ok { return h(msg) } wildKey := fmt.Sprintf( "*|%s|%s", strings.ToLower(qname), qtype, ) if h, ok := m.handlers[wildKey]; ok { return h(msg) } resp := new(dns.Msg) resp.SetReply(msg) return resp } func (m *mockDNSClient) on( server, qname, qtype string, handler func(msg *dns.Msg) *dns.Msg, ) { key := fmt.Sprintf( "%s|%s|%s", server, dns.Fqdn(strings.ToLower(qname)), qtype, ) m.handlers[key] = handler } // ---------------------------------------------------------------- // Response builders // ---------------------------------------------------------------- func referralResponse( msg *dns.Msg, nsNames []string, glue map[string]string, ) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) for _, ns := range nsNames { resp.Ns = append(resp.Ns, &dns.NS{ Hdr: dns.RR_Header{ Name: msg.Question[0].Name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: dns.Fqdn(ns), }) } for name, ip := range glue { resp.Extra = append(resp.Extra, &dns.A{ Hdr: dns.RR_Header{ Name: dns.Fqdn(name), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600, }, A: net.ParseIP(ip), }) } return resp } func nsAnswerResponse( msg *dns.Msg, nsNames []string, ) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) for _, ns := range nsNames { resp.Answer = append(resp.Answer, &dns.NS{ Hdr: dns.RR_Header{ Name: msg.Question[0].Name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: dns.Fqdn(ns), }) } return resp } func nxdomainResponse(msg *dns.Msg) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) resp.Rcode = dns.RcodeNameError return resp } func aResponse( msg *dns.Msg, name string, ip string, ) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) resp.Answer = append(resp.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: dns.Fqdn(name), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, }, A: net.ParseIP(ip), }) return resp } func aaaaResponse( msg *dns.Msg, name string, ip string, ) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) resp.Answer = append(resp.Answer, &dns.AAAA{ Hdr: dns.RR_Header{ Name: dns.Fqdn(name), Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 300, }, AAAA: net.ParseIP(ip), }) return resp } func emptyResponse(msg *dns.Msg) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) return resp } // ---------------------------------------------------------------- // Mock DNS hierarchy setup // ---------------------------------------------------------------- // mockData holds all test DNS hierarchy configuration. type mockData struct { tldNS []string tldGlue map[string]string exNS []string exGlue map[string]string cfNS []string cfGlue map[string]string } func newMockData() mockData { return mockData{ tldNS: []string{"ns1.tld.com", "ns2.tld.com"}, tldGlue: map[string]string{ "ns1.tld.com": "10.0.0.1", "ns2.tld.com": "10.0.0.2", }, exNS: []string{ "ns1.example.com", "ns2.example.com", "ns3.example.com", }, exGlue: map[string]string{ "ns1.example.com": "10.1.0.1", "ns2.example.com": "10.1.0.2", "ns3.example.com": "10.1.0.3", }, cfNS: []string{ "ns1.cloudflare.com", "ns2.cloudflare.com", }, cfGlue: map[string]string{ "ns1.cloudflare.com": "10.2.0.1", "ns2.cloudflare.com": "10.2.0.2", }, } } func rootIPList() []string { return []string{ "198.41.0.4", "170.247.170.2", "192.33.4.12", "199.7.91.13", "192.203.230.10", "192.5.5.241", "192.112.36.4", "198.97.190.53", "192.36.148.17", "192.58.128.30", "193.0.14.129", "199.7.83.42", "202.12.27.33", } } func allQueryTypes() []string { return []string{ "NS", "A", "AAAA", "CNAME", "MX", "TXT", "SRV", "CAA", } } func setupRootDelegations( m *mockDNSClient, tNS []string, tGlue map[string]string, ) { domains := []string{ "example.com.", "www.example.com.", "this-surely-does-not-exist-xyz.example.com.", "cloudflare.com.", } for _, rootIP := range rootIPList() { for _, domain := range domains { for _, qtype := range allQueryTypes() { m.on(rootIP, domain, qtype, func(msg *dns.Msg) *dns.Msg { return referralResponse( msg, tNS, tGlue, ) }, ) } } } } func setupRootARecords(m *mockDNSClient) { nsIPs := map[string]string{ "ns1.example.com.": "10.1.0.1", "ns2.example.com.": "10.1.0.2", "ns3.example.com.": "10.1.0.3", "ns1.cloudflare.com.": "10.2.0.1", "ns2.cloudflare.com.": "10.2.0.2", } for _, rootIP := range rootIPList() { for nsName, nsIP := range nsIPs { ip := nsIP name := nsName m.on(rootIP, name, "A", func(msg *dns.Msg) *dns.Msg { return aResponse(msg, name, ip) }, ) } } } func setupTLDDelegations( m *mockDNSClient, exNS []string, exGlue map[string]string, cfNS []string, cfGlue map[string]string, ) { tldIPs := []string{"10.0.0.1", "10.0.0.2"} exDomains := []string{ "example.com.", "www.example.com.", "this-surely-does-not-exist-xyz.example.com.", } for _, tldIP := range tldIPs { for _, domain := range exDomains { for _, qtype := range allQueryTypes() { m.on(tldIP, domain, qtype, func(msg *dns.Msg) *dns.Msg { return referralResponse( msg, exNS, exGlue, ) }, ) } } for _, qtype := range allQueryTypes() { m.on(tldIP, "cloudflare.com.", qtype, func(msg *dns.Msg) *dns.Msg { return referralResponse( msg, cfNS, cfGlue, ) }, ) } } } func setupExampleNSAndA( m *mockDNSClient, exNS []string, ) { exIPs := []string{"10.1.0.1", "10.1.0.2", "10.1.0.3"} for _, authIP := range exIPs { m.on(authIP, "example.com.", "NS", func(msg *dns.Msg) *dns.Msg { return nsAnswerResponse(msg, exNS) }, ) m.on(authIP, "example.com.", "A", func(msg *dns.Msg) *dns.Msg { return aResponse( msg, "example.com.", "93.184.216.34", ) }, ) m.on(authIP, "example.com.", "AAAA", func(msg *dns.Msg) *dns.Msg { return aaaaResponse( msg, "example.com.", "2606:2800:220:1:248:1893:25c8:1946", ) }, ) } } func setupExampleMXAndTXT(m *mockDNSClient) { exIPs := []string{"10.1.0.1", "10.1.0.2", "10.1.0.3"} for _, authIP := range exIPs { m.on(authIP, "example.com.", "MX", func(msg *dns.Msg) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) resp.Answer = append(resp.Answer, &dns.MX{ Hdr: dns.RR_Header{ Name: "example.com.", Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 300, }, Preference: 10, Mx: "mail.example.com.", }, &dns.MX{ Hdr: dns.RR_Header{ Name: "example.com.", Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 300, }, Preference: 20, Mx: "mail2.example.com.", }, ) return resp }, ) m.on(authIP, "example.com.", "TXT", func(msg *dns.Msg) *dns.Msg { resp := new(dns.Msg) resp.SetReply(msg) resp.Answer = append(resp.Answer, &dns.TXT{ Hdr: dns.RR_Header{ Name: "example.com.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 300, }, Txt: []string{ "v=spf1 include:_spf.example.com ~all", }, }) return resp }, ) } } func setupExampleSubdomains( m *mockDNSClient, exNS []string, ) { exIPs := []string{"10.1.0.1", "10.1.0.2", "10.1.0.3"} for _, authIP := range exIPs { m.on(authIP, "www.example.com.", "NS", func(msg *dns.Msg) *dns.Msg { return nsAnswerResponse(msg, exNS) }, ) m.on(authIP, "www.example.com.", "A", func(msg *dns.Msg) *dns.Msg { return aResponse( msg, "www.example.com.", "93.184.216.34", ) }, ) nxName := "this-surely-does-not-exist-xyz.example.com." for _, qtype := range allQueryTypes() { m.on(authIP, nxName, qtype, nxdomainResponse) } } } func setupCloudflareAuthRecords( m *mockDNSClient, cfNS []string, ) { cfIPs := []string{"10.2.0.1", "10.2.0.2"} for _, authIP := range cfIPs { m.on(authIP, "cloudflare.com.", "NS", func(msg *dns.Msg) *dns.Msg { return nsAnswerResponse(msg, cfNS) }, ) m.on(authIP, "cloudflare.com.", "A", func(msg *dns.Msg) *dns.Msg { return aResponse( msg, "cloudflare.com.", "104.16.132.229", ) }, ) m.on(authIP, "cloudflare.com.", "AAAA", func(msg *dns.Msg) *dns.Msg { return aaaaResponse( msg, "cloudflare.com.", "2606:4700::6810:84e5", ) }, ) m.on(authIP, "cloudflare.com.", "MX", emptyResponse) m.on(authIP, "cloudflare.com.", "TXT", emptyResponse) } } func setupMockDNS() *mockDNSClient { m := newMockClient() d := newMockData() setupRootDelegations(m, d.tldNS, d.tldGlue) setupRootARecords(m) setupTLDDelegations(m, d.exNS, d.exGlue, d.cfNS, d.cfGlue) setupExampleNSAndA(m, d.exNS) setupExampleMXAndTXT(m) setupExampleSubdomains(m, d.exNS) setupCloudflareAuthRecords(m, d.cfNS) return m } // ---------------------------------------------------------------- // Test helpers // ---------------------------------------------------------------- func newTestResolver(t *testing.T) *resolver.Resolver { t.Helper() log := slog.New(slog.NewTextHandler( os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}, )) return resolver.NewFromLoggerWithClient(log, setupMockDNS()) } func testContext(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout( context.Background(), 10*time.Second, ) t.Cleanup(cancel) return ctx } func findOneNSForDomain( t *testing.T, r *resolver.Resolver, ctx context.Context, //nolint:revive // test helper domain string, ) string { t.Helper() nameservers, err := r.FindAuthoritativeNameservers( ctx, domain, ) require.NoError(t, err) require.NotEmpty(t, nameservers) return nameservers[0] } // ---------------------------------------------------------------- // FindAuthoritativeNameservers tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ValidDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) hasExampleNS := false for _, ns := range nameservers { if strings.Contains(ns, "example") { hasExampleNS = true break } } assert.True(t, hasExampleNS, "expected example nameservers, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Subdomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "www.example.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) } func TestFindAuthoritativeNameservers_ReturnsSorted( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) assert.True( t, sort.StringsAreSorted(nameservers), "nameservers should be sorted, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Deterministic( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) first, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) second, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) assert.Equal(t, first, second) } func TestFindAuthoritativeNameservers_TrailingDot( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns1, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) ns2, err := r.FindAuthoritativeNameservers( ctx, "example.com.", ) require.NoError(t, err) assert.Equal(t, ns1, ns2) } func TestFindAuthoritativeNameservers_CloudflareDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "cloudflare.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { assert.True(t, strings.HasSuffix(ns, "."), "NS should be FQDN with trailing dot: %s", ns, ) } } // ---------------------------------------------------------------- // QueryNameserver tests // ---------------------------------------------------------------- func TestQueryNameserver_BasicA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "example.com") resp, err := r.QueryNameserver( ctx, ns, "www.example.com", ) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) assert.Equal(t, ns, resp.Nameserver) hasRecords := len(resp.Records["A"]) > 0 || len(resp.Records["CNAME"]) > 0 assert.True(t, hasRecords, "expected A or CNAME records for www.example.com", ) } func TestQueryNameserver_AAAA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") resp, err := r.QueryNameserver( ctx, ns, "cloudflare.com", ) require.NoError(t, err) aaaaRecords := resp.Records["AAAA"] require.NotEmpty(t, aaaaRecords, "cloudflare.com should have AAAA records", ) for _, ip := range aaaaRecords { parsed := net.ParseIP(ip) require.NotNil(t, parsed, "should be valid IP: %s", ip, ) } } func TestQueryNameserver_MX(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "example.com") resp, err := r.QueryNameserver( ctx, ns, "example.com", ) require.NoError(t, err) mxRecords := resp.Records["MX"] require.NotEmpty(t, mxRecords, "example.com should have MX records", ) } func TestQueryNameserver_TXT(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "example.com") resp, err := r.QueryNameserver( ctx, ns, "example.com", ) require.NoError(t, err) txtRecords := resp.Records["TXT"] require.NotEmpty(t, txtRecords, "example.com should have TXT records", ) hasSPF := false for _, txt := range txtRecords { if strings.Contains(txt, "v=spf1") { hasSPF = true break } } assert.True(t, hasSPF, "example.com should have SPF TXT record", ) } func TestQueryNameserver_NXDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "example.com") resp, err := r.QueryNameserver( ctx, ns, "this-surely-does-not-exist-xyz.example.com", ) require.NoError(t, err) assert.Equal(t, resolver.StatusNXDomain, resp.Status) } func TestQueryNameserver_RecordsSorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "example.com") resp, err := r.QueryNameserver( ctx, ns, "example.com", ) require.NoError(t, err) for recordType, values := range resp.Records { assert.True( t, sort.StringsAreSorted(values), "%s records should be sorted", recordType, ) } } func TestQueryNameserver_ResponseIncludesNameserver( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") resp, err := r.QueryNameserver( ctx, ns, "cloudflare.com", ) require.NoError(t, err) assert.Equal(t, ns, resp.Nameserver) } func TestQueryNameserver_EmptyRecordsOnNXDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "example.com") resp, err := r.QueryNameserver( ctx, ns, "this-surely-does-not-exist-xyz.example.com", ) require.NoError(t, err) totalRecords := 0 for _, values := range resp.Records { totalRecords += len(values) } assert.Zero(t, totalRecords) } func TestQueryNameserver_TrailingDotHandling(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "example.com") resp1, err := r.QueryNameserver( ctx, ns, "example.com", ) require.NoError(t, err) resp2, err := r.QueryNameserver( ctx, ns, "example.com.", ) require.NoError(t, err) assert.Equal(t, resp1.Status, resp2.Status) } // ---------------------------------------------------------------- // QueryAllNameservers tests // ---------------------------------------------------------------- func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "example.com", ) require.NoError(t, err) require.NotEmpty(t, results) assert.GreaterOrEqual(t, len(results), 2) for ns, resp := range results { assert.Equal(t, ns, resp.Nameserver) } } func TestQueryAllNameservers_AllReturnOK(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "example.com", ) require.NoError(t, err) for ns, resp := range results { assert.Equal( t, resolver.StatusOK, resp.Status, "NS %s should return OK", ns, ) } } func TestQueryAllNameservers_NXDomainFromAllNS( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "this-surely-does-not-exist-xyz.example.com", ) require.NoError(t, err) for ns, resp := range results { assert.Equal( t, resolver.StatusNXDomain, resp.Status, "NS %s should return nxdomain", ns, ) } } // ---------------------------------------------------------------- // LookupNS tests // ---------------------------------------------------------------- func TestLookupNS_ValidDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "example.com") require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { assert.True(t, strings.HasSuffix(ns, "."), "NS should have trailing dot: %s", ns, ) } } func TestLookupNS_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "example.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(nameservers)) } func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) fromLookup, err := r.LookupNS(ctx, "example.com") require.NoError(t, err) fromFind, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) assert.Equal(t, fromFind, fromLookup) } // ---------------------------------------------------------------- // ResolveIPAddresses tests // ---------------------------------------------------------------- func TestResolveIPAddresses_ReturnsIPs(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "example.com") require.NoError(t, err) require.NotEmpty(t, ips) for _, ip := range ips { parsed := net.ParseIP(ip) assert.NotNil(t, parsed, "should be valid IP: %s", ip, ) } } func TestResolveIPAddresses_Deduplicated(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "example.com") require.NoError(t, err) seen := make(map[string]bool) for _, ip := range ips { assert.False(t, seen[ip], "duplicate IP: %s", ip) seen[ip] = true } } func TestResolveIPAddresses_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "example.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(ips)) } func TestResolveIPAddresses_NXDomainReturnsEmpty( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses( ctx, "this-surely-does-not-exist-xyz.example.com", ) require.NoError(t, err) assert.Empty(t, ips) } func TestResolveIPAddresses_CloudflareDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com") require.NoError(t, err) require.NotEmpty(t, ips) } // ---------------------------------------------------------------- // Context cancellation tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ContextCanceled( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) assert.Error(t, err) } func TestQueryNameserver_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryNameserver( ctx, "ns1.example.com.", "example.com", ) assert.Error(t, err) } func TestQueryAllNameservers_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryAllNameservers(ctx, "example.com") assert.Error(t, err) } func TestResolveIPAddresses_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.ResolveIPAddresses(ctx, "example.com") assert.Error(t, err) }