7 Commits

Author SHA1 Message Date
d5738d6d43 Merge branch 'main' into feature/watcher-implementation 2026-02-20 09:06:27 +01:00
5e4631776a Merge pull request 'feat: unify DOMAINS/HOSTNAMES into single TARGETS config (closes #10)' (#11) from feature/unified-targets into main
Reviewed-on: #11
2026-02-20 09:04:59 +01:00
clawbot
f8d5a8f6cc fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:43:42 -08:00
clawbot
e09135d9d9 fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:42:50 -08:00
clawbot
73e01c7664 feat: unify DOMAINS/HOSTNAMES into single TARGETS config
Replace DNSWATCHER_DOMAINS and DNSWATCHER_HOSTNAMES with a single
DNSWATCHER_TARGETS env var. Names are automatically classified as apex
domains or hostnames using the Public Suffix List
(golang.org/x/net/publicsuffix).

- ClassifyDNSName() uses EffectiveTLDPlusOne to determine type
- Public suffixes themselves (e.g. co.uk) are rejected with an error
- Old DOMAINS/HOSTNAMES vars removed entirely (pre-1.0, no compat needed)
- README updated with pre-1.0 warning

Closes #10
2026-02-19 20:09:39 -08:00
clawbot
f676cc9458 feat: implement watcher monitoring orchestrator
Implements the full monitoring loop:
- Immediate checks on startup, then periodic DNS+port and TLS cycles
- Domain NS change detection with notifications
- Per-nameserver hostname record tracking with change/failure/recovery
  and inconsistency detection
- TCP port 80/443 monitoring with state change notifications
- TLS certificate monitoring with change, expiry, and failure detection
- State persistence after each cycle
- First run establishes baseline without notifications
- Graceful shutdown via context cancellation

Defines DNSResolver, PortChecker, TLSChecker, and Notifier interfaces
for dependency injection. Updates main.go fx wiring and resolver stub
signature to match per-NS record format.

Closes #2
2026-02-19 13:48:46 -08:00
clawbot
dea30028b1 test: add watcher orchestrator tests with mock dependencies
Tests cover: first-run baseline, NS change detection, record change
detection, port state changes, TLS expiry warnings, graceful shutdown,
and NS failure/recovery scenarios.
2026-02-19 13:48:38 -08:00
9 changed files with 1436 additions and 101 deletions

View File

@@ -1,5 +1,7 @@
# dnswatcher # dnswatcher
> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice.
dnswatcher is a production DNS and infrastructure monitoring daemon written in dnswatcher is a production DNS and infrastructure monitoring daemon written in
Go. It watches configured DNS domains and hostnames for changes, monitors TCP Go. It watches configured DNS domains and hostnames for changes, monitors TCP
port availability, tracks TLS certificate expiry, and delivers real-time port availability, tracks TLS certificate expiry, and delivers real-time
@@ -196,8 +198,6 @@ the following precedence (highest to lowest):
| `DNSWATCHER_DEBUG` | Enable debug logging | `false` | | `DNSWATCHER_DEBUG` | Enable debug logging | `false` |
| `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` | | `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` |
| `DNSWATCHER_TARGETS` | Comma-separated DNS names (auto-classified via PSL) | `""` | | `DNSWATCHER_TARGETS` | Comma-separated DNS names (auto-classified via PSL) | `""` |
| `DNSWATCHER_DOMAINS` | *(deprecated)* Comma-separated apex domains | `""` |
| `DNSWATCHER_HOSTNAMES` | *(deprecated)* Comma-separated hostnames | `""` |
| `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` | | `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` |
| `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` | | `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` |
| `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` | | `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` |

View File

@@ -51,6 +51,20 @@ func main() {
handlers.New, handlers.New,
server.New, server.New,
), ),
fx.Provide(
func(r *resolver.Resolver) watcher.DNSResolver {
return r
},
func(p *portcheck.Checker) watcher.PortChecker {
return p
},
func(t *tlscheck.Checker) watcher.TLSChecker {
return t
},
func(n *notify.Service) watcher.Notifier {
return n
},
),
fx.Invoke(func(*server.Server, *watcher.Watcher) {}), fx.Invoke(func(*server.Server, *watcher.Watcher) {}),
).Run() ).Run()
} }

View File

