package watcher import ( "context" "fmt" "log/slog" "sort" "strings" "time" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" ) // monitoredPorts are the TCP ports checked for each IP address. var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals // tlsPort is the port used for TLS certificate checks. const tlsPort = 443 // hoursPerDay converts days to hours for duration calculations. const hoursPerDay = 24 // Params contains dependencies for Watcher. type Params struct { fx.In Logger *logger.Logger Config *config.Config State *state.State Resolver DNSResolver PortCheck PortChecker TLSCheck TLSChecker Notify Notifier } // Watcher orchestrates all monitoring checks on a schedule. type Watcher struct { log *slog.Logger config *config.Config state *state.State resolver DNSResolver portCheck PortChecker tlsCheck TLSChecker notify Notifier cancel context.CancelFunc firstRun bool } // New creates a new Watcher instance wired into the fx lifecycle. func New( lifecycle fx.Lifecycle, params Params, ) (*Watcher, error) { w := &Watcher{ log: params.Logger.Get(), config: params.Config, state: params.State, resolver: params.Resolver, portCheck: params.PortCheck, tlsCheck: params.TLSCheck, notify: params.Notify, firstRun: true, } lifecycle.Append(fx.Hook{ OnStart: func(startCtx context.Context) error { ctx, cancel := context.WithCancel( context.WithoutCancel(startCtx), ) w.cancel = cancel go w.Run(ctx) return nil }, OnStop: func(_ context.Context) error { if w.cancel != nil { w.cancel() } return nil }, }) return w, nil } // NewForTest creates a Watcher without fx for unit testing. func NewForTest( cfg *config.Config, st *state.State, res DNSResolver, pc PortChecker, tc TLSChecker, n Notifier, ) *Watcher { return &Watcher{ log: slog.Default(), config: cfg, state: st, resolver: res, portCheck: pc, tlsCheck: tc, notify: n, firstRun: true, } } // Run starts the monitoring loop with periodic scheduling. func (w *Watcher) Run(ctx context.Context) { w.log.Info( "watcher starting", "domains", len(w.config.Domains), "hostnames", len(w.config.Hostnames), "dnsInterval", w.config.DNSInterval, "tlsInterval", w.config.TLSInterval, ) w.RunOnce(ctx) dnsTicker := time.NewTicker(w.config.DNSInterval) tlsTicker := time.NewTicker(w.config.TLSInterval) defer dnsTicker.Stop() defer tlsTicker.Stop() for { select { case <-ctx.Done(): w.log.Info("watcher stopped") return case <-dnsTicker.C: w.runDNSAndPortChecks(ctx) w.saveState() case <-tlsTicker.C: w.runTLSChecks(ctx) w.saveState() } } } // RunOnce performs a single complete monitoring cycle. func (w *Watcher) RunOnce(ctx context.Context) { w.detectFirstRun() w.runDNSAndPortChecks(ctx) w.runTLSChecks(ctx) w.saveState() w.firstRun = false } func (w *Watcher) detectFirstRun() { snap := w.state.GetSnapshot() hasState := len(snap.Domains) > 0 || len(snap.Hostnames) > 0 || len(snap.Ports) > 0 || len(snap.Certificates) > 0 if hasState { w.firstRun = false } } func (w *Watcher) runDNSAndPortChecks(ctx context.Context) { for _, domain := range w.config.Domains { w.checkDomain(ctx, domain) } for _, hostname := range w.config.Hostnames { w.checkHostname(ctx, hostname) } w.checkAllPorts(ctx) } func (w *Watcher) checkDomain( ctx context.Context, domain string, ) { nameservers, err := w.resolver.LookupNS(ctx, domain) if err != nil { w.log.Error( "failed to lookup NS", "domain", domain, "error", err, ) return } sort.Strings(nameservers) now := time.Now().UTC() prev, hasPrev := w.state.GetDomainState(domain) if hasPrev && !w.firstRun { w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers) } w.state.SetDomainState(domain, &state.DomainState{ Nameservers: nameservers, LastChecked: now, }) } func (w *Watcher) detectNSChanges( ctx context.Context, domain string, oldNS, newNS []string, ) { oldSet := toSet(oldNS) newSet := toSet(newNS) var added, removed []string for ns := range newSet { if !oldSet[ns] { added = append(added, ns) } } for ns := range oldSet { if !newSet[ns] { removed = append(removed, ns) } } if len(added) == 0 && len(removed) == 0 { return } msg := fmt.Sprintf( "Domain: %s\nAdded: %s\nRemoved: %s", domain, strings.Join(added, ", "), strings.Join(removed, ", "), ) w.notify.SendNotification( ctx, "NS Change: "+domain, msg, "warning", ) } func (w *Watcher) checkHostname( ctx context.Context, hostname string, ) { records, err := w.resolver.LookupAllRecords(ctx, hostname) if err != nil { w.log.Error( "failed to lookup records", "hostname", hostname, "error", err, ) return } now := time.Now().UTC() prev, hasPrev := w.state.GetHostnameState(hostname) if hasPrev && !w.firstRun { w.detectHostnameChanges(ctx, hostname, prev, records) } newState := buildHostnameState(records, now) w.state.SetHostnameState(hostname, newState) } func buildHostnameState( records map[string]map[string][]string, now time.Time, ) *state.HostnameState { hs := &state.HostnameState{ RecordsByNameserver: make( map[string]*state.NameserverRecordState, ), LastChecked: now, } for ns, recs := range records { hs.RecordsByNameserver[ns] = &state.NameserverRecordState{ Records: recs, Status: "ok", LastChecked: now, } } return hs } func (w *Watcher) detectHostnameChanges( ctx context.Context, hostname string, prev *state.HostnameState, current map[string]map[string][]string, ) { w.detectRecordChanges(ctx, hostname, prev, current) w.detectNSDisappearances(ctx, hostname, prev, current) w.detectInconsistencies(ctx, hostname, current) } func (w *Watcher) detectRecordChanges( ctx context.Context, hostname string, prev *state.HostnameState, current map[string]map[string][]string, ) { for ns, recs := range current { prevNS, ok := prev.RecordsByNameserver[ns] if !ok { continue } if recordsEqual(prevNS.Records, recs) { continue } msg := fmt.Sprintf( "Hostname: %s\nNameserver: %s\n"+ "Old: %v\nNew: %v", hostname, ns, prevNS.Records, recs, ) w.notify.SendNotification( ctx, "Record Change: "+hostname, msg, "warning", ) } } func (w *Watcher) detectNSDisappearances( ctx context.Context, hostname string, prev *state.HostnameState, current map[string]map[string][]string, ) { for ns, prevNS := range prev.RecordsByNameserver { if _, ok := current[ns]; ok || prevNS.Status != "ok" { continue } msg := fmt.Sprintf( "Hostname: %s\nNameserver: %s disappeared", hostname, ns, ) w.notify.SendNotification( ctx, "NS Failure: "+hostname, msg, "error", ) } for ns := range current { prevNS, ok := prev.RecordsByNameserver[ns] if !ok || prevNS.Status != "error" { continue } msg := fmt.Sprintf( "Hostname: %s\nNameserver: %s recovered", hostname, ns, ) w.notify.SendNotification( ctx, "NS Recovery: "+hostname, msg, "success", ) } } func (w *Watcher) detectInconsistencies( ctx context.Context, hostname string, current map[string]map[string][]string, ) { nameservers := make([]string, 0, len(current)) for ns := range current { nameservers = append(nameservers, ns) } sort.Strings(nameservers) for i := range len(nameservers) - 1 { ns1 := nameservers[i] ns2 := nameservers[i+1] if recordsEqual(current[ns1], current[ns2]) { continue } msg := fmt.Sprintf( "Hostname: %s\n%s: %v\n%s: %v", hostname, ns1, current[ns1], ns2, current[ns2], ) w.notify.SendNotification( ctx, "Inconsistency: "+hostname, msg, "warning", ) } } func (w *Watcher) checkAllPorts(ctx context.Context) { for _, hostname := range w.config.Hostnames { w.checkPortsForHostname(ctx, hostname) } for _, domain := range w.config.Domains { w.checkPortsForHostname(ctx, domain) } } func (w *Watcher) checkPortsForHostname( ctx context.Context, hostname string, ) { ips := w.collectIPs(hostname) for _, ip := range ips { for _, port := range monitoredPorts { w.checkSinglePort(ctx, ip, port, hostname) } } } func (w *Watcher) collectIPs(hostname string) []string { hs, ok := w.state.GetHostnameState(hostname) if !ok { return nil } ipSet := make(map[string]bool) for _, nsState := range hs.RecordsByNameserver { for _, ip := range nsState.Records["A"] { ipSet[ip] = true } for _, ip := range nsState.Records["AAAA"] { ipSet[ip] = true } } result := make([]string, 0, len(ipSet)) for ip := range ipSet { result = append(result, ip) } sort.Strings(result) return result } func (w *Watcher) checkSinglePort( ctx context.Context, ip string, port int, hostname string, ) { result, err := w.portCheck.CheckPort(ctx, ip, port) if err != nil { w.log.Error( "port check failed", "ip", ip, "port", port, "error", err, ) return } key := fmt.Sprintf("%s:%d", ip, port) now := time.Now().UTC() prev, hasPrev := w.state.GetPortState(key) if hasPrev && !w.firstRun && prev.Open != result.Open { stateStr := "closed" if result.Open { stateStr = "open" } msg := fmt.Sprintf( "Host: %s\nAddress: %s\nPort now %s", hostname, key, stateStr, ) w.notify.SendNotification( ctx, "Port Change: "+key, msg, "warning", ) } w.state.SetPortState(key, &state.PortState{ Open: result.Open, Hostname: hostname, LastChecked: now, }) } func (w *Watcher) runTLSChecks(ctx context.Context) { for _, hostname := range w.config.Hostnames { w.checkTLSForHostname(ctx, hostname) } for _, domain := range w.config.Domains { w.checkTLSForHostname(ctx, domain) } } func (w *Watcher) checkTLSForHostname( ctx context.Context, hostname string, ) { ips := w.collectIPs(hostname) for _, ip := range ips { portKey := fmt.Sprintf("%s:%d", ip, tlsPort) ps, ok := w.state.GetPortState(portKey) if !ok || !ps.Open { continue } w.checkTLSCert(ctx, ip, hostname) } } func (w *Watcher) checkTLSCert( ctx context.Context, ip string, hostname string, ) { cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname) certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname) now := time.Now().UTC() prev, hasPrev := w.state.GetCertificateState(certKey) if err != nil { w.handleTLSError( ctx, certKey, hostname, ip, hasPrev, prev, now, err, ) return } w.handleTLSSuccess( ctx, certKey, hostname, ip, hasPrev, prev, now, cert, ) } func (w *Watcher) handleTLSError( ctx context.Context, certKey, hostname, ip string, hasPrev bool, prev *state.CertificateState, now time.Time, err error, ) { if hasPrev && !w.firstRun && prev.Status == "ok" { msg := fmt.Sprintf( "Host: %s\nIP: %s\nError: %s", hostname, ip, err, ) w.notify.SendNotification( ctx, "TLS Failure: "+hostname, msg, "error", ) } w.state.SetCertificateState( certKey, &state.CertificateState{ Status: "error", Error: err.Error(), LastChecked: now, }, ) } func (w *Watcher) handleTLSSuccess( ctx context.Context, certKey, hostname, ip string, hasPrev bool, prev *state.CertificateState, now time.Time, cert *tlscheck.CertificateInfo, ) { if hasPrev && !w.firstRun { w.detectTLSChanges(ctx, hostname, ip, prev, cert) } w.checkTLSExpiry(ctx, hostname, ip, cert) w.state.SetCertificateState( certKey, &state.CertificateState{ CommonName: cert.CommonName, Issuer: cert.Issuer, NotAfter: cert.NotAfter, SubjectAlternativeNames: cert.SubjectAlternativeNames, Status: "ok", LastChecked: now, }, ) } func (w *Watcher) detectTLSChanges( ctx context.Context, hostname, ip string, prev *state.CertificateState, cert *tlscheck.CertificateInfo, ) { if prev.Status == "error" { msg := fmt.Sprintf( "Host: %s\nIP: %s\nTLS recovered", hostname, ip, ) w.notify.SendNotification( ctx, "TLS Recovery: "+hostname, msg, "success", ) return } changed := prev.CommonName != cert.CommonName || prev.Issuer != cert.Issuer || !sliceEqual( prev.SubjectAlternativeNames, cert.SubjectAlternativeNames, ) if !changed { return } msg := fmt.Sprintf( "Host: %s\nIP: %s\n"+ "Old CN: %s, Issuer: %s\n"+ "New CN: %s, Issuer: %s", hostname, ip, prev.CommonName, prev.Issuer, cert.CommonName, cert.Issuer, ) w.notify.SendNotification( ctx, "TLS Certificate Change: "+hostname, msg, "warning", ) } func (w *Watcher) checkTLSExpiry( ctx context.Context, hostname, ip string, cert *tlscheck.CertificateInfo, ) { daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay warningDays := float64(w.config.TLSExpiryWarning) if daysLeft > warningDays { return } msg := fmt.Sprintf( "Host: %s\nIP: %s\nCN: %s\n"+ "Expires: %s (%.0f days)", hostname, ip, cert.CommonName, cert.NotAfter.Format(time.RFC3339), daysLeft, ) w.notify.SendNotification( ctx, "TLS Expiry Warning: "+hostname, msg, "warning", ) } func (w *Watcher) saveState() { err := w.state.Save() if err != nil { w.log.Error("failed to save state", "error", err) } } // --- Utility functions --- func toSet(items []string) map[string]bool { set := make(map[string]bool, len(items)) for _, item := range items { set[item] = true } return set } func recordsEqual( a, b map[string][]string, ) bool { if len(a) != len(b) { return false } for k, av := range a { bv, ok := b[k] if !ok || !sliceEqual(av, bv) { return false } } return true } func sliceEqual(a, b []string) bool { if len(a) != len(b) { return false } aSorted := make([]string, len(a)) bSorted := make([]string, len(b)) copy(aSorted, a) copy(bSorted, b) sort.Strings(aSorted) sort.Strings(bSorted) for i := range aSorted { if aSorted[i] != bSorted[i] { return false } } return true }