feat: implement watcher monitoring orchestrator

Implements the full monitoring loop:
- Immediate checks on startup, then periodic DNS+port and TLS cycles
- Domain NS change detection with notifications
- Per-nameserver hostname record tracking with change/failure/recovery
  and inconsistency detection
- TCP port 80/443 monitoring with state change notifications
- TLS certificate monitoring with change, expiry, and failure detection
- State persistence after each cycle
- First run establishes baseline without notifications
- Graceful shutdown via context cancellation

Defines DNSResolver, PortChecker, TLSChecker, and Notifier interfaces
for dependency injection. Updates main.go fx wiring and resolver stub
signature to match per-NS record format.

Closes #2
This commit is contained in:
clawbot 2026-02-19 13:48:46 -08:00
parent dea30028b1
commit f676cc9458
3 changed files with 713 additions and 27 deletions

View File

@ -51,6 +51,20 @@ func main() {
handlers.New,
server.New,
),
fx.Provide(
func(r *resolver.Resolver) watcher.DNSResolver {
return r
},
func(p *portcheck.Checker) watcher.PortChecker {
return p
},
func(t *tlscheck.Checker) watcher.TLSChecker {
return t
},
func(n *notify.Service) watcher.Notifier {
return n
},
),
fx.Invoke(func(*server.Server, *watcher.Watcher) {}),
).Run()
}

View File

@ -46,11 +46,11 @@ func (r *Resolver) LookupNS(
}
// LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname.
// records for the given hostname, keyed by authoritative nameserver.
func (r *Resolver) LookupAllRecords(
_ context.Context,
_ string,
) (map[string][]string, error) {
) (map[string]map[string][]string, error) {
return nil, ErrNotImplemented
}

View File