@@ -90,8 +90,6 @@ func setupViper(name string) {
viper.SetDefault("DEBUG", false) viper.SetDefault("DEBUG", false)
viper.SetDefault("DATA_DIR", "./data") viper.SetDefault("DATA_DIR", "./data")
viper.SetDefault("TARGETS", "") viper.SetDefault("TARGETS", "")
viper.SetDefault("DOMAINS", "")
viper.SetDefault("HOSTNAMES", "")
viper.SetDefault("SLACK_WEBHOOK", "") viper.SetDefault("SLACK_WEBHOOK", "")
viper.SetDefault("MATTERMOST_WEBHOOK", "") viper.SetDefault("MATTERMOST_WEBHOOK", "")
viper.SetDefault("NTFY_TOPIC", "") viper.SetDefault("NTFY_TOPIC", "")
@@ -134,7 +132,9 @@ func buildConfig(
tlsInterval = defaultTLSInterval tlsInterval = defaultTLSInterval
} }
domains, hostnames, err := resolveTargets(log) domains, hostnames, err := ClassifyTargets(
parseCSV(viper.GetString("TARGETS")),
)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid targets configuration: %w", err) return nil, fmt.Errorf("invalid targets configuration: %w", err)
} }
@@ -187,56 +187,6 @@ func configureDebugLogging(cfg *Config, params Params) {
} }
} }
// resolveTargets merges DNSWATCHER_TARGETS with the deprecated
// DNSWATCHER_DOMAINS and DNSWATCHER_HOSTNAMES variables. When TARGETS
// is set, names are automatically classified using the Public Suffix
// List. Legacy variables are merged in and a deprecation warning is
// logged when they are used.
func resolveTargets(log *slog.Logger) ([]string, []string, error) {
targets := parseCSV(viper.GetString("TARGETS"))
legacyDomains := parseCSV(viper.GetString("DOMAINS"))
legacyHostnames := parseCSV(viper.GetString("HOSTNAMES"))
if len(legacyDomains) > 0 || len(legacyHostnames) > 0 {
log.Warn(
"DNSWATCHER_DOMAINS and DNSWATCHER_HOSTNAMES are deprecated; use DNSWATCHER_TARGETS instead",
)
}
var domains, hostnames []string
if len(targets) > 0 {
var err error
domains, hostnames, err = ClassifyTargets(targets)
if err != nil {
return nil, nil, err
}
}
domains = mergeUnique(domains, legacyDomains)
hostnames = mergeUnique(hostnames, legacyHostnames)
return domains, hostnames, nil
}
// mergeUnique appends items from b into a, skipping duplicates.
func mergeUnique(a, b []string) []string {
seen := make(map[string]bool, len(a))
for _, v := range a {
seen[v] = true
}
for _, v := range b {
if !seen[v] {
a = append(a, v)
seen[v] = true
}
}
return a
}
// StatePath returns the full path to the state JSON file. // StatePath returns the full path to the state JSON file.
func (c *Config) StatePath() string { func (c *Config) StatePath() string {
return c.DataDir + "/state.json" return c.DataDir + "/state.json"

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@@ -48,6 +49,9 @@ type Service struct {
log *slog.Logger log *slog.Logger
client *http.Client client *http.Client
config *config.Config config *config.Config
ntfyURL *url.URL
slackWebhookURL *url.URL
mattermostWebhookURL *url.URL
} }
// New creates a new notify Service. // New creates a new notify Service.
@@ -55,13 +59,44 @@ func New(
_ fx.Lifecycle, _ fx.Lifecycle,
params Params, params Params,
) (*Service, error) { ) (*Service, error) {
return &Service{ svc := &Service{
log: params.Logger.Get(), log: params.Logger.Get(),
client: &http.Client{ client: &http.Client{
Timeout: httpClientTimeout, Timeout: httpClientTimeout,
}, },
config: params.Config, config: params.Config,
}, nil }
if params.Config.NtfyTopic != "" {
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
if err != nil {
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
}
svc.ntfyURL = u
}
if params.Config.SlackWebhook != "" {
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
if err != nil {
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
}
svc.slackWebhookURL = u
}
if params.Config.MattermostWebhook != "" {
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
if err != nil {
return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err,
)
}
svc.mattermostWebhookURL = u
}
return svc, nil
} }
// SendNotification sends a notification to all configured endpoints. // SendNotification sends a notification to all configured endpoints.
@@ -69,13 +104,13 @@ func (svc *Service) SendNotification(
ctx context.Context, ctx context.Context,
title, message, priority string, title, message, priority string,
) { ) {
if svc.config.NtfyTopic != "" { if svc.ntfyURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendNtfy( err := svc.sendNtfy(
notifyCtx, notifyCtx,
svc.config.NtfyTopic, svc.ntfyURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -87,13 +122,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.SlackWebhook != "" { if svc.slackWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.SlackWebhook, svc.slackWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -105,13 +140,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.MattermostWebhook != "" { if svc.mattermostWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.MattermostWebhook, svc.mattermostWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -126,18 +161,19 @@ func (svc *Service) SendNotification(
func (svc *Service) sendNtfy( func (svc *Service) sendNtfy(
ctx context.Context, ctx context.Context,
topic, title, message, priority string, topicURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending ntfy notification", "sending ntfy notification",
"topic", topic, "topic", topicURL.String(),
"title", title, "title", title,
) )
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
topic, topicURL.String(),
bytes.NewBufferString(message), bytes.NewBufferString(message),
) )
if err != nil { if err != nil {
@@ -147,7 +183,7 @@ func (svc *Service) sendNtfy(
request.Header.Set("Title", title) request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority)) request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending ntfy request: %w", err) return fmt.Errorf("sending ntfy request: %w", err)
} }
@@ -193,11 +229,12 @@ type SlackAttachment struct {
func (svc *Service) sendSlack( func (svc *Service) sendSlack(
ctx context.Context, ctx context.Context,
webhookURL, title, message, priority string, webhookURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending webhook notification", "sending webhook notification",
"url", webhookURL, "url", webhookURL.String(),
"title", title, "title", title,
) )
@@ -219,7 +256,7 @@ func (svc *Service) sendSlack(
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
webhookURL, webhookURL.String(),
bytes.NewBuffer(body), bytes.NewBuffer(body),
) )
if err != nil { if err != nil {
@@ -228,7 +265,7 @@ func (svc *Service) sendSlack(
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending webhook request: %w", err) return fmt.Errorf("sending webhook request: %w", err)
} }

View File

@@ -46,11 +46,11 @@ func (r *Resolver) LookupNS(
} }
// LookupAllRecords performs iterative resolution to find all DNS // LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname. // records for the given hostname, keyed by authoritative nameserver.
func (r *Resolver) LookupAllRecords( func (r *Resolver) LookupAllRecords(
_ context.Context, _ context.Context,
_ string, _ string,
) (map[string][]string, error) { ) (map[string]map[string][]string, error) {
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }

View File

@@ -0,0 +1,22 @@
package state
import (
"log/slog"
"sneak.berlin/go/dnswatcher/internal/config"
)
// NewForTest creates a State for unit testing with no persistence.
func NewForTest() *State {
return &State{
log: slog.Default(),
snapshot: &Snapshot{
Version: stateVersion,
Domains: make(map[string]*DomainState),
Hostnames: make(map[string]*HostnameState),
Ports: make(map[string]*PortState),
Certificates: make(map[string]*CertificateState),
},
config: &config.Config{DataDir: ""},
}
}

View File

@@ -0,0 +1,60 @@
// Package watcher provides the main monitoring orchestrator.
package watcher
import (
"context"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
// DNSResolver performs iterative DNS resolution.
type DNSResolver interface {
// LookupNS discovers authoritative nameservers for a domain.
LookupNS(
ctx context.Context,
domain string,
) ([]string, error)
// LookupAllRecords queries all record types for a hostname,
// returning results keyed by nameserver then record type.
LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error)
// ResolveIPAddresses resolves a hostname to all IP addresses.
ResolveIPAddresses(
ctx context.Context,
hostname string,
) ([]string, error)
}
// PortChecker tests TCP port connectivity.
type PortChecker interface {
// CheckPort tests TCP connectivity to an address and port.
CheckPort(
ctx context.Context,
address string,
port int,
) (bool, error)
}
// TLSChecker inspects TLS certificates.
type TLSChecker interface {
// CheckCertificate connects via TLS and returns cert info.
CheckCertificate(
ctx context.Context,
ip string,
hostname string,
) (*tlscheck.CertificateInfo, error)
}
// Notifier delivers notifications to configured endpoints.
type Notifier interface {
// SendNotification sends a notification with the given
// details.
SendNotification(
ctx context.Context,
title, message, priority string,
)
}

View File

@@ -1,21 +1,30 @@
// Package watcher provides the main monitoring orchestrator and scheduler.
package watcher package watcher
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"sort"
"strings"
"time"
"go.uber.org/fx" "go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
"sneak.berlin/go/dnswatcher/internal/notify"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/resolver"
"sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/tlscheck"
) )
// monitoredPorts are the TCP ports checked for each IP address.
var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals
// tlsPort is the port used for TLS certificate checks.
const tlsPort = 443
// hoursPerDay converts days to hours for duration calculations.
const hoursPerDay = 24
// Params contains dependencies for Watcher. // Params contains dependencies for Watcher.
type Params struct { type Params struct {
fx.In fx.In
@@ -23,10 +32,10 @@ type Params struct {
Logger *logger.Logger Logger *logger.Logger
Config *config.Config Config *config.Config
State *state.State State *state.State
Resolver *resolver.Resolver Resolver DNSResolver
PortCheck *portcheck.Checker PortCheck PortChecker
TLSCheck *tlscheck.Checker TLSCheck TLSChecker
Notify *notify.Service Notify Notifier
} }
// Watcher orchestrates all monitoring checks on a schedule. // Watcher orchestrates all monitoring checks on a schedule.
@@ -34,19 +43,20 @@ type Watcher struct {
log *slog.Logger log *slog.Logger
config *config.Config config *config.Config
state *state.State state *state.State
resolver *resolver.Resolver resolver DNSResolver
portCheck *portcheck.Checker portCheck PortChecker
tlsCheck *tlscheck.Checker tlsCheck TLSChecker
notify *notify.Service notify Notifier
cancel context.CancelFunc cancel context.CancelFunc
firstRun bool
} }
// New creates a new Watcher instance. // New creates a new Watcher instance wired into the fx lifecycle.
func New( func New(
lifecycle fx.Lifecycle, lifecycle fx.Lifecycle,
params Params, params Params,
) (*Watcher, error) { ) (*Watcher, error) {
watcher := &Watcher{ w := &Watcher{
log: params.Logger.Get(), log: params.Logger.Get(),
config: params.Config, config: params.Config,
state: params.State, state: params.State,
@@ -54,30 +64,54 @@ func New(
portCheck: params.PortCheck, portCheck: params.PortCheck,
tlsCheck: params.TLSCheck, tlsCheck: params.TLSCheck,
notify: params.Notify, notify: params.Notify,
firstRun: true,
} }
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{
OnStart: func(startCtx context.Context) error { OnStart: func(startCtx context.Context) error {
ctx, cancel := context.WithCancel(startCtx) ctx, cancel := context.WithCancel(
watcher.cancel = cancel context.WithoutCancel(startCtx),
)
w.cancel = cancel
go watcher.Run(ctx) go w.Run(ctx)
return nil return nil
}, },
OnStop: func(_ context.Context) error { OnStop: func(_ context.Context) error {
if watcher.cancel != nil { if w.cancel != nil {
watcher.cancel() w.cancel()
} }
return nil return nil
}, },
}) })
return watcher, nil return w, nil
} }
// Run starts the monitoring loop. // NewForTest creates a Watcher without fx for unit testing.
func NewForTest(
cfg *config.Config,
st *state.State,
res DNSResolver,
pc PortChecker,
tc TLSChecker,
n Notifier,
) *Watcher {
return &Watcher{
log: slog.Default(),
config: cfg,
state: st,
resolver: res,
portCheck: pc,
tlsCheck: tc,
notify: n,
firstRun: true,
}
}
// Run starts the monitoring loop with periodic scheduling.
func (w *Watcher) Run(ctx context.Context) { func (w *Watcher) Run(ctx context.Context) {
w.log.Info( w.log.Info(
"watcher starting", "watcher starting",
@@ -87,8 +121,646 @@ func (w *Watcher) Run(ctx context.Context) {
"tlsInterval", w.config.TLSInterval, "tlsInterval", w.config.TLSInterval,
) )
// Stub: wait for context cancellation. w.RunOnce(ctx)
// Implementation will add initial check + periodic scheduling.
<-ctx.Done() dnsTicker := time.NewTicker(w.config.DNSInterval)
tlsTicker := time.NewTicker(w.config.TLSInterval)
defer dnsTicker.Stop()
defer tlsTicker.Stop()
for {
select {
case <-ctx.Done():
w.log.Info("watcher stopped") w.log.Info("watcher stopped")
return
case <-dnsTicker.C:
w.runDNSAndPortChecks(ctx)
w.saveState()
case <-tlsTicker.C:
w.runTLSChecks(ctx)
w.saveState()
}
}
}
// RunOnce performs a single complete monitoring cycle.
func (w *Watcher) RunOnce(ctx context.Context) {
w.detectFirstRun()
w.runDNSAndPortChecks(ctx)
w.runTLSChecks(ctx)
w.saveState()
w.firstRun = false
}
func (w *Watcher) detectFirstRun() {
snap := w.state.GetSnapshot()
hasState := len(snap.Domains) > 0 ||
len(snap.Hostnames) > 0 ||
len(snap.Ports) > 0 ||
len(snap.Certificates) > 0
if hasState {
w.firstRun = false
}
}
func (w *Watcher) runDNSAndPortChecks(ctx context.Context) {
for _, domain := range w.config.Domains {
w.checkDomain(ctx, domain)
}
for _, hostname := range w.config.Hostnames {
w.checkHostname(ctx, hostname)
}
w.checkAllPorts(ctx)
}
func (w *Watcher) checkDomain(
ctx context.Context,
domain string,
) {
nameservers, err := w.resolver.LookupNS(ctx, domain)
if err != nil {
w.log.Error(
"failed to lookup NS",
"domain", domain,
"error", err,
)
return
}
sort.Strings(nameservers)
now := time.Now().UTC()
prev, hasPrev := w.state.GetDomainState(domain)
if hasPrev && !w.firstRun {
w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers)
}
w.state.SetDomainState(domain, &state.DomainState{
Nameservers: nameservers,
LastChecked: now,
})
}
func (w *Watcher) detectNSChanges(
ctx context.Context,
domain string,
oldNS, newNS []string,
) {
oldSet := toSet(oldNS)
newSet := toSet(newNS)
var added, removed []string
for ns := range newSet {
if !oldSet[ns] {
added = append(added, ns)
}
}
for ns := range oldSet {
if !newSet[ns] {
removed = append(removed, ns)
}
}
if len(added) == 0 && len(removed) == 0 {
return
}
msg := fmt.Sprintf(
"Domain: %s\nAdded: %s\nRemoved: %s",
domain,
strings.Join(added, ", "),
strings.Join(removed, ", "),
)
w.notify.SendNotification(
ctx,
"NS Change: "+domain,
msg,
"warning",
)
}
func (w *Watcher) checkHostname(
ctx context.Context,
hostname string,
) {
records, err := w.resolver.LookupAllRecords(ctx, hostname)
if err != nil {
w.log.Error(
"failed to lookup records",
"hostname", hostname,
"error", err,
)
return
}
now := time.Now().UTC()
prev, hasPrev := w.state.GetHostnameState(hostname)
if hasPrev && !w.firstRun {
w.detectHostnameChanges(ctx, hostname, prev, records)
}
newState := buildHostnameState(records, now)
w.state.SetHostnameState(hostname, newState)
}
func buildHostnameState(
records map[string]map[string][]string,
now time.Time,
) *state.HostnameState {
hs := &state.HostnameState{
RecordsByNameserver: make(
map[string]*state.NameserverRecordState,
),
LastChecked: now,
}
for ns, recs := range records {
hs.RecordsByNameserver[ns] = &state.NameserverRecordState{
Records: recs,
Status: "ok",
LastChecked: now,
}
}
return hs
}
func (w *Watcher) detectHostnameChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
w.detectRecordChanges(ctx, hostname, prev, current)
w.detectNSDisappearances(ctx, hostname, prev, current)
w.detectInconsistencies(ctx, hostname, current)
}
func (w *Watcher) detectRecordChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, recs := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok {
continue
}
if recordsEqual(prevNS.Records, recs) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s\n"+
"Old: %v\nNew: %v",
hostname, ns,
prevNS.Records, recs,
)
w.notify.SendNotification(
ctx,
"Record Change: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) detectNSDisappearances(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, prevNS := range prev.RecordsByNameserver {
if _, ok := current[ns]; ok || prevNS.Status != "ok" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s disappeared",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Failure: "+hostname,
msg,
"error",
)
}
for ns := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok || prevNS.Status != "error" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s recovered",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Recovery: "+hostname,
msg,
"success",
)
}
}
func (w *Watcher) detectInconsistencies(
ctx context.Context,
hostname string,
current map[string]map[string][]string,
) {
nameservers := make([]string, 0, len(current))
for ns := range current {
nameservers = append(nameservers, ns)
}
sort.Strings(nameservers)
for i := range len(nameservers) - 1 {
ns1 := nameservers[i]
ns2 := nameservers[i+1]
if recordsEqual(current[ns1], current[ns2]) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\n%s: %v\n%s: %v",
hostname,
ns1, current[ns1],
ns2, current[ns2],
)
w.notify.SendNotification(
ctx,
"Inconsistency: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) checkAllPorts(ctx context.Context) {
for _, hostname := range w.config.Hostnames {
w.checkPortsForHostname(ctx, hostname)
}
for _, domain := range w.config.Domains {
w.checkPortsForHostname(ctx, domain)
}
}
func (w *Watcher) checkPortsForHostname(
ctx context.Context,
hostname string,
) {
ips := w.collectIPs(hostname)
for _, ip := range ips {
for _, port := range monitoredPorts {
w.checkSinglePort(ctx, ip, port, hostname)
}
}
}
func (w *Watcher) collectIPs(hostname string) []string {
hs, ok := w.state.GetHostnameState(hostname)
if !ok {
return nil
}
ipSet := make(map[string]bool)
for _, nsState := range hs.RecordsByNameserver {
for _, ip := range nsState.Records["A"] {
ipSet[ip] = true
}
for _, ip := range nsState.Records["AAAA"] {
ipSet[ip] = true
}
}
result := make([]string, 0, len(ipSet))
for ip := range ipSet {
result = append(result, ip)
}
sort.Strings(result)
return result
}
func (w *Watcher) checkSinglePort(
ctx context.Context,
ip string,
port int,
hostname string,
) {
open, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil {
w.log.Error(
"port check failed",
"ip", ip,
"port", port,
"error", err,
)
return
}
key := fmt.Sprintf("%s:%d", ip, port)
now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != open {
stateStr := "closed"
if open {
stateStr = "open"
}
msg := fmt.Sprintf(
"Host: %s\nAddress: %s\nPort now %s",
hostname, key, stateStr,
)
w.notify.SendNotification(
ctx,
"Port Change: "+key,
msg,
"warning",
)
}
w.state.SetPortState(key, &state.PortState{
Open: open,
Hostname: hostname,
LastChecked: now,
})
}
func (w *Watcher) runTLSChecks(ctx context.Context) {
for _, hostname := range w.config.Hostnames {
w.checkTLSForHostname(ctx, hostname)
}
for _, domain := range w.config.Domains {
w.checkTLSForHostname(ctx, domain)
}
}
func (w *Watcher) checkTLSForHostname(
ctx context.Context,
hostname string,
) {
ips := w.collectIPs(hostname)
for _, ip := range ips {
portKey := fmt.Sprintf("%s:%d", ip, tlsPort)
ps, ok := w.state.GetPortState(portKey)
if !ok || !ps.Open {
continue
}
w.checkTLSCert(ctx, ip, hostname)
}
}
func (w *Watcher) checkTLSCert(
ctx context.Context,
ip string,
hostname string,
) {
cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname)
certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname)
now := time.Now().UTC()
prev, hasPrev := w.state.GetCertificateState(certKey)
if err != nil {
w.handleTLSError(
ctx, certKey, hostname, ip,
hasPrev, prev, now, err,
)
return
}
w.handleTLSSuccess(
ctx, certKey, hostname, ip,
hasPrev, prev, now, cert,
)
}
func (w *Watcher) handleTLSError(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
err error,
) {
if hasPrev && !w.firstRun && prev.Status == "ok" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nError: %s",
hostname, ip, err,
)
w.notify.SendNotification(
ctx,
"TLS Failure: "+hostname,
msg,
"error",
)
}
w.state.SetCertificateState(
certKey, &state.CertificateState{
Status: "error",
Error: err.Error(),
LastChecked: now,
},
)
}
func (w *Watcher) handleTLSSuccess(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
cert *tlscheck.CertificateInfo,
) {
if hasPrev && !w.firstRun {
w.detectTLSChanges(ctx, hostname, ip, prev, cert)
}
w.checkTLSExpiry(ctx, hostname, ip, cert)
w.state.SetCertificateState(
certKey, &state.CertificateState{
CommonName: cert.CommonName,
Issuer: cert.Issuer,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: cert.SubjectAlternativeNames,
Status: "ok",
LastChecked: now,
},
)
}
func (w *Watcher) detectTLSChanges(
ctx context.Context,
hostname, ip string,
prev *state.CertificateState,
cert *tlscheck.CertificateInfo,
) {
if prev.Status == "error" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nTLS recovered",
hostname, ip,
)
w.notify.SendNotification(
ctx,
"TLS Recovery: "+hostname,
msg,
"success",
)
return
}
changed := prev.CommonName != cert.CommonName ||
prev.Issuer != cert.Issuer ||
!sliceEqual(
prev.SubjectAlternativeNames,
cert.SubjectAlternativeNames,
)
if !changed {
return
}
msg := fmt.Sprintf(
"Host: %s\nIP: %s\n"+
"Old CN: %s, Issuer: %s\n"+
"New CN: %s, Issuer: %s",
hostname, ip,
prev.CommonName, prev.Issuer,
cert.CommonName, cert.Issuer,
)
w.notify.SendNotification(
ctx,
"TLS Certificate Change: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) checkTLSExpiry(
ctx context.Context,
hostname, ip string,
cert *tlscheck.CertificateInfo,
) {
daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay
warningDays := float64(w.config.TLSExpiryWarning)
if daysLeft > warningDays {
return
}
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nCN: %s\n"+
"Expires: %s (%.0f days)",
hostname, ip, cert.CommonName,
cert.NotAfter.Format(time.RFC3339),
daysLeft,
)
w.notify.SendNotification(
ctx,
"TLS Expiry Warning: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) saveState() {
err := w.state.Save()
if err != nil {
w.log.Error("failed to save state", "error", err)
}
}
// --- Utility functions ---
func toSet(items []string) map[string]bool {
set := make(map[string]bool, len(items))
for _, item := range items {
set[item] = true
}
return set
}
func recordsEqual(
a, b map[string][]string,
) bool {
if len(a) != len(b) {
return false
}
for k, av := range a {
bv, ok := b[k]
if !ok || !sliceEqual(av, bv) {
return false
}
}
return true
}
func sliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
aSorted := make([]string, len(a))
bSorted := make([]string, len(b))
copy(aSorted, a)
copy(bSorted, b)
sort.Strings(aSorted)
sort.Strings(bSorted)
for i := range aSorted {
if aSorted[i] != bSorted[i] {
return false
}
}
return true
} }

