diff --git a/internal/resolver/errors.go b/internal/resolver/errors.go index 94bc313..3f203d4 100644 --- a/internal/resolver/errors.go +++ b/internal/resolver/errors.go @@ -4,11 +4,6 @@ import "errors" // Sentinel errors returned by the resolver. var ( - // ErrNotImplemented indicates a method is stubbed out. - ErrNotImplemented = errors.New( - "resolver not yet implemented", - ) - // ErrNoNameservers is returned when no authoritative NS // could be discovered for a domain. ErrNoNameservers = errors.New( diff --git a/internal/resolver/iterative.go b/internal/resolver/iterative.go index 9b26190..eebab39 100644 --- a/internal/resolver/iterative.go +++ b/internal/resolver/iterative.go @@ -460,6 +460,23 @@ func (r *Resolver) QueryNameserver( return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname) } +// QueryNameserverIP queries a nameserver by its IP address directly, +// bypassing NS hostname resolution. +func (r *Resolver) QueryNameserverIP( + ctx context.Context, + nsHostname string, + nsIP string, + hostname string, +) (*NameserverResponse, error) { + if checkCtx(ctx) != nil { + return nil, ErrContextCanceled + } + + hostname = dns.Fqdn(hostname) + + return r.queryAllTypes(ctx, nsHostname, nsIP, hostname) +} + func (r *Resolver) queryAllTypes( ctx context.Context, nsHostname string, @@ -487,6 +504,7 @@ func (r *Resolver) queryAllTypes( type queryState struct { gotNXDomain bool gotSERVFAIL bool + gotTimeout bool hasRecords bool } @@ -524,6 +542,10 @@ func (r *Resolver) querySingleType( ) { msg, err := r.queryDNS(ctx, nsIP, hostname, qtype) if err != nil { + if isTimeout(err) { + state.gotTimeout = true + } + return } @@ -561,12 +583,26 @@ func collectAnswerRecords( } } +// isTimeout checks whether an error is a network timeout. +func isTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + + return false +} + func classifyResponse(resp *NameserverResponse, state queryState) { switch { case state.gotNXDomain && !state.hasRecords: resp.Status = StatusNXDomain + case state.gotTimeout && !state.hasRecords: + resp.Status = StatusTimeout + resp.Error = "all queries timed out" case state.gotSERVFAIL && !state.hasRecords: resp.Status = StatusError + resp.Error = "server returned SERVFAIL" case !state.hasRecords && !state.gotNXDomain: resp.Status = StatusNoData } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 889cdeb..aec9b89 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -17,6 +17,7 @@ const ( StatusError = "error" StatusNXDomain = "nxdomain" StatusNoData = "nodata" + StatusTimeout = "timeout" ) // MaxCNAMEDepth is the maximum CNAME chain depth to follow. diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index 3b9d936..bcebfb9 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -622,6 +623,59 @@ func TestQueryAllNameservers_ContextCanceled(t *testing.T) { assert.Error(t, err) } +// ---------------------------------------------------------------- +// Timeout tests +// ---------------------------------------------------------------- + +func TestQueryNameserverIP_Timeout(t *testing.T) { + t.Parallel() + + log := slog.New(slog.NewTextHandler( + os.Stderr, + &slog.HandlerOptions{Level: slog.LevelDebug}, + )) + + r := resolver.NewFromLoggerWithClient( + log, &timeoutClient{}, + ) + + ctx, cancel := context.WithTimeout( + context.Background(), 10*time.Second, + ) + t.Cleanup(cancel) + + // Query any IP — the client always returns a timeout error. + resp, err := r.QueryNameserverIP( + ctx, "unreachable.test.", "192.0.2.1", + "example.com", + ) + require.NoError(t, err) + + assert.Equal(t, resolver.StatusTimeout, resp.Status) + assert.NotEmpty(t, resp.Error) +} + +// timeoutClient simulates DNS timeout errors for testing. +type timeoutClient struct{} + +func (c *timeoutClient) ExchangeContext( + _ context.Context, + _ *dns.Msg, + _ string, +) (*dns.Msg, time.Duration, error) { + return nil, 0, &net.OpError{ + Op: "read", + Net: "udp", + Err: &timeoutError{}, + } +} + +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + func TestResolveIPAddresses_ContextCanceled(t *testing.T) { t.Parallel()