1 Commits

Author SHA1 Message Date
f2970143d2 fix: retry on DNS timeout, distinguish authoritative negatives (closes #35)
Some checks failed
Check / check (pull_request) Failing after 6m4s
- Add StatusTimeout constant for timeout responses
- querySingleType now retries on timeout and SERVFAIL (3 attempts,
  exponential backoff starting at 100ms)
- NXDOMAIN and NOERROR+empty are treated as authoritative negatives
  with no retry
- classifyResponse sets structured error messages for timeout and
  SERVFAIL cases
- Refactored into smaller functions to satisfy cyclomatic complexity
  limits
2026-02-28 03:22:57 -08:00
3 changed files with 113 additions and 93 deletions

View File

@@ -4,6 +4,11 @@ import "errors"
// Sentinel errors returned by the resolver. // Sentinel errors returned by the resolver.
var ( var (
// ErrNotImplemented indicates a method is stubbed out.
ErrNotImplemented = errors.New(
"resolver not yet implemented",
)
// ErrNoNameservers is returned when no authoritative NS // ErrNoNameservers is returned when no authoritative NS
// could be discovered for a domain. // could be discovered for a domain.
ErrNoNameservers = errors.New( ErrNoNameservers = errors.New(
@@ -19,4 +24,8 @@ var (
// ErrContextCanceled wraps context cancellation for the // ErrContextCanceled wraps context cancellation for the
// resolver's iterative queries. // resolver's iterative queries.
ErrContextCanceled = errors.New("context canceled") ErrContextCanceled = errors.New("context canceled")
// ErrSERVFAIL is returned when a DNS server responds with
// SERVFAIL after all retries are exhausted.
ErrSERVFAIL = errors.New("SERVFAIL from server")
) )

View File

@@ -435,23 +435,6 @@ func (r *Resolver) QueryNameserver(
return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname) return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname)
} }
// QueryNameserverIP queries a nameserver by its IP address directly,
// bypassing NS hostname resolution.
func (r *Resolver) QueryNameserverIP(
ctx context.Context,
nsHostname string,
nsIP string,
hostname string,
) (*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
hostname = dns.Fqdn(hostname)
return r.queryAllTypes(ctx, nsHostname, nsIP, hostname)
}
func (r *Resolver) queryAllTypes( func (r *Resolver) queryAllTypes(
ctx context.Context, ctx context.Context,
nsHostname string, nsHostname string,
@@ -476,6 +459,11 @@ func (r *Resolver) queryAllTypes(
return resp, nil return resp, nil
} }
const (
singleTypeMaxRetries = 3
singleTypeInitialBackoff = 100 * time.Millisecond
)
type queryState struct { type queryState struct {
gotNXDomain bool gotNXDomain bool
gotSERVFAIL bool gotSERVFAIL bool
@@ -507,6 +495,21 @@ func (r *Resolver) queryEachType(
return state return state
} }
// isTimeout checks whether an error represents a DNS timeout.
func isTimeout(err error) bool {
if err == nil {
return false
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
// Also catch i/o timeout strings from the dns library.
return strings.Contains(err.Error(), "i/o timeout")
}
func (r *Resolver) querySingleType( func (r *Resolver) querySingleType(
ctx context.Context, ctx context.Context,
nsIP string, nsIP string,
@@ -515,27 +518,99 @@ func (r *Resolver) querySingleType(
resp *NameserverResponse, resp *NameserverResponse,
state *queryState, state *queryState,
) { ) {
msg, err := r.queryDNS(ctx, nsIP, hostname, qtype) msg, lastErr := r.querySingleTypeWithRetry(
if err != nil { ctx, nsIP, hostname, qtype,
if isTimeout(err) { )
state.gotTimeout = true if msg == nil {
r.recordRetryFailure(lastErr, state)
return
}
r.handleDNSResponse(msg, resp, state)
}
func (r *Resolver) querySingleTypeWithRetry(
ctx context.Context,
nsIP string,
hostname string,
qtype uint16,
) (*dns.Msg, error) {
var lastErr error
backoff := singleTypeInitialBackoff
for attempt := range singleTypeMaxRetries {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
} }
if attempt > 0 {
if !waitBackoff(ctx, backoff) {
return nil, ErrContextCanceled
}
backoff *= timeoutMultiplier
}
msg, err := r.queryDNS(ctx, nsIP, hostname, qtype)
if err != nil {
lastErr = err
if !isTimeout(err) {
return nil, err
}
continue
}
if msg.Rcode == dns.RcodeServerFailure {
lastErr = ErrSERVFAIL
continue
}
return msg, nil
}
return nil, lastErr
}
func waitBackoff(ctx context.Context, d time.Duration) bool {
select {
case <-ctx.Done():
return false
case <-time.After(d):
return true
}
}
func (r *Resolver) recordRetryFailure(
lastErr error,
state *queryState,
) {
if lastErr == nil {
return return
} }
if isTimeout(lastErr) {
state.gotTimeout = true
} else if errors.Is(lastErr, ErrSERVFAIL) {
state.gotSERVFAIL = true
}
}
func (r *Resolver) handleDNSResponse(
msg *dns.Msg,
resp *NameserverResponse,
state *queryState,
) {
if msg.Rcode == dns.RcodeNameError { if msg.Rcode == dns.RcodeNameError {
state.gotNXDomain = true state.gotNXDomain = true
return return
} }
if msg.Rcode == dns.RcodeServerFailure {
state.gotSERVFAIL = true
return
}
collectAnswerRecords(msg, resp, state) collectAnswerRecords(msg, resp, state)
} }
@@ -558,26 +633,16 @@ func collectAnswerRecords(
} }
} }
// isTimeout checks whether an error is a network timeout.
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func classifyResponse(resp *NameserverResponse, state queryState) { func classifyResponse(resp *NameserverResponse, state queryState) {
switch { switch {
case state.gotNXDomain && !state.hasRecords: case state.gotNXDomain && !state.hasRecords:
resp.Status = StatusNXDomain resp.Status = StatusNXDomain
case state.gotTimeout && !state.hasRecords: case state.gotTimeout && !state.hasRecords:
resp.Status = StatusTimeout resp.Status = StatusTimeout
resp.Error = "all queries timed out" resp.Error = "all queries timed out after retries"
case state.gotSERVFAIL && !state.hasRecords: case state.gotSERVFAIL && !state.hasRecords:
resp.Status = StatusError resp.Status = StatusError
resp.Error = "server returned SERVFAIL" resp.Error = "server failure (SERVFAIL) after retries"
case !state.hasRecords && !state.gotNXDomain: case !state.hasRecords && !state.gotNXDomain:
resp.Status = StatusNoData resp.Status = StatusNoData
} }

View File

@@ -10,7 +10,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -623,59 +622,6 @@ func TestQueryAllNameservers_ContextCanceled(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} }
// ----------------------------------------------------------------
// Timeout tests
// ----------------------------------------------------------------
func TestQueryNameserverIP_Timeout(t *testing.T) {
t.Parallel()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
r := resolver.NewFromLoggerWithClient(
log, &timeoutClient{},
)
ctx, cancel := context.WithTimeout(
context.Background(), 10*time.Second,
)
t.Cleanup(cancel)
// Query any IP — the client always returns a timeout error.
resp, err := r.QueryNameserverIP(
ctx, "unreachable.test.", "192.0.2.1",
"example.com",
)
require.NoError(t, err)
assert.Equal(t, resolver.StatusTimeout, resp.Status)
assert.NotEmpty(t, resp.Error)
}
// timeoutClient simulates DNS timeout errors for testing.
type timeoutClient struct{}
func (c *timeoutClient) ExchangeContext(
_ context.Context,
_ *dns.Msg,
_ string,
) (*dns.Msg, time.Duration, error) {
return nil, 0, &net.OpError{
Op: "read",
Net: "udp",
Err: &timeoutError{},
}
}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
func TestResolveIPAddresses_ContextCanceled(t *testing.T) { func TestResolveIPAddresses_ContextCanceled(t *testing.T) {
t.Parallel() t.Parallel()