package resolver_test import ( "context" "errors" "fmt" "log/slog" "net" "os" "sort" "strings" "testing" "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/dnswatcher/internal/resolver" ) // ---------------------------------------------------------------- // Mock DNS client // ---------------------------------------------------------------- // errNoQuestion is returned when a DNS message has no question. var errNoQuestion = errors.New("no question in DNS message") // mockExchange represents a canned DNS response for a query. type mockExchange struct { resp *dns.Msg err error } // mockDNSClient implements resolver.DNSClient with canned responses. // It matches on (qname, qtype) and returns pre-configured responses. // If no match is found it returns SERVFAIL. type mockDNSClient struct { // responses keyed by "QNAME/QTYPE" e.g. "google.com./NS" responses map[string]*mockExchange } func newMockClient() *mockDNSClient { return &mockDNSClient{ responses: make(map[string]*mockExchange), } } func mockKey(name string, qtype uint16) string { return fmt.Sprintf("%s/%s", strings.ToLower(dns.Fqdn(name)), dns.TypeToString[qtype], ) } // ExchangeContext implements resolver.DNSClient for testing. func (m *mockDNSClient) ExchangeContext( _ context.Context, msg *dns.Msg, _ string, ) (*dns.Msg, time.Duration, error) { if len(msg.Question) == 0 { return nil, 0, errNoQuestion } q := msg.Question[0] key := mockKey(q.Name, q.Qtype) if ex, ok := m.responses[key]; ok { if ex.err != nil { return nil, 0, ex.err } resp := ex.resp.Copy() resp.Id = msg.Id return resp, time.Millisecond, nil } // Default: SERVFAIL resp := new(dns.Msg) resp.SetReply(msg) resp.Rcode = dns.RcodeServerFailure return resp, time.Millisecond, nil } func (m *mockDNSClient) on( name string, qtype uint16, resp *dns.Msg, ) { m.responses[mockKey(name, qtype)] = &mockExchange{resp: resp} } // ---------------------------------------------------------------- // Helper functions to build DNS responses // ---------------------------------------------------------------- func makeNSResponse( qname string, nsNames []string, glue map[string]string, ) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(qname), dns.TypeNS) msg.Authoritative = true for _, ns := range nsNames { msg.Answer = append(msg.Answer, &dns.NS{ Hdr: dns.RR_Header{ Name: dns.Fqdn(qname), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: dns.Fqdn(ns), }) } for name, ip := range glue { msg.Extra = append(msg.Extra, &dns.A{ Hdr: dns.RR_Header{ Name: dns.Fqdn(name), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600, }, A: net.ParseIP(ip), }) } return msg } func makeAResponse(qname string, ips ...string) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(qname), dns.TypeA) msg.Authoritative = true for _, ip := range ips { msg.Answer = append(msg.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: dns.Fqdn(qname), Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300, }, A: net.ParseIP(ip), }) } return msg } func makeAAAAResponse(qname string, ips ...string) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(qname), dns.TypeAAAA) msg.Authoritative = true for _, ip := range ips { msg.Answer = append(msg.Answer, &dns.AAAA{ Hdr: dns.RR_Header{ Name: dns.Fqdn(qname), Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 300, }, AAAA: net.ParseIP(ip), }) } return msg } func makeMXResponse( qname string, mxRecords map[uint16]string, ) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(qname), dns.TypeMX) msg.Authoritative = true for pref, mx := range mxRecords { msg.Answer = append(msg.Answer, &dns.MX{ Hdr: dns.RR_Header{ Name: dns.Fqdn(qname), Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 300, }, Preference: pref, Mx: dns.Fqdn(mx), }) } return msg } func makeTXTResponse(qname string, txts ...string) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(qname), dns.TypeTXT) msg.Authoritative = true for _, txt := range txts { msg.Answer = append(msg.Answer, &dns.TXT{ Hdr: dns.RR_Header{ Name: dns.Fqdn(qname), Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 300, }, Txt: []string{txt}, }) } return msg } func makeNXDomainResponse( qname string, qtype uint16, ) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(qname), qtype) msg.Rcode = dns.RcodeNameError return msg } func makeEmptyResponse( qname string, qtype uint16, ) *dns.Msg { msg := new(dns.Msg) msg.SetQuestion(dns.Fqdn(qname), qtype) msg.Authoritative = true return msg } // ---------------------------------------------------------------- // Setup helpers // ---------------------------------------------------------------- // setupMockResolver creates a mock client pre-loaded with a // realistic DNS delegation chain for example.com and google.com. func setupGoogleMock(mock *mockDNSClient) { // google.com. NS (authoritative answer) googleDelegation := makeNSResponse( "google.com.", []string{ "ns1.google.com", "ns2.google.com", "ns3.google.com", "ns4.google.com", }, map[string]string{ "ns1.google.com": "216.239.32.10", "ns2.google.com": "216.239.34.10", "ns3.google.com": "216.239.36.10", "ns4.google.com": "216.239.38.10", }, ) mock.on("google.com.", dns.TypeNS, googleDelegation) // NS IP resolution mock.on("ns1.google.com.", dns.TypeA, makeAResponse("ns1.google.com", "216.239.32.10")) mock.on("ns2.google.com.", dns.TypeA, makeAResponse("ns2.google.com", "216.239.34.10")) mock.on("ns3.google.com.", dns.TypeA, makeAResponse("ns3.google.com", "216.239.36.10")) mock.on("ns4.google.com.", dns.TypeA, makeAResponse("ns4.google.com", "216.239.38.10")) // google.com records mock.on("google.com.", dns.TypeA, makeAResponse("google.com", "142.250.80.46")) mock.on("google.com.", dns.TypeAAAA, makeAAAAResponse("google.com", "2607:f8b0:4004:800::200e")) mock.on("google.com.", dns.TypeMX, makeMXResponse("google.com", map[uint16]string{ 10: "smtp.google.com", 20: "smtp2.google.com", })) mock.on("google.com.", dns.TypeTXT, makeTXTResponse("google.com", "v=spf1 include:_spf.google.com ~all")) mock.on("google.com.", dns.TypeSRV, makeEmptyResponse("google.com.", dns.TypeSRV)) mock.on("google.com.", dns.TypeCAA, makeEmptyResponse("google.com.", dns.TypeCAA)) // www.google.com records mock.on("www.google.com.", dns.TypeA, makeAResponse("www.google.com", "142.250.80.46")) for _, qt := range []uint16{ dns.TypeAAAA, dns.TypeCNAME, dns.TypeMX, dns.TypeTXT, dns.TypeSRV, dns.TypeCAA, dns.TypeNS, } { mock.on("www.google.com.", qt, makeEmptyResponse("www.google.com.", qt)) } // NXDOMAIN for nonexistent subdomain nxName := "this-surely-does-not-exist-xyz.google.com." for _, qt := range []uint16{ dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, dns.TypeMX, dns.TypeTXT, dns.TypeSRV, dns.TypeCAA, dns.TypeNS, } { mock.on(nxName, qt, makeNXDomainResponse(nxName, qt)) } } func setupCloudflareMock(mock *mockDNSClient) { cfNS := makeNSResponse( "cloudflare.com.", []string{"ns3.cloudflare.com", "ns7.cloudflare.com"}, map[string]string{ "ns3.cloudflare.com": "162.159.0.33", "ns7.cloudflare.com": "162.159.36.1", }, ) mock.on("cloudflare.com.", dns.TypeNS, cfNS) mock.on("ns3.cloudflare.com.", dns.TypeA, makeAResponse("ns3.cloudflare.com", "162.159.0.33")) mock.on("ns7.cloudflare.com.", dns.TypeA, makeAResponse("ns7.cloudflare.com", "162.159.36.1")) mock.on("cloudflare.com.", dns.TypeA, makeAResponse("cloudflare.com", "104.16.132.229", "104.16.133.229")) mock.on("cloudflare.com.", dns.TypeAAAA, makeAAAAResponse("cloudflare.com", "2606:4700::6810:84e5", "2606:4700::6810:85e5")) mock.on("cloudflare.com.", dns.TypeCNAME, makeEmptyResponse("cloudflare.com.", dns.TypeCNAME)) mock.on("cloudflare.com.", dns.TypeMX, makeEmptyResponse("cloudflare.com.", dns.TypeMX)) mock.on("cloudflare.com.", dns.TypeTXT, makeTXTResponse("cloudflare.com", "v=spf1 include:_spf.google.com ~all")) mock.on("cloudflare.com.", dns.TypeSRV, makeEmptyResponse("cloudflare.com.", dns.TypeSRV)) mock.on("cloudflare.com.", dns.TypeCAA, makeEmptyResponse("cloudflare.com.", dns.TypeCAA)) } func setupMockResolver(t *testing.T) *resolver.Resolver { t.Helper() mock := newMockClient() // Root -> com. NS (authoritative answer style since mock // cannot distinguish server addresses) mock.on("com.", dns.TypeNS, makeNSResponse( "com.", []string{"a.gtld-servers.net", "b.gtld-servers.net"}, map[string]string{ "a.gtld-servers.net": "192.5.6.30", "b.gtld-servers.net": "192.33.14.30", }, )) setupGoogleMock(mock) setupCloudflareMock(mock) log := slog.New(slog.NewTextHandler( os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}, )) return resolver.NewFromLoggerWithClient(log, mock) } func testContext(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout( context.Background(), 10*time.Second, ) t.Cleanup(cancel) return ctx } // ---------------------------------------------------------------- // FindAuthoritativeNameservers tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ValidDomain( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) hasGoogleNS := false for _, ns := range nameservers { if strings.Contains(ns, "google") { hasGoogleNS = true break } } assert.True(t, hasGoogleNS, "expected google nameservers, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Subdomain( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "www.google.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) } func TestFindAuthoritativeNameservers_ReturnsSorted( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) assert.True( t, sort.StringsAreSorted(nameservers), "nameservers should be sorted, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Deterministic( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) first, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) second, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) assert.Equal(t, first, second) } func TestFindAuthoritativeNameservers_TrailingDot( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) ns1, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) ns2, err := r.FindAuthoritativeNameservers( ctx, "google.com.", ) require.NoError(t, err) assert.Equal(t, ns1, ns2) } func TestFindAuthoritativeNameservers_CloudflareDomain( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "cloudflare.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { assert.True(t, strings.HasSuffix(ns, "."), "NS should be FQDN with trailing dot: %s", ns, ) } } // ---------------------------------------------------------------- // QueryNameserver tests // ---------------------------------------------------------------- func TestQueryNameserver_BasicA(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp, err := r.QueryNameserver( ctx, "ns1.google.com.", "www.google.com", ) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) assert.Equal(t, "ns1.google.com.", resp.Nameserver) hasRecords := len(resp.Records["A"]) > 0 || len(resp.Records["CNAME"]) > 0 assert.True(t, hasRecords, "expected A or CNAME records for www.google.com", ) } func TestQueryNameserver_AAAA(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp, err := r.QueryNameserver( ctx, "ns3.cloudflare.com.", "cloudflare.com", ) require.NoError(t, err) aaaaRecords := resp.Records["AAAA"] require.NotEmpty(t, aaaaRecords, "cloudflare.com should have AAAA records", ) for _, ip := range aaaaRecords { parsed := net.ParseIP(ip) require.NotNil(t, parsed, "should be valid IP: %s", ip, ) } } func TestQueryNameserver_MX(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp, err := r.QueryNameserver( ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) mxRecords := resp.Records["MX"] require.NotEmpty(t, mxRecords, "google.com should have MX records", ) } func TestQueryNameserver_TXT(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp, err := r.QueryNameserver( ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) txtRecords := resp.Records["TXT"] require.NotEmpty(t, txtRecords, "google.com should have TXT records", ) hasSPF := false for _, txt := range txtRecords { if strings.Contains(txt, "v=spf1") { hasSPF = true break } } assert.True(t, hasSPF, "google.com should have SPF TXT record", ) } func TestQueryNameserver_NXDomain(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp, err := r.QueryNameserver( ctx, "ns1.google.com.", "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) assert.Equal(t, resolver.StatusNXDomain, resp.Status) } func TestQueryNameserver_RecordsSorted(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp, err := r.QueryNameserver( ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) for recordType, values := range resp.Records { assert.True( t, sort.StringsAreSorted(values), "%s records should be sorted", recordType, ) } } func TestQueryNameserver_ResponseIncludesNameserver( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) ns := "ns3.cloudflare.com." resp, err := r.QueryNameserver( ctx, ns, "cloudflare.com", ) require.NoError(t, err) assert.Equal(t, ns, resp.Nameserver) } func TestQueryNameserver_EmptyRecordsOnNXDomain( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp, err := r.QueryNameserver( ctx, "ns1.google.com.", "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) totalRecords := 0 for _, values := range resp.Records { totalRecords += len(values) } assert.Zero(t, totalRecords) } func TestQueryNameserver_TrailingDotHandling(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) resp1, err := r.QueryNameserver( ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) resp2, err := r.QueryNameserver( ctx, "ns1.google.com.", "google.com.", ) require.NoError(t, err) assert.Equal(t, resp1.Status, resp2.Status) } // ---------------------------------------------------------------- // QueryAllNameservers tests // ---------------------------------------------------------------- func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "google.com", ) require.NoError(t, err) require.NotEmpty(t, results) assert.GreaterOrEqual(t, len(results), 2) for ns, resp := range results { assert.Equal(t, ns, resp.Nameserver) } } func TestQueryAllNameservers_AllReturnOK(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "google.com", ) require.NoError(t, err) for ns, resp := range results { assert.Equal( t, resolver.StatusOK, resp.Status, "NS %s should return OK", ns, ) } } func TestQueryAllNameservers_NXDomainFromAllNS( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) for ns, resp := range results { assert.Equal( t, resolver.StatusNXDomain, resp.Status, "NS %s should return nxdomain", ns, ) } } // ---------------------------------------------------------------- // LookupNS tests // ---------------------------------------------------------------- func TestLookupNS_ValidDomain(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { assert.True(t, strings.HasSuffix(ns, "."), "NS should have trailing dot: %s", ns, ) } } func TestLookupNS_Sorted(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(nameservers)) } func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) fromLookup, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) fromFind, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) assert.Equal(t, fromFind, fromLookup) } // ---------------------------------------------------------------- // ResolveIPAddresses tests // ---------------------------------------------------------------- func TestResolveIPAddresses_ReturnsIPs(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) require.NotEmpty(t, ips) for _, ip := range ips { parsed := net.ParseIP(ip) assert.NotNil(t, parsed, "should be valid IP: %s", ip, ) } } func TestResolveIPAddresses_Deduplicated(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) seen := make(map[string]bool) for _, ip := range ips { assert.False(t, seen[ip], "duplicate IP: %s", ip) seen[ip] = true } } func TestResolveIPAddresses_Sorted(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(ips)) } func TestResolveIPAddresses_NXDomainReturnsEmpty( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses( ctx, "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) assert.Empty(t, ips) } func TestResolveIPAddresses_CloudflareDomain(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com") require.NoError(t, err) require.NotEmpty(t, ips) } // ---------------------------------------------------------------- // Context cancellation tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ContextCanceled( t *testing.T, ) { t.Parallel() r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.FindAuthoritativeNameservers(ctx, "google.com") assert.Error(t, err) } func TestQueryNameserver_ContextCanceled(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryNameserver( ctx, "ns1.google.com.", "google.com", ) assert.Error(t, err) } func TestQueryAllNameservers_ContextCanceled(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryAllNameservers(ctx, "google.com") assert.Error(t, err) } func TestResolveIPAddresses_ContextCanceled(t *testing.T) { t.Parallel() r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.ResolveIPAddresses(ctx, "google.com") assert.Error(t, err) }