diff --git a/internal/resolver/resolver_integration_test.go b/internal/resolver/resolver_integration_test.go deleted file mode 100644 index ec8dd0e..0000000 --- a/internal/resolver/resolver_integration_test.go +++ /dev/null @@ -1,85 +0,0 @@ -//go:build integration - -package resolver_test - -import ( - "context" - "log/slog" - "os" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "sneak.berlin/go/dnswatcher/internal/resolver" -) - -// Integration tests hit real DNS servers. Run with: -// go test -tags integration -timeout 60s ./internal/resolver/ - -func newIntegrationResolver(t *testing.T) *resolver.Resolver { - t.Helper() - - log := slog.New(slog.NewTextHandler( - os.Stderr, - &slog.HandlerOptions{Level: slog.LevelDebug}, - )) - - return resolver.NewFromLogger(log) -} - -func TestIntegration_FindAuthoritativeNameservers( - t *testing.T, -) { - t.Parallel() - - r := newIntegrationResolver(t) - - ctx, cancel := context.WithTimeout( - context.Background(), 30*time.Second, - ) - defer cancel() - - nameservers, err := r.FindAuthoritativeNameservers( - ctx, "example.com", - ) - require.NoError(t, err) - require.NotEmpty(t, nameservers) - - t.Logf("example.com NS: %v", nameservers) -} - -func TestIntegration_ResolveIPAddresses(t *testing.T) { - t.Parallel() - - r := newIntegrationResolver(t) - - ctx, cancel := context.WithTimeout( - context.Background(), 30*time.Second, - ) - defer cancel() - - // sneak.cloud is on Cloudflare - nameservers, err := r.FindAuthoritativeNameservers( - ctx, "sneak.cloud", - ) - require.NoError(t, err) - require.NotEmpty(t, nameservers) - - hasCloudflare := false - - for _, ns := range nameservers { - if strings.Contains(ns, "cloudflare") { - hasCloudflare = true - - break - } - } - - assert.True(t, hasCloudflare, - "sneak.cloud should be on Cloudflare, got: %v", - nameservers, - ) -} diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index 22fb538..3b9d936 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -2,386 +2,20 @@ package resolver_test import ( "context" - "errors" "log/slog" "net" "os" - "slices" "sort" "strings" "testing" "time" - "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/dnswatcher/internal/resolver" ) -var ( - errNoQuestion = errors.New("no question") - errUnexpectedServer = errors.New("unexpected server") -) - -// mockDNSClient implements resolver.DNSClient for hermetic tests. -// It dispatches responses based on the query name, type, and -// server address to simulate a full delegation chain. -type mockDNSClient struct { - // handler returns a response for a given query. Return nil - // to simulate a network error. - handler func( - msg *dns.Msg, addr string, - ) (*dns.Msg, error) -} - -func (m *mockDNSClient) ExchangeContext( - _ context.Context, - msg *dns.Msg, - addr string, -) (*dns.Msg, time.Duration, error) { - resp, err := m.handler(msg, addr) - - return resp, 0, err -} - -// ---------------------------------------------------------------- -// Zone data used across tests -// ---------------------------------------------------------------- - -const ( - testDomain = "example.com." - testHostBasic = "basic.example.com." - testHostMultiA = "multi.example.com." - testHostIPv6 = "ipv6.example.com." - testHostDualStack = "dual.example.com." - testHostCNAME = "cname.example.com." - testHostCNAMETgt = "cname-target.example.com." - testHostMX = "mx.example.com." - testHostTXT = "txt.example.com." - testHostNXDomain = "nxdomain.example.com." - - ns1Name = "ns1.example.com." - ns2Name = "ns2.example.com." - ns1IP = "10.0.0.1" - ns2IP = "10.0.0.2" - - // Root and TLD nameserver IPs (we intercept all queries). - comNS = "192.0.2.100" - comName = "a.gtld-servers.net." -) - -// newReply creates a dns.Msg reply for the given question. -func newReply(q *dns.Msg, rcode int) *dns.Msg { - r := new(dns.Msg) - r.SetReply(q) - r.Rcode = rcode - - return r -} - -// buildMockClient returns a mockDNSClient that simulates: -// -// root -> .com TLD delegation -> example.com NS delegation -// -> authoritative answers for test hostnames. -func buildMockClient() *mockDNSClient { - return &mockDNSClient{ - handler: func( - msg *dns.Msg, addr string, - ) (*dns.Msg, error) { - if len(msg.Question) == 0 { - return nil, errNoQuestion - } - - q := msg.Question[0] - name := strings.ToLower(q.Name) - host, _, _ := net.SplitHostPort(addr) - - // --- Root servers: delegate .com --- - if isRootServer(host) { - return handleRoot(msg, name, q.Qtype) - } - - // --- .com TLD: delegate example.com --- - if host == comNS { - return handleTLD(msg, name) - } - - // --- Authoritative NS for example.com --- - if host == ns1IP || host == ns2IP { - return handleAuth(msg, name, q.Qtype) - } - - return nil, errUnexpectedServer - }, - } -} - -func isRootServer(ip string) bool { - roots := []string{ - "198.41.0.4", "170.247.170.2", "192.33.4.12", - "199.7.91.13", "192.203.230.10", "192.5.5.241", - "192.112.36.4", "198.97.190.53", "192.36.148.17", - "192.58.128.30", "193.0.14.129", "199.7.83.42", - "202.12.27.33", - } - - return slices.Contains(roots, ip) -} - -func handleRoot( - msg *dns.Msg, name string, qtype uint16, -) (*dns.Msg, error) { - r := newReply(msg, dns.RcodeSuccess) - - // If asking for NS of "com.", return answer - if name == "com." && qtype == dns.TypeNS { - appendComDelegation(r, true) - - return r, nil - } - - // Recursive A queries for NS hostnames (used by resolveARecord) - if resp, ok := handleRootNSResolution( - r, msg, name, qtype, - ); ok { - return resp, nil - } - - // For anything under .com, delegate to TLD - if strings.HasSuffix(name, ".com.") || name == "com." { - appendComDelegation(r, false) - - return r, nil - } - - return r, nil -} - -func handleRootNSResolution( - r *dns.Msg, msg *dns.Msg, name string, qtype uint16, -) (*dns.Msg, bool) { - if qtype != dns.TypeA { - return nil, false - } - - if msg.RecursionDesired || !strings.HasSuffix(name, ".com.") { - switch name { - case ns1Name: - r.Answer = append(r.Answer, mkA(name, ns1IP)) - - return r, true - case ns2Name: - r.Answer = append(r.Answer, mkA(name, ns2IP)) - - return r, true - } - } - - return nil, false -} - -func appendComDelegation(r *dns.Msg, inAnswer bool) { - ns := &dns.NS{ - Hdr: dns.RR_Header{ - Name: "com.", - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - Ttl: 3600, - }, - Ns: comName, - } - glue := &dns.A{ - Hdr: dns.RR_Header{ - Name: comName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 3600, - }, - A: net.ParseIP(comNS), - } - - if inAnswer { - r.Answer = append(r.Answer, ns) - } else { - r.Ns = append(r.Ns, ns) - } - - r.Extra = append(r.Extra, glue) -} - -func handleTLD(msg *dns.Msg, name string) (*dns.Msg, error) { - r := newReply(msg, dns.RcodeSuccess) - - if !strings.HasSuffix(name, "example.com.") && - name != "example.com." { - r.Rcode = dns.RcodeNameError - - return r, nil - } - - // Delegate to example.com NS - r.Ns = append(r.Ns, - &dns.NS{ - Hdr: dns.RR_Header{ - Name: "example.com.", - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - Ttl: 3600, - }, - Ns: ns1Name, - }, - &dns.NS{ - Hdr: dns.RR_Header{ - Name: "example.com.", - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - Ttl: 3600, - }, - Ns: ns2Name, - }, - ) - r.Extra = append(r.Extra, - &dns.A{ - Hdr: dns.RR_Header{ - Name: ns1Name, Rrtype: dns.TypeA, - Class: dns.ClassINET, Ttl: 3600, - }, - A: net.ParseIP(ns1IP), - }, - &dns.A{ - Hdr: dns.RR_Header{ - Name: ns2Name, Rrtype: dns.TypeA, - Class: dns.ClassINET, Ttl: 3600, - }, - A: net.ParseIP(ns2IP), - }, - ) - - return r, nil -} - -func handleAuth( - msg *dns.Msg, name string, qtype uint16, -) (*dns.Msg, error) { - r := newReply(msg, dns.RcodeSuccess) - - if name == "example.com." && qtype == dns.TypeNS { - appendExampleNS(r) - - return r, nil - } - - if name == testHostNXDomain { - r.Rcode = dns.RcodeNameError - - return r, nil - } - - addAuthRecords(r, name, qtype) - - return r, nil -} - -func appendExampleNS(r *dns.Msg) { - r.Answer = append(r.Answer, - &dns.NS{ - Hdr: dns.RR_Header{ - Name: "example.com.", Rrtype: dns.TypeNS, - Class: dns.ClassINET, Ttl: 3600, - }, - Ns: ns1Name, - }, - &dns.NS{ - Hdr: dns.RR_Header{ - Name: "example.com.", Rrtype: dns.TypeNS, - Class: dns.ClassINET, Ttl: 3600, - }, - Ns: ns2Name, - }, - ) -} - -//nolint:cyclop // dispatch table for test zone records -func addAuthRecords( - r *dns.Msg, name string, qtype uint16, -) { - switch { - case name == testHostBasic && qtype == dns.TypeA: - r.Answer = append(r.Answer, mkA(name, "192.0.2.1")) - - case name == testHostMultiA && qtype == dns.TypeA: - r.Answer = append(r.Answer, - mkA(name, "192.0.2.1"), mkA(name, "192.0.2.2"), - ) - - case name == testHostIPv6 && qtype == dns.TypeAAAA: - r.Answer = append(r.Answer, mkAAAA(name, "2001:db8::1")) - - case name == testHostDualStack && qtype == dns.TypeA: - r.Answer = append(r.Answer, mkA(name, "192.0.2.1")) - - case name == testHostDualStack && qtype == dns.TypeAAAA: - r.Answer = append(r.Answer, mkAAAA(name, "2001:db8::1")) - - case name == testHostCNAME && qtype == dns.TypeCNAME: - r.Answer = append(r.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: name, Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, Ttl: 3600, - }, - Target: testHostCNAMETgt, - }) - - case name == testHostCNAMETgt && qtype == dns.TypeA: - r.Answer = append(r.Answer, mkA(name, "198.51.100.1")) - - case name == testHostMX && qtype == dns.TypeMX: - r.Answer = append(r.Answer, &dns.MX{ - Hdr: dns.RR_Header{ - Name: name, Rrtype: dns.TypeMX, - Class: dns.ClassINET, Ttl: 3600, - }, - Preference: 10, Mx: "mail.example.com.", - }) - - case name == testHostTXT && qtype == dns.TypeTXT: - r.Answer = append(r.Answer, &dns.TXT{ - Hdr: dns.RR_Header{ - Name: name, Rrtype: dns.TypeTXT, - Class: dns.ClassINET, Ttl: 3600, - }, - Txt: []string{"v=spf1 -all"}, - }) - - case name == ns1Name && qtype == dns.TypeA: - r.Answer = append(r.Answer, mkA(name, ns1IP)) - - case name == ns2Name && qtype == dns.TypeA: - r.Answer = append(r.Answer, mkA(name, ns2IP)) - } -} - -func mkA(name, ip string) *dns.A { - return &dns.A{ - Hdr: dns.RR_Header{ - Name: name, Rrtype: dns.TypeA, - Class: dns.ClassINET, Ttl: 3600, - }, - A: net.ParseIP(ip), - } -} - -func mkAAAA(name, ip string) *dns.AAAA { - return &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: name, Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, Ttl: 3600, - }, - AAAA: net.ParseIP(ip), - } -} - // ---------------------------------------------------------------- // Test helpers // ---------------------------------------------------------------- @@ -394,22 +28,37 @@ func newTestResolver(t *testing.T) *resolver.Resolver { &slog.HandlerOptions{Level: slog.LevelDebug}, )) - return resolver.NewFromLoggerWithClient( - log, buildMockClient(), - ) + return resolver.NewFromLogger(log) } func testContext(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout( - context.Background(), 5*time.Second, + context.Background(), 60*time.Second, ) t.Cleanup(cancel) return ctx } +func findOneNSForDomain( + t *testing.T, + r *resolver.Resolver, + ctx context.Context, //nolint:revive // test helper + domain string, +) string { + t.Helper() + + nameservers, err := r.FindAuthoritativeNameservers( + ctx, domain, + ) + require.NoError(t, err) + require.NotEmpty(t, nameservers) + + return nameservers[0] +} + // ---------------------------------------------------------------- // FindAuthoritativeNameservers tests // ---------------------------------------------------------------- @@ -423,13 +72,24 @@ func TestFindAuthoritativeNameservers_ValidDomain( ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( - ctx, "example.com", + ctx, "google.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) - assert.Contains(t, nameservers, ns1Name) - assert.Contains(t, nameservers, ns2Name) + hasGoogleNS := false + + for _, ns := range nameservers { + if strings.Contains(ns, "google") { + hasGoogleNS = true + + break + } + } + + assert.True(t, hasGoogleNS, + "expected google nameservers, got: %v", nameservers, + ) } func TestFindAuthoritativeNameservers_Subdomain( @@ -441,12 +101,10 @@ func TestFindAuthoritativeNameservers_Subdomain( ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( - ctx, "basic.example.com", + ctx, "www.google.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) - - assert.Contains(t, nameservers, ns1Name) } func TestFindAuthoritativeNameservers_ReturnsSorted( @@ -458,7 +116,7 @@ func TestFindAuthoritativeNameservers_ReturnsSorted( ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( - ctx, "example.com", + ctx, "google.com", ) require.NoError(t, err) @@ -478,12 +136,12 @@ func TestFindAuthoritativeNameservers_Deterministic( ctx := testContext(t) first, err := r.FindAuthoritativeNameservers( - ctx, "example.com", + ctx, "google.com", ) require.NoError(t, err) second, err := r.FindAuthoritativeNameservers( - ctx, "example.com", + ctx, "google.com", ) require.NoError(t, err) @@ -499,67 +157,64 @@ func TestFindAuthoritativeNameservers_TrailingDot( ctx := testContext(t) ns1, err := r.FindAuthoritativeNameservers( - ctx, "example.com", + ctx, "google.com", ) require.NoError(t, err) ns2, err := r.FindAuthoritativeNameservers( - ctx, "example.com.", + ctx, "google.com.", ) require.NoError(t, err) assert.Equal(t, ns1, ns2) } -// ---------------------------------------------------------------- -// QueryNameserver tests -// ---------------------------------------------------------------- - -func findOneNS( +func TestFindAuthoritativeNameservers_CloudflareDomain( t *testing.T, - r *resolver.Resolver, - ctx context.Context, //nolint:revive // test helper -) string { - t.Helper() +) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( - ctx, "example.com", + ctx, "cloudflare.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) - return nameservers[0] + for _, ns := range nameservers { + assert.True(t, strings.HasSuffix(ns, "."), + "NS should be FQDN with trailing dot: %s", ns, + ) + } } +// ---------------------------------------------------------------- +// QueryNameserver tests +// ---------------------------------------------------------------- + func TestQueryNameserver_BasicA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "google.com") - resp, err := r.QueryNameserver(ctx, ns, "basic.example.com") + resp, err := r.QueryNameserver( + ctx, ns, "www.google.com", + ) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) assert.Equal(t, ns, resp.Nameserver) - assert.Contains(t, resp.Records["A"], "192.0.2.1") -} -func TestQueryNameserver_MultipleA(t *testing.T) { - t.Parallel() - - r := newTestResolver(t) - ctx := testContext(t) - ns := findOneNS(t, r, ctx) - - resp, err := r.QueryNameserver(ctx, ns, "multi.example.com") - require.NoError(t, err) - - aRecords := resp.Records["A"] - sort.Strings(aRecords) - assert.Equal(t, []string{"192.0.2.1", "192.0.2.2"}, aRecords) + hasRecords := len(resp.Records["A"]) > 0 || + len(resp.Records["CNAME"]) > 0 + assert.True(t, hasRecords, + "expected A or CNAME records for www.google.com", + ) } func TestQueryNameserver_AAAA(t *testing.T) { @@ -567,41 +222,24 @@ func TestQueryNameserver_AAAA(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") - resp, err := r.QueryNameserver(ctx, ns, "ipv6.example.com") - require.NoError(t, err) - - assert.Contains(t, resp.Records["AAAA"], "2001:db8::1") -} - -func TestQueryNameserver_DualStack(t *testing.T) { - t.Parallel() - - r := newTestResolver(t) - ctx := testContext(t) - ns := findOneNS(t, r, ctx) - - resp, err := r.QueryNameserver(ctx, ns, "dual.example.com") - require.NoError(t, err) - - assert.Contains(t, resp.Records["A"], "192.0.2.1") - assert.Contains(t, resp.Records["AAAA"], "2001:db8::1") -} - -func TestQueryNameserver_CNAME(t *testing.T) { - t.Parallel() - - r := newTestResolver(t) - ctx := testContext(t) - ns := findOneNS(t, r, ctx) - - resp, err := r.QueryNameserver(ctx, ns, "cname.example.com") - require.NoError(t, err) - - assert.Contains( - t, resp.Records["CNAME"], testHostCNAMETgt, + resp, err := r.QueryNameserver( + ctx, ns, "cloudflare.com", ) + require.NoError(t, err) + + aaaaRecords := resp.Records["AAAA"] + require.NotEmpty(t, aaaaRecords, + "cloudflare.com should have AAAA records", + ) + + for _, ip := range aaaaRecords { + parsed := net.ParseIP(ip) + require.NotNil(t, parsed, + "should be valid IP: %s", ip, + ) + } } func TestQueryNameserver_MX(t *testing.T) { @@ -609,25 +247,17 @@ func TestQueryNameserver_MX(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "google.com") - resp, err := r.QueryNameserver(ctx, ns, "mx.example.com") + resp, err := r.QueryNameserver( + ctx, ns, "google.com", + ) require.NoError(t, err) mxRecords := resp.Records["MX"] - require.NotEmpty(t, mxRecords) - - hasMail := false - - for _, mx := range mxRecords { - if strings.Contains(mx, "mail.example.com.") { - hasMail = true - - break - } - } - - assert.True(t, hasMail) + require.NotEmpty(t, mxRecords, + "google.com should have MX records", + ) } func TestQueryNameserver_TXT(t *testing.T) { @@ -635,14 +265,21 @@ func TestQueryNameserver_TXT(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "google.com") - resp, err := r.QueryNameserver(ctx, ns, "txt.example.com") + resp, err := r.QueryNameserver( + ctx, ns, "google.com", + ) require.NoError(t, err) + txtRecords := resp.Records["TXT"] + require.NotEmpty(t, txtRecords, + "google.com should have TXT records", + ) + hasSPF := false - for _, txt := range resp.Records["TXT"] { + for _, txt := range txtRecords { if strings.Contains(txt, "v=spf1") { hasSPF = true @@ -650,7 +287,9 @@ func TestQueryNameserver_TXT(t *testing.T) { } } - assert.True(t, hasSPF) + assert.True(t, hasSPF, + "google.com should have SPF TXT record", + ) } func TestQueryNameserver_NXDomain(t *testing.T) { @@ -658,10 +297,11 @@ func TestQueryNameserver_NXDomain(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, "nxdomain.example.com", + ctx, ns, + "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) @@ -673,9 +313,11 @@ func TestQueryNameserver_RecordsSorted(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "google.com") - resp, err := r.QueryNameserver(ctx, ns, "multi.example.com") + resp, err := r.QueryNameserver( + ctx, ns, "google.com", + ) require.NoError(t, err) for recordType, values := range resp.Records { @@ -694,9 +336,11 @@ func TestQueryNameserver_ResponseIncludesNameserver( r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") - resp, err := r.QueryNameserver(ctx, ns, "basic.example.com") + resp, err := r.QueryNameserver( + ctx, ns, "cloudflare.com", + ) require.NoError(t, err) assert.Equal(t, ns, resp.Nameserver) @@ -709,10 +353,11 @@ func TestQueryNameserver_EmptyRecordsOnNXDomain( r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, "nxdomain.example.com", + ctx, ns, + "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) @@ -729,19 +374,19 @@ func TestQueryNameserver_TrailingDotHandling(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - ns := findOneNS(t, r, ctx) + ns := findOneNSForDomain(t, r, ctx, "google.com") resp1, err := r.QueryNameserver( - ctx, ns, "basic.example.com", + ctx, ns, "google.com", ) require.NoError(t, err) resp2, err := r.QueryNameserver( - ctx, ns, "basic.example.com.", + ctx, ns, "google.com.", ) require.NoError(t, err) - assert.Equal(t, resp1.Records["A"], resp2.Records["A"]) + assert.Equal(t, resp1.Status, resp2.Status) } // ---------------------------------------------------------------- @@ -755,7 +400,7 @@ func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { ctx := testContext(t) results, err := r.QueryAllNameservers( - ctx, "basic.example.com", + ctx, "google.com", ) require.NoError(t, err) require.NotEmpty(t, results) @@ -767,32 +412,22 @@ func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { } } -func TestQueryAllNameservers_Consistent(t *testing.T) { +func TestQueryAllNameservers_AllReturnOK(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( - ctx, "basic.example.com", + ctx, "google.com", ) require.NoError(t, err) - var referenceA []string - for ns, resp := range results { assert.Equal( t, resolver.StatusOK, resp.Status, "NS %s should return OK", ns, ) - - if referenceA == nil { - referenceA = resp.Records["A"] - - continue - } - - assert.Equal(t, referenceA, resp.Records["A"]) } } @@ -805,7 +440,8 @@ func TestQueryAllNameservers_NXDomainFromAllNS( ctx := testContext(t) results, err := r.QueryAllNameservers( - ctx, "nxdomain.example.com", + ctx, + "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) @@ -827,12 +463,14 @@ func TestLookupNS_ValidDomain(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - nameservers, err := r.LookupNS(ctx, "example.com") + nameservers, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { - assert.True(t, strings.HasSuffix(ns, ".")) + assert.True(t, strings.HasSuffix(ns, "."), + "NS should have trailing dot: %s", ns, + ) } } @@ -842,7 +480,7 @@ func TestLookupNS_Sorted(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - nameservers, err := r.LookupNS(ctx, "example.com") + nameservers, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(nameservers)) @@ -854,11 +492,11 @@ func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - fromLookup, err := r.LookupNS(ctx, "example.com") + fromLookup, err := r.LookupNS(ctx, "google.com") require.NoError(t, err) fromFind, err := r.FindAuthoritativeNameservers( - ctx, "example.com", + ctx, "google.com", ) require.NoError(t, err) @@ -869,81 +507,31 @@ func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { // ResolveIPAddresses tests // ---------------------------------------------------------------- -func TestResolveIPAddresses_BasicA(t *testing.T) { +func TestResolveIPAddresses_ReturnsIPs(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) - ips, err := r.ResolveIPAddresses(ctx, "basic.example.com") - require.NoError(t, err) - assert.Contains(t, ips, "192.0.2.1") -} - -func TestResolveIPAddresses_MultipleA(t *testing.T) { - t.Parallel() - - r := newTestResolver(t) - ctx := testContext(t) - - ips, err := r.ResolveIPAddresses(ctx, "multi.example.com") - require.NoError(t, err) - - assert.Contains(t, ips, "192.0.2.1") - assert.Contains(t, ips, "192.0.2.2") -} - -func TestResolveIPAddresses_IPv6Only(t *testing.T) { - t.Parallel() - - r := newTestResolver(t) - ctx := testContext(t) - - ips, err := r.ResolveIPAddresses(ctx, "ipv6.example.com") + ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) require.NotEmpty(t, ips) - assert.Contains(t, ips, "2001:db8::1") for _, ip := range ips { parsed := net.ParseIP(ip) - require.NotNil(t, parsed) - assert.Nil(t, parsed.To4(), - "should not contain IPv4: %s", ip, + assert.NotNil(t, parsed, + "should be valid IP: %s", ip, ) } } -func TestResolveIPAddresses_DualStack(t *testing.T) { - t.Parallel() - - r := newTestResolver(t) - ctx := testContext(t) - - ips, err := r.ResolveIPAddresses(ctx, "dual.example.com") - require.NoError(t, err) - - assert.Contains(t, ips, "192.0.2.1") - assert.Contains(t, ips, "2001:db8::1") -} - -func TestResolveIPAddresses_FollowsCNAME(t *testing.T) { - t.Parallel() - - r := newTestResolver(t) - ctx := testContext(t) - - ips, err := r.ResolveIPAddresses(ctx, "cname.example.com") - require.NoError(t, err) - assert.Contains(t, ips, "198.51.100.1") -} - func TestResolveIPAddresses_Deduplicated(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) - ips, err := r.ResolveIPAddresses(ctx, "basic.example.com") + ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) seen := make(map[string]bool) @@ -960,7 +548,7 @@ func TestResolveIPAddresses_Sorted(t *testing.T) { r := newTestResolver(t) ctx := testContext(t) - ips, err := r.ResolveIPAddresses(ctx, "dual.example.com") + ips, err := r.ResolveIPAddresses(ctx, "google.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(ips)) @@ -975,12 +563,24 @@ func TestResolveIPAddresses_NXDomainReturnsEmpty( ctx := testContext(t) ips, err := r.ResolveIPAddresses( - ctx, "nxdomain.example.com", + ctx, + "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) assert.Empty(t, ips) } +func TestResolveIPAddresses_CloudflareDomain(t *testing.T) { + t.Parallel() + + r := newTestResolver(t) + ctx := testContext(t) + + ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com") + require.NoError(t, err) + require.NotEmpty(t, ips) +} + // ---------------------------------------------------------------- // Context cancellation tests // ---------------------------------------------------------------- @@ -994,7 +594,7 @@ func TestFindAuthoritativeNameservers_ContextCanceled( ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := r.FindAuthoritativeNameservers(ctx, "example.com") + _, err := r.FindAuthoritativeNameservers(ctx, "google.com") assert.Error(t, err) } @@ -1006,7 +606,7 @@ func TestQueryNameserver_ContextCanceled(t *testing.T) { cancel() _, err := r.QueryNameserver( - ctx, "ns1.example.com.", "basic.example.com", + ctx, "ns1.google.com.", "google.com", ) assert.Error(t, err) } @@ -1018,7 +618,7 @@ func TestQueryAllNameservers_ContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := r.QueryAllNameservers(ctx, "basic.example.com") + _, err := r.QueryAllNameservers(ctx, "google.com") assert.Error(t, err) } @@ -1029,6 +629,6 @@ func TestResolveIPAddresses_ContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := r.ResolveIPAddresses(ctx, "basic.example.com") + _, err := r.ResolveIPAddresses(ctx, "google.com") assert.Error(t, err) }