package resolver_test import ( "context" "errors" "log/slog" "net" "os" "slices" "sort" "strings" "testing" "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "sneak.berlin/go/dnswatcher/internal/resolver" ) var ( errNoQuestion = errors.New("no question") errUnexpectedServer = errors.New("unexpected server") ) // mockDNSClient implements resolver.DNSClient for hermetic tests. // It dispatches responses based on the query name, type, and // server address to simulate a full delegation chain. type mockDNSClient struct { // handler returns a response for a given query. Return nil // to simulate a network error. handler func( msg *dns.Msg, addr string, ) (*dns.Msg, error) } func (m *mockDNSClient) ExchangeContext( _ context.Context, msg *dns.Msg, addr string, ) (*dns.Msg, time.Duration, error) { resp, err := m.handler(msg, addr) return resp, 0, err } // ---------------------------------------------------------------- // Zone data used across tests // ---------------------------------------------------------------- const ( testDomain = "example.com." testHostBasic = "basic.example.com." testHostMultiA = "multi.example.com." testHostIPv6 = "ipv6.example.com." testHostDualStack = "dual.example.com." testHostCNAME = "cname.example.com." testHostCNAMETgt = "cname-target.example.com." testHostMX = "mx.example.com." testHostTXT = "txt.example.com." testHostNXDomain = "nxdomain.example.com." ns1Name = "ns1.example.com." ns2Name = "ns2.example.com." ns1IP = "10.0.0.1" ns2IP = "10.0.0.2" // Root and TLD nameserver IPs (we intercept all queries). comNS = "192.0.2.100" comName = "a.gtld-servers.net." ) // newReply creates a dns.Msg reply for the given question. func newReply(q *dns.Msg, rcode int) *dns.Msg { r := new(dns.Msg) r.SetReply(q) r.Rcode = rcode return r } // buildMockClient returns a mockDNSClient that simulates: // // root -> .com TLD delegation -> example.com NS delegation // -> authoritative answers for test hostnames. func buildMockClient() *mockDNSClient { return &mockDNSClient{ handler: func( msg *dns.Msg, addr string, ) (*dns.Msg, error) { if len(msg.Question) == 0 { return nil, errNoQuestion } q := msg.Question[0] name := strings.ToLower(q.Name) host, _, _ := net.SplitHostPort(addr) // --- Root servers: delegate .com --- if isRootServer(host) { return handleRoot(msg, name, q.Qtype) } // --- .com TLD: delegate example.com --- if host == comNS { return handleTLD(msg, name) } // --- Authoritative NS for example.com --- if host == ns1IP || host == ns2IP { return handleAuth(msg, name, q.Qtype) } return nil, errUnexpectedServer }, } } func isRootServer(ip string) bool { roots := []string{ "198.41.0.4", "170.247.170.2", "192.33.4.12", "199.7.91.13", "192.203.230.10", "192.5.5.241", "192.112.36.4", "198.97.190.53", "192.36.148.17", "192.58.128.30", "193.0.14.129", "199.7.83.42", "202.12.27.33", } return slices.Contains(roots, ip) } func handleRoot( msg *dns.Msg, name string, qtype uint16, ) (*dns.Msg, error) { r := newReply(msg, dns.RcodeSuccess) // If asking for NS of "com.", return answer if name == "com." && qtype == dns.TypeNS { appendComDelegation(r, true) return r, nil } // Recursive A queries for NS hostnames (used by resolveARecord) if resp, ok := handleRootNSResolution( r, msg, name, qtype, ); ok { return resp, nil } // For anything under .com, delegate to TLD if strings.HasSuffix(name, ".com.") || name == "com." { appendComDelegation(r, false) return r, nil } return r, nil } func handleRootNSResolution( r *dns.Msg, msg *dns.Msg, name string, qtype uint16, ) (*dns.Msg, bool) { if qtype != dns.TypeA { return nil, false } if msg.RecursionDesired || !strings.HasSuffix(name, ".com.") { switch name { case ns1Name: r.Answer = append(r.Answer, mkA(name, ns1IP)) return r, true case ns2Name: r.Answer = append(r.Answer, mkA(name, ns2IP)) return r, true } } return nil, false } func appendComDelegation(r *dns.Msg, inAnswer bool) { ns := &dns.NS{ Hdr: dns.RR_Header{ Name: "com.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: comName, } glue := &dns.A{ Hdr: dns.RR_Header{ Name: comName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600, }, A: net.ParseIP(comNS), } if inAnswer { r.Answer = append(r.Answer, ns) } else { r.Ns = append(r.Ns, ns) } r.Extra = append(r.Extra, glue) } func handleTLD(msg *dns.Msg, name string) (*dns.Msg, error) { r := newReply(msg, dns.RcodeSuccess) if !strings.HasSuffix(name, "example.com.") && name != "example.com." { r.Rcode = dns.RcodeNameError return r, nil } // Delegate to example.com NS r.Ns = append(r.Ns, &dns.NS{ Hdr: dns.RR_Header{ Name: "example.com.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: ns1Name, }, &dns.NS{ Hdr: dns.RR_Header{ Name: "example.com.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: ns2Name, }, ) r.Extra = append(r.Extra, &dns.A{ Hdr: dns.RR_Header{ Name: ns1Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600, }, A: net.ParseIP(ns1IP), }, &dns.A{ Hdr: dns.RR_Header{ Name: ns2Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600, }, A: net.ParseIP(ns2IP), }, ) return r, nil } func handleAuth( msg *dns.Msg, name string, qtype uint16, ) (*dns.Msg, error) { r := newReply(msg, dns.RcodeSuccess) if name == "example.com." && qtype == dns.TypeNS { appendExampleNS(r) return r, nil } if name == testHostNXDomain { r.Rcode = dns.RcodeNameError return r, nil } addAuthRecords(r, name, qtype) return r, nil } func appendExampleNS(r *dns.Msg) { r.Answer = append(r.Answer, &dns.NS{ Hdr: dns.RR_Header{ Name: "example.com.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: ns1Name, }, &dns.NS{ Hdr: dns.RR_Header{ Name: "example.com.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 3600, }, Ns: ns2Name, }, ) } //nolint:cyclop // dispatch table for test zone records func addAuthRecords( r *dns.Msg, name string, qtype uint16, ) { switch { case name == testHostBasic && qtype == dns.TypeA: r.Answer = append(r.Answer, mkA(name, "192.0.2.1")) case name == testHostMultiA && qtype == dns.TypeA: r.Answer = append(r.Answer, mkA(name, "192.0.2.1"), mkA(name, "192.0.2.2"), ) case name == testHostIPv6 && qtype == dns.TypeAAAA: r.Answer = append(r.Answer, mkAAAA(name, "2001:db8::1")) case name == testHostDualStack && qtype == dns.TypeA: r.Answer = append(r.Answer, mkA(name, "192.0.2.1")) case name == testHostDualStack && qtype == dns.TypeAAAA: r.Answer = append(r.Answer, mkAAAA(name, "2001:db8::1")) case name == testHostCNAME && qtype == dns.TypeCNAME: r.Answer = append(r.Answer, &dns.CNAME{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 3600, }, Target: testHostCNAMETgt, }) case name == testHostCNAMETgt && qtype == dns.TypeA: r.Answer = append(r.Answer, mkA(name, "198.51.100.1")) case name == testHostMX && qtype == dns.TypeMX: r.Answer = append(r.Answer, &dns.MX{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 3600, }, Preference: 10, Mx: "mail.example.com.", }) case name == testHostTXT && qtype == dns.TypeTXT: r.Answer = append(r.Answer, &dns.TXT{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 3600, }, Txt: []string{"v=spf1 -all"}, }) case name == ns1Name && qtype == dns.TypeA: r.Answer = append(r.Answer, mkA(name, ns1IP)) case name == ns2Name && qtype == dns.TypeA: r.Answer = append(r.Answer, mkA(name, ns2IP)) } } func mkA(name, ip string) *dns.A { return &dns.A{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 3600, }, A: net.ParseIP(ip), } } func mkAAAA(name, ip string) *dns.AAAA { return &dns.AAAA{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 3600, }, AAAA: net.ParseIP(ip), } } // ---------------------------------------------------------------- // Test helpers // ---------------------------------------------------------------- func newTestResolver(t *testing.T) *resolver.Resolver { t.Helper() log := slog.New(slog.NewTextHandler( os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}, )) return resolver.NewFromLoggerWithClient( log, buildMockClient(), ) } func testContext(t *testing.T) context.Context { t.Helper() ctx, cancel := context.WithTimeout( context.Background(), 5*time.Second, ) t.Cleanup(cancel) return ctx } // ---------------------------------------------------------------- // FindAuthoritativeNameservers tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ValidDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) assert.Contains(t, nameservers, ns1Name) assert.Contains(t, nameservers, ns2Name) } func TestFindAuthoritativeNameservers_Subdomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "basic.example.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) assert.Contains(t, nameservers, ns1Name) } func TestFindAuthoritativeNameservers_ReturnsSorted( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) assert.True( t, sort.StringsAreSorted(nameservers), "nameservers should be sorted, got: %v", nameservers, ) } func TestFindAuthoritativeNameservers_Deterministic( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) first, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) second, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) assert.Equal(t, first, second) } func TestFindAuthoritativeNameservers_TrailingDot( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns1, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) ns2, err := r.FindAuthoritativeNameservers( ctx, "example.com.", ) require.NoError(t, err) assert.Equal(t, ns1, ns2) } // ---------------------------------------------------------------- // QueryNameserver tests // ---------------------------------------------------------------- func findOneNS( t *testing.T, r *resolver.Resolver, ctx context.Context, //nolint:revive // test helper ) string { t.Helper() nameservers, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) require.NotEmpty(t, nameservers) return nameservers[0] } func TestQueryNameserver_BasicA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "basic.example.com") require.NoError(t, err) require.NotNil(t, resp) assert.Equal(t, resolver.StatusOK, resp.Status) assert.Equal(t, ns, resp.Nameserver) assert.Contains(t, resp.Records["A"], "192.0.2.1") } func TestQueryNameserver_MultipleA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "multi.example.com") require.NoError(t, err) aRecords := resp.Records["A"] sort.Strings(aRecords) assert.Equal(t, []string{"192.0.2.1", "192.0.2.2"}, aRecords) } func TestQueryNameserver_AAAA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "ipv6.example.com") require.NoError(t, err) assert.Contains(t, resp.Records["AAAA"], "2001:db8::1") } func TestQueryNameserver_DualStack(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "dual.example.com") require.NoError(t, err) assert.Contains(t, resp.Records["A"], "192.0.2.1") assert.Contains(t, resp.Records["AAAA"], "2001:db8::1") } func TestQueryNameserver_CNAME(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "cname.example.com") require.NoError(t, err) assert.Contains( t, resp.Records["CNAME"], testHostCNAMETgt, ) } func TestQueryNameserver_MX(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "mx.example.com") require.NoError(t, err) mxRecords := resp.Records["MX"] require.NotEmpty(t, mxRecords) hasMail := false for _, mx := range mxRecords { if strings.Contains(mx, "mail.example.com.") { hasMail = true break } } assert.True(t, hasMail) } func TestQueryNameserver_TXT(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "txt.example.com") require.NoError(t, err) hasSPF := false for _, txt := range resp.Records["TXT"] { if strings.Contains(txt, "v=spf1") { hasSPF = true break } } assert.True(t, hasSPF) } func TestQueryNameserver_NXDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver( ctx, ns, "nxdomain.example.com", ) require.NoError(t, err) assert.Equal(t, resolver.StatusNXDomain, resp.Status) } func TestQueryNameserver_RecordsSorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "multi.example.com") require.NoError(t, err) for recordType, values := range resp.Records { assert.True( t, sort.StringsAreSorted(values), "%s records should be sorted", recordType, ) } } func TestQueryNameserver_ResponseIncludesNameserver( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver(ctx, ns, "basic.example.com") require.NoError(t, err) assert.Equal(t, ns, resp.Nameserver) } func TestQueryNameserver_EmptyRecordsOnNXDomain( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp, err := r.QueryNameserver( ctx, ns, "nxdomain.example.com", ) require.NoError(t, err) totalRecords := 0 for _, values := range resp.Records { totalRecords += len(values) } assert.Zero(t, totalRecords) } func TestQueryNameserver_TrailingDotHandling(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ns := findOneNS(t, r, ctx) resp1, err := r.QueryNameserver( ctx, ns, "basic.example.com", ) require.NoError(t, err) resp2, err := r.QueryNameserver( ctx, ns, "basic.example.com.", ) require.NoError(t, err) assert.Equal(t, resp1.Records["A"], resp2.Records["A"]) } // ---------------------------------------------------------------- // QueryAllNameservers tests // ---------------------------------------------------------------- func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "basic.example.com", ) require.NoError(t, err) require.NotEmpty(t, results) assert.GreaterOrEqual(t, len(results), 2) for ns, resp := range results { assert.Equal(t, ns, resp.Nameserver) } } func TestQueryAllNameservers_Consistent(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "basic.example.com", ) require.NoError(t, err) var referenceA []string for ns, resp := range results { assert.Equal( t, resolver.StatusOK, resp.Status, "NS %s should return OK", ns, ) if referenceA == nil { referenceA = resp.Records["A"] continue } assert.Equal(t, referenceA, resp.Records["A"]) } } func TestQueryAllNameservers_NXDomainFromAllNS( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) results, err := r.QueryAllNameservers( ctx, "nxdomain.example.com", ) require.NoError(t, err) for ns, resp := range results { assert.Equal( t, resolver.StatusNXDomain, resp.Status, "NS %s should return nxdomain", ns, ) } } // ---------------------------------------------------------------- // LookupNS tests // ---------------------------------------------------------------- func TestLookupNS_ValidDomain(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "example.com") require.NoError(t, err) require.NotEmpty(t, nameservers) for _, ns := range nameservers { assert.True(t, strings.HasSuffix(ns, ".")) } } func TestLookupNS_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) nameservers, err := r.LookupNS(ctx, "example.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(nameservers)) } func TestLookupNS_MatchesFindAuthoritative(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) fromLookup, err := r.LookupNS(ctx, "example.com") require.NoError(t, err) fromFind, err := r.FindAuthoritativeNameservers( ctx, "example.com", ) require.NoError(t, err) assert.Equal(t, fromFind, fromLookup) } // ---------------------------------------------------------------- // ResolveIPAddresses tests // ---------------------------------------------------------------- func TestResolveIPAddresses_BasicA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "basic.example.com") require.NoError(t, err) assert.Contains(t, ips, "192.0.2.1") } func TestResolveIPAddresses_MultipleA(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "multi.example.com") require.NoError(t, err) assert.Contains(t, ips, "192.0.2.1") assert.Contains(t, ips, "192.0.2.2") } func TestResolveIPAddresses_IPv6Only(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "ipv6.example.com") require.NoError(t, err) require.NotEmpty(t, ips) assert.Contains(t, ips, "2001:db8::1") for _, ip := range ips { parsed := net.ParseIP(ip) require.NotNil(t, parsed) assert.Nil(t, parsed.To4(), "should not contain IPv4: %s", ip, ) } } func TestResolveIPAddresses_DualStack(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "dual.example.com") require.NoError(t, err) assert.Contains(t, ips, "192.0.2.1") assert.Contains(t, ips, "2001:db8::1") } func TestResolveIPAddresses_FollowsCNAME(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "cname.example.com") require.NoError(t, err) assert.Contains(t, ips, "198.51.100.1") } func TestResolveIPAddresses_Deduplicated(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "basic.example.com") require.NoError(t, err) seen := make(map[string]bool) for _, ip := range ips { assert.False(t, seen[ip], "duplicate IP: %s", ip) seen[ip] = true } } func TestResolveIPAddresses_Sorted(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses(ctx, "dual.example.com") require.NoError(t, err) assert.True(t, sort.StringsAreSorted(ips)) } func TestResolveIPAddresses_NXDomainReturnsEmpty( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx := testContext(t) ips, err := r.ResolveIPAddresses( ctx, "nxdomain.example.com", ) require.NoError(t, err) assert.Empty(t, ips) } // ---------------------------------------------------------------- // Context cancellation tests // ---------------------------------------------------------------- func TestFindAuthoritativeNameservers_ContextCanceled( t *testing.T, ) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.FindAuthoritativeNameservers(ctx, "example.com") assert.Error(t, err) } func TestQueryNameserver_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryNameserver( ctx, "ns1.example.com.", "basic.example.com", ) assert.Error(t, err) } func TestQueryAllNameservers_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.QueryAllNameservers(ctx, "basic.example.com") assert.Error(t, err) } func TestResolveIPAddresses_ContextCanceled(t *testing.T) { t.Parallel() r := newTestResolver(t) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := r.ResolveIPAddresses(ctx, "basic.example.com") assert.Error(t, err) }