package watcher_test import ( "context" "errors" "fmt" "sync" "testing" "time" "sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/portcheck" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/watcher" ) // errNotFound is returned when mock data is missing. var errNotFound = errors.New("not found") // --- Mock implementations --- type mockResolver struct { mu sync.Mutex nsRecords map[string][]string allRecords map[string]map[string]map[string][]string ipAddresses map[string][]string lookupNSErr error allRecordsErr error resolveIPErr error lookupNSCalls int allRecordCalls int } func (m *mockResolver) LookupNS( _ context.Context, domain string, ) ([]string, error) { m.mu.Lock() defer m.mu.Unlock() m.lookupNSCalls++ if m.lookupNSErr != nil { return nil, m.lookupNSErr } ns, ok := m.nsRecords[domain] if !ok { return nil, fmt.Errorf( "%w: NS for %s", errNotFound, domain, ) } return ns, nil } func (m *mockResolver) LookupAllRecords( _ context.Context, hostname string, ) (map[string]map[string][]string, error) { m.mu.Lock() defer m.mu.Unlock() m.allRecordCalls++ if m.allRecordsErr != nil { return nil, m.allRecordsErr } recs, ok := m.allRecords[hostname] if !ok { return nil, fmt.Errorf( "%w: records for %s", errNotFound, hostname, ) } return recs, nil } func (m *mockResolver) ResolveIPAddresses( _ context.Context, hostname string, ) ([]string, error) { m.mu.Lock() defer m.mu.Unlock() if m.resolveIPErr != nil { return nil, m.resolveIPErr } ips, ok := m.ipAddresses[hostname] if !ok { return nil, fmt.Errorf( "%w: IPs for %s", errNotFound, hostname, ) } return ips, nil } type mockPortChecker struct { mu sync.Mutex results map[string]bool err error calls int } func (m *mockPortChecker) CheckPort( _ context.Context, address string, port int, ) (*portcheck.PortResult, error) { m.mu.Lock() defer m.mu.Unlock() m.calls++ if m.err != nil { return nil, m.err } key := fmt.Sprintf("%s:%d", address, port) open := m.results[key] return &portcheck.PortResult{Open: open}, nil } type mockTLSChecker struct { mu sync.Mutex certs map[string]*tlscheck.CertificateInfo err error calls int } func (m *mockTLSChecker) CheckCertificate( _ context.Context, ip string, hostname string, ) (*tlscheck.CertificateInfo, error) { m.mu.Lock() defer m.mu.Unlock() m.calls++ if m.err != nil { return nil, m.err } key := fmt.Sprintf("%s:%s", ip, hostname) cert, ok := m.certs[key] if !ok { return nil, fmt.Errorf( "%w: cert for %s", errNotFound, key, ) } return cert, nil } type notification struct { Title string Message string Priority string } type mockNotifier struct { mu sync.Mutex notifications []notification } func (m *mockNotifier) SendNotification( _ context.Context, title, message, priority string, ) { m.mu.Lock() defer m.mu.Unlock() m.notifications = append(m.notifications, notification{ Title: title, Message: message, Priority: priority, }) } func (m *mockNotifier) getNotifications() []notification { m.mu.Lock() defer m.mu.Unlock() result := make([]notification, len(m.notifications)) copy(result, m.notifications) return result } // --- Helper to build a Watcher for testing --- type testDeps struct { resolver *mockResolver portChecker *mockPortChecker tlsChecker *mockTLSChecker notifier *mockNotifier state *state.State config *config.Config } func newTestWatcher( t *testing.T, cfg *config.Config, ) (*watcher.Watcher, *testDeps) { t.Helper() deps := &testDeps{ resolver: &mockResolver{ nsRecords: make(map[string][]string), allRecords: make(map[string]map[string]map[string][]string), ipAddresses: make(map[string][]string), }, portChecker: &mockPortChecker{ results: make(map[string]bool), }, tlsChecker: &mockTLSChecker{ certs: make(map[string]*tlscheck.CertificateInfo), }, notifier: &mockNotifier{}, config: cfg, } deps.state = state.NewForTest() w := watcher.NewForTest( deps.config, deps.state, deps.resolver, deps.portChecker, deps.tlsChecker, deps.notifier, ) return w, deps } func defaultTestConfig(t *testing.T) *config.Config { t.Helper() return &config.Config{ DNSInterval: time.Hour, TLSInterval: 12 * time.Hour, TLSExpiryWarning: 7, DataDir: t.TempDir(), } } func TestFirstRunBaseline(t *testing.T) { t.Parallel() cfg := defaultTestConfig(t) cfg.Domains = []string{"example.com"} cfg.Hostnames = []string{"www.example.com"} w, deps := newTestWatcher(t, cfg) setupBaselineMocks(deps) w.RunOnce(t.Context()) assertNoNotifications(t, deps) assertStatePopulated(t, deps) } func setupBaselineMocks(deps *testDeps) { deps.resolver.nsRecords["example.com"] = []string{ "ns1.example.com.", "ns2.example.com.", } deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ "ns1.example.com.": {"A": {"93.184.216.34"}}, "ns2.example.com.": {"A": {"93.184.216.34"}}, } deps.resolver.ipAddresses["www.example.com"] = []string{ "93.184.216.34", } deps.portChecker.results["93.184.216.34:80"] = true deps.portChecker.results["93.184.216.34:443"] = true deps.tlsChecker.certs["93.184.216.34:www.example.com"] = &tlscheck.CertificateInfo{ CommonName: "www.example.com", Issuer: "DigiCert", NotAfter: time.Now().Add(90 * 24 * time.Hour), SubjectAlternativeNames: []string{ "www.example.com", }, } } func assertNoNotifications( t *testing.T, deps *testDeps, ) { t.Helper() notifications := deps.notifier.getNotifications() if len(notifications) != 0 { t.Errorf( "expected 0 notifications on first run, got %d", len(notifications), ) } } func assertStatePopulated( t *testing.T, deps *testDeps, ) { t.Helper() snap := deps.state.GetSnapshot() if len(snap.Domains) != 1 { t.Errorf( "expected 1 domain in state, got %d", len(snap.Domains), ) } if len(snap.Hostnames) != 1 { t.Errorf( "expected 1 hostname in state, got %d", len(snap.Hostnames), ) } } func TestNSChangeDetection(t *testing.T) { t.Parallel() cfg := defaultTestConfig(t) cfg.Domains = []string{"example.com"} w, deps := newTestWatcher(t, cfg) deps.resolver.nsRecords["example.com"] = []string{ "ns1.example.com.", "ns2.example.com.", } ctx := t.Context() w.RunOnce(ctx) deps.resolver.mu.Lock() deps.resolver.nsRecords["example.com"] = []string{ "ns1.example.com.", "ns3.example.com.", } deps.resolver.mu.Unlock() w.RunOnce(ctx) notifications := deps.notifier.getNotifications() if len(notifications) == 0 { t.Error("expected notification for NS change") } found := false for _, n := range notifications { if n.Priority == "warning" { found = true } } if !found { t.Error("expected warning-priority NS change notification") } } func TestRecordChangeDetection(t *testing.T) { t.Parallel() cfg := defaultTestConfig(t) cfg.Hostnames = []string{"www.example.com"} w, deps := newTestWatcher(t, cfg) deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ "ns1.example.com.": {"A": {"93.184.216.34"}}, } deps.resolver.ipAddresses["www.example.com"] = []string{ "93.184.216.34", } deps.portChecker.results["93.184.216.34:80"] = false deps.portChecker.results["93.184.216.34:443"] = false ctx := t.Context() w.RunOnce(ctx) deps.resolver.mu.Lock() deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ "ns1.example.com.": {"A": {"93.184.216.35"}}, } deps.resolver.ipAddresses["www.example.com"] = []string{ "93.184.216.35", } deps.resolver.mu.Unlock() deps.portChecker.mu.Lock() deps.portChecker.results["93.184.216.35:80"] = false deps.portChecker.results["93.184.216.35:443"] = false deps.portChecker.mu.Unlock() w.RunOnce(ctx) notifications := deps.notifier.getNotifications() if len(notifications) == 0 { t.Error("expected notification for record change") } } func TestPortStateChange(t *testing.T) { t.Parallel() cfg := defaultTestConfig(t) cfg.Hostnames = []string{"www.example.com"} w, deps := newTestWatcher(t, cfg) deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ "ns1.example.com.": {"A": {"1.2.3.4"}}, } deps.resolver.ipAddresses["www.example.com"] = []string{ "1.2.3.4", } deps.portChecker.results["1.2.3.4:80"] = true deps.portChecker.results["1.2.3.4:443"] = true deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{ CommonName: "www.example.com", Issuer: "DigiCert", NotAfter: time.Now().Add(90 * 24 * time.Hour), SubjectAlternativeNames: []string{ "www.example.com", }, } ctx := t.Context() w.RunOnce(ctx) deps.portChecker.mu.Lock() deps.portChecker.results["1.2.3.4:443"] = false deps.portChecker.mu.Unlock() w.RunOnce(ctx) notifications := deps.notifier.getNotifications() if len(notifications) == 0 { t.Error("expected notification for port state change") } } func TestTLSExpiryWarning(t *testing.T) { t.Parallel() cfg := defaultTestConfig(t) cfg.Hostnames = []string{"www.example.com"} w, deps := newTestWatcher(t, cfg) deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ "ns1.example.com.": {"A": {"1.2.3.4"}}, } deps.resolver.ipAddresses["www.example.com"] = []string{ "1.2.3.4", } deps.portChecker.results["1.2.3.4:80"] = true deps.portChecker.results["1.2.3.4:443"] = true deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{ CommonName: "www.example.com", Issuer: "DigiCert", NotAfter: time.Now().Add(3 * 24 * time.Hour), SubjectAlternativeNames: []string{ "www.example.com", }, } ctx := t.Context() // First run = baseline w.RunOnce(ctx) // Second run should warn about expiry w.RunOnce(ctx) notifications := deps.notifier.getNotifications() found := false for _, n := range notifications { if n.Priority == "warning" { found = true } } if !found { t.Errorf( "expected expiry warning, got: %v", notifications, ) } } func TestGracefulShutdown(t *testing.T) { t.Parallel() cfg := defaultTestConfig(t) cfg.Domains = []string{"example.com"} cfg.DNSInterval = 100 * time.Millisecond cfg.TLSInterval = 100 * time.Millisecond w, deps := newTestWatcher(t, cfg) deps.resolver.nsRecords["example.com"] = []string{ "ns1.example.com.", } ctx, cancel := context.WithCancel(t.Context()) done := make(chan struct{}) go func() { w.Run(ctx) close(done) }() time.Sleep(250 * time.Millisecond) cancel() select { case <-done: // Shut down cleanly case <-time.After(5 * time.Second): t.Error("watcher did not shut down within timeout") } } func TestNSFailureAndRecovery(t *testing.T) { t.Parallel() cfg := defaultTestConfig(t) cfg.Hostnames = []string{"www.example.com"} w, deps := newTestWatcher(t, cfg) deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ "ns1.example.com.": {"A": {"1.2.3.4"}}, "ns2.example.com.": {"A": {"1.2.3.4"}}, } deps.resolver.ipAddresses["www.example.com"] = []string{ "1.2.3.4", } deps.portChecker.results["1.2.3.4:80"] = false deps.portChecker.results["1.2.3.4:443"] = false ctx := t.Context() w.RunOnce(ctx) deps.resolver.mu.Lock() deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ "ns1.example.com.": {"A": {"1.2.3.4"}}, } deps.resolver.mu.Unlock() w.RunOnce(ctx) notifications := deps.notifier.getNotifications() if len(notifications) == 0 { t.Error("expected notification for NS disappearance") } }