diff --git a/internal/resolver/iterative.go b/internal/resolver/iterative.go index 8f41b6d..68ce52f 100644 --- a/internal/resolver/iterative.go +++ b/internal/resolver/iterative.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "net" "sort" "strings" @@ -13,7 +14,7 @@ import ( ) const ( - queryTimeoutDuration = 5 * time.Second + queryTimeoutDuration = 2 * time.Second maxRetries = 2 maxDelegation = 20 timeoutMultiplier = 2 @@ -41,6 +42,22 @@ func rootServerList() []string { } } +const maxRootServers = 3 + +// randomRootServers returns a shuffled subset of root servers. +func randomRootServers() []string { + all := rootServerList() + rand.Shuffle(len(all), func(i, j int) { + all[i], all[j] = all[j], all[i] + }) + + if len(all) > maxRootServers { + return all[:maxRootServers] + } + + return all +} + func checkCtx(ctx context.Context) error { err := ctx.Err() if err != nil { @@ -302,7 +319,7 @@ func (r *Resolver) resolveNSRecursive( msg.SetQuestion(domain, dns.TypeNS) msg.RecursionDesired = true - for _, ip := range rootServerList()[:3] { + for _, ip := range randomRootServers() { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } @@ -333,7 +350,7 @@ func (r *Resolver) resolveARecord( msg.SetQuestion(hostname, dns.TypeA) msg.RecursionDesired = true - for _, ip := range rootServerList()[:3] { + for _, ip := range randomRootServers() { if checkCtx(ctx) != nil { return nil, ErrContextCanceled } @@ -385,7 +402,7 @@ func (r *Resolver) FindAuthoritativeNameservers( candidate := strings.Join(labels[i:], ".") + "." nsNames, err := r.followDelegation( - ctx, candidate, rootServerList(), + ctx, candidate, randomRootServers(), ) if err == nil && len(nsNames) > 0 { sort.Strings(nsNames)