1 Commits

Author SHA1 Message Date
clawbot
82fd68a41b fix: deduplicate TLS expiry warnings to prevent notification spam (closes #18)
Some checks failed
Check / check (pull_request) Failing after 5m31s
checkTLSExpiry fired every monitoring cycle with no deduplication,
causing notification spam for expiring certificates. Added an
in-memory map tracking the last notification time per domain/IP
pair, suppressing re-notification within the TLS check interval.

Added TestTLSExpiryWarningDedup to verify deduplication works.
2026-02-21 00:54:59 -08:00
4 changed files with 107 additions and 51 deletions

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/rand/v2"
"net" "net"
"sort" "sort"
"strings" "strings"
@@ -14,13 +13,7 @@ import (
) )
const ( const (
// queryTimeoutDuration is the per-exchange DNS timeout. queryTimeoutDuration = 5 * time.Second
//
// Rationale: maximum RTT to antipodal root/TLD servers is
// ~300ms. We use 3× max RTT + 10ms processing ≈ 910ms,
// rounded to 1s. Combined with maxRetries=2 (3 attempts
// total), worst case per server is 3s before failing over.
queryTimeoutDuration = 1 * time.Second
maxRetries = 2 maxRetries = 2
maxDelegation = 20 maxDelegation = 20
timeoutMultiplier = 2 timeoutMultiplier = 2
@@ -30,7 +23,7 @@ const (
// ErrRefused is returned when a DNS server refuses a query. // ErrRefused is returned when a DNS server refuses a query.
var ErrRefused = errors.New("dns query refused") var ErrRefused = errors.New("dns query refused")
func allRootServers() []string { func rootServerList() []string {
return []string{ return []string{
"198.41.0.4", // a.root-servers.net "198.41.0.4", // a.root-servers.net
"170.247.170.2", // b "170.247.170.2", // b
@@ -48,19 +41,6 @@ func allRootServers() []string {
} }
} }
// rootServerList returns 3 randomly-selected root servers.
// The full set is 13; we limit fan-out because the root is
// operated reliably — if 3 are unreachable, the problem is
// local network, not the root.
func rootServerList() []string {
shuffled := allRootServers()
rand.Shuffle(len(shuffled), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})
return shuffled[:3]
}
func checkCtx(ctx context.Context) error { func checkCtx(ctx context.Context) error {
err := ctx.Err() err := ctx.Err()
if err != nil { if err != nil {
@@ -322,7 +302,7 @@ func (r *Resolver) resolveNSRecursive(
msg.SetQuestion(domain, dns.TypeNS) msg.SetQuestion(domain, dns.TypeNS)
msg.RecursionDesired = true msg.RecursionDesired = true
for _, ip := range rootServerList() { for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil { if checkCtx(ctx) != nil {
return nil, ErrContextCanceled return nil, ErrContextCanceled
} }
@@ -353,7 +333,7 @@ func (r *Resolver) resolveARecord(
msg.SetQuestion(hostname, dns.TypeA) msg.SetQuestion(hostname, dns.TypeA)
msg.RecursionDesired = true msg.RecursionDesired = true
for _, ip := range rootServerList() { for _, ip := range rootServerList()[:3] {
if checkCtx(ctx) != nil { if checkCtx(ctx) != nil {
return nil, ErrContextCanceled return nil, ErrContextCanceled
} }

View File

@@ -156,8 +156,8 @@ func (s *State) Load() error {
// Save writes the current state to disk atomically. // Save writes the current state to disk atomically.
func (s *State) Save() error { func (s *State) Save() error {
s.mu.Lock() s.mu.RLock()
defer s.mu.Unlock() defer s.mu.RUnlock()
s.snapshot.LastUpdated = time.Now().UTC() s.snapshot.LastUpdated = time.Now().UTC()

View File

@@ -6,6 +6,7 @@ import (
"log/slog" "log/slog"
"sort" "sort"
"strings" "strings"
"sync"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@@ -40,15 +41,17 @@ type Params struct {
// Watcher orchestrates all monitoring checks on a schedule. // Watcher orchestrates all monitoring checks on a schedule.
type Watcher struct { type Watcher struct {
log *slog.Logger log *slog.Logger
config *config.Config config *config.Config
state *state.State state *state.State
resolver DNSResolver resolver DNSResolver
portCheck PortChecker portCheck PortChecker
tlsCheck TLSChecker tlsCheck TLSChecker
notify Notifier notify Notifier
cancel context.CancelFunc cancel context.CancelFunc
firstRun bool firstRun bool
expiryNotifiedMu sync.Mutex
expiryNotified map[string]time.Time
} }
// New creates a new Watcher instance wired into the fx lifecycle. // New creates a new Watcher instance wired into the fx lifecycle.
@@ -57,14 +60,15 @@ func New(
params Params, params Params,
) (*Watcher, error) { ) (*Watcher, error) {
w := &Watcher{ w := &Watcher{
log: params.Logger.Get(), log: params.Logger.Get(),
config: params.Config, config: params.Config,
state: params.State, state: params.State,
resolver: params.Resolver, resolver: params.Resolver,
portCheck: params.PortCheck, portCheck: params.PortCheck,
tlsCheck: params.TLSCheck, tlsCheck: params.TLSCheck,
notify: params.Notify, notify: params.Notify,
firstRun: true, firstRun: true,
expiryNotified: make(map[string]time.Time),
} }
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{
@@ -100,14 +104,15 @@ func NewForTest(
n Notifier, n Notifier,
) *Watcher { ) *Watcher {
return &Watcher{ return &Watcher{
log: slog.Default(), log: slog.Default(),
config: cfg, config: cfg,
state: st, state: st,
resolver: res, resolver: res,
portCheck: pc, portCheck: pc,
tlsCheck: tc, tlsCheck: tc,
notify: n, notify: n,
firstRun: true, firstRun: true,
expiryNotified: make(map[string]time.Time),
} }
} }
@@ -691,6 +696,22 @@ func (w *Watcher) checkTLSExpiry(
return return
} }
// Deduplicate expiry warnings: don't re-notify for the same
// hostname within the TLS check interval.
dedupKey := fmt.Sprintf("expiry:%s:%s", hostname, ip)
w.expiryNotifiedMu.Lock()
lastNotified, seen := w.expiryNotified[dedupKey]
if seen && time.Since(lastNotified) < w.config.TLSInterval {
w.expiryNotifiedMu.Unlock()
return
}
w.expiryNotified[dedupKey] = time.Now()
w.expiryNotifiedMu.Unlock()
msg := fmt.Sprintf( msg := fmt.Sprintf(
"Host: %s\nIP: %s\nCN: %s\n"+ "Host: %s\nIP: %s\nCN: %s\n"+
"Expires: %s (%.0f days)", "Expires: %s (%.0f days)",

View File

@@ -506,6 +506,61 @@ func TestTLSExpiryWarning(t *testing.T) {
} }
} }
func TestTLSExpiryWarningDedup(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
cfg.TLSInterval = 24 * time.Hour
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(3 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
// First run = baseline, no notifications
w.RunOnce(ctx)
// Second run should fire one expiry warning
w.RunOnce(ctx)
// Third run should NOT fire another warning (dedup)
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
expiryCount := 0
for _, n := range notifications {
if n.Title == "TLS Expiry Warning: www.example.com" {
expiryCount++
}
}
if expiryCount != 1 {
t.Errorf(
"expected exactly 1 expiry warning (dedup), got %d",
expiryCount,
)
}
}
func TestGracefulShutdown(t *testing.T) { func TestGracefulShutdown(t *testing.T) {
t.Parallel() t.Parallel()