Compare commits

...

10 Commits

Author SHA1 Message Date
clawbot
d786315452 fix: mock DNS in resolver tests for hermetic, fast unit tests
- Extract DNSClient interface from resolver to allow dependency injection
- Convert all resolver methods from package-level to receiver methods
  using the injectable DNS client
- Rewrite resolver_test.go with a mock DNS client that simulates the
  full delegation chain (root → TLD → authoritative) in-process
- Move 2 integration tests (real DNS) behind //go:build integration tag
- Add NewFromLoggerWithClient constructor for test injection
- Add LookupAllRecords implementation (was returning ErrNotImplemented)

All unit tests are hermetic (no network) and complete in <1s.
Total make check passes in ~5s.

Closes #12
2026-02-20 00:17:23 -08:00
clawbot
ee40af94da Merge origin/main into feature/resolver 2026-02-20 00:10:20 -08:00
e185000402 Merge pull request 'feat: implement watcher monitoring orchestrator (closes #2)' (#8) from feature/watcher-implementation into main
Reviewed-on: #8
2026-02-20 09:06:42 +01:00
d5738d6d43 Merge branch 'main' into feature/watcher-implementation 2026-02-20 09:06:27 +01:00
5e4631776a Merge pull request 'feat: unify DOMAINS/HOSTNAMES into single TARGETS config (closes #10)' (#11) from feature/unified-targets into main
Reviewed-on: #11
2026-02-20 09:04:59 +01:00
clawbot
f8d5a8f6cc fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:43:42 -08:00
clawbot
e09135d9d9 fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:42:50 -08:00
clawbot
73e01c7664 feat: unify DOMAINS/HOSTNAMES into single TARGETS config
Replace DNSWATCHER_DOMAINS and DNSWATCHER_HOSTNAMES with a single
DNSWATCHER_TARGETS env var. Names are automatically classified as apex
domains or hostnames using the Public Suffix List
(golang.org/x/net/publicsuffix).

- ClassifyDNSName() uses EffectiveTLDPlusOne to determine type
- Public suffixes themselves (e.g. co.uk) are rejected with an error
- Old DOMAINS/HOSTNAMES vars removed entirely (pre-1.0, no compat needed)
- README updated with pre-1.0 warning

Closes #10
2026-02-19 20:09:39 -08:00
clawbot
f676cc9458 feat: implement watcher monitoring orchestrator
Implements the full monitoring loop:
- Immediate checks on startup, then periodic DNS+port and TLS cycles
- Domain NS change detection with notifications
- Per-nameserver hostname record tracking with change/failure/recovery
  and inconsistency detection
- TCP port 80/443 monitoring with state change notifications
- TLS certificate monitoring with change, expiry, and failure detection
- State persistence after each cycle
- First run establishes baseline without notifications
- Graceful shutdown via context cancellation

Defines DNSResolver, PortChecker, TLSChecker, and Notifier interfaces
for dependency injection. Updates main.go fx wiring and resolver stub
signature to match per-NS record format.

Closes #2
2026-02-19 13:48:46 -08:00
clawbot
dea30028b1 test: add watcher orchestrator tests with mock dependencies
Tests cover: first-run baseline, NS change detection, record change
detection, port state changes, TLS expiry warnings, graceful shutdown,
and NS failure/recovery scenarios.
2026-02-19 13:48:38 -08:00
17 changed files with 2355 additions and 523 deletions

View File

@ -1,5 +1,7 @@
# dnswatcher # dnswatcher
> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice.
dnswatcher is a production DNS and infrastructure monitoring daemon written in dnswatcher is a production DNS and infrastructure monitoring daemon written in
Go. It watches configured DNS domains and hostnames for changes, monitors TCP Go. It watches configured DNS domains and hostnames for changes, monitors TCP
port availability, tracks TLS certificate expiry, and delivers real-time port availability, tracks TLS certificate expiry, and delivers real-time
@ -195,8 +197,7 @@ the following precedence (highest to lowest):
| `PORT` | HTTP listen port | `8080` | | `PORT` | HTTP listen port | `8080` |
| `DNSWATCHER_DEBUG` | Enable debug logging | `false` | | `DNSWATCHER_DEBUG` | Enable debug logging | `false` |
| `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` | | `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` |
| `DNSWATCHER_DOMAINS` | Comma-separated list of apex domains | `""` | | `DNSWATCHER_TARGETS` | Comma-separated DNS names (auto-classified via PSL) | `""` |
| `DNSWATCHER_HOSTNAMES` | Comma-separated list of hostnames | `""` |
| `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` | | `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` |
| `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` | | `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` |
| `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` | | `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` |
@ -214,8 +215,7 @@ the following precedence (highest to lowest):
PORT=8080 PORT=8080
DNSWATCHER_DEBUG=false DNSWATCHER_DEBUG=false
DNSWATCHER_DATA_DIR=./data DNSWATCHER_DATA_DIR=./data
DNSWATCHER_DOMAINS=example.com,example.org DNSWATCHER_TARGETS=example.com,example.org,www.example.com,api.example.com,mail.example.org
DNSWATCHER_HOSTNAMES=www.example.com,api.example.com,mail.example.org
DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx
DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx
DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts
@ -352,8 +352,7 @@ docker build -t dnswatcher .
docker run -d \ docker run -d \
-p 8080:8080 \ -p 8080:8080 \
-v dnswatcher-data:/var/lib/dnswatcher \ -v dnswatcher-data:/var/lib/dnswatcher \
-e DNSWATCHER_DOMAINS=example.com \ -e DNSWATCHER_TARGETS=example.com,www.example.com \
-e DNSWATCHER_HOSTNAMES=www.example.com \
-e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \ -e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \
dnswatcher dnswatcher
``` ```

View File

@ -51,6 +51,20 @@ func main() {
handlers.New, handlers.New,
server.New, server.New,
), ),
fx.Provide(
func(r *resolver.Resolver) watcher.DNSResolver {
return r
},
func(p *portcheck.Checker) watcher.PortChecker {
return p
},
func(t *tlscheck.Checker) watcher.TLSChecker {
return t
},
func(n *notify.Service) watcher.Notifier {
return n
},
),
fx.Invoke(func(*server.Server, *watcher.Watcher) {}), fx.Invoke(func(*server.Server, *watcher.Watcher) {}),
).Run() ).Run()
} }

10
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0
) )
require ( require (
@ -37,12 +38,11 @@ require (
go.uber.org/zap v1.26.0 // indirect go.uber.org/zap v1.26.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.31.0 // indirect golang.org/x/mod v0.32.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.32.0 // indirect golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.40.0 // indirect golang.org/x/tools v0.41.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

20
go.sum
View File

@ -76,18 +76,18 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -0,0 +1,85 @@
package config
import (
"errors"
"fmt"
"strings"
"golang.org/x/net/publicsuffix"
)
// DNSNameType indicates whether a DNS name is an apex domain or a hostname.
type DNSNameType int
const (
// DNSNameTypeDomain indicates the name is an apex (eTLD+1) domain.
DNSNameTypeDomain DNSNameType = iota
// DNSNameTypeHostname indicates the name is a subdomain/hostname.
DNSNameTypeHostname
)
// ErrEmptyDNSName is returned when an empty string is passed to ClassifyDNSName.
var ErrEmptyDNSName = errors.New("empty DNS name")
// String returns the string representation of a DNSNameType.
func (t DNSNameType) String() string {
switch t {
case DNSNameTypeDomain:
return "domain"
case DNSNameTypeHostname:
return "hostname"
default:
return "unknown"
}
}
// ClassifyDNSName determines whether a DNS name is an apex domain or a
// hostname (subdomain) using the Public Suffix List. It returns an error
// if the input is empty or is itself a public suffix (e.g. "co.uk").
func ClassifyDNSName(name string) (DNSNameType, error) {
name = strings.ToLower(strings.TrimSuffix(strings.TrimSpace(name), "."))
if name == "" {
return 0, ErrEmptyDNSName
}
etld1, err := publicsuffix.EffectiveTLDPlusOne(name)
if err != nil {
return 0, fmt.Errorf("invalid DNS name %q: %w", name, err)
}
if name == etld1 {
return DNSNameTypeDomain, nil
}
return DNSNameTypeHostname, nil
}
// ClassifyTargets splits a list of DNS names into apex domains and
// hostnames using the Public Suffix List. It returns an error if any
// name cannot be classified.
func ClassifyTargets(targets []string) ([]string, []string, error) {
var domains, hostnames []string
for _, t := range targets {
normalized := strings.ToLower(strings.TrimSuffix(strings.TrimSpace(t), "."))
if normalized == "" {
continue
}
typ, classErr := ClassifyDNSName(normalized)
if classErr != nil {
return nil, nil, classErr
}
switch typ {
case DNSNameTypeDomain:
domains = append(domains, normalized)
case DNSNameTypeHostname:
hostnames = append(hostnames, normalized)
}
}
return domains, hostnames, nil
}

View File

@ -0,0 +1,83 @@
package config_test
import (
"testing"
"sneak.berlin/go/dnswatcher/internal/config"
)
func TestClassifyDNSName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want config.DNSNameType
wantErr bool
}{
{name: "apex domain simple", input: "example.com", want: config.DNSNameTypeDomain},
{name: "hostname simple", input: "www.example.com", want: config.DNSNameTypeHostname},
{name: "apex domain multi-part TLD", input: "example.co.uk", want: config.DNSNameTypeDomain},
{name: "hostname multi-part TLD", input: "api.example.co.uk", want: config.DNSNameTypeHostname},
{name: "public suffix itself", input: "co.uk", wantErr: true},
{name: "empty string", input: "", wantErr: true},
{name: "deeply nested hostname", input: "a.b.c.example.com", want: config.DNSNameTypeHostname},
{name: "trailing dot stripped", input: "example.com.", want: config.DNSNameTypeDomain},
{name: "uppercase normalized", input: "WWW.Example.COM", want: config.DNSNameTypeHostname},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := config.ClassifyDNSName(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("ClassifyDNSName(%q) expected error, got %v", tt.input, got)
}
return
}
if err != nil {
t.Fatalf("ClassifyDNSName(%q) unexpected error: %v", tt.input, err)
}
if got != tt.want {
t.Errorf("ClassifyDNSName(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestClassifyTargets(t *testing.T) {
t.Parallel()
domains, hostnames, err := config.ClassifyTargets([]string{
"example.com",
"www.example.com",
"example.co.uk",
"api.example.co.uk",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(domains) != 2 {
t.Errorf("expected 2 domains, got %d: %v", len(domains), domains)
}
if len(hostnames) != 2 {
t.Errorf("expected 2 hostnames, got %d: %v", len(hostnames), hostnames)
}
}
func TestClassifyTargetsRejectsPublicSuffix(t *testing.T) {
t.Parallel()
_, _, err := config.ClassifyTargets([]string{"co.uk"})
if err == nil {
t.Error("expected error for public suffix, got nil")
}
}

View File

@ -89,8 +89,7 @@ func setupViper(name string) {
viper.SetDefault("PORT", defaultPort) viper.SetDefault("PORT", defaultPort)
viper.SetDefault("DEBUG", false) viper.SetDefault("DEBUG", false)
viper.SetDefault("DATA_DIR", "./data") viper.SetDefault("DATA_DIR", "./data")
viper.SetDefault("DOMAINS", "") viper.SetDefault("TARGETS", "")
viper.SetDefault("HOSTNAMES", "")
viper.SetDefault("SLACK_WEBHOOK", "") viper.SetDefault("SLACK_WEBHOOK", "")
viper.SetDefault("MATTERMOST_WEBHOOK", "") viper.SetDefault("MATTERMOST_WEBHOOK", "")
viper.SetDefault("NTFY_TOPIC", "") viper.SetDefault("NTFY_TOPIC", "")
@ -133,12 +132,19 @@ func buildConfig(
tlsInterval = defaultTLSInterval tlsInterval = defaultTLSInterval
} }
domains, hostnames, err := ClassifyTargets(
parseCSV(viper.GetString("TARGETS")),
)
if err != nil {
return nil, fmt.Errorf("invalid targets configuration: %w", err)
}
cfg := &Config{ cfg := &Config{
Port: viper.GetInt("PORT"), Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"), Debug: viper.GetBool("DEBUG"),
DataDir: viper.GetString("DATA_DIR"), DataDir: viper.GetString("DATA_DIR"),
Domains: parseCSV(viper.GetString("DOMAINS")), Domains: domains,
Hostnames: parseCSV(viper.GetString("HOSTNAMES")), Hostnames: hostnames,
SlackWebhook: viper.GetString("SLACK_WEBHOOK"), SlackWebhook: viper.GetString("SLACK_WEBHOOK"),
MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"), MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"),
NtfyTopic: viper.GetString("NTFY_TOPIC"), NtfyTopic: viper.GetString("NTFY_TOPIC"),

View File

@ -36,16 +36,6 @@ var (
) )
) )
// sanitizeURL parses and re-serializes a URL to satisfy static analysis (gosec G704).
func sanitizeURL(raw string) (string, error) {
u, err := url.Parse(raw)
if err != nil {
return "", fmt.Errorf("invalid URL %q: %w", raw, err)
}
return u.String(), nil
}
// Params contains dependencies for Service. // Params contains dependencies for Service.
type Params struct { type Params struct {
fx.In fx.In
@ -56,9 +46,12 @@ type Params struct {
// Service provides notification functionality. // Service provides notification functionality.
type Service struct { type Service struct {
log *slog.Logger log *slog.Logger
client *http.Client client *http.Client
config *config.Config config *config.Config
ntfyURL *url.URL
slackWebhookURL *url.URL
mattermostWebhookURL *url.URL
} }
// New creates a new notify Service. // New creates a new notify Service.
@ -66,13 +59,44 @@ func New(
_ fx.Lifecycle, _ fx.Lifecycle,
params Params, params Params,
) (*Service, error) { ) (*Service, error) {
return &Service{ svc := &Service{
log: params.Logger.Get(), log: params.Logger.Get(),
client: &http.Client{ client: &http.Client{
Timeout: httpClientTimeout, Timeout: httpClientTimeout,
}, },
config: params.Config, config: params.Config,
}, nil }
if params.Config.NtfyTopic != "" {
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
if err != nil {
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
}
svc.ntfyURL = u
}
if params.Config.SlackWebhook != "" {
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
if err != nil {
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
}
svc.slackWebhookURL = u
}
if params.Config.MattermostWebhook != "" {
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
if err != nil {
return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err,
)
}
svc.mattermostWebhookURL = u
}
return svc, nil
} }
// SendNotification sends a notification to all configured endpoints. // SendNotification sends a notification to all configured endpoints.
@ -80,13 +104,13 @@ func (svc *Service) SendNotification(
ctx context.Context, ctx context.Context,
title, message, priority string, title, message, priority string,
) { ) {
if svc.config.NtfyTopic != "" { if svc.ntfyURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendNtfy( err := svc.sendNtfy(
notifyCtx, notifyCtx,
svc.config.NtfyTopic, svc.ntfyURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@ -98,13 +122,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.SlackWebhook != "" { if svc.slackWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.SlackWebhook, svc.slackWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@ -116,13 +140,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.MattermostWebhook != "" { if svc.mattermostWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.MattermostWebhook, svc.mattermostWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@ -137,23 +161,19 @@ func (svc *Service) SendNotification(
func (svc *Service) sendNtfy( func (svc *Service) sendNtfy(
ctx context.Context, ctx context.Context,
topic, title, message, priority string, topicURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending ntfy notification", "sending ntfy notification",
"topic", topic, "topic", topicURL.String(),
"title", title, "title", title,
) )
cleanURL, err := sanitizeURL(topic)
if err != nil {
return fmt.Errorf("invalid ntfy topic URL: %w", err)
}
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
cleanURL, topicURL.String(),
bytes.NewBufferString(message), bytes.NewBufferString(message),
) )
if err != nil { if err != nil {
@ -163,7 +183,7 @@ func (svc *Service) sendNtfy(
request.Header.Set("Title", title) request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority)) request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending ntfy request: %w", err) return fmt.Errorf("sending ntfy request: %w", err)
} }
@ -209,11 +229,12 @@ type SlackAttachment struct {
func (svc *Service) sendSlack( func (svc *Service) sendSlack(
ctx context.Context, ctx context.Context,
webhookURL, title, message, priority string, webhookURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending webhook notification", "sending webhook notification",
"url", webhookURL, "url", webhookURL.String(),
"title", title, "title", title,
) )
@ -232,15 +253,10 @@ func (svc *Service) sendSlack(
return fmt.Errorf("marshaling webhook payload: %w", err) return fmt.Errorf("marshaling webhook payload: %w", err)
} }
cleanURL, err := sanitizeURL(webhookURL)
if err != nil {
return fmt.Errorf("invalid webhook URL: %w", err)
}
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
cleanURL, webhookURL.String(),
bytes.NewBuffer(body), bytes.NewBuffer(body),
) )
if err != nil { if err != nil {
@ -249,7 +265,7 @@ func (svc *Service) sendSlack(
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending webhook request: %w", err) return fmt.Errorf("sending webhook request: %w", err)
} }

View File

@ -0,0 +1,48 @@
package resolver
import (
"context"
"time"
"github.com/miekg/dns"
)
// DNSClient abstracts DNS wire-protocol exchanges so the resolver
// can be tested without hitting real nameservers.
type DNSClient interface {
ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error)
}
// udpClient wraps a real dns.Client for production use.
type udpClient struct {
timeout time.Duration
}
func (c *udpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}
// tcpClient wraps a real dns.Client using TCP.
type tcpClient struct {
timeout time.Duration
}
func (c *tcpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Net: "tcp", Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}

View File

@ -50,25 +50,20 @@ func checkCtx(ctx context.Context) error {
return nil return nil
} }
func exchangeWithTimeout( func (r *Resolver) exchangeWithTimeout(
ctx context.Context, ctx context.Context,
msg *dns.Msg, msg *dns.Msg,
addr string, addr string,
attempt int, attempt int,
) (*dns.Msg, error) { ) (*dns.Msg, error) {
c := new(dns.Client) _ = attempt // timeout escalation handled by client config
c.Timeout = queryTimeoutDuration
if attempt > 0 { resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
c.Timeout = queryTimeoutDuration * timeoutMultiplier
}
resp, _, err := c.ExchangeContext(ctx, msg, addr)
return resp, err return resp, err
} }
func tryExchange( func (r *Resolver) tryExchange(
ctx context.Context, ctx context.Context,
msg *dns.Msg, msg *dns.Msg,
addr string, addr string,
@ -82,7 +77,9 @@ func tryExchange(
return nil, ErrContextCanceled return nil, ErrContextCanceled
} }
resp, err = exchangeWithTimeout(ctx, msg, addr, attempt) resp, err = r.exchangeWithTimeout(
ctx, msg, addr, attempt,
)
if err == nil { if err == nil {
break break
} }
@ -91,7 +88,7 @@ func tryExchange(
return resp, err return resp, err
} }
func retryTCP( func (r *Resolver) retryTCP(
ctx context.Context, ctx context.Context,
msg *dns.Msg, msg *dns.Msg,
addr string, addr string,
@ -101,12 +98,7 @@ func retryTCP(
return resp return resp
} }
c := &dns.Client{ tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr)
Net: "tcp",
Timeout: queryTimeoutDuration,
}
tcpResp, _, tcpErr := c.ExchangeContext(ctx, msg, addr)
if tcpErr == nil { if tcpErr == nil {
return tcpResp return tcpResp
} }
@ -117,7 +109,7 @@ func retryTCP(
// queryDNS sends a DNS query to a specific server IP. // queryDNS sends a DNS query to a specific server IP.
// Tries non-recursive first, falls back to recursive on // Tries non-recursive first, falls back to recursive on
// REFUSED (handles DNS interception environments). // REFUSED (handles DNS interception environments).
func queryDNS( func (r *Resolver) queryDNS(
ctx context.Context, ctx context.Context,
serverIP string, serverIP string,
name string, name string,
@ -134,7 +126,7 @@ func queryDNS(
msg.SetQuestion(name, qtype) msg.SetQuestion(name, qtype)
msg.RecursionDesired = false msg.RecursionDesired = false
resp, err := tryExchange(ctx, msg, addr) resp, err := r.tryExchange(ctx, msg, addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err) return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err)
} }
@ -142,7 +134,7 @@ func queryDNS(
if resp.Rcode == dns.RcodeRefused { if resp.Rcode == dns.RcodeRefused {
msg.RecursionDesired = true msg.RecursionDesired = true
resp, err = tryExchange(ctx, msg, addr) resp, err = r.tryExchange(ctx, msg, addr)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, err, "query %s @%s: %w", name, serverIP, err,
@ -156,7 +148,7 @@ func queryDNS(
} }
} }
resp = retryTCP(ctx, msg, addr, resp) resp = r.retryTCP(ctx, msg, addr, resp)
return resp, nil return resp, nil
} }
@ -221,7 +213,9 @@ func (r *Resolver) followDelegation(
return nil, ErrContextCanceled return nil, ErrContextCanceled
} }
resp, err := queryServers(ctx, servers, domain, dns.TypeNS) resp, err := r.queryServers(
ctx, servers, domain, dns.TypeNS,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -253,7 +247,7 @@ func (r *Resolver) followDelegation(
return nil, ErrNoNameservers return nil, ErrNoNameservers
} }
func queryServers( func (r *Resolver) queryServers(
ctx context.Context, ctx context.Context,
servers []string, servers []string,
name string, name string,
@ -266,7 +260,7 @@ func queryServers(
return nil, ErrContextCanceled return nil, ErrContextCanceled
} }
resp, err := queryDNS(ctx, ip, name, qtype) resp, err := r.queryDNS(ctx, ip, name, qtype)
if err == nil { if err == nil {
return resp, nil return resp, nil
} }
@ -308,8 +302,6 @@ func (r *Resolver) resolveNSRecursive(
msg.SetQuestion(domain, dns.TypeNS) msg.SetQuestion(domain, dns.TypeNS)
msg.RecursionDesired = true msg.RecursionDesired = true
c := &dns.Client{Timeout: queryTimeoutDuration}
for _, ip := range rootServerList()[:3] { for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil { if checkCtx(ctx) != nil {
return nil, ErrContextCanceled return nil, ErrContextCanceled
@ -317,7 +309,7 @@ func (r *Resolver) resolveNSRecursive(
addr := net.JoinHostPort(ip, "53") addr := net.JoinHostPort(ip, "53")
resp, _, err := c.ExchangeContext(ctx, msg, addr) resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil { if err != nil {
continue continue
} }
@ -341,8 +333,6 @@ func (r *Resolver) resolveARecord(
msg.SetQuestion(hostname, dns.TypeA) msg.SetQuestion(hostname, dns.TypeA)
msg.RecursionDesired = true msg.RecursionDesired = true
c := &dns.Client{Timeout: queryTimeoutDuration}
for _, ip := range rootServerList()[:3] { for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil { if checkCtx(ctx) != nil {
return nil, ErrContextCanceled return nil, ErrContextCanceled
@ -350,7 +340,7 @@ func (r *Resolver) resolveARecord(
addr := net.JoinHostPort(ip, "53") addr := net.JoinHostPort(ip, "53")
resp, _, err := c.ExchangeContext(ctx, msg, addr) resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil { if err != nil {
continue continue
} }
@ -490,7 +480,7 @@ func (r *Resolver) querySingleType(
resp *NameserverResponse, resp *NameserverResponse,
state *queryState, state *queryState,
) { ) {
msg, err := queryDNS(ctx, nsIP, hostname, qtype) msg, err := r.queryDNS(ctx, nsIP, hostname, qtype)
if err != nil { if err != nil {
return return
} }
@ -641,6 +631,25 @@ func (r *Resolver) LookupNS(
return r.FindAuthoritativeNameservers(ctx, domain) return r.FindAuthoritativeNameservers(ctx, domain)
} }
// LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname, keyed by authoritative nameserver.
func (r *Resolver) LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error) {
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
out := make(map[string]map[string][]string, len(results))
for ns, resp := range results {
out[ns] = resp.Records
}
return out, nil
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6 // ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains up to MaxCNAMEDepth. // addresses, following CNAME chains up to MaxCNAMEDepth.
func (r *Resolver) ResolveIPAddresses( func (r *Resolver) ResolveIPAddresses(

View File

@ -39,7 +39,9 @@ type NameserverResponse struct {
// Resolver performs iterative DNS resolution from root servers. // Resolver performs iterative DNS resolution from root servers.
type Resolver struct { type Resolver struct {
log *slog.Logger log *slog.Logger
client DNSClient
tcp DNSClient
} }
// New creates a new Resolver instance for use with uber/fx. // New creates a new Resolver instance for use with uber/fx.
@ -48,14 +50,33 @@ func New(
params Params, params Params,
) (*Resolver, error) { ) (*Resolver, error) {
return &Resolver{ return &Resolver{
log: params.Logger.Get(), log: params.Logger.Get(),
client: &udpClient{timeout: queryTimeoutDuration},
tcp: &tcpClient{timeout: queryTimeoutDuration},
}, nil }, nil
} }
// NewFromLogger creates a Resolver directly from an slog.Logger, // NewFromLogger creates a Resolver directly from an slog.Logger,
// useful for testing without the fx lifecycle. // useful for testing without the fx lifecycle.
func NewFromLogger(log *slog.Logger) *Resolver { func NewFromLogger(log *slog.Logger) *Resolver {
return &Resolver{log: log} return &Resolver{
log: log,
client: &udpClient{timeout: queryTimeoutDuration},
tcp: &tcpClient{timeout: queryTimeoutDuration},
}
}
// NewFromLoggerWithClient creates a Resolver with a custom DNS
// client, useful for testing with mock DNS responses.
func NewFromLoggerWithClient(
log *slog.Logger,
client DNSClient,
) *Resolver {
return &Resolver{
log: log,
client: client,
tcp: client,
}
} }
// Method implementations are in iterative.go. // Method implementations are in iterative.go.

View File

@ -0,0 +1,85 @@
//go:build integration
package resolver_test
import (
"context"
"log/slog"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/resolver"
)
// Integration tests hit real DNS servers. Run with:
// go test -tags integration -timeout 60s ./internal/resolver/
func newIntegrationResolver(t *testing.T) *resolver.Resolver {
t.Helper()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
return resolver.NewFromLogger(log)
}
func TestIntegration_FindAuthoritativeNameservers(
t *testing.T,
) {
t.Parallel()
r := newIntegrationResolver(t)
ctx, cancel := context.WithTimeout(
context.Background(), 30*time.Second,
)
defer cancel()
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "example.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
t.Logf("example.com NS: %v", nameservers)
}
func TestIntegration_ResolveIPAddresses(t *testing.T) {
t.Parallel()
r := newIntegrationResolver(t)
ctx, cancel := context.WithTimeout(
context.Background(), 30*time.Second,
)
defer cancel()
// sneak.cloud is on Cloudflare
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "sneak.cloud",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
hasCloudflare := false
for _, ns := range nameservers {
if strings.Contains(ns, "cloudflare") {
hasCloudflare = true
break
}
}
assert.True(t, hasCloudflare,
"sneak.cloud should be on Cloudflare, got: %v",
nameservers,
)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,22 @@
package state
import (
"log/slog"
"sneak.berlin/go/dnswatcher/internal/config"
)
// NewForTest creates a State for unit testing with no persistence.
func NewForTest() *State {
return &State{
log: slog.Default(),
snapshot: &Snapshot{
Version: stateVersion,
Domains: make(map[string]*DomainState),
Hostnames: make(map[string]*HostnameState),
Ports: make(map[string]*PortState),
Certificates: make(map[string]*CertificateState),
},
config: &config.Config{DataDir: ""},
}
}

View File

@ -0,0 +1,60 @@
// Package watcher provides the main monitoring orchestrator.
package watcher
import (
"context"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
// DNSResolver performs iterative DNS resolution.
type DNSResolver interface {
// LookupNS discovers authoritative nameservers for a domain.
LookupNS(
ctx context.Context,
domain string,
) ([]string, error)
// LookupAllRecords queries all record types for a hostname,
// returning results keyed by nameserver then record type.
LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error)
// ResolveIPAddresses resolves a hostname to all IP addresses.
ResolveIPAddresses(
ctx context.Context,
hostname string,
) ([]string, error)
}
// PortChecker tests TCP port connectivity.
type PortChecker interface {
// CheckPort tests TCP connectivity to an address and port.
CheckPort(
ctx context.Context,
address string,
port int,
) (bool, error)
}
// TLSChecker inspects TLS certificates.
type TLSChecker interface {
// CheckCertificate connects via TLS and returns cert info.
CheckCertificate(
ctx context.Context,
ip string,
hostname string,
) (*tlscheck.CertificateInfo, error)
}
// Notifier delivers notifications to configured endpoints.
type Notifier interface {
// SendNotification sends a notification with the given
// details.
SendNotification(
ctx context.Context,
title, message, priority string,
)
}

View File

@ -1,21 +1,30 @@
// Package watcher provides the main monitoring orchestrator and scheduler.
package watcher package watcher
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"sort"
"strings"
"time"
"go.uber.org/fx" "go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
"sneak.berlin/go/dnswatcher/internal/notify"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/resolver"
"sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/tlscheck"
) )
// monitoredPorts are the TCP ports checked for each IP address.
var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals
// tlsPort is the port used for TLS certificate checks.
const tlsPort = 443
// hoursPerDay converts days to hours for duration calculations.
const hoursPerDay = 24
// Params contains dependencies for Watcher. // Params contains dependencies for Watcher.
type Params struct { type Params struct {
fx.In fx.In
@ -23,10 +32,10 @@ type Params struct {
Logger *logger.Logger Logger *logger.Logger
Config *config.Config Config *config.Config
State *state.State State *state.State
Resolver *resolver.Resolver Resolver DNSResolver
PortCheck *portcheck.Checker PortCheck PortChecker
TLSCheck *tlscheck.Checker TLSCheck TLSChecker
Notify *notify.Service Notify Notifier
} }
// Watcher orchestrates all monitoring checks on a schedule. // Watcher orchestrates all monitoring checks on a schedule.
@ -34,19 +43,20 @@ type Watcher struct {
log *slog.Logger log *slog.Logger
config *config.Config config *config.Config
state *state.State state *state.State
resolver *resolver.Resolver resolver DNSResolver
portCheck *portcheck.Checker portCheck PortChecker
tlsCheck *tlscheck.Checker tlsCheck TLSChecker
notify *notify.Service notify Notifier
cancel context.CancelFunc cancel context.CancelFunc
firstRun bool
} }
// New creates a new Watcher instance. // New creates a new Watcher instance wired into the fx lifecycle.
func New( func New(
lifecycle fx.Lifecycle, lifecycle fx.Lifecycle,
params Params, params Params,
) (*Watcher, error) { ) (*Watcher, error) {
watcher := &Watcher{ w := &Watcher{
log: params.Logger.Get(), log: params.Logger.Get(),
config: params.Config, config: params.Config,
state: params.State, state: params.State,
@ -54,30 +64,54 @@ func New(
portCheck: params.PortCheck, portCheck: params.PortCheck,
tlsCheck: params.TLSCheck, tlsCheck: params.TLSCheck,
notify: params.Notify, notify: params.Notify,
firstRun: true,
} }
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{
OnStart: func(startCtx context.Context) error { OnStart: func(startCtx context.Context) error {
ctx, cancel := context.WithCancel(startCtx) ctx, cancel := context.WithCancel(
watcher.cancel = cancel context.WithoutCancel(startCtx),
)
w.cancel = cancel
go watcher.Run(ctx) go w.Run(ctx)
return nil return nil
}, },
OnStop: func(_ context.Context) error { OnStop: func(_ context.Context) error {
if watcher.cancel != nil { if w.cancel != nil {
watcher.cancel() w.cancel()
} }
return nil return nil
}, },
}) })
return watcher, nil return w, nil
} }
// Run starts the monitoring loop. // NewForTest creates a Watcher without fx for unit testing.
func NewForTest(
cfg *config.Config,
st *state.State,
res DNSResolver,
pc PortChecker,
tc TLSChecker,
n Notifier,
) *Watcher {
return &Watcher{
log: slog.Default(),
config: cfg,
state: st,
resolver: res,
portCheck: pc,
tlsCheck: tc,
notify: n,
firstRun: true,
}
}
// Run starts the monitoring loop with periodic scheduling.
func (w *Watcher) Run(ctx context.Context) { func (w *Watcher) Run(ctx context.Context) {
w.log.Info( w.log.Info(
"watcher starting", "watcher starting",
@ -87,8 +121,646 @@ func (w *Watcher) Run(ctx context.Context) {
"tlsInterval", w.config.TLSInterval, "tlsInterval", w.config.TLSInterval,
) )
// Stub: wait for context cancellation. w.RunOnce(ctx)
// Implementation will add initial check + periodic scheduling.
<-ctx.Done() dnsTicker := time.NewTicker(w.config.DNSInterval)
w.log.Info("watcher stopped") tlsTicker := time.NewTicker(w.config.TLSInterval)
defer dnsTicker.Stop()
defer tlsTicker.Stop()
for {
select {
case <-ctx.Done():
w.log.Info("watcher stopped")
return
case <-dnsTicker.C:
w.runDNSAndPortChecks(ctx)
w.saveState()
case <-tlsTicker.C:
w.runTLSChecks(ctx)
w.saveState()
}
}
}
// RunOnce performs a single complete monitoring cycle.
func (w *Watcher) RunOnce(ctx context.Context) {
w.detectFirstRun()
w.runDNSAndPortChecks(ctx)
w.runTLSChecks(ctx)
w.saveState()
w.firstRun = false
}
func (w *Watcher) detectFirstRun() {
snap := w.state.GetSnapshot()
hasState := len(snap.Domains) > 0 ||
len(snap.Hostnames) > 0 ||
len(snap.Ports) > 0 ||
len(snap.Certificates) > 0
if hasState {
w.firstRun = false
}
}
func (w *Watcher) runDNSAndPortChecks(ctx context.Context) {
for _, domain := range w.config.Domains {
w.checkDomain(ctx, domain)
}
for _, hostname := range w.config.Hostnames {
w.checkHostname(ctx, hostname)
}
w.checkAllPorts(ctx)
}
func (w *Watcher) checkDomain(
ctx context.Context,
domain string,
) {
nameservers, err := w.resolver.LookupNS(ctx, domain)
if err != nil {
w.log.Error(
"failed to lookup NS",
"domain", domain,
"error", err,
)
return
}
sort.Strings(nameservers)
now := time.Now().UTC()
prev, hasPrev := w.state.GetDomainState(domain)
if hasPrev && !w.firstRun {
w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers)
}
w.state.SetDomainState(domain, &state.DomainState{
Nameservers: nameservers,
LastChecked: now,
})
}
func (w *Watcher) detectNSChanges(
ctx context.Context,
domain string,
oldNS, newNS []string,
) {
oldSet := toSet(oldNS)
newSet := toSet(newNS)
var added, removed []string
for ns := range newSet {
if !oldSet[ns] {
added = append(added, ns)
}
}
for ns := range oldSet {
if !newSet[ns] {
removed = append(removed, ns)
}
}
if len(added) == 0 && len(removed) == 0 {
return
}
msg := fmt.Sprintf(
"Domain: %s\nAdded: %s\nRemoved: %s",
domain,
strings.Join(added, ", "),
strings.Join(removed, ", "),
)
w.notify.SendNotification(
ctx,
"NS Change: "+domain,
msg,
"warning",
)
}
func (w *Watcher) checkHostname(
ctx context.Context,
hostname string,
) {
records, err := w.resolver.LookupAllRecords(ctx, hostname)
if err != nil {
w.log.Error(
"failed to lookup records",
"hostname", hostname,
"error", err,
)
return
}
now := time.Now().UTC()
prev, hasPrev := w.state.GetHostnameState(hostname)
if hasPrev && !w.firstRun {
w.detectHostnameChanges(ctx, hostname, prev, records)
}
newState := buildHostnameState(records, now)
w.state.SetHostnameState(hostname, newState)
}
func buildHostnameState(
records map[string]map[string][]string,
now time.Time,
) *state.HostnameState {
hs := &state.HostnameState{
RecordsByNameserver: make(
map[string]*state.NameserverRecordState,
),
LastChecked: now,
}
for ns, recs := range records {
hs.RecordsByNameserver[ns] = &state.NameserverRecordState{
Records: recs,
Status: "ok",
LastChecked: now,
}
}
return hs
}
func (w *Watcher) detectHostnameChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
w.detectRecordChanges(ctx, hostname, prev, current)
w.detectNSDisappearances(ctx, hostname, prev, current)
w.detectInconsistencies(ctx, hostname, current)
}
func (w *Watcher) detectRecordChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, recs := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok {
continue
}
if recordsEqual(prevNS.Records, recs) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s\n"+
"Old: %v\nNew: %v",
hostname, ns,
prevNS.Records, recs,
)
w.notify.SendNotification(
ctx,
"Record Change: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) detectNSDisappearances(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, prevNS := range prev.RecordsByNameserver {
if _, ok := current[ns]; ok || prevNS.Status != "ok" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s disappeared",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Failure: "+hostname,
msg,
"error",
)
}
for ns := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok || prevNS.Status != "error" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s recovered",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Recovery: "+hostname,
msg,
"success",
)
}
}
func (w *Watcher) detectInconsistencies(
ctx context.Context,
hostname string,
current map[string]map[string][]string,
) {
nameservers := make([]string, 0, len(current))
for ns := range current {
nameservers = append(nameservers, ns)
}
sort.Strings(nameservers)
for i := range len(nameservers) - 1 {
ns1 := nameservers[i]
ns2 := nameservers[i+1]
if recordsEqual(current[ns1], current[ns2]) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\n%s: %v\n%s: %v",
hostname,
ns1, current[ns1],
ns2, current[ns2],
)
w.notify.SendNotification(
ctx,
"Inconsistency: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) checkAllPorts(ctx context.Context) {
for _, hostname := range w.config.Hostnames {
w.checkPortsForHostname(ctx, hostname)
}
for _, domain := range w.config.Domains {
w.checkPortsForHostname(ctx, domain)
}
}
func (w *Watcher) checkPortsForHostname(
ctx context.Context,
hostname string,
) {
ips := w.collectIPs(hostname)
for _, ip := range ips {
for _, port := range monitoredPorts {
w.checkSinglePort(ctx, ip, port, hostname)
}
}
}
func (w *Watcher) collectIPs(hostname string) []string {
hs, ok := w.state.GetHostnameState(hostname)
if !ok {
return nil
}
ipSet := make(map[string]bool)
for _, nsState := range hs.RecordsByNameserver {
for _, ip := range nsState.Records["A"] {
ipSet[ip] = true
}
for _, ip := range nsState.Records["AAAA"] {
ipSet[ip] = true
}
}
result := make([]string, 0, len(ipSet))
for ip := range ipSet {
result = append(result, ip)
}
sort.Strings(result)
return result
}
func (w *Watcher) checkSinglePort(
ctx context.Context,
ip string,
port int,
hostname string,
) {
open, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil {
w.log.Error(
"port check failed",
"ip", ip,
"port", port,
"error", err,
)
return
}
key := fmt.Sprintf("%s:%d", ip, port)
now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != open {
stateStr := "closed"
if open {
stateStr = "open"
}
msg := fmt.Sprintf(
"Host: %s\nAddress: %s\nPort now %s",
hostname, key, stateStr,
)
w.notify.SendNotification(
ctx,
"Port Change: "+key,
msg,
"warning",
)
}
w.state.SetPortState(key, &state.PortState{
Open: open,
Hostname: hostname,
LastChecked: now,
})
}
func (w *Watcher) runTLSChecks(ctx context.Context) {
for _, hostname := range w.config.Hostnames {
w.checkTLSForHostname(ctx, hostname)
}
for _, domain := range w.config.Domains {
w.checkTLSForHostname(ctx, domain)
}
}
func (w *Watcher) checkTLSForHostname(
ctx context.Context,
hostname string,
) {
ips := w.collectIPs(hostname)
for _, ip := range ips {
portKey := fmt.Sprintf("%s:%d", ip, tlsPort)
ps, ok := w.state.GetPortState(portKey)
if !ok || !ps.Open {
continue
}
w.checkTLSCert(ctx, ip, hostname)
}
}
func (w *Watcher) checkTLSCert(
ctx context.Context,
ip string,
hostname string,
) {
cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname)
certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname)
now := time.Now().UTC()
prev, hasPrev := w.state.GetCertificateState(certKey)
if err != nil {
w.handleTLSError(
ctx, certKey, hostname, ip,
hasPrev, prev, now, err,
)
return
}
w.handleTLSSuccess(
ctx, certKey, hostname, ip,
hasPrev, prev, now, cert,
)
}
func (w *Watcher) handleTLSError(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
err error,
) {
if hasPrev && !w.firstRun && prev.Status == "ok" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nError: %s",
hostname, ip, err,
)
w.notify.SendNotification(
ctx,
"TLS Failure: "+hostname,
msg,
"error",
)
}
w.state.SetCertificateState(
certKey, &state.CertificateState{
Status: "error",
Error: err.Error(),
LastChecked: now,
},
)
}
func (w *Watcher) handleTLSSuccess(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
cert *tlscheck.CertificateInfo,
) {
if hasPrev && !w.firstRun {
w.detectTLSChanges(ctx, hostname, ip, prev, cert)
}
w.checkTLSExpiry(ctx, hostname, ip, cert)
w.state.SetCertificateState(
certKey, &state.CertificateState{
CommonName: cert.CommonName,
Issuer: cert.Issuer,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: cert.SubjectAlternativeNames,
Status: "ok",
LastChecked: now,
},
)
}
func (w *Watcher) detectTLSChanges(
ctx context.Context,
hostname, ip string,
prev *state.CertificateState,
cert *tlscheck.CertificateInfo,
) {
if prev.Status == "error" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nTLS recovered",
hostname, ip,
)
w.notify.SendNotification(
ctx,
"TLS Recovery: "+hostname,
msg,
"success",
)
return
}
changed := prev.CommonName != cert.CommonName ||
prev.Issuer != cert.Issuer ||
!sliceEqual(
prev.SubjectAlternativeNames,
cert.SubjectAlternativeNames,
)
if !changed {
return
}
msg := fmt.Sprintf(
"Host: %s\nIP: %s\n"+
"Old CN: %s, Issuer: %s\n"+
"New CN: %s, Issuer: %s",
hostname, ip,
prev.CommonName, prev.Issuer,
cert.CommonName, cert.Issuer,
)
w.notify.SendNotification(
ctx,
"TLS Certificate Change: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) checkTLSExpiry(
ctx context.Context,
hostname, ip string,
cert *tlscheck.CertificateInfo,
) {
daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay
warningDays := float64(w.config.TLSExpiryWarning)
if daysLeft > warningDays {
return
}
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nCN: %s\n"+
"Expires: %s (%.0f days)",
hostname, ip, cert.CommonName,
cert.NotAfter.Format(time.RFC3339),
daysLeft,
)
w.notify.SendNotification(
ctx,
"TLS Expiry Warning: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) saveState() {
err := w.state.Save()
if err != nil {
w.log.Error("failed to save state", "error", err)
}
}
// --- Utility functions ---
func toSet(items []string) map[string]bool {
set := make(map[string]bool, len(items))
for _, item := range items {
set[item] = true
}
return set
}
func recordsEqual(
a, b map[string][]string,
) bool {
if len(a) != len(b) {
return false
}
for k, av := range a {
bv, ok := b[k]
if !ok || !sliceEqual(av, bv) {
return false
}
}
return true
}
func sliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
aSorted := make([]string, len(a))
bSorted := make([]string, len(b))
copy(aSorted, a)
copy(bSorted, b)
sort.Strings(aSorted)
sort.Strings(bSorted)
for i := range aSorted {
if aSorted[i] != bSorted[i] {
return false
}
}
return true
} }

View File

@ -0,0 +1,580 @@
package watcher_test
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
"sneak.berlin/go/dnswatcher/internal/watcher"
)
// errNotFound is returned when mock data is missing.
var errNotFound = errors.New("not found")
// --- Mock implementations ---
type mockResolver struct {
mu sync.Mutex
nsRecords map[string][]string
allRecords map[string]map[string]map[string][]string
ipAddresses map[string][]string
lookupNSErr error
allRecordsErr error
resolveIPErr error
lookupNSCalls int
allRecordCalls int
}
func (m *mockResolver) LookupNS(
_ context.Context,
domain string,
) ([]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.lookupNSCalls++
if m.lookupNSErr != nil {
return nil, m.lookupNSErr
}
ns, ok := m.nsRecords[domain]
if !ok {
return nil, fmt.Errorf(
"%w: NS for %s", errNotFound, domain,
)
}
return ns, nil
}
func (m *mockResolver) LookupAllRecords(
_ context.Context,
hostname string,
) (map[string]map[string][]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.allRecordCalls++
if m.allRecordsErr != nil {
return nil, m.allRecordsErr
}
recs, ok := m.allRecords[hostname]
if !ok {
return nil, fmt.Errorf(
"%w: records for %s", errNotFound, hostname,
)
}
return recs, nil
}
func (m *mockResolver) ResolveIPAddresses(
_ context.Context,
hostname string,
) ([]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.resolveIPErr != nil {
return nil, m.resolveIPErr
}
ips, ok := m.ipAddresses[hostname]
if !ok {
return nil, fmt.Errorf(
"%w: IPs for %s", errNotFound, hostname,
)
}
return ips, nil
}
type mockPortChecker struct {
mu sync.Mutex
results map[string]bool
err error
calls int
}
func (m *mockPortChecker) CheckPort(
_ context.Context,
address string,
port int,
) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return false, m.err
}
key := fmt.Sprintf("%s:%d", address, port)
open, ok := m.results[key]
if !ok {
return false, nil
}
return open, nil
}
type mockTLSChecker struct {
mu sync.Mutex
certs map[string]*tlscheck.CertificateInfo
err error
calls int
}
func (m *mockTLSChecker) CheckCertificate(
_ context.Context,
ip string,
hostname string,
) (*tlscheck.CertificateInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return nil, m.err
}
key := fmt.Sprintf("%s:%s", ip, hostname)
cert, ok := m.certs[key]
if !ok {
return nil, fmt.Errorf(
"%w: cert for %s", errNotFound, key,
)
}
return cert, nil
}
type notification struct {
Title string
Message string
Priority string
}
type mockNotifier struct {
mu sync.Mutex
notifications []notification
}
func (m *mockNotifier) SendNotification(
_ context.Context,
title, message, priority string,
) {
m.mu.Lock()
defer m.mu.Unlock()
m.notifications = append(m.notifications, notification{
Title: title,
Message: message,
Priority: priority,
})
}
func (m *mockNotifier) getNotifications() []notification {
m.mu.Lock()
defer m.mu.Unlock()
result := make([]notification, len(m.notifications))
copy(result, m.notifications)
return result
}
// --- Helper to build a Watcher for testing ---
type testDeps struct {
resolver *mockResolver
portChecker *mockPortChecker
tlsChecker *mockTLSChecker
notifier *mockNotifier
state *state.State
config *config.Config
}
func newTestWatcher(
t *testing.T,
cfg *config.Config,
) (*watcher.Watcher, *testDeps) {
t.Helper()
deps := &testDeps{
resolver: &mockResolver{
nsRecords: make(map[string][]string),
allRecords: make(map[string]map[string]map[string][]string),
ipAddresses: make(map[string][]string),
},
portChecker: &mockPortChecker{
results: make(map[string]bool),
},
tlsChecker: &mockTLSChecker{
certs: make(map[string]*tlscheck.CertificateInfo),
},
notifier: &mockNotifier{},
config: cfg,
}
deps.state = state.NewForTest()
w := watcher.NewForTest(
deps.config,
deps.state,
deps.resolver,
deps.portChecker,
deps.tlsChecker,
deps.notifier,
)
return w, deps
}
func defaultTestConfig(t *testing.T) *config.Config {
t.Helper()
return &config.Config{
DNSInterval: time.Hour,
TLSInterval: 12 * time.Hour,
TLSExpiryWarning: 7,
DataDir: t.TempDir(),
}
}
func TestFirstRunBaseline(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
setupBaselineMocks(deps)
w.RunOnce(t.Context())
assertNoNotifications(t, deps)
assertStatePopulated(t, deps)
}
func setupBaselineMocks(deps *testDeps) {
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns2.example.com.",
}
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
"ns2.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.34",
}
deps.portChecker.results["93.184.216.34:80"] = true
deps.portChecker.results["93.184.216.34:443"] = true
deps.tlsChecker.certs["93.184.216.34:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
}
func assertNoNotifications(
t *testing.T,
deps *testDeps,
) {
t.Helper()
notifications := deps.notifier.getNotifications()
if len(notifications) != 0 {
t.Errorf(
"expected 0 notifications on first run, got %d",
len(notifications),
)
}
}
func assertStatePopulated(
t *testing.T,
deps *testDeps,
) {
t.Helper()
snap := deps.state.GetSnapshot()
if len(snap.Domains) != 1 {
t.Errorf(
"expected 1 domain in state, got %d",
len(snap.Domains),
)
}
if len(snap.Hostnames) != 1 {
t.Errorf(
"expected 1 hostname in state, got %d",
len(snap.Hostnames),
)
}
}
func TestNSChangeDetection(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns2.example.com.",
}
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns3.example.com.",
}
deps.resolver.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for NS change")
}
found := false
for _, n := range notifications {
if n.Priority == "warning" {
found = true
}
}
if !found {
t.Error("expected warning-priority NS change notification")
}
}
func TestRecordChangeDetection(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.34",
}
deps.portChecker.results["93.184.216.34:80"] = false
deps.portChecker.results["93.184.216.34:443"] = false
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.35"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.35",
}
deps.resolver.mu.Unlock()
deps.portChecker.mu.Lock()
deps.portChecker.results["93.184.216.35:80"] = false
deps.portChecker.results["93.184.216.35:443"] = false
deps.portChecker.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for record change")
}
}
func TestPortStateChange(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
w.RunOnce(ctx)
deps.portChecker.mu.Lock()
deps.portChecker.results["1.2.3.4:443"] = false
deps.portChecker.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for port state change")
}
}
func TestTLSExpiryWarning(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(3 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
// First run = baseline
w.RunOnce(ctx)
// Second run should warn about expiry
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
found := false
for _, n := range notifications {
if n.Priority == "warning" {
found = true
}
}
if !found {
t.Errorf(
"expected expiry warning, got: %v",
notifications,
)
}
}
func TestGracefulShutdown(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.DNSInterval = 100 * time.Millisecond
cfg.TLSInterval = 100 * time.Millisecond
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
}
ctx, cancel := context.WithCancel(t.Context())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
time.Sleep(250 * time.Millisecond)
cancel()
select {
case <-done:
// Shut down cleanly
case <-time.After(5 * time.Second):
t.Error("watcher did not shut down within timeout")
}
}
func TestNSFailureAndRecovery(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
"ns2.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = false
deps.portChecker.results["1.2.3.4:443"] = false
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for NS disappearance")
}
}