4 Commits

13 changed files with 123 additions and 557 deletions

View File

@@ -1,6 +1,6 @@
.git/ .git
bin/ bin
*.md data
LICENSE .env
.editorconfig .DS_Store
.gitignore *.exe

View File

@@ -8,5 +8,8 @@ charset = utf-8
trim_trailing_whitespace = true trim_trailing_whitespace = true
insert_final_newline = true insert_final_newline = true
[*.go]
indent_style = tab
[Makefile] [Makefile]
indent_style = tab indent_style = tab

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 sneak
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt fmt-check test check clean hooks docker .PHONY: all build lint fmt fmt-check test check clean docker hooks
BINARY := dnswatcher BINARY := dnswatcher
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -18,32 +18,25 @@ fmt:
goimports -w . goimports -w .
fmt-check: fmt-check:
@test -z "$$(gofmt -l .)" || (echo "gofmt: files not formatted:" && gofmt -l . && exit 1) @test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
test: test:
go test -v -race -timeout 30s -cover ./... go test -v -race -cover -timeout 30s ./...
# Check runs all validation without making changes # Check runs all validation without making changes
# Used by CI and Docker build - fails if anything is wrong # Used by CI and Docker build - fails if anything is wrong
check: check: fmt-check lint test
@echo "==> Checking formatting..."
@test -z "$$(gofmt -l .)" || (echo "Files not formatted:" && gofmt -l . && exit 1)
@echo "==> Running linter..."
golangci-lint run --config .golangci.yml ./...
@echo "==> Running tests..."
go test -v -race -timeout 30s ./...
@echo "==> Building..." @echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/dnswatcher go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/dnswatcher
@echo "==> All checks passed!" @echo "==> All checks passed!"
clean: docker:
rm -rf bin/ docker build .
hooks: hooks:
@echo '#!/bin/sh' > .git/hooks/pre-commit @printf '#!/bin/sh\nset -e\nmake check\n' > .git/hooks/pre-commit
@echo 'make check' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit @chmod +x .git/hooks/pre-commit
@echo "Pre-commit hook installed." @echo "Pre-commit hook installed."
docker: clean:
docker build . rm -rf bin/

View File

