From dea30028b1ba38c8d30bf0b209fe687979a1d3d8 Mon Sep 17 00:00:00 2001 From: clawbot Date: Thu, 19 Feb 2026 13:48:38 -0800 Subject: [PATCH] test: add watcher orchestrator tests with mock dependencies Tests cover: first-run baseline, NS change detection, record change detection, port state changes, TLS expiry warnings, graceful shutdown, and NS failure/recovery scenarios. --- internal/state/state_test_helper.go | 22 ++ internal/watcher/interfaces.go | 60 +++ internal/watcher/watcher_test.go | 580 ++++++++++++++++++++++++++++ 3 files changed, 662 insertions(+) create mode 100644 internal/state/state_test_helper.go create mode 100644 internal/watcher/interfaces.go create mode 100644 internal/watcher/watcher_test.go diff --git a/internal/state/state_test_helper.go b/internal/state/state_test_helper.go new file mode 100644 index 0000000..2142bfb --- /dev/null +++ b/internal/state/state_test_helper.go @@ -0,0 +1,22 @@ +package state + +import ( + "log/slog" + + "sneak.berlin/go/dnswatcher/internal/config" +) + +// NewForTest creates a State for unit testing with no persistence. +func NewForTest() *State { + return &State{ + log: slog.Default(), + snapshot: &Snapshot{ + Version: stateVersion, + Domains: make(map[string]*DomainState), + Hostnames: make(map[string]*HostnameState), + Ports: make(map[string]*PortState), + Certificates: make(map[string]*CertificateState), + }, + config: &config.Config{DataDir: ""}, + } +} diff --git a/internal/watcher/interfaces.go b/internal/watcher/interfaces.go new file mode 100644 index 0000000..695139d --- /dev/null +++ b/internal/watcher/interfaces.go @@ -0,0 +1,60 @@ +// Package watcher provides the main monitoring orchestrator. +package watcher + +import ( + "context" + + "sneak.berlin/go/dnswatcher/internal/tlscheck" +) + +// DNSResolver performs iterative DNS resolution. +type DNSResolver interface { + // LookupNS discovers authoritative nameservers for a domain. + LookupNS( + ctx context.Context, + domain string, + ) ([]string, error) + + // LookupAllRecords queries all record types for a hostname, + // returning results keyed by nameserver then record type. + LookupAllRecords( + ctx context.Context, + hostname string, + ) (map[string]map[string][]string, error) + + // ResolveIPAddresses resolves a hostname to all IP addresses. + ResolveIPAddresses( + ctx context.Context, + hostname string, + ) ([]string, error) +} + +// PortChecker tests TCP port connectivity. +type PortChecker interface { + // CheckPort tests TCP connectivity to an address and port. + CheckPort( + ctx context.Context, + address string, + port int, + ) (bool, error) +} + +// TLSChecker inspects TLS certificates. +type TLSChecker interface { + // CheckCertificate connects via TLS and returns cert info. + CheckCertificate( + ctx context.Context, + ip string, + hostname string, + ) (*tlscheck.CertificateInfo, error) +} + +// Notifier delivers notifications to configured endpoints. +type Notifier interface { + // SendNotification sends a notification with the given + // details. + SendNotification( + ctx context.Context, + title, message, priority string, + ) +} diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go new file mode 100644 index 0000000..69772aa --- /dev/null +++ b/internal/watcher/watcher_test.go @@ -0,0 +1,580 @@ +package watcher_test + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "sneak.berlin/go/dnswatcher/internal/config" + "sneak.berlin/go/dnswatcher/internal/state" + "sneak.berlin/go/dnswatcher/internal/tlscheck" + "sneak.berlin/go/dnswatcher/internal/watcher" +) + +// errNotFound is returned when mock data is missing. +var errNotFound = errors.New("not found") + +// --- Mock implementations --- + +type mockResolver struct { + mu sync.Mutex + nsRecords map[string][]string + allRecords map[string]map[string]map[string][]string + ipAddresses map[string][]string + lookupNSErr error + allRecordsErr error + resolveIPErr error + lookupNSCalls int + allRecordCalls int +} + +func (m *mockResolver) LookupNS( + _ context.Context, + domain string, +) ([]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.lookupNSCalls++ + + if m.lookupNSErr != nil { + return nil, m.lookupNSErr + } + + ns, ok := m.nsRecords[domain] + if !ok { + return nil, fmt.Errorf( + "%w: NS for %s", errNotFound, domain, + ) + } + + return ns, nil +} + +func (m *mockResolver) LookupAllRecords( + _ context.Context, + hostname string, +) (map[string]map[string][]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.allRecordCalls++ + + if m.allRecordsErr != nil { + return nil, m.allRecordsErr + } + + recs, ok := m.allRecords[hostname] + if !ok { + return nil, fmt.Errorf( + "%w: records for %s", errNotFound, hostname, + ) + } + + return recs, nil +} + +func (m *mockResolver) ResolveIPAddresses( + _ context.Context, + hostname string, +) ([]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.resolveIPErr != nil { + return nil, m.resolveIPErr + } + + ips, ok := m.ipAddresses[hostname] + if !ok { + return nil, fmt.Errorf( + "%w: IPs for %s", errNotFound, hostname, + ) + } + + return ips, nil +} + +type mockPortChecker struct { + mu sync.Mutex + results map[string]bool + err error + calls int +} + +func (m *mockPortChecker) CheckPort( + _ context.Context, + address string, + port int, +) (bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls++ + + if m.err != nil { + return false, m.err + } + + key := fmt.Sprintf("%s:%d", address, port) + open, ok := m.results[key] + + if !ok { + return false, nil + } + + return open, nil +} + +type mockTLSChecker struct { + mu sync.Mutex + certs map[string]*tlscheck.CertificateInfo + err error + calls int +} + +func (m *mockTLSChecker) CheckCertificate( + _ context.Context, + ip string, + hostname string, +) (*tlscheck.CertificateInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls++ + + if m.err != nil { + return nil, m.err + } + + key := fmt.Sprintf("%s:%s", ip, hostname) + cert, ok := m.certs[key] + + if !ok { + return nil, fmt.Errorf( + "%w: cert for %s", errNotFound, key, + ) + } + + return cert, nil +} + +type notification struct { + Title string + Message string + Priority string +} + +type mockNotifier struct { + mu sync.Mutex + notifications []notification +} + +func (m *mockNotifier) SendNotification( + _ context.Context, + title, message, priority string, +) { + m.mu.Lock() + defer m.mu.Unlock() + + m.notifications = append(m.notifications, notification{ + Title: title, + Message: message, + Priority: priority, + }) +} + +func (m *mockNotifier) getNotifications() []notification { + m.mu.Lock() + defer m.mu.Unlock() + + result := make([]notification, len(m.notifications)) + copy(result, m.notifications) + + return result +} + +// --- Helper to build a Watcher for testing --- + +type testDeps struct { + resolver *mockResolver + portChecker *mockPortChecker + tlsChecker *mockTLSChecker + notifier *mockNotifier + state *state.State + config *config.Config +} + +func newTestWatcher( + t *testing.T, + cfg *config.Config, +) (*watcher.Watcher, *testDeps) { + t.Helper() + + deps := &testDeps{ + resolver: &mockResolver{ + nsRecords: make(map[string][]string), + allRecords: make(map[string]map[string]map[string][]string), + ipAddresses: make(map[string][]string), + }, + portChecker: &mockPortChecker{ + results: make(map[string]bool), + }, + tlsChecker: &mockTLSChecker{ + certs: make(map[string]*tlscheck.CertificateInfo), + }, + notifier: &mockNotifier{}, + config: cfg, + } + + deps.state = state.NewForTest() + + w := watcher.NewForTest( + deps.config, + deps.state, + deps.resolver, + deps.portChecker, + deps.tlsChecker, + deps.notifier, + ) + + return w, deps +} + +func defaultTestConfig(t *testing.T) *config.Config { + t.Helper() + + return &config.Config{ + DNSInterval: time.Hour, + TLSInterval: 12 * time.Hour, + TLSExpiryWarning: 7, + DataDir: t.TempDir(), + } +} + +func TestFirstRunBaseline(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Domains = []string{"example.com"} + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + setupBaselineMocks(deps) + + w.RunOnce(t.Context()) + + assertNoNotifications(t, deps) + assertStatePopulated(t, deps) +} + +func setupBaselineMocks(deps *testDeps) { + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + "ns2.example.com.", + } + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"93.184.216.34"}}, + "ns2.example.com.": {"A": {"93.184.216.34"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "93.184.216.34", + } + deps.portChecker.results["93.184.216.34:80"] = true + deps.portChecker.results["93.184.216.34:443"] = true + deps.tlsChecker.certs["93.184.216.34:www.example.com"] = &tlscheck.CertificateInfo{ + CommonName: "www.example.com", + Issuer: "DigiCert", + NotAfter: time.Now().Add(90 * 24 * time.Hour), + SubjectAlternativeNames: []string{ + "www.example.com", + }, + } +} + +func assertNoNotifications( + t *testing.T, + deps *testDeps, +) { + t.Helper() + + notifications := deps.notifier.getNotifications() + if len(notifications) != 0 { + t.Errorf( + "expected 0 notifications on first run, got %d", + len(notifications), + ) + } +} + +func assertStatePopulated( + t *testing.T, + deps *testDeps, +) { + t.Helper() + + snap := deps.state.GetSnapshot() + + if len(snap.Domains) != 1 { + t.Errorf( + "expected 1 domain in state, got %d", + len(snap.Domains), + ) + } + + if len(snap.Hostnames) != 1 { + t.Errorf( + "expected 1 hostname in state, got %d", + len(snap.Hostnames), + ) + } +} + +func TestNSChangeDetection(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Domains = []string{"example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + "ns2.example.com.", + } + + ctx := t.Context() + w.RunOnce(ctx) + + deps.resolver.mu.Lock() + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + "ns3.example.com.", + } + deps.resolver.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for NS change") + } + + found := false + + for _, n := range notifications { + if n.Priority == "warning" { + found = true + } + } + + if !found { + t.Error("expected warning-priority NS change notification") + } +} + +func TestRecordChangeDetection(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"93.184.216.34"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "93.184.216.34", + } + deps.portChecker.results["93.184.216.34:80"] = false + deps.portChecker.results["93.184.216.34:443"] = false + + ctx := t.Context() + w.RunOnce(ctx) + + deps.resolver.mu.Lock() + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"93.184.216.35"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "93.184.216.35", + } + deps.resolver.mu.Unlock() + + deps.portChecker.mu.Lock() + deps.portChecker.results["93.184.216.35:80"] = false + deps.portChecker.results["93.184.216.35:443"] = false + deps.portChecker.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for record change") + } +} + +func TestPortStateChange(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "1.2.3.4", + } + deps.portChecker.results["1.2.3.4:80"] = true + deps.portChecker.results["1.2.3.4:443"] = true + deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{ + CommonName: "www.example.com", + Issuer: "DigiCert", + NotAfter: time.Now().Add(90 * 24 * time.Hour), + SubjectAlternativeNames: []string{ + "www.example.com", + }, + } + + ctx := t.Context() + w.RunOnce(ctx) + + deps.portChecker.mu.Lock() + deps.portChecker.results["1.2.3.4:443"] = false + deps.portChecker.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for port state change") + } +} + +func TestTLSExpiryWarning(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "1.2.3.4", + } + deps.portChecker.results["1.2.3.4:80"] = true + deps.portChecker.results["1.2.3.4:443"] = true + deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{ + CommonName: "www.example.com", + Issuer: "DigiCert", + NotAfter: time.Now().Add(3 * 24 * time.Hour), + SubjectAlternativeNames: []string{ + "www.example.com", + }, + } + + ctx := t.Context() + + // First run = baseline + w.RunOnce(ctx) + + // Second run should warn about expiry + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + + found := false + + for _, n := range notifications { + if n.Priority == "warning" { + found = true + } + } + + if !found { + t.Errorf( + "expected expiry warning, got: %v", + notifications, + ) + } +} + +func TestGracefulShutdown(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Domains = []string{"example.com"} + cfg.DNSInterval = 100 * time.Millisecond + cfg.TLSInterval = 100 * time.Millisecond + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + } + + ctx, cancel := context.WithCancel(t.Context()) + + done := make(chan struct{}) + + go func() { + w.Run(ctx) + close(done) + }() + + time.Sleep(250 * time.Millisecond) + cancel() + + select { + case <-done: + // Shut down cleanly + case <-time.After(5 * time.Second): + t.Error("watcher did not shut down within timeout") + } +} + +func TestNSFailureAndRecovery(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + "ns2.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "1.2.3.4", + } + deps.portChecker.results["1.2.3.4:80"] = false + deps.portChecker.results["1.2.3.4:443"] = false + + ctx := t.Context() + + w.RunOnce(ctx) + + deps.resolver.mu.Lock() + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for NS disappearance") + } +}