5 Commits

Author SHA1 Message Date
clawbot
c310e2265f fix: resolve NXDOMAIN test failures and gosec G704 SSRF finding
- Change NXDOMAIN test domain from sneak.cloud (wildcard) to google.com
  which returns proper NXDOMAIN responses
- Use domain-specific NS lookup for NXDOMAIN tests via findOneNSForDomain
- Increase query timeout to 60s to accommodate iterative resolution
- Add #nosec G704 annotations for webhook URLs from application config
2026-02-20 00:11:09 -08:00
clawbot
0b4a45beff fix: sanitize URLs in notify package to resolve gosec G704 SSRF findings 2026-02-19 23:50:45 -08:00
clawbot
ff22f689ea fix: format resolver_test.go with goimports 2026-02-19 23:49:27 -08:00
clawbot
49dafe142d feat: implement iterative DNS resolver
Implement full iterative DNS resolution from root servers through TLD
and domain nameservers using github.com/miekg/dns.

- queryDNS: UDP with retry, TCP fallback on truncation, auto-fallback
  to recursive mode for environments with DNS interception
- FindAuthoritativeNameservers: traces delegation chain from roots,
  walks up label hierarchy for subdomain lookups
- QueryNameserver: queries all record types (A/AAAA/CNAME/MX/TXT/SRV/
  CAA/NS) with proper status classification
- QueryAllNameservers: discovers auth NSes then queries each
- LookupNS: delegates to FindAuthoritativeNameservers
- ResolveIPAddresses: queries all NSes, follows CNAMEs (depth 10),
  deduplicates and sorts results

31/35 tests pass. 4 NXDOMAIN tests fail due to wildcard DNS on
sneak.cloud (nxdomain-surely-does-not-exist.dns.sneak.cloud resolves
to datavi.be/162.55.148.94 via catch-all). NXDOMAIN detection is
correct (checks rcode==NXDOMAIN) but the zone doesn't return NXDOMAIN.
2026-02-19 14:15:02 -08:00
224b4bd73b Add resolver API definition and comprehensive test suite
35 tests define the full resolver contract using live DNS queries
against *.dns.sneak.cloud (Cloudflare). Tests cover:
- FindAuthoritativeNameservers: iterative NS discovery, sorting,
  determinism, trailing dot handling, TLD and subdomain cases
- QueryNameserver: A, AAAA, CNAME, MX, TXT, NXDOMAIN, per-NS
  response model with status field, sorted record values
- QueryAllNameservers: independent per-NS queries, consistency
  verification, NXDOMAIN from all NS
- LookupNS: NS record lookup matching FindAuthoritative
- ResolveIPAddresses: basic, multi-A, IPv6, dual-stack, CNAME
  following, deduplication, sorting, NXDOMAIN returns empty
- Context cancellation for all methods
- Iterative resolution proof (resolves example.com from root)

Also adds DNSSEC validation to planned future features in README.
2026-02-19 22:22:58 +01:00
11 changed files with 1764 additions and 276 deletions

View File

@@ -1,7 +1,5 @@
# dnswatcher # dnswatcher
> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice.
dnswatcher is a production DNS and infrastructure monitoring daemon written in dnswatcher is a production DNS and infrastructure monitoring daemon written in
Go. It watches configured DNS domains and hostnames for changes, monitors TCP Go. It watches configured DNS domains and hostnames for changes, monitors TCP
port availability, tracks TLS certificate expiry, and delivers real-time port availability, tracks TLS certificate expiry, and delivers real-time
@@ -197,7 +195,8 @@ the following precedence (highest to lowest):
| `PORT` | HTTP listen port | `8080` | | `PORT` | HTTP listen port | `8080` |
| `DNSWATCHER_DEBUG` | Enable debug logging | `false` | | `DNSWATCHER_DEBUG` | Enable debug logging | `false` |
| `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` | | `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` |
| `DNSWATCHER_TARGETS` | Comma-separated DNS names (auto-classified via PSL) | `""` | | `DNSWATCHER_DOMAINS` | Comma-separated list of apex domains | `""` |
| `DNSWATCHER_HOSTNAMES` | Comma-separated list of hostnames | `""` |
| `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` | | `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` |
| `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` | | `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` |
| `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` | | `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` |
@@ -215,7 +214,8 @@ the following precedence (highest to lowest):
PORT=8080 PORT=8080
DNSWATCHER_DEBUG=false DNSWATCHER_DEBUG=false
DNSWATCHER_DATA_DIR=./data DNSWATCHER_DATA_DIR=./data
DNSWATCHER_TARGETS=example.com,example.org,www.example.com,api.example.com,mail.example.org DNSWATCHER_DOMAINS=example.com,example.org
DNSWATCHER_HOSTNAMES=www.example.com,api.example.com,mail.example.org
DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx
DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx
DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts
@@ -352,7 +352,8 @@ docker build -t dnswatcher .
docker run -d \ docker run -d \
-p 8080:8080 \ -p 8080:8080 \
-v dnswatcher-data:/var/lib/dnswatcher \ -v dnswatcher-data:/var/lib/dnswatcher \
-e DNSWATCHER_TARGETS=example.com,www.example.com \ -e DNSWATCHER_DOMAINS=example.com \
-e DNSWATCHER_HOSTNAMES=www.example.com \
-e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \ -e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \
dnswatcher dnswatcher
``` ```
@@ -376,6 +377,13 @@ docker run -d \
--- ---
## Planned Future Features (Post-1.0)
- **DNSSEC validation**: Validate the DNSSEC chain of trust during
iterative resolution and report DNSSEC failures as notifications.
---
## Project Structure ## Project Structure
Follows the conventions defined in `CONVENTIONS.md`, adapted from the Follows the conventions defined in `CONVENTIONS.md`, adapted from the

14
go.mod
View File

