All checks were successful
check / check (push) Successful in 34s
## Summary Port state keys are `ip:port` with a single `hostname` field. When multiple hostnames resolve to the same IP (shared hosting, CDN), only one hostname was associated. This caused orphaned port state when that hostname removed the IP from DNS while the IP remained valid for other hostnames. ## Changes ### State (`internal/state/state.go`) - `PortState.Hostname` (string) → `PortState.Hostnames` ([]string) - Custom `UnmarshalJSON` for backward compatibility: reads old single `hostname` field and migrates to a single-element `hostnames` slice - Added `DeletePortState` and `GetAllPortKeys` methods for cleanup ### Watcher (`internal/watcher/watcher.go`) - Refactored `checkAllPorts` into three phases: 1. Build IP:port → hostname associations from current DNS data 2. Check each unique IP:port once with all associated hostnames 3. Clean up stale port state entries with no hostname references - Port change notifications now list all associated hostnames (`Hosts:` instead of `Host:`) - Added `buildPortAssociations`, `parsePortKey`, and `cleanupStalePorts` helper functions ### README - Updated state file format example: `hostname` → `hostnames` (array) - Updated notification description to reflect multiple hostnames ## Backward Compatibility Existing state files with the old single `hostname` string are handled gracefully via custom JSON unmarshaling — they are read as single-element `hostnames` slices. Closes #55 Co-authored-by: clawbot <clawbot@noreply.eeqj.de> Reviewed-on: #65 Co-authored-by: clawbot <clawbot@noreply.example.org> Co-committed-by: clawbot <clawbot@noreply.example.org>
905 lines
18 KiB
Go
905 lines
18 KiB
Go
package watcher
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/fx"
|
|
|
|
"sneak.berlin/go/dnswatcher/internal/config"
|
|
"sneak.berlin/go/dnswatcher/internal/logger"
|
|
"sneak.berlin/go/dnswatcher/internal/state"
|
|
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
|
)
|
|
|
|
// monitoredPorts are the TCP ports checked for each IP address.
|
|
var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals
|
|
|
|
// tlsPort is the port used for TLS certificate checks.
|
|
const tlsPort = 443
|
|
|
|
// hoursPerDay converts days to hours for duration calculations.
|
|
const hoursPerDay = 24
|
|
|
|
// Params contains dependencies for Watcher.
|
|
type Params struct {
|
|
fx.In
|
|
|
|
Logger *logger.Logger
|
|
Config *config.Config
|
|
State *state.State
|
|
Resolver DNSResolver
|
|
PortCheck PortChecker
|
|
TLSCheck TLSChecker
|
|
Notify Notifier
|
|
}
|
|
|
|
// Watcher orchestrates all monitoring checks on a schedule.
|
|
type Watcher struct {
|
|
log *slog.Logger
|
|
config *config.Config
|
|
state *state.State
|
|
resolver DNSResolver
|
|
portCheck PortChecker
|
|
tlsCheck TLSChecker
|
|
notify Notifier
|
|
cancel context.CancelFunc
|
|
firstRun bool
|
|
expiryNotifiedMu sync.Mutex
|
|
expiryNotified map[string]time.Time
|
|
}
|
|
|
|
// New creates a new Watcher instance wired into the fx lifecycle.
|
|
func New(
|
|
lifecycle fx.Lifecycle,
|
|
params Params,
|
|
) (*Watcher, error) {
|
|
w := &Watcher{
|
|
log: params.Logger.Get(),
|
|
config: params.Config,
|
|
state: params.State,
|
|
resolver: params.Resolver,
|
|
portCheck: params.PortCheck,
|
|
tlsCheck: params.TLSCheck,
|
|
notify: params.Notify,
|
|
firstRun: true,
|
|
expiryNotified: make(map[string]time.Time),
|
|
}
|
|
|
|
lifecycle.Append(fx.Hook{
|
|
OnStart: func(startCtx context.Context) error {
|
|
ctx, cancel := context.WithCancel(
|
|
context.WithoutCancel(startCtx),
|
|
)
|
|
w.cancel = cancel
|
|
|
|
go w.Run(ctx)
|
|
|
|
return nil
|
|
},
|
|
OnStop: func(_ context.Context) error {
|
|
if w.cancel != nil {
|
|
w.cancel()
|
|
}
|
|
|
|
return nil
|
|
},
|
|
})
|
|
|
|
return w, nil
|
|
}
|
|
|
|
// NewForTest creates a Watcher without fx for unit testing.
|
|
func NewForTest(
|
|
cfg *config.Config,
|
|
st *state.State,
|
|
res DNSResolver,
|
|
pc PortChecker,
|
|
tc TLSChecker,
|
|
n Notifier,
|
|
) *Watcher {
|
|
return &Watcher{
|
|
log: slog.Default(),
|
|
config: cfg,
|
|
state: st,
|
|
resolver: res,
|
|
portCheck: pc,
|
|
tlsCheck: tc,
|
|
notify: n,
|
|
firstRun: true,
|
|
expiryNotified: make(map[string]time.Time),
|
|
}
|
|
}
|
|
|
|
// Run starts the monitoring loop with periodic scheduling.
|
|
func (w *Watcher) Run(ctx context.Context) {
|
|
w.log.Info(
|
|
"watcher starting",
|
|
"domains", len(w.config.Domains),
|
|
"hostnames", len(w.config.Hostnames),
|
|
"dnsInterval", w.config.DNSInterval,
|
|
"tlsInterval", w.config.TLSInterval,
|
|
)
|
|
|
|
w.RunOnce(ctx)
|
|
|
|
dnsTicker := time.NewTicker(w.config.DNSInterval)
|
|
tlsTicker := time.NewTicker(w.config.TLSInterval)
|
|
|
|
defer dnsTicker.Stop()
|
|
defer tlsTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
w.log.Info("watcher stopped")
|
|
|
|
return
|
|
case <-dnsTicker.C:
|
|
w.runDNSChecks(ctx)
|
|
|
|
w.checkAllPorts(ctx)
|
|
w.saveState()
|
|
case <-tlsTicker.C:
|
|
// Run DNS first so TLS checks use freshly
|
|
// resolved IP addresses, not stale ones from
|
|
// a previous cycle.
|
|
w.runDNSChecks(ctx)
|
|
|
|
w.runTLSChecks(ctx)
|
|
w.saveState()
|
|
}
|
|
}
|
|
}
|
|
|
|
// RunOnce performs a single complete monitoring cycle.
|
|
// DNS checks run first so that port and TLS checks use
|
|
// freshly resolved IP addresses. Port checks run before
|
|
// TLS because TLS checks only target IPs with an open
|
|
// port 443.
|
|
func (w *Watcher) RunOnce(ctx context.Context) {
|
|
w.detectFirstRun()
|
|
|
|
// Phase 1: DNS resolution must complete first so that
|
|
// subsequent checks use fresh IP addresses.
|
|
w.runDNSChecks(ctx)
|
|
|
|
// Phase 2: Port checks populate port state that TLS
|
|
// checks depend on (TLS only targets IPs where port
|
|
// 443 is open).
|
|
w.checkAllPorts(ctx)
|
|
|
|
// Phase 3: TLS checks use fresh DNS IPs and current
|
|
// port state.
|
|
w.runTLSChecks(ctx)
|
|
|
|
w.saveState()
|
|
w.firstRun = false
|
|
}
|
|
|
|
func (w *Watcher) detectFirstRun() {
|
|
snap := w.state.GetSnapshot()
|
|
hasState := len(snap.Domains) > 0 ||
|
|
len(snap.Hostnames) > 0 ||
|
|
len(snap.Ports) > 0 ||
|
|
len(snap.Certificates) > 0
|
|
|
|
if hasState {
|
|
w.firstRun = false
|
|
}
|
|
}
|
|
|
|
// runDNSChecks performs DNS resolution for all configured domains
|
|
// and hostnames, updating state with freshly resolved records.
|
|
// This must complete before port or TLS checks run so those
|
|
// checks operate on current IP addresses.
|
|
func (w *Watcher) runDNSChecks(ctx context.Context) {
|
|
for _, domain := range w.config.Domains {
|
|
w.checkDomain(ctx, domain)
|
|
}
|
|
|
|
for _, hostname := range w.config.Hostnames {
|
|
w.checkHostname(ctx, hostname)
|
|
}
|
|
}
|
|
|
|
func (w *Watcher) checkDomain(
|
|
ctx context.Context,
|
|
domain string,
|
|
) {
|
|
nameservers, err := w.resolver.LookupNS(ctx, domain)
|
|
if err != nil {
|
|
w.log.Error(
|
|
"failed to lookup NS",
|
|
"domain", domain,
|
|
"error", err,
|
|
)
|
|
|
|
return
|
|
}
|
|
|
|
sort.Strings(nameservers)
|
|
|
|
now := time.Now().UTC()
|
|
|
|
prev, hasPrev := w.state.GetDomainState(domain)
|
|
if hasPrev && !w.firstRun {
|
|
w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers)
|
|
}
|
|
|
|
w.state.SetDomainState(domain, &state.DomainState{
|
|
Nameservers: nameservers,
|
|
LastChecked: now,
|
|
})
|
|
|
|
// Also look up A/AAAA records for the apex domain so that
|
|
// port and TLS checks (which read HostnameState) can find
|
|
// the domain's IP addresses.
|
|
records, err := w.resolver.LookupAllRecords(ctx, domain)
|
|
if err != nil {
|
|
w.log.Error(
|
|
"failed to lookup records for domain",
|
|
"domain", domain,
|
|
"error", err,
|
|
)
|
|
|
|
return
|
|
}
|
|
|
|
prevHS, hasPrevHS := w.state.GetHostnameState(domain)
|
|
if hasPrevHS && !w.firstRun {
|
|
w.detectHostnameChanges(ctx, domain, prevHS, records)
|
|
}
|
|
|
|
newState := buildHostnameState(records, now)
|
|
w.state.SetHostnameState(domain, newState)
|
|
}
|
|
|
|
func (w *Watcher) detectNSChanges(
|
|
ctx context.Context,
|
|
domain string,
|
|
oldNS, newNS []string,
|
|
) {
|
|
oldSet := toSet(oldNS)
|
|
newSet := toSet(newNS)
|
|
|
|
var added, removed []string
|
|
|
|
for ns := range newSet {
|
|
if !oldSet[ns] {
|
|
added = append(added, ns)
|
|
}
|
|
}
|
|
|
|
for ns := range oldSet {
|
|
if !newSet[ns] {
|
|
removed = append(removed, ns)
|
|
}
|
|
}
|
|
|
|
if len(added) == 0 && len(removed) == 0 {
|
|
return
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"Domain: %s\nAdded: %s\nRemoved: %s",
|
|
domain,
|
|
strings.Join(added, ", "),
|
|
strings.Join(removed, ", "),
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"NS Change: "+domain,
|
|
msg,
|
|
"warning",
|
|
)
|
|
}
|
|
|
|
func (w *Watcher) checkHostname(
|
|
ctx context.Context,
|
|
hostname string,
|
|
) {
|
|
records, err := w.resolver.LookupAllRecords(ctx, hostname)
|
|
if err != nil {
|
|
w.log.Error(
|
|
"failed to lookup records",
|
|
"hostname", hostname,
|
|
"error", err,
|
|
)
|
|
|
|
return
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
prev, hasPrev := w.state.GetHostnameState(hostname)
|
|
|
|
if hasPrev && !w.firstRun {
|
|
w.detectHostnameChanges(ctx, hostname, prev, records)
|
|
}
|
|
|
|
newState := buildHostnameState(records, now)
|
|
w.state.SetHostnameState(hostname, newState)
|
|
}
|
|
|
|
func buildHostnameState(
|
|
records map[string]map[string][]string,
|
|
now time.Time,
|
|
) *state.HostnameState {
|
|
hs := &state.HostnameState{
|
|
RecordsByNameserver: make(
|
|
map[string]*state.NameserverRecordState,
|
|
),
|
|
LastChecked: now,
|
|
}
|
|
|
|
for ns, recs := range records {
|
|
hs.RecordsByNameserver[ns] = &state.NameserverRecordState{
|
|
Records: recs,
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
}
|
|
}
|
|
|
|
return hs
|
|
}
|
|
|
|
func (w *Watcher) detectHostnameChanges(
|
|
ctx context.Context,
|
|
hostname string,
|
|
prev *state.HostnameState,
|
|
current map[string]map[string][]string,
|
|
) {
|
|
w.detectRecordChanges(ctx, hostname, prev, current)
|
|
w.detectNSDisappearances(ctx, hostname, prev, current)
|
|
w.detectInconsistencies(ctx, hostname, current)
|
|
}
|
|
|
|
func (w *Watcher) detectRecordChanges(
|
|
ctx context.Context,
|
|
hostname string,
|
|
prev *state.HostnameState,
|
|
current map[string]map[string][]string,
|
|
) {
|
|
for ns, recs := range current {
|
|
prevNS, ok := prev.RecordsByNameserver[ns]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if recordsEqual(prevNS.Records, recs) {
|
|
continue
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"Hostname: %s\nNameserver: %s\n"+
|
|
"Old: %v\nNew: %v",
|
|
hostname, ns,
|
|
prevNS.Records, recs,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"Record Change: "+hostname,
|
|
msg,
|
|
"warning",
|
|
)
|
|
}
|
|
}
|
|
|
|
func (w *Watcher) detectNSDisappearances(
|
|
ctx context.Context,
|
|
hostname string,
|
|
prev *state.HostnameState,
|
|
current map[string]map[string][]string,
|
|
) {
|
|
for ns, prevNS := range prev.RecordsByNameserver {
|
|
if _, ok := current[ns]; ok || prevNS.Status != "ok" {
|
|
continue
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"Hostname: %s\nNameserver: %s disappeared",
|
|
hostname, ns,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"NS Failure: "+hostname,
|
|
msg,
|
|
"error",
|
|
)
|
|
}
|
|
|
|
for ns := range current {
|
|
prevNS, ok := prev.RecordsByNameserver[ns]
|
|
if !ok || prevNS.Status != "error" {
|
|
continue
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"Hostname: %s\nNameserver: %s recovered",
|
|
hostname, ns,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"NS Recovery: "+hostname,
|
|
msg,
|
|
"success",
|
|
)
|
|
}
|
|
}
|
|
|
|
func (w *Watcher) detectInconsistencies(
|
|
ctx context.Context,
|
|
hostname string,
|
|
current map[string]map[string][]string,
|
|
) {
|
|
nameservers := make([]string, 0, len(current))
|
|
for ns := range current {
|
|
nameservers = append(nameservers, ns)
|
|
}
|
|
|
|
sort.Strings(nameservers)
|
|
|
|
for i := range len(nameservers) - 1 {
|
|
ns1 := nameservers[i]
|
|
ns2 := nameservers[i+1]
|
|
|
|
if recordsEqual(current[ns1], current[ns2]) {
|
|
continue
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"Hostname: %s\n%s: %v\n%s: %v",
|
|
hostname,
|
|
ns1, current[ns1],
|
|
ns2, current[ns2],
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"Inconsistency: "+hostname,
|
|
msg,
|
|
"warning",
|
|
)
|
|
}
|
|
}
|
|
|
|
func (w *Watcher) checkAllPorts(ctx context.Context) {
|
|
// Phase 1: Build current IP:port → hostname associations
|
|
// from fresh DNS data.
|
|
associations := w.buildPortAssociations()
|
|
|
|
// Phase 2: Check each unique IP:port and update state
|
|
// with the full set of associated hostnames.
|
|
for key, hostnames := range associations {
|
|
ip, port := parsePortKey(key)
|
|
if port == 0 {
|
|
continue
|
|
}
|
|
|
|
w.checkSinglePort(ctx, ip, port, hostnames)
|
|
}
|
|
|
|
// Phase 3: Remove port state entries that no longer have
|
|
// any hostname referencing them.
|
|
w.cleanupStalePorts(associations)
|
|
}
|
|
|
|
// buildPortAssociations constructs a map from IP:port keys to
|
|
// the sorted set of hostnames currently resolving to that IP.
|
|
func (w *Watcher) buildPortAssociations() map[string][]string {
|
|
assoc := make(map[string]map[string]bool)
|
|
|
|
allNames := make(
|
|
[]string, 0,
|
|
len(w.config.Hostnames)+len(w.config.Domains),
|
|
)
|
|
allNames = append(allNames, w.config.Hostnames...)
|
|
allNames = append(allNames, w.config.Domains...)
|
|
|
|
for _, name := range allNames {
|
|
ips := w.collectIPs(name)
|
|
for _, ip := range ips {
|
|
for _, port := range monitoredPorts {
|
|
key := fmt.Sprintf("%s:%d", ip, port)
|
|
if assoc[key] == nil {
|
|
assoc[key] = make(map[string]bool)
|
|
}
|
|
|
|
assoc[key][name] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
result := make(map[string][]string, len(assoc))
|
|
for key, set := range assoc {
|
|
hostnames := make([]string, 0, len(set))
|
|
for h := range set {
|
|
hostnames = append(hostnames, h)
|
|
}
|
|
|
|
sort.Strings(hostnames)
|
|
|
|
result[key] = hostnames
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// parsePortKey splits an "ip:port" key into its components.
|
|
func parsePortKey(key string) (string, int) {
|
|
lastColon := strings.LastIndex(key, ":")
|
|
if lastColon < 0 {
|
|
return key, 0
|
|
}
|
|
|
|
ip := key[:lastColon]
|
|
|
|
var p int
|
|
|
|
_, err := fmt.Sscanf(key[lastColon+1:], "%d", &p)
|
|
if err != nil {
|
|
return ip, 0
|
|
}
|
|
|
|
return ip, p
|
|
}
|
|
|
|
// cleanupStalePorts removes port state entries that are no
|
|
// longer referenced by any hostname in the current DNS data.
|
|
func (w *Watcher) cleanupStalePorts(
|
|
currentAssociations map[string][]string,
|
|
) {
|
|
for _, key := range w.state.GetAllPortKeys() {
|
|
if _, exists := currentAssociations[key]; !exists {
|
|
w.state.DeletePortState(key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *Watcher) collectIPs(hostname string) []string {
|
|
hs, ok := w.state.GetHostnameState(hostname)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
ipSet := make(map[string]bool)
|
|
|
|
for _, nsState := range hs.RecordsByNameserver {
|
|
for _, ip := range nsState.Records["A"] {
|
|
ipSet[ip] = true
|
|
}
|
|
|
|
for _, ip := range nsState.Records["AAAA"] {
|
|
ipSet[ip] = true
|
|
}
|
|
}
|
|
|
|
result := make([]string, 0, len(ipSet))
|
|
for ip := range ipSet {
|
|
result = append(result, ip)
|
|
}
|
|
|
|
sort.Strings(result)
|
|
|
|
return result
|
|
}
|
|
|
|
func (w *Watcher) checkSinglePort(
|
|
ctx context.Context,
|
|
ip string,
|
|
port int,
|
|
hostnames []string,
|
|
) {
|
|
result, err := w.portCheck.CheckPort(ctx, ip, port)
|
|
if err != nil {
|
|
w.log.Error(
|
|
"port check failed",
|
|
"ip", ip,
|
|
"port", port,
|
|
"error", err,
|
|
)
|
|
|
|
return
|
|
}
|
|
|
|
key := fmt.Sprintf("%s:%d", ip, port)
|
|
now := time.Now().UTC()
|
|
prev, hasPrev := w.state.GetPortState(key)
|
|
|
|
if hasPrev && !w.firstRun && prev.Open != result.Open {
|
|
stateStr := "closed"
|
|
if result.Open {
|
|
stateStr = "open"
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"Hosts: %s\nAddress: %s\nPort now %s",
|
|
strings.Join(hostnames, ", "), key, stateStr,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"Port Change: "+key,
|
|
msg,
|
|
"warning",
|
|
)
|
|
}
|
|
|
|
w.state.SetPortState(key, &state.PortState{
|
|
Open: result.Open,
|
|
Hostnames: hostnames,
|
|
LastChecked: now,
|
|
})
|
|
}
|
|
|
|
func (w *Watcher) runTLSChecks(ctx context.Context) {
|
|
for _, hostname := range w.config.Hostnames {
|
|
w.checkTLSForHostname(ctx, hostname)
|
|
}
|
|
|
|
for _, domain := range w.config.Domains {
|
|
w.checkTLSForHostname(ctx, domain)
|
|
}
|
|
}
|
|
|
|
func (w *Watcher) checkTLSForHostname(
|
|
ctx context.Context,
|
|
hostname string,
|
|
) {
|
|
ips := w.collectIPs(hostname)
|
|
|
|
for _, ip := range ips {
|
|
portKey := fmt.Sprintf("%s:%d", ip, tlsPort)
|
|
|
|
ps, ok := w.state.GetPortState(portKey)
|
|
if !ok || !ps.Open {
|
|
continue
|
|
}
|
|
|
|
w.checkTLSCert(ctx, ip, hostname)
|
|
}
|
|
}
|
|
|
|
func (w *Watcher) checkTLSCert(
|
|
ctx context.Context,
|
|
ip string,
|
|
hostname string,
|
|
) {
|
|
cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname)
|
|
certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname)
|
|
now := time.Now().UTC()
|
|
prev, hasPrev := w.state.GetCertificateState(certKey)
|
|
|
|
if err != nil {
|
|
w.handleTLSError(
|
|
ctx, certKey, hostname, ip,
|
|
hasPrev, prev, now, err,
|
|
)
|
|
|
|
return
|
|
}
|
|
|
|
w.handleTLSSuccess(
|
|
ctx, certKey, hostname, ip,
|
|
hasPrev, prev, now, cert,
|
|
)
|
|
}
|
|
|
|
func (w *Watcher) handleTLSError(
|
|
ctx context.Context,
|
|
certKey, hostname, ip string,
|
|
hasPrev bool,
|
|
prev *state.CertificateState,
|
|
now time.Time,
|
|
err error,
|
|
) {
|
|
if hasPrev && !w.firstRun && prev.Status == "ok" {
|
|
msg := fmt.Sprintf(
|
|
"Host: %s\nIP: %s\nError: %s",
|
|
hostname, ip, err,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"TLS Failure: "+hostname,
|
|
msg,
|
|
"error",
|
|
)
|
|
}
|
|
|
|
w.state.SetCertificateState(
|
|
certKey, &state.CertificateState{
|
|
Status: "error",
|
|
Error: err.Error(),
|
|
LastChecked: now,
|
|
},
|
|
)
|
|
}
|
|
|
|
func (w *Watcher) handleTLSSuccess(
|
|
ctx context.Context,
|
|
certKey, hostname, ip string,
|
|
hasPrev bool,
|
|
prev *state.CertificateState,
|
|
now time.Time,
|
|
cert *tlscheck.CertificateInfo,
|
|
) {
|
|
if hasPrev && !w.firstRun {
|
|
w.detectTLSChanges(ctx, hostname, ip, prev, cert)
|
|
}
|
|
|
|
w.checkTLSExpiry(ctx, hostname, ip, cert)
|
|
|
|
w.state.SetCertificateState(
|
|
certKey, &state.CertificateState{
|
|
CommonName: cert.CommonName,
|
|
Issuer: cert.Issuer,
|
|
NotAfter: cert.NotAfter,
|
|
SubjectAlternativeNames: cert.SubjectAlternativeNames,
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
},
|
|
)
|
|
}
|
|
|
|
func (w *Watcher) detectTLSChanges(
|
|
ctx context.Context,
|
|
hostname, ip string,
|
|
prev *state.CertificateState,
|
|
cert *tlscheck.CertificateInfo,
|
|
) {
|
|
if prev.Status == "error" {
|
|
msg := fmt.Sprintf(
|
|
"Host: %s\nIP: %s\nTLS recovered",
|
|
hostname, ip,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"TLS Recovery: "+hostname,
|
|
msg,
|
|
"success",
|
|
)
|
|
|
|
return
|
|
}
|
|
|
|
changed := prev.CommonName != cert.CommonName ||
|
|
prev.Issuer != cert.Issuer ||
|
|
!sliceEqual(
|
|
prev.SubjectAlternativeNames,
|
|
cert.SubjectAlternativeNames,
|
|
)
|
|
|
|
if !changed {
|
|
return
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"Host: %s\nIP: %s\n"+
|
|
"Old CN: %s, Issuer: %s\n"+
|
|
"New CN: %s, Issuer: %s",
|
|
hostname, ip,
|
|
prev.CommonName, prev.Issuer,
|
|
cert.CommonName, cert.Issuer,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"TLS Certificate Change: "+hostname,
|
|
msg,
|
|
"warning",
|
|
)
|
|
}
|
|
|
|
func (w *Watcher) checkTLSExpiry(
|
|
ctx context.Context,
|
|
hostname, ip string,
|
|
cert *tlscheck.CertificateInfo,
|
|
) {
|
|
daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay
|
|
warningDays := float64(w.config.TLSExpiryWarning)
|
|
|
|
if daysLeft > warningDays {
|
|
return
|
|
}
|
|
|
|
// Deduplicate expiry warnings: don't re-notify for the same
|
|
// hostname within the TLS check interval.
|
|
dedupKey := fmt.Sprintf("expiry:%s:%s", hostname, ip)
|
|
|
|
w.expiryNotifiedMu.Lock()
|
|
|
|
lastNotified, seen := w.expiryNotified[dedupKey]
|
|
if seen && time.Since(lastNotified) < w.config.TLSInterval {
|
|
w.expiryNotifiedMu.Unlock()
|
|
|
|
return
|
|
}
|
|
|
|
w.expiryNotified[dedupKey] = time.Now()
|
|
w.expiryNotifiedMu.Unlock()
|
|
|
|
msg := fmt.Sprintf(
|
|
"Host: %s\nIP: %s\nCN: %s\n"+
|
|
"Expires: %s (%.0f days)",
|
|
hostname, ip, cert.CommonName,
|
|
cert.NotAfter.Format(time.RFC3339),
|
|
daysLeft,
|
|
)
|
|
|
|
w.notify.SendNotification(
|
|
ctx,
|
|
"TLS Expiry Warning: "+hostname,
|
|
msg,
|
|
"warning",
|
|
)
|
|
}
|
|
|
|
func (w *Watcher) saveState() {
|
|
err := w.state.Save()
|
|
if err != nil {
|
|
w.log.Error("failed to save state", "error", err)
|
|
}
|
|
}
|
|
|
|
// --- Utility functions ---
|
|
|
|
func toSet(items []string) map[string]bool {
|
|
set := make(map[string]bool, len(items))
|
|
for _, item := range items {
|
|
set[item] = true
|
|
}
|
|
|
|
return set
|
|
}
|
|
|
|
func recordsEqual(
|
|
a, b map[string][]string,
|
|
) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
for k, av := range a {
|
|
bv, ok := b[k]
|
|
if !ok || !sliceEqual(av, bv) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func sliceEqual(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
aSorted := make([]string, len(a))
|
|
bSorted := make([]string, len(b))
|
|
|
|
copy(aSorted, a)
|
|
copy(bSorted, b)
|
|
|
|
sort.Strings(aSorted)
|
|
sort.Strings(bSorted)
|
|
|
|
for i := range aSorted {
|
|
if aSorted[i] != bSorted[i] {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|