diff --git a/README.md b/README.md index 25fa039..0a9d555 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # dnswatcher +> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice. + dnswatcher is a production DNS and infrastructure monitoring daemon written in Go. It watches configured DNS domains and hostnames for changes, monitors TCP port availability, tracks TLS certificate expiry, and delivers real-time @@ -195,8 +197,7 @@ the following precedence (highest to lowest): | `PORT` | HTTP listen port | `8080` | | `DNSWATCHER_DEBUG` | Enable debug logging | `false` | | `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` | -| `DNSWATCHER_DOMAINS` | Comma-separated list of apex domains | `""` | -| `DNSWATCHER_HOSTNAMES` | Comma-separated list of hostnames | `""` | +| `DNSWATCHER_TARGETS` | Comma-separated DNS names (auto-classified via PSL) | `""` | | `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` | | `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` | | `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` | @@ -214,8 +215,7 @@ the following precedence (highest to lowest): PORT=8080 DNSWATCHER_DEBUG=false DNSWATCHER_DATA_DIR=./data -DNSWATCHER_DOMAINS=example.com,example.org -DNSWATCHER_HOSTNAMES=www.example.com,api.example.com,mail.example.org +DNSWATCHER_TARGETS=example.com,example.org,www.example.com,api.example.com,mail.example.org DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts @@ -352,8 +352,7 @@ docker build -t dnswatcher . docker run -d \ -p 8080:8080 \ -v dnswatcher-data:/var/lib/dnswatcher \ - -e DNSWATCHER_DOMAINS=example.com \ - -e DNSWATCHER_HOSTNAMES=www.example.com \ + -e DNSWATCHER_TARGETS=example.com,www.example.com \ -e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \ dnswatcher ``` diff --git a/cmd/dnswatcher/main.go b/cmd/dnswatcher/main.go index 0dd77b9..03b38da 100644 --- a/cmd/dnswatcher/main.go +++ b/cmd/dnswatcher/main.go @@ -51,6 +51,20 @@ func main() { handlers.New, server.New, ), + fx.Provide( + func(r *resolver.Resolver) watcher.DNSResolver { + return r + }, + func(p *portcheck.Checker) watcher.PortChecker { + return p + }, + func(t *tlscheck.Checker) watcher.TLSChecker { + return t + }, + func(n *notify.Service) watcher.Notifier { + return n + }, + ), fx.Invoke(func(*server.Server, *watcher.Watcher) {}), ).Run() } diff --git a/go.mod b/go.mod index 777aff0..58794b3 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 go.uber.org/fx v1.24.0 + golang.org/x/net v0.50.0 ) require ( @@ -37,12 +38,11 @@ require ( go.uber.org/zap v1.26.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/mod v0.31.0 // indirect - golang.org/x/net v0.48.0 // indirect + golang.org/x/mod v0.32.0 // indirect golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect - golang.org/x/tools v0.40.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect + golang.org/x/tools v0.41.0 // indirect google.golang.org/protobuf v1.36.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0fd1a03..720b18f 100644 --- a/go.sum +++ b/go.sum @@ -76,18 +76,18 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/config/classify.go b/internal/config/classify.go new file mode 100644 index 0000000..1076215 --- /dev/null +++ b/internal/config/classify.go @@ -0,0 +1,85 @@ +package config + +import ( + "errors" + "fmt" + "strings" + + "golang.org/x/net/publicsuffix" +) + +// DNSNameType indicates whether a DNS name is an apex domain or a hostname. +type DNSNameType int + +const ( + // DNSNameTypeDomain indicates the name is an apex (eTLD+1) domain. + DNSNameTypeDomain DNSNameType = iota + // DNSNameTypeHostname indicates the name is a subdomain/hostname. + DNSNameTypeHostname +) + +// ErrEmptyDNSName is returned when an empty string is passed to ClassifyDNSName. +var ErrEmptyDNSName = errors.New("empty DNS name") + +// String returns the string representation of a DNSNameType. +func (t DNSNameType) String() string { + switch t { + case DNSNameTypeDomain: + return "domain" + case DNSNameTypeHostname: + return "hostname" + default: + return "unknown" + } +} + +// ClassifyDNSName determines whether a DNS name is an apex domain or a +// hostname (subdomain) using the Public Suffix List. It returns an error +// if the input is empty or is itself a public suffix (e.g. "co.uk"). +func ClassifyDNSName(name string) (DNSNameType, error) { + name = strings.ToLower(strings.TrimSuffix(strings.TrimSpace(name), ".")) + + if name == "" { + return 0, ErrEmptyDNSName + } + + etld1, err := publicsuffix.EffectiveTLDPlusOne(name) + if err != nil { + return 0, fmt.Errorf("invalid DNS name %q: %w", name, err) + } + + if name == etld1 { + return DNSNameTypeDomain, nil + } + + return DNSNameTypeHostname, nil +} + +// ClassifyTargets splits a list of DNS names into apex domains and +// hostnames using the Public Suffix List. It returns an error if any +// name cannot be classified. +func ClassifyTargets(targets []string) ([]string, []string, error) { + var domains, hostnames []string + + for _, t := range targets { + normalized := strings.ToLower(strings.TrimSuffix(strings.TrimSpace(t), ".")) + + if normalized == "" { + continue + } + + typ, classErr := ClassifyDNSName(normalized) + if classErr != nil { + return nil, nil, classErr + } + + switch typ { + case DNSNameTypeDomain: + domains = append(domains, normalized) + case DNSNameTypeHostname: + hostnames = append(hostnames, normalized) + } + } + + return domains, hostnames, nil +} diff --git a/internal/config/classify_test.go b/internal/config/classify_test.go new file mode 100644 index 0000000..fb21bbc --- /dev/null +++ b/internal/config/classify_test.go @@ -0,0 +1,83 @@ +package config_test + +import ( + "testing" + + "sneak.berlin/go/dnswatcher/internal/config" +) + +func TestClassifyDNSName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want config.DNSNameType + wantErr bool + }{ + {name: "apex domain simple", input: "example.com", want: config.DNSNameTypeDomain}, + {name: "hostname simple", input: "www.example.com", want: config.DNSNameTypeHostname}, + {name: "apex domain multi-part TLD", input: "example.co.uk", want: config.DNSNameTypeDomain}, + {name: "hostname multi-part TLD", input: "api.example.co.uk", want: config.DNSNameTypeHostname}, + {name: "public suffix itself", input: "co.uk", wantErr: true}, + {name: "empty string", input: "", wantErr: true}, + {name: "deeply nested hostname", input: "a.b.c.example.com", want: config.DNSNameTypeHostname}, + {name: "trailing dot stripped", input: "example.com.", want: config.DNSNameTypeDomain}, + {name: "uppercase normalized", input: "WWW.Example.COM", want: config.DNSNameTypeHostname}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := config.ClassifyDNSName(tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("ClassifyDNSName(%q) expected error, got %v", tt.input, got) + } + + return + } + + if err != nil { + t.Fatalf("ClassifyDNSName(%q) unexpected error: %v", tt.input, err) + } + + if got != tt.want { + t.Errorf("ClassifyDNSName(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestClassifyTargets(t *testing.T) { + t.Parallel() + + domains, hostnames, err := config.ClassifyTargets([]string{ + "example.com", + "www.example.com", + "example.co.uk", + "api.example.co.uk", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(domains) != 2 { + t.Errorf("expected 2 domains, got %d: %v", len(domains), domains) + } + + if len(hostnames) != 2 { + t.Errorf("expected 2 hostnames, got %d: %v", len(hostnames), hostnames) + } +} + +func TestClassifyTargetsRejectsPublicSuffix(t *testing.T) { + t.Parallel() + + _, _, err := config.ClassifyTargets([]string{"co.uk"}) + if err == nil { + t.Error("expected error for public suffix, got nil") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index b43027d..0acf89e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -89,8 +89,7 @@ func setupViper(name string) { viper.SetDefault("PORT", defaultPort) viper.SetDefault("DEBUG", false) viper.SetDefault("DATA_DIR", "./data") - viper.SetDefault("DOMAINS", "") - viper.SetDefault("HOSTNAMES", "") + viper.SetDefault("TARGETS", "") viper.SetDefault("SLACK_WEBHOOK", "") viper.SetDefault("MATTERMOST_WEBHOOK", "") viper.SetDefault("NTFY_TOPIC", "") @@ -133,12 +132,19 @@ func buildConfig( tlsInterval = defaultTLSInterval } + domains, hostnames, err := ClassifyTargets( + parseCSV(viper.GetString("TARGETS")), + ) + if err != nil { + return nil, fmt.Errorf("invalid targets configuration: %w", err) + } + cfg := &Config{ Port: viper.GetInt("PORT"), Debug: viper.GetBool("DEBUG"), DataDir: viper.GetString("DATA_DIR"), - Domains: parseCSV(viper.GetString("DOMAINS")), - Hostnames: parseCSV(viper.GetString("HOSTNAMES")), + Domains: domains, + Hostnames: hostnames, SlackWebhook: viper.GetString("SLACK_WEBHOOK"), MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"), NtfyTopic: viper.GetString("NTFY_TOPIC"), diff --git a/internal/notify/notify.go b/internal/notify/notify.go index e0bc3f8..ea57155 100644 --- a/internal/notify/notify.go +++ b/internal/notify/notify.go @@ -36,16 +36,6 @@ var ( ) ) -// sanitizeURL parses and re-serializes a URL to satisfy static analysis (gosec G704). -func sanitizeURL(raw string) (string, error) { - u, err := url.Parse(raw) - if err != nil { - return "", fmt.Errorf("invalid URL %q: %w", raw, err) - } - - return u.String(), nil -} - // Params contains dependencies for Service. type Params struct { fx.In @@ -56,9 +46,12 @@ type Params struct { // Service provides notification functionality. type Service struct { - log *slog.Logger - client *http.Client - config *config.Config + log *slog.Logger + client *http.Client + config *config.Config + ntfyURL *url.URL + slackWebhookURL *url.URL + mattermostWebhookURL *url.URL } // New creates a new notify Service. @@ -66,13 +59,44 @@ func New( _ fx.Lifecycle, params Params, ) (*Service, error) { - return &Service{ + svc := &Service{ log: params.Logger.Get(), client: &http.Client{ Timeout: httpClientTimeout, }, config: params.Config, - }, nil + } + + if params.Config.NtfyTopic != "" { + u, err := url.ParseRequestURI(params.Config.NtfyTopic) + if err != nil { + return nil, fmt.Errorf("invalid ntfy topic URL: %w", err) + } + + svc.ntfyURL = u + } + + if params.Config.SlackWebhook != "" { + u, err := url.ParseRequestURI(params.Config.SlackWebhook) + if err != nil { + return nil, fmt.Errorf("invalid slack webhook URL: %w", err) + } + + svc.slackWebhookURL = u + } + + if params.Config.MattermostWebhook != "" { + u, err := url.ParseRequestURI(params.Config.MattermostWebhook) + if err != nil { + return nil, fmt.Errorf( + "invalid mattermost webhook URL: %w", err, + ) + } + + svc.mattermostWebhookURL = u + } + + return svc, nil } // SendNotification sends a notification to all configured endpoints. @@ -80,13 +104,13 @@ func (svc *Service) SendNotification( ctx context.Context, title, message, priority string, ) { - if svc.config.NtfyTopic != "" { + if svc.ntfyURL != nil { go func() { notifyCtx := context.WithoutCancel(ctx) err := svc.sendNtfy( notifyCtx, - svc.config.NtfyTopic, + svc.ntfyURL, title, message, priority, ) if err != nil { @@ -98,13 +122,13 @@ func (svc *Service) SendNotification( }() } - if svc.config.SlackWebhook != "" { + if svc.slackWebhookURL != nil { go func() { notifyCtx := context.WithoutCancel(ctx) err := svc.sendSlack( notifyCtx, - svc.config.SlackWebhook, + svc.slackWebhookURL, title, message, priority, ) if err != nil { @@ -116,13 +140,13 @@ func (svc *Service) SendNotification( }() } - if svc.config.MattermostWebhook != "" { + if svc.mattermostWebhookURL != nil { go func() { notifyCtx := context.WithoutCancel(ctx) err := svc.sendSlack( notifyCtx, - svc.config.MattermostWebhook, + svc.mattermostWebhookURL, title, message, priority, ) if err != nil { @@ -137,23 +161,19 @@ func (svc *Service) SendNotification( func (svc *Service) sendNtfy( ctx context.Context, - topic, title, message, priority string, + topicURL *url.URL, + title, message, priority string, ) error { svc.log.Debug( "sending ntfy notification", - "topic", topic, + "topic", topicURL.String(), "title", title, ) - cleanURL, err := sanitizeURL(topic) - if err != nil { - return fmt.Errorf("invalid ntfy topic URL: %w", err) - } - request, err := http.NewRequestWithContext( ctx, http.MethodPost, - cleanURL, + topicURL.String(), bytes.NewBufferString(message), ) if err != nil { @@ -163,7 +183,7 @@ func (svc *Service) sendNtfy( request.Header.Set("Title", title) request.Header.Set("Priority", ntfyPriority(priority)) - resp, err := svc.client.Do(request) + resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time if err != nil { return fmt.Errorf("sending ntfy request: %w", err) } @@ -209,11 +229,12 @@ type SlackAttachment struct { func (svc *Service) sendSlack( ctx context.Context, - webhookURL, title, message, priority string, + webhookURL *url.URL, + title, message, priority string, ) error { svc.log.Debug( "sending webhook notification", - "url", webhookURL, + "url", webhookURL.String(), "title", title, ) @@ -232,15 +253,10 @@ func (svc *Service) sendSlack( return fmt.Errorf("marshaling webhook payload: %w", err) } - cleanURL, err := sanitizeURL(webhookURL) - if err != nil { - return fmt.Errorf("invalid webhook URL: %w", err) - } - request, err := http.NewRequestWithContext( ctx, http.MethodPost, - cleanURL, + webhookURL.String(), bytes.NewBuffer(body), ) if err != nil { @@ -249,7 +265,7 @@ func (svc *Service) sendSlack( request.Header.Set("Content-Type", "application/json") - resp, err := svc.client.Do(request) + resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time if err != nil { return fmt.Errorf("sending webhook request: %w", err) } diff --git a/internal/state/state_test_helper.go b/internal/state/state_test_helper.go new file mode 100644 index 0000000..2142bfb --- /dev/null +++ b/internal/state/state_test_helper.go @@ -0,0 +1,22 @@ +package state + +import ( + "log/slog" + + "sneak.berlin/go/dnswatcher/internal/config" +) + +// NewForTest creates a State for unit testing with no persistence. +func NewForTest() *State { + return &State{ + log: slog.Default(), + snapshot: &Snapshot{ + Version: stateVersion, + Domains: make(map[string]*DomainState), + Hostnames: make(map[string]*HostnameState), + Ports: make(map[string]*PortState), + Certificates: make(map[string]*CertificateState), + }, + config: &config.Config{DataDir: ""}, + } +} diff --git a/internal/watcher/interfaces.go b/internal/watcher/interfaces.go new file mode 100644 index 0000000..695139d --- /dev/null +++ b/internal/watcher/interfaces.go @@ -0,0 +1,60 @@ +// Package watcher provides the main monitoring orchestrator. +package watcher + +import ( + "context" + + "sneak.berlin/go/dnswatcher/internal/tlscheck" +) + +// DNSResolver performs iterative DNS resolution. +type DNSResolver interface { + // LookupNS discovers authoritative nameservers for a domain. + LookupNS( + ctx context.Context, + domain string, + ) ([]string, error) + + // LookupAllRecords queries all record types for a hostname, + // returning results keyed by nameserver then record type. + LookupAllRecords( + ctx context.Context, + hostname string, + ) (map[string]map[string][]string, error) + + // ResolveIPAddresses resolves a hostname to all IP addresses. + ResolveIPAddresses( + ctx context.Context, + hostname string, + ) ([]string, error) +} + +// PortChecker tests TCP port connectivity. +type PortChecker interface { + // CheckPort tests TCP connectivity to an address and port. + CheckPort( + ctx context.Context, + address string, + port int, + ) (bool, error) +} + +// TLSChecker inspects TLS certificates. +type TLSChecker interface { + // CheckCertificate connects via TLS and returns cert info. + CheckCertificate( + ctx context.Context, + ip string, + hostname string, + ) (*tlscheck.CertificateInfo, error) +} + +// Notifier delivers notifications to configured endpoints. +type Notifier interface { + // SendNotification sends a notification with the given + // details. + SendNotification( + ctx context.Context, + title, message, priority string, + ) +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index a9b1caa..2493264 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1,21 +1,30 @@ -// Package watcher provides the main monitoring orchestrator and scheduler. package watcher import ( "context" + "fmt" "log/slog" + "sort" + "strings" + "time" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/logger" - "sneak.berlin/go/dnswatcher/internal/notify" - "sneak.berlin/go/dnswatcher/internal/portcheck" - "sneak.berlin/go/dnswatcher/internal/resolver" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" ) +// monitoredPorts are the TCP ports checked for each IP address. +var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals + +// tlsPort is the port used for TLS certificate checks. +const tlsPort = 443 + +// hoursPerDay converts days to hours for duration calculations. +const hoursPerDay = 24 + // Params contains dependencies for Watcher. type Params struct { fx.In @@ -23,10 +32,10 @@ type Params struct { Logger *logger.Logger Config *config.Config State *state.State - Resolver *resolver.Resolver - PortCheck *portcheck.Checker - TLSCheck *tlscheck.Checker - Notify *notify.Service + Resolver DNSResolver + PortCheck PortChecker + TLSCheck TLSChecker + Notify Notifier } // Watcher orchestrates all monitoring checks on a schedule. @@ -34,19 +43,20 @@ type Watcher struct { log *slog.Logger config *config.Config state *state.State - resolver *resolver.Resolver - portCheck *portcheck.Checker - tlsCheck *tlscheck.Checker - notify *notify.Service + resolver DNSResolver + portCheck PortChecker + tlsCheck TLSChecker + notify Notifier cancel context.CancelFunc + firstRun bool } -// New creates a new Watcher instance. +// New creates a new Watcher instance wired into the fx lifecycle. func New( lifecycle fx.Lifecycle, params Params, ) (*Watcher, error) { - watcher := &Watcher{ + w := &Watcher{ log: params.Logger.Get(), config: params.Config, state: params.State, @@ -54,30 +64,54 @@ func New( portCheck: params.PortCheck, tlsCheck: params.TLSCheck, notify: params.Notify, + firstRun: true, } lifecycle.Append(fx.Hook{ OnStart: func(startCtx context.Context) error { - ctx, cancel := context.WithCancel(startCtx) - watcher.cancel = cancel + ctx, cancel := context.WithCancel( + context.WithoutCancel(startCtx), + ) + w.cancel = cancel - go watcher.Run(ctx) + go w.Run(ctx) return nil }, OnStop: func(_ context.Context) error { - if watcher.cancel != nil { - watcher.cancel() + if w.cancel != nil { + w.cancel() } return nil }, }) - return watcher, nil + return w, nil } -// Run starts the monitoring loop. +// NewForTest creates a Watcher without fx for unit testing. +func NewForTest( + cfg *config.Config, + st *state.State, + res DNSResolver, + pc PortChecker, + tc TLSChecker, + n Notifier, +) *Watcher { + return &Watcher{ + log: slog.Default(), + config: cfg, + state: st, + resolver: res, + portCheck: pc, + tlsCheck: tc, + notify: n, + firstRun: true, + } +} + +// Run starts the monitoring loop with periodic scheduling. func (w *Watcher) Run(ctx context.Context) { w.log.Info( "watcher starting", @@ -87,8 +121,646 @@ func (w *Watcher) Run(ctx context.Context) { "tlsInterval", w.config.TLSInterval, ) - // Stub: wait for context cancellation. - // Implementation will add initial check + periodic scheduling. - <-ctx.Done() - w.log.Info("watcher stopped") + w.RunOnce(ctx) + + dnsTicker := time.NewTicker(w.config.DNSInterval) + tlsTicker := time.NewTicker(w.config.TLSInterval) + + defer dnsTicker.Stop() + defer tlsTicker.Stop() + + for { + select { + case <-ctx.Done(): + w.log.Info("watcher stopped") + + return + case <-dnsTicker.C: + w.runDNSAndPortChecks(ctx) + w.saveState() + case <-tlsTicker.C: + w.runTLSChecks(ctx) + w.saveState() + } + } +} + +// RunOnce performs a single complete monitoring cycle. +func (w *Watcher) RunOnce(ctx context.Context) { + w.detectFirstRun() + w.runDNSAndPortChecks(ctx) + w.runTLSChecks(ctx) + w.saveState() + w.firstRun = false +} + +func (w *Watcher) detectFirstRun() { + snap := w.state.GetSnapshot() + hasState := len(snap.Domains) > 0 || + len(snap.Hostnames) > 0 || + len(snap.Ports) > 0 || + len(snap.Certificates) > 0 + + if hasState { + w.firstRun = false + } +} + +func (w *Watcher) runDNSAndPortChecks(ctx context.Context) { + for _, domain := range w.config.Domains { + w.checkDomain(ctx, domain) + } + + for _, hostname := range w.config.Hostnames { + w.checkHostname(ctx, hostname) + } + + w.checkAllPorts(ctx) +} + +func (w *Watcher) checkDomain( + ctx context.Context, + domain string, +) { + nameservers, err := w.resolver.LookupNS(ctx, domain) + if err != nil { + w.log.Error( + "failed to lookup NS", + "domain", domain, + "error", err, + ) + + return + } + + sort.Strings(nameservers) + + now := time.Now().UTC() + + prev, hasPrev := w.state.GetDomainState(domain) + if hasPrev && !w.firstRun { + w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers) + } + + w.state.SetDomainState(domain, &state.DomainState{ + Nameservers: nameservers, + LastChecked: now, + }) +} + +func (w *Watcher) detectNSChanges( + ctx context.Context, + domain string, + oldNS, newNS []string, +) { + oldSet := toSet(oldNS) + newSet := toSet(newNS) + + var added, removed []string + + for ns := range newSet { + if !oldSet[ns] { + added = append(added, ns) + } + } + + for ns := range oldSet { + if !newSet[ns] { + removed = append(removed, ns) + } + } + + if len(added) == 0 && len(removed) == 0 { + return + } + + msg := fmt.Sprintf( + "Domain: %s\nAdded: %s\nRemoved: %s", + domain, + strings.Join(added, ", "), + strings.Join(removed, ", "), + ) + + w.notify.SendNotification( + ctx, + "NS Change: "+domain, + msg, + "warning", + ) +} + +func (w *Watcher) checkHostname( + ctx context.Context, + hostname string, +) { + records, err := w.resolver.LookupAllRecords(ctx, hostname) + if err != nil { + w.log.Error( + "failed to lookup records", + "hostname", hostname, + "error", err, + ) + + return + } + + now := time.Now().UTC() + prev, hasPrev := w.state.GetHostnameState(hostname) + + if hasPrev && !w.firstRun { + w.detectHostnameChanges(ctx, hostname, prev, records) + } + + newState := buildHostnameState(records, now) + w.state.SetHostnameState(hostname, newState) +} + +func buildHostnameState( + records map[string]map[string][]string, + now time.Time, +) *state.HostnameState { + hs := &state.HostnameState{ + RecordsByNameserver: make( + map[string]*state.NameserverRecordState, + ), + LastChecked: now, + } + + for ns, recs := range records { + hs.RecordsByNameserver[ns] = &state.NameserverRecordState{ + Records: recs, + Status: "ok", + LastChecked: now, + } + } + + return hs +} + +func (w *Watcher) detectHostnameChanges( + ctx context.Context, + hostname string, + prev *state.HostnameState, + current map[string]map[string][]string, +) { + w.detectRecordChanges(ctx, hostname, prev, current) + w.detectNSDisappearances(ctx, hostname, prev, current) + w.detectInconsistencies(ctx, hostname, current) +} + +func (w *Watcher) detectRecordChanges( + ctx context.Context, + hostname string, + prev *state.HostnameState, + current map[string]map[string][]string, +) { + for ns, recs := range current { + prevNS, ok := prev.RecordsByNameserver[ns] + if !ok { + continue + } + + if recordsEqual(prevNS.Records, recs) { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\nNameserver: %s\n"+ + "Old: %v\nNew: %v", + hostname, ns, + prevNS.Records, recs, + ) + + w.notify.SendNotification( + ctx, + "Record Change: "+hostname, + msg, + "warning", + ) + } +} + +func (w *Watcher) detectNSDisappearances( + ctx context.Context, + hostname string, + prev *state.HostnameState, + current map[string]map[string][]string, +) { + for ns, prevNS := range prev.RecordsByNameserver { + if _, ok := current[ns]; ok || prevNS.Status != "ok" { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\nNameserver: %s disappeared", + hostname, ns, + ) + + w.notify.SendNotification( + ctx, + "NS Failure: "+hostname, + msg, + "error", + ) + } + + for ns := range current { + prevNS, ok := prev.RecordsByNameserver[ns] + if !ok || prevNS.Status != "error" { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\nNameserver: %s recovered", + hostname, ns, + ) + + w.notify.SendNotification( + ctx, + "NS Recovery: "+hostname, + msg, + "success", + ) + } +} + +func (w *Watcher) detectInconsistencies( + ctx context.Context, + hostname string, + current map[string]map[string][]string, +) { + nameservers := make([]string, 0, len(current)) + for ns := range current { + nameservers = append(nameservers, ns) + } + + sort.Strings(nameservers) + + for i := range len(nameservers) - 1 { + ns1 := nameservers[i] + ns2 := nameservers[i+1] + + if recordsEqual(current[ns1], current[ns2]) { + continue + } + + msg := fmt.Sprintf( + "Hostname: %s\n%s: %v\n%s: %v", + hostname, + ns1, current[ns1], + ns2, current[ns2], + ) + + w.notify.SendNotification( + ctx, + "Inconsistency: "+hostname, + msg, + "warning", + ) + } +} + +func (w *Watcher) checkAllPorts(ctx context.Context) { + for _, hostname := range w.config.Hostnames { + w.checkPortsForHostname(ctx, hostname) + } + + for _, domain := range w.config.Domains { + w.checkPortsForHostname(ctx, domain) + } +} + +func (w *Watcher) checkPortsForHostname( + ctx context.Context, + hostname string, +) { + ips := w.collectIPs(hostname) + + for _, ip := range ips { + for _, port := range monitoredPorts { + w.checkSinglePort(ctx, ip, port, hostname) + } + } +} + +func (w *Watcher) collectIPs(hostname string) []string { + hs, ok := w.state.GetHostnameState(hostname) + if !ok { + return nil + } + + ipSet := make(map[string]bool) + + for _, nsState := range hs.RecordsByNameserver { + for _, ip := range nsState.Records["A"] { + ipSet[ip] = true + } + + for _, ip := range nsState.Records["AAAA"] { + ipSet[ip] = true + } + } + + result := make([]string, 0, len(ipSet)) + for ip := range ipSet { + result = append(result, ip) + } + + sort.Strings(result) + + return result +} + +func (w *Watcher) checkSinglePort( + ctx context.Context, + ip string, + port int, + hostname string, +) { + open, err := w.portCheck.CheckPort(ctx, ip, port) + if err != nil { + w.log.Error( + "port check failed", + "ip", ip, + "port", port, + "error", err, + ) + + return + } + + key := fmt.Sprintf("%s:%d", ip, port) + now := time.Now().UTC() + prev, hasPrev := w.state.GetPortState(key) + + if hasPrev && !w.firstRun && prev.Open != open { + stateStr := "closed" + if open { + stateStr = "open" + } + + msg := fmt.Sprintf( + "Host: %s\nAddress: %s\nPort now %s", + hostname, key, stateStr, + ) + + w.notify.SendNotification( + ctx, + "Port Change: "+key, + msg, + "warning", + ) + } + + w.state.SetPortState(key, &state.PortState{ + Open: open, + Hostname: hostname, + LastChecked: now, + }) +} + +func (w *Watcher) runTLSChecks(ctx context.Context) { + for _, hostname := range w.config.Hostnames { + w.checkTLSForHostname(ctx, hostname) + } + + for _, domain := range w.config.Domains { + w.checkTLSForHostname(ctx, domain) + } +} + +func (w *Watcher) checkTLSForHostname( + ctx context.Context, + hostname string, +) { + ips := w.collectIPs(hostname) + + for _, ip := range ips { + portKey := fmt.Sprintf("%s:%d", ip, tlsPort) + + ps, ok := w.state.GetPortState(portKey) + if !ok || !ps.Open { + continue + } + + w.checkTLSCert(ctx, ip, hostname) + } +} + +func (w *Watcher) checkTLSCert( + ctx context.Context, + ip string, + hostname string, +) { + cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname) + certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname) + now := time.Now().UTC() + prev, hasPrev := w.state.GetCertificateState(certKey) + + if err != nil { + w.handleTLSError( + ctx, certKey, hostname, ip, + hasPrev, prev, now, err, + ) + + return + } + + w.handleTLSSuccess( + ctx, certKey, hostname, ip, + hasPrev, prev, now, cert, + ) +} + +func (w *Watcher) handleTLSError( + ctx context.Context, + certKey, hostname, ip string, + hasPrev bool, + prev *state.CertificateState, + now time.Time, + err error, +) { + if hasPrev && !w.firstRun && prev.Status == "ok" { + msg := fmt.Sprintf( + "Host: %s\nIP: %s\nError: %s", + hostname, ip, err, + ) + + w.notify.SendNotification( + ctx, + "TLS Failure: "+hostname, + msg, + "error", + ) + } + + w.state.SetCertificateState( + certKey, &state.CertificateState{ + Status: "error", + Error: err.Error(), + LastChecked: now, + }, + ) +} + +func (w *Watcher) handleTLSSuccess( + ctx context.Context, + certKey, hostname, ip string, + hasPrev bool, + prev *state.CertificateState, + now time.Time, + cert *tlscheck.CertificateInfo, +) { + if hasPrev && !w.firstRun { + w.detectTLSChanges(ctx, hostname, ip, prev, cert) + } + + w.checkTLSExpiry(ctx, hostname, ip, cert) + + w.state.SetCertificateState( + certKey, &state.CertificateState{ + CommonName: cert.CommonName, + Issuer: cert.Issuer, + NotAfter: cert.NotAfter, + SubjectAlternativeNames: cert.SubjectAlternativeNames, + Status: "ok", + LastChecked: now, + }, + ) +} + +func (w *Watcher) detectTLSChanges( + ctx context.Context, + hostname, ip string, + prev *state.CertificateState, + cert *tlscheck.CertificateInfo, +) { + if prev.Status == "error" { + msg := fmt.Sprintf( + "Host: %s\nIP: %s\nTLS recovered", + hostname, ip, + ) + + w.notify.SendNotification( + ctx, + "TLS Recovery: "+hostname, + msg, + "success", + ) + + return + } + + changed := prev.CommonName != cert.CommonName || + prev.Issuer != cert.Issuer || + !sliceEqual( + prev.SubjectAlternativeNames, + cert.SubjectAlternativeNames, + ) + + if !changed { + return + } + + msg := fmt.Sprintf( + "Host: %s\nIP: %s\n"+ + "Old CN: %s, Issuer: %s\n"+ + "New CN: %s, Issuer: %s", + hostname, ip, + prev.CommonName, prev.Issuer, + cert.CommonName, cert.Issuer, + ) + + w.notify.SendNotification( + ctx, + "TLS Certificate Change: "+hostname, + msg, + "warning", + ) +} + +func (w *Watcher) checkTLSExpiry( + ctx context.Context, + hostname, ip string, + cert *tlscheck.CertificateInfo, +) { + daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay + warningDays := float64(w.config.TLSExpiryWarning) + + if daysLeft > warningDays { + return + } + + msg := fmt.Sprintf( + "Host: %s\nIP: %s\nCN: %s\n"+ + "Expires: %s (%.0f days)", + hostname, ip, cert.CommonName, + cert.NotAfter.Format(time.RFC3339), + daysLeft, + ) + + w.notify.SendNotification( + ctx, + "TLS Expiry Warning: "+hostname, + msg, + "warning", + ) +} + +func (w *Watcher) saveState() { + err := w.state.Save() + if err != nil { + w.log.Error("failed to save state", "error", err) + } +} + +// --- Utility functions --- + +func toSet(items []string) map[string]bool { + set := make(map[string]bool, len(items)) + for _, item := range items { + set[item] = true + } + + return set +} + +func recordsEqual( + a, b map[string][]string, +) bool { + if len(a) != len(b) { + return false + } + + for k, av := range a { + bv, ok := b[k] + if !ok || !sliceEqual(av, bv) { + return false + } + } + + return true +} + +func sliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + + aSorted := make([]string, len(a)) + bSorted := make([]string, len(b)) + + copy(aSorted, a) + copy(bSorted, b) + + sort.Strings(aSorted) + sort.Strings(bSorted) + + for i := range aSorted { + if aSorted[i] != bSorted[i] { + return false + } + } + + return true } diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go new file mode 100644 index 0000000..69772aa --- /dev/null +++ b/internal/watcher/watcher_test.go @@ -0,0 +1,580 @@ +package watcher_test + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "sneak.berlin/go/dnswatcher/internal/config" + "sneak.berlin/go/dnswatcher/internal/state" + "sneak.berlin/go/dnswatcher/internal/tlscheck" + "sneak.berlin/go/dnswatcher/internal/watcher" +) + +// errNotFound is returned when mock data is missing. +var errNotFound = errors.New("not found") + +// --- Mock implementations --- + +type mockResolver struct { + mu sync.Mutex + nsRecords map[string][]string + allRecords map[string]map[string]map[string][]string + ipAddresses map[string][]string + lookupNSErr error + allRecordsErr error + resolveIPErr error + lookupNSCalls int + allRecordCalls int +} + +func (m *mockResolver) LookupNS( + _ context.Context, + domain string, +) ([]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.lookupNSCalls++ + + if m.lookupNSErr != nil { + return nil, m.lookupNSErr + } + + ns, ok := m.nsRecords[domain] + if !ok { + return nil, fmt.Errorf( + "%w: NS for %s", errNotFound, domain, + ) + } + + return ns, nil +} + +func (m *mockResolver) LookupAllRecords( + _ context.Context, + hostname string, +) (map[string]map[string][]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.allRecordCalls++ + + if m.allRecordsErr != nil { + return nil, m.allRecordsErr + } + + recs, ok := m.allRecords[hostname] + if !ok { + return nil, fmt.Errorf( + "%w: records for %s", errNotFound, hostname, + ) + } + + return recs, nil +} + +func (m *mockResolver) ResolveIPAddresses( + _ context.Context, + hostname string, +) ([]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.resolveIPErr != nil { + return nil, m.resolveIPErr + } + + ips, ok := m.ipAddresses[hostname] + if !ok { + return nil, fmt.Errorf( + "%w: IPs for %s", errNotFound, hostname, + ) + } + + return ips, nil +} + +type mockPortChecker struct { + mu sync.Mutex + results map[string]bool + err error + calls int +} + +func (m *mockPortChecker) CheckPort( + _ context.Context, + address string, + port int, +) (bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls++ + + if m.err != nil { + return false, m.err + } + + key := fmt.Sprintf("%s:%d", address, port) + open, ok := m.results[key] + + if !ok { + return false, nil + } + + return open, nil +} + +type mockTLSChecker struct { + mu sync.Mutex + certs map[string]*tlscheck.CertificateInfo + err error + calls int +} + +func (m *mockTLSChecker) CheckCertificate( + _ context.Context, + ip string, + hostname string, +) (*tlscheck.CertificateInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.calls++ + + if m.err != nil { + return nil, m.err + } + + key := fmt.Sprintf("%s:%s", ip, hostname) + cert, ok := m.certs[key] + + if !ok { + return nil, fmt.Errorf( + "%w: cert for %s", errNotFound, key, + ) + } + + return cert, nil +} + +type notification struct { + Title string + Message string + Priority string +} + +type mockNotifier struct { + mu sync.Mutex + notifications []notification +} + +func (m *mockNotifier) SendNotification( + _ context.Context, + title, message, priority string, +) { + m.mu.Lock() + defer m.mu.Unlock() + + m.notifications = append(m.notifications, notification{ + Title: title, + Message: message, + Priority: priority, + }) +} + +func (m *mockNotifier) getNotifications() []notification { + m.mu.Lock() + defer m.mu.Unlock() + + result := make([]notification, len(m.notifications)) + copy(result, m.notifications) + + return result +} + +// --- Helper to build a Watcher for testing --- + +type testDeps struct { + resolver *mockResolver + portChecker *mockPortChecker + tlsChecker *mockTLSChecker + notifier *mockNotifier + state *state.State + config *config.Config +} + +func newTestWatcher( + t *testing.T, + cfg *config.Config, +) (*watcher.Watcher, *testDeps) { + t.Helper() + + deps := &testDeps{ + resolver: &mockResolver{ + nsRecords: make(map[string][]string), + allRecords: make(map[string]map[string]map[string][]string), + ipAddresses: make(map[string][]string), + }, + portChecker: &mockPortChecker{ + results: make(map[string]bool), + }, + tlsChecker: &mockTLSChecker{ + certs: make(map[string]*tlscheck.CertificateInfo), + }, + notifier: &mockNotifier{}, + config: cfg, + } + + deps.state = state.NewForTest() + + w := watcher.NewForTest( + deps.config, + deps.state, + deps.resolver, + deps.portChecker, + deps.tlsChecker, + deps.notifier, + ) + + return w, deps +} + +func defaultTestConfig(t *testing.T) *config.Config { + t.Helper() + + return &config.Config{ + DNSInterval: time.Hour, + TLSInterval: 12 * time.Hour, + TLSExpiryWarning: 7, + DataDir: t.TempDir(), + } +} + +func TestFirstRunBaseline(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Domains = []string{"example.com"} + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + setupBaselineMocks(deps) + + w.RunOnce(t.Context()) + + assertNoNotifications(t, deps) + assertStatePopulated(t, deps) +} + +func setupBaselineMocks(deps *testDeps) { + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + "ns2.example.com.", + } + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"93.184.216.34"}}, + "ns2.example.com.": {"A": {"93.184.216.34"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "93.184.216.34", + } + deps.portChecker.results["93.184.216.34:80"] = true + deps.portChecker.results["93.184.216.34:443"] = true + deps.tlsChecker.certs["93.184.216.34:www.example.com"] = &tlscheck.CertificateInfo{ + CommonName: "www.example.com", + Issuer: "DigiCert", + NotAfter: time.Now().Add(90 * 24 * time.Hour), + SubjectAlternativeNames: []string{ + "www.example.com", + }, + } +} + +func assertNoNotifications( + t *testing.T, + deps *testDeps, +) { + t.Helper() + + notifications := deps.notifier.getNotifications() + if len(notifications) != 0 { + t.Errorf( + "expected 0 notifications on first run, got %d", + len(notifications), + ) + } +} + +func assertStatePopulated( + t *testing.T, + deps *testDeps, +) { + t.Helper() + + snap := deps.state.GetSnapshot() + + if len(snap.Domains) != 1 { + t.Errorf( + "expected 1 domain in state, got %d", + len(snap.Domains), + ) + } + + if len(snap.Hostnames) != 1 { + t.Errorf( + "expected 1 hostname in state, got %d", + len(snap.Hostnames), + ) + } +} + +func TestNSChangeDetection(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Domains = []string{"example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + "ns2.example.com.", + } + + ctx := t.Context() + w.RunOnce(ctx) + + deps.resolver.mu.Lock() + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + "ns3.example.com.", + } + deps.resolver.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for NS change") + } + + found := false + + for _, n := range notifications { + if n.Priority == "warning" { + found = true + } + } + + if !found { + t.Error("expected warning-priority NS change notification") + } +} + +func TestRecordChangeDetection(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"93.184.216.34"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "93.184.216.34", + } + deps.portChecker.results["93.184.216.34:80"] = false + deps.portChecker.results["93.184.216.34:443"] = false + + ctx := t.Context() + w.RunOnce(ctx) + + deps.resolver.mu.Lock() + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"93.184.216.35"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "93.184.216.35", + } + deps.resolver.mu.Unlock() + + deps.portChecker.mu.Lock() + deps.portChecker.results["93.184.216.35:80"] = false + deps.portChecker.results["93.184.216.35:443"] = false + deps.portChecker.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for record change") + } +} + +func TestPortStateChange(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "1.2.3.4", + } + deps.portChecker.results["1.2.3.4:80"] = true + deps.portChecker.results["1.2.3.4:443"] = true + deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{ + CommonName: "www.example.com", + Issuer: "DigiCert", + NotAfter: time.Now().Add(90 * 24 * time.Hour), + SubjectAlternativeNames: []string{ + "www.example.com", + }, + } + + ctx := t.Context() + w.RunOnce(ctx) + + deps.portChecker.mu.Lock() + deps.portChecker.results["1.2.3.4:443"] = false + deps.portChecker.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for port state change") + } +} + +func TestTLSExpiryWarning(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "1.2.3.4", + } + deps.portChecker.results["1.2.3.4:80"] = true + deps.portChecker.results["1.2.3.4:443"] = true + deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{ + CommonName: "www.example.com", + Issuer: "DigiCert", + NotAfter: time.Now().Add(3 * 24 * time.Hour), + SubjectAlternativeNames: []string{ + "www.example.com", + }, + } + + ctx := t.Context() + + // First run = baseline + w.RunOnce(ctx) + + // Second run should warn about expiry + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + + found := false + + for _, n := range notifications { + if n.Priority == "warning" { + found = true + } + } + + if !found { + t.Errorf( + "expected expiry warning, got: %v", + notifications, + ) + } +} + +func TestGracefulShutdown(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Domains = []string{"example.com"} + cfg.DNSInterval = 100 * time.Millisecond + cfg.TLSInterval = 100 * time.Millisecond + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.nsRecords["example.com"] = []string{ + "ns1.example.com.", + } + + ctx, cancel := context.WithCancel(t.Context()) + + done := make(chan struct{}) + + go func() { + w.Run(ctx) + close(done) + }() + + time.Sleep(250 * time.Millisecond) + cancel() + + select { + case <-done: + // Shut down cleanly + case <-time.After(5 * time.Second): + t.Error("watcher did not shut down within timeout") + } +} + +func TestNSFailureAndRecovery(t *testing.T) { + t.Parallel() + + cfg := defaultTestConfig(t) + cfg.Hostnames = []string{"www.example.com"} + + w, deps := newTestWatcher(t, cfg) + + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + "ns2.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.ipAddresses["www.example.com"] = []string{ + "1.2.3.4", + } + deps.portChecker.results["1.2.3.4:80"] = false + deps.portChecker.results["1.2.3.4:443"] = false + + ctx := t.Context() + + w.RunOnce(ctx) + + deps.resolver.mu.Lock() + deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ + "ns1.example.com.": {"A": {"1.2.3.4"}}, + } + deps.resolver.mu.Unlock() + + w.RunOnce(ctx) + + notifications := deps.notifier.getNotifications() + if len(notifications) == 0 { + t.Error("expected notification for NS disappearance") + } +}