@ -1,21 +1,30 @@
// Package watcher provides the main monitoring orchestrator and scheduler.
package watcher
import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
"time"
"go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/logger"
"sneak.berlin/go/dnswatcher/internal/notify"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/resolver"
"sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
// monitoredPorts are the TCP ports checked for each IP address.
var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals
// tlsPort is the port used for TLS certificate checks.
const tlsPort = 443
// hoursPerDay converts days to hours for duration calculations.
const hoursPerDay = 24
// Params contains dependencies for Watcher.
type Params struct {
fx.In
@ -23,10 +32,10 @@ type Params struct {
Logger *logger.Logger
Config *config.Config
State *state.State
Resolver *resolver.Resolver
PortCheck *portcheck.Checker
TLSCheck *tlscheck.Checker
Notify *notify.Service
Resolver DNSResolver
PortCheck PortChecker
TLSCheck TLSChecker
Notify Notifier
}
// Watcher orchestrates all monitoring checks on a schedule.
@ -34,19 +43,20 @@ type Watcher struct {
log *slog.Logger
config *config.Config
state *state.State
resolver *resolver.Resolver
portCheck *portcheck.Checker
tlsCheck *tlscheck.Checker
notify *notify.Service
resolver DNSResolver
portCheck PortChecker
tlsCheck TLSChecker
notify Notifier
cancel context.CancelFunc
firstRun bool
}
// New creates a new Watcher instance.
// New creates a new Watcher instance wired into the fx lifecycle.
func New(
lifecycle fx.Lifecycle,
params Params,
) (*Watcher, error) {
watcher := &Watcher{
w := &Watcher{
log: params.Logger.Get(),
config: params.Config,
state: params.State,
@ -54,30 +64,54 @@ func New(
portCheck: params.PortCheck,
tlsCheck: params.TLSCheck,
notify: params.Notify,
firstRun: true,
}
lifecycle.Append(fx.Hook{
OnStart: func(startCtx context.Context) error {
ctx, cancel := context.WithCancel(startCtx)
watcher.cancel = cancel
ctx, cancel := context.WithCancel(
context.WithoutCancel(startCtx),
)
w.cancel = cancel
go watcher.Run(ctx)
go w.Run(ctx)
return nil
},
OnStop: func(_ context.Context) error {
if watcher.cancel != nil {
watcher.cancel()
if w.cancel != nil {
w.cancel()
}
return nil
},
})
return watcher, nil
return w, nil
}
// Run starts the monitoring loop.
// NewForTest creates a Watcher without fx for unit testing.
func NewForTest(
cfg *config.Config,
st *state.State,
res DNSResolver,
pc PortChecker,
tc TLSChecker,
n Notifier,
) *Watcher {
return &Watcher{
log: slog.Default(),
config: cfg,
state: st,
resolver: res,
portCheck: pc,
tlsCheck: tc,
notify: n,
firstRun: true,
}
}
// Run starts the monitoring loop with periodic scheduling.
func (w *Watcher) Run(ctx context.Context) {
w.log.Info(
"watcher starting",
@ -87,8 +121,646 @@ func (w *Watcher) Run(ctx context.Context) {
"tlsInterval", w.config.TLSInterval,
)
// Stub: wait for context cancellation.
// Implementation will add initial check + periodic scheduling.
<-ctx.Done()
w.log.Info("watcher stopped")
w.RunOnce(ctx)
dnsTicker := time.NewTicker(w.config.DNSInterval)
tlsTicker := time.NewTicker(w.config.TLSInterval)
defer dnsTicker.Stop()
defer tlsTicker.Stop()
for {
select {
case <-ctx.Done():
w.log.Info("watcher stopped")
return
case <-dnsTicker.C:
w.runDNSAndPortChecks(ctx)
w.saveState()
case <-tlsTicker.C:
w.runTLSChecks(ctx)
w.saveState()
}
}
}
// RunOnce performs a single complete monitoring cycle.
func (w *Watcher) RunOnce(ctx context.Context) {
w.detectFirstRun()
w.runDNSAndPortChecks(ctx)
w.runTLSChecks(ctx)
w.saveState()
w.firstRun = false
}
func (w *Watcher) detectFirstRun() {
snap := w.state.GetSnapshot()
hasState := len(snap.Domains) > 0 ||
len(snap.Hostnames) > 0 ||
len(snap.Ports) > 0 ||
len(snap.Certificates) > 0
if hasState {
w.firstRun = false
}
}
func (w *Watcher) runDNSAndPortChecks(ctx context.Context) {
for _, domain := range w.config.Domains {
w.checkDomain(ctx, domain)
}
for _, hostname := range w.config.Hostnames {
w.checkHostname(ctx, hostname)
}
w.checkAllPorts(ctx)
}
func (w *Watcher) checkDomain(
ctx context.Context,
domain string,
) {
nameservers, err := w.resolver.LookupNS(ctx, domain)
if err != nil {
w.log.Error(
"failed to lookup NS",
"domain", domain,
"error", err,
)
return
}
sort.Strings(nameservers)
now := time.Now().UTC()
prev, hasPrev := w.state.GetDomainState(domain)
if hasPrev && !w.firstRun {
w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers)
}
w.state.SetDomainState(domain, &state.DomainState{
Nameservers: nameservers,
LastChecked: now,
})
}
func (w *Watcher) detectNSChanges(
ctx context.Context,
domain string,
oldNS, newNS []string,
) {
oldSet := toSet(oldNS)
newSet := toSet(newNS)
var added, removed []string
for ns := range newSet {
if !oldSet[ns] {
added = append(added, ns)
}
}
for ns := range oldSet {
if !newSet[ns] {
removed = append(removed, ns)
}
}
if len(added) == 0 && len(removed) == 0 {
return
}
msg := fmt.Sprintf(
"Domain: %s\nAdded: %s\nRemoved: %s",
domain,
strings.Join(added, ", "),
strings.Join(removed, ", "),
)
w.notify.SendNotification(
ctx,
"NS Change: "+domain,
msg,
"warning",
)
}
func (w *Watcher) checkHostname(
ctx context.Context,
hostname string,
) {
records, err := w.resolver.LookupAllRecords(ctx, hostname)
if err != nil {
w.log.Error(
"failed to lookup records",
"hostname", hostname,
"error", err,
)
return
}
now := time.Now().UTC()
prev, hasPrev := w.state.GetHostnameState(hostname)
if hasPrev && !w.firstRun {
w.detectHostnameChanges(ctx, hostname, prev, records)
}
newState := buildHostnameState(records, now)
w.state.SetHostnameState(hostname, newState)
}
func buildHostnameState(
records map[string]map[string][]string,
now time.Time,
) *state.HostnameState {
hs := &state.HostnameState{
RecordsByNameserver: make(
map[string]*state.NameserverRecordState,
),
LastChecked: now,
}
for ns, recs := range records {
hs.RecordsByNameserver[ns] = &state.NameserverRecordState{
Records: recs,
Status: "ok",
LastChecked: now,
}
}
return hs
}
func (w *Watcher) detectHostnameChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
w.detectRecordChanges(ctx, hostname, prev, current)
w.detectNSDisappearances(ctx, hostname, prev, current)
w.detectInconsistencies(ctx, hostname, current)
}
func (w *Watcher) detectRecordChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, recs := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok {
continue
}
if recordsEqual(prevNS.Records, recs) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s\n"+
"Old: %v\nNew: %v",
hostname, ns,
prevNS.Records, recs,
)
w.notify.SendNotification(
ctx,
"Record Change: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) detectNSDisappearances(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, prevNS := range prev.RecordsByNameserver {
if _, ok := current[ns]; ok || prevNS.Status != "ok" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s disappeared",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Failure: "+hostname,
msg,
"error",
)
}
for ns := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok || prevNS.Status != "error" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s recovered",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Recovery: "+hostname,
msg,
"success",
)
}
}
func (w *Watcher) detectInconsistencies(
ctx context.Context,
hostname string,
current map[string]map[string][]string,
) {
nameservers := make([]string, 0, len(current))
for ns := range current {
nameservers = append(nameservers, ns)
}
sort.Strings(nameservers)
for i := range len(nameservers) - 1 {
ns1 := nameservers[i]
ns2 := nameservers[i+1]
if recordsEqual(current[ns1], current[ns2]) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\n%s: %v\n%s: %v",
hostname,
ns1, current[ns1],
ns2, current[ns2],
)
w.notify.SendNotification(
ctx,
"Inconsistency: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) checkAllPorts(ctx context.Context) {
for _, hostname := range w.config.Hostnames {
w.checkPortsForHostname(ctx, hostname)
}
for _, domain := range w.config.Domains {
w.checkPortsForHostname(ctx, domain)
}
}
func (w *Watcher) checkPortsForHostname(
ctx context.Context,
hostname string,
) {
ips := w.collectIPs(hostname)
for _, ip := range ips {
for _, port := range monitoredPorts {
w.checkSinglePort(ctx, ip, port, hostname)
}
}
}
func (w *Watcher) collectIPs(hostname string) []string {
hs, ok := w.state.GetHostnameState(hostname)
if !ok {
return nil
}
ipSet := make(map[string]bool)
for _, nsState := range hs.RecordsByNameserver {
for _, ip := range nsState.Records["A"] {
ipSet[ip] = true
}
for _, ip := range nsState.Records["AAAA"] {
ipSet[ip] = true
}
}
result := make([]string, 0, len(ipSet))
for ip := range ipSet {
result = append(result, ip)
}
sort.Strings(result)
return result
}
func (w *Watcher) checkSinglePort(
ctx context.Context,
ip string,
port int,
hostname string,
) {
open, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil {
w.log.Error(
"port check failed",
"ip", ip,
"port", port,
"error", err,
)
return
}
key := fmt.Sprintf("%s:%d", ip, port)
now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != open {
stateStr := "closed"
if open {
stateStr = "open"
}
msg := fmt.Sprintf(
"Host: %s\nAddress: %s\nPort now %s",
hostname, key, stateStr,
)
w.notify.SendNotification(
ctx,
"Port Change: "+key,
msg,
"warning",
)
}
w.state.SetPortState(key, &state.PortState{
Open: open,
Hostname: hostname,
LastChecked: now,
})
}
func (w *Watcher) runTLSChecks(ctx context.Context) {
for _, hostname := range w.config.Hostnames {
w.checkTLSForHostname(ctx, hostname)
}
for _, domain := range w.config.Domains {
w.checkTLSForHostname(ctx, domain)
}
}
func (w *Watcher) checkTLSForHostname(
ctx context.Context,
hostname string,
) {
ips := w.collectIPs(hostname)
for _, ip := range ips {
portKey := fmt.Sprintf("%s:%d", ip, tlsPort)
ps, ok := w.state.GetPortState(portKey)
if !ok || !ps.Open {
continue
}
w.checkTLSCert(ctx, ip, hostname)
}
}
func (w *Watcher) checkTLSCert(
ctx context.Context,
ip string,
hostname string,
) {
cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname)
certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname)
now := time.Now().UTC()
prev, hasPrev := w.state.GetCertificateState(certKey)
if err != nil {
w.handleTLSError(
ctx, certKey, hostname, ip,
hasPrev, prev, now, err,
)
return
}
w.handleTLSSuccess(
ctx, certKey, hostname, ip,
hasPrev, prev, now, cert,
)
}
func (w *Watcher) handleTLSError(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
err error,
) {
if hasPrev && !w.firstRun && prev.Status == "ok" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nError: %s",
hostname, ip, err,
)
w.notify.SendNotification(
ctx,
"TLS Failure: "+hostname,
msg,
"error",
)
}
w.state.SetCertificateState(
certKey, &state.CertificateState{
Status: "error",
Error: err.Error(),
LastChecked: now,
},
)
}
func (w *Watcher) handleTLSSuccess(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
cert *tlscheck.CertificateInfo,
) {
if hasPrev && !w.firstRun {
w.detectTLSChanges(ctx, hostname, ip, prev, cert)
}
w.checkTLSExpiry(ctx, hostname, ip, cert)
w.state.SetCertificateState(
certKey, &state.CertificateState{
CommonName: cert.CommonName,
Issuer: cert.Issuer,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: cert.SubjectAlternativeNames,
Status: "ok",
LastChecked: now,
},
)
}
func (w *Watcher) detectTLSChanges(
ctx context.Context,
hostname, ip string,
prev *state.CertificateState,
cert *tlscheck.CertificateInfo,
) {
if prev.Status == "error" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nTLS recovered",
hostname, ip,
)
w.notify.SendNotification(
ctx,
"TLS Recovery: "+hostname,
msg,
"success",
)
return
}
changed := prev.CommonName != cert.CommonName ||
prev.Issuer != cert.Issuer ||
!sliceEqual(
prev.SubjectAlternativeNames,
cert.SubjectAlternativeNames,
)
if !changed {
return
}
msg := fmt.Sprintf(
"Host: %s\nIP: %s\n"+
"Old CN: %s, Issuer: %s\n"+
"New CN: %s, Issuer: %s",
hostname, ip,
prev.CommonName, prev.Issuer,
cert.CommonName, cert.Issuer,
)
w.notify.SendNotification(
ctx,
"TLS Certificate Change: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) checkTLSExpiry(
ctx context.Context,
hostname, ip string,
cert *tlscheck.CertificateInfo,
) {
daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay
warningDays := float64(w.config.TLSExpiryWarning)
if daysLeft > warningDays {
return
}
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nCN: %s\n"+
"Expires: %s (%.0f days)",
hostname, ip, cert.CommonName,
cert.NotAfter.Format(time.RFC3339),
daysLeft,
)
w.notify.SendNotification(
ctx,
"TLS Expiry Warning: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) saveState() {
err := w.state.Save()
if err != nil {
w.log.Error("failed to save state", "error", err)
}
}
// --- Utility functions ---
func toSet(items []string) map[string]bool {
set := make(map[string]bool, len(items))
for _, item := range items {
set[item] = true
}
return set
}
func recordsEqual(
a, b map[string][]string,
) bool {
if len(a) != len(b) {
return false
}
for k, av := range a {
bv, ok := b[k]
if !ok || !sliceEqual(av, bv) {
return false
}
}
return true
}
func sliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
aSorted := make([]string, len(a))
bSorted := make([]string, len(b))
copy(aSorted, a)
copy(bSorted, b)
sort.Strings(aSorted)
sort.Strings(bSorted)
for i := range aSorted {
if aSorted[i] != bSorted[i] {
return false
}
}
return true
}