@@ -1,10 +1,9 @@
# dnswatcher # dnswatcher
dnswatcher is a pre-1.0 Go daemon by [@sneak](https://sneak.berlin) that monitors DNS records, TCP port availability, and TLS certificates, delivering real-time change notifications via Slack, Mattermost, and ntfy webhooks.
> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice. > ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice.
dnswatcher watches configured DNS domains and hostnames for changes, monitors TCP dnswatcher is a production DNS and infrastructure monitoring daemon written in
Go. It watches configured DNS domains and hostnames for changes, monitors TCP
port availability, tracks TLS certificate expiry, and delivers real-time port availability, tracks TLS certificate expiry, and delivers real-time
notifications via Slack, Mattermost, and/or ntfy webhooks. notifications via Slack, Mattermost, and/or ntfy webhooks.
@@ -110,8 +109,8 @@ includes:
- **NS recoveries**: Which nameserver recovered, which hostname/domain. - **NS recoveries**: Which nameserver recovered, which hostname/domain.
- **NS inconsistencies**: Which nameservers disagree, what each one - **NS inconsistencies**: Which nameservers disagree, what each one
returned, which hostname affected. returned, which hostname affected.
- **Port changes**: Which IP:port, old state, new state, all associated - **Port changes**: Which IP:port, old state, new state, associated
hostnames. hostname.
- **TLS expiry warnings**: Which certificate, days remaining, CN, - **TLS expiry warnings**: Which certificate, days remaining, CN,
issuer, associated hostname and IP. issuer, associated hostname and IP.
- **TLS certificate changes**: Old and new CN/issuer/SANs, associated - **TLS certificate changes**: Old and new CN/issuer/SANs, associated
@@ -290,12 +289,12 @@ not as a merged view, to enable inconsistency detection.
"ports": { "ports": {
"93.184.216.34:80": { "93.184.216.34:80": {
"open": true, "open": true,
"hostnames": ["www.example.com"], "hostname": "www.example.com",
"lastChecked": "2026-02-19T12:00:00Z" "lastChecked": "2026-02-19T12:00:00Z"
}, },
"93.184.216.34:443": { "93.184.216.34:443": {
"open": true, "open": true,
"hostnames": ["www.example.com"], "hostname": "www.example.com",
"lastChecked": "2026-02-19T12:00:00Z" "lastChecked": "2026-02-19T12:00:00Z"
} }
}, },
@@ -328,10 +327,13 @@ tracks reachability:
```sh ```sh
make build # Build binary to bin/dnswatcher make build # Build binary to bin/dnswatcher
make test # Run tests with race detector make test # Run tests with race detector and 30s timeout
make lint # Run golangci-lint make lint # Run golangci-lint
make fmt # Format code make fmt # Format code (writes)
make check # Run all checks (format, lint, test, build) make fmt-check # Read-only format check
make check # Run all checks (fmt-check, lint, test, build)
make docker # Build Docker image
make hooks # Install pre-commit hook
make clean # Remove build artifacts make clean # Remove build artifacts
``` ```
@@ -367,15 +369,9 @@ docker run -d \
triggering change notifications). triggering change notifications).
2. **Initial check**: Immediately perform all DNS, port, and TLS checks 2. **Initial check**: Immediately perform all DNS, port, and TLS checks
on startup. on startup.
3. **Periodic checks** (DNS always runs first): 3. **Periodic checks**:
- DNS checks: every `DNSWATCHER_DNS_INTERVAL` (default 1h). Also - DNS and port checks: every `DNSWATCHER_DNS_INTERVAL` (default 1h).
re-run before every TLS check cycle to ensure fresh IPs. - TLS checks: every `DNSWATCHER_TLS_INTERVAL` (default 12h).
- Port checks: every `DNSWATCHER_DNS_INTERVAL`, after DNS completes.
- TLS checks: every `DNSWATCHER_TLS_INTERVAL` (default 12h), after
DNS completes.
- Port and TLS checks always use freshly resolved IP addresses from
the DNS phase that immediately precedes them — never stale IPs
from a previous cycle.
4. **On change detection**: Send notifications to all configured 4. **On change detection**: Send notifications to all configured
endpoints, update in-memory state, persist to disk. endpoints, update in-memory state, persist to disk.
5. **Shutdown**: Persist final state to disk, complete in-flight 5. **Shutdown**: Persist final state to disk, complete in-flight
@@ -401,8 +397,7 @@ Viper for configuration.
## License ## License
License has not yet been chosen for this project. Pending decision by the MIT — see [LICENSE](LICENSE).
author (MIT, GPL, or WTFPL).
## Author ## Author

View File

@@ -1,60 +0,0 @@
package handlers
import (
"net/http"
"time"
)
// domainResponse represents a single domain in the API response.
type domainResponse struct {
Domain string `json:"domain"`
Nameservers []string `json:"nameservers,omitempty"`
LastChecked string `json:"lastChecked,omitempty"`
Status string `json:"status"`
}
// domainsResponse is the top-level response for GET /api/v1/domains.
type domainsResponse struct {
Domains []domainResponse `json:"domains"`
}
// HandleDomains returns the configured domains and their status.
func (h *Handlers) HandleDomains() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
configured := h.config.Domains
snapshot := h.state.GetSnapshot()
domains := make(
[]domainResponse, 0, len(configured),
)
for _, domain := range configured {
dr := domainResponse{
Domain: domain,
Status: "pending",
}
ds, ok := snapshot.Domains[domain]
if ok {
dr.Nameservers = ds.Nameservers
dr.Status = "ok"
if !ds.LastChecked.IsZero() {
dr.LastChecked = ds.LastChecked.
Format(time.RFC3339)
}
}
domains = append(domains, dr)
}
h.respondJSON(
writer, request,
&domainsResponse{Domains: domains},
http.StatusOK,
)
}
}

View File

