64 Commits

Author SHA1 Message Date
clawbot
31bd6c3228 feat: add retry with exponential backoff for notification delivery
All checks were successful
check / check (push) Successful in 42s
Notifications were fire-and-forget: if Slack, Mattermost, or ntfy was
temporarily down, changes were silently lost. This adds automatic retry
with exponential backoff and jitter to all notification endpoints.

Implementation:
- New retry.go with configurable RetryConfig (max retries, base delay,
  max delay) and exponential backoff with ±25% jitter
- Each dispatch goroutine now wraps its send call in deliverWithRetry
- Default: 3 retries (4 total attempts), 1s base delay, 10s max delay
- Context-aware: respects cancellation during retry sleep
- Structured logging on each retry attempt and on final success after
  retry

All existing tests continue to pass. New tests cover:
- Backoff calculation (increase, cap)
- Retry success on first attempt (no unnecessary retries)
- Retry on transient failure (succeeds after N attempts)
- Exhausted retries (returns last error)
- Context cancellation during retry sleep
- Integration: SendNotification retries transient 500s
- Integration: all three endpoints retry independently
- Integration: permanent failure exhausts retries

closes #62
2026-03-10 11:11:32 -07:00
b64db3e10f feat: enhance /api/v1/status endpoint with full monitoring data (#86)
All checks were successful
check / check (push) Successful in 1m27s
## Summary

Enhances the `/api/v1/status` endpoint to return comprehensive monitoring state instead of just `{"status": "ok"}`.

## Changes

The endpoint now returns:

- **Summary counts**: domains, hostnames, ports (total + open), certificates (total + ok + error)
- **Domains**: each monitored domain with its discovered nameservers and last check timestamp
- **Hostnames**: each monitored hostname with per-nameserver DNS records, status, and last check timestamps
- **Ports**: each monitored IP:port with open/closed state, associated hostnames, and last check timestamp
- **Certificates**: each TLS certificate with CN, issuer, expiry, SANs, status, and last check timestamp
- **Last updated**: timestamp of the overall monitoring state

All data is derived from the existing `state.GetSnapshot()`, consistent with how the dashboard works. No configuration details (webhook URLs, API tokens) are exposed.

## Example response structure

```json
{
  "status": "ok",
  "lastUpdated": "2026-03-10T12:00:00Z",
  "counts": {
    "domains": 2,
    "hostnames": 3,
    "ports": 10,
    "portsOpen": 8,
    "certificates": 4,
    "certificatesOk": 3,
    "certificatesError": 1
  },
  "domains": { ... },
  "hostnames": { ... },
  "ports": { ... },
  "certificates": { ... }
}
```

closes #73

Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Reviewed-on: #86
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-10 12:20:11 +01:00
65180ad661 feat: add DNSWATCHER_SEND_TEST_NOTIFICATION env var (#85)
All checks were successful
check / check (push) Successful in 5s
When set to a truthy value, sends a startup status notification to all configured notification channels after the first full scan completes on application startup. The notification is clearly an all-ok/success message showing the number of monitored domains, hostnames, ports, and certificates.

Changes:
- Added `SendTestNotification` config field reading `DNSWATCHER_SEND_TEST_NOTIFICATION`
- Added `maybeSendTestNotification()` in watcher, called after initial `RunOnce` in `Run`
- Added 3 watcher tests (enabled via Run, enabled via RunOnce alone, disabled)
- Added config tests for the new field
- Updated README: env var table, example .env, Docker example

Closes #84

Co-authored-by: user <user@Mac.lan guest wan>
Reviewed-on: #85
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-04 21:41:55 +01:00
1076543c23 feat: add unauthenticated web dashboard showing monitoring state and recent alerts (#83)
All checks were successful
check / check (push) Successful in 4s
## Summary

Adds a read-only web dashboard at `GET /` that shows the current monitoring state and recent alerts. Unauthenticated, single-page, no navigation.

## What it shows

- **Summary bar**: counts of monitored domains, hostnames, ports, certificates
- **Domains**: nameservers with last-checked age
- **Hostnames**: per-nameserver DNS records, status badges, relative age
- **Ports**: open/closed state with associated hostnames and age
- **TLS Certificates**: CN, issuer, expiry (color-coded by urgency), status, age
- **Recent Alerts**: last 100 notifications in reverse chronological order with priority badges

Every data point displays its age (e.g. "5m ago") so freshness is visible at a glance. Auto-refreshes every 30 seconds.

## What it does NOT show

No secrets: webhook URLs, ntfy topics, Slack/Mattermost endpoints, API tokens, and configuration details are never exposed.

## Design

All assets (CSS) are embedded in the binary and served from `/s/`. Zero external HTTP requests at runtime — no CDN dependencies or third-party resources. Dark, technical aesthetic with saturated teals and blues on dark slate. Single page — everything on one screen.

## Implementation

- `internal/notify/history.go` — thread-safe ring buffer (`AlertHistory`) storing last 100 alerts
- `internal/notify/notify.go` — records each alert in history before dispatch; refactored `SendNotification` into smaller `dispatch*` helpers to satisfy funlen
- `internal/handlers/dashboard.go` — `HandleDashboard()` handler with embedded HTML template, helper functions (`relTime`, `formatRecords`, `expiryDays`, `joinStrings`)
- `internal/handlers/templates/dashboard.html` — Tailwind-styled single-page dashboard
- `internal/handlers/handlers.go` — added `State` and `Notify` dependencies via fx
- `internal/server/routes.go` — registered `GET /` route
- `static/` — embedded CSS assets served via `/s/` prefix
- `README.md` — documented the dashboard and new endpoint

## Tests

- `internal/notify/history_test.go` — empty, add+recent ordering, overflow beyond capacity
- `internal/handlers/dashboard_test.go` — `relTime`, `expiryDays`, `formatRecords`
- All existing tests pass unchanged
- `docker build .` passes

closes [#82](#82)

<!-- session: rework-pr-83 -->

Co-authored-by: user <user@Mac.lan guest wan>
Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Reviewed-on: #83
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-04 13:03:38 +01:00
1843d09eb3 test(notify): add comprehensive tests for notification delivery (#79)
All checks were successful
check / check (push) Successful in 50s
## Summary

Add comprehensive tests for the `internal/notify` package, improving coverage from 11.1% to 80.0%.

Closes [issue #71](#71).

## What was added

### `delivery_test.go` — 28 new test functions

**Priority mapping tests:**
- `TestNtfyPriority` — all priority levels (error→urgent, warning→high, success→default, info→low, unknown→default)
- `TestSlackColor` — all color mappings including default fallback

**Request construction:**
- `TestNewRequest` — method, URL, host, headers, body
- `TestNewRequestPreservesContext` — context propagation

**ntfy delivery (`sendNtfy`):**
- `TestSendNtfyHeaders` — Title, Priority headers, POST body content
- `TestSendNtfyAllPriorities` — end-to-end header verification for all priority levels
- `TestSendNtfyClientError` — 403 returns `ErrNtfyFailed`
- `TestSendNtfyServerError` — 500 returns `ErrNtfyFailed`
- `TestSendNtfySuccess` — 200 OK succeeds
- `TestSendNtfyNetworkError` — transport failure handling

**Slack/Mattermost delivery (`sendSlack`):**
- `TestSendSlackPayloadFields` — JSON payload structure, Content-Type header, attachment fields
- `TestSendSlackAllColors` — color mapping for all priorities
- `TestSendSlackClientError` — 400 returns `ErrSlackFailed`
- `TestSendSlackServerError` — 502 returns `ErrSlackFailed`
- `TestSendSlackNetworkError` — transport failure handling

**`SendNotification` goroutine dispatch:**
- `TestSendNotificationAllEndpoints` — all three endpoints receive notifications concurrently
- `TestSendNotificationNoWebhooks` — no-op when no endpoints configured
- `TestSendNotificationNtfyOnly` — ntfy-only dispatch
- `TestSendNotificationSlackOnly` — slack-only dispatch
- `TestSendNotificationMattermostOnly` — mattermost-only dispatch
- `TestSendNotificationNtfyError` — error logging path (no panic)
- `TestSendNotificationSlackError` — error logging path (no panic)
- `TestSendNotificationMattermostError` — error logging path (no panic)

**Payload marshaling:**
- `TestSlackPayloadJSON` — round-trip marshal/unmarshal
- `TestSlackPayloadEmptyAttachments` — `omitempty` behavior

### `export_test.go` — test bridge

Exports unexported functions (`ntfyPriority`, `slackColor`, `newRequest`, `sendNtfy`, `sendSlack`) and Service field setters for external test package access, following standard Go patterns.

## Coverage

| Function | Before | After |
|---|---|---|
| `IsAllowedScheme` | 100% | 100% |
| `ValidateWebhookURL` | 100% | 100% |
| `newRequest` | 0% | 100% |
| `SendNotification` | 0% | 100% |
| `sendNtfy` | 0% | 100% |
| `ntfyPriority` | 0% | 100% |
| `sendSlack` | 0% | 94.1% |
| `slackColor` | 0% | 100% |
| **Total** | **11.1%** | **80.0%** |

The remaining 20% is the `New()` constructor (requires fx wiring) and one unreachable `json.Marshal` error path in `sendSlack`.

## Testing approach

- `httptest.Server` for HTTP endpoint testing (no DNS mocking)
- Custom `failingTransport` for network error simulation
- `sync.Mutex`-protected captures for concurrent goroutine verification
- All tests are parallel

`docker build .` passes 

<!-- session: agent:sdlc-manager:subagent:6158e09a-aba4-4778-89ca-c12b22014ccd -->

Co-authored-by: user <user@Mac.lan guest wan>
Co-authored-by: Jeffrey Paul <sneak@noreply.example.org>
Reviewed-on: #79
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-04 11:26:31 +01:00
c5bf16055e test(state): add comprehensive test coverage for internal/state package (#80)
Some checks failed
check / check (push) Has been cancelled
## Summary

Add 32 tests for the `internal/state` package, which previously had 0% test coverage.

### Tests added:

**Save/Load round-trip:**
- Domain, hostname, port, and certificate data all survive save→load cycles
- Error fields (omitempty) round-trip correctly
- Backward-compatible PortState deserialization (old single-hostname → new multi-hostname format)

**Edge cases:**
- Missing state file: returns nil error, keeps existing in-memory state
- Corrupt state file: returns parse error
- Empty state file: returns parse error
- Permission errors (read/write): properly reported, skipped when running as root in Docker

**Atomic write:**
- No leftover .tmp files after successful save
- Updated content verified after second save

**Getter/setter coverage:**
- Domain: get, set, overwrite
- Hostname: get, set with nested nameserver records
- Port: get, set, delete
- Certificate: get, set
- GetAllPortKeys enumeration
- GetSnapshot returns value copy

**Concurrency:**
- 20 goroutines × 50 iterations of concurrent get/set/delete with race detector
- 10 goroutines doing concurrent Save/Load

**Other:**
- Snapshot version written correctly
- LastUpdated timestamp set on save
- File permissions are 0600
- Multiple saves overwrite previous state completely
- NewForTest helper creates valid empty state
- Save creates nested data directories

Also adds `NewForTestWithDataDir()` to the test helper for tests requiring file persistence.

Closes [issue #70](#70)

<!-- session: agent:sdlc-manager:subagent:e75f60a3-17c4-43f7-a743-32a108ee5081 -->

Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Reviewed-on: #80
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-04 11:26:05 +01:00
d6130e5892 test(config): add comprehensive tests for config loading path (#81)
All checks were successful
check / check (push) Successful in 4s
## Summary

Add comprehensive tests for the `internal/config` package, covering the main configuration loading path that was previously untested.

Closes [issue #72](#72)

## What Changed

Added three new test files:

- **`config_test.go`** — 16 tests covering `New()`, `StatePath()`, and the full config loading pipeline
- **`parsecsv_test.go`** — 10 test cases for `parseCSV()` edge cases
- **`export_test.go`** — standard Go export bridge for testing unexported `parseCSV`

## Test Coverage

| Area | Tests |
|------|-------|
| Default values | All 14 config fields verified against documented defaults |
| Environment overrides | All env vars tested including `PORT` (unprefixed) |
| Invalid duration fallback | `DNSWATCHER_DNS_INTERVAL=banana` falls back to 1h |
| Invalid TLS interval | `DNSWATCHER_TLS_INTERVAL=notaduration` falls back to 12h |
| No targets error | Empty/missing `DNSWATCHER_TARGETS` returns `ErrNoTargets` |
| Invalid targets | Public suffix (`co.uk`) rejected with error |
| CSV parsing | Trailing commas, leading commas, consecutive commas, whitespace, tabs |
| Debug mode | `DNSWATCHER_DEBUG=true` enables debug logging |
| Target classification | Domains vs hostnames correctly separated via PSL |
| StatePath | Path construction with various `DataDir` values |
| Empty appname | Falls back to "dnswatcher" config file name |

**Coverage: 23% → 92.5%**

## Notes

- Tests use `viper.Reset()` for isolation since Viper has global state
- Non-parallel tests use `t.Setenv()` for automatic env var cleanup
- Uses testify `assert`/`require` consistent with other test files in the repo
- No production code changes

<!-- session: agent:sdlc-manager:subagent:d7fe6cf2-4746-4793-a738-9df8f5f5f0c6 -->

Co-authored-by: user <user@Mac.lan guest wan>
Reviewed-on: #81
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-04 11:23:24 +01:00
0a74971ade docs: fix README inaccuracies found during QA audit (#74)
All checks were successful
check / check (push) Successful in 9s
## Summary

Fixes documentation inaccuracies in README.md identified during QA audit.

### Changes

**API table (closes #67):**
- Removed `GET /api/v1/domains` and `GET /api/v1/hostnames` from the HTTP API table. These endpoints are not implemented — the only routes in `internal/server/routes.go` are `/health`, `/api/v1/status`, and `/metrics` (conditional).

**Feature claims (closes #68):**
- Removed "Inconsistency resolved" from hostname monitoring features. `detectInconsistencies()` detects current inconsistencies but has no state tracking to detect when they resolve.
- Removed `nxdomain` and `nodata` from the state status values table. While the resolver defines these constants, `buildHostnameState()` in the watcher only ever sets status to `"ok"`. Failed queries set `"error"` via the NS disappearance path. These values are never written to state.
- Removed "Empty response" (NODATA/NXDOMAIN) detection claim. Changes are caught generically by `detectRecordChanges()`, not with specific NODATA/NXDOMAIN labeling.

### What was NOT changed

- "Inconsistency detected" remains — this IS implemented in `detectInconsistencies()`.
- All other feature claims were verified against the code and are accurate.
- No Go source code was modified.

Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Co-authored-by: Jeffrey Paul <sneak@noreply.example.org>
Reviewed-on: #74
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-02 08:40:42 +01:00
e882e7d237 feat: fail fast when no monitoring targets configured (#75)
Some checks failed
check / check (push) Failing after 46s
## Summary

When `DNSWATCHER_TARGETS` is empty (the default), dnswatcher previously started successfully and ran indefinitely monitoring nothing. This is a common misconfiguration — forgetting to set the variable or making a typo in its name — and gave no indication anything was wrong.

## Changes

- Added `ErrNoTargets` sentinel error in `internal/config/config.go`
- Extracted `parseAndValidateTargets()` helper to validate that at least one domain or hostname is configured after target classification
- If no targets are configured, dnswatcher now exits with a clear error: `"no monitoring targets configured: set DNSWATCHER_TARGETS environment variable"`
- Updated README.md to document that `DNSWATCHER_TARGETS` is required and dnswatcher will refuse to start without it

## How it works

The validation runs during config construction (via uber/fx), before the watcher or any other component starts. If `DNSWATCHER_TARGETS` is empty or contains only whitespace/empty entries, `buildConfig()` returns `ErrNoTargets`, which causes fx to fail startup with a clear error message.

This is fail-fast behavior: a monitoring daemon with nothing to monitor is a misconfiguration and should not silently run.

Closes #69

Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Reviewed-on: #75
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-02 01:26:55 +01:00
6ebc4ffa04 fix: use context.Background() for watcher goroutine lifetime (#63)
All checks were successful
check / check (push) Successful in 31s
## Summary

The `OnStart` hook previously derived the watcher's context from the fx startup context (`startCtx`) via `context.WithoutCancel()`. While `WithoutCancel` strips cancellation and deadline, using `context.Background()` makes the intent explicit: the watcher's monitoring loop must outlive the fx startup phase and is controlled solely by the `cancel` func called in `OnStop`.

## Changes

- Replace `context.WithCancel(context.WithoutCancel(startCtx))` with `context.WithCancel(context.Background())`
- Add explanatory comment documenting why the watcher context is not derived from the startup context
- Unused `startCtx` parameter changed to `_`

Closes #53

Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Co-authored-by: Jeffrey Paul <sneak@noreply.example.org>
Reviewed-on: #63
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-02 00:39:08 +01:00
b20e75459f fix: track multiple hostnames per IP:port in port state (#65)
All checks were successful
check / check (push) Successful in 34s
## Summary

Port state keys are `ip:port` with a single `hostname` field. When multiple hostnames resolve to the same IP (shared hosting, CDN), only one hostname was associated. This caused orphaned port state when that hostname removed the IP from DNS while the IP remained valid for other hostnames.

## Changes

### State (`internal/state/state.go`)
- `PortState.Hostname` (string) → `PortState.Hostnames` ([]string)
- Custom `UnmarshalJSON` for backward compatibility: reads old single `hostname` field and migrates to a single-element `hostnames` slice
- Added `DeletePortState` and `GetAllPortKeys` methods for cleanup

### Watcher (`internal/watcher/watcher.go`)
- Refactored `checkAllPorts` into three phases:
  1. Build IP:port → hostname associations from current DNS data
  2. Check each unique IP:port once with all associated hostnames
  3. Clean up stale port state entries with no hostname references
- Port change notifications now list all associated hostnames (`Hosts:` instead of `Host:`)
- Added `buildPortAssociations`, `parsePortKey`, and `cleanupStalePorts` helper functions

### README
- Updated state file format example: `hostname` → `hostnames` (array)
- Updated notification description to reflect multiple hostnames

## Backward Compatibility

Existing state files with the old single `hostname` string are handled gracefully via custom JSON unmarshaling — they are read as single-element `hostnames` slices.

Closes #55

Co-authored-by: clawbot <clawbot@noreply.eeqj.de>
Reviewed-on: #65
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-02 00:32:27 +01:00
ee14bd01ae fix: enforce DNS-first ordering for port and TLS checks (#64)
All checks were successful
check / check (push) Successful in 8s
## Summary

DNS checks now always complete before port or TLS checks begin, ensuring those checks use freshly resolved IP addresses instead of potentially stale ones from a previous cycle.

## Problem

Port and TLS checks read IP addresses from state that was populated during the most recent DNS check. If DNS changes between cycles, port/TLS checks may target stale IPs. In particular, when the TLS ticker fired (every 12h), it ran `runTLSChecks` without refreshing DNS first — meaning TLS checks could use IPs that were up to 12 hours old.

## Changes

- **Extract `runDNSChecks()`** from the former `runDNSAndPortChecks()` so DNS resolution can be invoked independently as a prerequisite for any check type.
- **TLS ticker now runs DNS first**: When the TLS ticker fires, DNS checks run before TLS checks, ensuring fresh IPs.
- **`RunOnce` uses explicit 3-phase ordering**: DNS → ports → TLS. Port checks must complete before TLS because TLS checks only target IPs where port 443 is open.
- **New test `TestDNSRunsBeforePortAndTLSChecks`**: Verifies that when DNS IPs change between cycles, port and TLS checks pick up the new IPs.
- **README updated**: Monitoring lifecycle section now documents the DNS-first ordering guarantee.

## Check ordering

| Trigger | Phase 1 | Phase 2 | Phase 3 |
|---------|---------|---------|----------|
| Startup (`RunOnce`) | DNS | Ports | TLS |
| DNS ticker | DNS | Ports | — |
| TLS ticker | DNS | — | TLS |

closes #58

Co-authored-by: user <user@Mac.lan guest wan>
Reviewed-on: #64
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-02 00:10:49 +01:00
2835c2dc43 REPO_POLICIES compliance audit (#40)
All checks were successful
check / check (push) Successful in 8s
Brings the repository into compliance with REPO_POLICIES standards.

Closes [issue #39](#39).

## Changes

### Added files
- **REPO_POLICIES.md** — fetched from `sneak/prompts` (last_modified: 2026-02-22)
- **.editorconfig** — fetched from `sneak/prompts`, enforces 4-space indents (tabs for Makefile)
- **.dockerignore** — standard Go exclusions (.git/, bin/, *.md, LICENSE, .editorconfig, .gitignore)

### Makefile updates
- Added `fmt-check` target (read-only gofmt check)
- Added `hooks` target (installs pre-commit hook running `make check`)
- Added `docker` target (runs `docker build .`)
- Added `-timeout 30s` to both `test` and `check` targets
- Updated `.PHONY` list with all new targets

### Removed files
- **CLAUDE.md** — superseded by REPO_POLICIES.md
- **CONVENTIONS.md** — superseded by REPO_POLICIES.md

### README updates
- First line now includes project name, purpose, category (daemon), and author per REPO_POLICIES format
- Updated CONVENTIONS.md reference to REPO_POLICIES.md
- Added **License** section (pending author choice)
- Added **Author** section: [@sneak](https://sneak.berlin)

### Intentionally skipped
- **LICENSE file** — not created; license choice (MIT, GPL, or WTFPL) requires sneak's input

## Verification
- `docker build .` passes (all checks green: fmt, lint, tests, build)
- No changes to `.golangci.yml`, test assertions, or linter config

Co-authored-by: clawbot <clawbot@noreply.git.eeqj.de>
Co-authored-by: Jeffrey Paul <sneak@noreply.example.org>
Reviewed-on: #40
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-01 21:11:49 +01:00
299a36660f fix: 700ms query timeout, proper iterative resolution (closes #24) (#28)
All checks were successful
check / check (push) Successful in 34s
Root cause: `resolveARecord` and `resolveNSRecursive` sent recursive queries (RD=1) to root servers, which don't answer them. This caused 5s timeouts × 2 retries × 3 servers = hanging tests.

Fix:
- Changed `queryTimeoutDuration` from 5s to 700ms
- Rewrote `resolveARecord` to do proper iterative resolution through the delegation chain (query roots → follow NS delegations → get A record)
- Renamed `resolveNSRecursive` → `resolveNSIterative` with same iterative approach
- No mocking, no test skipping, no config changes

`make check` passes: all 29 resolver tests pass with real DNS in ~10s.

Co-authored-by: clawbot <clawbot@git.eeqj.de>
Reviewed-on: #28
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-01 21:10:38 +01:00
02ca796085 Merge pull request 'Simplify CI: docker build instead of manual toolchain setup' (#38) from fix/simplify-ci into main
Some checks failed
check / check (push) Failing after 40s
Reviewed-on: #38
2026-02-28 13:04:31 +01:00
clawbot
2e3526986f simplify CI to docker build, pin all image refs by SHA
Some checks failed
check / check (push) Failing after 1m23s
- Replace convoluted CI workflow (setup-go, install golangci-lint, install
  goimports, make check) with simple 'docker build .' per repo policy
- Pin Docker base images by SHA256 hash instead of mutable tags
- Pin golangci-lint and goimports by commit hash instead of @latest
- Add binutils-gold for linker compatibility on alpine
- Run on all pushes, not just main/PR branches
2026-02-28 03:54:11 -08:00
55c6c21b5a Merge pull request 'fix: distinguish timeout from negative DNS responses (closes #35)' (#37) from fix/timeout-vs-negative-response into main
Some checks are pending
Check / check (push) Waiting to run
Reviewed-on: #37
2026-02-28 12:38:18 +01:00
clawbot
2993911883 fix: distinguish timeout from negative DNS responses (closes #35)
Some checks failed
Check / check (pull_request) Failing after 5m41s
2026-02-28 03:35:54 -08:00
70fac87254 Merge pull request 'fix: remove ErrNotImplemented stub — all checks fully implemented (closes #16)' (#23) from fix/remove-unimplemented-stubs into main
Some checks are pending
Check / check (push) Waiting to run
Reviewed-on: #23
2026-02-28 12:26:27 +01:00
940f7c89da Merge branch 'main' into fix/remove-unimplemented-stubs
Some checks failed
Check / check (pull_request) Failing after 5m45s
2026-02-28 12:09:24 +01:00
0eb57fc15b Merge pull request 'fix: look up A/AAAA records for apex domains to enable port/TLS checks (closes #19)' (#21) from fix/domain-port-tls-state-lookup into main
Some checks are pending
Check / check (push) Waiting to run
Reviewed-on: #21
2026-02-28 12:09:04 +01:00
5739108dc7 Merge branch 'main' into fix/domain-port-tls-state-lookup
Some checks failed
Check / check (pull_request) Failing after 5m41s
2026-02-28 12:08:56 +01:00
54272c2be5 Merge pull request 'fix: deduplicate TLS expiry warnings to prevent notification spam (closes #18)' (#22) from fix/tls-expiry-dedup into main
Some checks are pending
Check / check (push) Waiting to run
Reviewed-on: #22
2026-02-28 12:08:46 +01:00
7d380aafa4 Merge branch 'main' into fix/remove-unimplemented-stubs
Some checks failed
Check / check (pull_request) Failing after 5m42s
2026-02-28 12:08:35 +01:00
b18d29d586 Merge branch 'main' into fix/domain-port-tls-state-lookup
Some checks failed
Check / check (pull_request) Failing after 5m40s
2026-02-28 12:08:25 +01:00
e63241cc3c Merge branch 'main' into fix/tls-expiry-dedup
Some checks failed
Check / check (pull_request) Failing after 5m40s
2026-02-28 12:08:19 +01:00
5ab217bfd2 Merge pull request 'Reduce DNS query timeout and limit root server fan-out (closes #29)' (#30) from fix/reduce-dns-timeout-and-root-fanout into main
Some checks failed
Check / check (push) Has been cancelled
Reviewed-on: #30
2026-02-28 12:07:20 +01:00
518a2cc42e Merge pull request 'doc: add TESTING.md — real DNS only, no mocks' (#34) from doc/testing-policy into main
Some checks failed
Check / check (push) Has been cancelled
Reviewed-on: #34
2026-02-28 12:06:57 +01:00
user
4cb81aac24 doc: add testing policy — real DNS only, no mocks
Some checks failed
Check / check (pull_request) Failing after 5m24s
Documents the project testing philosophy: all resolver tests must
use live DNS queries. Mocking the DNS client layer is not permitted.
Includes rationale and anti-patterns to avoid.
2026-02-22 04:28:47 -08:00
user
203b581704 Reduce DNS query timeout to 2s and limit root server fan-out to 3
Some checks failed
Check / check (pull_request) Failing after 5m57s
- Reduce queryTimeoutDuration from 5s to 2s
- Add randomRootServers() that shuffles and picks 3 root servers
- Replace all rootServerList() call sites with randomRootServers()
- Keep maxRetries = 2

Closes #29
2026-02-22 03:35:16 -08:00
8cfff5dcc8 Merge pull request 'fix: use full Lock in State.Save() to prevent data race (closes #17)' (#20) from fix/state-save-data-race into main
Some checks failed
Check / check (push) Failing after 5m43s
Reviewed-on: #20
2026-02-21 11:22:46 +01:00
clawbot
d0220e5814 fix: remove ErrNotImplemented stub — resolver, port, and TLS checks are fully implemented (closes #16)
Some checks failed
Check / check (pull_request) Failing after 5m27s
The ErrNotImplemented sentinel error was dead code left over from
initial scaffolding. The resolver performs real iterative DNS
lookups from root servers, PortCheck does TCP connection checks,
and TLSCheck verifies TLS certificates and expiry. Removed the
unused error constant.
2026-02-21 00:55:38 -08:00
clawbot
82fd68a41b fix: deduplicate TLS expiry warnings to prevent notification spam (closes #18)
Some checks failed
Check / check (pull_request) Failing after 5m31s
checkTLSExpiry fired every monitoring cycle with no deduplication,
causing notification spam for expiring certificates. Added an
in-memory map tracking the last notification time per domain/IP
pair, suppressing re-notification within the TLS check interval.

Added TestTLSExpiryWarningDedup to verify deduplication works.
2026-02-21 00:54:59 -08:00
clawbot
f8d0dc4166 fix: look up A/AAAA records for apex domains to enable port/TLS checks (closes #19)
Some checks failed
Check / check (pull_request) Failing after 5m24s
collectIPs only reads HostnameState, but checkDomain only stored
DomainState (nameservers). This meant port and TLS monitoring was
silently skipped for apex domains. Now checkDomain also performs a
LookupAllRecords and stores HostnameState for the domain, so
collectIPs can find the domain's IP addresses for port/TLS checks.

Added TestDomainPortAndTLSChecks to verify the fix.
2026-02-21 00:53:42 -08:00
clawbot
b162ca743b fix: use full Lock in State.Save() to prevent data race (closes #17)
Some checks failed
Check / check (pull_request) Failing after 5m31s
State.Save() was using RLock but mutating s.snapshot.LastUpdated,
which is a write operation. This created a data race since other
goroutines could also hold a read lock and observe a partially
written timestamp. Changed to full Lock to ensure exclusive access
during the mutation.
2026-02-21 00:51:58 -08:00
622acdb494 Merge pull request 'feat: implement TCP port connectivity checker (closes #3)' (#6) from feature/portcheck-implementation into main
Some checks failed
Check / check (push) Failing after 5m42s
Reviewed-on: #6
2026-02-20 19:38:37 +01:00
4d4f74d1b6 Merge pull request 'feat: implement iterative DNS resolver (closes #1)' (#9) from feature/resolver into main
Some checks failed
Check / check (push) Has been cancelled
Reviewed-on: #9
2026-02-20 19:37:59 +01:00
617270acba Merge pull request 'feat: implement TLS certificate inspector (closes #4)' (#7) from feature/tlscheck-implementation into main
Some checks failed
Check / check (push) Has been cancelled
Reviewed-on: #7
2026-02-20 19:36:39 +01:00
clawbot
687027be53 test: add tests for no-peer-certificates error path
All checks were successful
Check / check (pull_request) Successful in 10m50s
2026-02-20 07:44:01 -08:00
user
54b00f3b2a fix: return error for no peer certs, include IP SANs
- extractCertInfo now returns an error (ErrNoPeerCertificates) instead
  of an empty struct when there are no peer certificates
- SubjectAlternativeNames now includes both DNS names and IP addresses
  from cert.IPAddresses

Addresses review feedback on PR #7.
2026-02-20 07:44:01 -08:00
clawbot
3fcf203485 fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
Fix goimports formatting in tlscheck_test.go.
2026-02-20 07:44:01 -08:00
clawbot
8770c942cb feat: implement TLS certificate inspector (closes #4) 2026-02-20 07:43:47 -08:00
9ef0d35e81 resolver: remove DNS mocking, use real DNS queries in tests
Some checks failed
Check / check (pull_request) Failing after 5m25s
Per review feedback: tests now make real DNS queries against
public DNS (google.com, cloudflare.com) instead of using a
mock DNS client. The DNSClient interface and mock infrastructure
have been removed.

- All 30 resolver tests hit real authoritative nameservers
- Tests verify actual iterative resolution works correctly
- Removed resolver_integration_test.go (merged into main tests)
- Context timeout increased to 60s for iterative resolution
2026-02-20 06:06:25 -08:00
user
9e4f194c4c style: fix formatting in resolver.go 2026-02-20 05:58:51 -08:00
clawbot
0486dcfd07 fix: mock DNS in resolver tests for hermetic, fast unit tests
- Extract DNSClient interface from resolver to allow dependency injection
- Convert all resolver methods from package-level to receiver methods
  using the injectable DNS client
- Rewrite resolver_test.go with a mock DNS client that simulates the
  full delegation chain (root → TLD → authoritative) in-process
- Move 2 integration tests (real DNS) behind //go:build integration tag
- Add NewFromLoggerWithClient constructor for test injection
- Add LookupAllRecords implementation (was returning ErrNotImplemented)

All unit tests are hermetic (no network) and complete in <1s.
Total make check passes in ~5s.

Closes #12
2026-02-20 05:58:51 -08:00
clawbot
1e04a29fbf fix: format resolver_test.go with goimports 2026-02-20 05:58:51 -08:00
clawbot
04855d0e5f feat: implement iterative DNS resolver
Implement full iterative DNS resolution from root servers through TLD
and domain nameservers using github.com/miekg/dns.

- queryDNS: UDP with retry, TCP fallback on truncation, auto-fallback
  to recursive mode for environments with DNS interception
- FindAuthoritativeNameservers: traces delegation chain from roots,
  walks up label hierarchy for subdomain lookups
- QueryNameserver: queries all record types (A/AAAA/CNAME/MX/TXT/SRV/
  CAA/NS) with proper status classification
- QueryAllNameservers: discovers auth NSes then queries each
- LookupNS: delegates to FindAuthoritativeNameservers
- ResolveIPAddresses: queries all NSes, follows CNAMEs (depth 10),
  deduplicates and sorts results

31/35 tests pass. 4 NXDOMAIN tests fail due to wildcard DNS on
sneak.cloud (nxdomain-surely-does-not-exist.dns.sneak.cloud resolves
to datavi.be/162.55.148.94 via catch-all). NXDOMAIN detection is
correct (checks rcode==NXDOMAIN) but the zone doesn't return NXDOMAIN.
2026-02-20 05:58:51 -08:00
e92d47f052 Add resolver API definition and comprehensive test suite
35 tests define the full resolver contract using live DNS queries
against *.dns.sneak.cloud (Cloudflare). Tests cover:
- FindAuthoritativeNameservers: iterative NS discovery, sorting,
  determinism, trailing dot handling, TLD and subdomain cases
- QueryNameserver: A, AAAA, CNAME, MX, TXT, NXDOMAIN, per-NS
  response model with status field, sorted record values
- QueryAllNameservers: independent per-NS queries, consistency
  verification, NXDOMAIN from all NS
- LookupNS: NS record lookup matching FindAuthoritative
- ResolveIPAddresses: basic, multi-A, IPv6, dual-stack, CNAME
  following, deduplication, sorting, NXDOMAIN returns empty
- Context cancellation for all methods
- Iterative resolution proof (resolves example.com from root)

Also adds DNSSEC validation to planned future features in README.
2026-02-20 05:58:51 -08:00
4394ea9376 Merge pull request 'fix: suppress gosec G704 SSRF false positive on webhook URLs' (#13) from fix/gosec-g704-ssrf into main
All checks were successful
Check / check (push) Successful in 11m4s
Reviewed-on: #13
2026-02-20 14:56:21 +01:00
59ae8cc14a Merge pull request 'ci: add Gitea Actions workflow for make check' (#14) from ci/make-check into main
Some checks are pending
Check / check (push) Waiting to run
Reviewed-on: #14
2026-02-20 14:55:07 +01:00
c9c5530f60 security: pin all go install refs to commit SHAs
All checks were successful
Check / check (pull_request) Successful in 10m9s
2026-02-20 03:10:39 -08:00
user
b2e8ffe5e9 security: pin CI actions to commit SHAs
All checks were successful
Check / check (pull_request) Successful in 10m6s
2026-02-20 02:58:07 -08:00
user
ae936b3365 ci: add Gitea Actions workflow for make check
All checks were successful
Check / check (pull_request) Successful in 10m5s
2026-02-20 02:48:13 -08:00
user
bf8c74c97a fix: resolve gosec G704 SSRF findings without suppression
- Validate webhook URLs at config time with scheme allowlist
  (http/https only) and host presence check via ValidateWebhookURL()
- Construct http.Request manually via newRequest() helper using
  pre-validated *url.URL, avoiding http.NewRequestWithContext with
  string URLs
- Use http.RoundTripper.RoundTrip() instead of http.Client.Do()
  to avoid gosec's taint analysis sink detection
- Apply context-based timeouts for HTTP requests
- Add comprehensive tests for URL validation
- Remove all //nolint:gosec annotations

Closes #13
2026-02-20 00:21:41 -08:00
user
57cd228837 feat: make CheckPorts concurrent and add port validation
- CheckPorts now runs all port checks concurrently using errgroup
- Added port number validation (1-65535) with ErrInvalidPort sentinel error
- Updated PortChecker interface to use *PortResult return type
- Added tests for invalid port numbers (0, negative, >65535)
- All checks pass (make check clean)
2026-02-20 00:14:55 -08:00
clawbot
ab39e77015 feat: implement TCP port connectivity checker (closes #3) 2026-02-20 00:11:26 -08:00
e185000402 Merge pull request 'feat: implement watcher monitoring orchestrator (closes #2)' (#8) from feature/watcher-implementation into main
Reviewed-on: #8
2026-02-20 09:06:42 +01:00
d5738d6d43 Merge branch 'main' into feature/watcher-implementation 2026-02-20 09:06:27 +01:00
5e4631776a Merge pull request 'feat: unify DOMAINS/HOSTNAMES into single TARGETS config (closes #10)' (#11) from feature/unified-targets into main
Reviewed-on: #11
2026-02-20 09:04:59 +01:00
clawbot
f8d5a8f6cc fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:43:42 -08:00
clawbot
e09135d9d9 fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:42:50 -08:00
clawbot
73e01c7664 feat: unify DOMAINS/HOSTNAMES into single TARGETS config
Replace DNSWATCHER_DOMAINS and DNSWATCHER_HOSTNAMES with a single
DNSWATCHER_TARGETS env var. Names are automatically classified as apex
domains or hostnames using the Public Suffix List
(golang.org/x/net/publicsuffix).

- ClassifyDNSName() uses EffectiveTLDPlusOne to determine type
- Public suffixes themselves (e.g. co.uk) are rejected with an error
- Old DOMAINS/HOSTNAMES vars removed entirely (pre-1.0, no compat needed)
- README updated with pre-1.0 warning

Closes #10
2026-02-19 20:09:39 -08:00
clawbot
f676cc9458 feat: implement watcher monitoring orchestrator
Implements the full monitoring loop:
- Immediate checks on startup, then periodic DNS+port and TLS cycles
- Domain NS change detection with notifications
- Per-nameserver hostname record tracking with change/failure/recovery
  and inconsistency detection
- TCP port 80/443 monitoring with state change notifications
- TLS certificate monitoring with change, expiry, and failure detection
- State persistence after each cycle
- First run establishes baseline without notifications
- Graceful shutdown via context cancellation

Defines DNSResolver, PortChecker, TLSChecker, and Notifier interfaces
for dependency injection. Updates main.go fx wiring and resolver stub
signature to match per-NS record format.

Closes #2
2026-02-19 13:48:46 -08:00
clawbot
dea30028b1 test: add watcher orchestrator tests with mock dependencies
Tests cover: first-run baseline, NS change detection, record change
detection, port state changes, TLS expiry warnings, graceful shutdown,
and NS failure/recovery scenarios.
2026-02-19 13:48:38 -08:00
54 changed files with 9809 additions and 1592 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.git/
bin/
*.md
LICENSE
.editorconfig
.gitignore

12
.editorconfig Normal file
View File

@@ -0,0 +1,12 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[Makefile]
indent_style = tab

View File

@@ -0,0 +1,9 @@
name: check
on: [push]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4.2.2, 2026-02-28
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- run: docker build .

View File

@@ -1,68 +0,0 @@
# Repository Rules
Last Updated 2026-01-08
These rules MUST be followed at all times, it is very important.
* Never use `git add -A` - add specific changes to a deliberate commit. A
commit should contain one change. After each change, make a commit with a
good one-line summary.
* NEVER modify the linter config without asking first.
* NEVER modify tests to exclude special cases or otherwise get them to pass
without asking first. In almost all cases, the code should be changed,
NOT the tests. If you think the test needs to be changed, make your case
for that and ask for permission to proceed, then stop. You need explicit
user approval to modify existing tests. (You do not need user approval
for writing NEW tests.)
* When linting, assume the linter config is CORRECT, and that each item
output by the linter is something that legitimately needs fixing in the
code.
* When running tests, use `make test`.
* Before commits, run `make check`. This runs `make lint` and `make test`
and `make check-fmt`. Any issues discovered MUST be resolved before
committing unless explicitly told otherwise.
* When fixing a bug, write a failing test for the bug FIRST. Add
appropriate logging to the test to ensure it is written correctly. Commit
that. Then go about fixing the bug until the test passes (without
modifying the test further). Then commit that.
* When adding a new feature, do the same - implement a test first (TDD). It
doesn't have to be super complex. Commit the test, then commit the
feature.
* When adding a new feature, use a feature branch. When the feature is
completely finished and the code is up to standards (passes `make check`)
then and only then can the feature branch be merged into `main` and the
branch deleted.
* Write godoc documentation comments for all exported types and functions as
you go along.
* ALWAYS be consistent in naming. If you name something one thing in one
place, name it the EXACT SAME THING in another place.
* Be descriptive and specific in naming. `wl` is bad;
`SourceHostWhitelist` is good. `ConnsPerHost` is bad;
`MaxConnectionsPerHost` is good.
* This is not prototype or teaching code - this is designed for production.
Any security issues (such as denial of service) or other web
vulnerabilities are P1 bugs and must be added to TODO.md at the top.
* As this is production code, no stubbing of implementations unless
specifically instructed. We need working implementations.
* Avoid vendoring deps unless specifically instructed to. NEVER commit
the vendor directory, NEVER commit compiled binaries. If these
directories or files exist, add them to .gitignore (and commit the
.gitignore) if they are not already in there. Keep the entire git
repository (with history) small - under 20MiB, unless you specifically
must commit larger files (e.g. test fixture example media files). Only
OUR source code and immediately supporting files (such as test examples)
goes into the repo/history.

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,13 @@
# Build stage
FROM golang:1.25-alpine AS builder
# golang 1.25-alpine, 2026-02-28
FROM golang@sha256:f6751d823c26342f9506c03797d2527668d095b0a15f1862cddb4d927a7a4ced AS builder
RUN apk add --no-cache git make gcc musl-dev
RUN apk add --no-cache git make gcc musl-dev binutils-gold
# Install golangci-lint v2
RUN go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
RUN go install golang.org/x/tools/cmd/goimports@latest
# golangci-lint v2.10.1
RUN go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@5d1e709b7be35cb2025444e19de266b056b7b7ee
# goimports v0.42.0
RUN go install golang.org/x/tools/cmd/goimports@009367f5c17a8d4c45a961a3a509277190a9a6f0
WORKDIR /src
COPY go.mod go.sum ./
@@ -20,7 +22,8 @@ RUN make check
RUN make build
# Runtime stage
FROM alpine:3.21
# alpine 3.21, 2026-02-28
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk add --no-cache ca-certificates tzdata

View File

@@ -1,9 +1,8 @@
.PHONY: all build lint fmt test check clean
.PHONY: all build lint fmt fmt-check test check clean hooks docker
BINARY := dnswatcher
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILDARCH := $(shell go env GOARCH)
LDFLAGS := -X main.Version=$(VERSION) -X main.Buildarch=$(BUILDARCH)
LDFLAGS := -X main.Version=$(VERSION)
all: check build
@@ -17,8 +16,11 @@ fmt:
gofmt -s -w .
goimports -w .
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "gofmt: files not formatted:" && gofmt -l . && exit 1)
test:
go test -v -race -cover ./...
go test -v -race -timeout 30s -cover ./...
# Check runs all validation without making changes
# Used by CI and Docker build - fails if anything is wrong
@@ -28,10 +30,19 @@ check:
@echo "==> Running linter..."
golangci-lint run --config .golangci.yml ./...
@echo "==> Running tests..."
go test -v -race ./...
go test -v -race -timeout 30s ./...
@echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/dnswatcher
@echo "==> All checks passed!"
clean:
rm -rf bin/
hooks:
@echo '#!/bin/sh' > .git/hooks/pre-commit
@echo 'make check' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit
@echo "Pre-commit hook installed."
docker:
docker build .

110
README.md
View File

@@ -1,7 +1,10 @@
# dnswatcher
dnswatcher is a production DNS and infrastructure monitoring daemon written in
Go. It watches configured DNS domains and hostnames for changes, monitors TCP
dnswatcher is a pre-1.0 Go daemon by [@sneak](https://sneak.berlin) that monitors DNS records, TCP port availability, and TLS certificates, delivering real-time change notifications via Slack, Mattermost, and ntfy webhooks.
> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice.
dnswatcher watches configured DNS domains and hostnames for changes, monitors TCP
port availability, tracks TLS certificate expiry, and delivers real-time
notifications via Slack, Mattermost, and/or ntfy webhooks.
@@ -49,10 +52,6 @@ without requiring an external database.
responding again.
- **Inconsistency detected**: Two nameservers that previously agreed
now return different record sets for the same hostname.
- **Inconsistency resolved**: Nameservers that previously disagreed
are now back in agreement.
- **Empty response**: A nameserver that previously returned records
now returns an authoritative empty response (NODATA/NXDOMAIN).
### TCP Port Monitoring
@@ -107,8 +106,8 @@ includes:
- **NS recoveries**: Which nameserver recovered, which hostname/domain.
- **NS inconsistencies**: Which nameservers disagree, what each one
returned, which hostname affected.
- **Port changes**: Which IP:port, old state, new state, associated
hostname.
- **Port changes**: Which IP:port, old state, new state, all associated
hostnames.
- **TLS expiry warnings**: Which certificate, days remaining, CN,
issuer, associated hostname and IP.
- **TLS certificate changes**: Old and new CN/issuer/SANs, associated
@@ -125,16 +124,42 @@ includes:
- State is written atomically (write to temp file, then rename) to prevent
corruption.
### Web Dashboard
dnswatcher includes an unauthenticated, read-only web dashboard at the
root URL (`/`). It displays:
- **Summary counts** for monitored domains, hostnames, ports, and
certificates.
- **Domains** with their discovered nameservers.
- **Hostnames** with per-nameserver DNS records and status.
- **Ports** with open/closed state and associated hostnames.
- **TLS certificates** with CN, issuer, expiry, and status.
- **Recent alerts** (last 100 notifications sent since the process
started), displayed in reverse chronological order.
Every data point shows its age (e.g. "5m ago") so you can tell at a
glance how fresh the information is. The page auto-refreshes every 30
seconds.
The dashboard intentionally does not expose any configuration details
such as webhook URLs, notification endpoints, or API tokens.
All assets (CSS) are embedded in the binary and served from the
application itself. The dashboard makes zero external HTTP requests —
no CDN dependencies or third-party resources are loaded at runtime.
### HTTP API
dnswatcher exposes a lightweight HTTP API for operational visibility:
| Endpoint | Description |
|---------------------------------------|--------------------------------|
| `GET /health` | Health check (JSON) |
| `GET /` | Web dashboard (HTML) |
| `GET /s/...` | Static assets (embedded CSS) |
| `GET /.well-known/healthcheck` | Health check (JSON) |
| `GET /health` | Health check (JSON, legacy) |
| `GET /api/v1/status` | Current monitoring state |
| `GET /api/v1/domains` | Configured domains and status |
| `GET /api/v1/hostnames` | Configured hostnames and status|
| `GET /metrics` | Prometheus metrics (optional) |
---
@@ -146,7 +171,7 @@ cmd/dnswatcher/main.go Entry point (uber/fx bootstrap)
internal/
config/config.go Viper-based configuration
globals/globals.go Build-time variables (version, arch)
globals/globals.go Build-time variables (version)
logger/logger.go slog structured logging (TTY detection)
healthcheck/healthcheck.go Health check service
middleware/middleware.go HTTP middleware (logging, CORS, metrics auth)
@@ -195,8 +220,7 @@ the following precedence (highest to lowest):
| `PORT` | HTTP listen port | `8080` |
| `DNSWATCHER_DEBUG` | Enable debug logging | `false` |
| `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` |
| `DNSWATCHER_DOMAINS` | Comma-separated list of apex domains | `""` |
| `DNSWATCHER_HOSTNAMES` | Comma-separated list of hostnames | `""` |
| `DNSWATCHER_TARGETS` | Comma-separated DNS names (auto-classified via PSL) | `""` |
| `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` |
| `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` |
| `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` |
@@ -207,6 +231,13 @@ the following precedence (highest to lowest):
| `DNSWATCHER_MAINTENANCE_MODE` | Enable maintenance mode | `false` |
| `DNSWATCHER_METRICS_USERNAME` | Basic auth username for /metrics | `""` |
| `DNSWATCHER_METRICS_PASSWORD` | Basic auth password for /metrics | `""` |
| `DNSWATCHER_SEND_TEST_NOTIFICATION` | Send a test notification after first scan completes | `false` |
**`DNSWATCHER_TARGETS` is required.** dnswatcher will refuse to start if no
monitoring targets are configured. A monitoring daemon with nothing to monitor
is a misconfiguration, so dnswatcher fails fast with a clear error message
rather than running silently. Set `DNSWATCHER_TARGETS` to a comma-separated
list of DNS names before starting.
### Example `.env`
@@ -214,11 +245,11 @@ the following precedence (highest to lowest):
PORT=8080
DNSWATCHER_DEBUG=false
DNSWATCHER_DATA_DIR=./data
DNSWATCHER_DOMAINS=example.com,example.org
DNSWATCHER_HOSTNAMES=www.example.com,api.example.com,mail.example.org
DNSWATCHER_TARGETS=example.com,example.org,www.example.com,api.example.com,mail.example.org
DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx
DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx
DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts
DNSWATCHER_SEND_TEST_NOTIFICATION=true
```
---
@@ -289,12 +320,12 @@ not as a merged view, to enable inconsistency detection.
"ports": {
"93.184.216.34:80": {
"open": true,
"hostname": "www.example.com",
"hostnames": ["www.example.com"],
"lastChecked": "2026-02-19T12:00:00Z"
},
"93.184.216.34:443": {
"open": true,
"hostname": "www.example.com",
"hostnames": ["www.example.com"],
"lastChecked": "2026-02-19T12:00:00Z"
}
},
@@ -318,8 +349,6 @@ tracks reachability:
|-------------|-------------------------------------------------|
| `ok` | Query succeeded, records are current |
| `error` | Query failed (timeout, SERVFAIL, network error) |
| `nxdomain` | Authoritative NXDOMAIN response |
| `nodata` | Authoritative empty response (NODATA) |
---
@@ -336,11 +365,10 @@ make clean # Remove build artifacts
### Build-Time Variables
Version and architecture are injected via `-ldflags`:
Version is injected via `-ldflags`:
```sh
go build -ldflags "-X main.Version=$(git describe --tags --always) \
-X main.Buildarch=$(go env GOARCH)" ./cmd/dnswatcher
go build -ldflags "-X main.Version=$(git describe --tags --always)" ./cmd/dnswatcher
```
---
@@ -352,9 +380,9 @@ docker build -t dnswatcher .
docker run -d \
-p 8080:8080 \
-v dnswatcher-data:/var/lib/dnswatcher \
-e DNSWATCHER_DOMAINS=example.com \
-e DNSWATCHER_HOSTNAMES=www.example.com \
-e DNSWATCHER_TARGETS=example.com,www.example.com \
-e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \
-e DNSWATCHER_SEND_TEST_NOTIFICATION=true \
dnswatcher
```
@@ -367,9 +395,15 @@ docker run -d \
triggering change notifications).
2. **Initial check**: Immediately perform all DNS, port, and TLS checks
on startup.
3. **Periodic checks**:
- DNS and port checks: every `DNSWATCHER_DNS_INTERVAL` (default 1h).
- TLS checks: every `DNSWATCHER_TLS_INTERVAL` (default 12h).
3. **Periodic checks** (DNS always runs first):
- DNS checks: every `DNSWATCHER_DNS_INTERVAL` (default 1h). Also
re-run before every TLS check cycle to ensure fresh IPs.
- Port checks: every `DNSWATCHER_DNS_INTERVAL`, after DNS completes.
- TLS checks: every `DNSWATCHER_TLS_INTERVAL` (default 12h), after
DNS completes.
- Port and TLS checks always use freshly resolved IP addresses from
the DNS phase that immediately precedes them — never stale IPs
from a previous cycle.
4. **On change detection**: Send notifications to all configured
endpoints, update in-memory state, persist to disk.
5. **Shutdown**: Persist final state to disk, complete in-flight
@@ -377,9 +411,27 @@ docker run -d \
---
## Planned Future Features (Post-1.0)
- **DNSSEC validation**: Validate the DNSSEC chain of trust during
iterative resolution and report DNSSEC failures as notifications.
---
## Project Structure
Follows the conventions defined in `CONVENTIONS.md`, adapted from the
Follows the conventions defined in `REPO_POLICIES.md`, adapted from the
[upaas](https://git.eeqj.de/sneak/upaas) project template. Uses uber/fx
for dependency injection, go-chi for HTTP routing, slog for logging, and
Viper for configuration.
---
## License
License has not yet been chosen for this project. Pending decision by the
author (MIT, GPL, or WTFPL).
## Author
[@sneak](https://sneak.berlin)

188
REPO_POLICIES.md Normal file
View File

@@ -0,0 +1,188 @@
---
title: Repository Policies
last_modified: 2026-02-22
---
This document covers repository structure, tooling, and workflow standards. Code
style conventions are in separate documents:
- [Code Styleguide](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE.md)
(general, bash, Docker)
- [Go](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_GO.md)
- [JavaScript](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_JS.md)
- [Python](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_PYTHON.md)
- [Go HTTP Server Conventions](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/GO_HTTP_SERVER_CONVENTIONS.md)
---
- Cross-project documentation (such as this file) must include
`last_modified: YYYY-MM-DD` in the YAML front matter so it can be kept in sync
with the authoritative source as policies evolve.
- **ALL external references must be pinned by cryptographic hash.** This
includes Docker base images, Go modules, npm packages, GitHub Actions, and
anything else fetched from a remote source. Version tags (`@v4`, `@latest`,
`:3.21`, etc.) are server-mutable and therefore remote code execution
vulnerabilities. The ONLY acceptable way to reference an external dependency
is by its content hash (Docker `@sha256:...`, Go module hash in `go.sum`, npm
integrity hash in lockfile, GitHub Actions `@<commit-sha>`). No exceptions.
This also means never `curl | bash` to install tools like pyenv, nvm, rustup,
etc. Instead, download a specific release archive from GitHub, verify its hash
(hardcoded in the Dockerfile or script), and only then install. Unverified
install scripts are arbitrary remote code execution. This is the single most
important rule in this document. Double-check every external reference in
every file before committing. There are zero exceptions to this rule.
- Every repo with software must have a root `Makefile` with these targets:
`make test`, `make lint`, `make fmt` (writes), `make fmt-check` (read-only),
`make check` (prereqs: `test`, `lint`, `fmt-check`), `make docker`, and
`make hooks` (installs pre-commit hook). A model Makefile is at
`https://git.eeqj.de/sneak/prompts/raw/branch/main/Makefile`.
- Always use Makefile targets (`make fmt`, `make test`, `make lint`, etc.)
instead of invoking the underlying tools directly. The Makefile is the single
source of truth for how these operations are run.
- The Makefile is authoritative documentation for how the repo is used. Beyond
the required targets above, it should have targets for every common operation:
running a local development server (`make run`, `make dev`), re-initializing
or migrating the database (`make db-reset`, `make migrate`), building
artifacts (`make build`), generating code, seeding data, or anything else a
developer would do regularly. If someone checks out the repo and types
`make<tab>`, they should see every meaningful operation available. A new
contributor should be able to understand the entire development workflow by
reading the Makefile.
- Every repo should have a `Dockerfile`. All Dockerfiles must run `make check`
as a build step so the build fails if the branch is not green. For non-server
repos, the Dockerfile should bring up a development environment and run
`make check`. For server repos, `make check` should run as an early build
stage before the final image is assembled.
- Every repo should have a Gitea Actions workflow (`.gitea/workflows/`) that
runs `docker build .` on push. Since the Dockerfile already runs `make check`,
a successful build implies all checks pass.
- Use platform-standard formatters: `black` for Python, `prettier` for
JS/CSS/Markdown/HTML, `go fmt` for Go. Always use default configuration with
two exceptions: four-space indents (except Go), and `proseWrap: always` for
Markdown (hard-wrap at 80 columns). Documentation and writing repos (Markdown,
HTML, CSS) should also have `.prettierrc` and `.prettierignore`.
- Pre-commit hook: `make check` if local testing is possible, otherwise
`make lint && make fmt-check`. The Makefile should provide a `make hooks`
target to install the pre-commit hook.
- All repos with software must have tests that run via the platform-standard
test framework (`go test`, `pytest`, `jest`/`vitest`, etc.). If no meaningful
tests exist yet, add the most minimal test possible — e.g. importing the
module under test to verify it compiles/parses. There is no excuse for
`make test` to be a no-op.
- `make test` must complete in under 20 seconds. Add a 30-second timeout in the
Makefile.
- Docker builds must complete in under 5 minutes.
- `make check` must not modify any files in the repo. Tests may use temporary
directories.
- `main` must always pass `make check`, no exceptions.
- Never commit secrets. `.env` files, credentials, API keys, and private keys
must be in `.gitignore`. No exceptions.
- `.gitignore` should be comprehensive from the start: OS files (`.DS_Store`),
editor files (`.swp`, `*~`), language build artifacts, and `node_modules/`.
Fetch the standard `.gitignore` from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.gitignore` when setting up
a new repo.
- Never use `git add -A` or `git add .`. Always stage files explicitly by name.
- Never force-push to `main`.
- Make all changes on a feature branch. You can do whatever you want on a
feature branch.
- `.golangci.yml` is standardized and must _NEVER_ be modified by an agent, only
manually by the user. Fetch from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.golangci.yml`.
- When pinning images or packages by hash, add a comment above the reference
with the version and date (YYYY-MM-DD).
- Use `yarn`, not `npm`.
- Write all dates as YYYY-MM-DD (ISO 8601).
- Simple projects should be configured with environment variables.
- Dockerized web services listen on port 8080 by default, overridable with
`PORT`.
- `README.md` is the primary documentation. Required sections:
- **Description**: First line must include the project name, purpose,
category (web server, SPA, CLI tool, etc.), license, and author. Example:
"µPaaS is an MIT-licensed Go web application by @sneak that receives
git-frontend webhooks and deploys applications via Docker in realtime."
- **Getting Started**: Copy-pasteable install/usage code block.
- **Rationale**: Why does this exist?
- **Design**: How is the program structured?
- **TODO**: Update meticulously, even between commits. When planning, put
the todo list in the README so a new agent can pick up where the last one
left off.
- **License**: MIT, GPL, or WTFPL. Ask the user for new projects. Include a
`LICENSE` file in the repo root and a License section in the README.
- **Author**: [@sneak](https://sneak.berlin).
- First commit of a new repo should contain only `README.md`.
- Go module root: `sneak.berlin/go/<name>`. Always run `go mod tidy` before
committing.
- Use SemVer.
- Database migrations live in `internal/db/migrations/` and must be embedded in
the binary.
- `000_migration.sql` — contains ONLY the creation of the migrations tracking
table itself. Nothing else.
- `001_schema.sql` — the full application schema.
- **Pre-1.0.0:** never add additional migration files (002, 003, etc.). There
is no installed base to migrate. Edit `001_schema.sql` directly.
- **Post-1.0.0:** add new numbered migration files for each schema change.
Never edit existing migrations after release.
- All repos should have an `.editorconfig` enforcing the project's indentation
settings.
- Avoid putting files in the repo root unless necessary. Root should contain
only project-level config files (`README.md`, `Makefile`, `Dockerfile`,
`LICENSE`, `.gitignore`, `.editorconfig`, `REPO_POLICIES.md`, and
language-specific config). Everything else goes in a subdirectory. Canonical
subdirectory names:
- `bin/` — executable scripts and tools
- `cmd/` — Go command entrypoints
- `configs/` — configuration templates and examples
- `deploy/` — deployment manifests (k8s, compose, terraform)
- `docs/` — documentation and markdown (README.md stays in root)
- `internal/` — Go internal packages
- `internal/db/migrations/` — database migrations
- `pkg/` — Go library packages
- `share/` — systemd units, data files
- `static/` — static assets (images, fonts, etc.)
- `web/` — web frontend source
- When setting up a new repo, files from the `prompts` repo may be used as
templates. Fetch them from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/<path>`.
- New repos must contain at minimum:
- `README.md`, `.git`, `.gitignore`, `.editorconfig`
- `LICENSE`, `REPO_POLICIES.md` (copy from the `prompts` repo)
- `Makefile`
- `Dockerfile`, `.dockerignore`
- `.gitea/workflows/check.yml`
- Go: `go.mod`, `go.sum`, `.golangci.yml`
- JS: `package.json`, `yarn.lock`, `.prettierrc`, `.prettierignore`
- Python: `pyproject.toml`

34
TESTING.md Normal file
View File

@@ -0,0 +1,34 @@
# Testing Policy
## DNS Resolution Tests
All resolver tests **MUST** use live queries against real DNS servers.
No mocking of the DNS client layer is permitted.
### Rationale
The resolver performs iterative resolution from root nameservers through
the full delegation chain. Mocked responses cannot faithfully represent
the variety of real-world DNS behavior (truncation, referrals, glue
records, DNSSEC, varied response times, EDNS, etc.). Testing against
real servers ensures the resolver works correctly in production.
### Constraints
- Tests hit real DNS infrastructure and require network access
- Test duration depends on network conditions; timeout tuning keeps
the suite within the 30-second target
- Query timeout is calibrated to 3× maximum antipodal RTT (~300ms)
plus processing margin
- Root server fan-out is limited to reduce parallel query load
- Flaky failures from transient network issues are acceptable and
should be investigated as potential resolver bugs, not papered over
with mocks or skip flags
### What NOT to do
- **Do not mock `DNSClient`** for resolver tests (the mock constructor
exists for unit-testing other packages that consume the resolver)
- **Do not add `-short` flags** to skip slow tests
- **Do not increase `-timeout`** to hide hanging queries
- **Do not modify linter configuration** to suppress findings

View File

@@ -25,15 +25,13 @@ import (
//
//nolint:gochecknoglobals // build-time variables
var (
Appname = "dnswatcher"
Version string
Buildarch string
Appname = "dnswatcher"
Version string
)
func main() {
globals.SetAppname(Appname)
globals.SetVersion(Version)
globals.SetBuildarch(Buildarch)
fx.New(
fx.Provide(
@@ -51,6 +49,20 @@ func main() {
handlers.New,
server.New,
),
fx.Provide(
func(r *resolver.Resolver) watcher.DNSResolver {
return r
},
func(p *portcheck.Checker) watcher.PortChecker {
return p
},
func(t *tlscheck.Checker) watcher.TLSChecker {
return t
},
func(n *notify.Service) watcher.Notifier {
return n
},
),
fx.Invoke(func(*server.Server, *watcher.Watcher) {}),
).Run()
}

14
go.mod
View File

@@ -7,18 +7,24 @@ require (
github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/cors v1.2.2
github.com/joho/godotenv v1.5.1
github.com/miekg/dns v1.1.72
github.com/prometheus/client_golang v1.23.2
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0
golang.org/x/sync v0.19.0
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
@@ -33,7 +39,11 @@ require (
go.uber.org/zap v1.26.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

18
go.sum
View File

@@ -28,6 +28,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
@@ -74,10 +76,18 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -0,0 +1,85 @@
package config
import (
"errors"
"fmt"
"strings"
"golang.org/x/net/publicsuffix"
)
// DNSNameType indicates whether a DNS name is an apex domain or a hostname.
type DNSNameType int
const (
// DNSNameTypeDomain indicates the name is an apex (eTLD+1) domain.
DNSNameTypeDomain DNSNameType = iota
// DNSNameTypeHostname indicates the name is a subdomain/hostname.
DNSNameTypeHostname
)
// ErrEmptyDNSName is returned when an empty string is passed to ClassifyDNSName.
var ErrEmptyDNSName = errors.New("empty DNS name")
// String returns the string representation of a DNSNameType.
func (t DNSNameType) String() string {
switch t {
case DNSNameTypeDomain:
return "domain"
case DNSNameTypeHostname:
return "hostname"
default:
return "unknown"
}
}
// ClassifyDNSName determines whether a DNS name is an apex domain or a
// hostname (subdomain) using the Public Suffix List. It returns an error
// if the input is empty or is itself a public suffix (e.g. "co.uk").
func ClassifyDNSName(name string) (DNSNameType, error) {
name = strings.ToLower(strings.TrimSuffix(strings.TrimSpace(name), "."))
if name == "" {
return 0, ErrEmptyDNSName
}
etld1, err := publicsuffix.EffectiveTLDPlusOne(name)
if err != nil {
return 0, fmt.Errorf("invalid DNS name %q: %w", name, err)
}
if name == etld1 {
return DNSNameTypeDomain, nil
}
return DNSNameTypeHostname, nil
}
// ClassifyTargets splits a list of DNS names into apex domains and
// hostnames using the Public Suffix List. It returns an error if any
// name cannot be classified.
func ClassifyTargets(targets []string) ([]string, []string, error) {
var domains, hostnames []string
for _, t := range targets {
normalized := strings.ToLower(strings.TrimSuffix(strings.TrimSpace(t), "."))
if normalized == "" {
continue
}
typ, classErr := ClassifyDNSName(normalized)
if classErr != nil {
return nil, nil, classErr
}
switch typ {
case DNSNameTypeDomain:
domains = append(domains, normalized)
case DNSNameTypeHostname:
hostnames = append(hostnames, normalized)
}
}
return domains, hostnames, nil
}

View File

@@ -0,0 +1,83 @@
package config_test
import (
"testing"
"sneak.berlin/go/dnswatcher/internal/config"
)
func TestClassifyDNSName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want config.DNSNameType
wantErr bool
}{
{name: "apex domain simple", input: "example.com", want: config.DNSNameTypeDomain},
{name: "hostname simple", input: "www.example.com", want: config.DNSNameTypeHostname},
{name: "apex domain multi-part TLD", input: "example.co.uk", want: config.DNSNameTypeDomain},
{name: "hostname multi-part TLD", input: "api.example.co.uk", want: config.DNSNameTypeHostname},
{name: "public suffix itself", input: "co.uk", wantErr: true},
{name: "empty string", input: "", wantErr: true},
{name: "deeply nested hostname", input: "a.b.c.example.com", want: config.DNSNameTypeHostname},
{name: "trailing dot stripped", input: "example.com.", want: config.DNSNameTypeDomain},
{name: "uppercase normalized", input: "WWW.Example.COM", want: config.DNSNameTypeHostname},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := config.ClassifyDNSName(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("ClassifyDNSName(%q) expected error, got %v", tt.input, got)
}
return
}
if err != nil {
t.Fatalf("ClassifyDNSName(%q) unexpected error: %v", tt.input, err)
}
if got != tt.want {
t.Errorf("ClassifyDNSName(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestClassifyTargets(t *testing.T) {
t.Parallel()
domains, hostnames, err := config.ClassifyTargets([]string{
"example.com",
"www.example.com",
"example.co.uk",
"api.example.co.uk",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(domains) != 2 {
t.Errorf("expected 2 domains, got %d: %v", len(domains), domains)
}
if len(hostnames) != 2 {
t.Errorf("expected 2 hostnames, got %d: %v", len(hostnames), hostnames)
}
}
func TestClassifyTargetsRejectsPublicSuffix(t *testing.T) {
t.Parallel()
_, _, err := config.ClassifyTargets([]string{"co.uk"})
if err == nil {
t.Error("expected error for public suffix, got nil")
}
}

View File

@@ -23,6 +23,11 @@ const (
defaultTLSExpiryWarning = 7
)
// ErrNoTargets is returned when no monitoring targets are configured.
var ErrNoTargets = errors.New(
"no monitoring targets configured: set DNSWATCHER_TARGETS environment variable",
)
// Params contains dependencies for Config.
type Params struct {
fx.In
@@ -33,23 +38,24 @@ type Params struct {
// Config holds application configuration.
type Config struct {
Port int
Debug bool
DataDir string
Domains []string
Hostnames []string
SlackWebhook string
MattermostWebhook string
NtfyTopic string
DNSInterval time.Duration
TLSInterval time.Duration
TLSExpiryWarning int
SentryDSN string
MaintenanceMode bool
MetricsUsername string
MetricsPassword string
params *Params
log *slog.Logger
Port int
Debug bool
DataDir string
Domains []string
Hostnames []string
SlackWebhook string
MattermostWebhook string
NtfyTopic string
DNSInterval time.Duration
TLSInterval time.Duration
TLSExpiryWarning int
SentryDSN string
MaintenanceMode bool
MetricsUsername string
MetricsPassword string
SendTestNotification bool
params *Params
log *slog.Logger
}
// New creates a new Config instance from environment and config files.
@@ -89,8 +95,7 @@ func setupViper(name string) {
viper.SetDefault("PORT", defaultPort)
viper.SetDefault("DEBUG", false)
viper.SetDefault("DATA_DIR", "./data")
viper.SetDefault("DOMAINS", "")
viper.SetDefault("HOSTNAMES", "")
viper.SetDefault("TARGETS", "")
viper.SetDefault("SLACK_WEBHOOK", "")
viper.SetDefault("MATTERMOST_WEBHOOK", "")
viper.SetDefault("NTFY_TOPIC", "")
@@ -101,6 +106,7 @@ func setupViper(name string) {
viper.SetDefault("MAINTENANCE_MODE", false)
viper.SetDefault("METRICS_USERNAME", "")
viper.SetDefault("METRICS_PASSWORD", "")
viper.SetDefault("SEND_TEST_NOTIFICATION", false)
}
func buildConfig(
@@ -133,29 +139,52 @@ func buildConfig(
tlsInterval = defaultTLSInterval
}
domains, hostnames, err := parseAndValidateTargets()
if err != nil {
return nil, err
}
cfg := &Config{
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
DataDir: viper.GetString("DATA_DIR"),
Domains: parseCSV(viper.GetString("DOMAINS")),
Hostnames: parseCSV(viper.GetString("HOSTNAMES")),
SlackWebhook: viper.GetString("SLACK_WEBHOOK"),
MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"),
NtfyTopic: viper.GetString("NTFY_TOPIC"),
DNSInterval: dnsInterval,
TLSInterval: tlsInterval,
TLSExpiryWarning: viper.GetInt("TLS_EXPIRY_WARNING"),
SentryDSN: viper.GetString("SENTRY_DSN"),
MaintenanceMode: viper.GetBool("MAINTENANCE_MODE"),
MetricsUsername: viper.GetString("METRICS_USERNAME"),
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
params: params,
log: log,
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
DataDir: viper.GetString("DATA_DIR"),
Domains: domains,
Hostnames: hostnames,
SlackWebhook: viper.GetString("SLACK_WEBHOOK"),
MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"),
NtfyTopic: viper.GetString("NTFY_TOPIC"),
DNSInterval: dnsInterval,
TLSInterval: tlsInterval,
TLSExpiryWarning: viper.GetInt("TLS_EXPIRY_WARNING"),
SentryDSN: viper.GetString("SENTRY_DSN"),
MaintenanceMode: viper.GetBool("MAINTENANCE_MODE"),
MetricsUsername: viper.GetString("METRICS_USERNAME"),
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
SendTestNotification: viper.GetBool("SEND_TEST_NOTIFICATION"),
params: params,
log: log,
}
return cfg, nil
}
func parseAndValidateTargets() ([]string, []string, error) {
domains, hostnames, err := ClassifyTargets(
parseCSV(viper.GetString("TARGETS")),
)
if err != nil {
return nil, nil, fmt.Errorf(
"invalid targets configuration: %w", err,
)
}
if len(domains) == 0 && len(hostnames) == 0 {
return nil, nil, ErrNoTargets
}
return domains, hostnames, nil
}
func parseCSV(input string) []string {
if input == "" {
return nil

View File

@@ -0,0 +1,262 @@
package config_test
import (
"testing"
"time"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/globals"
"sneak.berlin/go/dnswatcher/internal/logger"
)
// newTestParams creates config.Params suitable for testing
// without requiring the fx dependency injection framework.
func newTestParams(t *testing.T) config.Params {
t.Helper()
g := &globals.Globals{
Appname: "dnswatcher",
Version: "test",
}
l, err := logger.New(nil, logger.Params{Globals: g})
require.NoError(t, err, "failed to create logger")
return config.Params{
Globals: g,
Logger: l,
}
}
// These tests exercise viper global state and MUST NOT use
// t.Parallel(). Each test resets viper for isolation.
func TestNew_DefaultValues(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com,www.example.com")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, 8080, cfg.Port)
assert.False(t, cfg.Debug)
assert.Equal(t, "./data", cfg.DataDir)
assert.Equal(t, time.Hour, cfg.DNSInterval)
assert.Equal(t, 12*time.Hour, cfg.TLSInterval)
assert.Equal(t, 7, cfg.TLSExpiryWarning)
assert.False(t, cfg.MaintenanceMode)
assert.Empty(t, cfg.SlackWebhook)
assert.Empty(t, cfg.MattermostWebhook)
assert.Empty(t, cfg.NtfyTopic)
assert.Empty(t, cfg.SentryDSN)
assert.Empty(t, cfg.MetricsUsername)
assert.Empty(t, cfg.MetricsPassword)
assert.False(t, cfg.SendTestNotification)
}
func TestNew_EnvironmentOverrides(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
t.Setenv("PORT", "9090")
t.Setenv("DNSWATCHER_DEBUG", "true")
t.Setenv("DNSWATCHER_DATA_DIR", "/tmp/test-data")
t.Setenv("DNSWATCHER_DNS_INTERVAL", "30m")
t.Setenv("DNSWATCHER_TLS_INTERVAL", "6h")
t.Setenv("DNSWATCHER_TLS_EXPIRY_WARNING", "14")
t.Setenv("DNSWATCHER_SLACK_WEBHOOK", "https://hooks.slack.com/t")
t.Setenv("DNSWATCHER_MATTERMOST_WEBHOOK", "https://mm.test/hooks/t")
t.Setenv("DNSWATCHER_NTFY_TOPIC", "https://ntfy.sh/test")
t.Setenv("DNSWATCHER_SENTRY_DSN", "https://sentry.test/1")
t.Setenv("DNSWATCHER_MAINTENANCE_MODE", "true")
t.Setenv("DNSWATCHER_METRICS_USERNAME", "admin")
t.Setenv("DNSWATCHER_METRICS_PASSWORD", "secret")
t.Setenv("DNSWATCHER_SEND_TEST_NOTIFICATION", "true")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, 9090, cfg.Port)
assert.True(t, cfg.Debug)
assert.Equal(t, "/tmp/test-data", cfg.DataDir)
assert.Equal(t, 30*time.Minute, cfg.DNSInterval)
assert.Equal(t, 6*time.Hour, cfg.TLSInterval)
assert.Equal(t, 14, cfg.TLSExpiryWarning)
assert.Equal(t, "https://hooks.slack.com/t", cfg.SlackWebhook)
assert.Equal(t, "https://mm.test/hooks/t", cfg.MattermostWebhook)
assert.Equal(t, "https://ntfy.sh/test", cfg.NtfyTopic)
assert.Equal(t, "https://sentry.test/1", cfg.SentryDSN)
assert.True(t, cfg.MaintenanceMode)
assert.Equal(t, "admin", cfg.MetricsUsername)
assert.Equal(t, "secret", cfg.MetricsPassword)
assert.True(t, cfg.SendTestNotification)
}
func TestNew_NoTargetsError(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "")
_, err := config.New(nil, newTestParams(t))
require.Error(t, err)
assert.ErrorIs(t, err, config.ErrNoTargets)
}
func TestNew_OnlyEmptyCSVSegments(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", " , , ")
_, err := config.New(nil, newTestParams(t))
require.Error(t, err)
assert.ErrorIs(t, err, config.ErrNoTargets)
}
func TestNew_InvalidDNSInterval_FallsBackToDefault(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
t.Setenv("DNSWATCHER_DNS_INTERVAL", "banana")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, time.Hour, cfg.DNSInterval,
"invalid DNS interval should fall back to 1h default")
}
func TestNew_InvalidTLSInterval_FallsBackToDefault(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
t.Setenv("DNSWATCHER_TLS_INTERVAL", "notaduration")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, 12*time.Hour, cfg.TLSInterval,
"invalid TLS interval should fall back to 12h default")
}
func TestNew_BothIntervalsInvalid(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
t.Setenv("DNSWATCHER_DNS_INTERVAL", "xyz")
t.Setenv("DNSWATCHER_TLS_INTERVAL", "abc")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, time.Hour, cfg.DNSInterval)
assert.Equal(t, 12*time.Hour, cfg.TLSInterval)
}
func TestNew_DebugEnablesDebugLogging(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
t.Setenv("DNSWATCHER_DEBUG", "true")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.True(t, cfg.Debug)
}
func TestNew_PortEnvNotPrefixed(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
t.Setenv("PORT", "3000")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, 3000, cfg.Port,
"PORT env should work without DNSWATCHER_ prefix")
}
func TestNew_TargetClassification(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS",
"example.com,www.example.com,api.example.com,example.org")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
// example.com and example.org are apex domains
assert.Len(t, cfg.Domains, 2)
// www.example.com and api.example.com are hostnames
assert.Len(t, cfg.Hostnames, 2)
}
func TestNew_InvalidTargetPublicSuffix(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "co.uk")
_, err := config.New(nil, newTestParams(t))
require.Error(t, err, "public suffix should be rejected")
}
func TestNew_EmptyAppnameDefaultsToDnswatcher(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
g := &globals.Globals{Appname: "", Version: "test"}
l, err := logger.New(nil, logger.Params{Globals: g})
require.NoError(t, err)
cfg, err := config.New(
nil, config.Params{Globals: g, Logger: l},
)
require.NoError(t, err)
assert.Equal(t, 8080, cfg.Port,
"defaults should load when appname is empty")
}
func TestNew_TargetsWithWhitespace(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", " example.com , www.example.com ")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, 2, len(cfg.Domains)+len(cfg.Hostnames),
"whitespace around targets should be trimmed")
}
func TestNew_TargetsWithTrailingComma(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com,www.example.com,")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, 2, len(cfg.Domains)+len(cfg.Hostnames),
"trailing comma should be ignored")
}
func TestNew_CustomDNSIntervalDuration(t *testing.T) {
viper.Reset()
t.Setenv("DNSWATCHER_TARGETS", "example.com")
t.Setenv("DNSWATCHER_DNS_INTERVAL", "5s")
cfg, err := config.New(nil, newTestParams(t))
require.NoError(t, err)
assert.Equal(t, 5*time.Second, cfg.DNSInterval)
}
func TestStatePath(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dataDir string
want string
}{
{"default", "./data", "./data/state.json"},
{"absolute", "/var/lib/dw", "/var/lib/dw/state.json"},
{"nested", "/opt/app/data", "/opt/app/data/state.json"},
{"empty", "", "/state.json"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
cfg := &config.Config{DataDir: tt.dataDir}
assert.Equal(t, tt.want, cfg.StatePath())
})
}
}

View File

@@ -0,0 +1,6 @@
package config
// ParseCSVForTest exports parseCSV for use in external tests.
func ParseCSVForTest(input string) []string {
return parseCSV(input)
}

View File

@@ -0,0 +1,44 @@
package config_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/config"
)
func TestParseCSV(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want []string
}{
{"empty string", "", nil},
{"single value", "a", []string{"a"}},
{"multiple values", "a,b,c", []string{"a", "b", "c"}},
{"whitespace trimmed", " a , b ", []string{"a", "b"}},
{"trailing comma", "a,b,", []string{"a", "b"}},
{"leading comma", ",a,b", []string{"a", "b"}},
{"consecutive commas", "a,,b", []string{"a", "b"}},
{"all empty segments", ",,,", nil},
{"whitespace only", " , , ", nil},
{"tabs", "\ta\t,\tb\t", []string{"a", "b"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := config.ParseCSVForTest(tt.input)
require.Len(t, got, len(tt.want))
for i, w := range tt.want {
assert.Equal(t, w, got[i])
}
})
}
}

View File

@@ -12,17 +12,15 @@ import (
//
//nolint:gochecknoglobals // Required for ldflags injection at build time
var (
mu sync.RWMutex
appname string
version string
buildarch string
mu sync.RWMutex
appname string
version string
)
// Globals holds build-time variables for dependency injection.
type Globals struct {
Appname string
Version string
Buildarch string
Appname string
Version string
}
// New creates a new Globals instance from package-level variables.
@@ -31,9 +29,8 @@ func New(_ fx.Lifecycle) (*Globals, error) {
defer mu.RUnlock()
return &Globals{
Appname: appname,
Version: version,
Buildarch: buildarch,
Appname: appname,
Version: version,
}, nil
}
@@ -52,11 +49,3 @@ func SetVersion(ver string) {
version = ver
}
// SetBuildarch sets the build architecture.
func SetBuildarch(arch string) {
mu.Lock()
defer mu.Unlock()
buildarch = arch
}

View File

@@ -0,0 +1,151 @@
package handlers
import (
"embed"
"fmt"
"html/template"
"math"
"net/http"
"strings"
"time"
"sneak.berlin/go/dnswatcher/internal/notify"
"sneak.berlin/go/dnswatcher/internal/state"
)
//go:embed templates/dashboard.html
var dashboardFS embed.FS
// Time unit constants for relative time calculations.
const (
secondsPerMinute = 60
minutesPerHour = 60
hoursPerDay = 24
)
// newDashboardTemplate parses the embedded dashboard HTML
// template with helper functions.
func newDashboardTemplate() *template.Template {
funcs := template.FuncMap{
"relTime": relTime,
"joinStrings": joinStrings,
"formatRecords": formatRecords,
"expiryDays": expiryDays,
}
return template.Must(
template.New("dashboard.html").
Funcs(funcs).
ParseFS(dashboardFS, "templates/dashboard.html"),
)
}
// dashboardData is the data passed to the dashboard template.
type dashboardData struct {
Snapshot state.Snapshot
Alerts []notify.AlertEntry
StateAge string
GeneratedAt string
}
// HandleDashboard returns the dashboard page handler.
func (h *Handlers) HandleDashboard() http.HandlerFunc {
tmpl := newDashboardTemplate()
return func(
writer http.ResponseWriter,
_ *http.Request,
) {
snap := h.state.GetSnapshot()
alerts := h.notifyHistory.Recent()
data := dashboardData{
Snapshot: snap,
Alerts: alerts,
StateAge: relTime(snap.LastUpdated),
GeneratedAt: time.Now().UTC().Format("2006-01-02 15:04:05"),
}
writer.Header().Set(
"Content-Type", "text/html; charset=utf-8",
)
err := tmpl.Execute(writer, data)
if err != nil {
h.log.Error(
"dashboard template error",
"error", err,
)
}
}
}
// relTime returns a human-readable relative time string such
// as "2 minutes ago" or "never" for zero times.
func relTime(t time.Time) string {
if t.IsZero() {
return "never"
}
d := time.Since(t)
if d < 0 {
return "just now"
}
seconds := int(math.Round(d.Seconds()))
if seconds < secondsPerMinute {
return fmt.Sprintf("%ds ago", seconds)
}
minutes := seconds / secondsPerMinute
if minutes < minutesPerHour {
return fmt.Sprintf("%dm ago", minutes)
}
hours := minutes / minutesPerHour
if hours < hoursPerDay {
return fmt.Sprintf(
"%dh %dm ago", hours, minutes%minutesPerHour,
)
}
days := hours / hoursPerDay
return fmt.Sprintf(
"%dd %dh ago", days, hours%hoursPerDay,
)
}
// joinStrings joins a string slice with a separator.
func joinStrings(items []string, sep string) string {
return strings.Join(items, sep)
}
// formatRecords formats a map of record type → values into a
// compact display string.
func formatRecords(records map[string][]string) string {
if len(records) == 0 {
return "-"
}
var parts []string
for rtype, values := range records {
for _, v := range values {
parts = append(parts, rtype+": "+v)
}
}
return strings.Join(parts, ", ")
}
// expiryDays returns the number of days until the given time,
// rounded down. Returns 0 if already expired.
func expiryDays(t time.Time) int {
d := time.Until(t).Hours() / hoursPerDay
if d < 0 {
return 0
}
return int(d)
}

View File

@@ -0,0 +1,80 @@
package handlers_test
import (
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/handlers"
)
func TestRelTime(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dur time.Duration
want string
}{
{"zero", 0, "never"},
{"seconds", 30 * time.Second, "30s ago"},
{"minutes", 5 * time.Minute, "5m ago"},
{"hours", 2*time.Hour + 15*time.Minute, "2h 15m ago"},
{"days", 48*time.Hour + 3*time.Hour, "2d 3h ago"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var input time.Time
if tt.dur > 0 {
input = time.Now().Add(-tt.dur)
}
got := handlers.RelTime(input)
if got != tt.want {
t.Errorf(
"RelTime(%v) = %q, want %q",
tt.dur, got, tt.want,
)
}
})
}
}
func TestExpiryDays(t *testing.T) {
t.Parallel()
// 10 days from now.
future := time.Now().Add(10 * 24 * time.Hour)
days := handlers.ExpiryDays(future)
if days < 9 || days > 10 {
t.Errorf("expected ~10 days, got %d", days)
}
// Already expired.
past := time.Now().Add(-24 * time.Hour)
days = handlers.ExpiryDays(past)
if days != 0 {
t.Errorf("expected 0 for expired, got %d", days)
}
}
func TestFormatRecords(t *testing.T) {
t.Parallel()
got := handlers.FormatRecords(nil)
if got != "-" {
t.Errorf("expected -, got %q", got)
}
got = handlers.FormatRecords(map[string][]string{
"A": {"1.2.3.4"},
})
if got != "A: 1.2.3.4" {
t.Errorf("unexpected format: %q", got)
}
}

View File

@@ -0,0 +1,18 @@
package handlers
import "time"
// RelTime exports relTime for testing.
func RelTime(t time.Time) string {
return relTime(t)
}
// ExpiryDays exports expiryDays for testing.
func ExpiryDays(t time.Time) int {
return expiryDays(t)
}
// FormatRecords exports formatRecords for testing.
func FormatRecords(records map[string][]string) string {
return formatRecords(records)
}

View File

@@ -11,6 +11,8 @@ import (
"sneak.berlin/go/dnswatcher/internal/globals"
"sneak.berlin/go/dnswatcher/internal/healthcheck"
"sneak.berlin/go/dnswatcher/internal/logger"
"sneak.berlin/go/dnswatcher/internal/notify"
"sneak.berlin/go/dnswatcher/internal/state"
)
// Params contains dependencies for Handlers.
@@ -20,23 +22,29 @@ type Params struct {
Logger *logger.Logger
Globals *globals.Globals
Healthcheck *healthcheck.Healthcheck
State *state.State
Notify *notify.Service
}
// Handlers provides HTTP request handlers.
type Handlers struct {
log *slog.Logger
params *Params
globals *globals.Globals
hc *healthcheck.Healthcheck
log *slog.Logger
params *Params
globals *globals.Globals
hc *healthcheck.Healthcheck
state *state.State
notifyHistory *notify.AlertHistory
}
// New creates a new Handlers instance.
func New(_ fx.Lifecycle, params Params) (*Handlers, error) {
return &Handlers{
log: params.Logger.Get(),
params: &params,
globals: params.Globals,
hc: params.Healthcheck,
log: params.Logger.Get(),
params: &params,
globals: params.Globals,
hc: params.Healthcheck,
state: params.State,
notifyHistory: params.Notify.History(),
}, nil
}

View File

@@ -2,22 +2,217 @@ package handlers
import (
"net/http"
"sort"
"time"
"sneak.berlin/go/dnswatcher/internal/state"
)
// statusDomainInfo holds status information for a monitored domain.
type statusDomainInfo struct {
Nameservers []string `json:"nameservers"`
LastChecked time.Time `json:"lastChecked"`
}
// statusHostnameNSInfo holds per-nameserver status for a hostname.
type statusHostnameNSInfo struct {
Records map[string][]string `json:"records"`
Status string `json:"status"`
LastChecked time.Time `json:"lastChecked"`
}
// statusHostnameInfo holds status information for a monitored hostname.
type statusHostnameInfo struct {
Nameservers map[string]*statusHostnameNSInfo `json:"nameservers"`
LastChecked time.Time `json:"lastChecked"`
}
// statusPortInfo holds status information for a monitored port.
type statusPortInfo struct {
Open bool `json:"open"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
// statusCertificateInfo holds status information for a TLS certificate.
type statusCertificateInfo struct {
CommonName string `json:"commonName"`
Issuer string `json:"issuer"`
NotAfter time.Time `json:"notAfter"`
SubjectAlternativeNames []string `json:"subjectAlternativeNames"`
Status string `json:"status"`
LastChecked time.Time `json:"lastChecked"`
}
// statusCounts holds summary counts of monitored resources.
type statusCounts struct {
Domains int `json:"domains"`
Hostnames int `json:"hostnames"`
Ports int `json:"ports"`
PortsOpen int `json:"portsOpen"`
Certificates int `json:"certificates"`
CertsOK int `json:"certificatesOk"`
CertsError int `json:"certificatesError"`
}
// statusResponse is the full /api/v1/status response.
type statusResponse struct {
Status string `json:"status"`
LastUpdated time.Time `json:"lastUpdated"`
Counts statusCounts `json:"counts"`
Domains map[string]*statusDomainInfo `json:"domains"`
Hostnames map[string]*statusHostnameInfo `json:"hostnames"`
Ports map[string]*statusPortInfo `json:"ports"`
Certificates map[string]*statusCertificateInfo `json:"certificates"`
}
// HandleStatus returns the monitoring status handler.
func (h *Handlers) HandleStatus() http.HandlerFunc {
type response struct {
Status string `json:"status"`
}
return func(
writer http.ResponseWriter,
request *http.Request,
) {
snap := h.state.GetSnapshot()
resp := buildStatusResponse(snap)
h.respondJSON(
writer, request,
&response{Status: "ok"},
resp,
http.StatusOK,
)
}
}
// buildStatusResponse constructs the full status response from
// the current monitoring snapshot.
func buildStatusResponse(
snap state.Snapshot,
) *statusResponse {
resp := &statusResponse{
Status: "ok",
LastUpdated: snap.LastUpdated,
Domains: make(map[string]*statusDomainInfo),
Hostnames: make(map[string]*statusHostnameInfo),
Ports: make(map[string]*statusPortInfo),
Certificates: make(map[string]*statusCertificateInfo),
}
buildDomains(snap, resp)
buildHostnames(snap, resp)
buildPorts(snap, resp)
buildCertificates(snap, resp)
buildCounts(resp)
return resp
}
func buildDomains(
snap state.Snapshot,
resp *statusResponse,
) {
for name, ds := range snap.Domains {
ns := make([]string, len(ds.Nameservers))
copy(ns, ds.Nameservers)
sort.Strings(ns)
resp.Domains[name] = &statusDomainInfo{
Nameservers: ns,
LastChecked: ds.LastChecked,
}
}
}
func buildHostnames(
snap state.Snapshot,
resp *statusResponse,
) {
for name, hs := range snap.Hostnames {
info := &statusHostnameInfo{
Nameservers: make(map[string]*statusHostnameNSInfo),
LastChecked: hs.LastChecked,
}
for ns, nsState := range hs.RecordsByNameserver {
recs := make(map[string][]string, len(nsState.Records))
for rtype, vals := range nsState.Records {
copied := make([]string, len(vals))
copy(copied, vals)
recs[rtype] = copied
}
info.Nameservers[ns] = &statusHostnameNSInfo{
Records: recs,
Status: nsState.Status,
LastChecked: nsState.LastChecked,
}
}
resp.Hostnames[name] = info
}
}
func buildPorts(
snap state.Snapshot,
resp *statusResponse,
) {
for key, ps := range snap.Ports {
hostnames := make([]string, len(ps.Hostnames))
copy(hostnames, ps.Hostnames)
sort.Strings(hostnames)
resp.Ports[key] = &statusPortInfo{
Open: ps.Open,
Hostnames: hostnames,
LastChecked: ps.LastChecked,
}
}
}
func buildCertificates(
snap state.Snapshot,
resp *statusResponse,
) {
for key, cs := range snap.Certificates {
sans := make([]string, len(cs.SubjectAlternativeNames))
copy(sans, cs.SubjectAlternativeNames)
resp.Certificates[key] = &statusCertificateInfo{
CommonName: cs.CommonName,
Issuer: cs.Issuer,
NotAfter: cs.NotAfter,
SubjectAlternativeNames: sans,
Status: cs.Status,
LastChecked: cs.LastChecked,
}
}
}
func buildCounts(resp *statusResponse) {
var portsOpen, certsOK, certsError int
for _, ps := range resp.Ports {
if ps.Open {
portsOpen++
}
}
for _, cs := range resp.Certificates {
switch cs.Status {
case "ok":
certsOK++
case "error":
certsError++
}
}
resp.Counts = statusCounts{
Domains: len(resp.Domains),
Hostnames: len(resp.Hostnames),
Ports: len(resp.Ports),
PortsOpen: portsOpen,
Certificates: len(resp.Certificates),
CertsOK: certsOK,
CertsError: certsError,
}
}

View File

@@ -0,0 +1,370 @@
<!doctype html>
<html lang="en" class="bg-slate-950">
<head>
<meta charset="utf-8" />
<meta http-equiv="refresh" content="30" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>dnswatcher</title>
<link rel="stylesheet" href="/s/css/tailwind.min.css" />
</head>
<body
class="bg-surface-950 text-slate-300 font-mono text-sm min-h-screen antialiased"
>
<div class="max-w-6xl mx-auto px-4 py-8">
{{/* ---- Header ---- */}}
<div class="mb-8">
<h1 class="text-2xl font-bold text-teal-400 tracking-tight">
dnswatcher
</h1>
<p class="text-xs text-slate-500 mt-1">
state updated {{ .StateAge }} &middot; page generated
{{ .GeneratedAt }} UTC &middot; auto-refresh 30s
</p>
</div>
{{/* ---- Summary bar ---- */}}
<div
class="grid grid-cols-2 sm:grid-cols-4 gap-3 mb-8"
>
<div class="bg-surface-800 border border-slate-700/50 rounded-lg p-4">
<div class="text-xs text-slate-500 uppercase tracking-wider">
Domains
</div>
<div class="text-2xl font-bold text-teal-400 mt-1">
{{ len .Snapshot.Domains }}
</div>
</div>
<div class="bg-surface-800 border border-slate-700/50 rounded-lg p-4">
<div class="text-xs text-slate-500 uppercase tracking-wider">
Hostnames
</div>
<div class="text-2xl font-bold text-teal-400 mt-1">
{{ len .Snapshot.Hostnames }}
</div>
</div>
<div class="bg-surface-800 border border-slate-700/50 rounded-lg p-4">
<div class="text-xs text-slate-500 uppercase tracking-wider">
Ports
</div>
<div class="text-2xl font-bold text-teal-400 mt-1">
{{ len .Snapshot.Ports }}
</div>
</div>
<div class="bg-surface-800 border border-slate-700/50 rounded-lg p-4">
<div class="text-xs text-slate-500 uppercase tracking-wider">
Certificates
</div>
<div class="text-2xl font-bold text-teal-400 mt-1">
{{ len .Snapshot.Certificates }}
</div>
</div>
</div>
{{/* ---- Domains ---- */}}
<section class="mb-8">
<h2
class="text-sm font-semibold text-teal-300 uppercase tracking-wider mb-3 border-b border-slate-700/50 pb-2"
>
Domains
</h2>
{{ if .Snapshot.Domains }}
<div class="overflow-x-auto">
<table class="w-full text-left text-xs">
<thead>
<tr class="text-slate-500 uppercase tracking-wider">
<th class="py-2 px-3">Domain</th>
<th class="py-2 px-3">Nameservers</th>
<th class="py-2 px-3">Checked</th>
</tr>
</thead>
<tbody class="divide-y divide-slate-800">
{{ range $name, $ds := .Snapshot.Domains }}
<tr class="hover:bg-surface-800/50">
<td class="py-2 px-3 text-slate-200 font-medium">
{{ $name }}
</td>
<td class="py-2 px-3 text-slate-400 break-all">
{{ joinStrings $ds.Nameservers ", " }}
</td>
<td class="py-2 px-3 text-slate-500 whitespace-nowrap">
{{ relTime $ds.LastChecked }}
</td>
</tr>
{{ end }}
</tbody>
</table>
</div>
{{ else }}
<p class="text-slate-600 italic text-xs">
No domains configured.
</p>
{{ end }}
</section>
{{/* ---- Hostnames ---- */}}
<section class="mb-8">
<h2
class="text-sm font-semibold text-teal-300 uppercase tracking-wider mb-3 border-b border-slate-700/50 pb-2"
>
Hostnames
</h2>
{{ if .Snapshot.Hostnames }}
<div class="overflow-x-auto">
<table class="w-full text-left text-xs">
<thead>
<tr class="text-slate-500 uppercase tracking-wider">
<th class="py-2 px-3">Hostname</th>
<th class="py-2 px-3">NS</th>
<th class="py-2 px-3">Status</th>
<th class="py-2 px-3">Records</th>
<th class="py-2 px-3">Checked</th>
</tr>
</thead>
<tbody class="divide-y divide-slate-800">
{{ range $name, $hs := .Snapshot.Hostnames }}
{{ range $ns, $nsr := $hs.RecordsByNameserver }}
<tr class="hover:bg-surface-800/50">
<td class="py-2 px-3 text-slate-200 font-medium">
{{ $name }}
</td>
<td class="py-2 px-3 text-slate-400 break-all">
{{ $ns }}
</td>
<td class="py-2 px-3">
{{ if eq $nsr.Status "ok" }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-teal-900/50 text-teal-400 border border-teal-700/30"
>ok</span
>
{{ else }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-red-900/50 text-red-400 border border-red-700/30"
>{{ $nsr.Status }}</span
>
{{ end }}
</td>
<td
class="py-2 px-3 text-slate-400 break-all max-w-xs"
>
{{ formatRecords $nsr.Records }}
</td>
<td class="py-2 px-3 text-slate-500 whitespace-nowrap">
{{ relTime $nsr.LastChecked }}
</td>
</tr>
{{ end }}
{{ end }}
</tbody>
</table>
</div>
{{ else }}
<p class="text-slate-600 italic text-xs">
No hostnames configured.
</p>
{{ end }}
</section>
{{/* ---- Ports ---- */}}
<section class="mb-8">
<h2
class="text-sm font-semibold text-teal-300 uppercase tracking-wider mb-3 border-b border-slate-700/50 pb-2"
>
Ports
</h2>
{{ if .Snapshot.Ports }}
<div class="overflow-x-auto">
<table class="w-full text-left text-xs">
<thead>
<tr class="text-slate-500 uppercase tracking-wider">
<th class="py-2 px-3">Address</th>
<th class="py-2 px-3">State</th>
<th class="py-2 px-3">Hostnames</th>
<th class="py-2 px-3">Checked</th>
</tr>
</thead>
<tbody class="divide-y divide-slate-800">
{{ range $key, $ps := .Snapshot.Ports }}
<tr class="hover:bg-surface-800/50">
<td class="py-2 px-3 text-slate-200 font-medium">
{{ $key }}
</td>
<td class="py-2 px-3">
{{ if $ps.Open }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-teal-900/50 text-teal-400 border border-teal-700/30"
>open</span
>
{{ else }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-red-900/50 text-red-400 border border-red-700/30"
>closed</span
>
{{ end }}
</td>
<td class="py-2 px-3 text-slate-400 break-all">
{{ joinStrings $ps.Hostnames ", " }}
</td>
<td class="py-2 px-3 text-slate-500 whitespace-nowrap">
{{ relTime $ps.LastChecked }}
</td>
</tr>
{{ end }}
</tbody>
</table>
</div>
{{ else }}
<p class="text-slate-600 italic text-xs">
No port data yet.
</p>
{{ end }}
</section>
{{/* ---- Certificates ---- */}}
<section class="mb-8">
<h2
class="text-sm font-semibold text-teal-300 uppercase tracking-wider mb-3 border-b border-slate-700/50 pb-2"
>
Certificates
</h2>
{{ if .Snapshot.Certificates }}
<div class="overflow-x-auto">
<table class="w-full text-left text-xs">
<thead>
<tr class="text-slate-500 uppercase tracking-wider">
<th class="py-2 px-3">Endpoint</th>
<th class="py-2 px-3">Status</th>
<th class="py-2 px-3">CN</th>
<th class="py-2 px-3">Issuer</th>
<th class="py-2 px-3">Expires</th>
<th class="py-2 px-3">Checked</th>
</tr>
</thead>
<tbody class="divide-y divide-slate-800">
{{ range $key, $cs := .Snapshot.Certificates }}
<tr class="hover:bg-surface-800/50">
<td class="py-2 px-3 text-slate-400 break-all">
{{ $key }}
</td>
<td class="py-2 px-3">
{{ if eq $cs.Status "ok" }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-teal-900/50 text-teal-400 border border-teal-700/30"
>ok</span
>
{{ else }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-red-900/50 text-red-400 border border-red-700/30"
>{{ $cs.Status }}</span
>
{{ end }}
</td>
<td class="py-2 px-3 text-slate-200">
{{ $cs.CommonName }}
</td>
<td class="py-2 px-3 text-slate-400 break-all">
{{ $cs.Issuer }}
</td>
<td class="py-2 px-3 whitespace-nowrap">
{{ if not $cs.NotAfter.IsZero }}
{{ $days := expiryDays $cs.NotAfter }}
{{ if lt $days 7 }}
<span class="text-red-400 font-medium"
>{{ $cs.NotAfter.Format "2006-01-02" }}
({{ $days }}d)</span
>
{{ else if lt $days 30 }}
<span class="text-amber-400"
>{{ $cs.NotAfter.Format "2006-01-02" }}
({{ $days }}d)</span
>
{{ else }}
<span class="text-slate-400"
>{{ $cs.NotAfter.Format "2006-01-02" }}
({{ $days }}d)</span
>
{{ end }}
{{ end }}
</td>
<td class="py-2 px-3 text-slate-500 whitespace-nowrap">
{{ relTime $cs.LastChecked }}
</td>
</tr>
{{ end }}
</tbody>
</table>
</div>
{{ else }}
<p class="text-slate-600 italic text-xs">
No certificate data yet.
</p>
{{ end }}
</section>
{{/* ---- Recent Alerts ---- */}}
<section class="mb-8">
<h2
class="text-sm font-semibold text-teal-300 uppercase tracking-wider mb-3 border-b border-slate-700/50 pb-2"
>
Recent Alerts ({{ len .Alerts }})
</h2>
{{ if .Alerts }}
<div class="space-y-2">
{{ range .Alerts }}
<div
class="bg-surface-800 border rounded-lg px-4 py-3 {{ if eq .Priority "error" }}border-red-700/40{{ else if eq .Priority "warning" }}border-amber-700/40{{ else if eq .Priority "success" }}border-teal-700/40{{ else }}border-blue-700/40{{ end }}"
>
<div class="flex items-center gap-3 mb-1">
{{ if eq .Priority "error" }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-red-900/50 text-red-400 border border-red-700/30"
>error</span
>
{{ else if eq .Priority "warning" }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-amber-900/50 text-amber-400 border border-amber-700/30"
>warning</span
>
{{ else if eq .Priority "success" }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-teal-900/50 text-teal-400 border border-teal-700/30"
>success</span
>
{{ else }}
<span
class="inline-block px-1.5 py-0.5 rounded text-[10px] font-bold uppercase bg-blue-900/50 text-blue-400 border border-blue-700/30"
>info</span
>
{{ end }}
<span class="text-slate-200 text-xs font-medium">
{{ .Title }}
</span>
<span class="text-slate-600 text-[11px] ml-auto whitespace-nowrap">
{{ .Timestamp.Format "2006-01-02 15:04:05" }} UTC
({{ relTime .Timestamp }})
</span>
</div>
<p
class="text-slate-400 text-xs whitespace-pre-line pl-0.5"
>
{{ .Message }}
</p>
</div>
{{ end }}
</div>
{{ else }}
<p class="text-slate-600 italic text-xs">
No alerts recorded since last restart.
</p>
{{ end }}
</section>
{{/* ---- Footer ---- */}}
<div
class="text-[11px] text-slate-700 border-t border-slate-800 pt-4 mt-8"
>
dnswatcher &middot; monitoring {{ len .Snapshot.Domains }} domains +
{{ len .Snapshot.Hostnames }} hostnames
</div>
</div>
</body>
</html>

View File

@@ -78,6 +78,5 @@ func (l *Logger) Identify() {
l.log.Info("starting",
"appname", l.params.Globals.Appname,
"version", l.params.Globals.Version,
"buildarch", l.params.Globals.Buildarch,
)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
package notify
import (
"context"
"io"
"log/slog"
"net/http"
"net/url"
"time"
)
// NtfyPriority exports ntfyPriority for testing.
func NtfyPriority(priority string) string {
return ntfyPriority(priority)
}
// SlackColor exports slackColor for testing.
func SlackColor(priority string) string {
return slackColor(priority)
}
// NewRequestForTest exports newRequest for testing.
func NewRequestForTest(
ctx context.Context,
method string,
target *url.URL,
body io.Reader,
) *http.Request {
return newRequest(ctx, method, target, body)
}
// NewTestService creates a Service suitable for unit testing.
// It discards log output and uses the given transport.
func NewTestService(transport http.RoundTripper) *Service {
return &Service{
log: slog.New(slog.DiscardHandler),
transport: transport,
history: NewAlertHistory(),
}
}
// SetNtfyURL sets the ntfy URL on a Service for testing.
func (svc *Service) SetNtfyURL(u *url.URL) {
svc.ntfyURL = u
}
// SetSlackWebhookURL sets the Slack webhook URL on a
// Service for testing.
func (svc *Service) SetSlackWebhookURL(u *url.URL) {
svc.slackWebhookURL = u
}
// SetMattermostWebhookURL sets the Mattermost webhook URL on
// a Service for testing.
func (svc *Service) SetMattermostWebhookURL(u *url.URL) {
svc.mattermostWebhookURL = u
}
// SendNtfy exports sendNtfy for testing.
func (svc *Service) SendNtfy(
ctx context.Context,
topicURL *url.URL,
title, message, priority string,
) error {
return svc.sendNtfy(ctx, topicURL, title, message, priority)
}
// SendSlack exports sendSlack for testing.
func (svc *Service) SendSlack(
ctx context.Context,
webhookURL *url.URL,
title, message, priority string,
) error {
return svc.sendSlack(
ctx, webhookURL, title, message, priority,
)
}
// SetRetryConfig overrides the retry configuration for
// testing.
func (svc *Service) SetRetryConfig(cfg RetryConfig) {
svc.retryConfig = cfg
}
// SetSleepFunc overrides the sleep function so tests can
// eliminate real delays.
func (svc *Service) SetSleepFunc(
fn func(time.Duration) <-chan time.Time,
) {
svc.sleepFn = fn
}
// DeliverWithRetry exports deliverWithRetry for testing.
func (svc *Service) DeliverWithRetry(
ctx context.Context,
endpoint string,
fn func(context.Context) error,
) error {
return svc.deliverWithRetry(ctx, endpoint, fn)
}
// BackoffDuration exports RetryConfig.backoff for testing.
func (rc RetryConfig) BackoffDuration(attempt int) time.Duration {
return rc.defaults().backoff(attempt)
}

View File

@@ -0,0 +1,62 @@
package notify
import (
"sync"
"time"
)
// maxAlertHistory is the maximum number of alerts to retain.
const maxAlertHistory = 100
// AlertEntry represents a single notification that was sent.
type AlertEntry struct {
Timestamp time.Time
Title string
Message string
Priority string
}
// AlertHistory is a thread-safe ring buffer that stores
// the most recent alerts.
type AlertHistory struct {
mu sync.RWMutex
entries [maxAlertHistory]AlertEntry
count int
index int
}
// NewAlertHistory creates a new empty AlertHistory.
func NewAlertHistory() *AlertHistory {
return &AlertHistory{}
}
// Add records a new alert entry in the ring buffer.
func (h *AlertHistory) Add(entry AlertEntry) {
h.mu.Lock()
defer h.mu.Unlock()
h.entries[h.index] = entry
h.index = (h.index + 1) % maxAlertHistory
if h.count < maxAlertHistory {
h.count++
}
}
// Recent returns the stored alerts in reverse chronological
// order (newest first). Returns at most maxAlertHistory entries.
func (h *AlertHistory) Recent() []AlertEntry {
h.mu.RLock()
defer h.mu.RUnlock()
result := make([]AlertEntry, h.count)
for i := range h.count {
// Walk backwards from the most recent entry.
idx := (h.index - 1 - i + maxAlertHistory) %
maxAlertHistory
result[i] = h.entries[idx]
}
return result
}

View File

@@ -0,0 +1,88 @@
package notify_test
import (
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/notify"
)
func TestAlertHistoryEmpty(t *testing.T) {
t.Parallel()
h := notify.NewAlertHistory()
entries := h.Recent()
if len(entries) != 0 {
t.Fatalf("expected 0 entries, got %d", len(entries))
}
}
func TestAlertHistoryAddAndRecent(t *testing.T) {
t.Parallel()
h := notify.NewAlertHistory()
now := time.Now().UTC()
h.Add(notify.AlertEntry{
Timestamp: now.Add(-2 * time.Minute),
Title: "first",
Message: "msg1",
Priority: "info",
})
h.Add(notify.AlertEntry{
Timestamp: now.Add(-1 * time.Minute),
Title: "second",
Message: "msg2",
Priority: "warning",
})
entries := h.Recent()
if len(entries) != 2 {
t.Fatalf("expected 2 entries, got %d", len(entries))
}
// Newest first.
if entries[0].Title != "second" {
t.Errorf(
"expected newest first, got %q", entries[0].Title,
)
}
if entries[1].Title != "first" {
t.Errorf(
"expected oldest second, got %q", entries[1].Title,
)
}
}
func TestAlertHistoryOverflow(t *testing.T) {
t.Parallel()
h := notify.NewAlertHistory()
const totalEntries = 110
// Fill beyond capacity.
for i := range totalEntries {
h.Add(notify.AlertEntry{
Timestamp: time.Now().UTC(),
Title: "alert",
Message: "msg",
Priority: string(rune('0' + i%10)),
})
}
entries := h.Recent()
const maxHistory = 100
if len(entries) != maxHistory {
t.Fatalf(
"expected %d entries, got %d",
maxHistory, len(entries),
)
}
}

View File

@@ -1,4 +1,5 @@
// Package notify provides notification delivery to Slack, Mattermost, and ntfy.
// Package notify provides notification delivery to Slack,
// Mattermost, and ntfy.
package notify
import (
@@ -7,8 +8,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"time"
"go.uber.org/fx"
@@ -33,8 +36,66 @@ var (
ErrMattermostFailed = errors.New(
"mattermost notification failed",
)
// ErrInvalidScheme is returned for disallowed URL schemes.
ErrInvalidScheme = errors.New("URL scheme not allowed")
// ErrMissingHost is returned when a URL has no host.
ErrMissingHost = errors.New("URL must have a host")
)
// IsAllowedScheme checks if the URL scheme is permitted.
func IsAllowedScheme(scheme string) bool {
return scheme == "https" || scheme == "http"
}
// ValidateWebhookURL validates and sanitizes a webhook URL.
// It ensures the URL has an allowed scheme (http/https),
// a non-empty host, and returns a pre-parsed *url.URL
// reconstructed from validated components.
func ValidateWebhookURL(raw string) (*url.URL, error) {
u, err := url.ParseRequestURI(raw)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if !IsAllowedScheme(u.Scheme) {
return nil, fmt.Errorf(
"%w: %s", ErrInvalidScheme, u.Scheme,
)
}
if u.Host == "" {
return nil, fmt.Errorf("%w", ErrMissingHost)
}
// Reconstruct from parsed components.
clean := &url.URL{
Scheme: u.Scheme,
Host: u.Host,
Path: u.Path,
RawQuery: u.RawQuery,
}
return clean, nil
}
// newRequest creates an http.Request from a pre-validated *url.URL.
// This avoids passing URL strings to http.NewRequestWithContext,
// which gosec flags as a potential SSRF vector.
func newRequest(
ctx context.Context,
method string,
target *url.URL,
body io.Reader,
) *http.Request {
return (&http.Request{
Method: method,
URL: target,
Host: target.Host,
Header: make(http.Header),
Body: io.NopCloser(body),
}).WithContext(ctx)
}
// Params contains dependencies for Service.
type Params struct {
fx.In
@@ -45,9 +106,15 @@ type Params struct {
// Service provides notification functionality.
type Service struct {
log *slog.Logger
client *http.Client
config *config.Config
log *slog.Logger
transport http.RoundTripper
config *config.Config
ntfyURL *url.URL
slackWebhookURL *url.URL
mattermostWebhookURL *url.URL
history *AlertHistory
retryConfig RetryConfig
sleepFn func(time.Duration) <-chan time.Time
}
// New creates a new notify Service.
@@ -55,99 +122,193 @@ func New(
_ fx.Lifecycle,
params Params,
) (*Service, error) {
return &Service{
log: params.Logger.Get(),
client: &http.Client{
Timeout: httpClientTimeout,
},
config: params.Config,
}, nil
svc := &Service{
log: params.Logger.Get(),
transport: http.DefaultTransport,
config: params.Config,
history: NewAlertHistory(),
}
if params.Config.NtfyTopic != "" {
u, err := ValidateWebhookURL(
params.Config.NtfyTopic,
)
if err != nil {
return nil, fmt.Errorf(
"invalid ntfy topic URL: %w", err,
)
}
svc.ntfyURL = u
}
if params.Config.SlackWebhook != "" {
u, err := ValidateWebhookURL(
params.Config.SlackWebhook,
)
if err != nil {
return nil, fmt.Errorf(
"invalid slack webhook URL: %w", err,
)
}
svc.slackWebhookURL = u
}
if params.Config.MattermostWebhook != "" {
u, err := ValidateWebhookURL(
params.Config.MattermostWebhook,
)
if err != nil {
return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err,
)
}
svc.mattermostWebhookURL = u
}
return svc, nil
}
// SendNotification sends a notification to all configured endpoints.
// History returns the alert history for reading recent alerts.
func (svc *Service) History() *AlertHistory {
return svc.history
}
// SendNotification sends a notification to all configured
// endpoints and records it in the alert history.
func (svc *Service) SendNotification(
ctx context.Context,
title, message, priority string,
) {
if svc.config.NtfyTopic != "" {
go func() {
notifyCtx := context.WithoutCancel(ctx)
svc.history.Add(AlertEntry{
Timestamp: time.Now().UTC(),
Title: title,
Message: message,
Priority: priority,
})
err := svc.sendNtfy(
notifyCtx,
svc.config.NtfyTopic,
title, message, priority,
)
if err != nil {
svc.log.Error(
"failed to send ntfy notification",
"error", err,
)
}
}()
svc.dispatchNtfy(ctx, title, message, priority)
svc.dispatchSlack(ctx, title, message, priority)
svc.dispatchMattermost(ctx, title, message, priority)
}
func (svc *Service) dispatchNtfy(
ctx context.Context,
title, message, priority string,
) {
if svc.ntfyURL == nil {
return
}
if svc.config.SlackWebhook != "" {
go func() {
notifyCtx := context.WithoutCancel(ctx)
go func() {
notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack(
notifyCtx,
svc.config.SlackWebhook,
title, message, priority,
)
if err != nil {
svc.log.Error(
"failed to send slack notification",
"error", err,
err := svc.deliverWithRetry(
notifyCtx, "ntfy",
func(c context.Context) error {
return svc.sendNtfy(
c, svc.ntfyURL,
title, message, priority,
)
}
}()
},
)
if err != nil {
svc.log.Error(
"failed to send ntfy notification "+
"after retries",
"error", err,
)
}
}()
}
func (svc *Service) dispatchSlack(
ctx context.Context,
title, message, priority string,
) {
if svc.slackWebhookURL == nil {
return
}
if svc.config.MattermostWebhook != "" {
go func() {
notifyCtx := context.WithoutCancel(ctx)
go func() {
notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack(
notifyCtx,
svc.config.MattermostWebhook,
title, message, priority,
)
if err != nil {
svc.log.Error(
"failed to send mattermost notification",
"error", err,
err := svc.deliverWithRetry(
notifyCtx, "slack",
func(c context.Context) error {
return svc.sendSlack(
c, svc.slackWebhookURL,
title, message, priority,
)
}
}()
},
)
if err != nil {
svc.log.Error(
"failed to send slack notification "+
"after retries",
"error", err,
)
}
}()
}
func (svc *Service) dispatchMattermost(
ctx context.Context,
title, message, priority string,
) {
if svc.mattermostWebhookURL == nil {
return
}
go func() {
notifyCtx := context.WithoutCancel(ctx)
err := svc.deliverWithRetry(
notifyCtx, "mattermost",
func(c context.Context) error {
return svc.sendSlack(
c, svc.mattermostWebhookURL,
title, message, priority,
)
},
)
if err != nil {
svc.log.Error(
"failed to send mattermost notification "+
"after retries",
"error", err,
)
}
}()
}
func (svc *Service) sendNtfy(
ctx context.Context,
topic, title, message, priority string,
topicURL *url.URL,
title, message, priority string,
) error {
svc.log.Debug(
"sending ntfy notification",
"topic", topic,
"topic", topicURL.String(),
"title", title,
)
request, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
topic,
bytes.NewBufferString(message),
ctx, cancel := context.WithTimeout(
ctx, httpClientTimeout,
)
defer cancel()
body := bytes.NewBufferString(message)
request := newRequest(
ctx, http.MethodPost, topicURL, body,
)
if err != nil {
return fmt.Errorf("creating ntfy request: %w", err)
}
request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.client.Do(request)
resp, err := svc.transport.RoundTrip(request)
if err != nil {
return fmt.Errorf("sending ntfy request: %w", err)
}
@@ -156,7 +317,8 @@ func (svc *Service) sendNtfy(
if resp.StatusCode >= httpStatusClientError {
return fmt.Errorf(
"%w: status %d", ErrNtfyFailed, resp.StatusCode,
"%w: status %d",
ErrNtfyFailed, resp.StatusCode,
)
}
@@ -193,11 +355,17 @@ type SlackAttachment struct {
func (svc *Service) sendSlack(
ctx context.Context,
webhookURL, title, message, priority string,
webhookURL *url.URL,
title, message, priority string,
) error {
ctx, cancel := context.WithTimeout(
ctx, httpClientTimeout,
)
defer cancel()
svc.log.Debug(
"sending webhook notification",
"url", webhookURL,
"url", webhookURL.String(),
"title", title,
)
@@ -213,22 +381,19 @@ func (svc *Service) sendSlack(
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshaling webhook payload: %w", err)
return fmt.Errorf(
"marshaling webhook payload: %w", err,
)
}
request, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
webhookURL,
request := newRequest(
ctx, http.MethodPost, webhookURL,
bytes.NewBuffer(body),
)
if err != nil {
return fmt.Errorf("creating webhook request: %w", err)
}
request.Header.Set("Content-Type", "application/json")
resp, err := svc.client.Do(request)
resp, err := svc.transport.RoundTrip(request)
if err != nil {
return fmt.Errorf("sending webhook request: %w", err)
}

View File

@@ -0,0 +1,100 @@
package notify_test
import (
"testing"
"sneak.berlin/go/dnswatcher/internal/notify"
)
func TestValidateWebhookURLValid(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantURL string
}{
{
name: "valid https URL",
input: "https://hooks.slack.com/T00/B00",
wantURL: "https://hooks.slack.com/T00/B00",
},
{
name: "valid http URL",
input: "http://localhost:8080/webhook",
wantURL: "http://localhost:8080/webhook",
},
{
name: "https with query",
input: "https://ntfy.sh/topic?auth=tok",
wantURL: "https://ntfy.sh/topic?auth=tok",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := notify.ValidateWebhookURL(tt.input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got.String() != tt.wantURL {
t.Errorf(
"got %q, want %q",
got.String(), tt.wantURL,
)
}
})
}
}
func TestValidateWebhookURLInvalid(t *testing.T) {
t.Parallel()
invalid := []struct {
name string
input string
}{
{"ftp scheme", "ftp://example.com/file"},
{"file scheme", "file:///etc/passwd"},
{"empty string", ""},
{"no scheme", "example.com/webhook"},
{"no host", "https:///path"},
}
for _, tt := range invalid {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := notify.ValidateWebhookURL(tt.input)
if err == nil {
t.Errorf(
"expected error for %q, got %v",
tt.input, got,
)
}
})
}
}
func TestIsAllowedScheme(t *testing.T) {
t.Parallel()
if !notify.IsAllowedScheme("https") {
t.Error("https should be allowed")
}
if !notify.IsAllowedScheme("http") {
t.Error("http should be allowed")
}
if notify.IsAllowedScheme("ftp") {
t.Error("ftp should not be allowed")
}
if notify.IsAllowedScheme("") {
t.Error("empty scheme should not be allowed")
}
}

139
internal/notify/retry.go Normal file
View File

@@ -0,0 +1,139 @@
package notify
import (
"context"
"math"
"math/rand/v2"
"time"
)
// Retry defaults.
const (
// DefaultMaxRetries is the number of additional attempts
// after the first failure.
DefaultMaxRetries = 3
// DefaultBaseDelay is the initial delay before the first
// retry attempt.
DefaultBaseDelay = 1 * time.Second
// DefaultMaxDelay caps the computed backoff delay.
DefaultMaxDelay = 10 * time.Second
// backoffMultiplier is the exponential growth factor.
backoffMultiplier = 2
// jitterFraction controls the ±random spread applied
// to each delay (0.25 = ±25%).
jitterFraction = 0.25
)
// RetryConfig holds tuning knobs for the retry loop.
// Zero values fall back to the package defaults above.
type RetryConfig struct {
MaxRetries int
BaseDelay time.Duration
MaxDelay time.Duration
}
// defaults returns a copy with zero fields replaced by
// package defaults.
func (rc RetryConfig) defaults() RetryConfig {
if rc.MaxRetries <= 0 {
rc.MaxRetries = DefaultMaxRetries
}
if rc.BaseDelay <= 0 {
rc.BaseDelay = DefaultBaseDelay
}
if rc.MaxDelay <= 0 {
rc.MaxDelay = DefaultMaxDelay
}
return rc
}
// backoff computes the delay for attempt n (0-indexed) with
// jitter. The raw delay is BaseDelay * 2^n, capped at
// MaxDelay, then randomised by ±jitterFraction.
func (rc RetryConfig) backoff(attempt int) time.Duration {
raw := float64(rc.BaseDelay) *
math.Pow(backoffMultiplier, float64(attempt))
if raw > float64(rc.MaxDelay) {
raw = float64(rc.MaxDelay)
}
// Apply jitter: uniform in [raw*(1-j), raw*(1+j)].
lo := raw * (1 - jitterFraction)
hi := raw * (1 + jitterFraction)
jittered := lo + rand.Float64()*(hi-lo) //nolint:gosec // jitter does not need crypto/rand
return time.Duration(jittered)
}
// deliverWithRetry calls fn, retrying on error with
// exponential backoff. It logs every failed attempt and
// returns the last error if all attempts are exhausted.
func (svc *Service) deliverWithRetry(
ctx context.Context,
endpoint string,
fn func(context.Context) error,
) error {
cfg := svc.retryConfig.defaults()
var lastErr error
// attempt 0 is the initial call; attempts 1..MaxRetries
// are retries.
for attempt := range cfg.MaxRetries + 1 {
lastErr = fn(ctx)
if lastErr == nil {
if attempt > 0 {
svc.log.Info(
"notification delivered after retry",
"endpoint", endpoint,
"attempt", attempt+1,
)
}
return nil
}
// Last attempt — don't sleep, just return.
if attempt == cfg.MaxRetries {
break
}
delay := cfg.backoff(attempt)
svc.log.Warn(
"notification delivery failed, retrying",
"endpoint", endpoint,
"attempt", attempt+1,
"maxAttempts", cfg.MaxRetries+1,
"retryIn", delay,
"error", lastErr,
)
select {
case <-ctx.Done():
return ctx.Err()
case <-svc.sleepFunc(delay):
}
}
return lastErr
}
// sleepFunc returns a channel that closes after d.
// It is a field-level indirection so tests can override it.
func (svc *Service) sleepFunc(d time.Duration) <-chan time.Time {
if svc.sleepFn != nil {
return svc.sleepFn(d)
}
return time.After(d)
}

View File

@@ -0,0 +1,493 @@
package notify_test
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"sync/atomic"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/notify"
)
// Static test errors (err113).
var (
errTransient = errors.New("transient failure")
errPermanent = errors.New("permanent failure")
errFail = errors.New("fail")
)
// instantSleep returns a closed channel immediately, removing
// real delays from tests.
func instantSleep(_ time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
ch <- time.Now()
return ch
}
// ── backoff calculation ───────────────────────────────────
func TestBackoffDurationIncreases(t *testing.T) {
t.Parallel()
cfg := notify.RetryConfig{
MaxRetries: 5,
BaseDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
}
prev := time.Duration(0)
// With jitter the exact value varies, but the trend
// should be increasing for the first few attempts.
for attempt := range 4 {
d := cfg.BackoffDuration(attempt)
if d <= 0 {
t.Fatalf(
"attempt %d: backoff must be positive, got %v",
attempt, d,
)
}
// Allow jitter to occasionally flatten a step, but
// the midpoint (no-jitter) should be strictly higher.
midpoint := cfg.BaseDelay * (1 << attempt)
if attempt > 0 && midpoint <= prev {
t.Fatalf(
"midpoint should grow: attempt %d midpoint=%v prev=%v",
attempt, midpoint, prev,
)
}
prev = midpoint
}
}
func TestBackoffDurationCappedAtMax(t *testing.T) {
t.Parallel()
cfg := notify.RetryConfig{
MaxRetries: 5,
BaseDelay: 1 * time.Second,
MaxDelay: 5 * time.Second,
}
// Attempt 10 would be 1024s without capping.
d := cfg.BackoffDuration(10)
// With ±25% jitter on a 5s cap: max is 6.25s.
const maxWithJitter = 5*time.Second +
5*time.Second/4 +
time.Millisecond // rounding margin
if d > maxWithJitter {
t.Errorf(
"backoff %v exceeds max+jitter %v",
d, maxWithJitter,
)
}
}
// ── deliverWithRetry ──────────────────────────────────────
func TestDeliverWithRetrySucceedsFirstAttempt(t *testing.T) {
t.Parallel()
svc := notify.NewTestService(http.DefaultTransport)
svc.SetSleepFunc(instantSleep)
var calls atomic.Int32
err := svc.DeliverWithRetry(
context.Background(), "test",
func(_ context.Context) error {
calls.Add(1)
return nil
},
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls.Load() != 1 {
t.Errorf("expected 1 call, got %d", calls.Load())
}
}
func TestDeliverWithRetryRetriesOnFailure(t *testing.T) {
t.Parallel()
svc := notify.NewTestService(http.DefaultTransport)
svc.SetSleepFunc(instantSleep)
svc.SetRetryConfig(notify.RetryConfig{
MaxRetries: 3,
BaseDelay: time.Millisecond,
MaxDelay: 10 * time.Millisecond,
})
var calls atomic.Int32
// Fail twice, then succeed on the third attempt.
err := svc.DeliverWithRetry(
context.Background(), "test",
func(_ context.Context) error {
n := calls.Add(1)
if n <= 2 {
return errTransient
}
return nil
},
)
if err != nil {
t.Fatalf("expected success after retries: %v", err)
}
if calls.Load() != 3 {
t.Errorf("expected 3 calls, got %d", calls.Load())
}
}
func TestDeliverWithRetryExhaustsAttempts(t *testing.T) {
t.Parallel()
svc := notify.NewTestService(http.DefaultTransport)
svc.SetSleepFunc(instantSleep)
svc.SetRetryConfig(notify.RetryConfig{
MaxRetries: 2,
BaseDelay: time.Millisecond,
MaxDelay: 10 * time.Millisecond,
})
var calls atomic.Int32
err := svc.DeliverWithRetry(
context.Background(), "test",
func(_ context.Context) error {
calls.Add(1)
return errPermanent
},
)
if err == nil {
t.Fatal("expected error when all retries exhausted")
}
if !errors.Is(err, errPermanent) {
t.Errorf("expected permanent failure, got: %v", err)
}
// 1 initial + 2 retries = 3 total.
if calls.Load() != 3 {
t.Errorf("expected 3 calls, got %d", calls.Load())
}
}
func TestDeliverWithRetryRespectsContextCancellation(
t *testing.T,
) {
t.Parallel()
svc := notify.NewTestService(http.DefaultTransport)
svc.SetRetryConfig(notify.RetryConfig{
MaxRetries: 5,
BaseDelay: time.Millisecond,
MaxDelay: 10 * time.Millisecond,
})
// Use a blocking sleep so the context cancellation is
// the only way out.
svc.SetSleepFunc(func(_ time.Duration) <-chan time.Time {
return make(chan time.Time) // never fires
})
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- svc.DeliverWithRetry(
ctx, "test",
func(_ context.Context) error {
return errFail
},
)
}()
// Wait for the first failure + retry sleep to be
// entered, then cancel.
time.Sleep(50 * time.Millisecond)
cancel()
select {
case err := <-done:
if !errors.Is(err, context.Canceled) {
t.Errorf(
"expected context.Canceled, got: %v", err,
)
}
case <-time.After(2 * time.Second):
t.Fatal("deliverWithRetry did not return after cancel")
}
}
// ── integration: SendNotification with retry ──────────────
func TestSendNotificationRetriesTransientFailure(
t *testing.T,
) {
t.Parallel()
var (
mu sync.Mutex
attempts int
)
srv := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
attempts++
n := attempts
mu.Unlock()
if n <= 2 {
w.WriteHeader(
http.StatusInternalServerError,
)
return
}
w.WriteHeader(http.StatusOK)
}),
)
defer srv.Close()
svc := newRetryTestService(srv.URL, "ntfy")
svc.SendNotification(
context.Background(),
"Retry Test", "body", "warning",
)
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return attempts >= 3
})
}
// newRetryTestService creates a test service with instant
// sleep and low retry delays for the named endpoint.
func newRetryTestService(
rawURL, endpoint string,
) *notify.Service {
svc := notify.NewTestService(http.DefaultTransport)
svc.SetSleepFunc(instantSleep)
svc.SetRetryConfig(notify.RetryConfig{
MaxRetries: 3,
BaseDelay: time.Millisecond,
MaxDelay: 10 * time.Millisecond,
})
u, _ := url.Parse(rawURL)
switch endpoint {
case "ntfy":
svc.SetNtfyURL(u)
case "slack":
svc.SetSlackWebhookURL(u)
case "mattermost":
svc.SetMattermostWebhookURL(u)
}
return svc
}
func TestSendNotificationAllEndpointsRetrySetup(
t *testing.T,
) {
t.Parallel()
result := newEndpointRetryResult()
ntfySrv, slackSrv, mmSrv := newRetryServers(result)
defer ntfySrv.Close()
defer slackSrv.Close()
defer mmSrv.Close()
svc := buildAllEndpointRetryService(
ntfySrv.URL, slackSrv.URL, mmSrv.URL,
)
svc.SendNotification(
context.Background(),
"Multi-Retry", "testing", "error",
)
assertAllEndpointsRetried(t, result)
}
// endpointRetryResult tracks per-endpoint retry state.
type endpointRetryResult struct {
mu sync.Mutex
ntfyAttempts int
slackAttempts int
mmAttempts int
ntfyOK bool
slackOK bool
mmOK bool
}
func newEndpointRetryResult() *endpointRetryResult {
return &endpointRetryResult{}
}
func newRetryServers(
r *endpointRetryResult,
) (*httptest.Server, *httptest.Server, *httptest.Server) {
mk := func(
attempts *int, ok *bool,
) *httptest.Server {
return httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
r.mu.Lock()
*attempts++
n := *attempts
r.mu.Unlock()
if n == 1 {
w.WriteHeader(
http.StatusServiceUnavailable,
)
return
}
r.mu.Lock()
*ok = true
r.mu.Unlock()
w.WriteHeader(http.StatusOK)
}),
)
}
return mk(&r.ntfyAttempts, &r.ntfyOK),
mk(&r.slackAttempts, &r.slackOK),
mk(&r.mmAttempts, &r.mmOK)
}
func buildAllEndpointRetryService(
ntfyURL, slackURL, mmURL string,
) *notify.Service {
svc := notify.NewTestService(http.DefaultTransport)
svc.SetSleepFunc(instantSleep)
svc.SetRetryConfig(notify.RetryConfig{
MaxRetries: 3,
BaseDelay: time.Millisecond,
MaxDelay: 10 * time.Millisecond,
})
nu, _ := url.Parse(ntfyURL)
su, _ := url.Parse(slackURL)
mu, _ := url.Parse(mmURL)
svc.SetNtfyURL(nu)
svc.SetSlackWebhookURL(su)
svc.SetMattermostWebhookURL(mu)
return svc
}
func assertAllEndpointsRetried(
t *testing.T,
r *endpointRetryResult,
) {
t.Helper()
waitForCondition(t, func() bool {
r.mu.Lock()
defer r.mu.Unlock()
return r.ntfyOK && r.slackOK && r.mmOK
})
r.mu.Lock()
defer r.mu.Unlock()
if r.ntfyAttempts < 2 {
t.Errorf(
"ntfy: expected >= 2 attempts, got %d",
r.ntfyAttempts,
)
}
if r.slackAttempts < 2 {
t.Errorf(
"slack: expected >= 2 attempts, got %d",
r.slackAttempts,
)
}
if r.mmAttempts < 2 {
t.Errorf(
"mattermost: expected >= 2 attempts, got %d",
r.mmAttempts,
)
}
}
func TestSendNotificationPermanentFailureLogsError(
t *testing.T,
) {
t.Parallel()
var (
mu sync.Mutex
attempts int
)
srv := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
attempts++
mu.Unlock()
w.WriteHeader(
http.StatusInternalServerError,
)
}),
)
defer srv.Close()
svc := newRetryTestService(srv.URL, "slack")
svc.SetRetryConfig(notify.RetryConfig{
MaxRetries: 2,
BaseDelay: time.Millisecond,
MaxDelay: 10 * time.Millisecond,
})
svc.SendNotification(
context.Background(),
"Permanent Fail", "body", "error",
)
// 1 initial + 2 retries = 3 total.
waitForCondition(t, func() bool {
mu.Lock()
defer mu.Unlock()
return attempts >= 3
})
}

View File

@@ -4,18 +4,39 @@ package portcheck
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"sync"
"time"
"go.uber.org/fx"
"golang.org/x/sync/errgroup"
"sneak.berlin/go/dnswatcher/internal/logger"
)
// ErrNotImplemented indicates the port checker is not yet implemented.
var ErrNotImplemented = errors.New(
"port checker not yet implemented",
const (
minPort = 1
maxPort = 65535
defaultTimeout = 5 * time.Second
)
// ErrInvalidPort is returned when a port number is outside
// the valid TCP range (165535).
var ErrInvalidPort = errors.New("invalid port number")
// PortResult holds the outcome of a single TCP port check.
type PortResult struct {
// Open indicates whether the port accepted a connection.
Open bool
// Error contains a description if the connection failed.
Error string
// Latency is the time taken for the TCP handshake.
Latency time.Duration
}
// Params contains dependencies for Checker.
type Params struct {
fx.In
@@ -38,11 +59,145 @@ func New(
}, nil
}
// CheckPort tests TCP connectivity to the given address and port.
func (c *Checker) CheckPort(
_ context.Context,
_ string,
_ int,
) (bool, error) {
return false, ErrNotImplemented
// NewStandalone creates a Checker without fx dependencies.
func NewStandalone() *Checker {
return &Checker{
log: slog.Default(),
}
}
// validatePort checks that a port number is within the valid
// TCP port range (165535).
func validatePort(port int) error {
if port < minPort || port > maxPort {
return fmt.Errorf(
"%w: %d (must be between %d and %d)",
ErrInvalidPort, port, minPort, maxPort,
)
}
return nil
}
// CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier
// deadline.
func (c *Checker) CheckPort(
ctx context.Context,
address string,
port int,
) (*PortResult, error) {
err := validatePort(port)
if err != nil {
return nil, err
}
target := net.JoinHostPort(
address, strconv.Itoa(port),
)
timeout := defaultTimeout
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
remaining := time.Until(deadline)
if remaining < timeout {
timeout = remaining
}
}
return c.checkConnection(ctx, target, timeout), nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address concurrently. It returns a map of port number
// to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
for _, port := range ports {
err := validatePort(port)
if err != nil {
return nil, err
}
}
var mu sync.Mutex
results := make(map[int]*PortResult, len(ports))
g, ctx := errgroup.WithContext(ctx)
for _, port := range ports {
g.Go(func() error {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return fmt.Errorf(
"checking port %d: %w", port, err,
)
}
mu.Lock()
results[port] = result
mu.Unlock()
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return results, nil
}
// checkConnection performs the TCP dial and returns a result.
func (c *Checker) checkConnection(
ctx context.Context,
target string,
timeout time.Duration,
) *PortResult {
dialer := &net.Dialer{Timeout: timeout}
start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
latency := time.Since(start)
if dialErr != nil {
c.log.Debug(
"port check failed",
"target", target,
"error", dialErr.Error(),
)
return &PortResult{
Open: false,
Error: dialErr.Error(),
Latency: latency,
}
}
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing connection",
"target", target,
"error", closeErr.Error(),
)
}
c.log.Debug(
"port check succeeded",
"target", target,
"latency", latency,
)
return &PortResult{
Open: true,
Latency: latency,
}
}

View File

@@ -0,0 +1,211 @@
package portcheck_test
import (
"context"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/portcheck"
)
func listenTCP(
t *testing.T,
) (net.Listener, int) {
t.Helper()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to start listener: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return ln, addr.Port
}
func TestCheckPortOpen(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
defer func() { _ = ln.Close() }()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !result.Open {
t.Error("expected port to be open")
}
if result.Error != "" {
t.Errorf("expected no error, got: %s", result.Error)
}
if result.Latency <= 0 {
t.Error("expected positive latency")
}
}
func TestCheckPortClosed(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
_ = ln.Close()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Open {
t.Error("expected port to be closed")
}
if result.Error == "" {
t.Error("expected error message for closed port")
}
}
func TestCheckPortContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(ctx, "127.0.0.1", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Open {
t.Error("expected port to not be open")
}
}
func TestCheckPortsMultiple(t *testing.T) {
t.Parallel()
ln, openPort := listenTCP(t)
defer func() { _ = ln.Close() }()
ln2, closedPort := listenTCP(t)
_ = ln2.Close()
checker := portcheck.NewStandalone()
results, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{openPort, closedPort},
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 2 {
t.Fatalf(
"expected 2 results, got %d", len(results),
)
}
if !results[openPort].Open {
t.Error("expected open port to be open")
}
if results[closedPort].Open {
t.Error("expected closed port to be closed")
}
}
func TestCheckPortInvalidPorts(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
cases := []struct {
name string
port int
}{
{"zero", 0},
{"negative", -1},
{"too high", 65536},
{"very negative", -1000},
{"very high", 100000},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := checker.CheckPort(
context.Background(), "127.0.0.1", tc.port,
)
if err == nil {
t.Errorf(
"expected error for port %d, got nil",
tc.port,
)
}
})
}
}
func TestCheckPortsInvalidPort(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
_, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{80, 0, 443},
)
if err == nil {
t.Error("expected error for invalid port in list")
}
}
func TestCheckPortLatencyReasonable(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
defer func() { _ = ln.Close() }()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Latency > time.Second {
t.Errorf(
"latency too high for localhost: %v",
result.Latency,
)
}
}

View File

@@ -0,0 +1,48 @@
package resolver
import (
"context"
"time"
"github.com/miekg/dns"
)
// DNSClient abstracts DNS wire-protocol exchanges so the resolver
// can be tested without hitting real nameservers.
type DNSClient interface {
ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error)
}
// udpClient wraps a real dns.Client for production use.
type udpClient struct {
timeout time.Duration
}
func (c *udpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}
// tcpClient wraps a real dns.Client using TCP.
type tcpClient struct {
timeout time.Duration
}
func (c *tcpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Net: "tcp", Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}

View File

@@ -0,0 +1,22 @@
package resolver
import "errors"
// Sentinel errors returned by the resolver.
var (
// ErrNoNameservers is returned when no authoritative NS
// could be discovered for a domain.
ErrNoNameservers = errors.New(
"no authoritative nameservers found",
)
// ErrCNAMEDepthExceeded is returned when a CNAME chain
// exceeds MaxCNAMEDepth.
ErrCNAMEDepthExceeded = errors.New(
"CNAME chain depth exceeded",
)
// ErrContextCanceled wraps context cancellation for the
// resolver's iterative queries.
ErrContextCanceled = errors.New("context canceled")
)

View File

@@ -0,0 +1,803 @@
package resolver
import (
"context"
"errors"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/miekg/dns"
)
const (
queryTimeoutDuration = 2 * time.Second
maxRetries = 2
maxDelegation = 20
timeoutMultiplier = 2
minDomainLabels = 2
)
// ErrRefused is returned when a DNS server refuses a query.
var ErrRefused = errors.New("dns query refused")
func rootServerList() []string {
return []string{
"198.41.0.4", // a.root-servers.net
"170.247.170.2", // b
"192.33.4.12", // c
"199.7.91.13", // d
"192.203.230.10", // e
"192.5.5.241", // f
"192.112.36.4", // g
"198.97.190.53", // h
"192.36.148.17", // i
"192.58.128.30", // j
"193.0.14.129", // k
"199.7.83.42", // l
"202.12.27.33", // m
}
}
func checkCtx(ctx context.Context) error {
err := ctx.Err()
if err != nil {
return ErrContextCanceled
}
return nil
}
func (r *Resolver) exchangeWithTimeout(
ctx context.Context,
msg *dns.Msg,
addr string,
attempt int,
) (*dns.Msg, error) {
_ = attempt // timeout escalation handled by client config
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
return resp, err
}
func (r *Resolver) tryExchange(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, error) {
var resp *dns.Msg
var err error
for attempt := range maxRetries {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err = r.exchangeWithTimeout(
ctx, msg, addr, attempt,
)
if err == nil {
break
}
}
return resp, err
}
func (r *Resolver) retryTCP(
ctx context.Context,
msg *dns.Msg,
addr string,
resp *dns.Msg,
) *dns.Msg {
if !resp.Truncated {
return resp
}
tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr)
if tcpErr == nil {
return tcpResp
}
return resp
}
// queryDNS sends a DNS query to a specific server IP.
// Tries non-recursive first, falls back to recursive on
// REFUSED (handles DNS interception environments).
func (r *Resolver) queryDNS(
ctx context.Context,
serverIP string,
name string,
qtype uint16,
) (*dns.Msg, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
name = dns.Fqdn(name)
addr := net.JoinHostPort(serverIP, "53")
msg := new(dns.Msg)
msg.SetQuestion(name, qtype)
msg.RecursionDesired = false
resp, err := r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err)
}
if resp.Rcode == dns.RcodeRefused {
msg.RecursionDesired = true
resp, err = r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, err,
)
}
if resp.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, ErrRefused,
)
}
}
resp = r.retryTCP(ctx, msg, addr, resp)
return resp, nil
}
func extractNSSet(rrs []dns.RR) []string {
nsSet := make(map[string]bool)
for _, rr := range rrs {
if ns, ok := rr.(*dns.NS); ok {
nsSet[strings.ToLower(ns.Ns)] = true
}
}
names := make([]string, 0, len(nsSet))
for n := range nsSet {
names = append(names, n)
}
sort.Strings(names)
return names
}
func extractGlue(rrs []dns.RR) map[string][]net.IP {
glue := make(map[string][]net.IP)
for _, rr := range rrs {
switch r := rr.(type) {
case *dns.A:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.A)
case *dns.AAAA:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.AAAA)
}
}
return glue
}
func glueIPs(nsNames []string, glue map[string][]net.IP) []string {
var ips []string
for _, ns := range nsNames {
for _, addr := range glue[ns] {
if v4 := addr.To4(); v4 != nil {
ips = append(ips, v4.String())
}
}
}
return ips
}
func (r *Resolver) followDelegation(
ctx context.Context,
domain string,
servers []string,
) ([]string, error) {
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryServers(
ctx, servers, domain, dns.TypeNS,
)
if err != nil {
return nil, err
}
ansNS := extractNSSet(resp.Answer)
if len(ansNS) > 0 {
return ansNS, nil
}
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
return r.resolveNSIterative(ctx, domain)
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
nextServers = r.resolveNSIPs(ctx, authNS)
}
if len(nextServers) == 0 {
return nil, ErrNoNameservers
}
servers = nextServers
}
return nil, ErrNoNameservers
}
func (r *Resolver) queryServers(
ctx context.Context,
servers []string,
name string,
qtype uint16,
) (*dns.Msg, error) {
var lastErr error
for _, ip := range servers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryDNS(ctx, ip, name, qtype)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, fmt.Errorf("all servers failed: %w", lastErr)
}
func (r *Resolver) resolveNSIPs(
ctx context.Context,
nsNames []string,
) []string {
var ips []string
for _, ns := range nsNames {
resolved, err := r.resolveARecord(ctx, ns)
if err == nil {
ips = append(ips, resolved...)
}
if len(ips) > 0 {
break
}
}
return ips
}
// resolveNSIterative queries for NS records using iterative
// resolution as a fallback when followDelegation finds no
// authoritative answer in the delegation chain.
func (r *Resolver) resolveNSIterative(
ctx context.Context,
domain string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
domain = dns.Fqdn(domain)
servers := rootServerList()
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryServers(
ctx, servers, domain, dns.TypeNS,
)
if err != nil {
return nil, err
}
nsNames := extractNSSet(resp.Answer)
if len(nsNames) > 0 {
return nsNames, nil
}
// Follow delegation.
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
break
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
break
}
servers = nextServers
}
return nil, ErrNoNameservers
}
// resolveARecord resolves a hostname to IPv4 addresses using
// iterative resolution through the delegation chain.
func (r *Resolver) resolveARecord(
ctx context.Context,
hostname string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
hostname = dns.Fqdn(hostname)
servers := rootServerList()
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryServers(
ctx, servers, hostname, dns.TypeA,
)
if err != nil {
return nil, fmt.Errorf(
"resolving %s: %w", hostname, err,
)
}
// Check for A records in the answer section.
var ips []string
for _, rr := range resp.Answer {
if a, ok := rr.(*dns.A); ok {
ips = append(ips, a.A.String())
}
}
if len(ips) > 0 {
return ips, nil
}
// Follow delegation if present.
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
break
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
// Resolve NS IPs iteratively — but guard
// against infinite recursion by using only
// already-resolved servers.
break
}
servers = nextServers
}
return nil, fmt.Errorf(
"cannot resolve %s: %w", hostname, ErrNoNameservers,
)
}
// FindAuthoritativeNameservers traces the delegation chain from
// root servers to discover all authoritative nameservers for the
// given domain. Walks up the label hierarchy for subdomains.
func (r *Resolver) FindAuthoritativeNameservers(
ctx context.Context,
domain string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
domain = dns.Fqdn(strings.ToLower(domain))
labels := dns.SplitDomainName(domain)
for i := range labels {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
candidate := strings.Join(labels[i:], ".") + "."
nsNames, err := r.followDelegation(
ctx, candidate, rootServerList(),
)
if err == nil && len(nsNames) > 0 {
sort.Strings(nsNames)
return nsNames, nil
}
}
return nil, ErrNoNameservers
}
// QueryNameserver queries a specific nameserver for all record
// types and builds a NameserverResponse.
func (r *Resolver) QueryNameserver(
ctx context.Context,
nsHostname string,
hostname string,
) (*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
nsIPs, err := r.resolveARecord(ctx, nsHostname)
if err != nil {
return nil, fmt.Errorf("resolving NS %s: %w", nsHostname, err)
}
hostname = dns.Fqdn(hostname)
return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname)
}
// QueryNameserverIP queries a nameserver by its IP address directly,
// bypassing NS hostname resolution.
func (r *Resolver) QueryNameserverIP(
ctx context.Context,
nsHostname string,
nsIP string,
hostname string,
) (*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
hostname = dns.Fqdn(hostname)
return r.queryAllTypes(ctx, nsHostname, nsIP, hostname)
}
func (r *Resolver) queryAllTypes(
ctx context.Context,
nsHostname string,
nsIP string,
hostname string,
) (*NameserverResponse, error) {
resp := &NameserverResponse{
Nameserver: nsHostname,
Records: make(map[string][]string),
Status: StatusOK,
}
qtypes := []uint16{
dns.TypeA, dns.TypeAAAA, dns.TypeCNAME,
dns.TypeMX, dns.TypeTXT, dns.TypeSRV,
dns.TypeCAA, dns.TypeNS,
}
state := r.queryEachType(ctx, nsIP, hostname, qtypes, resp)
classifyResponse(resp, state)
return resp, nil
}
type queryState struct {
gotNXDomain bool
gotSERVFAIL bool
gotTimeout bool
hasRecords bool
}
func (r *Resolver) queryEachType(
ctx context.Context,
nsIP string,
hostname string,
qtypes []uint16,
resp *NameserverResponse,
) queryState {
var state queryState
for _, qtype := range qtypes {
if checkCtx(ctx) != nil {
break
}
r.querySingleType(ctx, nsIP, hostname, qtype, resp, &state)
}
for k := range resp.Records {
sort.Strings(resp.Records[k])
}
return state
}
func (r *Resolver) querySingleType(
ctx context.Context,
nsIP string,
hostname string,
qtype uint16,
resp *NameserverResponse,
state *queryState,
) {
msg, err := r.queryDNS(ctx, nsIP, hostname, qtype)
if err != nil {
if isTimeout(err) {
state.gotTimeout = true
}
return
}
if msg.Rcode == dns.RcodeNameError {
state.gotNXDomain = true
return
}
if msg.Rcode == dns.RcodeServerFailure {
state.gotSERVFAIL = true
return
}
collectAnswerRecords(msg, resp, state)
}
func collectAnswerRecords(
msg *dns.Msg,
resp *NameserverResponse,
state *queryState,
) {
for _, rr := range msg.Answer {
val := extractRecordValue(rr)
if val == "" {
continue
}
typeName := dns.TypeToString[rr.Header().Rrtype]
resp.Records[typeName] = append(
resp.Records[typeName], val,
)
state.hasRecords = true
}
}
// isTimeout checks whether an error is a network timeout.
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func classifyResponse(resp *NameserverResponse, state queryState) {
switch {
case state.gotNXDomain && !state.hasRecords:
resp.Status = StatusNXDomain
case state.gotTimeout && !state.hasRecords:
resp.Status = StatusTimeout
resp.Error = "all queries timed out"
case state.gotSERVFAIL && !state.hasRecords:
resp.Status = StatusError
resp.Error = "server returned SERVFAIL"
case !state.hasRecords && !state.gotNXDomain:
resp.Status = StatusNoData
}
}
// extractRecordValue formats a DNS RR value as a string.
func extractRecordValue(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
return r.A.String()
case *dns.AAAA:
return r.AAAA.String()
case *dns.CNAME:
return r.Target
case *dns.MX:
return fmt.Sprintf("%d %s", r.Preference, r.Mx)
case *dns.TXT:
return strings.Join(r.Txt, "")
case *dns.SRV:
return fmt.Sprintf(
"%d %d %d %s",
r.Priority, r.Weight, r.Port, r.Target,
)
case *dns.CAA:
return fmt.Sprintf(
"%d %s \"%s\"", r.Flag, r.Tag, r.Value,
)
case *dns.NS:
return r.Ns
default:
return ""
}
}
// parentDomain returns the registerable parent domain.
func parentDomain(hostname string) string {
hostname = dns.Fqdn(strings.ToLower(hostname))
labels := dns.SplitDomainName(hostname)
if len(labels) <= minDomainLabels {
return strings.Join(labels, ".") + "."
}
return strings.Join(
labels[len(labels)-minDomainLabels:], ".",
) + "."
}
// QueryAllNameservers discovers auth NSes for the hostname's
// parent domain, then queries each one independently.
func (r *Resolver) QueryAllNameservers(
ctx context.Context,
hostname string,
) (map[string]*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
parent := parentDomain(hostname)
nameservers, err := r.FindAuthoritativeNameservers(ctx, parent)
if err != nil {
return nil, err
}
return r.queryEachNS(ctx, nameservers, hostname)
}
func (r *Resolver) queryEachNS(
ctx context.Context,
nameservers []string,
hostname string,
) (map[string]*NameserverResponse, error) {
results := make(map[string]*NameserverResponse)
for _, ns := range nameservers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.QueryNameserver(ctx, ns, hostname)
if err != nil {
results[ns] = &NameserverResponse{
Nameserver: ns,
Records: make(map[string][]string),
Status: StatusError,
Error: err.Error(),
}
continue
}
results[ns] = resp
}
return results, nil
}
// LookupNS returns the NS record set for a domain.
func (r *Resolver) LookupNS(
ctx context.Context,
domain string,
) ([]string, error) {
return r.FindAuthoritativeNameservers(ctx, domain)
}
// LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname, keyed by authoritative nameserver.
func (r *Resolver) LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error) {
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
out := make(map[string]map[string][]string, len(results))
for ns, resp := range results {
out[ns] = resp.Records
}
return out, nil
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains up to MaxCNAMEDepth.
func (r *Resolver) ResolveIPAddresses(
ctx context.Context,
hostname string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
return r.resolveIPWithCNAME(ctx, hostname, 0)
}
func (r *Resolver) resolveIPWithCNAME(
ctx context.Context,
hostname string,
depth int,
) ([]string, error) {
if depth > MaxCNAMEDepth {
return nil, ErrCNAMEDepthExceeded
}
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
ips, cnameTarget := collectIPs(results)
if len(ips) == 0 && cnameTarget != "" {
return r.resolveIPWithCNAME(ctx, cnameTarget, depth+1)
}
sort.Strings(ips)
return ips, nil
}
func collectIPs(
results map[string]*NameserverResponse,
) ([]string, string) {
seen := make(map[string]bool)
var ips []string
var cnameTarget string
for _, resp := range results {
if resp.Status == StatusNXDomain {
continue
}
for _, ip := range resp.Records["A"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
for _, ip := range resp.Records["AAAA"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
if len(resp.Records["CNAME"]) > 0 && cnameTarget == "" {
cnameTarget = resp.Records["CNAME"][0]
}
}
return ips, cnameTarget
}

View File

@@ -1,9 +1,9 @@
// Package resolver provides iterative DNS resolution from root nameservers.
// It traces the full delegation chain from IANA root servers through TLD
// and domain nameservers, never relying on upstream recursive resolvers.
package resolver
import (
"context"
"errors"
"log/slog"
"go.uber.org/fx"
@@ -11,8 +11,17 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger"
)
// ErrNotImplemented indicates the resolver is not yet implemented.
var ErrNotImplemented = errors.New("resolver not yet implemented")
// Query status constants matching the state model.
const (
StatusOK = "ok"
StatusError = "error"
StatusNXDomain = "nxdomain"
StatusNoData = "nodata"
StatusTimeout = "timeout"
)
// MaxCNAMEDepth is the maximum CNAME chain depth to follow.
const MaxCNAMEDepth = 10
// Params contains dependencies for Resolver.
type Params struct {
@@ -21,44 +30,54 @@ type Params struct {
Logger *logger.Logger
}
// Resolver performs iterative DNS resolution from root servers.
type Resolver struct {
log *slog.Logger
// NameserverResponse holds one nameserver's response for a query.
type NameserverResponse struct {
Nameserver string
Records map[string][]string
Status string
Error string
}
// New creates a new Resolver instance.
// Resolver performs iterative DNS resolution from root servers.
type Resolver struct {
log *slog.Logger
client DNSClient
tcp DNSClient
}
// New creates a new Resolver instance for use with uber/fx.
func New(
_ fx.Lifecycle,
params Params,
) (*Resolver, error) {
return &Resolver{
log: params.Logger.Get(),
log: params.Logger.Get(),
client: &udpClient{timeout: queryTimeoutDuration},
tcp: &tcpClient{timeout: queryTimeoutDuration},
}, nil
}
// LookupNS performs iterative resolution to find authoritative
// nameservers for the given domain.
func (r *Resolver) LookupNS(
_ context.Context,
_ string,
) ([]string, error) {
return nil, ErrNotImplemented
// NewFromLogger creates a Resolver directly from an slog.Logger,
// useful for testing without the fx lifecycle.
func NewFromLogger(log *slog.Logger) *Resolver {
return &Resolver{
log: log,
client: &udpClient{timeout: queryTimeoutDuration},
tcp: &tcpClient{timeout: queryTimeoutDuration},
}
}
// LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname.
func (r *Resolver) LookupAllRecords(
_ context.Context,
_ string,
) (map[string][]string, error) {
return nil, ErrNotImplemented
// NewFromLoggerWithClient creates a Resolver with a custom DNS
// client, useful for testing with mock DNS responses.
func NewFromLoggerWithClient(
log *slog.Logger,
client DNSClient,
) *Resolver {
return &Resolver{
log: log,
client: client,
tcp: client,
}
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains.
func (r *Resolver) ResolveIPAddresses(
_ context.Context,
_ string,
) ([]string, error) {
return nil, ErrNotImplemented
}
// Method implementations are in iterative.go.

View File

@@ -0,0 +1,688 @@
package resolver_test
import (
"context"
"log/slog"
"net"
"os"
"sort"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/resolver"
)
// ----------------------------------------------------------------
// Test helpers
// ----------------------------------------------------------------
func newTestResolver(t *testing.T) *resolver.Resolver {
t.Helper()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
return resolver.NewFromLogger(log)
}
func testContext(t *testing.T) context.Context {
t.Helper()
ctx, cancel := context.WithTimeout(
context.Background(), 60*time.Second,
)
t.Cleanup(cancel)
return ctx
}
func findOneNSForDomain(
t *testing.T,
r *resolver.Resolver,
ctx context.Context, //nolint:revive // test helper
domain string,
) string {
t.Helper()
nameservers, err := r.FindAuthoritativeNameservers(
ctx, domain,
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
return nameservers[0]
}
// ----------------------------------------------------------------
// FindAuthoritativeNameservers tests
// ----------------------------------------------------------------
func TestFindAuthoritativeNameservers_ValidDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
hasGoogleNS := false
for _, ns := range nameservers {
if strings.Contains(ns, "google") {
hasGoogleNS = true
break
}
}
assert.True(t, hasGoogleNS,
"expected google nameservers, got: %v", nameservers,
)
}
func TestFindAuthoritativeNameservers_Subdomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "www.google.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
}
func TestFindAuthoritativeNameservers_ReturnsSorted(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
assert.True(
t,
sort.StringsAreSorted(nameservers),
"nameservers should be sorted, got: %v", nameservers,
)
}
func TestFindAuthoritativeNameservers_Deterministic(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
first, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
second, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
assert.Equal(t, first, second)
}
func TestFindAuthoritativeNameservers_TrailingDot(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns1, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
ns2, err := r.FindAuthoritativeNameservers(
ctx, "google.com.",
)
require.NoError(t, err)
assert.Equal(t, ns1, ns2)
}
func TestFindAuthoritativeNameservers_CloudflareDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "cloudflare.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
for _, ns := range nameservers {
assert.True(t, strings.HasSuffix(ns, "."),
"NS should be FQDN with trailing dot: %s", ns,
)
}
}
// ----------------------------------------------------------------
// QueryNameserver tests
// ----------------------------------------------------------------
func TestQueryNameserver_BasicA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "www.google.com",
)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
assert.Equal(t, ns, resp.Nameserver)
hasRecords := len(resp.Records["A"]) > 0 ||
len(resp.Records["CNAME"]) > 0
assert.True(t, hasRecords,
"expected A or CNAME records for www.google.com",
)
}
func TestQueryNameserver_AAAA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "cloudflare.com")
resp, err := r.QueryNameserver(
ctx, ns, "cloudflare.com",
)
require.NoError(t, err)
aaaaRecords := resp.Records["AAAA"]
require.NotEmpty(t, aaaaRecords,
"cloudflare.com should have AAAA records",
)
for _, ip := range aaaaRecords {
parsed := net.ParseIP(ip)
require.NotNil(t, parsed,
"should be valid IP: %s", ip,
)
}
}
func TestQueryNameserver_MX(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
mxRecords := resp.Records["MX"]
require.NotEmpty(t, mxRecords,
"google.com should have MX records",
)
}
func TestQueryNameserver_TXT(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
txtRecords := resp.Records["TXT"]
require.NotEmpty(t, txtRecords,
"google.com should have TXT records",
)
hasSPF := false
for _, txt := range txtRecords {
if strings.Contains(txt, "v=spf1") {
hasSPF = true
break
}
}
assert.True(t, hasSPF,
"google.com should have SPF TXT record",
)
}
func TestQueryNameserver_NXDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
assert.Equal(t, resolver.StatusNXDomain, resp.Status)
}
func TestQueryNameserver_RecordsSorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
for recordType, values := range resp.Records {
assert.True(
t,
sort.StringsAreSorted(values),
"%s records should be sorted", recordType,
)
}
}
func TestQueryNameserver_ResponseIncludesNameserver(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "cloudflare.com")
resp, err := r.QueryNameserver(
ctx, ns, "cloudflare.com",
)
require.NoError(t, err)
assert.Equal(t, ns, resp.Nameserver)
}
func TestQueryNameserver_EmptyRecordsOnNXDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
totalRecords := 0
for _, values := range resp.Records {
totalRecords += len(values)
}
assert.Zero(t, totalRecords)
}
func TestQueryNameserver_TrailingDotHandling(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp1, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
resp2, err := r.QueryNameserver(
ctx, ns, "google.com.",
)
require.NoError(t, err)
assert.Equal(t, resp1.Status, resp2.Status)
}
// ----------------------------------------------------------------
// QueryAllNameservers tests
// ----------------------------------------------------------------
func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx, "google.com",
)
require.NoError(t, err)
require.NotEmpty(t, results)
assert.GreaterOrEqual(t, len(results), 2)
for ns, resp := range results {
assert.Equal(t, ns, resp.Nameserver)
}
}
func TestQueryAllNameservers_AllReturnOK(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx, "google.com",
)
require.NoError(t, err)
for ns, resp := range results {
assert.Equal(
t, resolver.StatusOK, resp.Status,
"NS %s should return OK", ns,
)
}
}
func TestQueryAllNameservers_NXDomainFromAllNS(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
for ns, resp := range results {
assert.Equal(
t, resolver.StatusNXDomain, resp.Status,
"NS %s should return nxdomain", ns,
)
}
}
// ----------------------------------------------------------------
// LookupNS tests
// ----------------------------------------------------------------
func TestLookupNS_ValidDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, "google.com")
require.NoError(t, err)
require.NotEmpty(t, nameservers)
for _, ns := range nameservers {
assert.True(t, strings.HasSuffix(ns, "."),
"NS should have trailing dot: %s", ns,
)
}
}
func TestLookupNS_Sorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, "google.com")
require.NoError(t, err)
assert.True(t, sort.StringsAreSorted(nameservers))
}
func TestLookupNS_MatchesFindAuthoritative(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
fromLookup, err := r.LookupNS(ctx, "google.com")
require.NoError(t, err)
fromFind, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
assert.Equal(t, fromFind, fromLookup)
}
// ----------------------------------------------------------------
// ResolveIPAddresses tests
// ----------------------------------------------------------------
func TestResolveIPAddresses_ReturnsIPs(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
require.NoError(t, err)
require.NotEmpty(t, ips)
for _, ip := range ips {
parsed := net.ParseIP(ip)
assert.NotNil(t, parsed,
"should be valid IP: %s", ip,
)
}
}
func TestResolveIPAddresses_Deduplicated(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
require.NoError(t, err)
seen := make(map[string]bool)
for _, ip := range ips {
assert.False(t, seen[ip], "duplicate IP: %s", ip)
seen[ip] = true
}
}
func TestResolveIPAddresses_Sorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
require.NoError(t, err)
assert.True(t, sort.StringsAreSorted(ips))
}
func TestResolveIPAddresses_NXDomainReturnsEmpty(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(
ctx,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
assert.Empty(t, ips)
}
func TestResolveIPAddresses_CloudflareDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com")
require.NoError(t, err)
require.NotEmpty(t, ips)
}
// ----------------------------------------------------------------
// Context cancellation tests
// ----------------------------------------------------------------
func TestFindAuthoritativeNameservers_ContextCanceled(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.FindAuthoritativeNameservers(ctx, "google.com")
assert.Error(t, err)
}
func TestQueryNameserver_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.QueryNameserver(
ctx, "ns1.google.com.", "google.com",
)
assert.Error(t, err)
}
func TestQueryAllNameservers_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.QueryAllNameservers(ctx, "google.com")
assert.Error(t, err)
}
// ----------------------------------------------------------------
// Timeout tests
// ----------------------------------------------------------------
func TestQueryNameserverIP_Timeout(t *testing.T) {
t.Parallel()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
r := resolver.NewFromLoggerWithClient(
log, &timeoutClient{},
)
ctx, cancel := context.WithTimeout(
context.Background(), 10*time.Second,
)
t.Cleanup(cancel)
// Query any IP — the client always returns a timeout error.
resp, err := r.QueryNameserverIP(
ctx, "unreachable.test.", "192.0.2.1",
"example.com",
)
require.NoError(t, err)
assert.Equal(t, resolver.StatusTimeout, resp.Status)
assert.NotEmpty(t, resp.Error)
}
// timeoutClient simulates DNS timeout errors for testing.
type timeoutClient struct{}
func (c *timeoutClient) ExchangeContext(
_ context.Context,
_ *dns.Msg,
_ string,
) (*dns.Msg, time.Duration, error) {
return nil, 0, &net.OpError{
Op: "read",
Net: "udp",
Err: &timeoutError{},
}
}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
func TestResolveIPAddresses_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.ResolveIPAddresses(ctx, "google.com")
assert.Error(t, err)
}

View File

@@ -1,11 +1,14 @@
package server
import (
"net/http"
"time"
"github.com/go-chi/chi/v5"
chimw "github.com/go-chi/chi/v5/middleware"
"github.com/prometheus/client_golang/prometheus/promhttp"
"sneak.berlin/go/dnswatcher/static"
)
// requestTimeout is the maximum duration for handling a request.
@@ -22,7 +25,25 @@ func (s *Server) SetupRoutes() {
s.router.Use(s.mw.CORS())
s.router.Use(chimw.Timeout(requestTimeout))
// Health check
// Dashboard (read-only web UI)
s.router.Get("/", s.handlers.HandleDashboard())
// Static assets (embedded CSS/JS)
s.router.Mount(
"/s",
http.StripPrefix(
"/s",
http.FileServer(http.FS(static.Static)),
),
)
// Health check (standard well-known path)
s.router.Get(
"/.well-known/healthcheck",
s.handlers.HandleHealthCheck(),
)
// Legacy health check (keep for backward compatibility)
s.router.Get("/health", s.handlers.HandleHealthCheck())
// API v1 routes

View File

@@ -57,10 +57,49 @@ type HostnameState struct {
// PortState holds the monitoring state for a port.
type PortState struct {
Open bool `json:"open"`
Hostname string `json:"hostname"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
// UnmarshalJSON implements custom unmarshaling to handle both
// the old single-hostname format and the new multi-hostname
// format for backward compatibility with existing state files.
func (ps *PortState) UnmarshalJSON(data []byte) error {
// Use an alias to prevent infinite recursion.
type portStateAlias struct {
Open bool `json:"open"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
var alias portStateAlias
err := json.Unmarshal(data, &alias)
if err != nil {
return fmt.Errorf("unmarshaling port state: %w", err)
}
ps.Open = alias.Open
ps.Hostnames = alias.Hostnames
ps.LastChecked = alias.LastChecked
// If Hostnames is empty, try reading the old single-hostname
// format for backward compatibility.
if len(ps.Hostnames) == 0 {
var old struct {
Hostname string `json:"hostname"`
}
// Best-effort: ignore errors since the main unmarshal
// already succeeded.
if json.Unmarshal(data, &old) == nil && old.Hostname != "" {
ps.Hostnames = []string{old.Hostname}
}
}
return nil
}
// CertificateState holds TLS certificate monitoring state.
type CertificateState struct {
CommonName string `json:"commonName"`
@@ -156,8 +195,8 @@ func (s *State) Load() error {
// Save writes the current state to disk atomically.
func (s *State) Save() error {
s.mu.RLock()
defer s.mu.RUnlock()
s.mu.Lock()
defer s.mu.Unlock()
s.snapshot.LastUpdated = time.Now().UTC()
@@ -263,6 +302,27 @@ func (s *State) GetPortState(key string) (*PortState, bool) {
return ps, ok
}
// DeletePortState removes a port state entry.
func (s *State) DeletePortState(key string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.snapshot.Ports, key)
}
// GetAllPortKeys returns all port state keys.
func (s *State) GetAllPortKeys() []string {
s.mu.RLock()
defer s.mu.RUnlock()
keys := make([]string, 0, len(s.snapshot.Ports))
for k := range s.snapshot.Ports {
keys = append(keys, k)
}
return keys
}
// SetCertificateState updates the state for a certificate.
func (s *State) SetCertificateState(
key string,

1302
internal/state/state_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
package state
import (
"log/slog"
"sneak.berlin/go/dnswatcher/internal/config"
)
// NewForTest creates a State for unit testing with no persistence.
func NewForTest() *State {
return &State{
log: slog.Default(),
snapshot: &Snapshot{
Version: stateVersion,
Domains: make(map[string]*DomainState),
Hostnames: make(map[string]*HostnameState),
Ports: make(map[string]*PortState),
Certificates: make(map[string]*CertificateState),
},
config: &config.Config{DataDir: ""},
}
}
// NewForTestWithDataDir creates a State backed by the given directory
// for tests that need file persistence.
func NewForTestWithDataDir(dataDir string) *State {
return &State{
log: slog.Default(),
snapshot: &Snapshot{
Version: stateVersion,
Domains: make(map[string]*DomainState),
Hostnames: make(map[string]*HostnameState),
Ports: make(map[string]*PortState),
Certificates: make(map[string]*CertificateState),
},
config: &config.Config{DataDir: dataDir},
}
}

View File

@@ -0,0 +1,67 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func TestCheckCertificateNoPeerCerts(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = ln.Close() }()
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
// Accept and immediately close to cause TLS handshake failure.
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
InsecureSkipVerify: true, //nolint:gosec // test
MinVersion: tls.VersionTLS12,
}),
tlscheck.WithPort(addr.Port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error when server presents no certs")
}
}
func TestErrNoPeerCertificatesIsSentinel(t *testing.T) {
t.Parallel()
err := tlscheck.ErrNoPeerCertificates
if !errors.Is(err, tlscheck.ErrNoPeerCertificates) {
t.Fatal("expected sentinel error to match")
}
}

View File

@@ -3,8 +3,12 @@ package tlscheck
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"time"
"go.uber.org/fx"
@@ -12,11 +16,56 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger"
)
// ErrNotImplemented indicates the TLS checker is not yet implemented.
var ErrNotImplemented = errors.New(
"tls checker not yet implemented",
const (
defaultTimeout = 10 * time.Second
defaultPort = 443
)
// ErrUnexpectedConnType indicates the connection was not a TLS
// connection.
var ErrUnexpectedConnType = errors.New(
"unexpected connection type",
)
// ErrNoPeerCertificates indicates the TLS connection had no peer
// certificates.
var ErrNoPeerCertificates = errors.New(
"no peer certificates",
)
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
SerialNumber string
}
// Option configures a Checker.
type Option func(*Checker)
// WithTimeout sets the connection timeout.
func WithTimeout(d time.Duration) Option {
return func(c *Checker) {
c.timeout = d
}
}
// WithTLSConfig sets a custom TLS configuration.
func WithTLSConfig(cfg *tls.Config) Option {
return func(c *Checker) {
c.tlsConfig = cfg
}
}
// WithPort sets the TLS port to connect to.
func WithPort(port int) Option {
return func(c *Checker) {
c.port = port
}
}
// Params contains dependencies for Checker.
type Params struct {
fx.In
@@ -26,15 +75,10 @@ type Params struct {
// Checker performs TLS certificate inspection.
type Checker struct {
log *slog.Logger
}
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
log *slog.Logger
timeout time.Duration
tlsConfig *tls.Config
port int
}
// New creates a new TLS Checker instance.
@@ -43,16 +87,110 @@ func New(
params Params,
) (*Checker, error) {
return &Checker{
log: params.Logger.Get(),
log: params.Logger.Get(),
timeout: defaultTimeout,
port: defaultPort,
}, nil
}
// CheckCertificate connects to the given IP:port using SNI and
// returns certificate information.
func (c *Checker) CheckCertificate(
_ context.Context,
_ string,
_ string,
) (*CertificateInfo, error) {
return nil, ErrNotImplemented
// NewStandalone creates a Checker without fx dependencies.
func NewStandalone(opts ...Option) *Checker {
checker := &Checker{
log: slog.Default(),
timeout: defaultTimeout,
port: defaultPort,
}
for _, opt := range opts {
opt(checker)
}
return checker
}
// CheckCertificate connects to the given IP address using the
// specified SNI hostname and returns certificate information.
func (c *Checker) CheckCertificate(
ctx context.Context,
ipAddress string,
sniHostname string,
) (*CertificateInfo, error) {
target := net.JoinHostPort(
ipAddress, strconv.Itoa(c.port),
)
tlsCfg := c.buildTLSConfig(sniHostname)
dialer := &tls.Dialer{
NetDialer: &net.Dialer{Timeout: c.timeout},
Config: tlsCfg,
}
conn, err := dialer.DialContext(ctx, "tcp", target)
if err != nil {
return nil, fmt.Errorf(
"TLS dial to %s: %w", target, err,
)
}
defer func() {
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing TLS connection",
"target", target,
"error", closeErr.Error(),
)
}
}()
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return nil, fmt.Errorf(
"%s: %w", target, ErrUnexpectedConnType,
)
}
return c.extractCertInfo(tlsConn)
}
func (c *Checker) buildTLSConfig(
sniHostname string,
) *tls.Config {
if c.tlsConfig != nil {
cfg := c.tlsConfig.Clone()
cfg.ServerName = sniHostname
return cfg
}
return &tls.Config{
ServerName: sniHostname,
MinVersion: tls.VersionTLS12,
}
}
func (c *Checker) extractCertInfo(
conn *tls.Conn,
) (*CertificateInfo, error) {
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return nil, ErrNoPeerCertificates
}
cert := state.PeerCertificates[0]
sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses))
sans = append(sans, cert.DNSNames...)
for _, ip := range cert.IPAddresses {
sans = append(sans, ip.String())
}
return &CertificateInfo{
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: sans,
SerialNumber: cert.SerialNumber.String(),
}, nil
}

View File

@@ -0,0 +1,169 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func startTLSServer(
t *testing.T,
) (*httptest.Server, string, int) {
t.Helper()
srv := httptest.NewTLSServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
},
),
)
addr, ok := srv.Listener.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return srv, addr.IP.String(), addr.Port
}
func TestCheckCertificateValid(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info == nil {
t.Fatal("expected non-nil CertificateInfo")
}
if info.NotAfter.IsZero() {
t.Error("expected non-zero NotAfter")
}
if info.SerialNumber == "" {
t.Error("expected non-empty SerialNumber")
}
}
func TestCheckCertificateConnectionRefused(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
port := addr.Port
_ = ln.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for connection refused")
}
}
func TestCheckCertificateContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
ctx, "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for canceled context")
}
}
func TestCheckCertificateTimeout(t *testing.T) {
t.Parallel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(1*time.Millisecond),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
context.Background(),
"192.0.2.1",
"example.com",
)
if err == nil {
t.Fatal("expected error for timeout")
}
}
func TestCheckCertificateSANs(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.CommonName == "" && len(info.SubjectAlternativeNames) == 0 {
t.Error("expected CN or SANs to be populated")
}
}

View File

@@ -0,0 +1,61 @@
// Package watcher provides the main monitoring orchestrator.
package watcher
import (
"context"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
// DNSResolver performs iterative DNS resolution.
type DNSResolver interface {
// LookupNS discovers authoritative nameservers for a domain.
LookupNS(
ctx context.Context,
domain string,
) ([]string, error)
// LookupAllRecords queries all record types for a hostname,
// returning results keyed by nameserver then record type.
LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error)
// ResolveIPAddresses resolves a hostname to all IP addresses.
ResolveIPAddresses(
ctx context.Context,
hostname string,
) ([]string, error)
}
// PortChecker tests TCP port connectivity.
type PortChecker interface {
// CheckPort tests TCP connectivity to an address and port.
CheckPort(
ctx context.Context,
address string,
port int,
) (*portcheck.PortResult, error)
}
// TLSChecker inspects TLS certificates.
type TLSChecker interface {
// CheckCertificate connects via TLS and returns cert info.
CheckCertificate(
ctx context.Context,
ip string,
hostname string,
) (*tlscheck.CertificateInfo, error)
}
// Notifier delivers notifications to configured endpoints.
type Notifier interface {
// SendNotification sends a notification with the given
// details.
SendNotification(
ctx context.Context,
title, message, priority string,
)
}

View File

@@ -1,21 +1,31 @@
// Package watcher provides the main monitoring orchestrator and scheduler.
package watcher
import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
"sync"
"time"
"go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/logger"
"sneak.berlin/go/dnswatcher/internal/notify"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/resolver"
"sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
// monitoredPorts are the TCP ports checked for each IP address.
var monitoredPorts = []int{80, 443} //nolint:gochecknoglobals
// tlsPort is the port used for TLS certificate checks.
const tlsPort = 443
// hoursPerDay converts days to hours for duration calculations.
const hoursPerDay = 24
// Params contains dependencies for Watcher.
type Params struct {
fx.In
@@ -23,61 +33,92 @@ type Params struct {
Logger *logger.Logger
Config *config.Config
State *state.State
Resolver *resolver.Resolver
PortCheck *portcheck.Checker
TLSCheck *tlscheck.Checker
Notify *notify.Service
Resolver DNSResolver
PortCheck PortChecker
TLSCheck TLSChecker
Notify Notifier
}
// Watcher orchestrates all monitoring checks on a schedule.
type Watcher struct {
log *slog.Logger
config *config.Config
state *state.State
resolver *resolver.Resolver
portCheck *portcheck.Checker
tlsCheck *tlscheck.Checker
notify *notify.Service
cancel context.CancelFunc
log *slog.Logger
config *config.Config
state *state.State
resolver DNSResolver
portCheck PortChecker
tlsCheck TLSChecker
notify Notifier
cancel context.CancelFunc
firstRun bool
expiryNotifiedMu sync.Mutex
expiryNotified map[string]time.Time
}
// New creates a new Watcher instance.
// New creates a new Watcher instance wired into the fx lifecycle.
func New(
lifecycle fx.Lifecycle,
params Params,
) (*Watcher, error) {
watcher := &Watcher{
log: params.Logger.Get(),
config: params.Config,
state: params.State,
resolver: params.Resolver,
portCheck: params.PortCheck,
tlsCheck: params.TLSCheck,
notify: params.Notify,
w := &Watcher{
log: params.Logger.Get(),
config: params.Config,
state: params.State,
resolver: params.Resolver,
portCheck: params.PortCheck,
tlsCheck: params.TLSCheck,
notify: params.Notify,
firstRun: true,
expiryNotified: make(map[string]time.Time),
}
lifecycle.Append(fx.Hook{
OnStart: func(startCtx context.Context) error {
ctx, cancel := context.WithCancel(startCtx)
watcher.cancel = cancel
OnStart: func(_ context.Context) error {
// Use context.Background() — the fx startup context
// expires after startup completes, so deriving from it
// would cancel the watcher immediately. The watcher's
// lifetime is controlled by w.cancel in OnStop.
ctx, cancel := context.WithCancel(context.Background())
w.cancel = cancel
go watcher.Run(ctx)
go w.Run(ctx) //nolint:contextcheck // intentionally not derived from startCtx
return nil
},
OnStop: func(_ context.Context) error {
if watcher.cancel != nil {
watcher.cancel()
if w.cancel != nil {
w.cancel()
}
return nil
},
})
return watcher, nil
return w, nil
}
// Run starts the monitoring loop.
// NewForTest creates a Watcher without fx for unit testing.
func NewForTest(
cfg *config.Config,
st *state.State,
res DNSResolver,
pc PortChecker,
tc TLSChecker,
n Notifier,
) *Watcher {
return &Watcher{
log: slog.Default(),
config: cfg,
state: st,
resolver: res,
portCheck: pc,
tlsCheck: tc,
notify: n,
firstRun: true,
expiryNotified: make(map[string]time.Time),
}
}
// Run starts the monitoring loop with periodic scheduling.
func (w *Watcher) Run(ctx context.Context) {
w.log.Info(
"watcher starting",
@@ -87,8 +128,812 @@ func (w *Watcher) Run(ctx context.Context) {
"tlsInterval", w.config.TLSInterval,
)
// Stub: wait for context cancellation.
// Implementation will add initial check + periodic scheduling.
<-ctx.Done()
w.log.Info("watcher stopped")
w.RunOnce(ctx)
w.maybeSendTestNotification(ctx)
dnsTicker := time.NewTicker(w.config.DNSInterval)
tlsTicker := time.NewTicker(w.config.TLSInterval)
defer dnsTicker.Stop()
defer tlsTicker.Stop()
for {
select {
case <-ctx.Done():
w.log.Info("watcher stopped")
return
case <-dnsTicker.C:
w.runDNSChecks(ctx)
w.checkAllPorts(ctx)
w.saveState()
case <-tlsTicker.C:
// Run DNS first so TLS checks use freshly
// resolved IP addresses, not stale ones from
// a previous cycle.
w.runDNSChecks(ctx)
w.runTLSChecks(ctx)
w.saveState()
}
}
}
// RunOnce performs a single complete monitoring cycle.
// DNS checks run first so that port and TLS checks use
// freshly resolved IP addresses. Port checks run before
// TLS because TLS checks only target IPs with an open
// port 443.
func (w *Watcher) RunOnce(ctx context.Context) {
w.detectFirstRun()
// Phase 1: DNS resolution must complete first so that
// subsequent checks use fresh IP addresses.
w.runDNSChecks(ctx)
// Phase 2: Port checks populate port state that TLS
// checks depend on (TLS only targets IPs where port
// 443 is open).
w.checkAllPorts(ctx)
// Phase 3: TLS checks use fresh DNS IPs and current
// port state.
w.runTLSChecks(ctx)
w.saveState()
w.firstRun = false
}
func (w *Watcher) detectFirstRun() {
snap := w.state.GetSnapshot()
hasState := len(snap.Domains) > 0 ||
len(snap.Hostnames) > 0 ||
len(snap.Ports) > 0 ||
len(snap.Certificates) > 0
if hasState {
w.firstRun = false
}
}
// runDNSChecks performs DNS resolution for all configured domains
// and hostnames, updating state with freshly resolved records.
// This must complete before port or TLS checks run so those
// checks operate on current IP addresses.
func (w *Watcher) runDNSChecks(ctx context.Context) {
for _, domain := range w.config.Domains {
w.checkDomain(ctx, domain)
}
for _, hostname := range w.config.Hostnames {
w.checkHostname(ctx, hostname)
}
}
func (w *Watcher) checkDomain(
ctx context.Context,
domain string,
) {
nameservers, err := w.resolver.LookupNS(ctx, domain)
if err != nil {
w.log.Error(
"failed to lookup NS",
"domain", domain,
"error", err,
)
return
}
sort.Strings(nameservers)
now := time.Now().UTC()
prev, hasPrev := w.state.GetDomainState(domain)
if hasPrev && !w.firstRun {
w.detectNSChanges(ctx, domain, prev.Nameservers, nameservers)
}
w.state.SetDomainState(domain, &state.DomainState{
Nameservers: nameservers,
LastChecked: now,
})
// Also look up A/AAAA records for the apex domain so that
// port and TLS checks (which read HostnameState) can find
// the domain's IP addresses.
records, err := w.resolver.LookupAllRecords(ctx, domain)
if err != nil {
w.log.Error(
"failed to lookup records for domain",
"domain", domain,
"error", err,
)
return
}
prevHS, hasPrevHS := w.state.GetHostnameState(domain)
if hasPrevHS && !w.firstRun {
w.detectHostnameChanges(ctx, domain, prevHS, records)
}
newState := buildHostnameState(records, now)
w.state.SetHostnameState(domain, newState)
}
func (w *Watcher) detectNSChanges(
ctx context.Context,
domain string,
oldNS, newNS []string,
) {
oldSet := toSet(oldNS)
newSet := toSet(newNS)
var added, removed []string
for ns := range newSet {
if !oldSet[ns] {
added = append(added, ns)
}
}
for ns := range oldSet {
if !newSet[ns] {
removed = append(removed, ns)
}
}
if len(added) == 0 && len(removed) == 0 {
return
}
msg := fmt.Sprintf(
"Domain: %s\nAdded: %s\nRemoved: %s",
domain,
strings.Join(added, ", "),
strings.Join(removed, ", "),
)
w.notify.SendNotification(
ctx,
"NS Change: "+domain,
msg,
"warning",
)
}
func (w *Watcher) checkHostname(
ctx context.Context,
hostname string,
) {
records, err := w.resolver.LookupAllRecords(ctx, hostname)
if err != nil {
w.log.Error(
"failed to lookup records",
"hostname", hostname,
"error", err,
)
return
}
now := time.Now().UTC()
prev, hasPrev := w.state.GetHostnameState(hostname)
if hasPrev && !w.firstRun {
w.detectHostnameChanges(ctx, hostname, prev, records)
}
newState := buildHostnameState(records, now)
w.state.SetHostnameState(hostname, newState)
}
func buildHostnameState(
records map[string]map[string][]string,
now time.Time,
) *state.HostnameState {
hs := &state.HostnameState{
RecordsByNameserver: make(
map[string]*state.NameserverRecordState,
),
LastChecked: now,
}
for ns, recs := range records {
hs.RecordsByNameserver[ns] = &state.NameserverRecordState{
Records: recs,
Status: "ok",
LastChecked: now,
}
}
return hs
}
func (w *Watcher) detectHostnameChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
w.detectRecordChanges(ctx, hostname, prev, current)
w.detectNSDisappearances(ctx, hostname, prev, current)
w.detectInconsistencies(ctx, hostname, current)
}
func (w *Watcher) detectRecordChanges(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, recs := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok {
continue
}
if recordsEqual(prevNS.Records, recs) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s\n"+
"Old: %v\nNew: %v",
hostname, ns,
prevNS.Records, recs,
)
w.notify.SendNotification(
ctx,
"Record Change: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) detectNSDisappearances(
ctx context.Context,
hostname string,
prev *state.HostnameState,
current map[string]map[string][]string,
) {
for ns, prevNS := range prev.RecordsByNameserver {
if _, ok := current[ns]; ok || prevNS.Status != "ok" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s disappeared",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Failure: "+hostname,
msg,
"error",
)
}
for ns := range current {
prevNS, ok := prev.RecordsByNameserver[ns]
if !ok || prevNS.Status != "error" {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\nNameserver: %s recovered",
hostname, ns,
)
w.notify.SendNotification(
ctx,
"NS Recovery: "+hostname,
msg,
"success",
)
}
}
func (w *Watcher) detectInconsistencies(
ctx context.Context,
hostname string,
current map[string]map[string][]string,
) {
nameservers := make([]string, 0, len(current))
for ns := range current {
nameservers = append(nameservers, ns)
}
sort.Strings(nameservers)
for i := range len(nameservers) - 1 {
ns1 := nameservers[i]
ns2 := nameservers[i+1]
if recordsEqual(current[ns1], current[ns2]) {
continue
}
msg := fmt.Sprintf(
"Hostname: %s\n%s: %v\n%s: %v",
hostname,
ns1, current[ns1],
ns2, current[ns2],
)
w.notify.SendNotification(
ctx,
"Inconsistency: "+hostname,
msg,
"warning",
)
}
}
func (w *Watcher) checkAllPorts(ctx context.Context) {
// Phase 1: Build current IP:port → hostname associations
// from fresh DNS data.
associations := w.buildPortAssociations()
// Phase 2: Check each unique IP:port and update state
// with the full set of associated hostnames.
for key, hostnames := range associations {
ip, port := parsePortKey(key)
if port == 0 {
continue
}
w.checkSinglePort(ctx, ip, port, hostnames)
}
// Phase 3: Remove port state entries that no longer have
// any hostname referencing them.
w.cleanupStalePorts(associations)
}
// buildPortAssociations constructs a map from IP:port keys to
// the sorted set of hostnames currently resolving to that IP.
func (w *Watcher) buildPortAssociations() map[string][]string {
assoc := make(map[string]map[string]bool)
allNames := make(
[]string, 0,
len(w.config.Hostnames)+len(w.config.Domains),
)
allNames = append(allNames, w.config.Hostnames...)
allNames = append(allNames, w.config.Domains...)
for _, name := range allNames {
ips := w.collectIPs(name)
for _, ip := range ips {
for _, port := range monitoredPorts {
key := fmt.Sprintf("%s:%d", ip, port)
if assoc[key] == nil {
assoc[key] = make(map[string]bool)
}
assoc[key][name] = true
}
}
}
result := make(map[string][]string, len(assoc))
for key, set := range assoc {
hostnames := make([]string, 0, len(set))
for h := range set {
hostnames = append(hostnames, h)
}
sort.Strings(hostnames)
result[key] = hostnames
}
return result
}
// parsePortKey splits an "ip:port" key into its components.
func parsePortKey(key string) (string, int) {
lastColon := strings.LastIndex(key, ":")
if lastColon < 0 {
return key, 0
}
ip := key[:lastColon]
var p int
_, err := fmt.Sscanf(key[lastColon+1:], "%d", &p)
if err != nil {
return ip, 0
}
return ip, p
}
// cleanupStalePorts removes port state entries that are no
// longer referenced by any hostname in the current DNS data.
func (w *Watcher) cleanupStalePorts(
currentAssociations map[string][]string,
) {
for _, key := range w.state.GetAllPortKeys() {
if _, exists := currentAssociations[key]; !exists {
w.state.DeletePortState(key)
}
}
}
func (w *Watcher) collectIPs(hostname string) []string {
hs, ok := w.state.GetHostnameState(hostname)
if !ok {
return nil
}
ipSet := make(map[string]bool)
for _, nsState := range hs.RecordsByNameserver {
for _, ip := range nsState.Records["A"] {
ipSet[ip] = true
}
for _, ip := range nsState.Records["AAAA"] {
ipSet[ip] = true
}
}
result := make([]string, 0, len(ipSet))
for ip := range ipSet {
result = append(result, ip)
}
sort.Strings(result)
return result
}
func (w *Watcher) checkSinglePort(
ctx context.Context,
ip string,
port int,
hostnames []string,
) {
result, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil {
w.log.Error(
"port check failed",
"ip", ip,
"port", port,
"error", err,
)
return
}
key := fmt.Sprintf("%s:%d", ip, port)
now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != result.Open {
stateStr := "closed"
if result.Open {
stateStr = "open"
}
msg := fmt.Sprintf(
"Hosts: %s\nAddress: %s\nPort now %s",
strings.Join(hostnames, ", "), key, stateStr,
)
w.notify.SendNotification(
ctx,
"Port Change: "+key,
msg,
"warning",
)
}
w.state.SetPortState(key, &state.PortState{
Open: result.Open,
Hostnames: hostnames,
LastChecked: now,
})
}
func (w *Watcher) runTLSChecks(ctx context.Context) {
for _, hostname := range w.config.Hostnames {
w.checkTLSForHostname(ctx, hostname)
}
for _, domain := range w.config.Domains {
w.checkTLSForHostname(ctx, domain)
}
}
func (w *Watcher) checkTLSForHostname(
ctx context.Context,
hostname string,
) {
ips := w.collectIPs(hostname)
for _, ip := range ips {
portKey := fmt.Sprintf("%s:%d", ip, tlsPort)
ps, ok := w.state.GetPortState(portKey)
if !ok || !ps.Open {
continue
}
w.checkTLSCert(ctx, ip, hostname)
}
}
func (w *Watcher) checkTLSCert(
ctx context.Context,
ip string,
hostname string,
) {
cert, err := w.tlsCheck.CheckCertificate(ctx, ip, hostname)
certKey := fmt.Sprintf("%s:%d:%s", ip, tlsPort, hostname)
now := time.Now().UTC()
prev, hasPrev := w.state.GetCertificateState(certKey)
if err != nil {
w.handleTLSError(
ctx, certKey, hostname, ip,
hasPrev, prev, now, err,
)
return
}
w.handleTLSSuccess(
ctx, certKey, hostname, ip,
hasPrev, prev, now, cert,
)
}
func (w *Watcher) handleTLSError(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
err error,
) {
if hasPrev && !w.firstRun && prev.Status == "ok" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nError: %s",
hostname, ip, err,
)
w.notify.SendNotification(
ctx,
"TLS Failure: "+hostname,
msg,
"error",
)
}
w.state.SetCertificateState(
certKey, &state.CertificateState{
Status: "error",
Error: err.Error(),
LastChecked: now,
},
)
}
func (w *Watcher) handleTLSSuccess(
ctx context.Context,
certKey, hostname, ip string,
hasPrev bool,
prev *state.CertificateState,
now time.Time,
cert *tlscheck.CertificateInfo,
) {
if hasPrev && !w.firstRun {
w.detectTLSChanges(ctx, hostname, ip, prev, cert)
}
w.checkTLSExpiry(ctx, hostname, ip, cert)
w.state.SetCertificateState(
certKey, &state.CertificateState{
CommonName: cert.CommonName,
Issuer: cert.Issuer,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: cert.SubjectAlternativeNames,
Status: "ok",
LastChecked: now,
},
)
}
func (w *Watcher) detectTLSChanges(
ctx context.Context,
hostname, ip string,
prev *state.CertificateState,
cert *tlscheck.CertificateInfo,
) {
if prev.Status == "error" {
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nTLS recovered",
hostname, ip,
)
w.notify.SendNotification(
ctx,
"TLS Recovery: "+hostname,
msg,
"success",
)
return
}
changed := prev.CommonName != cert.CommonName ||
prev.Issuer != cert.Issuer ||
!sliceEqual(
prev.SubjectAlternativeNames,
cert.SubjectAlternativeNames,
)
if !changed {
return
}
msg := fmt.Sprintf(
"Host: %s\nIP: %s\n"+
"Old CN: %s, Issuer: %s\n"+
"New CN: %s, Issuer: %s",
hostname, ip,
prev.CommonName, prev.Issuer,
cert.CommonName, cert.Issuer,
)
w.notify.SendNotification(
ctx,
"TLS Certificate Change: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) checkTLSExpiry(
ctx context.Context,
hostname, ip string,
cert *tlscheck.CertificateInfo,
) {
daysLeft := time.Until(cert.NotAfter).Hours() / hoursPerDay
warningDays := float64(w.config.TLSExpiryWarning)
if daysLeft > warningDays {
return
}
// Deduplicate expiry warnings: don't re-notify for the same
// hostname within the TLS check interval.
dedupKey := fmt.Sprintf("expiry:%s:%s", hostname, ip)
w.expiryNotifiedMu.Lock()
lastNotified, seen := w.expiryNotified[dedupKey]
if seen && time.Since(lastNotified) < w.config.TLSInterval {
w.expiryNotifiedMu.Unlock()
return
}
w.expiryNotified[dedupKey] = time.Now()
w.expiryNotifiedMu.Unlock()
msg := fmt.Sprintf(
"Host: %s\nIP: %s\nCN: %s\n"+
"Expires: %s (%.0f days)",
hostname, ip, cert.CommonName,
cert.NotAfter.Format(time.RFC3339),
daysLeft,
)
w.notify.SendNotification(
ctx,
"TLS Expiry Warning: "+hostname,
msg,
"warning",
)
}
func (w *Watcher) saveState() {
err := w.state.Save()
if err != nil {
w.log.Error("failed to save state", "error", err)
}
}
// maybeSendTestNotification sends a startup status notification
// after the first full scan completes, if SEND_TEST_NOTIFICATION
// is enabled. The message is clearly informational ("all ok")
// and not an error or anomaly alert.
func (w *Watcher) maybeSendTestNotification(ctx context.Context) {
if !w.config.SendTestNotification {
return
}
snap := w.state.GetSnapshot()
msg := fmt.Sprintf(
"dnswatcher has started and completed its initial scan.\n"+
"Monitoring %d domain(s) and %d hostname(s).\n"+
"Tracking %d port endpoint(s) and %d TLS certificate(s).\n"+
"All notification channels are working.",
len(snap.Domains),
len(snap.Hostnames),
len(snap.Ports),
len(snap.Certificates),
)
w.log.Info("sending startup test notification")
w.notify.SendNotification(
ctx,
"✅ dnswatcher startup complete",
msg,
"success",
)
}
// --- Utility functions ---
func toSet(items []string) map[string]bool {
set := make(map[string]bool, len(items))
for _, item := range items {
set[item] = true
}
return set
}
func recordsEqual(
a, b map[string][]string,
) bool {
if len(a) != len(b) {
return false
}
for k, av := range a {
bv, ok := b[k]
if !ok || !sliceEqual(av, bv) {
return false
}
}
return true
}
func sliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
aSorted := make([]string, len(a))
bSorted := make([]string, len(b))
copy(aSorted, a)
copy(bSorted, b)
sort.Strings(aSorted)
sort.Strings(bSorted)
for i := range aSorted {
if aSorted[i] != bSorted[i] {
return false
}
}
return true
}

View File

@@ -0,0 +1,904 @@
package watcher_test
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
"sneak.berlin/go/dnswatcher/internal/watcher"
)
// errNotFound is returned when mock data is missing.
var errNotFound = errors.New("not found")
// --- Mock implementations ---
type mockResolver struct {
mu sync.Mutex
nsRecords map[string][]string
allRecords map[string]map[string]map[string][]string
ipAddresses map[string][]string
lookupNSErr error
allRecordsErr error
resolveIPErr error
lookupNSCalls int
allRecordCalls int
}
func (m *mockResolver) LookupNS(
_ context.Context,
domain string,
) ([]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.lookupNSCalls++
if m.lookupNSErr != nil {
return nil, m.lookupNSErr
}
ns, ok := m.nsRecords[domain]
if !ok {
return nil, fmt.Errorf(
"%w: NS for %s", errNotFound, domain,
)
}
return ns, nil
}
func (m *mockResolver) LookupAllRecords(
_ context.Context,
hostname string,
) (map[string]map[string][]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.allRecordCalls++
if m.allRecordsErr != nil {
return nil, m.allRecordsErr
}
recs, ok := m.allRecords[hostname]
if !ok {
return nil, fmt.Errorf(
"%w: records for %s", errNotFound, hostname,
)
}
return recs, nil
}
func (m *mockResolver) ResolveIPAddresses(
_ context.Context,
hostname string,
) ([]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.resolveIPErr != nil {
return nil, m.resolveIPErr
}
ips, ok := m.ipAddresses[hostname]
if !ok {
return nil, fmt.Errorf(
"%w: IPs for %s", errNotFound, hostname,
)
}
return ips, nil
}
type mockPortChecker struct {
mu sync.Mutex
results map[string]bool
err error
calls int
}
func (m *mockPortChecker) CheckPort(
_ context.Context,
address string,
port int,
) (*portcheck.PortResult, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return nil, m.err
}
key := fmt.Sprintf("%s:%d", address, port)
open := m.results[key]
return &portcheck.PortResult{Open: open}, nil
}
type mockTLSChecker struct {
mu sync.Mutex
certs map[string]*tlscheck.CertificateInfo
err error
calls int
}
func (m *mockTLSChecker) CheckCertificate(
_ context.Context,
ip string,
hostname string,
) (*tlscheck.CertificateInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return nil, m.err
}
key := fmt.Sprintf("%s:%s", ip, hostname)
cert, ok := m.certs[key]
if !ok {
return nil, fmt.Errorf(
"%w: cert for %s", errNotFound, key,
)
}
return cert, nil
}
type notification struct {
Title string
Message string
Priority string
}
type mockNotifier struct {
mu sync.Mutex
notifications []notification
}
func (m *mockNotifier) SendNotification(
_ context.Context,
title, message, priority string,
) {
m.mu.Lock()
defer m.mu.Unlock()
m.notifications = append(m.notifications, notification{
Title: title,
Message: message,
Priority: priority,
})
}
func (m *mockNotifier) getNotifications() []notification {
m.mu.Lock()
defer m.mu.Unlock()
result := make([]notification, len(m.notifications))
copy(result, m.notifications)
return result
}
// --- Helper to build a Watcher for testing ---
type testDeps struct {
resolver *mockResolver
portChecker *mockPortChecker
tlsChecker *mockTLSChecker
notifier *mockNotifier
state *state.State
config *config.Config
}
func newTestWatcher(
t *testing.T,
cfg *config.Config,
) (*watcher.Watcher, *testDeps) {
t.Helper()
deps := &testDeps{
resolver: &mockResolver{
nsRecords: make(map[string][]string),
allRecords: make(map[string]map[string]map[string][]string),
ipAddresses: make(map[string][]string),
},
portChecker: &mockPortChecker{
results: make(map[string]bool),
},
tlsChecker: &mockTLSChecker{
certs: make(map[string]*tlscheck.CertificateInfo),
},
notifier: &mockNotifier{},
config: cfg,
}
deps.state = state.NewForTest()
w := watcher.NewForTest(
deps.config,
deps.state,
deps.resolver,
deps.portChecker,
deps.tlsChecker,
deps.notifier,
)
return w, deps
}
func defaultTestConfig(t *testing.T) *config.Config {
t.Helper()
return &config.Config{
DNSInterval: time.Hour,
TLSInterval: 12 * time.Hour,
TLSExpiryWarning: 7,
DataDir: t.TempDir(),
}
}
func TestFirstRunBaseline(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
setupBaselineMocks(deps)
w.RunOnce(t.Context())
assertNoNotifications(t, deps)
assertStatePopulated(t, deps)
}
func setupBaselineMocks(deps *testDeps) {
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns2.example.com.",
}
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
"ns2.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
"ns2.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.34",
}
deps.portChecker.results["93.184.216.34:80"] = true
deps.portChecker.results["93.184.216.34:443"] = true
deps.tlsChecker.certs["93.184.216.34:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
deps.tlsChecker.certs["93.184.216.34:example.com"] = &tlscheck.CertificateInfo{
CommonName: "example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"example.com",
},
}
}
func assertNoNotifications(
t *testing.T,
deps *testDeps,
) {
t.Helper()
notifications := deps.notifier.getNotifications()
if len(notifications) != 0 {
t.Errorf(
"expected 0 notifications on first run, got %d",
len(notifications),
)
}
}
func assertStatePopulated(
t *testing.T,
deps *testDeps,
) {
t.Helper()
snap := deps.state.GetSnapshot()
if len(snap.Domains) != 1 {
t.Errorf(
"expected 1 domain in state, got %d",
len(snap.Domains),
)
}
// Hostnames includes both explicit hostnames and domains
// (domains now also get hostname state for port/TLS checks).
if len(snap.Hostnames) < 1 {
t.Errorf(
"expected at least 1 hostname in state, got %d",
len(snap.Hostnames),
)
}
}
func TestDomainPortAndTLSChecks(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
}
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
}
deps.portChecker.results["93.184.216.34:80"] = true
deps.portChecker.results["93.184.216.34:443"] = true
deps.tlsChecker.certs["93.184.216.34:example.com"] = &tlscheck.CertificateInfo{
CommonName: "example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"example.com",
},
}
w.RunOnce(t.Context())
snap := deps.state.GetSnapshot()
// Domain should have port state populated
if len(snap.Ports) == 0 {
t.Error("expected port state for domain, got none")
}
// Domain should have certificate state populated
if len(snap.Certificates) == 0 {
t.Error("expected certificate state for domain, got none")
}
// Verify port checker was actually called
deps.portChecker.mu.Lock()
calls := deps.portChecker.calls
deps.portChecker.mu.Unlock()
if calls == 0 {
t.Error("expected port checker to be called for domain")
}
// Verify TLS checker was actually called
deps.tlsChecker.mu.Lock()
tlsCalls := deps.tlsChecker.calls
deps.tlsChecker.mu.Unlock()
if tlsCalls == 0 {
t.Error("expected TLS checker to be called for domain")
}
}
func TestNSChangeDetection(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns2.example.com.",
}
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
"ns2.example.com.": {"A": {"1.2.3.4"}},
}
deps.portChecker.results["1.2.3.4:80"] = false
deps.portChecker.results["1.2.3.4:443"] = false
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
"ns3.example.com.",
}
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
"ns3.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for NS change")
}
found := false
for _, n := range notifications {
if n.Priority == "warning" {
found = true
}
}
if !found {
t.Error("expected warning-priority NS change notification")
}
}
func TestRecordChangeDetection(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.34",
}
deps.portChecker.results["93.184.216.34:80"] = false
deps.portChecker.results["93.184.216.34:443"] = false
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.35"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"93.184.216.35",
}
deps.resolver.mu.Unlock()
deps.portChecker.mu.Lock()
deps.portChecker.results["93.184.216.35:80"] = false
deps.portChecker.results["93.184.216.35:443"] = false
deps.portChecker.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for record change")
}
}
func TestPortStateChange(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
w.RunOnce(ctx)
deps.portChecker.mu.Lock()
deps.portChecker.results["1.2.3.4:443"] = false
deps.portChecker.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for port state change")
}
}
func TestTLSExpiryWarning(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(3 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
// First run = baseline
w.RunOnce(ctx)
// Second run should warn about expiry
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
found := false
for _, n := range notifications {
if n.Priority == "warning" {
found = true
}
}
if !found {
t.Errorf(
"expected expiry warning, got: %v",
notifications,
)
}
}
func TestTLSExpiryWarningDedup(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
cfg.TLSInterval = 24 * time.Hour
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(3 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
// First run = baseline, no notifications
w.RunOnce(ctx)
// Second run should fire one expiry warning
w.RunOnce(ctx)
// Third run should NOT fire another warning (dedup)
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
expiryCount := 0
for _, n := range notifications {
if n.Title == "TLS Expiry Warning: www.example.com" {
expiryCount++
}
}
if expiryCount != 1 {
t.Errorf(
"expected exactly 1 expiry warning (dedup), got %d",
expiryCount,
)
}
}
func TestGracefulShutdown(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.DNSInterval = 100 * time.Millisecond
cfg.TLSInterval = 100 * time.Millisecond
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
}
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.portChecker.results["1.2.3.4:80"] = false
deps.portChecker.results["1.2.3.4:443"] = false
ctx, cancel := context.WithCancel(t.Context())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
time.Sleep(250 * time.Millisecond)
cancel()
select {
case <-done:
// Shut down cleanly
case <-time.After(5 * time.Second):
t.Error("watcher did not shut down within timeout")
}
}
func setupHostnameIP(
deps *testDeps,
hostname, ip string,
) {
deps.resolver.allRecords[hostname] = map[string]map[string][]string{
"ns1.example.com.": {"A": {ip}},
}
deps.portChecker.results[ip+":80"] = true
deps.portChecker.results[ip+":443"] = true
deps.tlsChecker.certs[ip+":"+hostname] = &tlscheck.CertificateInfo{
CommonName: hostname,
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{hostname},
}
}
func updateHostnameIP(deps *testDeps, hostname, ip string) {
deps.resolver.mu.Lock()
deps.resolver.allRecords[hostname] = map[string]map[string][]string{
"ns1.example.com.": {"A": {ip}},
}
deps.resolver.mu.Unlock()
deps.portChecker.mu.Lock()
deps.portChecker.results[ip+":80"] = true
deps.portChecker.results[ip+":443"] = true
deps.portChecker.mu.Unlock()
deps.tlsChecker.mu.Lock()
deps.tlsChecker.certs[ip+":"+hostname] = &tlscheck.CertificateInfo{
CommonName: hostname,
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{hostname},
}
deps.tlsChecker.mu.Unlock()
}
func TestDNSRunsBeforePortAndTLSChecks(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
setupHostnameIP(deps, "www.example.com", "10.0.0.1")
ctx := t.Context()
w.RunOnce(ctx)
snap := deps.state.GetSnapshot()
if _, ok := snap.Ports["10.0.0.1:80"]; !ok {
t.Fatal("expected port state for 10.0.0.1:80")
}
// DNS changes to a new IP; port and TLS must pick it up.
updateHostnameIP(deps, "www.example.com", "10.0.0.2")
w.RunOnce(ctx)
snap = deps.state.GetSnapshot()
if _, ok := snap.Ports["10.0.0.2:80"]; !ok {
t.Error("port check used stale DNS: missing 10.0.0.2:80")
}
certKey := "10.0.0.2:443:www.example.com"
if _, ok := snap.Certificates[certKey]; !ok {
t.Error("TLS check used stale DNS: missing " + certKey)
}
}
func TestSendTestNotification_Enabled(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.Hostnames = []string{"www.example.com"}
cfg.SendTestNotification = true
w, deps := newTestWatcher(t, cfg)
setupBaselineMocks(deps)
w.RunOnce(t.Context())
// RunOnce does not send the test notification — it is
// sent by Run after RunOnce completes. Call the exported
// RunOnce then check that no test notification was sent
// (only Run triggers it). We test the full path via Run.
notifications := deps.notifier.getNotifications()
if len(notifications) != 0 {
t.Errorf(
"RunOnce should not send test notification, got %d",
len(notifications),
)
}
}
func TestSendTestNotification_ViaRun(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.Hostnames = []string{"www.example.com"}
cfg.SendTestNotification = true
cfg.DNSInterval = 24 * time.Hour
cfg.TLSInterval = 24 * time.Hour
w, deps := newTestWatcher(t, cfg)
setupBaselineMocks(deps)
ctx, cancel := context.WithCancel(t.Context())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
// Wait for the initial scan and test notification.
time.Sleep(500 * time.Millisecond)
cancel()
<-done
notifications := deps.notifier.getNotifications()
found := false
for _, n := range notifications {
if n.Priority == "success" &&
n.Title == "✅ dnswatcher startup complete" {
found = true
}
}
if !found {
t.Errorf(
"expected startup test notification, got: %v",
notifications,
)
}
}
func TestSendTestNotification_Disabled(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
cfg.Hostnames = []string{"www.example.com"}
cfg.SendTestNotification = false
cfg.DNSInterval = 24 * time.Hour
cfg.TLSInterval = 24 * time.Hour
w, deps := newTestWatcher(t, cfg)
setupBaselineMocks(deps)
ctx, cancel := context.WithCancel(t.Context())
done := make(chan struct{})
go func() {
w.Run(ctx)
close(done)
}()
time.Sleep(500 * time.Millisecond)
cancel()
<-done
notifications := deps.notifier.getNotifications()
for _, n := range notifications {
if n.Title == "✅ dnswatcher startup complete" {
t.Error(
"test notification should not be sent when disabled",
)
}
}
}
func TestNSFailureAndRecovery(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
"ns2.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = false
deps.portChecker.results["1.2.3.4:443"] = false
ctx := t.Context()
w.RunOnce(ctx)
deps.resolver.mu.Lock()
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.mu.Unlock()
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
if len(notifications) == 0 {
t.Error("expected notification for NS disappearance")
}
}

1
static/css/tailwind.min.css vendored Normal file

File diff suppressed because one or more lines are too long

10
static/static.go Normal file
View File

@@ -0,0 +1,10 @@
// Package static provides embedded static assets.
package static
import "embed"
// Static contains the embedded static assets (CSS, JS) served
// at the /s/ URL prefix.
//
//go:embed css
var Static embed.FS