package resolver_test import ( "context" "log/slog" "net" "os" "sort" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/dnswatcher/internal/resolver" ) // ---------------------------------------------------------------- // Test helpers // ---------------------------------------------------------------- func newTestResolver(t *testing.T) *resolver.Resolver { t.Helper() log := slog.New(slog.NewTextHandler( os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}, )) return resolver.NewFromLogger(log) } func testContext(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout( context.Background(), 60*time.Second, ) t.Cleanup(cancel) return ctx } func findOneNSForDomain( t *testing.T, r *resolver.Resolver, ctx context.Context, //nolint:revive // test helper domain string, ) string { t.Helper() nameservers, err := r.FindAuthoritativeNameservers( ctx, domain, ) require.NoError(t, err) require.NotEmpty(t, nameservers) return nameservers[0] } // ---------------------------------------------------------------- // FindAuthoritativeNameservers tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ValidDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) hasGoogleNS := false for _, ns := range nameservers { if strings.Contains(ns, "google") { hasGoogleNS = true break } } assert.True(t, hasGoogleNS, "expected google nameservers, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Subdomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "www.google.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) } func TestFindAuthoritativeNameservers_ReturnsSorted( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) assert.True( t, sort.StringsAreSorted(nameservers), "nameservers should be sorted, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Deterministic( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) first, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) second, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) assert.Equal(t, first, second) } func TestFindAuthoritativeNameservers_TrailingDot( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns1, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) ns2, err := r.FindAuthoritativeNameservers( ctx, "google.com.", ) require.NoError(t, err) assert.Equal(t, ns1, ns2) } func TestFindAuthoritativeNameservers_CloudflareDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "cloudflare.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { assert.True(t, strings.HasSuffix(ns, "."), "NS should be FQDN with trailing dot: %s", ns, ) } } // ---------------------------------------------------------------- // QueryNameserver tests // ---------------------------------------------------------------- func TestQueryNameserver_BasicA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( ctx, ns, "www.google.com", ) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) assert.Equal(t, ns, resp.Nameserver) hasRecords := len(resp.Records["A"]) > 0 || len(resp.Records["CNAME"]) > 0 assert.True(t, hasRecords, "expected A or CNAME records for www.google.com", ) } func TestQueryNameserver_AAAA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") resp, err := r.QueryNameserver( ctx, ns, "cloudflare.com", ) require.NoError(t, err) aaaaRecords := resp.Records["AAAA"] require.NotEmpty(t, aaaaRecords, "cloudflare.com should have AAAA records", ) for _, ip := range aaaaRecords { parsed := net.ParseIP(ip) require.NotNil(t, parsed, "should be valid IP: %s", ip, ) } } func TestQueryNameserver_MX(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( ctx, ns, "google.com", ) require.NoError(t, err) mxRecords := resp.Records["MX"] require.NotEmpty(t, mxRecords, "google.com should have MX records", ) } func TestQueryNameserver_TXT(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( ctx, ns, "google.com", ) require.NoError(t, err) txtRecords := resp.Records["TXT"] require.NotEmpty(t, txtRecords, "google.com should have TXT records", ) hasSPF := false for _, txt := range txtRecords { if strings.Contains(txt, "v=spf1") { hasSPF = true break } } assert.True(t, hasSPF, "google.com should have SPF TXT record", ) } func TestQueryNameserver_NXDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( ctx, ns, "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) assert.Equal(t, resolver.StatusNXDomain, resp.Status) } func TestQueryNameserver_RecordsSorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( ctx, ns, "google.com", ) require.NoError(t, err) for recordType, values := range resp.Records { assert.True( t, sort.StringsAreSorted(values), "%s records should be sorted", recordType, ) } } func TestQueryNameserver_ResponseIncludesNameserver( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") resp, err := r.QueryNameserver( ctx, ns, "cloudflare.com", ) require.NoError(t, err) assert.Equal(t, ns, resp.Nameserver) } func TestQueryNameserver_EmptyRecordsOnNXDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( ctx, ns, "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) totalRecords := 0 for _, values := range resp.Records { totalRecords += len(values) } assert.Zero(t, totalRecords) } func TestQueryNameserver_TrailingDotHandling(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNSForDomain(t, r, ctx, "google.com") resp1, err := r.QueryNameserver( ctx, ns, "google.com", ) require.NoError(t, err) resp2, err := r.QueryNameserver( ctx, ns, "google.com.", ) require.NoError(t, err) assert.Equal(t, resp1.Status, resp2.Status) } // ---------------------------------------------------------------- // QueryAllNameservers tests // ---------------------------------------------------------------- func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "google.com", ) require.NoError(t, err) require.NotEmpty(t, results) assert.GreaterOrEqual(t, len(results), 2) for ns, resp := range results { assert.Equal(t, ns, resp.Nameserver) } } func TestQueryAllNameservers_AllReturnOK(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "google.com", ) require.NoError(t, err) for ns, resp := range results { assert.Equal( t, resolver.StatusOK, resp.Status, "NS %s should return OK", ns, ) } } func TestQueryAllNameservers_NXDomainFromAllNS( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) for ns, resp := range results { assert.Equal( t, resolver.StatusNXDomain, resp.Status, "NS %s should return nxdomain", ns, ) } } // ---------------------------------------------------------------- // LookupNS tests // ---------------------------------------------------------------- func TestLookupNS_ValidDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { assert.True(t, strings.HasSuffix(ns, "."), "NS should have trailing dot: %s", ns, ) } } func TestLookupNS_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(nameservers)) } func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) fromLookup, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) fromFind, err := r.FindAuthoritativeNameservers( ctx, "google.com", ) require.NoError(t, err) assert.Equal(t, fromFind, fromLookup) } // ---------------------------------------------------------------- // ResolveIPAddresses tests // ---------------------------------------------------------------- func TestResolveIPAddresses_ReturnsIPs(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) require.NotEmpty(t, ips) for _, ip := range ips { parsed := net.ParseIP(ip) assert.NotNil(t, parsed, "should be valid IP: %s", ip, ) } } func TestResolveIPAddresses_Deduplicated(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) seen := make(map[string]bool) for _, ip := range ips { assert.False(t, seen[ip], "duplicate IP: %s", ip) seen[ip] = true } } func TestResolveIPAddresses_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(ips)) } func TestResolveIPAddresses_NXDomainReturnsEmpty( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses( ctx, "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) assert.Empty(t, ips) } func TestResolveIPAddresses_CloudflareDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com") require.NoError(t, err) require.NotEmpty(t, ips) } // ---------------------------------------------------------------- // Context cancellation tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ContextCanceled( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.FindAuthoritativeNameservers(ctx, "google.com") assert.Error(t, err) } func TestQueryNameserver_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryNameserver( ctx, "ns1.google.com.", "google.com", ) assert.Error(t, err) } func TestQueryAllNameservers_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryAllNameservers(ctx, "google.com") assert.Error(t, err) } func TestResolveIPAddresses_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.ResolveIPAddresses(ctx, "google.com") assert.Error(t, err) }