@@ -8,11 +8,9 @@ import (
"go.uber.org/fx" "go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/globals" "sneak.berlin/go/dnswatcher/internal/globals"
"sneak.berlin/go/dnswatcher/internal/healthcheck" "sneak.berlin/go/dnswatcher/internal/healthcheck"
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
"sneak.berlin/go/dnswatcher/internal/state"
) )
// Params contains dependencies for Handlers. // Params contains dependencies for Handlers.
@@ -22,8 +20,6 @@ type Params struct {
Logger *logger.Logger Logger *logger.Logger
Globals *globals.Globals Globals *globals.Globals
Healthcheck *healthcheck.Healthcheck Healthcheck *healthcheck.Healthcheck
State *state.State
Config *config.Config
} }
// Handlers provides HTTP request handlers. // Handlers provides HTTP request handlers.
@@ -32,8 +28,6 @@ type Handlers struct {
params *Params params *Params
globals *globals.Globals globals *globals.Globals
hc *healthcheck.Healthcheck hc *healthcheck.Healthcheck
state *state.State
config *config.Config
} }
// New creates a new Handlers instance. // New creates a new Handlers instance.
@@ -43,8 +37,6 @@ func New(_ fx.Lifecycle, params Params) (*Handlers, error) {
params: &params, params: &params,
globals: params.Globals, globals: params.Globals,
hc: params.Healthcheck, hc: params.Healthcheck,
state: params.State,
config: params.Config,
}, nil }, nil
} }
@@ -52,7 +44,7 @@ func (h *Handlers) respondJSON(
writer http.ResponseWriter, writer http.ResponseWriter,
_ *http.Request, _ *http.Request,
data any, data any,
status int, //nolint:unparam // general-purpose utility; status varies in future use status int,
) { ) {
writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(status) writer.WriteHeader(status)

View File

@@ -1,120 +0,0 @@
package handlers
import (
"net/http"
"sort"
"time"
"sneak.berlin/go/dnswatcher/internal/state"
)
// nameserverRecordResponse represents one nameserver's records
// in the API response.
type nameserverRecordResponse struct {
Nameserver string `json:"nameserver"`
Records map[string][]string `json:"records"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
LastChecked string `json:"lastChecked,omitempty"`
}
// hostnameResponse represents a single hostname in the API response.
type hostnameResponse struct {
Hostname string `json:"hostname"`
Nameservers []nameserverRecordResponse `json:"nameservers,omitempty"`
LastChecked string `json:"lastChecked,omitempty"`
Status string `json:"status"`
}
// hostnamesResponse is the top-level response for
// GET /api/v1/hostnames.
type hostnamesResponse struct {
Hostnames []hostnameResponse `json:"hostnames"`
}
// HandleHostnames returns the configured hostnames and their status.
func (h *Handlers) HandleHostnames() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
configured := h.config.Hostnames
snapshot := h.state.GetSnapshot()
hostnames := make(
[]hostnameResponse, 0, len(configured),
)
for _, hostname := range configured {
hr := hostnameResponse{
Hostname: hostname,
Status: "pending",
}
hs, ok := snapshot.Hostnames[hostname]
if ok {
hr.Status = "ok"
if !hs.LastChecked.IsZero() {
hr.LastChecked = hs.LastChecked.
Format(time.RFC3339)
}
hr.Nameservers = buildNameserverRecords(
hs,
)
}
hostnames = append(hostnames, hr)
}
h.respondJSON(
writer, request,
&hostnamesResponse{Hostnames: hostnames},
http.StatusOK,
)
}
}
// buildNameserverRecords converts the per-nameserver state map
// into a sorted slice for deterministic JSON output.
func buildNameserverRecords(
hs *state.HostnameState,
) []nameserverRecordResponse {
if hs.RecordsByNameserver == nil {
return nil
}
nsNames := make(
[]string, 0, len(hs.RecordsByNameserver),
)
for ns := range hs.RecordsByNameserver {
nsNames = append(nsNames, ns)
}
sort.Strings(nsNames)
records := make(
[]nameserverRecordResponse, 0, len(nsNames),
)
for _, ns := range nsNames {
nsr := hs.RecordsByNameserver[ns]
entry := nameserverRecordResponse{
Nameserver: ns,
Records: nsr.Records,
Status: nsr.Status,
Error: nsr.Error,
}
if !nsr.LastChecked.IsZero() {
entry.LastChecked = nsr.LastChecked.
Format(time.RFC3339)
}
records = append(records, entry)
}
return records
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"sort" "sort"
"strings" "strings"
@@ -41,6 +42,22 @@ func rootServerList() []string {
} }
} }
const maxRootServers = 3
// randomRootServers returns a shuffled subset of root servers.
func randomRootServers() []string {
all := rootServerList()
rand.Shuffle(len(all), func(i, j int) {
all[i], all[j] = all[j], all[i]
})
if len(all) > maxRootServers {
return all[:maxRootServers]
}
return all
}
func checkCtx(ctx context.Context) error { func checkCtx(ctx context.Context) error {
err := ctx.Err() err := ctx.Err()
if err != nil { if err != nil {
@@ -227,7 +244,7 @@ func (r *Resolver) followDelegation(
authNS := extractNSSet(resp.Ns) authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 { if len(authNS) == 0 {
return r.resolveNSIterative(ctx, domain) return r.resolveNSRecursive(ctx, domain)
} }
glue := extractGlue(resp.Extra) glue := extractGlue(resp.Extra)
@@ -291,84 +308,60 @@ func (r *Resolver) resolveNSIPs(
return ips return ips
} }
// resolveNSIterative queries for NS records using iterative // resolveNSRecursive queries for NS records using recursive
// resolution as a fallback when followDelegation finds no // resolution as a fallback for intercepted environments.
// authoritative answer in the delegation chain. func (r *Resolver) resolveNSRecursive(
func (r *Resolver) resolveNSIterative(
ctx context.Context, ctx context.Context,
domain string, domain string,
) ([]string, error) { ) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
domain = dns.Fqdn(domain) domain = dns.Fqdn(domain)
servers := rootServerList() msg := new(dns.Msg)
msg.SetQuestion(domain, dns.TypeNS)
msg.RecursionDesired = true
for range maxDelegation { for _, ip := range randomRootServers() {
if checkCtx(ctx) != nil { if checkCtx(ctx) != nil {
return nil, ErrContextCanceled return nil, ErrContextCanceled
} }
resp, err := r.queryServers( addr := net.JoinHostPort(ip, "53")
ctx, servers, domain, dns.TypeNS,
) resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil { if err != nil {
return nil, err continue
} }
nsNames := extractNSSet(resp.Answer) nsNames := extractNSSet(resp.Answer)
if len(nsNames) > 0 { if len(nsNames) > 0 {
return nsNames, nil return nsNames, nil
} }
// Follow delegation.
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
break
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
break
}
servers = nextServers
} }
return nil, ErrNoNameservers return nil, ErrNoNameservers
} }
// resolveARecord resolves a hostname to IPv4 addresses using // resolveARecord resolves a hostname to IPv4 addresses.
// iterative resolution through the delegation chain.
func (r *Resolver) resolveARecord( func (r *Resolver) resolveARecord(
ctx context.Context, ctx context.Context,
hostname string, hostname string,
) ([]string, error) { ) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
hostname = dns.Fqdn(hostname) hostname = dns.Fqdn(hostname)
servers := rootServerList() msg := new(dns.Msg)
msg.SetQuestion(hostname, dns.TypeA)
msg.RecursionDesired = true
for range maxDelegation { for _, ip := range randomRootServers() {
if checkCtx(ctx) != nil { if checkCtx(ctx) != nil {
return nil, ErrContextCanceled return nil, ErrContextCanceled
} }
resp, err := r.queryServers( addr := net.JoinHostPort(ip, "53")
ctx, servers, hostname, dns.TypeA,
) resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
if err != nil { if err != nil {
return nil, fmt.Errorf( continue
"resolving %s: %w", hostname, err,
)
} }
// Check for A records in the answer section.
var ips []string var ips []string
for _, rr := range resp.Answer { for _, rr := range resp.Answer {
@@ -380,24 +373,6 @@ func (r *Resolver) resolveARecord(
if len(ips) > 0 { if len(ips) > 0 {
return ips, nil return ips, nil
} }
// Follow delegation if present.
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
break
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
// Resolve NS IPs iteratively — but guard
// against infinite recursion by using only
// already-resolved servers.
break
}
servers = nextServers
} }
return nil, fmt.Errorf( return nil, fmt.Errorf(
@@ -427,7 +402,7 @@ func (r *Resolver) FindAuthoritativeNameservers(
candidate := strings.Join(labels[i:], ".") + "." candidate := strings.Join(labels[i:], ".") + "."
nsNames, err := r.followDelegation( nsNames, err := r.followDelegation(
ctx, candidate, rootServerList(), ctx, candidate, randomRootServers(),
) )
if err == nil && len(nsNames) > 0 { if err == nil && len(nsNames) > 0 {
sort.Strings(nsNames) sort.Strings(nsNames)

View File

@@ -28,8 +28,6 @@ func (s *Server) SetupRoutes() {
// API v1 routes // API v1 routes
s.router.Route("/api/v1", func(r chi.Router) { s.router.Route("/api/v1", func(r chi.Router) {
r.Get("/status", s.handlers.HandleStatus()) r.Get("/status", s.handlers.HandleStatus())
r.Get("/domains", s.handlers.HandleDomains())
r.Get("/hostnames", s.handlers.HandleHostnames())
}) })
// Metrics endpoint (optional, with basic auth) // Metrics endpoint (optional, with basic auth)

View File

@@ -57,49 +57,10 @@ type HostnameState struct {
// PortState holds the monitoring state for a port. // PortState holds the monitoring state for a port.
type PortState struct { type PortState struct {
Open bool `json:"open"` Open bool `json:"open"`
Hostnames []string `json:"hostnames"` Hostname string `json:"hostname"`
LastChecked time.Time `json:"lastChecked"` LastChecked time.Time `json:"lastChecked"`
} }
// UnmarshalJSON implements custom unmarshaling to handle both
// the old single-hostname format and the new multi-hostname
// format for backward compatibility with existing state files.
func (ps *PortState) UnmarshalJSON(data []byte) error {
// Use an alias to prevent infinite recursion.
type portStateAlias struct {
Open bool `json:"open"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
var alias portStateAlias
err := json.Unmarshal(data, &alias)
if err != nil {
return fmt.Errorf("unmarshaling port state: %w", err)
}
ps.Open = alias.Open
ps.Hostnames = alias.Hostnames
ps.LastChecked = alias.LastChecked
// If Hostnames is empty, try reading the old single-hostname
// format for backward compatibility.
if len(ps.Hostnames) == 0 {
var old struct {
Hostname string `json:"hostname"`
}
// Best-effort: ignore errors since the main unmarshal
// already succeeded.
if json.Unmarshal(data, &old) == nil && old.Hostname != "" {
ps.Hostnames = []string{old.Hostname}
}
}
return nil
}
// CertificateState holds TLS certificate monitoring state. // CertificateState holds TLS certificate monitoring state.
type CertificateState struct { type CertificateState struct {
CommonName string `json:"commonName"` CommonName string `json:"commonName"`
@@ -302,27 +263,6 @@ func (s *State) GetPortState(key string) (*PortState, bool) {
return ps, ok return ps, ok
} }
// DeletePortState removes a port state entry.
func (s *State) DeletePortState(key string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.snapshot.Ports, key)
}
// GetAllPortKeys returns all port state keys.
func (s *State) GetAllPortKeys() []string {
s.mu.RLock()
defer s.mu.RUnlock()
keys := make([]string, 0, len(s.snapshot.Ports))
for k := range s.snapshot.Ports {
keys = append(keys, k)
}
return keys
}
// SetCertificateState updates the state for a certificate. // SetCertificateState updates the state for a certificate.
func (s *State) SetCertificateState( func (s *State) SetCertificateState(
key string, key string,

View File

@@ -72,15 +72,13 @@ func New(
} }
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(startCtx context.Context) error {
// Use context.Background() — the fx startup context ctx, cancel := context.WithCancel(
// expires after startup completes, so deriving from it context.WithoutCancel(startCtx),
// would cancel the watcher immediately. The watcher's )
// lifetime is controlled by w.cancel in OnStop.
ctx, cancel := context.WithCancel(context.Background())
w.cancel = cancel w.cancel = cancel
go w.Run(ctx) //nolint:contextcheck // intentionally not derived from startCtx go w.Run(ctx)
return nil return nil
}, },
@@ -143,16 +141,9 @@ func (w *Watcher) Run(ctx context.Context) {
return return
case <-dnsTicker.C: case <-dnsTicker.C:
w.runDNSChecks(ctx) w.runDNSAndPortChecks(ctx)
w.checkAllPorts(ctx)
w.saveState() w.saveState()
case <-tlsTicker.C: case <-tlsTicker.C:
// Run DNS first so TLS checks use freshly
// resolved IP addresses, not stale ones from
// a previous cycle.
w.runDNSChecks(ctx)
w.runTLSChecks(ctx) w.runTLSChecks(ctx)
w.saveState() w.saveState()
} }
@@ -160,26 +151,10 @@ func (w *Watcher) Run(ctx context.Context) {
} }
// RunOnce performs a single complete monitoring cycle. // RunOnce performs a single complete monitoring cycle.
// DNS checks run first so that port and TLS checks use
// freshly resolved IP addresses. Port checks run before
// TLS because TLS checks only target IPs with an open
// port 443.
func (w *Watcher) RunOnce(ctx context.Context) { func (w *Watcher) RunOnce(ctx context.Context) {
w.detectFirstRun() w.detectFirstRun()
w.runDNSAndPortChecks(ctx)
// Phase 1: DNS resolution must complete first so that
// subsequent checks use fresh IP addresses.
w.runDNSChecks(ctx)
// Phase 2: Port checks populate port state that TLS
// checks depend on (TLS only targets IPs where port
// 443 is open).
w.checkAllPorts(ctx)
// Phase 3: TLS checks use fresh DNS IPs and current
// port state.
w.runTLSChecks(ctx) w.runTLSChecks(ctx)
w.saveState() w.saveState()
w.firstRun = false w.firstRun = false
} }
@@ -196,11 +171,7 @@ func (w *Watcher) detectFirstRun() {
} }
} }
// runDNSChecks performs DNS resolution for all configured domains func (w *Watcher) runDNSAndPortChecks(ctx context.Context) {
// and hostnames, updating state with freshly resolved records.
// This must complete before port or TLS checks run so those
// checks operate on current IP addresses.
func (w *Watcher) runDNSChecks(ctx context.Context) {
for _, domain := range w.config.Domains { for _, domain := range w.config.Domains {
w.checkDomain(ctx, domain) w.checkDomain(ctx, domain)
} }
@@ -208,6 +179,8 @@ func (w *Watcher) runDNSChecks(ctx context.Context) {
for _, hostname := range w.config.Hostnames { for _, hostname := range w.config.Hostnames {
w.checkHostname(ctx, hostname) w.checkHostname(ctx, hostname)
} }
w.checkAllPorts(ctx)
} }
func (w *Watcher) checkDomain( func (w *Watcher) checkDomain(
@@ -475,94 +448,24 @@ func (w *Watcher) detectInconsistencies(
} }
func (w *Watcher) checkAllPorts(ctx context.Context) { func (w *Watcher) checkAllPorts(ctx context.Context) {
// Phase 1: Build current IP:port → hostname associations for _, hostname := range w.config.Hostnames {
// from fresh DNS data. w.checkPortsForHostname(ctx, hostname)
associations := w.buildPortAssociations()
// Phase 2: Check each unique IP:port and update state
// with the full set of associated hostnames.
for key, hostnames := range associations {
ip, port := parsePortKey(key)
if port == 0 {
continue
}
w.checkSinglePort(ctx, ip, port, hostnames)
} }
// Phase 3: Remove port state entries that no longer have for _, domain := range w.config.Domains {
// any hostname referencing them. w.checkPortsForHostname(ctx, domain)
w.cleanupStalePorts(associations) }
} }
// buildPortAssociations constructs a map from IP:port keys to func (w *Watcher) checkPortsForHostname(
// the sorted set of hostnames currently resolving to that IP. ctx context.Context,
func (w *Watcher) buildPortAssociations() map[string][]string { hostname string,
assoc := make(map[string]map[string]bool)
allNames := make(
[]string, 0,
len(w.config.Hostnames)+len(w.config.Domains),
)
allNames = append(allNames, w.config.Hostnames...)
allNames = append(allNames, w.config.Domains...)
for _, name := range allNames {
ips := w.collectIPs(name)
for _, ip := range ips {
for _, port := range monitoredPorts {
key := fmt.Sprintf("%s:%d", ip, port)
if assoc[key] == nil {
assoc[key] = make(map[string]bool)
}
assoc[key][name] = true
}
}
}
result := make(map[string][]string, len(assoc))
for key, set := range assoc {
hostnames := make([]string, 0, len(set))
for h := range set {
hostnames = append(hostnames, h)
}
sort.Strings(hostnames)
result[key] = hostnames
}
return result
}
// parsePortKey splits an "ip:port" key into its components.
func parsePortKey(key string) (string, int) {
lastColon := strings.LastIndex(key, ":")
if lastColon < 0 {
return key, 0
}
ip := key[:lastColon]
var p int
_, err := fmt.Sscanf(key[lastColon+1:], "%d", &p)
if err != nil {
return ip, 0
}
return ip, p
}
// cleanupStalePorts removes port state entries that are no
// longer referenced by any hostname in the current DNS data.
func (w *Watcher) cleanupStalePorts(
currentAssociations map[string][]string,
) { ) {
for _, key := range w.state.GetAllPortKeys() { ips := w.collectIPs(hostname)
if _, exists := currentAssociations[key]; !exists {
w.state.DeletePortState(key) for _, ip := range ips {
for _, port := range monitoredPorts {
w.checkSinglePort(ctx, ip, port, hostname)
} }
} }
} }
@@ -599,7 +502,7 @@ func (w *Watcher) checkSinglePort(
ctx context.Context, ctx context.Context,
ip string, ip string,
port int, port int,
hostnames []string, hostname string,
) { ) {
result, err := w.portCheck.CheckPort(ctx, ip, port) result, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil { if err != nil {
@@ -624,8 +527,8 @@ func (w *Watcher) checkSinglePort(
} }
msg := fmt.Sprintf( msg := fmt.Sprintf(
"Hosts: %s\nAddress: %s\nPort now %s", "Host: %s\nAddress: %s\nPort now %s",
strings.Join(hostnames, ", "), key, stateStr, hostname, key, stateStr,
) )
w.notify.SendNotification( w.notify.SendNotification(
@@ -638,7 +541,7 @@ func (w *Watcher) checkSinglePort(
w.state.SetPortState(key, &state.PortState{ w.state.SetPortState(key, &state.PortState{
Open: result.Open, Open: result.Open,
Hostnames: hostnames, Hostname: hostname,
LastChecked: now, LastChecked: now,
}) })
} }

View File

@@ -682,80 +682,6 @@ func TestGracefulShutdown(t *testing.T) {
} }
} }
func setupHostnameIP(
deps *testDeps,
hostname, ip string,
) {
deps.resolver.allRecords[hostname] = map[string]map[string][]string{
"ns1.example.com.": {"A": {ip}},
}
deps.portChecker.results[ip+":80"] = true
deps.portChecker.results[ip+":443"] = true
deps.tlsChecker.certs[ip+":"+hostname] = &tlscheck.CertificateInfo{
CommonName: hostname,
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{hostname},
}
}
func updateHostnameIP(deps *testDeps, hostname, ip string) {
deps.resolver.mu.Lock()
deps.resolver.allRecords[hostname] = map[string]map[string][]string{
"ns1.example.com.": {"A": {ip}},
}
deps.resolver.mu.Unlock()
deps.portChecker.mu.Lock()
deps.portChecker.results[ip+":80"] = true
deps.portChecker.results[ip+":443"] = true
deps.portChecker.mu.Unlock()
deps.tlsChecker.mu.Lock()
deps.tlsChecker.certs[ip+":"+hostname] = &tlscheck.CertificateInfo{
CommonName: hostname,
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{hostname},
}
deps.tlsChecker.mu.Unlock()
}
func TestDNSRunsBeforePortAndTLSChecks(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
setupHostnameIP(deps, "www.example.com", "10.0.0.1")
ctx := t.Context()
w.RunOnce(ctx)
snap := deps.state.GetSnapshot()
if _, ok := snap.Ports["10.0.0.1:80"]; !ok {
t.Fatal("expected port state for 10.0.0.1:80")
}
// DNS changes to a new IP; port and TLS must pick it up.
updateHostnameIP(deps, "www.example.com", "10.0.0.2")
w.RunOnce(ctx)
snap = deps.state.GetSnapshot()
if _, ok := snap.Ports["10.0.0.2:80"]; !ok {
t.Error("port check used stale DNS: missing 10.0.0.2:80")
}
certKey := "10.0.0.2:443:www.example.com"
if _, ok := snap.Certificates[certKey]; !ok {
t.Error("TLS check used stale DNS: missing " + certKey)
}
}
func TestNSFailureAndRecovery(t *testing.T) { func TestNSFailureAndRecovery(t *testing.T) {
t.Parallel() t.Parallel()