From f676cc9458f90950f0a57fc1ceffac01c2aaca93 Mon Sep 17 00:00:00 2001 From: clawbot Date: Thu, 19 Feb 2026 13:48:46 -0800 Subject: [PATCH] feat: implement watcher monitoring orchestrator Implements the full monitoring loop: - Immediate checks on startup, then periodic DNS+port and TLS cycles - Domain NS change detection with notifications - Per-nameserver hostname record tracking with change/failure/recovery and inconsistency detection - TCP port 80/443 monitoring with state change notifications - TLS certificate monitoring with change, expiry, and failure detection - State persistence after each cycle - First run establishes baseline without notifications - Graceful shutdown via context cancellation Defines DNSResolver, PortChecker, TLSChecker, and Notifier interfaces for dependency injection. Updates main.go fx wiring and resolver stub signature to match per-NS record format. Closes #2 --- cmd/dnswatcher/main.go | 14 + internal/resolver/resolver.go | 4 +- internal/watcher/watcher.go | 722 ++++++++++++++++++++++++++++++++-- 3 files changed, 713 insertions(+), 27 deletions(-) diff --git a/cmd/dnswatcher/main.go b/cmd/dnswatcher/main.go index 0dd77b9..03b38da 100644 --- a/cmd/dnswatcher/main.go +++ b/cmd/dnswatcher/main.go @@ -51,6 +51,20 @@ func main() { handlers.New, server.New, ), + fx.Provide( + func(r *resolver.Resolver) watcher.DNSResolver { + return r + }, + func(p *portcheck.Checker) watcher.PortChecker { + return p + }, + func(t *tlscheck.Checker) watcher.TLSChecker { + return t + }, + func(n *notify.Service) watcher.Notifier { + return n + }, + ), fx.Invoke(func(*server.Server, *watcher.Watcher) {}), ).Run() } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index be47717..76432d9 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -46,11 +46,11 @@ func (r *Resolver) LookupNS( } // LookupAllRecords performs iterative resolution to find all DNS -// records for the given hostname. +// records for the given hostname, keyed by authoritative nameserver. func (r *Resolver) LookupAllRecords( _ context.Context, _ string, -) (map[string][]string, error) { +) (map[string]map[string][]string, error) { return nil, ErrNotImplemented } diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index a9b1caa..2493264 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1,21 +1,30 @@ -// Package watcher provides the main monitoring orchestrator and scheduler. package watcher import ( "context" + "fmt" "log/slog" + "sort" + "strings" + "time" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/logger" - "sneak.berlin/go/dnswatcher/internal/notify" - "sneak.berlin/go/dnswatcher/internal/portcheck" - "sneak.berlin/go/dnswatcher/internal/resolver" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" ) +// monitoredPorts are the TCP ports checked for each IP address. +var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals + +// tlsPort is the port used for TLS certificate checks. +const tlsPort = 443 + +// hoursPerDay converts days to hours for duration calculations. +const hoursPerDay = 24 + // Params contains dependencies for Watcher. type Params struct { fx.In @@ -23,10 +32,10 @@ type Params struct { Logger *logger.Logger Config *config.Config State *state.State - Resolver *resolver.Resolver - PortCheck *portcheck.Checker - TLSCheck *tlscheck.Checker - Notify *notify.Service + Resolver DNSResolver + PortCheck PortChecker + TLSCheck TLSChecker + Notify Notifier } // Watcher orchestrates all monitoring checks on a schedule. @@ -34,19 +43,20 @@ type Watcher struct { log *slog.Logger config *config.Config state *state.State - resolver *resolver.Resolver - portCheck *portcheck.Checker - tlsCheck *tlscheck.Checker - notify *notify.Service + resolver DNSResolver + portCheck PortChecker + tlsCheck TLSChecker + notify Notifier cancel context.CancelFunc + firstRun bool } -// New creates a new Watcher instance. +// New creates a new Watcher instance wired into the fx lifecycle. func New( lifecycle fx.Lifecycle, params Params, ) (*Watcher, error) { - watcher := &Watcher{ + w := &Watcher{ log: params.Logger.Get(), config: params.Config, state: params.State, @@ -54,30 +64,54 @@ func New( portCheck: params.PortCheck, tlsCheck: params.TLSCheck, notify: params.Notify, + firstRun: true, } lifecycle.Append(fx.Hook{ OnStart: func(startCtx context.Context) error { - ctx, cancel := context.WithCancel(startCtx) - watcher.cancel = cancel + ctx, cancel := context.WithCancel( + context.WithoutCancel(startCtx), + ) + w.cancel = cancel - go watcher.Run(ctx) + go w.Run(ctx) return nil }, OnStop: func(_ context.Context) error { - if watcher.cancel != nil { - watcher.cancel() + if w.cancel != nil { + w.cancel() } return nil }, }) - return watcher, nil + return w, nil } -// Run starts the monitoring loop. +// NewForTest creates a Watcher without fx for unit testing. +func NewForTest( + cfg *config.Config, + st *state.State, + res DNSResolver, + pc PortChecker, + tc TLSChecker, + n Notifier, +) *Watcher { + return &Watcher{ + log: slog.Default(), + config: cfg, + state: st, + resolver: res, + portCheck: pc, + tlsCheck: tc, + notify: n, + firstRun: true, + } +} + +// Run starts the monitoring loop with periodic scheduling. func (w *Watcher) Run(ctx context.Context) { w.log.Info( "watcher starting", @@ -87,8 +121,646 @@ func (w *Watcher) Run(ctx context.Context) { "tlsInterval", w.config.TLSInterval, ) - // Stub: wait for context cancellation. - // Implementation will add initial check + periodic scheduling. - <-ctx.Done() - w.log.Info("watcher stopped") + w.RunOnce(ctx) + + dnsTicker := time.NewTicker(w.config.DNSInterval) + tlsTicker := time.NewTicker(w.config.TLSInterval) + + defer dnsTicker.Stop() + defer tlsTicker.Stop() + + for { + select { + case <-ctx.Done(): + w.log.Info("watcher stopped") + + return + case <-dnsTicker.C: + w.runDNSAndPortChecks(ctx) + w.saveState() + case <-tlsTicker.C: + w.runTLSChecks(ctx) + w.saveState() + } + } +} + +// RunOnce performs a single complete monitoring cycle. +func (w *Watcher) RunOnce(ctx context.Context) { + w.detectFirstRun() + w.runDNSAndPortChecks(ctx) + w.runTLSChecks(ctx) + w.saveState() + w.firstRun = false +} + +func (w *Watcher) detectFirstRun() { + snap := w.state.GetSnapshot() + hasState := len(snap.Domains) > 0 || + len(snap.Hostnames) > 0 || + len(snap.Ports) > 0 || + len(snap.Certificates) > 0 + + if hasState { + w.firstRun = false + } +} + +func (w *Watcher) runDNSAndPortChecks(ctx context.Context) { + for _, domain := range w.config.Domains { + w.checkDomain(ctx, domain) + } + + for _, hostname := range w.config.Hostnames { + w.checkHostname(ctx, hostname) + } + + w.checkAllPorts(ctx) +} + +func (w *Watcher) checkDomain( + ctx context.Context, + domain string, +) { + nameservers, err := w.resolver.LookupNS(ctx, domain) + if err != nil { + w.log.Error( + "failed to lookup NS", + "domain", domain, + "error", err, + ) + + return + } + + sort.Strings(nameservers) + + now := time.Now().UTC() + + prev, hasPrev := w.state.GetDomainState(domain) + if hasPrev && !w.firstRun { + w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers) + } + + w.state.SetDomainState(domain, &state.DomainState{ + Nameservers: nameservers, + LastChecked: now, + }) +} + +func (w *Watcher) detectNSChanges( + ctx context.Context, + domain string, + oldNS, newNS []string, +) { + oldSet := toSet(oldNS) + newSet := toSet(newNS) + + var added, removed []string + + for ns := range newSet { + if !oldSet[ns] { + added = append(added, ns) + } + } + + for ns := range oldSet { + if !newSet[ns] { + removed = append(removed, ns) + } + } + + if len(added) == 0 && len(removed) == 0 { + return + } + + msg := fmt.Sprintf( + "Domain: %s\nAdded: %s\nRemoved: %s", + domain, + strings.Join(added, ", "), + strings.Join(removed, ", "), + ) + + w.notify.SendNotification( + ctx, + "NS Change: "+domain, + msg, + "warning", + ) +} + +func (w *Watcher) checkHostname( + ctx context.Context, + hostname string, +) { + records, err := w.resolver.LookupAllRecords(ctx, hostname) + if err != nil { + w.log.Error( + "failed to lookup records", + "hostname", hostname, + "error", err, + ) + + return + } + + now := time.Now().UTC() + prev, hasPrev := w.state.GetHostnameState(hostname) + + if hasPrev && !w.firstRun { + w.detectHostnameChanges(ctx, hostname, prev, records) + } + + newState := buildHostnameState(records, now) + w.state.SetHostnameState(hostname, newState) +} + +func buildHostnameState( + records map[string]map[string][]string, + now time.Time, +) *state.HostnameState { + hs := &state.HostnameState{ + RecordsByNameserver: make( + map[string]*state.NameserverRecordState, + ), + LastChecked: now, + } + + for ns, recs := range records { + hs.RecordsByNameserver[ns] = &state.NameserverRecordState{ + Records: recs, + Status: "ok", + LastChecked: now, + } + } + + return hs +} + +func (w *Watcher) detectHostnameChanges( + ctx context.Context, + hostname string, + prev *state.HostnameState, + current map[string]map[string][]string, +) { + w.detectRecordChanges(ctx, hostname, prev, current) + w.detectNSDisappearances(ctx, hostname, prev, current) + w.detectInconsistencies(ctx, hostname, current) +} + +func (w *Watcher) detectRecordChanges( + ctx context.Context, + hostname string, + prev *state.HostnameState, + current map[string]map[string][]string, +) { + for ns, recs := range current { + prevNS, ok := prev.RecordsByNameserver[ns] + if !ok { + continue + } + + if recordsEqual(prevNS.Records, recs) { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\nNameserver: %s\n"+ + "Old: %v\nNew: %v", + hostname, ns, + prevNS.Records, recs, + ) + + w.notify.SendNotification( + ctx, + "Record Change: "+hostname, + msg, + "warning", + ) + } +} + +func (w *Watcher) detectNSDisappearances( + ctx context.Context, + hostname string, + prev *state.HostnameState, + current map[string]map[string][]string, +) { + for ns, prevNS := range prev.RecordsByNameserver { + if _, ok := current[ns]; ok || prevNS.Status != "ok" { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\nNameserver: %s disappeared", + hostname, ns, + ) + + w.notify.SendNotification( + ctx, + "NS Failure: "+hostname, + msg, + "error", + ) + } + + for ns := range current { + prevNS, ok := prev.RecordsByNameserver[ns] + if !ok || prevNS.Status != "error" { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\nNameserver: %s recovered", + hostname, ns, + ) + + w.notify.SendNotification( + ctx, + "NS Recovery: "+hostname, + msg, + "success", + ) + } +} + +func (w *Watcher) detectInconsistencies( + ctx context.Context, + hostname string, + current map[string]map[string][]string, +) { + nameservers := make([]string, 0, len(current)) + for ns := range current { + nameservers = append(nameservers, ns) + } + + sort.Strings(nameservers) + + for i := range len(nameservers) - 1 { + ns1 := nameservers[i] + ns2 := nameservers[i+1] + + if recordsEqual(current[ns1], current[ns2]) { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\n%s: %v\n%s: %v", + hostname, + ns1, current[ns1], + ns2, current[ns2], + ) + + w.notify.SendNotification( + ctx, + "Inconsistency: "+hostname, + msg, + "warning", + ) + } +} + +func (w *Watcher) checkAllPorts(ctx context.Context) { + for _, hostname := range w.config.Hostnames { + w.checkPortsForHostname(ctx, hostname) + } + + for _, domain := range w.config.Domains { + w.checkPortsForHostname(ctx, domain) + } +} + +func (w *Watcher) checkPortsForHostname( + ctx context.Context, + hostname string, +) { + ips := w.collectIPs(hostname) + + for _, ip := range ips { + for _, port := range monitoredPorts { + w.checkSinglePort(ctx, ip, port, hostname) + } + } +} + +func (w *Watcher) collectIPs(hostname string) []string { + hs, ok := w.state.GetHostnameState(hostname) + if !ok { + return nil + } + + ipSet := make(map[string]bool) + + for _, nsState := range hs.RecordsByNameserver { + for _, ip := range nsState.Records["A"] { + ipSet[ip] = true + } + + for _, ip := range nsState.Records["AAAA"] { + ipSet[ip] = true + } + } + + result := make([]string, 0, len(ipSet)) + for ip := range ipSet { + result = append(result, ip) + } + + sort.Strings(result) + + return result +} + +func (w *Watcher) checkSinglePort( + ctx context.Context, + ip string, + port int, + hostname string, +) { + open, err := w.portCheck.CheckPort(ctx, ip, port) + if err != nil { + w.log.Error( + "port check failed", + "ip", ip, + "port", port, + "error", err, + ) + + return + } + + key := fmt.Sprintf("%s:%d", ip, port) + now := time.Now().UTC() + prev, hasPrev := w.state.GetPortState(key) + + if hasPrev && !w.firstRun && prev.Open != open { + stateStr := "closed" + if open { + stateStr = "open" + } + + msg := fmt.Sprintf( + "Host: %s\nAddress: %s\nPort now %s", + hostname, key, stateStr, + ) + + w.notify.SendNotification( + ctx, + "Port Change: "+key, + msg, + "warning", + ) + } + + w.state.SetPortState(key, &state.PortState{ + Open: open, + Hostname: hostname, + LastChecked: now, + }) +} + +func (w *Watcher) runTLSChecks(ctx context.Context) { + for _, hostname := range w.config.Hostnames { + w.checkTLSForHostname(ctx, hostname) + } + + for _, domain := range w.config.Domains { + w.checkTLSForHostname(ctx, domain) + } +} + +func (w *Watcher) checkTLSForHostname( + ctx context.Context, + hostname string, +) { + ips := w.collectIPs(hostname) + + for _, ip := range ips { + portKey := fmt.Sprintf("%s:%d", ip, tlsPort) + + ps, ok := w.state.GetPortState(portKey) + if !ok || !ps.Open { + continue + } + + w.checkTLSCert(ctx, ip, hostname) + } +} + +func (w *Watcher) checkTLSCert( + ctx context.Context, + ip string, + hostname string, +) { + cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname) + certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname) + now := time.Now().UTC() + prev, hasPrev := w.state.GetCertificateState(certKey) + + if err != nil { + w.handleTLSError( + ctx, certKey, hostname, ip, + hasPrev, prev, now, err, + ) + + return + } + + w.handleTLSSuccess( + ctx, certKey, hostname, ip, + hasPrev, prev, now, cert, + ) +} + +func (w *Watcher) handleTLSError( + ctx context.Context, + certKey, hostname, ip string, + hasPrev bool, + prev *state.CertificateState, + now time.Time, + err error, +) { + if hasPrev && !w.firstRun && prev.Status == "ok" { + msg := fmt.Sprintf( + "Host: %s\nIP: %s\nError: %s", + hostname, ip, err, + ) + + w.notify.SendNotification( + ctx, + "TLS Failure: "+hostname, + msg, + "error", + ) + } + + w.state.SetCertificateState( + certKey, &state.CertificateState{ + Status: "error", + Error: err.Error(), + LastChecked: now, + }, + ) +} + +func (w *Watcher) handleTLSSuccess( + ctx context.Context, + certKey, hostname, ip string, + hasPrev bool, + prev *state.CertificateState, + now time.Time, + cert *tlscheck.CertificateInfo, +) { + if hasPrev && !w.firstRun { + w.detectTLSChanges(ctx, hostname, ip, prev, cert) + } + + w.checkTLSExpiry(ctx, hostname, ip, cert) + + w.state.SetCertificateState( + certKey, &state.CertificateState{ + CommonName: cert.CommonName, + Issuer: cert.Issuer, + NotAfter: cert.NotAfter, + SubjectAlternativeNames: cert.SubjectAlternativeNames, + Status: "ok", + LastChecked: now, + }, + ) +} + +func (w *Watcher) detectTLSChanges( + ctx context.Context, + hostname, ip string, + prev *state.CertificateState, + cert *tlscheck.CertificateInfo, +) { + if prev.Status == "error" { + msg := fmt.Sprintf( + "Host: %s\nIP: %s\nTLS recovered", + hostname, ip, + ) + + w.notify.SendNotification( + ctx, + "TLS Recovery: "+hostname, + msg, + "success", + ) + + return + } + + changed := prev.CommonName != cert.CommonName || + prev.Issuer != cert.Issuer || + !sliceEqual( + prev.SubjectAlternativeNames, + cert.SubjectAlternativeNames, + ) + + if !changed { + return + } + + msg := fmt.Sprintf( + "Host: %s\nIP: %s\n"+ + "Old CN: %s, Issuer: %s\n"+ + "New CN: %s, Issuer: %s", + hostname, ip, + prev.CommonName, prev.Issuer, + cert.CommonName, cert.Issuer, + ) + + w.notify.SendNotification( + ctx, + "TLS Certificate Change: "+hostname, + msg, + "warning", + ) +} + +func (w *Watcher) checkTLSExpiry( + ctx context.Context, + hostname, ip string, + cert *tlscheck.CertificateInfo, +) { + daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay + warningDays := float64(w.config.TLSExpiryWarning) + + if daysLeft > warningDays { + return + } + + msg := fmt.Sprintf( + "Host: %s\nIP: %s\nCN: %s\n"+ + "Expires: %s (%.0f days)", + hostname, ip, cert.CommonName, + cert.NotAfter.Format(time.RFC3339), + daysLeft, + ) + + w.notify.SendNotification( + ctx, + "TLS Expiry Warning: "+hostname, + msg, + "warning", + ) +} + +func (w *Watcher) saveState() { + err := w.state.Save() + if err != nil { + w.log.Error("failed to save state", "error", err) + } +} + +// --- Utility functions --- + +func toSet(items []string) map[string]bool { + set := make(map[string]bool, len(items)) + for _, item := range items { + set[item] = true + } + + return set +} + +func recordsEqual( + a, b map[string][]string, +) bool { + if len(a) != len(b) { + return false + } + + for k, av := range a { + bv, ok := b[k] + if !ok || !sliceEqual(av, bv) { + return false + } + } + + return true +} + +func sliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + + aSorted := make([]string, len(a)) + bSorted := make([]string, len(b)) + + copy(aSorted, a) + copy(bSorted, b) + + sort.Strings(aSorted) + sort.Strings(bSorted) + + for i := range aSorted { + if aSorted[i] != bSorted[i] { + return false + } + } + + return true }