feat: implement iterative DNS resolver (closes #1) #9

Merged
sneak merged 6 commits from feature/resolver into main 2026-02-20 19:37:59 +01:00
7 changed files with 741 additions and 446 deletions
Showing only changes of commit 0486dcfd07 - Show all commits

4
go.mod
View File

@ -38,11 +38,11 @@ require (
go.uber.org/zap v1.26.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.40.0 // indirect
golang.org/x/tools v0.41.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

20
go.sum
View File

@ -76,18 +76,18 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -0,0 +1,48 @@
package resolver
import (
"context"
"time"
"github.com/miekg/dns"
)
// DNSClient abstracts DNS wire-protocol exchanges so the resolver
// can be tested without hitting real nameservers.
type DNSClient interface {
ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error)
}
// udpClient wraps a real dns.Client for production use.
type udpClient struct {
timeout time.Duration
}
func (c *udpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}
// tcpClient wraps a real dns.Client using TCP.
type tcpClient struct {
timeout time.Duration
}
func (c *tcpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Net: "tcp", Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}

View File

@ -50,25 +50,20 @@ func checkCtx(ctx context.Context) error {
return nil
}
func exchangeWithTimeout(
func (r *Resolver) exchangeWithTimeout(
ctx context.Context,
msg *dns.Msg,
addr string,
attempt int,
) (*dns.Msg, error) {
c := new(dns.Client)
c.Timeout = queryTimeoutDuration
_ = attempt // timeout escalation handled by client config
if attempt > 0 {
c.Timeout = queryTimeoutDuration * timeoutMultiplier
}
resp, _, err := c.ExchangeContext(ctx, msg, addr)
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
return resp, err
}
func tryExchange(
func (r *Resolver) tryExchange(
ctx context.Context,
msg *dns.Msg,
addr string,
@ -82,7 +77,9 @@ func tryExchange(
return nil, ErrContextCanceled
}
resp, err = exchangeWithTimeout(ctx, msg, addr, attempt)
resp, err = r.exchangeWithTimeout(
ctx, msg, addr, attempt,
)
if err == nil {
break
}
@ -91,7 +88,7 @@ func tryExchange(
return resp, err
}
func retryTCP(
func (r *Resolver) retryTCP(
ctx context.Context,
msg *dns.Msg,
addr string,
@ -101,12 +98,7 @@ func retryTCP(
return resp
}
c := &dns.Client{
Net: "tcp",
Timeout: queryTimeoutDuration,
}
tcpResp, _, tcpErr := c.ExchangeContext(ctx, msg, addr)
tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr)
if tcpErr == nil {
return tcpResp
}
@ -117,7 +109,7 @@ func retryTCP(
// queryDNS sends a DNS query to a specific server IP.
// Tries non-recursive first, falls back to recursive on
// REFUSED (handles DNS interception environments).
func queryDNS(
func (r *Resolver) queryDNS(
ctx context.Context,
serverIP string,
name string,
@ -134,7 +126,7 @@ func queryDNS(
msg.SetQuestion(name, qtype)
msg.RecursionDesired = false
resp, err := tryExchange(ctx, msg, addr)
resp, err := r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err)
}
@ -142,7 +134,7 @@ func queryDNS(
if resp.Rcode == dns.RcodeRefused {
msg.RecursionDesired = true
resp, err = tryExchange(ctx, msg, addr)
resp, err = r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, err,
@ -156,7 +148,7 @@ func queryDNS(
}
}
resp = retryTCP(ctx, msg, addr, resp)
resp = r.retryTCP(ctx, msg, addr, resp)
return resp, nil
}
@ -221,7 +213,9 @@ func (r *Resolver) followDelegation(
return nil, ErrContextCanceled
}
resp, err := queryServers(ctx, servers, domain, dns.TypeNS)
resp, err := r.queryServers(
ctx, servers, domain, dns.TypeNS,
)
if err != nil {
return nil, err
}
@ -253,7 +247,7 @@ func (r *Resolver) followDelegation(
return nil, ErrNoNameservers
}
func queryServers(
func (r *Resolver) queryServers(
ctx context.Context,
servers []string,
name string,
@ -266,7 +260,7 @@ func queryServers(
return nil, ErrContextCanceled
}
resp, err := queryDNS(ctx, ip, name, qtype)
resp, err := r.queryDNS(ctx, ip, name, qtype)
if err == nil {
return resp, nil
}
@ -308,8 +302,6 @@ func (r *Resolver) resolveNSRecursive(
msg.SetQuestion(domain, dns.TypeNS)
msg.RecursionDesired = true
c := &dns.Client{Timeout: queryTimeoutDuration}
for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
@ -317,7 +309,7 @@ func (r *Resolver) resolveNSRecursive(
addr := net.JoinHostPort(ip, "53")
resp, _, err := c.ExchangeContext(ctx, msg, addr)
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil {
continue
}
@ -341,8 +333,6 @@ func (r *Resolver) resolveARecord(
msg.SetQuestion(hostname, dns.TypeA)
msg.RecursionDesired = true
c := &dns.Client{Timeout: queryTimeoutDuration}
for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
@ -350,7 +340,7 @@ func (r *Resolver) resolveARecord(
addr := net.JoinHostPort(ip, "53")
resp, _, err := c.ExchangeContext(ctx, msg, addr)
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil {
continue
}
@ -490,7 +480,7 @@ func (r *Resolver) querySingleType(
resp *NameserverResponse,
state *queryState,
) {
msg, err := queryDNS(ctx, nsIP, hostname, qtype)
msg, err := r.queryDNS(ctx, nsIP, hostname, qtype)
if err != nil {
return
}
@ -641,6 +631,25 @@ func (r *Resolver) LookupNS(
return r.FindAuthoritativeNameservers(ctx, domain)
}
// LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname, keyed by authoritative nameserver.
func (r *Resolver) LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error) {
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
out := make(map[string]map[string][]string, len(results))
for ns, resp := range results {
out[ns] = resp.Records
}
return out, nil
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains up to MaxCNAMEDepth.
func (r *Resolver) ResolveIPAddresses(

View File

@ -39,7 +39,9 @@ type NameserverResponse struct {
// Resolver performs iterative DNS resolution from root servers.
type Resolver struct {
log *slog.Logger
log *slog.Logger
client DNSClient
tcp DNSClient
}
// New creates a new Resolver instance for use with uber/fx.
@ -48,14 +50,33 @@ func New(
params Params,
) (*Resolver, error) {
return &Resolver{
log: params.Logger.Get(),
log: params.Logger.Get(),
client: &udpClient{timeout: queryTimeoutDuration},
tcp: &tcpClient{timeout: queryTimeoutDuration},
}, nil
}
// NewFromLogger creates a Resolver directly from an slog.Logger,
// useful for testing without the fx lifecycle.
func NewFromLogger(log *slog.Logger) *Resolver {
return &Resolver{log: log}
return &Resolver{
log: log,
client: &udpClient{timeout: queryTimeoutDuration},
tcp: &tcpClient{timeout: queryTimeoutDuration},
}
}
// NewFromLoggerWithClient creates a Resolver with a custom DNS
// client, useful for testing with mock DNS responses.
func NewFromLoggerWithClient(
log *slog.Logger,
client DNSClient,
) *Resolver {
return &Resolver{
log: log,
client: client,
tcp: client,
}
}
// Method implementations are in iterative.go.

View File

@ -0,0 +1,85 @@
//go:build integration
package resolver_test
import (
"context"
"log/slog"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/resolver"
)
// Integration tests hit real DNS servers. Run with:
// go test -tags integration -timeout 60s ./internal/resolver/
func newIntegrationResolver(t *testing.T) *resolver.Resolver {
t.Helper()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
return resolver.NewFromLogger(log)
}
func TestIntegration_FindAuthoritativeNameservers(
t *testing.T,
) {
t.Parallel()
r := newIntegrationResolver(t)
ctx, cancel := context.WithTimeout(
context.Background(), 30*time.Second,
)
defer cancel()
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "example.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
t.Logf("example.com NS: %v", nameservers)
}
func TestIntegration_ResolveIPAddresses(t *testing.T) {
t.Parallel()
r := newIntegrationResolver(t)
ctx, cancel := context.WithTimeout(
context.Background(), 30*time.Second,
)
defer cancel()
// sneak.cloud is on Cloudflare
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "sneak.cloud",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
hasCloudflare := false
for _, ns := range nameservers {
if strings.Contains(ns, "cloudflare") {
hasCloudflare = true
break
}
}
assert.True(t, hasCloudflare,
"sneak.cloud should be on Cloudflare, got: %v",
nameservers,
)
}

File diff suppressed because it is too large Load Diff