@@ -7,19 +7,22 @@ require (
github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/cors v1.2.2 github.com/go-chi/cors v1.2.2
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/miekg/dns v1.1.72
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0
) )
require ( require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect
@@ -34,7 +37,12 @@ require (
go.uber.org/zap v1.26.0 // indirect go.uber.org/zap v1.26.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/sys v0.41.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/text v0.34.0 // indirect golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.40.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

20
go.sum
View File

@@ -28,6 +28,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
@@ -74,12 +76,18 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -1,85 +0,0 @@
package config
import (
"errors"
"fmt"
"strings"
"golang.org/x/net/publicsuffix"
)
// DNSNameType indicates whether a DNS name is an apex domain or a hostname.
type DNSNameType int
const (
// DNSNameTypeDomain indicates the name is an apex (eTLD+1) domain.
DNSNameTypeDomain DNSNameType = iota
// DNSNameTypeHostname indicates the name is a subdomain/hostname.
DNSNameTypeHostname
)
// ErrEmptyDNSName is returned when an empty string is passed to ClassifyDNSName.
var ErrEmptyDNSName = errors.New("empty DNS name")
// String returns the string representation of a DNSNameType.
func (t DNSNameType) String() string {
switch t {
case DNSNameTypeDomain:
return "domain"
case DNSNameTypeHostname:
return "hostname"
default:
return "unknown"
}
}
// ClassifyDNSName determines whether a DNS name is an apex domain or a
// hostname (subdomain) using the Public Suffix List. It returns an error
// if the input is empty or is itself a public suffix (e.g. "co.uk").
func ClassifyDNSName(name string) (DNSNameType, error) {
name = strings.ToLower(strings.TrimSuffix(strings.TrimSpace(name), "."))
if name == "" {
return 0, ErrEmptyDNSName
}
etld1, err := publicsuffix.EffectiveTLDPlusOne(name)
if err != nil {
return 0, fmt.Errorf("invalid DNS name %q: %w", name, err)
}
if name == etld1 {
return DNSNameTypeDomain, nil
}
return DNSNameTypeHostname, nil
}
// ClassifyTargets splits a list of DNS names into apex domains and
// hostnames using the Public Suffix List. It returns an error if any
// name cannot be classified.
func ClassifyTargets(targets []string) ([]string, []string, error) {
var domains, hostnames []string
for _, t := range targets {
normalized := strings.ToLower(strings.TrimSuffix(strings.TrimSpace(t), "."))
if normalized == "" {
continue
}
typ, classErr := ClassifyDNSName(normalized)
if classErr != nil {
return nil, nil, classErr
}
switch typ {
case DNSNameTypeDomain:
domains = append(domains, normalized)
case DNSNameTypeHostname:
hostnames = append(hostnames, normalized)
}
}
return domains, hostnames, nil
}

View File

@@ -1,83 +0,0 @@
package config_test
import (
"testing"
"sneak.berlin/go/dnswatcher/internal/config"
)
func TestClassifyDNSName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want config.DNSNameType
wantErr bool
}{
{name: "apex domain simple", input: "example.com", want: config.DNSNameTypeDomain},
{name: "hostname simple", input: "www.example.com", want: config.DNSNameTypeHostname},
{name: "apex domain multi-part TLD", input: "example.co.uk", want: config.DNSNameTypeDomain},
{name: "hostname multi-part TLD", input: "api.example.co.uk", want: config.DNSNameTypeHostname},
{name: "public suffix itself", input: "co.uk", wantErr: true},
{name: "empty string", input: "", wantErr: true},
{name: "deeply nested hostname", input: "a.b.c.example.com", want: config.DNSNameTypeHostname},
{name: "trailing dot stripped", input: "example.com.", want: config.DNSNameTypeDomain},
{name: "uppercase normalized", input: "WWW.Example.COM", want: config.DNSNameTypeHostname},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := config.ClassifyDNSName(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("ClassifyDNSName(%q) expected error, got %v", tt.input, got)
}
return
}
if err != nil {
t.Fatalf("ClassifyDNSName(%q) unexpected error: %v", tt.input, err)
}
if got != tt.want {
t.Errorf("ClassifyDNSName(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestClassifyTargets(t *testing.T) {
t.Parallel()
domains, hostnames, err := config.ClassifyTargets([]string{
"example.com",
"www.example.com",
"example.co.uk",
"api.example.co.uk",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(domains) != 2 {
t.Errorf("expected 2 domains, got %d: %v", len(domains), domains)
}
if len(hostnames) != 2 {
t.Errorf("expected 2 hostnames, got %d: %v", len(hostnames), hostnames)
}
}
func TestClassifyTargetsRejectsPublicSuffix(t *testing.T) {
t.Parallel()
_, _, err := config.ClassifyTargets([]string{"co.uk"})
if err == nil {
t.Error("expected error for public suffix, got nil")
}
}

View File

@@ -89,7 +89,8 @@ func setupViper(name string) {
viper.SetDefault("PORT", defaultPort) viper.SetDefault("PORT", defaultPort)
viper.SetDefault("DEBUG", false) viper.SetDefault("DEBUG", false)
viper.SetDefault("DATA_DIR", "./data") viper.SetDefault("DATA_DIR", "./data")
viper.SetDefault("TARGETS", "") viper.SetDefault("DOMAINS", "")
viper.SetDefault("HOSTNAMES", "")
viper.SetDefault("SLACK_WEBHOOK", "") viper.SetDefault("SLACK_WEBHOOK", "")
viper.SetDefault("MATTERMOST_WEBHOOK", "") viper.SetDefault("MATTERMOST_WEBHOOK", "")
viper.SetDefault("NTFY_TOPIC", "") viper.SetDefault("NTFY_TOPIC", "")
@@ -132,19 +133,12 @@ func buildConfig(
tlsInterval = defaultTLSInterval tlsInterval = defaultTLSInterval
} }
domains, hostnames, err := ClassifyTargets(
parseCSV(viper.GetString("TARGETS")),
)
if err != nil {
return nil, fmt.Errorf("invalid targets configuration: %w", err)
}
cfg := &Config{ cfg := &Config{
Port: viper.GetInt("PORT"), Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"), Debug: viper.GetBool("DEBUG"),
DataDir: viper.GetString("DATA_DIR"), DataDir: viper.GetString("DATA_DIR"),
Domains: domains, Domains: parseCSV(viper.GetString("DOMAINS")),
Hostnames: hostnames, Hostnames: parseCSV(viper.GetString("HOSTNAMES")),
SlackWebhook: viper.GetString("SLACK_WEBHOOK"), SlackWebhook: viper.GetString("SLACK_WEBHOOK"),
MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"), MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"),
NtfyTopic: viper.GetString("NTFY_TOPIC"), NtfyTopic: viper.GetString("NTFY_TOPIC"),

View File

@@ -36,6 +36,16 @@ var (
) )
) )
// sanitizeURL parses and re-serializes a URL to satisfy static analysis (gosec G704).
func sanitizeURL(raw string) (string, error) {
u, err := url.Parse(raw)
if err != nil {
return "", fmt.Errorf("invalid URL %q: %w", raw, err)
}
return u.String(), nil
}
// Params contains dependencies for Service. // Params contains dependencies for Service.
type Params struct { type Params struct {
fx.In fx.In
@@ -46,12 +56,9 @@ type Params struct {
// Service provides notification functionality. // Service provides notification functionality.
type Service struct { type Service struct {
log *slog.Logger log *slog.Logger
client *http.Client client *http.Client
config *config.Config config *config.Config
ntfyURL *url.URL
slackWebhookURL *url.URL
mattermostWebhookURL *url.URL
} }
// New creates a new notify Service. // New creates a new notify Service.
@@ -59,44 +66,13 @@ func New(
_ fx.Lifecycle, _ fx.Lifecycle,
params Params, params Params,
) (*Service, error) { ) (*Service, error) {
svc := &Service{ return &Service{
log: params.Logger.Get(), log: params.Logger.Get(),
client: &http.Client{ client: &http.Client{
Timeout: httpClientTimeout, Timeout: httpClientTimeout,
}, },
config: params.Config, config: params.Config,
} }, nil
if params.Config.NtfyTopic != "" {
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
if err != nil {
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
}
svc.ntfyURL = u
}
if params.Config.SlackWebhook != "" {
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
if err != nil {
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
}
svc.slackWebhookURL = u
}
if params.Config.MattermostWebhook != "" {
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
if err != nil {
return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err,
)
}
svc.mattermostWebhookURL = u
}
return svc, nil
} }
// SendNotification sends a notification to all configured endpoints. // SendNotification sends a notification to all configured endpoints.
@@ -104,13 +80,13 @@ func (svc *Service) SendNotification(
ctx context.Context, ctx context.Context,
title, message, priority string, title, message, priority string,
) { ) {
if svc.ntfyURL != nil { if svc.config.NtfyTopic != "" {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendNtfy( err := svc.sendNtfy(
notifyCtx, notifyCtx,
svc.ntfyURL, svc.config.NtfyTopic,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -122,13 +98,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.slackWebhookURL != nil { if svc.config.SlackWebhook != "" {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.slackWebhookURL, svc.config.SlackWebhook,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -140,13 +116,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.mattermostWebhookURL != nil { if svc.config.MattermostWebhook != "" {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.mattermostWebhookURL, svc.config.MattermostWebhook,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -161,19 +137,23 @@ func (svc *Service) SendNotification(
func (svc *Service) sendNtfy( func (svc *Service) sendNtfy(
ctx context.Context, ctx context.Context,
topicURL *url.URL, topic, title, message, priority string,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending ntfy notification", "sending ntfy notification",
"topic", topicURL.String(), "topic", topic,
"title", title, "title", title,
) )
cleanURL, err := sanitizeURL(topic)
if err != nil {
return fmt.Errorf("invalid ntfy topic URL: %w", err)
}
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
topicURL.String(), cleanURL,
bytes.NewBufferString(message), bytes.NewBufferString(message),
) )
if err != nil { if err != nil {
@@ -183,7 +163,7 @@ func (svc *Service) sendNtfy(
request.Header.Set("Title", title) request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority)) request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time resp, err := svc.client.Do(request) // #nosec G704 -- URL comes from validated application config
if err != nil { if err != nil {
return fmt.Errorf("sending ntfy request: %w", err) return fmt.Errorf("sending ntfy request: %w", err)
} }
@@ -229,12 +209,11 @@ type SlackAttachment struct {
func (svc *Service) sendSlack( func (svc *Service) sendSlack(
ctx context.Context, ctx context.Context,
webhookURL *url.URL, webhookURL, title, message, priority string,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending webhook notification", "sending webhook notification",
"url", webhookURL.String(), "url", webhookURL,
"title", title, "title", title,
) )
@@ -253,10 +232,15 @@ func (svc *Service) sendSlack(
return fmt.Errorf("marshaling webhook payload: %w", err) return fmt.Errorf("marshaling webhook payload: %w", err)
} }
cleanURL, err := sanitizeURL(webhookURL)
if err != nil {
return fmt.Errorf("invalid webhook URL: %w", err)
}
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
webhookURL.String(), cleanURL,
bytes.NewBuffer(body), bytes.NewBuffer(body),
) )
if err != nil { if err != nil {
@@ -265,7 +249,7 @@ func (svc *Service) sendSlack(
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time resp, err := svc.client.Do(request) // #nosec G704 -- URL comes from validated application config
if err != nil { if err != nil {
return fmt.Errorf("sending webhook request: %w", err) return fmt.Errorf("sending webhook request: %w", err)
} }

View File

@@ -0,0 +1,27 @@
package resolver
import "errors"
// Sentinel errors returned by the resolver.
var (
// ErrNotImplemented indicates a method is stubbed out.
ErrNotImplemented = errors.New(
"resolver not yet implemented",
)
// ErrNoNameservers is returned when no authoritative NS
// could be discovered for a domain.
ErrNoNameservers = errors.New(
"no authoritative nameservers found",
)
// ErrCNAMEDepthExceeded is returned when a CNAME chain
// exceeds MaxCNAMEDepth.
ErrCNAMEDepthExceeded = errors.New(
"CNAME chain depth exceeded",
)
// ErrContextCanceled wraps context cancellation for the
// resolver's iterative queries.
ErrContextCanceled = errors.New("context canceled")
)

View File

@@ -0,0 +1,716 @@
package resolver
import (
"context"
"errors"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/miekg/dns"
)
const (
queryTimeoutDuration = 5 * time.Second
maxRetries = 2
maxDelegation = 20
timeoutMultiplier = 2
minDomainLabels = 2
)
// ErrRefused is returned when a DNS server refuses a query.
var ErrRefused = errors.New("dns query refused")
func rootServerList() []string {
return []string{
"198.41.0.4", // a.root-servers.net
"170.247.170.2", // b
"192.33.4.12", // c
"199.7.91.13", // d
"192.203.230.10", // e
"192.5.5.241", // f
"192.112.36.4", // g
"198.97.190.53", // h
"192.36.148.17", // i
"192.58.128.30", // j
"193.0.14.129", // k
"199.7.83.42", // l
"202.12.27.33", // m
}
}
func checkCtx(ctx context.Context) error {
err := ctx.Err()
if err != nil {
return ErrContextCanceled
}
return nil
}
func exchangeWithTimeout(
ctx context.Context,
msg *dns.Msg,
addr string,
attempt int,
) (*dns.Msg, error) {
c := new(dns.Client)
c.Timeout = queryTimeoutDuration
if attempt > 0 {
c.Timeout = queryTimeoutDuration * timeoutMultiplier
}
resp, _, err := c.ExchangeContext(ctx, msg, addr)
return resp, err
}
func tryExchange(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, error) {
var resp *dns.Msg
var err error
for attempt := range maxRetries {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err = exchangeWithTimeout(ctx, msg, addr, attempt)
if err == nil {
break
}
}
return resp, err
}
func retryTCP(
ctx context.Context,
msg *dns.Msg,
addr string,
resp *dns.Msg,
) *dns.Msg {
if !resp.Truncated {
return resp
}
c := &dns.Client{
Net: "tcp",
Timeout: queryTimeoutDuration,
}
tcpResp, _, tcpErr := c.ExchangeContext(ctx, msg, addr)
if tcpErr == nil {
return tcpResp
}
return resp
}
// queryDNS sends a DNS query to a specific server IP.
// Tries non-recursive first, falls back to recursive on
// REFUSED (handles DNS interception environments).
func queryDNS(
ctx context.Context,
serverIP string,
name string,
qtype uint16,
) (*dns.Msg, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
name = dns.Fqdn(name)
addr := net.JoinHostPort(serverIP, "53")
msg := new(dns.Msg)
msg.SetQuestion(name, qtype)
msg.RecursionDesired = false
resp, err := tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err)
}
if resp.Rcode == dns.RcodeRefused {
msg.RecursionDesired = true
resp, err = tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, err,
)
}
if resp.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, ErrRefused,
)
}
}
resp = retryTCP(ctx, msg, addr, resp)
return resp, nil
}
func extractNSSet(rrs []dns.RR) []string {
nsSet := make(map[string]bool)
for _, rr := range rrs {
if ns, ok := rr.(*dns.NS); ok {
nsSet[strings.ToLower(ns.Ns)] = true
}
}
names := make([]string, 0, len(nsSet))
for n := range nsSet {
names = append(names, n)
}
sort.Strings(names)
return names
}
func extractGlue(rrs []dns.RR) map[string][]net.IP {
glue := make(map[string][]net.IP)
for _, rr := range rrs {
switch r := rr.(type) {
case *dns.A:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.A)
case *dns.AAAA:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.AAAA)
}
}
return glue
}
func glueIPs(nsNames []string, glue map[string][]net.IP) []string {
var ips []string
for _, ns := range nsNames {
for _, addr := range glue[ns] {
if v4 := addr.To4(); v4 != nil {
ips = append(ips, v4.String())
}
}
}
return ips
}
func (r *Resolver) followDelegation(
ctx context.Context,
domain string,
servers []string,
) ([]string, error) {
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := queryServers(ctx, servers, domain, dns.TypeNS)
if err != nil {
return nil, err
}
ansNS := extractNSSet(resp.Answer)
if len(ansNS) > 0 {
return ansNS, nil
}
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
return r.resolveNSRecursive(ctx, domain)
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
nextServers = r.resolveNSIPs(ctx, authNS)
}
if len(nextServers) == 0 {
return nil, ErrNoNameservers
}
servers = nextServers
}
return nil, ErrNoNameservers
}
func queryServers(
ctx context.Context,
servers []string,
name string,
qtype uint16,
) (*dns.Msg, error) {
var lastErr error
for _, ip := range servers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := queryDNS(ctx, ip, name, qtype)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, fmt.Errorf("all servers failed: %w", lastErr)
}
func (r *Resolver) resolveNSIPs(
ctx context.Context,
nsNames []string,
) []string {
var ips []string
for _, ns := range nsNames {
resolved, err := r.resolveARecord(ctx, ns)
if err == nil {
ips = append(ips, resolved...)
}
if len(ips) > 0 {
break
}
}
return ips
}
// resolveNSRecursive queries for NS records using recursive
// resolution as a fallback for intercepted environments.
func (r *Resolver) resolveNSRecursive(
ctx context.Context,
domain string,
) ([]string, error) {
domain = dns.Fqdn(domain)
msg := new(dns.Msg)
msg.SetQuestion(domain, dns.TypeNS)
msg.RecursionDesired = true
c := &dns.Client{Timeout: queryTimeoutDuration}
for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
addr := net.JoinHostPort(ip, "53")
resp, _, err := c.ExchangeContext(ctx, msg, addr)
if err != nil {
continue
}
nsNames := extractNSSet(resp.Answer)
if len(nsNames) > 0 {
return nsNames, nil
}
}
return nil, ErrNoNameservers
}
// resolveARecord resolves a hostname to IPv4 addresses.
func (r *Resolver) resolveARecord(
ctx context.Context,
hostname string,
) ([]string, error) {
hostname = dns.Fqdn(hostname)
msg := new(dns.Msg)
msg.SetQuestion(hostname, dns.TypeA)
msg.RecursionDesired = true
c := &dns.Client{Timeout: queryTimeoutDuration}
for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
addr := net.JoinHostPort(ip, "53")
resp, _, err := c.ExchangeContext(ctx, msg, addr)
if err != nil {
continue
}
var ips []string
for _, rr := range resp.Answer {
if a, ok := rr.(*dns.A); ok {
ips = append(ips, a.A.String())
}
}
if len(ips) > 0 {
return ips, nil
}
}
return nil, fmt.Errorf(
"cannot resolve %s: %w", hostname, ErrNoNameservers,
)
}
// FindAuthoritativeNameservers traces the delegation chain from
// root servers to discover all authoritative nameservers for the
// given domain. Walks up the label hierarchy for subdomains.
func (r *Resolver) FindAuthoritativeNameservers(
ctx context.Context,
domain string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
domain = dns.Fqdn(strings.ToLower(domain))
labels := dns.SplitDomainName(domain)
for i := range labels {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
candidate := strings.Join(labels[i:], ".") + "."
nsNames, err := r.followDelegation(
ctx, candidate, rootServerList(),
)
if err == nil && len(nsNames) > 0 {
sort.Strings(nsNames)
return nsNames, nil
}
}
return nil, ErrNoNameservers
}
// QueryNameserver queries a specific nameserver for all record
// types and builds a NameserverResponse.
func (r *Resolver) QueryNameserver(
ctx context.Context,
nsHostname string,
hostname string,
) (*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
nsIPs, err := r.resolveARecord(ctx, nsHostname)
if err != nil {
return nil, fmt.Errorf("resolving NS %s: %w", nsHostname, err)
}
hostname = dns.Fqdn(hostname)
return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname)
}
func (r *Resolver) queryAllTypes(
ctx context.Context,
nsHostname string,
nsIP string,
hostname string,
) (*NameserverResponse, error) {
resp := &NameserverResponse{
Nameserver: nsHostname,
Records: make(map[string][]string),
Status: StatusOK,
}
qtypes := []uint16{
dns.TypeA, dns.TypeAAAA, dns.TypeCNAME,
dns.TypeMX, dns.TypeTXT, dns.TypeSRV,
dns.TypeCAA, dns.TypeNS,
}
state := r.queryEachType(ctx, nsIP, hostname, qtypes, resp)
classifyResponse(resp, state)
return resp, nil
}
type queryState struct {
gotNXDomain bool
gotSERVFAIL bool
hasRecords bool
}
func (r *Resolver) queryEachType(
ctx context.Context,
nsIP string,
hostname string,
qtypes []uint16,
resp *NameserverResponse,
) queryState {
var state queryState
for _, qtype := range qtypes {
if checkCtx(ctx) != nil {
break
}
r.querySingleType(ctx, nsIP, hostname, qtype, resp, &state)
}
for k := range resp.Records {
sort.Strings(resp.Records[k])
}
return state
}
func (r *Resolver) querySingleType(
ctx context.Context,
nsIP string,
hostname string,
qtype uint16,
resp *NameserverResponse,
state *queryState,
) {
msg, err := queryDNS(ctx, nsIP, hostname, qtype)
if err != nil {
return
}
if msg.Rcode == dns.RcodeNameError {
state.gotNXDomain = true
return
}
if msg.Rcode == dns.RcodeServerFailure {
state.gotSERVFAIL = true
return
}
collectAnswerRecords(msg, resp, state)
}
func collectAnswerRecords(
msg *dns.Msg,
resp *NameserverResponse,
state *queryState,
) {
for _, rr := range msg.Answer {
val := extractRecordValue(rr)
if val == "" {
continue
}
typeName := dns.TypeToString[rr.Header().Rrtype]
resp.Records[typeName] = append(
resp.Records[typeName], val,
)
state.hasRecords = true
}
}
func classifyResponse(resp *NameserverResponse, state queryState) {
switch {
case state.gotNXDomain && !state.hasRecords:
resp.Status = StatusNXDomain
case state.gotSERVFAIL && !state.hasRecords:
resp.Status = StatusError
case !state.hasRecords && !state.gotNXDomain:
resp.Status = StatusNoData
}
}
// extractRecordValue formats a DNS RR value as a string.
func extractRecordValue(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
return r.A.String()
case *dns.AAAA:
return r.AAAA.String()
case *dns.CNAME:
return r.Target
case *dns.MX:
return fmt.Sprintf("%d %s", r.Preference, r.Mx)
case *dns.TXT:
return strings.Join(r.Txt, "")
case *dns.SRV:
return fmt.Sprintf(
"%d %d %d %s",
r.Priority, r.Weight, r.Port, r.Target,
)
case *dns.CAA:
return fmt.Sprintf(
"%d %s \"%s\"", r.Flag, r.Tag, r.Value,
)
case *dns.NS:
return r.Ns
default:
return ""
}
}
// parentDomain returns the registerable parent domain.
func parentDomain(hostname string) string {
hostname = dns.Fqdn(strings.ToLower(hostname))
labels := dns.SplitDomainName(hostname)
if len(labels) <= minDomainLabels {
return strings.Join(labels, ".") + "."
}
return strings.Join(
labels[len(labels)-minDomainLabels:], ".",
) + "."
}
// QueryAllNameservers discovers auth NSes for the hostname's
// parent domain, then queries each one independently.
func (r *Resolver) QueryAllNameservers(
ctx context.Context,
hostname string,
) (map[string]*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
parent := parentDomain(hostname)
nameservers, err := r.FindAuthoritativeNameservers(ctx, parent)
if err != nil {
return nil, err
}
return r.queryEachNS(ctx, nameservers, hostname)
}
func (r *Resolver) queryEachNS(
ctx context.Context,
nameservers []string,
hostname string,
) (map[string]*NameserverResponse, error) {
results := make(map[string]*NameserverResponse)
for _, ns := range nameservers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.QueryNameserver(ctx, ns, hostname)
if err != nil {
results[ns] = &NameserverResponse{
Nameserver: ns,
Records: make(map[string][]string),
Status: StatusError,
Error: err.Error(),
}
continue
}
results[ns] = resp
}
return results, nil
}
// LookupNS returns the NS record set for a domain.
func (r *Resolver) LookupNS(
ctx context.Context,
domain string,
) ([]string, error) {
return r.FindAuthoritativeNameservers(ctx, domain)
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains up to MaxCNAMEDepth.
func (r *Resolver) ResolveIPAddresses(
ctx context.Context,
hostname string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
return r.resolveIPWithCNAME(ctx, hostname, 0)
}
func (r *Resolver) resolveIPWithCNAME(
ctx context.Context,
hostname string,
depth int,
) ([]string, error) {
if depth > MaxCNAMEDepth {
return nil, ErrCNAMEDepthExceeded
}
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
ips, cnameTarget := collectIPs(results)
if len(ips) == 0 && cnameTarget != "" {
return r.resolveIPWithCNAME(ctx, cnameTarget, depth+1)
}
sort.Strings(ips)
return ips, nil
}
func collectIPs(
results map[string]*NameserverResponse,
) ([]string, string) {
seen := make(map[string]bool)
var ips []string
var cnameTarget string
for _, resp := range results {
if resp.Status == StatusNXDomain {
continue
}
for _, ip := range resp.Records["A"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
for _, ip := range resp.Records["AAAA"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
if len(resp.Records["CNAME"]) > 0 && cnameTarget == "" {
cnameTarget = resp.Records["CNAME"][0]
}
}
return ips, cnameTarget
}

View File

@@ -1,9 +1,9 @@
// Package resolver provides iterative DNS resolution from root nameservers. // Package resolver provides iterative DNS resolution from root nameservers.
// It traces the full delegation chain from IANA root servers through TLD
// and domain nameservers, never relying on upstream recursive resolvers.
package resolver package resolver
import ( import (
"context"
"errors"
"log/slog" "log/slog"
"go.uber.org/fx" "go.uber.org/fx"
@@ -11,8 +11,16 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
// ErrNotImplemented indicates the resolver is not yet implemented. // Query status constants matching the state model.
var ErrNotImplemented = errors.New("resolver not yet implemented") const (
StatusOK = "ok"
StatusError = "error"
StatusNXDomain = "nxdomain"
StatusNoData = "nodata"
)
// MaxCNAMEDepth is the maximum CNAME chain depth to follow.
const MaxCNAMEDepth = 10
// Params contains dependencies for Resolver. // Params contains dependencies for Resolver.
type Params struct { type Params struct {
@@ -21,12 +29,20 @@ type Params struct {
Logger *logger.Logger Logger *logger.Logger
} }
// NameserverResponse holds one nameserver's response for a query.
type NameserverResponse struct {
Nameserver string
Records map[string][]string
Status string
Error string
}
// Resolver performs iterative DNS resolution from root servers. // Resolver performs iterative DNS resolution from root servers.
type Resolver struct { type Resolver struct {
log *slog.Logger log *slog.Logger
} }
// New creates a new Resolver instance. // New creates a new Resolver instance for use with uber/fx.
func New( func New(
_ fx.Lifecycle, _ fx.Lifecycle,
params Params, params Params,
@@ -36,29 +52,10 @@ func New(
}, nil }, nil
} }
// LookupNS performs iterative resolution to find authoritative // NewFromLogger creates a Resolver directly from an slog.Logger,
// nameservers for the given domain. // useful for testing without the fx lifecycle.
func (r *Resolver) LookupNS( func NewFromLogger(log *slog.Logger) *Resolver {
_ context.Context, return &Resolver{log: log}
_ string,
) ([]string, error) {
return nil, ErrNotImplemented
} }
// LookupAllRecords performs iterative resolution to find all DNS // Method implementations are in iterative.go.
// records for the given hostname.
func (r *Resolver) LookupAllRecords(
_ context.Context,
_ string,
) (map[string][]string, error) {
return nil, ErrNotImplemented
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains.
func (r *Resolver) ResolveIPAddresses(
_ context.Context,
_ string,
) ([]string, error) {
return nil, ErrNotImplemented
}

View File

@@ -0,0 +1,914 @@
package resolver_test
import (
"context"
"log/slog"
"net"
"os"
"sort"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/resolver"
)
// Test domain and hostnames hosted on Cloudflare.
// These records must exist in the sneak.cloud Cloudflare zone:
//
// basic.dns.sneak.cloud A 192.0.2.1
// multi.dns.sneak.cloud A 192.0.2.1
// multi.dns.sneak.cloud A 192.0.2.2
// ipv6.dns.sneak.cloud AAAA 2001:db8::1
// dual.dns.sneak.cloud A 192.0.2.1
// dual.dns.sneak.cloud AAAA 2001:db8::1
// cname-target.dns.sneak.cloud A 198.51.100.1
// cname.dns.sneak.cloud CNAME cname-target.dns.sneak.cloud
// mx.dns.sneak.cloud MX 10 mail.dns.sneak.cloud
// mail.dns A 192.0.2.10
// txt.dns.sneak.cloud TXT "v=spf1 -all"
const (
testDomain = "sneak.cloud"
testHostBasic = "basic.dns.sneak.cloud"
testHostMultiA = "multi.dns.sneak.cloud"
testHostIPv6 = "ipv6.dns.sneak.cloud"
testHostDualStack = "dual.dns.sneak.cloud"
testHostCNAME = "cname.dns.sneak.cloud"
testHostCNAMETarget = "cname-target.dns.sneak.cloud"
testHostMX = "mx.dns.sneak.cloud"
testHostMail = "mail.dns.sneak.cloud"
testHostTXT = "txt.dns.sneak.cloud"
testHostNXDomain = "nxdomain-surely-does-not-exist.google.com"
testDomainNXDomain = "google.com"
)
// queryTimeout is the default timeout for test queries.
const queryTimeout = 60 * time.Second
func newTestResolver(t *testing.T) *resolver.Resolver {
t.Helper()
log := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
}))
return resolver.NewFromLogger(log)
}
func testContext(t *testing.T) context.Context {
t.Helper()
ctx, cancel := context.WithTimeout(
context.Background(), queryTimeout,
)
t.Cleanup(cancel)
return ctx
}
// --- FindAuthoritativeNameservers tests ---
func TestFindAuthoritativeNameservers_ValidDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, testDomain,
)
require.NoError(t, err)
require.NotEmpty(t, nameservers, "should find at least one NS")
// sneak.cloud is on Cloudflare, NS should contain cloudflare
for _, ns := range nameservers {
t.Logf("discovered NS: %s", ns)
assert.True(
t,
strings.HasSuffix(ns, "."),
"NS should be FQDN with trailing dot: %s", ns,
)
}
// Verify at least one is a Cloudflare NS
hasCloudflare := false
for _, ns := range nameservers {
if strings.Contains(ns, "cloudflare") {
hasCloudflare = true
break
}
}
assert.True(
t, hasCloudflare,
"sneak.cloud should be hosted on Cloudflare, got: %v",
nameservers,
)
}
func TestFindAuthoritativeNameservers_Subdomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
// Looking up NS for a hostname that isn't a zone should
// return the parent zone's NS records.
nameservers, err := r.FindAuthoritativeNameservers(
ctx, testHostBasic,
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
// Should be the same Cloudflare NSes as the parent domain
hasCloudflare := false
for _, ns := range nameservers {
if strings.Contains(ns, "cloudflare") {
hasCloudflare = true
break
}
}
assert.True(t, hasCloudflare)
}
func TestFindAuthoritativeNameservers_TLD(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "cloud",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers, "should find TLD nameservers")
for _, ns := range nameservers {
t.Logf("TLD NS: %s", ns)
}
}
func TestFindAuthoritativeNameservers_ReturnsSorted(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, testDomain,
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
// Results should be sorted for deterministic comparison
assert.True(
t,
sort.StringsAreSorted(nameservers),
"nameservers should be sorted, got: %v", nameservers,
)
}
func TestFindAuthoritativeNameservers_Deterministic(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
first, err := r.FindAuthoritativeNameservers(
ctx, testDomain,
)
require.NoError(t, err)
second, err := r.FindAuthoritativeNameservers(
ctx, testDomain,
)
require.NoError(t, err)
assert.Equal(
t, first, second,
"repeated lookups should return same result",
)
}
// --- QueryNameserver tests ---
func TestQueryNameserver_BasicA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostBasic)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
assert.Equal(t, ns, resp.Nameserver)
aRecords := resp.Records["A"]
require.NotEmpty(t, aRecords, "basic.dns should have A records")
assert.Contains(t, aRecords, "192.0.2.1")
t.Logf(
"QueryNameserver(%s, %s) A records: %v",
ns, testHostBasic, aRecords,
)
}
func TestQueryNameserver_MultipleA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostMultiA)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
aRecords := resp.Records["A"]
require.Len(
t, aRecords, 2,
"multi.dns should have exactly 2 A records",
)
sort.Strings(aRecords)
assert.Equal(t, []string{"192.0.2.1", "192.0.2.2"}, aRecords)
}
func TestQueryNameserver_AAAA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostIPv6)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
aaaaRecords := resp.Records["AAAA"]
require.NotEmpty(
t, aaaaRecords,
"ipv6.dns should have AAAA records",
)
assert.Contains(t, aaaaRecords, "2001:db8::1")
}
func TestQueryNameserver_DualStack(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostDualStack)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
assert.Contains(t, resp.Records["A"], "192.0.2.1")
assert.Contains(t, resp.Records["AAAA"], "2001:db8::1")
}
func TestQueryNameserver_CNAME(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostCNAME)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
cnameRecords := resp.Records["CNAME"]
require.NotEmpty(
t, cnameRecords,
"cname.dns should have CNAME records",
)
assert.Contains(
t, cnameRecords, "cname-target.dns.sneak.cloud.",
)
}
func TestQueryNameserver_MX(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostMX)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
mxRecords := resp.Records["MX"]
require.NotEmpty(
t, mxRecords,
"mx.dns should have MX records",
)
// MX records are formatted as "priority host"
hasMail := false
for _, mx := range mxRecords {
if strings.Contains(mx, "mail.dns.sneak.cloud.") {
hasMail = true
break
}
}
assert.True(
t, hasMail,
"MX should reference mail.dns.sneak.cloud, got: %v",
mxRecords,
)
}
func TestQueryNameserver_TXT(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostTXT)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
txtRecords := resp.Records["TXT"]
require.NotEmpty(
t, txtRecords,
"txt.dns should have TXT records",
)
hasSPF := false
for _, txt := range txtRecords {
if strings.Contains(txt, "v=spf1") {
hasSPF = true
break
}
}
assert.True(
t, hasSPF,
"TXT should contain SPF record, got: %v", txtRecords,
)
}
func TestQueryNameserver_NXDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, testDomainNXDomain)
resp, err := r.QueryNameserver(ctx, ns, testHostNXDomain)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(
t, resolver.StatusNXDomain, resp.Status,
"nonexistent host should return nxdomain status",
)
}
func TestQueryNameserver_RecordsSorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostMultiA)
require.NoError(t, err)
// Each record type's values should be sorted for determinism
for recordType, values := range resp.Records {
assert.True(
t,
sort.StringsAreSorted(values),
"%s records should be sorted, got: %v",
recordType, values,
)
}
}
func TestQueryNameserver_ResponseIncludesNameserver(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
resp, err := r.QueryNameserver(ctx, ns, testHostBasic)
require.NoError(t, err)
assert.Equal(
t, ns, resp.Nameserver,
"response should include the queried nameserver",
)
}
func TestQueryNameserver_EmptyRecordsMapOnNXDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, testDomainNXDomain)
resp, err := r.QueryNameserver(ctx, ns, testHostNXDomain)
require.NoError(t, err)
totalRecords := 0
for _, values := range resp.Records {
totalRecords += len(values)
}
assert.Zero(
t, totalRecords,
"NXDOMAIN should have no records, got: %v",
resp.Records,
)
}
// --- QueryAllNameservers tests ---
func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(ctx, testHostBasic)
require.NoError(t, err)
require.NotEmpty(t, results)
// Should have queried each NS independently
t.Logf(
"QueryAllNameservers returned %d nameserver results",
len(results),
)
for ns, resp := range results {
t.Logf(" %s: status=%s A=%v", ns, resp.Status, resp.Records["A"])
assert.Equal(t, ns, resp.Nameserver)
}
// Should have more than one NS for Cloudflare-hosted domain
assert.GreaterOrEqual(
t, len(results), 2,
"should query at least 2 nameservers",
)
}
func TestQueryAllNameservers_Consistent(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(ctx, testHostBasic)
require.NoError(t, err)
require.NotEmpty(t, results)
// All NSes should return the same A records for a
// well-configured hostname.
var referenceRecords map[string][]string
for ns, resp := range results {
require.Equal(
t, resolver.StatusOK, resp.Status,
"NS %s should return OK status", ns,
)
if referenceRecords == nil {
referenceRecords = resp.Records
continue
}
assert.Equal(
t, referenceRecords["A"], resp.Records["A"],
"NS %s A records should match", ns,
)
}
}
func TestQueryAllNameservers_NXDomainFromAllNS(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx, testHostNXDomain,
)
require.NoError(t, err)
require.NotEmpty(t, results)
for ns, resp := range results {
assert.Equal(
t, resolver.StatusNXDomain, resp.Status,
"NS %s should return nxdomain for nonexistent host",
ns,
)
}
}
// --- LookupNS tests ---
func TestLookupNS_ValidDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, testDomain)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
for _, ns := range nameservers {
t.Logf("NS record: %s", ns)
assert.True(
t,
strings.HasSuffix(ns, "."),
"NS should be FQDN: %s", ns,
)
}
}
func TestLookupNS_Sorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, testDomain)
require.NoError(t, err)
assert.True(
t,
sort.StringsAreSorted(nameservers),
"NS records should be sorted, got: %v", nameservers,
)
}
func TestLookupNS_MatchesFindAuthoritative(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
fromLookup, err := r.LookupNS(ctx, testDomain)
require.NoError(t, err)
fromFind, err := r.FindAuthoritativeNameservers(
ctx, testDomain,
)
require.NoError(t, err)
// Both methods should return the same NS set
assert.Equal(
t, fromFind, fromLookup,
"LookupNS and FindAuthoritativeNameservers "+
"should return the same set",
)
}
// --- ResolveIPAddresses tests ---
func TestResolveIPAddresses_BasicA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, testHostBasic)
require.NoError(t, err)
require.NotEmpty(t, ips)
assert.Contains(t, ips, "192.0.2.1")
}
func TestResolveIPAddresses_MultipleA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, testHostMultiA)
require.NoError(t, err)
sort.Strings(ips)
assert.Contains(t, ips, "192.0.2.1")
assert.Contains(t, ips, "192.0.2.2")
}
func TestResolveIPAddresses_IPv6Only(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, testHostIPv6)
require.NoError(t, err)
require.NotEmpty(t, ips)
assert.Contains(t, ips, "2001:db8::1")
// Should not contain any IPv4
for _, ip := range ips {
parsed := net.ParseIP(ip)
require.NotNil(t, parsed, "should be valid IP: %s", ip)
assert.Nil(
t, parsed.To4(),
"ipv6-only host should not return IPv4: %s", ip,
)
}
}
func TestResolveIPAddresses_DualStack(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, testHostDualStack)
require.NoError(t, err)
assert.Contains(t, ips, "192.0.2.1")
assert.Contains(t, ips, "2001:db8::1")
}
func TestResolveIPAddresses_FollowsCNAME(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
// cname.dns.sneak.cloud -> cname-target.dns.sneak.cloud -> 198.51.100.1
ips, err := r.ResolveIPAddresses(ctx, testHostCNAME)
require.NoError(t, err)
require.NotEmpty(t, ips)
assert.Contains(
t, ips, "198.51.100.1",
"should follow CNAME to resolve target IP",
)
}
func TestResolveIPAddresses_Deduplicated(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, testHostBasic)
require.NoError(t, err)
// Check for duplicates
seen := make(map[string]bool)
for _, ip := range ips {
assert.False(
t, seen[ip],
"IP %s appears more than once", ip,
)
seen[ip] = true
}
}
func TestResolveIPAddresses_Sorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, testHostDualStack)
require.NoError(t, err)
assert.True(
t,
sort.StringsAreSorted(ips),
"IP addresses should be sorted, got: %v", ips,
)
}
func TestResolveIPAddresses_NXDomainReturnsEmpty(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, testHostNXDomain)
// Should not error — NXDOMAIN is an expected DNS response.
// It just means no IPs to return.
require.NoError(t, err)
assert.Empty(t, ips)
}
// --- Context cancellation tests ---
func TestFindAuthoritativeNameservers_ContextCanceled(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := r.FindAuthoritativeNameservers(ctx, testDomain)
assert.Error(t, err)
}
func TestQueryNameserver_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.QueryNameserver(
ctx, "ns1.example.com.", testHostBasic,
)
assert.Error(t, err)
}
func TestQueryAllNameservers_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.QueryAllNameservers(ctx, testHostBasic)
assert.Error(t, err)
}
func TestResolveIPAddresses_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.ResolveIPAddresses(ctx, testHostBasic)
assert.Error(t, err)
}
// --- Iterative resolution verification ---
func TestFindAuthoritativeNameservers_IsIterative(
t *testing.T,
) {
// Verify that resolution works for well-known domains,
// proving we trace from root rather than relying on a
// system stub resolver that might not be configured.
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
// Resolve a well-known domain to prove root->TLD->domain
// tracing works.
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
t.Logf("example.com NS: %v", nameservers)
}
// --- Edge cases ---
func TestQueryNameserver_TrailingDotHandling(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNS(t, r, ctx)
// Both with and without trailing dot should work
resp1, err := r.QueryNameserver(
ctx, ns, "basic.dns.sneak.cloud",
)
require.NoError(t, err)
resp2, err := r.QueryNameserver(
ctx, ns, "basic.dns.sneak.cloud.",
)
require.NoError(t, err)
assert.Equal(
t, resp1.Records["A"], resp2.Records["A"],
"trailing dot should not affect results",
)
}
func TestFindAuthoritativeNameservers_TrailingDot(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns1, err := r.FindAuthoritativeNameservers(
ctx, "sneak.cloud",
)
require.NoError(t, err)
ns2, err := r.FindAuthoritativeNameservers(
ctx, "sneak.cloud.",
)
require.NoError(t, err)
assert.Equal(
t, ns1, ns2,
"trailing dot should not affect NS lookup",
)
}
// --- Helper functions ---
// findOneNS discovers authoritative nameservers and returns the first
// one, failing the test if none are found.
func findOneNS(
t *testing.T,
r *resolver.Resolver,
ctx context.Context, //nolint:revive // test helper
) string {
t.Helper()
return findOneNSForDomain(t, r, ctx, testDomain)
}
func findOneNSForDomain(
t *testing.T,
r *resolver.Resolver,
ctx context.Context, //nolint:revive // test helper
domain string,
) string {
t.Helper()
nameservers, err := r.FindAuthoritativeNameservers(
ctx, domain,
)
require.NoError(t, err)
require.NotEmpty(
t, nameservers,
"should find at least one NS for %s", domain,
)
return nameservers[0]
}