- Extract DNSClient interface from resolver to allow dependency injection - Convert all resolver methods from package-level to receiver methods using the injectable DNS client - Rewrite resolver_test.go with a mock DNS client that simulates the full delegation chain (root → TLD → authoritative) in-process - Move 2 integration tests (real DNS) behind //go:build integration tag - Add NewFromLoggerWithClient constructor for test injection - Add LookupAllRecords implementation (was returning ErrNotImplemented) All unit tests are hermetic (no network) and complete in <1s. Total make check passes in ~5s. Closes #12
1035 lines
21 KiB
Go
1035 lines
21 KiB
Go
package resolver_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"slices"
|
|
"sort"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"sneak.berlin/go/dnswatcher/internal/resolver"
|
|
)
|
|
|
|
var (
|
|
errNoQuestion = errors.New("no question")
|
|
errUnexpectedServer = errors.New("unexpected server")
|
|
)
|
|
|
|
// mockDNSClient implements resolver.DNSClient for hermetic tests.
|
|
// It dispatches responses based on the query name, type, and
|
|
// server address to simulate a full delegation chain.
|
|
type mockDNSClient struct {
|
|
// handler returns a response for a given query. Return nil
|
|
// to simulate a network error.
|
|
handler func(
|
|
msg *dns.Msg, addr string,
|
|
) (*dns.Msg, error)
|
|
}
|
|
|
|
func (m *mockDNSClient) ExchangeContext(
|
|
_ context.Context,
|
|
msg *dns.Msg,
|
|
addr string,
|
|
) (*dns.Msg, time.Duration, error) {
|
|
resp, err := m.handler(msg, addr)
|
|
|
|
return resp, 0, err
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// Zone data used across tests
|
|
// ----------------------------------------------------------------
|
|
|
|
const (
|
|
testDomain = "example.com."
|
|
testHostBasic = "basic.example.com."
|
|
testHostMultiA = "multi.example.com."
|
|
testHostIPv6 = "ipv6.example.com."
|
|
testHostDualStack = "dual.example.com."
|
|
testHostCNAME = "cname.example.com."
|
|
testHostCNAMETgt = "cname-target.example.com."
|
|
testHostMX = "mx.example.com."
|
|
testHostTXT = "txt.example.com."
|
|
testHostNXDomain = "nxdomain.example.com."
|
|
|
|
ns1Name = "ns1.example.com."
|
|
ns2Name = "ns2.example.com."
|
|
ns1IP = "10.0.0.1"
|
|
ns2IP = "10.0.0.2"
|
|
|
|
// Root and TLD nameserver IPs (we intercept all queries).
|
|
comNS = "192.0.2.100"
|
|
comName = "a.gtld-servers.net."
|
|
)
|
|
|
|
// newReply creates a dns.Msg reply for the given question.
|
|
func newReply(q *dns.Msg, rcode int) *dns.Msg {
|
|
r := new(dns.Msg)
|
|
r.SetReply(q)
|
|
r.Rcode = rcode
|
|
|
|
return r
|
|
}
|
|
|
|
// buildMockClient returns a mockDNSClient that simulates:
|
|
//
|
|
// root -> .com TLD delegation -> example.com NS delegation
|
|
// -> authoritative answers for test hostnames.
|
|
func buildMockClient() *mockDNSClient {
|
|
return &mockDNSClient{
|
|
handler: func(
|
|
msg *dns.Msg, addr string,
|
|
) (*dns.Msg, error) {
|
|
if len(msg.Question) == 0 {
|
|
return nil, errNoQuestion
|
|
}
|
|
|
|
q := msg.Question[0]
|
|
name := strings.ToLower(q.Name)
|
|
host, _, _ := net.SplitHostPort(addr)
|
|
|
|
// --- Root servers: delegate .com ---
|
|
if isRootServer(host) {
|
|
return handleRoot(msg, name, q.Qtype)
|
|
}
|
|
|
|
// --- .com TLD: delegate example.com ---
|
|
if host == comNS {
|
|
return handleTLD(msg, name)
|
|
}
|
|
|
|
// --- Authoritative NS for example.com ---
|
|
if host == ns1IP || host == ns2IP {
|
|
return handleAuth(msg, name, q.Qtype)
|
|
}
|
|
|
|
return nil, errUnexpectedServer
|
|
},
|
|
}
|
|
}
|
|
|
|
func isRootServer(ip string) bool {
|
|
roots := []string{
|
|
"198.41.0.4", "170.247.170.2", "192.33.4.12",
|
|
"199.7.91.13", "192.203.230.10", "192.5.5.241",
|
|
"192.112.36.4", "198.97.190.53", "192.36.148.17",
|
|
"192.58.128.30", "193.0.14.129", "199.7.83.42",
|
|
"202.12.27.33",
|
|
}
|
|
|
|
return slices.Contains(roots, ip)
|
|
}
|
|
|
|
func handleRoot(
|
|
msg *dns.Msg, name string, qtype uint16,
|
|
) (*dns.Msg, error) {
|
|
r := newReply(msg, dns.RcodeSuccess)
|
|
|
|
// If asking for NS of "com.", return answer
|
|
if name == "com." && qtype == dns.TypeNS {
|
|
appendComDelegation(r, true)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
// Recursive A queries for NS hostnames (used by resolveARecord)
|
|
if resp, ok := handleRootNSResolution(
|
|
r, msg, name, qtype,
|
|
); ok {
|
|
return resp, nil
|
|
}
|
|
|
|
// For anything under .com, delegate to TLD
|
|
if strings.HasSuffix(name, ".com.") || name == "com." {
|
|
appendComDelegation(r, false)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func handleRootNSResolution(
|
|
r *dns.Msg, msg *dns.Msg, name string, qtype uint16,
|
|
) (*dns.Msg, bool) {
|
|
if qtype != dns.TypeA {
|
|
return nil, false
|
|
}
|
|
|
|
if msg.RecursionDesired || !strings.HasSuffix(name, ".com.") {
|
|
switch name {
|
|
case ns1Name:
|
|
r.Answer = append(r.Answer, mkA(name, ns1IP))
|
|
|
|
return r, true
|
|
case ns2Name:
|
|
r.Answer = append(r.Answer, mkA(name, ns2IP))
|
|
|
|
return r, true
|
|
}
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
func appendComDelegation(r *dns.Msg, inAnswer bool) {
|
|
ns := &dns.NS{
|
|
Hdr: dns.RR_Header{
|
|
Name: "com.",
|
|
Rrtype: dns.TypeNS,
|
|
Class: dns.ClassINET,
|
|
Ttl: 3600,
|
|
},
|
|
Ns: comName,
|
|
}
|
|
glue := &dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: comName,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: 3600,
|
|
},
|
|
A: net.ParseIP(comNS),
|
|
}
|
|
|
|
if inAnswer {
|
|
r.Answer = append(r.Answer, ns)
|
|
} else {
|
|
r.Ns = append(r.Ns, ns)
|
|
}
|
|
|
|
r.Extra = append(r.Extra, glue)
|
|
}
|
|
|
|
func handleTLD(msg *dns.Msg, name string) (*dns.Msg, error) {
|
|
r := newReply(msg, dns.RcodeSuccess)
|
|
|
|
if !strings.HasSuffix(name, "example.com.") &&
|
|
name != "example.com." {
|
|
r.Rcode = dns.RcodeNameError
|
|
|
|
return r, nil
|
|
}
|
|
|
|
// Delegate to example.com NS
|
|
r.Ns = append(r.Ns,
|
|
&dns.NS{
|
|
Hdr: dns.RR_Header{
|
|
Name: "example.com.",
|
|
Rrtype: dns.TypeNS,
|
|
Class: dns.ClassINET,
|
|
Ttl: 3600,
|
|
},
|
|
Ns: ns1Name,
|
|
},
|
|
&dns.NS{
|
|
Hdr: dns.RR_Header{
|
|
Name: "example.com.",
|
|
Rrtype: dns.TypeNS,
|
|
Class: dns.ClassINET,
|
|
Ttl: 3600,
|
|
},
|
|
Ns: ns2Name,
|
|
},
|
|
)
|
|
r.Extra = append(r.Extra,
|
|
&dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: ns1Name, Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
A: net.ParseIP(ns1IP),
|
|
},
|
|
&dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: ns2Name, Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
A: net.ParseIP(ns2IP),
|
|
},
|
|
)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func handleAuth(
|
|
msg *dns.Msg, name string, qtype uint16,
|
|
) (*dns.Msg, error) {
|
|
r := newReply(msg, dns.RcodeSuccess)
|
|
|
|
if name == "example.com." && qtype == dns.TypeNS {
|
|
appendExampleNS(r)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
if name == testHostNXDomain {
|
|
r.Rcode = dns.RcodeNameError
|
|
|
|
return r, nil
|
|
}
|
|
|
|
addAuthRecords(r, name, qtype)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func appendExampleNS(r *dns.Msg) {
|
|
r.Answer = append(r.Answer,
|
|
&dns.NS{
|
|
Hdr: dns.RR_Header{
|
|
Name: "example.com.", Rrtype: dns.TypeNS,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
Ns: ns1Name,
|
|
},
|
|
&dns.NS{
|
|
Hdr: dns.RR_Header{
|
|
Name: "example.com.", Rrtype: dns.TypeNS,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
Ns: ns2Name,
|
|
},
|
|
)
|
|
}
|
|
|
|
//nolint:cyclop // dispatch table for test zone records
|
|
func addAuthRecords(
|
|
r *dns.Msg, name string, qtype uint16,
|
|
) {
|
|
switch {
|
|
case name == testHostBasic && qtype == dns.TypeA:
|
|
r.Answer = append(r.Answer, mkA(name, "192.0.2.1"))
|
|
|
|
case name == testHostMultiA && qtype == dns.TypeA:
|
|
r.Answer = append(r.Answer,
|
|
mkA(name, "192.0.2.1"), mkA(name, "192.0.2.2"),
|
|
)
|
|
|
|
case name == testHostIPv6 && qtype == dns.TypeAAAA:
|
|
r.Answer = append(r.Answer, mkAAAA(name, "2001:db8::1"))
|
|
|
|
case name == testHostDualStack && qtype == dns.TypeA:
|
|
r.Answer = append(r.Answer, mkA(name, "192.0.2.1"))
|
|
|
|
case name == testHostDualStack && qtype == dns.TypeAAAA:
|
|
r.Answer = append(r.Answer, mkAAAA(name, "2001:db8::1"))
|
|
|
|
case name == testHostCNAME && qtype == dns.TypeCNAME:
|
|
r.Answer = append(r.Answer, &dns.CNAME{
|
|
Hdr: dns.RR_Header{
|
|
Name: name, Rrtype: dns.TypeCNAME,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
Target: testHostCNAMETgt,
|
|
})
|
|
|
|
case name == testHostCNAMETgt && qtype == dns.TypeA:
|
|
r.Answer = append(r.Answer, mkA(name, "198.51.100.1"))
|
|
|
|
case name == testHostMX && qtype == dns.TypeMX:
|
|
r.Answer = append(r.Answer, &dns.MX{
|
|
Hdr: dns.RR_Header{
|
|
Name: name, Rrtype: dns.TypeMX,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
Preference: 10, Mx: "mail.example.com.",
|
|
})
|
|
|
|
case name == testHostTXT && qtype == dns.TypeTXT:
|
|
r.Answer = append(r.Answer, &dns.TXT{
|
|
Hdr: dns.RR_Header{
|
|
Name: name, Rrtype: dns.TypeTXT,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
Txt: []string{"v=spf1 -all"},
|
|
})
|
|
|
|
case name == ns1Name && qtype == dns.TypeA:
|
|
r.Answer = append(r.Answer, mkA(name, ns1IP))
|
|
|
|
case name == ns2Name && qtype == dns.TypeA:
|
|
r.Answer = append(r.Answer, mkA(name, ns2IP))
|
|
}
|
|
}
|
|
|
|
func mkA(name, ip string) *dns.A {
|
|
return &dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: name, Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
A: net.ParseIP(ip),
|
|
}
|
|
}
|
|
|
|
func mkAAAA(name, ip string) *dns.AAAA {
|
|
return &dns.AAAA{
|
|
Hdr: dns.RR_Header{
|
|
Name: name, Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET, Ttl: 3600,
|
|
},
|
|
AAAA: net.ParseIP(ip),
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// Test helpers
|
|
// ----------------------------------------------------------------
|
|
|
|
func newTestResolver(t *testing.T) *resolver.Resolver {
|
|
t.Helper()
|
|
|
|
log := slog.New(slog.NewTextHandler(
|
|
os.Stderr,
|
|
&slog.HandlerOptions{Level: slog.LevelDebug},
|
|
))
|
|
|
|
return resolver.NewFromLoggerWithClient(
|
|
log, buildMockClient(),
|
|
)
|
|
}
|
|
|
|
func testContext(t *testing.T) context.Context {
|
|
t.Helper()
|
|
|
|
ctx, cancel := context.WithTimeout(
|
|
context.Background(), 5*time.Second,
|
|
)
|
|
t.Cleanup(cancel)
|
|
|
|
return ctx
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// FindAuthoritativeNameservers tests
|
|
// ----------------------------------------------------------------
|
|
|
|
func TestFindAuthoritativeNameservers_ValidDomain(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
nameservers, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, nameservers)
|
|
|
|
assert.Contains(t, nameservers, ns1Name)
|
|
assert.Contains(t, nameservers, ns2Name)
|
|
}
|
|
|
|
func TestFindAuthoritativeNameservers_Subdomain(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
nameservers, err := r.FindAuthoritativeNameservers(
|
|
ctx, "basic.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, nameservers)
|
|
|
|
assert.Contains(t, nameservers, ns1Name)
|
|
}
|
|
|
|
func TestFindAuthoritativeNameservers_ReturnsSorted(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
nameservers, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
assert.True(
|
|
t,
|
|
sort.StringsAreSorted(nameservers),
|
|
"nameservers should be sorted, got: %v", nameservers,
|
|
)
|
|
}
|
|
|
|
func TestFindAuthoritativeNameservers_Deterministic(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
first, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
second, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, first, second)
|
|
}
|
|
|
|
func TestFindAuthoritativeNameservers_TrailingDot(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ns1, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
ns2, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com.",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, ns1, ns2)
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// QueryNameserver tests
|
|
// ----------------------------------------------------------------
|
|
|
|
func findOneNS(
|
|
t *testing.T,
|
|
r *resolver.Resolver,
|
|
ctx context.Context, //nolint:revive // test helper
|
|
) string {
|
|
t.Helper()
|
|
|
|
nameservers, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, nameservers)
|
|
|
|
return nameservers[0]
|
|
}
|
|
|
|
func TestQueryNameserver_BasicA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "basic.example.com")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
|
|
assert.Equal(t, resolver.StatusOK, resp.Status)
|
|
assert.Equal(t, ns, resp.Nameserver)
|
|
assert.Contains(t, resp.Records["A"], "192.0.2.1")
|
|
}
|
|
|
|
func TestQueryNameserver_MultipleA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "multi.example.com")
|
|
require.NoError(t, err)
|
|
|
|
aRecords := resp.Records["A"]
|
|
sort.Strings(aRecords)
|
|
assert.Equal(t, []string{"192.0.2.1", "192.0.2.2"}, aRecords)
|
|
}
|
|
|
|
func TestQueryNameserver_AAAA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "ipv6.example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.Contains(t, resp.Records["AAAA"], "2001:db8::1")
|
|
}
|
|
|
|
func TestQueryNameserver_DualStack(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "dual.example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.Contains(t, resp.Records["A"], "192.0.2.1")
|
|
assert.Contains(t, resp.Records["AAAA"], "2001:db8::1")
|
|
}
|
|
|
|
func TestQueryNameserver_CNAME(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "cname.example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.Contains(
|
|
t, resp.Records["CNAME"], testHostCNAMETgt,
|
|
)
|
|
}
|
|
|
|
func TestQueryNameserver_MX(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "mx.example.com")
|
|
require.NoError(t, err)
|
|
|
|
mxRecords := resp.Records["MX"]
|
|
require.NotEmpty(t, mxRecords)
|
|
|
|
hasMail := false
|
|
|
|
for _, mx := range mxRecords {
|
|
if strings.Contains(mx, "mail.example.com.") {
|
|
hasMail = true
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
assert.True(t, hasMail)
|
|
}
|
|
|
|
func TestQueryNameserver_TXT(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "txt.example.com")
|
|
require.NoError(t, err)
|
|
|
|
hasSPF := false
|
|
|
|
for _, txt := range resp.Records["TXT"] {
|
|
if strings.Contains(txt, "v=spf1") {
|
|
hasSPF = true
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
assert.True(t, hasSPF)
|
|
}
|
|
|
|
func TestQueryNameserver_NXDomain(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(
|
|
ctx, ns, "nxdomain.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, resolver.StatusNXDomain, resp.Status)
|
|
}
|
|
|
|
func TestQueryNameserver_RecordsSorted(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "multi.example.com")
|
|
require.NoError(t, err)
|
|
|
|
for recordType, values := range resp.Records {
|
|
assert.True(
|
|
t,
|
|
sort.StringsAreSorted(values),
|
|
"%s records should be sorted", recordType,
|
|
)
|
|
}
|
|
}
|
|
|
|
func TestQueryNameserver_ResponseIncludesNameserver(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(ctx, ns, "basic.example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, ns, resp.Nameserver)
|
|
}
|
|
|
|
func TestQueryNameserver_EmptyRecordsOnNXDomain(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp, err := r.QueryNameserver(
|
|
ctx, ns, "nxdomain.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
totalRecords := 0
|
|
for _, values := range resp.Records {
|
|
totalRecords += len(values)
|
|
}
|
|
|
|
assert.Zero(t, totalRecords)
|
|
}
|
|
|
|
func TestQueryNameserver_TrailingDotHandling(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
ns := findOneNS(t, r, ctx)
|
|
|
|
resp1, err := r.QueryNameserver(
|
|
ctx, ns, "basic.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
resp2, err := r.QueryNameserver(
|
|
ctx, ns, "basic.example.com.",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, resp1.Records["A"], resp2.Records["A"])
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// QueryAllNameservers tests
|
|
// ----------------------------------------------------------------
|
|
|
|
func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
results, err := r.QueryAllNameservers(
|
|
ctx, "basic.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, results)
|
|
|
|
assert.GreaterOrEqual(t, len(results), 2)
|
|
|
|
for ns, resp := range results {
|
|
assert.Equal(t, ns, resp.Nameserver)
|
|
}
|
|
}
|
|
|
|
func TestQueryAllNameservers_Consistent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
results, err := r.QueryAllNameservers(
|
|
ctx, "basic.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
var referenceA []string
|
|
|
|
for ns, resp := range results {
|
|
assert.Equal(
|
|
t, resolver.StatusOK, resp.Status,
|
|
"NS %s should return OK", ns,
|
|
)
|
|
|
|
if referenceA == nil {
|
|
referenceA = resp.Records["A"]
|
|
|
|
continue
|
|
}
|
|
|
|
assert.Equal(t, referenceA, resp.Records["A"])
|
|
}
|
|
}
|
|
|
|
func TestQueryAllNameservers_NXDomainFromAllNS(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
results, err := r.QueryAllNameservers(
|
|
ctx, "nxdomain.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
for ns, resp := range results {
|
|
assert.Equal(
|
|
t, resolver.StatusNXDomain, resp.Status,
|
|
"NS %s should return nxdomain", ns,
|
|
)
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// LookupNS tests
|
|
// ----------------------------------------------------------------
|
|
|
|
func TestLookupNS_ValidDomain(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
nameservers, err := r.LookupNS(ctx, "example.com")
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, nameservers)
|
|
|
|
for _, ns := range nameservers {
|
|
assert.True(t, strings.HasSuffix(ns, "."))
|
|
}
|
|
}
|
|
|
|
func TestLookupNS_Sorted(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
nameservers, err := r.LookupNS(ctx, "example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.True(t, sort.StringsAreSorted(nameservers))
|
|
}
|
|
|
|
func TestLookupNS_MatchesFindAuthoritative(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
fromLookup, err := r.LookupNS(ctx, "example.com")
|
|
require.NoError(t, err)
|
|
|
|
fromFind, err := r.FindAuthoritativeNameservers(
|
|
ctx, "example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, fromFind, fromLookup)
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// ResolveIPAddresses tests
|
|
// ----------------------------------------------------------------
|
|
|
|
func TestResolveIPAddresses_BasicA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(ctx, "basic.example.com")
|
|
require.NoError(t, err)
|
|
assert.Contains(t, ips, "192.0.2.1")
|
|
}
|
|
|
|
func TestResolveIPAddresses_MultipleA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(ctx, "multi.example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.Contains(t, ips, "192.0.2.1")
|
|
assert.Contains(t, ips, "192.0.2.2")
|
|
}
|
|
|
|
func TestResolveIPAddresses_IPv6Only(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(ctx, "ipv6.example.com")
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, ips)
|
|
assert.Contains(t, ips, "2001:db8::1")
|
|
|
|
for _, ip := range ips {
|
|
parsed := net.ParseIP(ip)
|
|
require.NotNil(t, parsed)
|
|
assert.Nil(t, parsed.To4(),
|
|
"should not contain IPv4: %s", ip,
|
|
)
|
|
}
|
|
}
|
|
|
|
func TestResolveIPAddresses_DualStack(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(ctx, "dual.example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.Contains(t, ips, "192.0.2.1")
|
|
assert.Contains(t, ips, "2001:db8::1")
|
|
}
|
|
|
|
func TestResolveIPAddresses_FollowsCNAME(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(ctx, "cname.example.com")
|
|
require.NoError(t, err)
|
|
assert.Contains(t, ips, "198.51.100.1")
|
|
}
|
|
|
|
func TestResolveIPAddresses_Deduplicated(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(ctx, "basic.example.com")
|
|
require.NoError(t, err)
|
|
|
|
seen := make(map[string]bool)
|
|
|
|
for _, ip := range ips {
|
|
assert.False(t, seen[ip], "duplicate IP: %s", ip)
|
|
seen[ip] = true
|
|
}
|
|
}
|
|
|
|
func TestResolveIPAddresses_Sorted(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(ctx, "dual.example.com")
|
|
require.NoError(t, err)
|
|
|
|
assert.True(t, sort.StringsAreSorted(ips))
|
|
}
|
|
|
|
func TestResolveIPAddresses_NXDomainReturnsEmpty(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx := testContext(t)
|
|
|
|
ips, err := r.ResolveIPAddresses(
|
|
ctx, "nxdomain.example.com",
|
|
)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, ips)
|
|
}
|
|
|
|
// ----------------------------------------------------------------
|
|
// Context cancellation tests
|
|
// ----------------------------------------------------------------
|
|
|
|
func TestFindAuthoritativeNameservers_ContextCanceled(
|
|
t *testing.T,
|
|
) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
_, err := r.FindAuthoritativeNameservers(ctx, "example.com")
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestQueryNameserver_ContextCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
_, err := r.QueryNameserver(
|
|
ctx, "ns1.example.com.", "basic.example.com",
|
|
)
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestQueryAllNameservers_ContextCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
_, err := r.QueryAllNameservers(ctx, "basic.example.com")
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestResolveIPAddresses_ContextCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := newTestResolver(t)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
_, err := r.ResolveIPAddresses(ctx, "basic.example.com")
|
|
assert.Error(t, err)
|
|
}
|