package resolver_test import ( "context" "log/slog" "net" "os" "sort" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/dnswatcher/internal/resolver" ) // Test domain and hostnames hosted on Cloudflare. // These records must exist in the sneak.cloud Cloudflare zone: // // basic.dns.sneak.cloud A 192.0.2.1 // multi.dns.sneak.cloud A 192.0.2.1 // multi.dns.sneak.cloud A 192.0.2.2 // ipv6.dns.sneak.cloud AAAA 2001:db8::1 // dual.dns.sneak.cloud A 192.0.2.1 // dual.dns.sneak.cloud AAAA 2001:db8::1 // cname-target.dns.sneak.cloud A 198.51.100.1 // cname.dns.sneak.cloud CNAME cname-target.dns.sneak.cloud // mx.dns.sneak.cloud MX 10 mail.dns.sneak.cloud // mail.dns A 192.0.2.10 // txt.dns.sneak.cloud TXT "v=spf1 -all" const ( testDomain = "sneak.cloud" testHostBasic = "basic.dns.sneak.cloud" testHostMultiA = "multi.dns.sneak.cloud" testHostIPv6 = "ipv6.dns.sneak.cloud" testHostDualStack = "dual.dns.sneak.cloud" testHostCNAME = "cname.dns.sneak.cloud" testHostCNAMETarget = "cname-target.dns.sneak.cloud" testHostMX = "mx.dns.sneak.cloud" testHostMail = "mail.dns.sneak.cloud" testHostTXT = "txt.dns.sneak.cloud" testHostNXDomain = "nxdomain-surely-does-not-exist.google.com" testDomainNXDomain = "google.com" ) // queryTimeout is the default timeout for test queries. const queryTimeout = 60 * time.Second func newTestResolver(t *testing.T) *resolver.Resolver { t.Helper() log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelDebug, })) return resolver.NewFromLogger(log) } func testContext(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout( context.Background(), queryTimeout, ) t.Cleanup(cancel) return ctx } // --- FindAuthoritativeNameservers tests --- func TestFindAuthoritativeNameservers_ValidDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, testDomain, ) require.NoError(t, err) require.NotEmpty(t, nameservers, "should find at least one NS") // sneak.cloud is on Cloudflare, NS should contain cloudflare for _, ns := range nameservers { t.Logf("discovered NS: %s", ns) assert.True( t, strings.HasSuffix(ns, "."), "NS should be FQDN with trailing dot: %s", ns, ) } // Verify at least one is a Cloudflare NS hasCloudflare := false for _, ns := range nameservers { if strings.Contains(ns, "cloudflare") { hasCloudflare = true break } } assert.True( t, hasCloudflare, "sneak.cloud should be hosted on Cloudflare, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Subdomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) // Looking up NS for a hostname that isn't a zone should // return the parent zone's NS records. nameservers, err := r.FindAuthoritativeNameservers( ctx, testHostBasic, ) require.NoError(t, err) require.NotEmpty(t, nameservers) // Should be the same Cloudflare NSes as the parent domain hasCloudflare := false for _, ns := range nameservers { if strings.Contains(ns, "cloudflare") { hasCloudflare = true break } } assert.True(t, hasCloudflare) } func TestFindAuthoritativeNameservers_TLD(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "cloud", ) require.NoError(t, err) require.NotEmpty(t, nameservers, "should find TLD nameservers") for _, ns := range nameservers { t.Logf("TLD NS: %s", ns) } } func TestFindAuthoritativeNameservers_ReturnsSorted( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, testDomain, ) require.NoError(t, err) require.NotEmpty(t, nameservers) // Results should be sorted for deterministic comparison assert.True( t, sort.StringsAreSorted(nameservers), "nameservers should be sorted, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Deterministic( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) first, err := r.FindAuthoritativeNameservers( ctx, testDomain, ) require.NoError(t, err) second, err := r.FindAuthoritativeNameservers( ctx, testDomain, ) require.NoError(t, err) assert.Equal( t, first, second, "repeated lookups should return same result", ) } // --- QueryNameserver tests --- func TestQueryNameserver_BasicA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostBasic) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) assert.Equal(t, ns, resp.Nameserver) aRecords := resp.Records["A"] require.NotEmpty(t, aRecords, "basic.dns should have A records") assert.Contains(t, aRecords, "192.0.2.1") t.Logf( "QueryNameserver(%s, %s) A records: %v", ns, testHostBasic, aRecords, ) } func TestQueryNameserver_MultipleA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostMultiA) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) aRecords := resp.Records["A"] require.Len( t, aRecords, 2, "multi.dns should have exactly 2 A records", ) sort.Strings(aRecords) assert.Equal(t, []string{"192.0.2.1", "192.0.2.2"}, aRecords) } func TestQueryNameserver_AAAA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostIPv6) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) aaaaRecords := resp.Records["AAAA"] require.NotEmpty( t, aaaaRecords, "ipv6.dns should have AAAA records", ) assert.Contains(t, aaaaRecords, "2001:db8::1") } func TestQueryNameserver_DualStack(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostDualStack) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) assert.Contains(t, resp.Records["A"], "192.0.2.1") assert.Contains(t, resp.Records["AAAA"], "2001:db8::1") } func TestQueryNameserver_CNAME(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostCNAME) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) cnameRecords := resp.Records["CNAME"] require.NotEmpty( t, cnameRecords, "cname.dns should have CNAME records", ) assert.Contains( t, cnameRecords, "cname-target.dns.sneak.cloud.", ) } func TestQueryNameserver_MX(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostMX) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) mxRecords := resp.Records["MX"] require.NotEmpty( t, mxRecords, "mx.dns should have MX records", ) // MX records are formatted as "priority host" hasMail := false for _, mx := range mxRecords { if strings.Contains(mx, "mail.dns.sneak.cloud.") { hasMail = true break } } assert.True( t, hasMail, "MX should reference mail.dns.sneak.cloud, got: %v", mxRecords, ) } func TestQueryNameserver_TXT(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostTXT) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) txtRecords := resp.Records["TXT"] require.NotEmpty( t, txtRecords, "txt.dns should have TXT records", ) hasSPF := false for _, txt := range txtRecords { if strings.Contains(txt, "v=spf1") { hasSPF = true break } } assert.True( t, hasSPF, "TXT should contain SPF record, got: %v", txtRecords, ) } func TestQueryNameserver_NXDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, testDomainNXDomain) resp, err := r.QueryNameserver(ctx, ns, testHostNXDomain) require.NoError(t, err) require.NotNil(t, resp) assert.Equal( t, resolver.StatusNXDomain, resp.Status, "nonexistent host should return nxdomain status", ) } func TestQueryNameserver_RecordsSorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostMultiA) require.NoError(t, err) // Each record type's values should be sorted for determinism for recordType, values := range resp.Records { assert.True( t, sort.StringsAreSorted(values), "%s records should be sorted, got: %v", recordType, values, ) } } func TestQueryNameserver_ResponseIncludesNameserver( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, testHostBasic) require.NoError(t, err) assert.Equal( t, ns, resp.Nameserver, "response should include the queried nameserver", ) } func TestQueryNameserver_EmptyRecordsMapOnNXDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, testDomainNXDomain) resp, err := r.QueryNameserver(ctx, ns, testHostNXDomain) require.NoError(t, err) totalRecords := 0 for _, values := range resp.Records { totalRecords += len(values) } assert.Zero( t, totalRecords, "NXDOMAIN should have no records, got: %v", resp.Records, ) } // --- QueryAllNameservers tests --- func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers(ctx, testHostBasic) require.NoError(t, err) require.NotEmpty(t, results) // Should have queried each NS independently t.Logf( "QueryAllNameservers returned %d nameserver results", len(results), ) for ns, resp := range results { t.Logf(" %s: status=%s A=%v", ns, resp.Status, resp.Records["A"]) assert.Equal(t, ns, resp.Nameserver) } // Should have more than one NS for Cloudflare-hosted domain assert.GreaterOrEqual( t, len(results), 2, "should query at least 2 nameservers", ) } func TestQueryAllNameservers_Consistent(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers(ctx, testHostBasic) require.NoError(t, err) require.NotEmpty(t, results) // All NSes should return the same A records for a // well-configured hostname. var referenceRecords map[string][]string for ns, resp := range results { require.Equal( t, resolver.StatusOK, resp.Status, "NS %s should return OK status", ns, ) if referenceRecords == nil { referenceRecords = resp.Records continue } assert.Equal( t, referenceRecords["A"], resp.Records["A"], "NS %s A records should match", ns, ) } } func TestQueryAllNameservers_NXDomainFromAllNS( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, testHostNXDomain, ) require.NoError(t, err) require.NotEmpty(t, results) for ns, resp := range results { assert.Equal( t, resolver.StatusNXDomain, resp.Status, "NS %s should return nxdomain for nonexistent host", ns, ) } } // --- LookupNS tests --- func TestLookupNS_ValidDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, testDomain) require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { t.Logf("NS record: %s", ns) assert.True( t, strings.HasSuffix(ns, "."), "NS should be FQDN: %s", ns, ) } } func TestLookupNS_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, testDomain) require.NoError(t, err) assert.True( t, sort.StringsAreSorted(nameservers), "NS records should be sorted, got: %v", nameservers, ) } func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) fromLookup, err := r.LookupNS(ctx, testDomain) require.NoError(t, err) fromFind, err := r.FindAuthoritativeNameservers( ctx, testDomain, ) require.NoError(t, err) // Both methods should return the same NS set assert.Equal( t, fromFind, fromLookup, "LookupNS and FindAuthoritativeNameservers "+ "should return the same set", ) } // --- ResolveIPAddresses tests --- func TestResolveIPAddresses_BasicA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, testHostBasic) require.NoError(t, err) require.NotEmpty(t, ips) assert.Contains(t, ips, "192.0.2.1") } func TestResolveIPAddresses_MultipleA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, testHostMultiA) require.NoError(t, err) sort.Strings(ips) assert.Contains(t, ips, "192.0.2.1") assert.Contains(t, ips, "192.0.2.2") } func TestResolveIPAddresses_IPv6Only(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, testHostIPv6) require.NoError(t, err) require.NotEmpty(t, ips) assert.Contains(t, ips, "2001:db8::1") // Should not contain any IPv4 for _, ip := range ips { parsed := net.ParseIP(ip) require.NotNil(t, parsed, "should be valid IP: %s", ip) assert.Nil( t, parsed.To4(), "ipv6-only host should not return IPv4: %s", ip, ) } } func TestResolveIPAddresses_DualStack(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, testHostDualStack) require.NoError(t, err) assert.Contains(t, ips, "192.0.2.1") assert.Contains(t, ips, "2001:db8::1") } func TestResolveIPAddresses_FollowsCNAME(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) // cname.dns.sneak.cloud -> cname-target.dns.sneak.cloud -> 198.51.100.1 ips, err := r.ResolveIPAddresses(ctx, testHostCNAME) require.NoError(t, err) require.NotEmpty(t, ips) assert.Contains( t, ips, "198.51.100.1", "should follow CNAME to resolve target IP", ) } func TestResolveIPAddresses_Deduplicated(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, testHostBasic) require.NoError(t, err) // Check for duplicates seen := make(map[string]bool) for _, ip := range ips { assert.False( t, seen[ip], "IP %s appears more than once", ip, ) seen[ip] = true } } func TestResolveIPAddresses_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, testHostDualStack) require.NoError(t, err) assert.True( t, sort.StringsAreSorted(ips), "IP addresses should be sorted, got: %v", ips, ) } func TestResolveIPAddresses_NXDomainReturnsEmpty( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, testHostNXDomain) // Should not error — NXDOMAIN is an expected DNS response. // It just means no IPs to return. require.NoError(t, err) assert.Empty(t, ips) } // --- Context cancellation tests --- func TestFindAuthoritativeNameservers_ContextCanceled( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately _, err := r.FindAuthoritativeNameservers(ctx, testDomain) assert.Error(t, err) } func TestQueryNameserver_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryNameserver( ctx, "ns1.example.com.", testHostBasic, ) assert.Error(t, err) } func TestQueryAllNameservers_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryAllNameservers(ctx, testHostBasic) assert.Error(t, err) } func TestResolveIPAddresses_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.ResolveIPAddresses(ctx, testHostBasic) assert.Error(t, err) } // --- Iterative resolution verification --- func TestFindAuthoritativeNameservers_IsIterative( t *testing.T, ) { // Verify that resolution works for well-known domains, // proving we trace from root rather than relying on a // system stub resolver that might not be configured. t.Parallel() r := newTestResolver(t) ctx := testContext(t) // Resolve a well-known domain to prove root->TLD->domain // tracing works. nameservers, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) t.Logf("example.com NS: %v", nameservers) } // --- Edge cases --- func TestQueryNameserver_TrailingDotHandling(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) // Both with and without trailing dot should work resp1, err := r.QueryNameserver( ctx, ns, "basic.dns.sneak.cloud", ) require.NoError(t, err) resp2, err := r.QueryNameserver( ctx, ns, "basic.dns.sneak.cloud.", ) require.NoError(t, err) assert.Equal( t, resp1.Records["A"], resp2.Records["A"], "trailing dot should not affect results", ) } func TestFindAuthoritativeNameservers_TrailingDot( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns1, err := r.FindAuthoritativeNameservers( ctx, "sneak.cloud", ) require.NoError(t, err) ns2, err := r.FindAuthoritativeNameservers( ctx, "sneak.cloud.", ) require.NoError(t, err) assert.Equal( t, ns1, ns2, "trailing dot should not affect NS lookup", ) } // --- Helper functions --- // findOneNS discovers authoritative nameservers and returns the first // one, failing the test if none are found. func findOneNS( t *testing.T, r *resolver.Resolver, ctx context.Context, //nolint:revive // test helper ) string { t.Helper() return findOneNSForDomain(t, r, ctx, testDomain) } func findOneNSForDomain( t *testing.T, r *resolver.Resolver, ctx context.Context, //nolint:revive // test helper domain string, ) string { t.Helper() nameservers, err := r.FindAuthoritativeNameservers( ctx, domain, ) require.NoError(t, err) require.NotEmpty( t, nameservers, "should find at least one NS for %s", domain, ) return nameservers[0] }