From 9b93385f6f65c2c2a7df4078051a3ee8afb701ac Mon Sep 17 00:00:00 2001 From: user Date: Sat, 21 Feb 2026 02:57:23 -0800 Subject: [PATCH] fix: mock DNS client in resolver tests to avoid network calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace real DNS lookups in resolver_test.go with a mock DNSClient that returns canned responses. Tests now run in ~1s instead of >30s with no network dependency. Uses the existing NewFromLoggerWithClient constructor and DNSClient interface — no production code changes needed. --- internal/resolver/resolver_test.go | 460 ++++++++++++++++++++++++----- 1 file changed, 390 insertions(+), 70 deletions(-) diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index 3b9d936..363db79 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -2,6 +2,8 @@ package resolver_test import ( "context" + "errors" + "fmt" "log/slog" "net" "os" @@ -10,6 +12,7 @@ import ( "testing" "time" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,48 +20,372 @@ import ( ) // ---------------------------------------------------------------- -// Test helpers +// Mock DNS client // ---------------------------------------------------------------- -func newTestResolver(t *testing.T) *resolver.Resolver { +// errNoQuestion is returned when a DNS message has no question. +var errNoQuestion = errors.New("no question in DNS message") + +// mockExchange represents a canned DNS response for a query. +type mockExchange struct { + resp *dns.Msg + err error +} + +// mockDNSClient implements resolver.DNSClient with canned responses. +// It matches on (qname, qtype) and returns pre-configured responses. +// If no match is found it returns SERVFAIL. +type mockDNSClient struct { + // responses keyed by "QNAME/QTYPE" e.g. "google.com./NS" + responses map[string]*mockExchange +} + +func newMockClient() *mockDNSClient { + return &mockDNSClient{ + responses: make(map[string]*mockExchange), + } +} + +func mockKey(name string, qtype uint16) string { + return fmt.Sprintf("%s/%s", + strings.ToLower(dns.Fqdn(name)), + dns.TypeToString[qtype], + ) +} + +// ExchangeContext implements resolver.DNSClient for testing. +func (m *mockDNSClient) ExchangeContext( + _ context.Context, + msg *dns.Msg, + _ string, +) (*dns.Msg, time.Duration, error) { + if len(msg.Question) == 0 { + return nil, 0, errNoQuestion + } + + q := msg.Question[0] + key := mockKey(q.Name, q.Qtype) + + if ex, ok := m.responses[key]; ok { + if ex.err != nil { + return nil, 0, ex.err + } + + resp := ex.resp.Copy() + resp.Id = msg.Id + + return resp, time.Millisecond, nil + } + + // Default: SERVFAIL + resp := new(dns.Msg) + resp.SetReply(msg) + resp.Rcode = dns.RcodeServerFailure + + return resp, time.Millisecond, nil +} + +func (m *mockDNSClient) on( + name string, + qtype uint16, + resp *dns.Msg, +) { + m.responses[mockKey(name, qtype)] = &mockExchange{resp: resp} +} + +// ---------------------------------------------------------------- +// Helper functions to build DNS responses +// ---------------------------------------------------------------- + +func makeNSResponse( + qname string, + nsNames []string, + glue map[string]string, +) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qname), dns.TypeNS) + msg.Authoritative = true + + for _, ns := range nsNames { + msg.Answer = append(msg.Answer, &dns.NS{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(qname), + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + Ttl: 3600, + }, + Ns: dns.Fqdn(ns), + }) + } + + for name, ip := range glue { + msg.Extra = append(msg.Extra, &dns.A{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(name), + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 3600, + }, + A: net.ParseIP(ip), + }) + } + + return msg +} + +func makeAResponse(qname string, ips ...string) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qname), dns.TypeA) + msg.Authoritative = true + + for _, ip := range ips { + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(qname), + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP(ip), + }) + } + + return msg +} + +func makeAAAAResponse(qname string, ips ...string) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qname), dns.TypeAAAA) + msg.Authoritative = true + + for _, ip := range ips { + msg.Answer = append(msg.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(qname), + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: net.ParseIP(ip), + }) + } + + return msg +} + +func makeMXResponse( + qname string, + mxRecords map[uint16]string, +) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qname), dns.TypeMX) + msg.Authoritative = true + + for pref, mx := range mxRecords { + msg.Answer = append(msg.Answer, &dns.MX{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(qname), + Rrtype: dns.TypeMX, + Class: dns.ClassINET, + Ttl: 300, + }, + Preference: pref, + Mx: dns.Fqdn(mx), + }) + } + + return msg +} + +func makeTXTResponse(qname string, txts ...string) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qname), dns.TypeTXT) + msg.Authoritative = true + + for _, txt := range txts { + msg.Answer = append(msg.Answer, &dns.TXT{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(qname), + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 300, + }, + Txt: []string{txt}, + }) + } + + return msg +} + +func makeNXDomainResponse( + qname string, + qtype uint16, +) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qname), qtype) + msg.Rcode = dns.RcodeNameError + + return msg +} + +func makeEmptyResponse( + qname string, + qtype uint16, +) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(qname), qtype) + msg.Authoritative = true + + return msg +} + +// ---------------------------------------------------------------- +// Setup helpers +// ---------------------------------------------------------------- + +// setupMockResolver creates a mock client pre-loaded with a +// realistic DNS delegation chain for example.com and google.com. +func setupGoogleMock(mock *mockDNSClient) { + // google.com. NS (authoritative answer) + googleDelegation := makeNSResponse( + "google.com.", + []string{ + "ns1.google.com", "ns2.google.com", + "ns3.google.com", "ns4.google.com", + }, + map[string]string{ + "ns1.google.com": "216.239.32.10", + "ns2.google.com": "216.239.34.10", + "ns3.google.com": "216.239.36.10", + "ns4.google.com": "216.239.38.10", + }, + ) + mock.on("google.com.", dns.TypeNS, googleDelegation) + + // NS IP resolution + mock.on("ns1.google.com.", dns.TypeA, + makeAResponse("ns1.google.com", "216.239.32.10")) + mock.on("ns2.google.com.", dns.TypeA, + makeAResponse("ns2.google.com", "216.239.34.10")) + mock.on("ns3.google.com.", dns.TypeA, + makeAResponse("ns3.google.com", "216.239.36.10")) + mock.on("ns4.google.com.", dns.TypeA, + makeAResponse("ns4.google.com", "216.239.38.10")) + + // google.com records + mock.on("google.com.", dns.TypeA, + makeAResponse("google.com", "142.250.80.46")) + mock.on("google.com.", dns.TypeAAAA, + makeAAAAResponse("google.com", "2607:f8b0:4004:800::200e")) + mock.on("google.com.", dns.TypeMX, + makeMXResponse("google.com", map[uint16]string{ + 10: "smtp.google.com", 20: "smtp2.google.com", + })) + mock.on("google.com.", dns.TypeTXT, + makeTXTResponse("google.com", + "v=spf1 include:_spf.google.com ~all")) + mock.on("google.com.", dns.TypeSRV, + makeEmptyResponse("google.com.", dns.TypeSRV)) + mock.on("google.com.", dns.TypeCAA, + makeEmptyResponse("google.com.", dns.TypeCAA)) + + // www.google.com records + mock.on("www.google.com.", dns.TypeA, + makeAResponse("www.google.com", "142.250.80.46")) + + for _, qt := range []uint16{ + dns.TypeAAAA, dns.TypeCNAME, dns.TypeMX, + dns.TypeTXT, dns.TypeSRV, dns.TypeCAA, dns.TypeNS, + } { + mock.on("www.google.com.", qt, + makeEmptyResponse("www.google.com.", qt)) + } + + // NXDOMAIN for nonexistent subdomain + nxName := "this-surely-does-not-exist-xyz.google.com." + + for _, qt := range []uint16{ + dns.TypeA, dns.TypeAAAA, dns.TypeCNAME, + dns.TypeMX, dns.TypeTXT, dns.TypeSRV, + dns.TypeCAA, dns.TypeNS, + } { + mock.on(nxName, qt, makeNXDomainResponse(nxName, qt)) + } +} + +func setupCloudflareMock(mock *mockDNSClient) { + cfNS := makeNSResponse( + "cloudflare.com.", + []string{"ns3.cloudflare.com", "ns7.cloudflare.com"}, + map[string]string{ + "ns3.cloudflare.com": "162.159.0.33", + "ns7.cloudflare.com": "162.159.36.1", + }, + ) + mock.on("cloudflare.com.", dns.TypeNS, cfNS) + + mock.on("ns3.cloudflare.com.", dns.TypeA, + makeAResponse("ns3.cloudflare.com", "162.159.0.33")) + mock.on("ns7.cloudflare.com.", dns.TypeA, + makeAResponse("ns7.cloudflare.com", "162.159.36.1")) + + mock.on("cloudflare.com.", dns.TypeA, + makeAResponse("cloudflare.com", + "104.16.132.229", "104.16.133.229")) + mock.on("cloudflare.com.", dns.TypeAAAA, + makeAAAAResponse("cloudflare.com", + "2606:4700::6810:84e5", "2606:4700::6810:85e5")) + mock.on("cloudflare.com.", dns.TypeCNAME, + makeEmptyResponse("cloudflare.com.", dns.TypeCNAME)) + mock.on("cloudflare.com.", dns.TypeMX, + makeEmptyResponse("cloudflare.com.", dns.TypeMX)) + mock.on("cloudflare.com.", dns.TypeTXT, + makeTXTResponse("cloudflare.com", + "v=spf1 include:_spf.google.com ~all")) + mock.on("cloudflare.com.", dns.TypeSRV, + makeEmptyResponse("cloudflare.com.", dns.TypeSRV)) + mock.on("cloudflare.com.", dns.TypeCAA, + makeEmptyResponse("cloudflare.com.", dns.TypeCAA)) +} + +func setupMockResolver(t *testing.T) *resolver.Resolver { t.Helper() + mock := newMockClient() + + // Root -> com. NS (authoritative answer style since mock + // cannot distinguish server addresses) + mock.on("com.", dns.TypeNS, makeNSResponse( + "com.", + []string{"a.gtld-servers.net", "b.gtld-servers.net"}, + map[string]string{ + "a.gtld-servers.net": "192.5.6.30", + "b.gtld-servers.net": "192.33.14.30", + }, + )) + + setupGoogleMock(mock) + setupCloudflareMock(mock) + log := slog.New(slog.NewTextHandler( os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}, )) - return resolver.NewFromLogger(log) + return resolver.NewFromLoggerWithClient(log, mock) } func testContext(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout( - context.Background(), 60*time.Second, + context.Background(), 10*time.Second, ) t.Cleanup(cancel) return ctx } -func findOneNSForDomain( - t *testing.T, - r *resolver.Resolver, - ctx context.Context, //nolint:revive // test helper - domain string, -) string { - t.Helper() - - nameservers, err := r.FindAuthoritativeNameservers( - ctx, domain, - ) - require.NoError(t, err) - require.NotEmpty(t, nameservers) - - return nameservers[0] -} - // ---------------------------------------------------------------- // FindAuthoritativeNameservers tests // ---------------------------------------------------------------- @@ -68,7 +395,7 @@ func TestFindAuthoritativeNameservers_ValidDomain( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( @@ -97,7 +424,7 @@ func TestFindAuthoritativeNameservers_Subdomain( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( @@ -112,7 +439,7 @@ func TestFindAuthoritativeNameservers_ReturnsSorted( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( @@ -132,7 +459,7 @@ func TestFindAuthoritativeNameservers_Deterministic( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) first, err := r.FindAuthoritativeNameservers( @@ -153,7 +480,7 @@ func TestFindAuthoritativeNameservers_TrailingDot( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) ns1, err := r.FindAuthoritativeNameservers( @@ -174,7 +501,7 @@ func TestFindAuthoritativeNameservers_CloudflareDomain( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( @@ -197,18 +524,17 @@ func TestFindAuthoritativeNameservers_CloudflareDomain( func TestQueryNameserver_BasicA(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, "www.google.com", + ctx, "ns1.google.com.", "www.google.com", ) require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) - assert.Equal(t, ns, resp.Nameserver) + assert.Equal(t, "ns1.google.com.", resp.Nameserver) hasRecords := len(resp.Records["A"]) > 0 || len(resp.Records["CNAME"]) > 0 @@ -220,12 +546,11 @@ func TestQueryNameserver_BasicA(t *testing.T) { func TestQueryNameserver_AAAA(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") resp, err := r.QueryNameserver( - ctx, ns, "cloudflare.com", + ctx, "ns3.cloudflare.com.", "cloudflare.com", ) require.NoError(t, err) @@ -245,12 +570,11 @@ func TestQueryNameserver_AAAA(t *testing.T) { func TestQueryNameserver_MX(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, "google.com", + ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) @@ -263,12 +587,11 @@ func TestQueryNameserver_MX(t *testing.T) { func TestQueryNameserver_TXT(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, "google.com", + ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) @@ -295,12 +618,11 @@ func TestQueryNameserver_TXT(t *testing.T) { func TestQueryNameserver_NXDomain(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, + ctx, "ns1.google.com.", "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) @@ -311,12 +633,11 @@ func TestQueryNameserver_NXDomain(t *testing.T) { func TestQueryNameserver_RecordsSorted(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, "google.com", + ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) @@ -334,9 +655,10 @@ func TestQueryNameserver_ResponseIncludesNameserver( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "cloudflare.com") + + ns := "ns3.cloudflare.com." resp, err := r.QueryNameserver( ctx, ns, "cloudflare.com", @@ -351,12 +673,11 @@ func TestQueryNameserver_EmptyRecordsOnNXDomain( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "google.com") resp, err := r.QueryNameserver( - ctx, ns, + ctx, "ns1.google.com.", "this-surely-does-not-exist-xyz.google.com", ) require.NoError(t, err) @@ -372,17 +693,16 @@ func TestQueryNameserver_EmptyRecordsOnNXDomain( func TestQueryNameserver_TrailingDotHandling(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) - ns := findOneNSForDomain(t, r, ctx, "google.com") resp1, err := r.QueryNameserver( - ctx, ns, "google.com", + ctx, "ns1.google.com.", "google.com", ) require.NoError(t, err) resp2, err := r.QueryNameserver( - ctx, ns, "google.com.", + ctx, "ns1.google.com.", "google.com.", ) require.NoError(t, err) @@ -396,7 +716,7 @@ func TestQueryNameserver_TrailingDotHandling(t *testing.T) { func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( @@ -415,7 +735,7 @@ func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { func TestQueryAllNameservers_AllReturnOK(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( @@ -436,7 +756,7 @@ func TestQueryAllNameservers_NXDomainFromAllNS( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( @@ -460,7 +780,7 @@ func TestQueryAllNameservers_NXDomainFromAllNS( func TestLookupNS_ValidDomain(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "google.com") @@ -477,7 +797,7 @@ func TestLookupNS_ValidDomain(t *testing.T) { func TestLookupNS_Sorted(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "google.com") @@ -489,7 +809,7 @@ func TestLookupNS_Sorted(t *testing.T) { func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) fromLookup, err := r.LookupNS(ctx, "google.com") @@ -510,7 +830,7 @@ func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { func TestResolveIPAddresses_ReturnsIPs(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") @@ -528,7 +848,7 @@ func TestResolveIPAddresses_ReturnsIPs(t *testing.T) { func TestResolveIPAddresses_Deduplicated(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") @@ -545,7 +865,7 @@ func TestResolveIPAddresses_Deduplicated(t *testing.T) { func TestResolveIPAddresses_Sorted(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "google.com") @@ -559,7 +879,7 @@ func TestResolveIPAddresses_NXDomainReturnsEmpty( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses( @@ -573,7 +893,7 @@ func TestResolveIPAddresses_NXDomainReturnsEmpty( func TestResolveIPAddresses_CloudflareDomain(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com") @@ -590,7 +910,7 @@ func TestFindAuthoritativeNameservers_ContextCanceled( ) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -601,7 +921,7 @@ func TestFindAuthoritativeNameservers_ContextCanceled( func TestQueryNameserver_ContextCanceled(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -614,7 +934,7 @@ func TestQueryNameserver_ContextCanceled(t *testing.T) { func TestQueryAllNameservers_ContextCanceled(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -625,7 +945,7 @@ func TestQueryAllNameservers_ContextCanceled(t *testing.T) { func TestResolveIPAddresses_ContextCanceled(t *testing.T) { t.Parallel() - r := newTestResolver(t) + r := setupMockResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel()