package watcher import ( "context" "fmt" "log/slog" "sort" "strings" "sync" "time" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" ) // monitoredPorts are the TCP ports checked for each IP address. var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals // tlsPort is the port used for TLS certificate checks. const tlsPort = 443 // hoursPerDay converts days to hours for duration calculations. const hoursPerDay = 24 // Params contains dependencies for Watcher. type Params struct { fx.In Logger *logger.Logger Config *config.Config State *state.State Resolver DNSResolver PortCheck PortChecker TLSCheck TLSChecker Notify Notifier } // Watcher orchestrates all monitoring checks on a schedule. type Watcher struct { log *slog.Logger config *config.Config state *state.State resolver DNSResolver portCheck PortChecker tlsCheck TLSChecker notify Notifier cancel context.CancelFunc firstRun bool expiryNotifiedMu sync.Mutex expiryNotified map[string]time.Time } // New creates a new Watcher instance wired into the fx lifecycle. func New( lifecycle fx.Lifecycle, params Params, ) (*Watcher, error) { w := &Watcher{ log: params.Logger.Get(), config: params.Config, state: params.State, resolver: params.Resolver, portCheck: params.PortCheck, tlsCheck: params.TLSCheck, notify: params.Notify, firstRun: true, expiryNotified: make(map[string]time.Time), } lifecycle.Append(fx.Hook{ OnStart: func(_ context.Context) error { // Use context.Background() — the fx startup context // expires after startup completes, so deriving from it // would cancel the watcher immediately. The watcher's // lifetime is controlled by w.cancel in OnStop. ctx, cancel := context.WithCancel(context.Background()) w.cancel = cancel go w.Run(ctx) //nolint:contextcheck // intentionally not derived from startCtx return nil }, OnStop: func(_ context.Context) error { if w.cancel != nil { w.cancel() } return nil }, }) return w, nil } // NewForTest creates a Watcher without fx for unit testing. func NewForTest( cfg *config.Config, st *state.State, res DNSResolver, pc PortChecker, tc TLSChecker, n Notifier, ) *Watcher { return &Watcher{ log: slog.Default(), config: cfg, state: st, resolver: res, portCheck: pc, tlsCheck: tc, notify: n, firstRun: true, expiryNotified: make(map[string]time.Time), } } // Run starts the monitoring loop with periodic scheduling. func (w *Watcher) Run(ctx context.Context) { w.log.Info( "watcher starting", "domains", len(w.config.Domains), "hostnames", len(w.config.Hostnames), "dnsInterval", w.config.DNSInterval, "tlsInterval", w.config.TLSInterval, ) w.RunOnce(ctx) w.maybeSendTestNotification(ctx) dnsTicker := time.NewTicker(w.config.DNSInterval) tlsTicker := time.NewTicker(w.config.TLSInterval) defer dnsTicker.Stop() defer tlsTicker.Stop() for { select { case <-ctx.Done(): w.log.Info("watcher stopped") return case <-dnsTicker.C: w.runDNSChecks(ctx) w.checkAllPorts(ctx) w.saveState() case <-tlsTicker.C: // Run DNS first so TLS checks use freshly // resolved IP addresses, not stale ones from // a previous cycle. w.runDNSChecks(ctx) w.runTLSChecks(ctx) w.saveState() } } } // RunOnce performs a single complete monitoring cycle. // DNS checks run first so that port and TLS checks use // freshly resolved IP addresses. Port checks run before // TLS because TLS checks only target IPs with an open // port 443. func (w *Watcher) RunOnce(ctx context.Context) { w.detectFirstRun() // Phase 1: DNS resolution must complete first so that // subsequent checks use fresh IP addresses. w.runDNSChecks(ctx) // Phase 2: Port checks populate port state that TLS // checks depend on (TLS only targets IPs where port // 443 is open). w.checkAllPorts(ctx) // Phase 3: TLS checks use fresh DNS IPs and current // port state. w.runTLSChecks(ctx) w.saveState() w.firstRun = false } func (w *Watcher) detectFirstRun() { snap := w.state.GetSnapshot() hasState := len(snap.Domains) > 0 || len(snap.Hostnames) > 0 || len(snap.Ports) > 0 || len(snap.Certificates) > 0 if hasState { w.firstRun = false } } // runDNSChecks performs DNS resolution for all configured domains // and hostnames, updating state with freshly resolved records. // This must complete before port or TLS checks run so those // checks operate on current IP addresses. func (w *Watcher) runDNSChecks(ctx context.Context) { for _, domain := range w.config.Domains { w.checkDomain(ctx, domain) } for _, hostname := range w.config.Hostnames { w.checkHostname(ctx, hostname) } } func (w *Watcher) checkDomain( ctx context.Context, domain string, ) { nameservers, err := w.resolver.LookupNS(ctx, domain) if err != nil { w.log.Error( "failed to lookup NS", "domain", domain, "error", err, ) return } sort.Strings(nameservers) now := time.Now().UTC() prev, hasPrev := w.state.GetDomainState(domain) if hasPrev && !w.firstRun { w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers) } w.state.SetDomainState(domain, &state.DomainState{ Nameservers: nameservers, LastChecked: now, }) // Also look up A/AAAA records for the apex domain so that // port and TLS checks (which read HostnameState) can find // the domain's IP addresses. records, err := w.resolver.LookupAllRecords(ctx, domain) if err != nil { w.log.Error( "failed to lookup records for domain", "domain", domain, "error", err, ) return } prevHS, hasPrevHS := w.state.GetHostnameState(domain) if hasPrevHS && !w.firstRun { w.detectHostnameChanges(ctx, domain, prevHS, records) } newState := buildHostnameState(records, now) w.state.SetHostnameState(domain, newState) } func (w *Watcher) detectNSChanges( ctx context.Context, domain string, oldNS, newNS []string, ) { oldSet := toSet(oldNS) newSet := toSet(newNS) var added, removed []string for ns := range newSet { if !oldSet[ns] { added = append(added, ns) } } for ns := range oldSet { if !newSet[ns] { removed = append(removed, ns) } } if len(added) == 0 && len(removed) == 0 { return } msg := fmt.Sprintf( "Domain: %s\nAdded: %s\nRemoved: %s", domain, strings.Join(added, ", "), strings.Join(removed, ", "), ) w.notify.SendNotification( ctx, "NS Change: "+domain, msg, "warning", ) } func (w *Watcher) checkHostname( ctx context.Context, hostname string, ) { records, err := w.resolver.LookupAllRecords(ctx, hostname) if err != nil { w.log.Error( "failed to lookup records", "hostname", hostname, "error", err, ) return } now := time.Now().UTC() prev, hasPrev := w.state.GetHostnameState(hostname) if hasPrev && !w.firstRun { w.detectHostnameChanges(ctx, hostname, prev, records) } newState := buildHostnameState(records, now) w.state.SetHostnameState(hostname, newState) } func buildHostnameState( records map[string]map[string][]string, now time.Time, ) *state.HostnameState { hs := &state.HostnameState{ RecordsByNameserver: make( map[string]*state.NameserverRecordState, ), LastChecked: now, } for ns, recs := range records { hs.RecordsByNameserver[ns] = &state.NameserverRecordState{ Records: recs, Status: "ok", LastChecked: now, } } return hs } func (w *Watcher) detectHostnameChanges( ctx context.Context, hostname string, prev *state.HostnameState, current map[string]map[string][]string, ) { w.detectRecordChanges(ctx, hostname, prev, current) w.detectNSDisappearances(ctx, hostname, prev, current) w.detectInconsistencies(ctx, hostname, current) } func (w *Watcher) detectRecordChanges( ctx context.Context, hostname string, prev *state.HostnameState, current map[string]map[string][]string, ) { for ns, recs := range current { prevNS, ok := prev.RecordsByNameserver[ns] if !ok { continue } if recordsEqual(prevNS.Records, recs) { continue } msg := fmt.Sprintf( "Hostname: %s\nNameserver: %s\n"+ "Old: %v\nNew: %v", hostname, ns, prevNS.Records, recs, ) w.notify.SendNotification( ctx, "Record Change: "+hostname, msg, "warning", ) } } func (w *Watcher) detectNSDisappearances( ctx context.Context, hostname string, prev *state.HostnameState, current map[string]map[string][]string, ) { for ns, prevNS := range prev.RecordsByNameserver { if _, ok := current[ns]; ok || prevNS.Status != "ok" { continue } msg := fmt.Sprintf( "Hostname: %s\nNameserver: %s disappeared", hostname, ns, ) w.notify.SendNotification( ctx, "NS Failure: "+hostname, msg, "error", ) } for ns := range current { prevNS, ok := prev.RecordsByNameserver[ns] if !ok || prevNS.Status != "error" { continue } msg := fmt.Sprintf( "Hostname: %s\nNameserver: %s recovered", hostname, ns, ) w.notify.SendNotification( ctx, "NS Recovery: "+hostname, msg, "success", ) } } func (w *Watcher) detectInconsistencies( ctx context.Context, hostname string, current map[string]map[string][]string, ) { nameservers := make([]string, 0, len(current)) for ns := range current { nameservers = append(nameservers, ns) } sort.Strings(nameservers) for i := range len(nameservers) - 1 { ns1 := nameservers[i] ns2 := nameservers[i+1] if recordsEqual(current[ns1], current[ns2]) { continue } msg := fmt.Sprintf( "Hostname: %s\n%s: %v\n%s: %v", hostname, ns1, current[ns1], ns2, current[ns2], ) w.notify.SendNotification( ctx, "Inconsistency: "+hostname, msg, "warning", ) } } func (w *Watcher) checkAllPorts(ctx context.Context) { // Phase 1: Build current IP:port → hostname associations // from fresh DNS data. associations := w.buildPortAssociations() // Phase 2: Check each unique IP:port and update state // with the full set of associated hostnames. for key, hostnames := range associations { ip, port := parsePortKey(key) if port == 0 { continue } w.checkSinglePort(ctx, ip, port, hostnames) } // Phase 3: Remove port state entries that no longer have // any hostname referencing them. w.cleanupStalePorts(associations) } // buildPortAssociations constructs a map from IP:port keys to // the sorted set of hostnames currently resolving to that IP. func (w *Watcher) buildPortAssociations() map[string][]string { assoc := make(map[string]map[string]bool) allNames := make( []string, 0, len(w.config.Hostnames)+len(w.config.Domains), ) allNames = append(allNames, w.config.Hostnames...) allNames = append(allNames, w.config.Domains...) for _, name := range allNames { ips := w.collectIPs(name) for _, ip := range ips { for _, port := range monitoredPorts { key := fmt.Sprintf("%s:%d", ip, port) if assoc[key] == nil { assoc[key] = make(map[string]bool) } assoc[key][name] = true } } } result := make(map[string][]string, len(assoc)) for key, set := range assoc { hostnames := make([]string, 0, len(set)) for h := range set { hostnames = append(hostnames, h) } sort.Strings(hostnames) result[key] = hostnames } return result } // parsePortKey splits an "ip:port" key into its components. func parsePortKey(key string) (string, int) { lastColon := strings.LastIndex(key, ":") if lastColon < 0 { return key, 0 } ip := key[:lastColon] var p int _, err := fmt.Sscanf(key[lastColon+1:], "%d", &p) if err != nil { return ip, 0 } return ip, p } // cleanupStalePorts removes port state entries that are no // longer referenced by any hostname in the current DNS data. func (w *Watcher) cleanupStalePorts( currentAssociations map[string][]string, ) { for _, key := range w.state.GetAllPortKeys() { if _, exists := currentAssociations[key]; !exists { w.state.DeletePortState(key) } } } func (w *Watcher) collectIPs(hostname string) []string { hs, ok := w.state.GetHostnameState(hostname) if !ok { return nil } ipSet := make(map[string]bool) for _, nsState := range hs.RecordsByNameserver { for _, ip := range nsState.Records["A"] { ipSet[ip] = true } for _, ip := range nsState.Records["AAAA"] { ipSet[ip] = true } } result := make([]string, 0, len(ipSet)) for ip := range ipSet { result = append(result, ip) } sort.Strings(result) return result } func (w *Watcher) checkSinglePort( ctx context.Context, ip string, port int, hostnames []string, ) { result, err := w.portCheck.CheckPort(ctx, ip, port) if err != nil { w.log.Error( "port check failed", "ip", ip, "port", port, "error", err, ) return } key := fmt.Sprintf("%s:%d", ip, port) now := time.Now().UTC() prev, hasPrev := w.state.GetPortState(key) if hasPrev && !w.firstRun && prev.Open != result.Open { stateStr := "closed" if result.Open { stateStr = "open" } msg := fmt.Sprintf( "Hosts: %s\nAddress: %s\nPort now %s", strings.Join(hostnames, ", "), key, stateStr, ) w.notify.SendNotification( ctx, "Port Change: "+key, msg, "warning", ) } w.state.SetPortState(key, &state.PortState{ Open: result.Open, Hostnames: hostnames, LastChecked: now, }) } func (w *Watcher) runTLSChecks(ctx context.Context) { for _, hostname := range w.config.Hostnames { w.checkTLSForHostname(ctx, hostname) } for _, domain := range w.config.Domains { w.checkTLSForHostname(ctx, domain) } } func (w *Watcher) checkTLSForHostname( ctx context.Context, hostname string, ) { ips := w.collectIPs(hostname) for _, ip := range ips { portKey := fmt.Sprintf("%s:%d", ip, tlsPort) ps, ok := w.state.GetPortState(portKey) if !ok || !ps.Open { continue } w.checkTLSCert(ctx, ip, hostname) } } func (w *Watcher) checkTLSCert( ctx context.Context, ip string, hostname string, ) { cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname) certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname) now := time.Now().UTC() prev, hasPrev := w.state.GetCertificateState(certKey) if err != nil { w.handleTLSError( ctx, certKey, hostname, ip, hasPrev, prev, now, err, ) return } w.handleTLSSuccess( ctx, certKey, hostname, ip, hasPrev, prev, now, cert, ) } func (w *Watcher) handleTLSError( ctx context.Context, certKey, hostname, ip string, hasPrev bool, prev *state.CertificateState, now time.Time, err error, ) { if hasPrev && !w.firstRun && prev.Status == "ok" { msg := fmt.Sprintf( "Host: %s\nIP: %s\nError: %s", hostname, ip, err, ) w.notify.SendNotification( ctx, "TLS Failure: "+hostname, msg, "error", ) } w.state.SetCertificateState( certKey, &state.CertificateState{ Status: "error", Error: err.Error(), LastChecked: now, }, ) } func (w *Watcher) handleTLSSuccess( ctx context.Context, certKey, hostname, ip string, hasPrev bool, prev *state.CertificateState, now time.Time, cert *tlscheck.CertificateInfo, ) { if hasPrev && !w.firstRun { w.detectTLSChanges(ctx, hostname, ip, prev, cert) } w.checkTLSExpiry(ctx, hostname, ip, cert) w.state.SetCertificateState( certKey, &state.CertificateState{ CommonName: cert.CommonName, Issuer: cert.Issuer, NotAfter: cert.NotAfter, SubjectAlternativeNames: cert.SubjectAlternativeNames, Status: "ok", LastChecked: now, }, ) } func (w *Watcher) detectTLSChanges( ctx context.Context, hostname, ip string, prev *state.CertificateState, cert *tlscheck.CertificateInfo, ) { if prev.Status == "error" { msg := fmt.Sprintf( "Host: %s\nIP: %s\nTLS recovered", hostname, ip, ) w.notify.SendNotification( ctx, "TLS Recovery: "+hostname, msg, "success", ) return } changed := prev.CommonName != cert.CommonName || prev.Issuer != cert.Issuer || !sliceEqual( prev.SubjectAlternativeNames, cert.SubjectAlternativeNames, ) if !changed { return } msg := fmt.Sprintf( "Host: %s\nIP: %s\n"+ "Old CN: %s, Issuer: %s\n"+ "New CN: %s, Issuer: %s", hostname, ip, prev.CommonName, prev.Issuer, cert.CommonName, cert.Issuer, ) w.notify.SendNotification( ctx, "TLS Certificate Change: "+hostname, msg, "warning", ) } func (w *Watcher) checkTLSExpiry( ctx context.Context, hostname, ip string, cert *tlscheck.CertificateInfo, ) { daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay warningDays := float64(w.config.TLSExpiryWarning) if daysLeft > warningDays { return } // Deduplicate expiry warnings: don't re-notify for the same // hostname within the TLS check interval. dedupKey := fmt.Sprintf("expiry:%s:%s", hostname, ip) w.expiryNotifiedMu.Lock() lastNotified, seen := w.expiryNotified[dedupKey] if seen && time.Since(lastNotified) < w.config.TLSInterval { w.expiryNotifiedMu.Unlock() return } w.expiryNotified[dedupKey] = time.Now() w.expiryNotifiedMu.Unlock() msg := fmt.Sprintf( "Host: %s\nIP: %s\nCN: %s\n"+ "Expires: %s (%.0f days)", hostname, ip, cert.CommonName, cert.NotAfter.Format(time.RFC3339), daysLeft, ) w.notify.SendNotification( ctx, "TLS Expiry Warning: "+hostname, msg, "warning", ) } func (w *Watcher) saveState() { err := w.state.Save() if err != nil { w.log.Error("failed to save state", "error", err) } } // maybeSendTestNotification sends a startup status notification // after the first full scan completes, if SEND_TEST_NOTIFICATION // is enabled. The message is clearly informational ("all ok") // and not an error or anomaly alert. func (w *Watcher) maybeSendTestNotification(ctx context.Context) { if !w.config.SendTestNotification { return } snap := w.state.GetSnapshot() msg := fmt.Sprintf( "dnswatcher has started and completed its initial scan.\n"+ "Monitoring %d domain(s) and %d hostname(s).\n"+ "Tracking %d port endpoint(s) and %d TLS certificate(s).\n"+ "All notification channels are working.", len(snap.Domains), len(snap.Hostnames), len(snap.Ports), len(snap.Certificates), ) w.log.Info("sending startup test notification") w.notify.SendNotification( ctx, "✅ dnswatcher startup complete", msg, "success", ) } // --- Utility functions --- func toSet(items []string) map[string]bool { set := make(map[string]bool, len(items)) for _, item := range items { set[item] = true } return set } func recordsEqual( a, b map[string][]string, ) bool { if len(a) != len(b) { return false } for k, av := range a { bv, ok := b[k] if !ok || !sliceEqual(av, bv) { return false } } return true } func sliceEqual(a, b []string) bool { if len(a) != len(b) { return false } aSorted := make([]string, len(a)) bSorted := make([]string, len(b)) copy(aSorted, a) copy(bSorted, b) sort.Strings(aSorted) sort.Strings(bSorted) for i := range aSorted { if aSorted[i] != bSorted[i] { return false } } return true }