View File

@@ -0,0 +1,580 @@
package watcher_test
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
"sneak.berlin/go/dnswatcher/internal/watcher"
)
// errNotFound is returned when mock data is missing.
var errNotFound = errors.New("not found")
// --- Mock implementations ---
type mockResolver struct {
mu sync.Mutex
nsRecords map[string][]string
allRecords map[string]map[string]map[string][]string
ipAddresses map[string][]string
lookupNSErr error
allRecordsErr error
resolveIPErr error
lookupNSCalls int
allRecordCalls int
}
func (m *mockResolver) LookupNS(
_ context.Context,
domain string,
) ([]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.lookupNSCalls++
if m.lookupNSErr != nil {
return nil, m.lookupNSErr
}
ns, ok := m.nsRecords[domain]
if !ok {
return nil, fmt.Errorf(
"%w: NS for %s", errNotFound, domain,
)
}
return ns, nil
}
func (m *mockResolver) LookupAllRecords(
_ context.Context,
hostname string,
) (map[string]map[string][]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.allRecordCalls++
if m.allRecordsErr != nil {
return nil, m.allRecordsErr
}
recs, ok := m.allRecords[hostname]
if !ok {
return nil, fmt.Errorf(
"%w: records for %s", errNotFound, hostname,
)
}
return recs, nil
}
func (m *mockResolver) ResolveIPAddresses(
_ context.Context,
hostname string,
) ([]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.resolveIPErr != nil {
return nil, m.resolveIPErr
}
ips, ok := m.ipAddresses[hostname]
if !ok {
return nil, fmt.Errorf(
"%w: IPs for %s", errNotFound, hostname,
)
}
return ips, nil
}
type mockPortChecker struct {
mu sync.Mutex
results map[string]bool
err error
calls int
}
func (m *mockPortChecker) CheckPort(
_ context.Context,
address string,
port int,
) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return false, m.err
}
key := fmt.Sprintf("%s:%d", address, port)
open, ok := m.results[key]
if !ok {
return false, nil
}
return open, nil
}
type mockTLSChecker struct {
mu sync.Mutex
certs map[string]*tlscheck.CertificateInfo
err error
calls int
}
func (m *mockTLSChecker) CheckCertificate(
_ context.Context,
ip string,
hostname string,
) (*tlscheck.CertificateInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return nil, m.err
}
key := fmt.Sprintf("%s:%s", ip, hostname)
cert, ok := m.certs[key]
if !ok {
return nil, fmt.Errorf(
"%w: cert for %s", errNotFound, key,
)
}
return cert, nil
}
type notification struct {
Title string
Message string
Priority string
}
type mockNotifier struct {
mu sync.Mutex
notifications []notification
}
func (m *mockNotifier) SendNotification(
_ context.Context,
title, message, priority string,
) {
m.mu.Lock()
defer m.mu.Unlock()
m.notifications = append(m.notifications, notification{
Title: title,
Message: message,
Priority: priority,
})
}
func (m *mockNotifier) getNotifications() []notification {
m.mu.Lock()
defer m.mu.Unlock()
result := make([]notification, len(m.notifications))
copy(result, m.notifications)
return result
}
// --- Helper to build a Watcher for testing ---
type testDeps struct {
resolver *mockResolver
portChecker *mockPortChecker
tlsChecker *mockTLSChecker
notifier *mockNotifier
state *state.State
config *config.Config
}
func newTestWatcher(
t *testing.T,
cfg *config.Config,
) (*watcher.Watcher, *testDeps) {
t.Helper()
deps := &testDeps{
resolver: &mockResolver{
nsRecords: make(map[string][]string),
allRecords: make(map[string]map[string]map[string][]string),
ipAddresses: make(map[string][]string),
},
portChecker: &mockPortChecker{
results: make(map[string]bool),
},
tlsChecker: &mockTLSChecker{
certs: make(map[string]*tlscheck.CertificateInfo),
},
notifier: &mockNotifier{},
config: cfg,
}
deps.state = state.NewForTest()
w := watcher.NewForTest(
deps.config,
deps.state,
deps.resolver,
deps.portChecker,
deps.tlsChecker,
deps.notifier,
)
return w, deps
}
func defaultTestConfig(t *testing.T) *config.Config {
t.Helper()
return &config.Config{
DNSInterval: time.Hour,
TLSInterval: 12 * time.Hour,
TLSExpiryWarning: 7,
DataDir: t.TempDir(),
}
}
func TestFirstRunBaseline(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
setupBaselineMocks(deps)
w.RunOnce(t.Context())
assertNoNotifications(t, deps)
assertStatePopulated(t, deps)
}
func setupBaselineMocks(deps *testDeps) {
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns2.example.com.",
}
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
"ns2.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.34",
}
deps.portChecker.results["93.184.216.34:80"] = true
deps.portChecker.results["93.184.216.34:443"] = true
deps.tlsChecker.certs["93.184.216.34:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
}
func assertNoNotifications(
t *testing.T,
deps *testDeps,
) {
t.Helper()
notifications := deps.notifier.getNotifications()
if len(notifications) != 0 {
t.Errorf(
"expected 0 notifications on first run, got %d",
len(notifications),
)
}
}
func assertStatePopulated(
t *testing.T,
deps *testDeps,
) {
t.Helper()
snap := deps.state.GetSnapshot()
if len(snap.Domains) != 1 {
t.Errorf(
"expected 1 domain in state, got %d",
len(snap.Domains),
)
}
if len(snap.Hostnames) != 1 {
t.Errorf(
"expected 1 hostname in state, got %d",
len(snap.Hostnames),
)
}
}
func TestNSChangeDetection(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns2.example.com.",
}
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns3.example.com.",
}
deps.resolver.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for NS change")
}
found := false
for _, n := range notifications {
if n.Priority == "warning" {
found = true
}
}
if !found {
t.Error("expected warning-priority NS change notification")
}
}
func TestRecordChangeDetection(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.34",
}
deps.portChecker.results["93.184.216.34:80"] = false
deps.portChecker.results["93.184.216.34:443"] = false
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.35"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.35",
}
deps.resolver.mu.Unlock()
deps.portChecker.mu.Lock()
deps.portChecker.results["93.184.216.35:80"] = false
deps.portChecker.results["93.184.216.35:443"] = false
deps.portChecker.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for record change")
}
}
func TestPortStateChange(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
w.RunOnce(ctx)
deps.portChecker.mu.Lock()
deps.portChecker.results["1.2.3.4:443"] = false
deps.portChecker.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for port state change")
}
}
func TestTLSExpiryWarning(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(3 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
// First run = baseline
w.RunOnce(ctx)
// Second run should warn about expiry
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
found := false
for _, n := range notifications {
if n.Priority == "warning" {
found = true
}
}
if !found {
t.Errorf(
"expected expiry warning, got: %v",
notifications,
)
}
}
func TestGracefulShutdown(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.DNSInterval = 100 * time.Millisecond
cfg.TLSInterval = 100 * time.Millisecond
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
}
ctx, cancel := context.WithCancel(t.Context())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
time.Sleep(250 * time.Millisecond)
cancel()
select {
case <-done:
// Shut down cleanly
case <-time.After(5 * time.Second):
t.Error("watcher did not shut down within timeout")
}
}
func TestNSFailureAndRecovery(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
"ns2.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = false
deps.portChecker.results["1.2.3.4:443"] = false
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for NS disappearance")
}
}