3 Commits

Author SHA1 Message Date
9a9a95581a Merge branch 'main' into feature/portcheck-implementation 2026-02-20 09:06:53 +01:00
clawbot
f0ea83179f fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:42:21 -08:00
clawbot
28f2d829ce feat: implement TCP port connectivity checker (closes #3) 2026-02-19 13:44:20 -08:00
32 changed files with 1491 additions and 3342 deletions

View File

@@ -1,6 +0,0 @@
.git/
bin/
*.md
LICENSE
.editorconfig
.gitignore

View File

@@ -1,12 +0,0 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[Makefile]
indent_style = tab

View File

@@ -1,9 +0,0 @@
name: check
on: [push]
jobs:
check:
runs-on: ubuntu-latest
steps:
# actions/checkout v4.2.2, 2026-02-28
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- run: docker build .

68
CLAUDE.md Normal file
View File

@@ -0,0 +1,68 @@
# Repository Rules
Last Updated 2026-01-08
These rules MUST be followed at all times, it is very important.
* Never use `git add -A` - add specific changes to a deliberate commit. A
commit should contain one change. After each change, make a commit with a
good one-line summary.
* NEVER modify the linter config without asking first.
* NEVER modify tests to exclude special cases or otherwise get them to pass
without asking first. In almost all cases, the code should be changed,
NOT the tests. If you think the test needs to be changed, make your case
for that and ask for permission to proceed, then stop. You need explicit
user approval to modify existing tests. (You do not need user approval
for writing NEW tests.)
* When linting, assume the linter config is CORRECT, and that each item
output by the linter is something that legitimately needs fixing in the
code.
* When running tests, use `make test`.
* Before commits, run `make check`. This runs `make lint` and `make test`
and `make check-fmt`. Any issues discovered MUST be resolved before
committing unless explicitly told otherwise.
* When fixing a bug, write a failing test for the bug FIRST. Add
appropriate logging to the test to ensure it is written correctly. Commit
that. Then go about fixing the bug until the test passes (without
modifying the test further). Then commit that.
* When adding a new feature, do the same - implement a test first (TDD). It
doesn't have to be super complex. Commit the test, then commit the
feature.
* When adding a new feature, use a feature branch. When the feature is
completely finished and the code is up to standards (passes `make check`)
then and only then can the feature branch be merged into `main` and the
branch deleted.
* Write godoc documentation comments for all exported types and functions as
you go along.
* ALWAYS be consistent in naming. If you name something one thing in one
place, name it the EXACT SAME THING in another place.
* Be descriptive and specific in naming. `wl` is bad;
`SourceHostWhitelist` is good. `ConnsPerHost` is bad;
`MaxConnectionsPerHost` is good.
* This is not prototype or teaching code - this is designed for production.
Any security issues (such as denial of service) or other web
vulnerabilities are P1 bugs and must be added to TODO.md at the top.
* As this is production code, no stubbing of implementations unless
specifically instructed. We need working implementations.
* Avoid vendoring deps unless specifically instructed to. NEVER commit
the vendor directory, NEVER commit compiled binaries. If these
directories or files exist, add them to .gitignore (and commit the
.gitignore) if they are not already in there. Keep the entire git
repository (with history) small - under 20MiB, unless you specifically
must commit larger files (e.g. test fixture example media files). Only
OUR source code and immediately supporting files (such as test examples)
goes into the repo/history.

1225
CONVENTIONS.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,11 @@
# Build stage # Build stage
# golang 1.25-alpine, 2026-02-28 FROM golang:1.25-alpine AS builder
FROM golang@sha256:f6751d823c26342f9506c03797d2527668d095b0a15f1862cddb4d927a7a4ced AS builder
RUN apk add --no-cache git make gcc musl-dev binutils-gold RUN apk add --no-cache git make gcc musl-dev
# golangci-lint v2.10.1 # Install golangci-lint v2
RUN go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@5d1e709b7be35cb2025444e19de266b056b7b7ee RUN go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
# goimports v0.42.0 RUN go install golang.org/x/tools/cmd/goimports@latest
RUN go install golang.org/x/tools/cmd/goimports@009367f5c17a8d4c45a961a3a509277190a9a6f0
WORKDIR /src WORKDIR /src
COPY go.mod go.sum ./ COPY go.mod go.sum ./
@@ -22,8 +20,7 @@ RUN make check
RUN make build RUN make build
# Runtime stage # Runtime stage
# alpine 3.21, 2026-02-28 FROM alpine:3.21
FROM alpine@sha256:c3f8e73fdb79deaebaa2037150150191b9dcbfba68b4a46d70103204c53f4709
RUN apk add --no-cache ca-certificates tzdata RUN apk add --no-cache ca-certificates tzdata

View File

@@ -1,4 +1,4 @@
.PHONY: all build lint fmt fmt-check test check clean hooks docker .PHONY: all build lint fmt test check clean
BINARY := dnswatcher BINARY := dnswatcher
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
@@ -17,11 +17,8 @@ fmt:
gofmt -s -w . gofmt -s -w .
goimports -w . goimports -w .
fmt-check:
@test -z "$$(gofmt -l .)" || (echo "gofmt: files not formatted:" && gofmt -l . && exit 1)
test: test:
go test -v -race -timeout 30s -cover ./... go test -v -race -cover ./...
# Check runs all validation without making changes # Check runs all validation without making changes
# Used by CI and Docker build - fails if anything is wrong # Used by CI and Docker build - fails if anything is wrong
@@ -31,19 +28,10 @@ check:
@echo "==> Running linter..." @echo "==> Running linter..."
golangci-lint run --config .golangci.yml ./... golangci-lint run --config .golangci.yml ./...
@echo "==> Running tests..." @echo "==> Running tests..."
go test -v -race -timeout 30s ./... go test -v -race ./...
@echo "==> Building..." @echo "==> Building..."
go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/dnswatcher go build -ldflags "$(LDFLAGS)" -o /dev/null ./cmd/dnswatcher
@echo "==> All checks passed!" @echo "==> All checks passed!"
clean: clean:
rm -rf bin/ rm -rf bin/
hooks:
@echo '#!/bin/sh' > .git/hooks/pre-commit
@echo 'make check' >> .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit
@echo "Pre-commit hook installed."
docker:
docker build .

View File

@@ -1,10 +1,9 @@
# dnswatcher # dnswatcher
dnswatcher is a pre-1.0 Go daemon by [@sneak](https://sneak.berlin) that monitors DNS records, TCP port availability, and TLS certificates, delivering real-time change notifications via Slack, Mattermost, and ntfy webhooks.
> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice. > ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice.
dnswatcher watches configured DNS domains and hostnames for changes, monitors TCP dnswatcher is a production DNS and infrastructure monitoring daemon written in
Go. It watches configured DNS domains and hostnames for changes, monitors TCP
port availability, tracks TLS certificate expiry, and delivers real-time port availability, tracks TLS certificate expiry, and delivers real-time
notifications via Slack, Mattermost, and/or ntfy webhooks. notifications via Slack, Mattermost, and/or ntfy webhooks.
@@ -110,8 +109,8 @@ includes:
- **NS recoveries**: Which nameserver recovered, which hostname/domain. - **NS recoveries**: Which nameserver recovered, which hostname/domain.
- **NS inconsistencies**: Which nameservers disagree, what each one - **NS inconsistencies**: Which nameservers disagree, what each one
returned, which hostname affected. returned, which hostname affected.
- **Port changes**: Which IP:port, old state, new state, all associated - **Port changes**: Which IP:port, old state, new state, associated
hostnames. hostname.
- **TLS expiry warnings**: Which certificate, days remaining, CN, - **TLS expiry warnings**: Which certificate, days remaining, CN,
issuer, associated hostname and IP. issuer, associated hostname and IP.
- **TLS certificate changes**: Old and new CN/issuer/SANs, associated - **TLS certificate changes**: Old and new CN/issuer/SANs, associated
@@ -290,12 +289,12 @@ not as a merged view, to enable inconsistency detection.
"ports": { "ports": {
"93.184.216.34:80": { "93.184.216.34:80": {
"open": true, "open": true,
"hostnames": ["www.example.com"], "hostname": "www.example.com",
"lastChecked": "2026-02-19T12:00:00Z" "lastChecked": "2026-02-19T12:00:00Z"
}, },
"93.184.216.34:443": { "93.184.216.34:443": {
"open": true, "open": true,
"hostnames": ["www.example.com"], "hostname": "www.example.com",
"lastChecked": "2026-02-19T12:00:00Z" "lastChecked": "2026-02-19T12:00:00Z"
} }
}, },
@@ -367,15 +366,9 @@ docker run -d \
triggering change notifications). triggering change notifications).
2. **Initial check**: Immediately perform all DNS, port, and TLS checks 2. **Initial check**: Immediately perform all DNS, port, and TLS checks
on startup. on startup.
3. **Periodic checks** (DNS always runs first): 3. **Periodic checks**:
- DNS checks: every `DNSWATCHER_DNS_INTERVAL` (default 1h). Also - DNS and port checks: every `DNSWATCHER_DNS_INTERVAL` (default 1h).
re-run before every TLS check cycle to ensure fresh IPs. - TLS checks: every `DNSWATCHER_TLS_INTERVAL` (default 12h).
- Port checks: every `DNSWATCHER_DNS_INTERVAL`, after DNS completes.
- TLS checks: every `DNSWATCHER_TLS_INTERVAL` (default 12h), after
DNS completes.
- Port and TLS checks always use freshly resolved IP addresses from
the DNS phase that immediately precedes them — never stale IPs
from a previous cycle.
4. **On change detection**: Send notifications to all configured 4. **On change detection**: Send notifications to all configured
endpoints, update in-memory state, persist to disk. endpoints, update in-memory state, persist to disk.
5. **Shutdown**: Persist final state to disk, complete in-flight 5. **Shutdown**: Persist final state to disk, complete in-flight
@@ -383,27 +376,9 @@ docker run -d \
--- ---
## Planned Future Features (Post-1.0)
- **DNSSEC validation**: Validate the DNSSEC chain of trust during
iterative resolution and report DNSSEC failures as notifications.
---
## Project Structure ## Project Structure
Follows the conventions defined in `REPO_POLICIES.md`, adapted from the Follows the conventions defined in `CONVENTIONS.md`, adapted from the
[upaas](https://git.eeqj.de/sneak/upaas) project template. Uses uber/fx [upaas](https://git.eeqj.de/sneak/upaas) project template. Uses uber/fx
for dependency injection, go-chi for HTTP routing, slog for logging, and for dependency injection, go-chi for HTTP routing, slog for logging, and
Viper for configuration. Viper for configuration.
---
## License
License has not yet been chosen for this project. Pending decision by the
author (MIT, GPL, or WTFPL).
## Author
[@sneak](https://sneak.berlin)

View File

@@ -1,188 +0,0 @@
---
title: Repository Policies
last_modified: 2026-02-22
---
This document covers repository structure, tooling, and workflow standards. Code
style conventions are in separate documents:
- [Code Styleguide](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE.md)
(general, bash, Docker)
- [Go](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_GO.md)
- [JavaScript](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_JS.md)
- [Python](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/CODE_STYLEGUIDE_PYTHON.md)
- [Go HTTP Server Conventions](https://git.eeqj.de/sneak/prompts/raw/branch/main/prompts/GO_HTTP_SERVER_CONVENTIONS.md)
---
- Cross-project documentation (such as this file) must include
`last_modified: YYYY-MM-DD` in the YAML front matter so it can be kept in sync
with the authoritative source as policies evolve.
- **ALL external references must be pinned by cryptographic hash.** This
includes Docker base images, Go modules, npm packages, GitHub Actions, and
anything else fetched from a remote source. Version tags (`@v4`, `@latest`,
`:3.21`, etc.) are server-mutable and therefore remote code execution
vulnerabilities. The ONLY acceptable way to reference an external dependency
is by its content hash (Docker `@sha256:...`, Go module hash in `go.sum`, npm
integrity hash in lockfile, GitHub Actions `@<commit-sha>`). No exceptions.
This also means never `curl | bash` to install tools like pyenv, nvm, rustup,
etc. Instead, download a specific release archive from GitHub, verify its hash
(hardcoded in the Dockerfile or script), and only then install. Unverified
install scripts are arbitrary remote code execution. This is the single most
important rule in this document. Double-check every external reference in
every file before committing. There are zero exceptions to this rule.
- Every repo with software must have a root `Makefile` with these targets:
`make test`, `make lint`, `make fmt` (writes), `make fmt-check` (read-only),
`make check` (prereqs: `test`, `lint`, `fmt-check`), `make docker`, and
`make hooks` (installs pre-commit hook). A model Makefile is at
`https://git.eeqj.de/sneak/prompts/raw/branch/main/Makefile`.
- Always use Makefile targets (`make fmt`, `make test`, `make lint`, etc.)
instead of invoking the underlying tools directly. The Makefile is the single
source of truth for how these operations are run.
- The Makefile is authoritative documentation for how the repo is used. Beyond
the required targets above, it should have targets for every common operation:
running a local development server (`make run`, `make dev`), re-initializing
or migrating the database (`make db-reset`, `make migrate`), building
artifacts (`make build`), generating code, seeding data, or anything else a
developer would do regularly. If someone checks out the repo and types
`make<tab>`, they should see every meaningful operation available. A new
contributor should be able to understand the entire development workflow by
reading the Makefile.
- Every repo should have a `Dockerfile`. All Dockerfiles must run `make check`
as a build step so the build fails if the branch is not green. For non-server
repos, the Dockerfile should bring up a development environment and run
`make check`. For server repos, `make check` should run as an early build
stage before the final image is assembled.
- Every repo should have a Gitea Actions workflow (`.gitea/workflows/`) that
runs `docker build .` on push. Since the Dockerfile already runs `make check`,
a successful build implies all checks pass.
- Use platform-standard formatters: `black` for Python, `prettier` for
JS/CSS/Markdown/HTML, `go fmt` for Go. Always use default configuration with
two exceptions: four-space indents (except Go), and `proseWrap: always` for
Markdown (hard-wrap at 80 columns). Documentation and writing repos (Markdown,
HTML, CSS) should also have `.prettierrc` and `.prettierignore`.
- Pre-commit hook: `make check` if local testing is possible, otherwise
`make lint && make fmt-check`. The Makefile should provide a `make hooks`
target to install the pre-commit hook.
- All repos with software must have tests that run via the platform-standard
test framework (`go test`, `pytest`, `jest`/`vitest`, etc.). If no meaningful
tests exist yet, add the most minimal test possible — e.g. importing the
module under test to verify it compiles/parses. There is no excuse for
`make test` to be a no-op.
- `make test` must complete in under 20 seconds. Add a 30-second timeout in the
Makefile.
- Docker builds must complete in under 5 minutes.
- `make check` must not modify any files in the repo. Tests may use temporary
directories.
- `main` must always pass `make check`, no exceptions.
- Never commit secrets. `.env` files, credentials, API keys, and private keys
must be in `.gitignore`. No exceptions.
- `.gitignore` should be comprehensive from the start: OS files (`.DS_Store`),
editor files (`.swp`, `*~`), language build artifacts, and `node_modules/`.
Fetch the standard `.gitignore` from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.gitignore` when setting up
a new repo.
- Never use `git add -A` or `git add .`. Always stage files explicitly by name.
- Never force-push to `main`.
- Make all changes on a feature branch. You can do whatever you want on a
feature branch.
- `.golangci.yml` is standardized and must _NEVER_ be modified by an agent, only
manually by the user. Fetch from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/.golangci.yml`.
- When pinning images or packages by hash, add a comment above the reference
with the version and date (YYYY-MM-DD).
- Use `yarn`, not `npm`.
- Write all dates as YYYY-MM-DD (ISO 8601).
- Simple projects should be configured with environment variables.
- Dockerized web services listen on port 8080 by default, overridable with
`PORT`.
- `README.md` is the primary documentation. Required sections:
- **Description**: First line must include the project name, purpose,
category (web server, SPA, CLI tool, etc.), license, and author. Example:
"µPaaS is an MIT-licensed Go web application by @sneak that receives
git-frontend webhooks and deploys applications via Docker in realtime."
- **Getting Started**: Copy-pasteable install/usage code block.
- **Rationale**: Why does this exist?
- **Design**: How is the program structured?
- **TODO**: Update meticulously, even between commits. When planning, put
the todo list in the README so a new agent can pick up where the last one
left off.
- **License**: MIT, GPL, or WTFPL. Ask the user for new projects. Include a
`LICENSE` file in the repo root and a License section in the README.
- **Author**: [@sneak](https://sneak.berlin).
- First commit of a new repo should contain only `README.md`.
- Go module root: `sneak.berlin/go/<name>`. Always run `go mod tidy` before
committing.
- Use SemVer.
- Database migrations live in `internal/db/migrations/` and must be embedded in
the binary.
- `000_migration.sql` — contains ONLY the creation of the migrations tracking
table itself. Nothing else.
- `001_schema.sql` — the full application schema.
- **Pre-1.0.0:** never add additional migration files (002, 003, etc.). There
is no installed base to migrate. Edit `001_schema.sql` directly.
- **Post-1.0.0:** add new numbered migration files for each schema change.
Never edit existing migrations after release.
- All repos should have an `.editorconfig` enforcing the project's indentation
settings.
- Avoid putting files in the repo root unless necessary. Root should contain
only project-level config files (`README.md`, `Makefile`, `Dockerfile`,
`LICENSE`, `.gitignore`, `.editorconfig`, `REPO_POLICIES.md`, and
language-specific config). Everything else goes in a subdirectory. Canonical
subdirectory names:
- `bin/` — executable scripts and tools
- `cmd/` — Go command entrypoints
- `configs/` — configuration templates and examples
- `deploy/` — deployment manifests (k8s, compose, terraform)
- `docs/` — documentation and markdown (README.md stays in root)
- `internal/` — Go internal packages
- `internal/db/migrations/` — database migrations
- `pkg/` — Go library packages
- `share/` — systemd units, data files
- `static/` — static assets (images, fonts, etc.)
- `web/` — web frontend source
- When setting up a new repo, files from the `prompts` repo may be used as
templates. Fetch them from
`https://git.eeqj.de/sneak/prompts/raw/branch/main/<path>`.
- New repos must contain at minimum:
- `README.md`, `.git`, `.gitignore`, `.editorconfig`
- `LICENSE`, `REPO_POLICIES.md` (copy from the `prompts` repo)
- `Makefile`
- `Dockerfile`, `.dockerignore`
- `.gitea/workflows/check.yml`
- Go: `go.mod`, `go.sum`, `.golangci.yml`
- JS: `package.json`, `yarn.lock`, `.prettierrc`, `.prettierignore`
- Python: `pyproject.toml`

View File

@@ -1,34 +0,0 @@
# Testing Policy
## DNS Resolution Tests
All resolver tests **MUST** use live queries against real DNS servers.
No mocking of the DNS client layer is permitted.
### Rationale
The resolver performs iterative resolution from root nameservers through
the full delegation chain. Mocked responses cannot faithfully represent
the variety of real-world DNS behavior (truncation, referrals, glue
records, DNSSEC, varied response times, EDNS, etc.). Testing against
real servers ensures the resolver works correctly in production.
### Constraints
- Tests hit real DNS infrastructure and require network access
- Test duration depends on network conditions; timeout tuning keeps
the suite within the 30-second target
- Query timeout is calibrated to 3× maximum antipodal RTT (~300ms)
plus processing margin
- Root server fan-out is limited to reduce parallel query load
- Flaky failures from transient network issues are acceptable and
should be investigated as potential resolver bugs, not papered over
with mocks or skip flags
### What NOT to do
- **Do not mock `DNSClient`** for resolver tests (the mock constructor
exists for unit-testing other packages that consume the resolver)
- **Do not add `-short` flags** to skip slow tests
- **Do not increase `-timeout`** to hide hanging queries
- **Do not modify linter configuration** to suppress findings

9
go.mod
View File

@@ -7,24 +7,19 @@ require (
github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/cors v1.2.2 github.com/go-chi/cors v1.2.2
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/miekg/dns v1.1.72
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0 golang.org/x/net v0.50.0
golang.org/x/sync v0.19.0
) )
require ( require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect
@@ -39,11 +34,7 @@ require (
go.uber.org/zap v1.26.0 // indirect go.uber.org/zap v1.26.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

8
go.sum
View File

@@ -28,8 +28,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
@@ -76,18 +74,12 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -1,60 +0,0 @@
package handlers
import (
"net/http"
"time"
)
// domainResponse represents a single domain in the API response.
type domainResponse struct {
Domain string `json:"domain"`
Nameservers []string `json:"nameservers,omitempty"`
LastChecked string `json:"lastChecked,omitempty"`
Status string `json:"status"`
}
// domainsResponse is the top-level response for GET /api/v1/domains.
type domainsResponse struct {
Domains []domainResponse `json:"domains"`
}
// HandleDomains returns the configured domains and their status.
func (h *Handlers) HandleDomains() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
configured := h.config.Domains
snapshot := h.state.GetSnapshot()
domains := make(
[]domainResponse, 0, len(configured),
)
for _, domain := range configured {
dr := domainResponse{
Domain: domain,
Status: "pending",
}
ds, ok := snapshot.Domains[domain]
if ok {
dr.Nameservers = ds.Nameservers
dr.Status = "ok"
if !ds.LastChecked.IsZero() {
dr.LastChecked = ds.LastChecked.
Format(time.RFC3339)
}
}
domains = append(domains, dr)
}
h.respondJSON(
writer, request,
&domainsResponse{Domains: domains},
http.StatusOK,
)
}
}

View File

@@ -8,11 +8,9 @@ import (
"go.uber.org/fx" "go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/globals" "sneak.berlin/go/dnswatcher/internal/globals"
"sneak.berlin/go/dnswatcher/internal/healthcheck" "sneak.berlin/go/dnswatcher/internal/healthcheck"
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
"sneak.berlin/go/dnswatcher/internal/state"
) )
// Params contains dependencies for Handlers. // Params contains dependencies for Handlers.
@@ -22,8 +20,6 @@ type Params struct {
Logger *logger.Logger Logger *logger.Logger
Globals *globals.Globals Globals *globals.Globals
Healthcheck *healthcheck.Healthcheck Healthcheck *healthcheck.Healthcheck
State *state.State
Config *config.Config
} }
// Handlers provides HTTP request handlers. // Handlers provides HTTP request handlers.
@@ -32,8 +28,6 @@ type Handlers struct {
params *Params params *Params
globals *globals.Globals globals *globals.Globals
hc *healthcheck.Healthcheck hc *healthcheck.Healthcheck
state *state.State
config *config.Config
} }
// New creates a new Handlers instance. // New creates a new Handlers instance.
@@ -43,8 +37,6 @@ func New(_ fx.Lifecycle, params Params) (*Handlers, error) {
params: &params, params: &params,
globals: params.Globals, globals: params.Globals,
hc: params.Healthcheck, hc: params.Healthcheck,
state: params.State,
config: params.Config,
}, nil }, nil
} }
@@ -52,7 +44,7 @@ func (h *Handlers) respondJSON(
writer http.ResponseWriter, writer http.ResponseWriter,
_ *http.Request, _ *http.Request,
data any, data any,
status int, //nolint:unparam // general-purpose utility; status varies in future use status int,
) { ) {
writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(status) writer.WriteHeader(status)

View File

@@ -1,120 +0,0 @@
package handlers
import (
"net/http"
"sort"
"time"
"sneak.berlin/go/dnswatcher/internal/state"
)
// nameserverRecordResponse represents one nameserver's records
// in the API response.
type nameserverRecordResponse struct {
Nameserver string `json:"nameserver"`
Records map[string][]string `json:"records"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
LastChecked string `json:"lastChecked,omitempty"`
}
// hostnameResponse represents a single hostname in the API response.
type hostnameResponse struct {
Hostname string `json:"hostname"`
Nameservers []nameserverRecordResponse `json:"nameservers,omitempty"`
LastChecked string `json:"lastChecked,omitempty"`
Status string `json:"status"`
}
// hostnamesResponse is the top-level response for
// GET /api/v1/hostnames.
type hostnamesResponse struct {
Hostnames []hostnameResponse `json:"hostnames"`
}
// HandleHostnames returns the configured hostnames and their status.
func (h *Handlers) HandleHostnames() http.HandlerFunc {
return func(
writer http.ResponseWriter,
request *http.Request,
) {
configured := h.config.Hostnames
snapshot := h.state.GetSnapshot()
hostnames := make(
[]hostnameResponse, 0, len(configured),
)
for _, hostname := range configured {
hr := hostnameResponse{
Hostname: hostname,
Status: "pending",
}
hs, ok := snapshot.Hostnames[hostname]
if ok {
hr.Status = "ok"
if !hs.LastChecked.IsZero() {
hr.LastChecked = hs.LastChecked.
Format(time.RFC3339)
}
hr.Nameservers = buildNameserverRecords(
hs,
)
}
hostnames = append(hostnames, hr)
}
h.respondJSON(
writer, request,
&hostnamesResponse{Hostnames: hostnames},
http.StatusOK,
)
}
}
// buildNameserverRecords converts the per-nameserver state map
// into a sorted slice for deterministic JSON output.
func buildNameserverRecords(
hs *state.HostnameState,
) []nameserverRecordResponse {
if hs.RecordsByNameserver == nil {
return nil
}
nsNames := make(
[]string, 0, len(hs.RecordsByNameserver),
)
for ns := range hs.RecordsByNameserver {
nsNames = append(nsNames, ns)
}
sort.Strings(nsNames)
records := make(
[]nameserverRecordResponse, 0, len(nsNames),
)
for _, ns := range nsNames {
nsr := hs.RecordsByNameserver[ns]
entry := nameserverRecordResponse{
Nameserver: ns,
Records: nsr.Records,
Status: nsr.Status,
Error: nsr.Error,
}
if !nsr.LastChecked.IsZero() {
entry.LastChecked = nsr.LastChecked.
Format(time.RFC3339)
}
records = append(records, entry)
}
return records
}

View File

@@ -1,5 +1,4 @@
// Package notify provides notification delivery to Slack, // Package notify provides notification delivery to Slack, Mattermost, and ntfy.
// Mattermost, and ntfy.
package notify package notify
import ( import (
@@ -8,7 +7,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url" "net/url"
@@ -36,66 +34,8 @@ var (
ErrMattermostFailed = errors.New( ErrMattermostFailed = errors.New(
"mattermost notification failed", "mattermost notification failed",
) )
// ErrInvalidScheme is returned for disallowed URL schemes.
ErrInvalidScheme = errors.New("URL scheme not allowed")
// ErrMissingHost is returned when a URL has no host.
ErrMissingHost = errors.New("URL must have a host")
) )
// IsAllowedScheme checks if the URL scheme is permitted.
func IsAllowedScheme(scheme string) bool {
return scheme == "https" || scheme == "http"
}
// ValidateWebhookURL validates and sanitizes a webhook URL.
// It ensures the URL has an allowed scheme (http/https),
// a non-empty host, and returns a pre-parsed *url.URL
// reconstructed from validated components.
func ValidateWebhookURL(raw string) (*url.URL, error) {
u, err := url.ParseRequestURI(raw)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if !IsAllowedScheme(u.Scheme) {
return nil, fmt.Errorf(
"%w: %s", ErrInvalidScheme, u.Scheme,
)
}
if u.Host == "" {
return nil, fmt.Errorf("%w", ErrMissingHost)
}
// Reconstruct from parsed components.
clean := &url.URL{
Scheme: u.Scheme,
Host: u.Host,
Path: u.Path,
RawQuery: u.RawQuery,
}
return clean, nil
}
// newRequest creates an http.Request from a pre-validated *url.URL.
// This avoids passing URL strings to http.NewRequestWithContext,
// which gosec flags as a potential SSRF vector.
func newRequest(
ctx context.Context,
method string,
target *url.URL,
body io.Reader,
) *http.Request {
return (&http.Request{
Method: method,
URL: target,
Host: target.Host,
Header: make(http.Header),
Body: io.NopCloser(body),
}).WithContext(ctx)
}
// Params contains dependencies for Service. // Params contains dependencies for Service.
type Params struct { type Params struct {
fx.In fx.In
@@ -107,7 +47,7 @@ type Params struct {
// Service provides notification functionality. // Service provides notification functionality.
type Service struct { type Service struct {
log *slog.Logger log *slog.Logger
transport http.RoundTripper client *http.Client
config *config.Config config *config.Config
ntfyURL *url.URL ntfyURL *url.URL
slackWebhookURL *url.URL slackWebhookURL *url.URL
@@ -121,40 +61,32 @@ func New(
) (*Service, error) { ) (*Service, error) {
svc := &Service{ svc := &Service{
log: params.Logger.Get(), log: params.Logger.Get(),
transport: http.DefaultTransport, client: &http.Client{
Timeout: httpClientTimeout,
},
config: params.Config, config: params.Config,
} }
if params.Config.NtfyTopic != "" { if params.Config.NtfyTopic != "" {
u, err := ValidateWebhookURL( u, err := url.ParseRequestURI(params.Config.NtfyTopic)
params.Config.NtfyTopic,
)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
"invalid ntfy topic URL: %w", err,
)
} }
svc.ntfyURL = u svc.ntfyURL = u
} }
if params.Config.SlackWebhook != "" { if params.Config.SlackWebhook != "" {
u, err := ValidateWebhookURL( u, err := url.ParseRequestURI(params.Config.SlackWebhook)
params.Config.SlackWebhook,
)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
"invalid slack webhook URL: %w", err,
)
} }
svc.slackWebhookURL = u svc.slackWebhookURL = u
} }
if params.Config.MattermostWebhook != "" { if params.Config.MattermostWebhook != "" {
u, err := ValidateWebhookURL( u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
params.Config.MattermostWebhook,
)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err, "invalid mattermost webhook URL: %w", err,
@@ -167,8 +99,7 @@ func New(
return svc, nil return svc, nil
} }
// SendNotification sends a notification to all configured // SendNotification sends a notification to all configured endpoints.
// endpoints.
func (svc *Service) SendNotification( func (svc *Service) SendNotification(
ctx context.Context, ctx context.Context,
title, message, priority string, title, message, priority string,
@@ -239,20 +170,20 @@ func (svc *Service) sendNtfy(
"title", title, "title", title,
) )
ctx, cancel := context.WithTimeout( request, err := http.NewRequestWithContext(
ctx, httpClientTimeout, ctx,
) http.MethodPost,
defer cancel() topicURL.String(),
bytes.NewBufferString(message),
body := bytes.NewBufferString(message)
request := newRequest(
ctx, http.MethodPost, topicURL, body,
) )
if err != nil {
return fmt.Errorf("creating ntfy request: %w", err)
}
request.Header.Set("Title", title) request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority)) request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.transport.RoundTrip(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending ntfy request: %w", err) return fmt.Errorf("sending ntfy request: %w", err)
} }
@@ -261,8 +192,7 @@ func (svc *Service) sendNtfy(
if resp.StatusCode >= httpStatusClientError { if resp.StatusCode >= httpStatusClientError {
return fmt.Errorf( return fmt.Errorf(
"%w: status %d", "%w: status %d", ErrNtfyFailed, resp.StatusCode,
ErrNtfyFailed, resp.StatusCode,
) )
} }
@@ -302,11 +232,6 @@ func (svc *Service) sendSlack(
webhookURL *url.URL, webhookURL *url.URL,
title, message, priority string, title, message, priority string,
) error { ) error {
ctx, cancel := context.WithTimeout(
ctx, httpClientTimeout,
)
defer cancel()
svc.log.Debug( svc.log.Debug(
"sending webhook notification", "sending webhook notification",
"url", webhookURL.String(), "url", webhookURL.String(),
@@ -325,19 +250,22 @@ func (svc *Service) sendSlack(
body, err := json.Marshal(payload) body, err := json.Marshal(payload)
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf("marshaling webhook payload: %w", err)
"marshaling webhook payload: %w", err,
)
} }
request := newRequest( request, err := http.NewRequestWithContext(
ctx, http.MethodPost, webhookURL, ctx,
http.MethodPost,
webhookURL.String(),
bytes.NewBuffer(body), bytes.NewBuffer(body),
) )
if err != nil {
return fmt.Errorf("creating webhook request: %w", err)
}
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
resp, err := svc.transport.RoundTrip(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending webhook request: %w", err) return fmt.Errorf("sending webhook request: %w", err)
} }

View File

@@ -1,100 +0,0 @@
package notify_test
import (
"testing"
"sneak.berlin/go/dnswatcher/internal/notify"
)
func TestValidateWebhookURLValid(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantURL string
}{
{
name: "valid https URL",
input: "https://hooks.slack.com/T00/B00",
wantURL: "https://hooks.slack.com/T00/B00",
},
{
name: "valid http URL",
input: "http://localhost:8080/webhook",
wantURL: "http://localhost:8080/webhook",
},
{
name: "https with query",
input: "https://ntfy.sh/topic?auth=tok",
wantURL: "https://ntfy.sh/topic?auth=tok",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := notify.ValidateWebhookURL(tt.input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got.String() != tt.wantURL {
t.Errorf(
"got %q, want %q",
got.String(), tt.wantURL,
)
}
})
}
}
func TestValidateWebhookURLInvalid(t *testing.T) {
t.Parallel()
invalid := []struct {
name string
input string
}{
{"ftp scheme", "ftp://example.com/file"},
{"file scheme", "file:///etc/passwd"},
{"empty string", ""},
{"no scheme", "example.com/webhook"},
{"no host", "https:///path"},
}
for _, tt := range invalid {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := notify.ValidateWebhookURL(tt.input)
if err == nil {
t.Errorf(
"expected error for %q, got %v",
tt.input, got,
)
}
})
}
}
func TestIsAllowedScheme(t *testing.T) {
t.Parallel()
if !notify.IsAllowedScheme("https") {
t.Error("https should be allowed")
}
if !notify.IsAllowedScheme("http") {
t.Error("http should be allowed")
}
if notify.IsAllowedScheme("ftp") {
t.Error("ftp should not be allowed")
}
if notify.IsAllowedScheme("") {
t.Error("empty scheme should not be allowed")
}
}

View File

@@ -3,29 +3,18 @@ package portcheck
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
"strconv" "strconv"
"sync"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
"golang.org/x/sync/errgroup"
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
const ( const defaultTimeout = 5 * time.Second
minPort = 1
maxPort = 65535
defaultTimeout = 5 * time.Second
)
// ErrInvalidPort is returned when a port number is outside
// the valid TCP range (165535).
var ErrInvalidPort = errors.New("invalid port number")
// PortResult holds the outcome of a single TCP port check. // PortResult holds the outcome of a single TCP port check.
type PortResult struct { type PortResult struct {
@@ -66,19 +55,6 @@ func NewStandalone() *Checker {
} }
} }
// validatePort checks that a port number is within the valid
// TCP port range (165535).
func validatePort(port int) error {
if port < minPort || port > maxPort {
return fmt.Errorf(
"%w: %d (must be between %d and %d)",
ErrInvalidPort, port, minPort, maxPort,
)
}
return nil
}
// CheckPort tests TCP connectivity to the given address and port. // CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier // It uses a 5-second timeout unless the context has an earlier
// deadline. // deadline.
@@ -87,18 +63,13 @@ func (c *Checker) CheckPort(
address string, address string,
port int, port int,
) (*PortResult, error) { ) (*PortResult, error) {
err := validatePort(port)
if err != nil {
return nil, err
}
target := net.JoinHostPort( target := net.JoinHostPort(
address, strconv.Itoa(port), address, strconv.Itoa(port),
) )
deadline, hasDeadline := ctx.Deadline()
timeout := defaultTimeout timeout := defaultTimeout
deadline, hasDeadline := ctx.Deadline()
if hasDeadline { if hasDeadline {
remaining := time.Until(deadline) remaining := time.Until(deadline)
if remaining < timeout { if remaining < timeout {
@@ -106,62 +77,8 @@ func (c *Checker) CheckPort(
} }
} }
return c.checkConnection(ctx, target, timeout), nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address concurrently. It returns a map of port number
// to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
for _, port := range ports {
err := validatePort(port)
if err != nil {
return nil, err
}
}
var mu sync.Mutex
results := make(map[int]*PortResult, len(ports))
g, ctx := errgroup.WithContext(ctx)
for _, port := range ports {
g.Go(func() error {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return fmt.Errorf(
"checking port %d: %w", port, err,
)
}
mu.Lock()
results[port] = result
mu.Unlock()
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return results, nil
}
// checkConnection performs the TCP dial and returns a result.
func (c *Checker) checkConnection(
ctx context.Context,
target string,
timeout time.Duration,
) *PortResult {
dialer := &net.Dialer{Timeout: timeout} dialer := &net.Dialer{Timeout: timeout}
start := time.Now() start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target) conn, dialErr := dialer.DialContext(ctx, "tcp", target)
@@ -178,7 +95,7 @@ func (c *Checker) checkConnection(
Open: false, Open: false,
Error: dialErr.Error(), Error: dialErr.Error(),
Latency: latency, Latency: latency,
} }, nil
} }
closeErr := conn.Close() closeErr := conn.Close()
@@ -199,5 +116,28 @@ func (c *Checker) checkConnection(
return &PortResult{ return &PortResult{
Open: true, Open: true,
Latency: latency, Latency: latency,
} }, nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address. It returns a map of port number to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
results := make(map[int]*PortResult, len(ports))
for _, port := range ports {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return nil, fmt.Errorf(
"checking port %d: %w", port, err,
)
}
results[port] = result
}
return results, nil
} }

View File

@@ -138,54 +138,6 @@ func TestCheckPortsMultiple(t *testing.T) {
} }
} }
func TestCheckPortInvalidPorts(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
cases := []struct {
name string
port int
}{
{"zero", 0},
{"negative", -1},
{"too high", 65536},
{"very negative", -1000},
{"very high", 100000},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := checker.CheckPort(
context.Background(), "127.0.0.1", tc.port,
)
if err == nil {
t.Errorf(
"expected error for port %d, got nil",
tc.port,
)
}
})
}
}
func TestCheckPortsInvalidPort(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
_, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{80, 0, 443},
)
if err == nil {
t.Error("expected error for invalid port in list")
}
}
func TestCheckPortLatencyReasonable(t *testing.T) { func TestCheckPortLatencyReasonable(t *testing.T) {
t.Parallel() t.Parallel()

View File

@@ -1,48 +0,0 @@
package resolver
import (
"context"
"time"
"github.com/miekg/dns"
)
// DNSClient abstracts DNS wire-protocol exchanges so the resolver
// can be tested without hitting real nameservers.
type DNSClient interface {
ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error)
}
// udpClient wraps a real dns.Client for production use.
type udpClient struct {
timeout time.Duration
}
func (c *udpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}
// tcpClient wraps a real dns.Client using TCP.
type tcpClient struct {
timeout time.Duration
}
func (c *tcpClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
cl := &dns.Client{Net: "tcp", Timeout: c.timeout}
return cl.ExchangeContext(ctx, msg, addr)
}

View File

@@ -1,22 +0,0 @@
package resolver
import "errors"
// Sentinel errors returned by the resolver.
var (
// ErrNoNameservers is returned when no authoritative NS
// could be discovered for a domain.
ErrNoNameservers = errors.New(
"no authoritative nameservers found",
)
// ErrCNAMEDepthExceeded is returned when a CNAME chain
// exceeds MaxCNAMEDepth.
ErrCNAMEDepthExceeded = errors.New(
"CNAME chain depth exceeded",
)
// ErrContextCanceled wraps context cancellation for the
// resolver's iterative queries.
ErrContextCanceled = errors.New("context canceled")
)

View File

@@ -1,803 +0,0 @@
package resolver
import (
"context"
"errors"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/miekg/dns"
)
const (
queryTimeoutDuration = 2 * time.Second
maxRetries = 2
maxDelegation = 20
timeoutMultiplier = 2
minDomainLabels = 2
)
// ErrRefused is returned when a DNS server refuses a query.
var ErrRefused = errors.New("dns query refused")
func rootServerList() []string {
return []string{
"198.41.0.4", // a.root-servers.net
"170.247.170.2", // b
"192.33.4.12", // c
"199.7.91.13", // d
"192.203.230.10", // e
"192.5.5.241", // f
"192.112.36.4", // g
"198.97.190.53", // h
"192.36.148.17", // i
"192.58.128.30", // j
"193.0.14.129", // k
"199.7.83.42", // l
"202.12.27.33", // m
}
}
func checkCtx(ctx context.Context) error {
err := ctx.Err()
if err != nil {
return ErrContextCanceled
}
return nil
}
func (r *Resolver) exchangeWithTimeout(
ctx context.Context,
msg *dns.Msg,
addr string,
attempt int,
) (*dns.Msg, error) {
_ = attempt // timeout escalation handled by client config
resp, _, err := r.client.ExchangeContext(ctx, msg, addr)
return resp, err
}
func (r *Resolver) tryExchange(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, error) {
var resp *dns.Msg
var err error
for attempt := range maxRetries {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err = r.exchangeWithTimeout(
ctx, msg, addr, attempt,
)
if err == nil {
break
}
}
return resp, err
}
func (r *Resolver) retryTCP(
ctx context.Context,
msg *dns.Msg,
addr string,
resp *dns.Msg,
) *dns.Msg {
if !resp.Truncated {
return resp
}
tcpResp, _, tcpErr := r.tcp.ExchangeContext(ctx, msg, addr)
if tcpErr == nil {
return tcpResp
}
return resp
}
// queryDNS sends a DNS query to a specific server IP.
// Tries non-recursive first, falls back to recursive on
// REFUSED (handles DNS interception environments).
func (r *Resolver) queryDNS(
ctx context.Context,
serverIP string,
name string,
qtype uint16,
) (*dns.Msg, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
name = dns.Fqdn(name)
addr := net.JoinHostPort(serverIP, "53")
msg := new(dns.Msg)
msg.SetQuestion(name, qtype)
msg.RecursionDesired = false
resp, err := r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf("query %s @%s: %w", name, serverIP, err)
}
if resp.Rcode == dns.RcodeRefused {
msg.RecursionDesired = true
resp, err = r.tryExchange(ctx, msg, addr)
if err != nil {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, err,
)
}
if resp.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf(
"query %s @%s: %w", name, serverIP, ErrRefused,
)
}
}
resp = r.retryTCP(ctx, msg, addr, resp)
return resp, nil
}
func extractNSSet(rrs []dns.RR) []string {
nsSet := make(map[string]bool)
for _, rr := range rrs {
if ns, ok := rr.(*dns.NS); ok {
nsSet[strings.ToLower(ns.Ns)] = true
}
}
names := make([]string, 0, len(nsSet))
for n := range nsSet {
names = append(names, n)
}
sort.Strings(names)
return names
}
func extractGlue(rrs []dns.RR) map[string][]net.IP {
glue := make(map[string][]net.IP)
for _, rr := range rrs {
switch r := rr.(type) {
case *dns.A:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.A)
case *dns.AAAA:
name := strings.ToLower(r.Hdr.Name)
glue[name] = append(glue[name], r.AAAA)
}
}
return glue
}
func glueIPs(nsNames []string, glue map[string][]net.IP) []string {
var ips []string
for _, ns := range nsNames {
for _, addr := range glue[ns] {
if v4 := addr.To4(); v4 != nil {
ips = append(ips, v4.String())
}
}
}
return ips
}
func (r *Resolver) followDelegation(
ctx context.Context,
domain string,
servers []string,
) ([]string, error) {
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryServers(
ctx, servers, domain, dns.TypeNS,
)
if err != nil {
return nil, err
}
ansNS := extractNSSet(resp.Answer)
if len(ansNS) > 0 {
return ansNS, nil
}
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
return r.resolveNSIterative(ctx, domain)
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
nextServers = r.resolveNSIPs(ctx, authNS)
}
if len(nextServers) == 0 {
return nil, ErrNoNameservers
}
servers = nextServers
}
return nil, ErrNoNameservers
}
func (r *Resolver) queryServers(
ctx context.Context,
servers []string,
name string,
qtype uint16,
) (*dns.Msg, error) {
var lastErr error
for _, ip := range servers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryDNS(ctx, ip, name, qtype)
if err == nil {
return resp, nil
}
lastErr = err
}
return nil, fmt.Errorf("all servers failed: %w", lastErr)
}
func (r *Resolver) resolveNSIPs(
ctx context.Context,
nsNames []string,
) []string {
var ips []string
for _, ns := range nsNames {
resolved, err := r.resolveARecord(ctx, ns)
if err == nil {
ips = append(ips, resolved...)
}
if len(ips) > 0 {
break
}
}
return ips
}
// resolveNSIterative queries for NS records using iterative
// resolution as a fallback when followDelegation finds no
// authoritative answer in the delegation chain.
func (r *Resolver) resolveNSIterative(
ctx context.Context,
domain string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
domain = dns.Fqdn(domain)
servers := rootServerList()
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryServers(
ctx, servers, domain, dns.TypeNS,
)
if err != nil {
return nil, err
}
nsNames := extractNSSet(resp.Answer)
if len(nsNames) > 0 {
return nsNames, nil
}
// Follow delegation.
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
break
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
break
}
servers = nextServers
}
return nil, ErrNoNameservers
}
// resolveARecord resolves a hostname to IPv4 addresses using
// iterative resolution through the delegation chain.
func (r *Resolver) resolveARecord(
ctx context.Context,
hostname string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
hostname = dns.Fqdn(hostname)
servers := rootServerList()
for range maxDelegation {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.queryServers(
ctx, servers, hostname, dns.TypeA,
)
if err != nil {
return nil, fmt.Errorf(
"resolving %s: %w", hostname, err,
)
}
// Check for A records in the answer section.
var ips []string
for _, rr := range resp.Answer {
if a, ok := rr.(*dns.A); ok {
ips = append(ips, a.A.String())
}
}
if len(ips) > 0 {
return ips, nil
}
// Follow delegation if present.
authNS := extractNSSet(resp.Ns)
if len(authNS) == 0 {
break
}
glue := extractGlue(resp.Extra)
nextServers := glueIPs(authNS, glue)
if len(nextServers) == 0 {
// Resolve NS IPs iteratively — but guard
// against infinite recursion by using only
// already-resolved servers.
break
}
servers = nextServers
}
return nil, fmt.Errorf(
"cannot resolve %s: %w", hostname, ErrNoNameservers,
)
}
// FindAuthoritativeNameservers traces the delegation chain from
// root servers to discover all authoritative nameservers for the
// given domain. Walks up the label hierarchy for subdomains.
func (r *Resolver) FindAuthoritativeNameservers(
ctx context.Context,
domain string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
domain = dns.Fqdn(strings.ToLower(domain))
labels := dns.SplitDomainName(domain)
for i := range labels {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
candidate := strings.Join(labels[i:], ".") + "."
nsNames, err := r.followDelegation(
ctx, candidate, rootServerList(),
)
if err == nil && len(nsNames) > 0 {
sort.Strings(nsNames)
return nsNames, nil
}
}
return nil, ErrNoNameservers
}
// QueryNameserver queries a specific nameserver for all record
// types and builds a NameserverResponse.
func (r *Resolver) QueryNameserver(
ctx context.Context,
nsHostname string,
hostname string,
) (*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
nsIPs, err := r.resolveARecord(ctx, nsHostname)
if err != nil {
return nil, fmt.Errorf("resolving NS %s: %w", nsHostname, err)
}
hostname = dns.Fqdn(hostname)
return r.queryAllTypes(ctx, nsHostname, nsIPs[0], hostname)
}
// QueryNameserverIP queries a nameserver by its IP address directly,
// bypassing NS hostname resolution.
func (r *Resolver) QueryNameserverIP(
ctx context.Context,
nsHostname string,
nsIP string,
hostname string,
) (*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
hostname = dns.Fqdn(hostname)
return r.queryAllTypes(ctx, nsHostname, nsIP, hostname)
}
func (r *Resolver) queryAllTypes(
ctx context.Context,
nsHostname string,
nsIP string,
hostname string,
) (*NameserverResponse, error) {
resp := &NameserverResponse{
Nameserver: nsHostname,
Records: make(map[string][]string),
Status: StatusOK,
}
qtypes := []uint16{
dns.TypeA, dns.TypeAAAA, dns.TypeCNAME,
dns.TypeMX, dns.TypeTXT, dns.TypeSRV,
dns.TypeCAA, dns.TypeNS,
}
state := r.queryEachType(ctx, nsIP, hostname, qtypes, resp)
classifyResponse(resp, state)
return resp, nil
}
type queryState struct {
gotNXDomain bool
gotSERVFAIL bool
gotTimeout bool
hasRecords bool
}
func (r *Resolver) queryEachType(
ctx context.Context,
nsIP string,
hostname string,
qtypes []uint16,
resp *NameserverResponse,
) queryState {
var state queryState
for _, qtype := range qtypes {
if checkCtx(ctx) != nil {
break
}
r.querySingleType(ctx, nsIP, hostname, qtype, resp, &state)
}
for k := range resp.Records {
sort.Strings(resp.Records[k])
}
return state
}
func (r *Resolver) querySingleType(
ctx context.Context,
nsIP string,
hostname string,
qtype uint16,
resp *NameserverResponse,
state *queryState,
) {
msg, err := r.queryDNS(ctx, nsIP, hostname, qtype)
if err != nil {
if isTimeout(err) {
state.gotTimeout = true
}
return
}
if msg.Rcode == dns.RcodeNameError {
state.gotNXDomain = true
return
}
if msg.Rcode == dns.RcodeServerFailure {
state.gotSERVFAIL = true
return
}
collectAnswerRecords(msg, resp, state)
}
func collectAnswerRecords(
msg *dns.Msg,
resp *NameserverResponse,
state *queryState,
) {
for _, rr := range msg.Answer {
val := extractRecordValue(rr)
if val == "" {
continue
}
typeName := dns.TypeToString[rr.Header().Rrtype]
resp.Records[typeName] = append(
resp.Records[typeName], val,
)
state.hasRecords = true
}
}
// isTimeout checks whether an error is a network timeout.
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func classifyResponse(resp *NameserverResponse, state queryState) {
switch {
case state.gotNXDomain && !state.hasRecords:
resp.Status = StatusNXDomain
case state.gotTimeout && !state.hasRecords:
resp.Status = StatusTimeout
resp.Error = "all queries timed out"
case state.gotSERVFAIL && !state.hasRecords:
resp.Status = StatusError
resp.Error = "server returned SERVFAIL"
case !state.hasRecords && !state.gotNXDomain:
resp.Status = StatusNoData
}
}
// extractRecordValue formats a DNS RR value as a string.
func extractRecordValue(rr dns.RR) string {
switch r := rr.(type) {
case *dns.A:
return r.A.String()
case *dns.AAAA:
return r.AAAA.String()
case *dns.CNAME:
return r.Target
case *dns.MX:
return fmt.Sprintf("%d %s", r.Preference, r.Mx)
case *dns.TXT:
return strings.Join(r.Txt, "")
case *dns.SRV:
return fmt.Sprintf(
"%d %d %d %s",
r.Priority, r.Weight, r.Port, r.Target,
)
case *dns.CAA:
return fmt.Sprintf(
"%d %s \"%s\"", r.Flag, r.Tag, r.Value,
)
case *dns.NS:
return r.Ns
default:
return ""
}
}
// parentDomain returns the registerable parent domain.
func parentDomain(hostname string) string {
hostname = dns.Fqdn(strings.ToLower(hostname))
labels := dns.SplitDomainName(hostname)
if len(labels) <= minDomainLabels {
return strings.Join(labels, ".") + "."
}
return strings.Join(
labels[len(labels)-minDomainLabels:], ".",
) + "."
}
// QueryAllNameservers discovers auth NSes for the hostname's
// parent domain, then queries each one independently.
func (r *Resolver) QueryAllNameservers(
ctx context.Context,
hostname string,
) (map[string]*NameserverResponse, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
parent := parentDomain(hostname)
nameservers, err := r.FindAuthoritativeNameservers(ctx, parent)
if err != nil {
return nil, err
}
return r.queryEachNS(ctx, nameservers, hostname)
}
func (r *Resolver) queryEachNS(
ctx context.Context,
nameservers []string,
hostname string,
) (map[string]*NameserverResponse, error) {
results := make(map[string]*NameserverResponse)
for _, ns := range nameservers {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
resp, err := r.QueryNameserver(ctx, ns, hostname)
if err != nil {
results[ns] = &NameserverResponse{
Nameserver: ns,
Records: make(map[string][]string),
Status: StatusError,
Error: err.Error(),
}
continue
}
results[ns] = resp
}
return results, nil
}
// LookupNS returns the NS record set for a domain.
func (r *Resolver) LookupNS(
ctx context.Context,
domain string,
) ([]string, error) {
return r.FindAuthoritativeNameservers(ctx, domain)
}
// LookupAllRecords performs iterative resolution to find all DNS
// records for the given hostname, keyed by authoritative nameserver.
func (r *Resolver) LookupAllRecords(
ctx context.Context,
hostname string,
) (map[string]map[string][]string, error) {
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
out := make(map[string]map[string][]string, len(results))
for ns, resp := range results {
out[ns] = resp.Records
}
return out, nil
}
// ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains up to MaxCNAMEDepth.
func (r *Resolver) ResolveIPAddresses(
ctx context.Context,
hostname string,
) ([]string, error) {
if checkCtx(ctx) != nil {
return nil, ErrContextCanceled
}
return r.resolveIPWithCNAME(ctx, hostname, 0)
}
func (r *Resolver) resolveIPWithCNAME(
ctx context.Context,
hostname string,
depth int,
) ([]string, error) {
if depth > MaxCNAMEDepth {
return nil, ErrCNAMEDepthExceeded
}
results, err := r.QueryAllNameservers(ctx, hostname)
if err != nil {
return nil, err
}
ips, cnameTarget := collectIPs(results)
if len(ips) == 0 && cnameTarget != "" {
return r.resolveIPWithCNAME(ctx, cnameTarget, depth+1)
}
sort.Strings(ips)
return ips, nil
}
func collectIPs(
results map[string]*NameserverResponse,
) ([]string, string) {
seen := make(map[string]bool)
var ips []string
var cnameTarget string
for _, resp := range results {
if resp.Status == StatusNXDomain {
continue
}
for _, ip := range resp.Records["A"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
for _, ip := range resp.Records["AAAA"] {
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
if len(resp.Records["CNAME"]) > 0 && cnameTarget == "" {
cnameTarget = resp.Records["CNAME"][0]
}
}
return ips, cnameTarget
}

View File

@@ -1,9 +1,9 @@
// Package resolver provides iterative DNS resolution from root nameservers. // Package resolver provides iterative DNS resolution from root nameservers.
// It traces the full delegation chain from IANA root servers through TLD
// and domain nameservers, never relying on upstream recursive resolvers.
package resolver package resolver
import ( import (
"context"
"errors"
"log/slog" "log/slog"
"go.uber.org/fx" "go.uber.org/fx"
@@ -11,17 +11,8 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
// Query status constants matching the state model. // ErrNotImplemented indicates the resolver is not yet implemented.
const ( var ErrNotImplemented = errors.New("resolver not yet implemented")
StatusOK = "ok"
StatusError = "error"
StatusNXDomain = "nxdomain"
StatusNoData = "nodata"
StatusTimeout = "timeout"
)
// MaxCNAMEDepth is the maximum CNAME chain depth to follow.
const MaxCNAMEDepth = 10
// Params contains dependencies for Resolver. // Params contains dependencies for Resolver.
type Params struct { type Params struct {
@@ -30,54 +21,44 @@ type Params struct {
Logger *logger.Logger Logger *logger.Logger
} }
// NameserverResponse holds one nameserver's response for a query.
type NameserverResponse struct {
Nameserver string
Records map[string][]string
Status string
Error string
}
// Resolver performs iterative DNS resolution from root servers. // Resolver performs iterative DNS resolution from root servers.
type Resolver struct { type Resolver struct {
log *slog.Logger log *slog.Logger
client DNSClient
tcp DNSClient
} }
// New creates a new Resolver instance for use with uber/fx. // New creates a new Resolver instance.
func New( func New(
_ fx.Lifecycle, _ fx.Lifecycle,
params Params, params Params,
) (*Resolver, error) { ) (*Resolver, error) {
return &Resolver{ return &Resolver{
log: params.Logger.Get(), log: params.Logger.Get(),
client: &udpClient{timeout: queryTimeoutDuration},
tcp: &tcpClient{timeout: queryTimeoutDuration},
}, nil }, nil
} }
// NewFromLogger creates a Resolver directly from an slog.Logger, // LookupNS performs iterative resolution to find authoritative
// useful for testing without the fx lifecycle. // nameservers for the given domain.
func NewFromLogger(log *slog.Logger) *Resolver { func (r *Resolver) LookupNS(
return &Resolver{ _ context.Context,
log: log, _ string,
client: &udpClient{timeout: queryTimeoutDuration}, ) ([]string, error) {
tcp: &tcpClient{timeout: queryTimeoutDuration}, return nil, ErrNotImplemented
}
} }
// NewFromLoggerWithClient creates a Resolver with a custom DNS // LookupAllRecords performs iterative resolution to find all DNS
// client, useful for testing with mock DNS responses. // records for the given hostname, keyed by authoritative nameserver.
func NewFromLoggerWithClient( func (r *Resolver) LookupAllRecords(
log *slog.Logger, _ context.Context,
client DNSClient, _ string,
) *Resolver { ) (map[string]map[string][]string, error) {
return &Resolver{ return nil, ErrNotImplemented
log: log,
client: client,
tcp: client,
}
} }
// Method implementations are in iterative.go. // ResolveIPAddresses resolves a hostname to all IPv4 and IPv6
// addresses, following CNAME chains.
func (r *Resolver) ResolveIPAddresses(
_ context.Context,
_ string,
) ([]string, error) {
return nil, ErrNotImplemented
}

View File

@@ -1,688 +0,0 @@
package resolver_test
import (
"context"
"log/slog"
"net"
"os"
"sort"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/resolver"
)
// ----------------------------------------------------------------
// Test helpers
// ----------------------------------------------------------------
func newTestResolver(t *testing.T) *resolver.Resolver {
t.Helper()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
return resolver.NewFromLogger(log)
}
func testContext(t *testing.T) context.Context {
t.Helper()
ctx, cancel := context.WithTimeout(
context.Background(), 60*time.Second,
)
t.Cleanup(cancel)
return ctx
}
func findOneNSForDomain(
t *testing.T,
r *resolver.Resolver,
ctx context.Context, //nolint:revive // test helper
domain string,
) string {
t.Helper()
nameservers, err := r.FindAuthoritativeNameservers(
ctx, domain,
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
return nameservers[0]
}
// ----------------------------------------------------------------
// FindAuthoritativeNameservers tests
// ----------------------------------------------------------------
func TestFindAuthoritativeNameservers_ValidDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
hasGoogleNS := false
for _, ns := range nameservers {
if strings.Contains(ns, "google") {
hasGoogleNS = true
break
}
}
assert.True(t, hasGoogleNS,
"expected google nameservers, got: %v", nameservers,
)
}
func TestFindAuthoritativeNameservers_Subdomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "www.google.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
}
func TestFindAuthoritativeNameservers_ReturnsSorted(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
assert.True(
t,
sort.StringsAreSorted(nameservers),
"nameservers should be sorted, got: %v", nameservers,
)
}
func TestFindAuthoritativeNameservers_Deterministic(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
first, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
second, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
assert.Equal(t, first, second)
}
func TestFindAuthoritativeNameservers_TrailingDot(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns1, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
ns2, err := r.FindAuthoritativeNameservers(
ctx, "google.com.",
)
require.NoError(t, err)
assert.Equal(t, ns1, ns2)
}
func TestFindAuthoritativeNameservers_CloudflareDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "cloudflare.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
for _, ns := range nameservers {
assert.True(t, strings.HasSuffix(ns, "."),
"NS should be FQDN with trailing dot: %s", ns,
)
}
}
// ----------------------------------------------------------------
// QueryNameserver tests
// ----------------------------------------------------------------
func TestQueryNameserver_BasicA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "www.google.com",
)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, resolver.StatusOK, resp.Status)
assert.Equal(t, ns, resp.Nameserver)
hasRecords := len(resp.Records["A"]) > 0 ||
len(resp.Records["CNAME"]) > 0
assert.True(t, hasRecords,
"expected A or CNAME records for www.google.com",
)
}
func TestQueryNameserver_AAAA(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "cloudflare.com")
resp, err := r.QueryNameserver(
ctx, ns, "cloudflare.com",
)
require.NoError(t, err)
aaaaRecords := resp.Records["AAAA"]
require.NotEmpty(t, aaaaRecords,
"cloudflare.com should have AAAA records",
)
for _, ip := range aaaaRecords {
parsed := net.ParseIP(ip)
require.NotNil(t, parsed,
"should be valid IP: %s", ip,
)
}
}
func TestQueryNameserver_MX(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
mxRecords := resp.Records["MX"]
require.NotEmpty(t, mxRecords,
"google.com should have MX records",
)
}
func TestQueryNameserver_TXT(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
txtRecords := resp.Records["TXT"]
require.NotEmpty(t, txtRecords,
"google.com should have TXT records",
)
hasSPF := false
for _, txt := range txtRecords {
if strings.Contains(txt, "v=spf1") {
hasSPF = true
break
}
}
assert.True(t, hasSPF,
"google.com should have SPF TXT record",
)
}
func TestQueryNameserver_NXDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
assert.Equal(t, resolver.StatusNXDomain, resp.Status)
}
func TestQueryNameserver_RecordsSorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
for recordType, values := range resp.Records {
assert.True(
t,
sort.StringsAreSorted(values),
"%s records should be sorted", recordType,
)
}
}
func TestQueryNameserver_ResponseIncludesNameserver(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "cloudflare.com")
resp, err := r.QueryNameserver(
ctx, ns, "cloudflare.com",
)
require.NoError(t, err)
assert.Equal(t, ns, resp.Nameserver)
}
func TestQueryNameserver_EmptyRecordsOnNXDomain(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp, err := r.QueryNameserver(
ctx, ns,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
totalRecords := 0
for _, values := range resp.Records {
totalRecords += len(values)
}
assert.Zero(t, totalRecords)
}
func TestQueryNameserver_TrailingDotHandling(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
resp1, err := r.QueryNameserver(
ctx, ns, "google.com",
)
require.NoError(t, err)
resp2, err := r.QueryNameserver(
ctx, ns, "google.com.",
)
require.NoError(t, err)
assert.Equal(t, resp1.Status, resp2.Status)
}
// ----------------------------------------------------------------
// QueryAllNameservers tests
// ----------------------------------------------------------------
func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx, "google.com",
)
require.NoError(t, err)
require.NotEmpty(t, results)
assert.GreaterOrEqual(t, len(results), 2)
for ns, resp := range results {
assert.Equal(t, ns, resp.Nameserver)
}
}
func TestQueryAllNameservers_AllReturnOK(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx, "google.com",
)
require.NoError(t, err)
for ns, resp := range results {
assert.Equal(
t, resolver.StatusOK, resp.Status,
"NS %s should return OK", ns,
)
}
}
func TestQueryAllNameservers_NXDomainFromAllNS(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
for ns, resp := range results {
assert.Equal(
t, resolver.StatusNXDomain, resp.Status,
"NS %s should return nxdomain", ns,
)
}
}
// ----------------------------------------------------------------
// LookupNS tests
// ----------------------------------------------------------------
func TestLookupNS_ValidDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, "google.com")
require.NoError(t, err)
require.NotEmpty(t, nameservers)
for _, ns := range nameservers {
assert.True(t, strings.HasSuffix(ns, "."),
"NS should have trailing dot: %s", ns,
)
}
}
func TestLookupNS_Sorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, "google.com")
require.NoError(t, err)
assert.True(t, sort.StringsAreSorted(nameservers))
}
func TestLookupNS_MatchesFindAuthoritative(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
fromLookup, err := r.LookupNS(ctx, "google.com")
require.NoError(t, err)
fromFind, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
)
require.NoError(t, err)
assert.Equal(t, fromFind, fromLookup)
}
// ----------------------------------------------------------------
// ResolveIPAddresses tests
// ----------------------------------------------------------------
func TestResolveIPAddresses_ReturnsIPs(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
require.NoError(t, err)
require.NotEmpty(t, ips)
for _, ip := range ips {
parsed := net.ParseIP(ip)
assert.NotNil(t, parsed,
"should be valid IP: %s", ip,
)
}
}
func TestResolveIPAddresses_Deduplicated(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
require.NoError(t, err)
seen := make(map[string]bool)
for _, ip := range ips {
assert.False(t, seen[ip], "duplicate IP: %s", ip)
seen[ip] = true
}
}
func TestResolveIPAddresses_Sorted(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
require.NoError(t, err)
assert.True(t, sort.StringsAreSorted(ips))
}
func TestResolveIPAddresses_NXDomainReturnsEmpty(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(
ctx,
"this-surely-does-not-exist-xyz.google.com",
)
require.NoError(t, err)
assert.Empty(t, ips)
}
func TestResolveIPAddresses_CloudflareDomain(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "cloudflare.com")
require.NoError(t, err)
require.NotEmpty(t, ips)
}
// ----------------------------------------------------------------
// Context cancellation tests
// ----------------------------------------------------------------
func TestFindAuthoritativeNameservers_ContextCanceled(
t *testing.T,
) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.FindAuthoritativeNameservers(ctx, "google.com")
assert.Error(t, err)
}
func TestQueryNameserver_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.QueryNameserver(
ctx, "ns1.google.com.", "google.com",
)
assert.Error(t, err)
}
func TestQueryAllNameservers_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.QueryAllNameservers(ctx, "google.com")
assert.Error(t, err)
}
// ----------------------------------------------------------------
// Timeout tests
// ----------------------------------------------------------------
func TestQueryNameserverIP_Timeout(t *testing.T) {
t.Parallel()
log := slog.New(slog.NewTextHandler(
os.Stderr,
&slog.HandlerOptions{Level: slog.LevelDebug},
))
r := resolver.NewFromLoggerWithClient(
log, &timeoutClient{},
)
ctx, cancel := context.WithTimeout(
context.Background(), 10*time.Second,
)
t.Cleanup(cancel)
// Query any IP — the client always returns a timeout error.
resp, err := r.QueryNameserverIP(
ctx, "unreachable.test.", "192.0.2.1",
"example.com",
)
require.NoError(t, err)
assert.Equal(t, resolver.StatusTimeout, resp.Status)
assert.NotEmpty(t, resp.Error)
}
// timeoutClient simulates DNS timeout errors for testing.
type timeoutClient struct{}
func (c *timeoutClient) ExchangeContext(
_ context.Context,
_ *dns.Msg,
_ string,
) (*dns.Msg, time.Duration, error) {
return nil, 0, &net.OpError{
Op: "read",
Net: "udp",
Err: &timeoutError{},
}
}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
func TestResolveIPAddresses_ContextCanceled(t *testing.T) {
t.Parallel()
r := newTestResolver(t)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.ResolveIPAddresses(ctx, "google.com")
assert.Error(t, err)
}

View File

@@ -28,8 +28,6 @@ func (s *Server) SetupRoutes() {
// API v1 routes // API v1 routes
s.router.Route("/api/v1", func(r chi.Router) { s.router.Route("/api/v1", func(r chi.Router) {
r.Get("/status", s.handlers.HandleStatus()) r.Get("/status", s.handlers.HandleStatus())
r.Get("/domains", s.handlers.HandleDomains())
r.Get("/hostnames", s.handlers.HandleHostnames())
}) })
// Metrics endpoint (optional, with basic auth) // Metrics endpoint (optional, with basic auth)

View File

@@ -57,47 +57,8 @@ type HostnameState struct {
// PortState holds the monitoring state for a port. // PortState holds the monitoring state for a port.
type PortState struct { type PortState struct {
Open bool `json:"open"` Open bool `json:"open"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
// UnmarshalJSON implements custom unmarshaling to handle both
// the old single-hostname format and the new multi-hostname
// format for backward compatibility with existing state files.
func (ps *PortState) UnmarshalJSON(data []byte) error {
// Use an alias to prevent infinite recursion.
type portStateAlias struct {
Open bool `json:"open"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
var alias portStateAlias
err := json.Unmarshal(data, &alias)
if err != nil {
return fmt.Errorf("unmarshaling port state: %w", err)
}
ps.Open = alias.Open
ps.Hostnames = alias.Hostnames
ps.LastChecked = alias.LastChecked
// If Hostnames is empty, try reading the old single-hostname
// format for backward compatibility.
if len(ps.Hostnames) == 0 {
var old struct {
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
} LastChecked time.Time `json:"lastChecked"`
// Best-effort: ignore errors since the main unmarshal
// already succeeded.
if json.Unmarshal(data, &old) == nil && old.Hostname != "" {
ps.Hostnames = []string{old.Hostname}
}
}
return nil
} }
// CertificateState holds TLS certificate monitoring state. // CertificateState holds TLS certificate monitoring state.
@@ -195,8 +156,8 @@ func (s *State) Load() error {
// Save writes the current state to disk atomically. // Save writes the current state to disk atomically.
func (s *State) Save() error { func (s *State) Save() error {
s.mu.Lock() s.mu.RLock()
defer s.mu.Unlock() defer s.mu.RUnlock()
s.snapshot.LastUpdated = time.Now().UTC() s.snapshot.LastUpdated = time.Now().UTC()
@@ -302,27 +263,6 @@ func (s *State) GetPortState(key string) (*PortState, bool) {
return ps, ok return ps, ok
} }
// DeletePortState removes a port state entry.
func (s *State) DeletePortState(key string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.snapshot.Ports, key)
}
// GetAllPortKeys returns all port state keys.
func (s *State) GetAllPortKeys() []string {
s.mu.RLock()
defer s.mu.RUnlock()
keys := make([]string, 0, len(s.snapshot.Ports))
for k := range s.snapshot.Ports {
keys = append(keys, k)
}
return keys
}
// SetCertificateState updates the state for a certificate. // SetCertificateState updates the state for a certificate.
func (s *State) SetCertificateState( func (s *State) SetCertificateState(
key string, key string,

View File

@@ -1,67 +0,0 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func TestCheckCertificateNoPeerCerts(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = ln.Close() }()
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
// Accept and immediately close to cause TLS handshake failure.
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
InsecureSkipVerify: true, //nolint:gosec // test
MinVersion: tls.VersionTLS12,
}),
tlscheck.WithPort(addr.Port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error when server presents no certs")
}
}
func TestErrNoPeerCertificatesIsSentinel(t *testing.T) {
t.Parallel()
err := tlscheck.ErrNoPeerCertificates
if !errors.Is(err, tlscheck.ErrNoPeerCertificates) {
t.Fatal("expected sentinel error to match")
}
}

View File

@@ -3,12 +3,8 @@ package tlscheck
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"net"
"strconv"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@@ -16,56 +12,11 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
const ( // ErrNotImplemented indicates the TLS checker is not yet implemented.
defaultTimeout = 10 * time.Second var ErrNotImplemented = errors.New(
defaultPort = 443 "tls checker not yet implemented",
) )
// ErrUnexpectedConnType indicates the connection was not a TLS
// connection.
var ErrUnexpectedConnType = errors.New(
"unexpected connection type",
)
// ErrNoPeerCertificates indicates the TLS connection had no peer
// certificates.
var ErrNoPeerCertificates = errors.New(
"no peer certificates",
)
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
SerialNumber string
}
// Option configures a Checker.
type Option func(*Checker)
// WithTimeout sets the connection timeout.
func WithTimeout(d time.Duration) Option {
return func(c *Checker) {
c.timeout = d
}
}
// WithTLSConfig sets a custom TLS configuration.
func WithTLSConfig(cfg *tls.Config) Option {
return func(c *Checker) {
c.tlsConfig = cfg
}
}
// WithPort sets the TLS port to connect to.
func WithPort(port int) Option {
return func(c *Checker) {
c.port = port
}
}
// Params contains dependencies for Checker. // Params contains dependencies for Checker.
type Params struct { type Params struct {
fx.In fx.In
@@ -76,9 +27,14 @@ type Params struct {
// Checker performs TLS certificate inspection. // Checker performs TLS certificate inspection.
type Checker struct { type Checker struct {
log *slog.Logger log *slog.Logger
timeout time.Duration }
tlsConfig *tls.Config
port int // CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
} }
// New creates a new TLS Checker instance. // New creates a new TLS Checker instance.
@@ -88,109 +44,15 @@ func New(
) (*Checker, error) { ) (*Checker, error) {
return &Checker{ return &Checker{
log: params.Logger.Get(), log: params.Logger.Get(),
timeout: defaultTimeout,
port: defaultPort,
}, nil }, nil
} }
// NewStandalone creates a Checker without fx dependencies. // CheckCertificate connects to the given IP:port using SNI and
func NewStandalone(opts ...Option) *Checker { // returns certificate information.
checker := &Checker{
log: slog.Default(),
timeout: defaultTimeout,
port: defaultPort,
}
for _, opt := range opts {
opt(checker)
}
return checker
}
// CheckCertificate connects to the given IP address using the
// specified SNI hostname and returns certificate information.
func (c *Checker) CheckCertificate( func (c *Checker) CheckCertificate(
ctx context.Context, _ context.Context,
ipAddress string, _ string,
sniHostname string, _ string,
) (*CertificateInfo, error) { ) (*CertificateInfo, error) {
target := net.JoinHostPort( return nil, ErrNotImplemented
ipAddress, strconv.Itoa(c.port),
)
tlsCfg := c.buildTLSConfig(sniHostname)
dialer := &tls.Dialer{
NetDialer: &net.Dialer{Timeout: c.timeout},
Config: tlsCfg,
}
conn, err := dialer.DialContext(ctx, "tcp", target)
if err != nil {
return nil, fmt.Errorf(
"TLS dial to %s: %w", target, err,
)
}
defer func() {
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing TLS connection",
"target", target,
"error", closeErr.Error(),
)
}
}()
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return nil, fmt.Errorf(
"%s: %w", target, ErrUnexpectedConnType,
)
}
return c.extractCertInfo(tlsConn)
}
func (c *Checker) buildTLSConfig(
sniHostname string,
) *tls.Config {
if c.tlsConfig != nil {
cfg := c.tlsConfig.Clone()
cfg.ServerName = sniHostname
return cfg
}
return &tls.Config{
ServerName: sniHostname,
MinVersion: tls.VersionTLS12,
}
}
func (c *Checker) extractCertInfo(
conn *tls.Conn,
) (*CertificateInfo, error) {
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return nil, ErrNoPeerCertificates
}
cert := state.PeerCertificates[0]
sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses))
sans = append(sans, cert.DNSNames...)
for _, ip := range cert.IPAddresses {
sans = append(sans, ip.String())
}
return &CertificateInfo{
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: sans,
SerialNumber: cert.SerialNumber.String(),
}, nil
} }

View File

@@ -1,169 +0,0 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func startTLSServer(
t *testing.T,
) (*httptest.Server, string, int) {
t.Helper()
srv := httptest.NewTLSServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
},
),
)
addr, ok := srv.Listener.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return srv, addr.IP.String(), addr.Port
}
func TestCheckCertificateValid(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info == nil {
t.Fatal("expected non-nil CertificateInfo")
}
if info.NotAfter.IsZero() {
t.Error("expected non-zero NotAfter")
}
if info.SerialNumber == "" {
t.Error("expected non-empty SerialNumber")
}
}
func TestCheckCertificateConnectionRefused(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
port := addr.Port
_ = ln.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for connection refused")
}
}
func TestCheckCertificateContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
ctx, "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for canceled context")
}
}
func TestCheckCertificateTimeout(t *testing.T) {
t.Parallel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(1*time.Millisecond),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
context.Background(),
"192.0.2.1",
"example.com",
)
if err == nil {
t.Fatal("expected error for timeout")
}
}
func TestCheckCertificateSANs(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.CommonName == "" && len(info.SubjectAlternativeNames) == 0 {
t.Error("expected CN or SANs to be populated")
}
}

View File

@@ -4,7 +4,6 @@ package watcher
import ( import (
"context" "context"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/tlscheck"
) )
@@ -37,7 +36,7 @@ type PortChecker interface {
ctx context.Context, ctx context.Context,
address string, address string,
port int, port int,
) (*portcheck.PortResult, error) ) (bool, error)
} }
// TLSChecker inspects TLS certificates. // TLSChecker inspects TLS certificates.

View File

@@ -6,7 +6,6 @@ import (
"log/slog" "log/slog"
"sort" "sort"
"strings" "strings"
"sync"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@@ -50,8 +49,6 @@ type Watcher struct {
notify Notifier notify Notifier
cancel context.CancelFunc cancel context.CancelFunc
firstRun bool firstRun bool
expiryNotifiedMu sync.Mutex
expiryNotified map[string]time.Time
} }
// New creates a new Watcher instance wired into the fx lifecycle. // New creates a new Watcher instance wired into the fx lifecycle.
@@ -68,19 +65,16 @@ func New(
tlsCheck: params.TLSCheck, tlsCheck: params.TLSCheck,
notify: params.Notify, notify: params.Notify,
firstRun: true, firstRun: true,
expiryNotified: make(map[string]time.Time),
} }
lifecycle.Append(fx.Hook{ lifecycle.Append(fx.Hook{
OnStart: func(_ context.Context) error { OnStart: func(startCtx context.Context) error {
// Use context.Background() — the fx startup context ctx, cancel := context.WithCancel(
// expires after startup completes, so deriving from it context.WithoutCancel(startCtx),
// would cancel the watcher immediately. The watcher's )
// lifetime is controlled by w.cancel in OnStop.
ctx, cancel := context.WithCancel(context.Background())
w.cancel = cancel w.cancel = cancel
go w.Run(ctx) //nolint:contextcheck // intentionally not derived from startCtx go w.Run(ctx)
return nil return nil
}, },
@@ -114,7 +108,6 @@ func NewForTest(
tlsCheck: tc, tlsCheck: tc,
notify: n, notify: n,
firstRun: true, firstRun: true,
expiryNotified: make(map[string]time.Time),
} }
} }
@@ -143,16 +136,9 @@ func (w *Watcher) Run(ctx context.Context) {
return return
case <-dnsTicker.C: case <-dnsTicker.C:
w.runDNSChecks(ctx) w.runDNSAndPortChecks(ctx)
w.checkAllPorts(ctx)
w.saveState() w.saveState()
case <-tlsTicker.C: case <-tlsTicker.C:
// Run DNS first so TLS checks use freshly
// resolved IP addresses, not stale ones from
// a previous cycle.
w.runDNSChecks(ctx)
w.runTLSChecks(ctx) w.runTLSChecks(ctx)
w.saveState() w.saveState()
} }
@@ -160,26 +146,10 @@ func (w *Watcher) Run(ctx context.Context) {
} }
// RunOnce performs a single complete monitoring cycle. // RunOnce performs a single complete monitoring cycle.
// DNS checks run first so that port and TLS checks use
// freshly resolved IP addresses. Port checks run before
// TLS because TLS checks only target IPs with an open
// port 443.
func (w *Watcher) RunOnce(ctx context.Context) { func (w *Watcher) RunOnce(ctx context.Context) {
w.detectFirstRun() w.detectFirstRun()
w.runDNSAndPortChecks(ctx)
// Phase 1: DNS resolution must complete first so that
// subsequent checks use fresh IP addresses.
w.runDNSChecks(ctx)
// Phase 2: Port checks populate port state that TLS
// checks depend on (TLS only targets IPs where port
// 443 is open).
w.checkAllPorts(ctx)
// Phase 3: TLS checks use fresh DNS IPs and current
// port state.
w.runTLSChecks(ctx) w.runTLSChecks(ctx)
w.saveState() w.saveState()
w.firstRun = false w.firstRun = false
} }
@@ -196,11 +166,7 @@ func (w *Watcher) detectFirstRun() {
} }
} }
// runDNSChecks performs DNS resolution for all configured domains func (w *Watcher) runDNSAndPortChecks(ctx context.Context) {
// and hostnames, updating state with freshly resolved records.
// This must complete before port or TLS checks run so those
// checks operate on current IP addresses.
func (w *Watcher) runDNSChecks(ctx context.Context) {
for _, domain := range w.config.Domains { for _, domain := range w.config.Domains {
w.checkDomain(ctx, domain) w.checkDomain(ctx, domain)
} }
@@ -208,6 +174,8 @@ func (w *Watcher) runDNSChecks(ctx context.Context) {
for _, hostname := range w.config.Hostnames { for _, hostname := range w.config.Hostnames {
w.checkHostname(ctx, hostname) w.checkHostname(ctx, hostname)
} }
w.checkAllPorts(ctx)
} }
func (w *Watcher) checkDomain( func (w *Watcher) checkDomain(
@@ -238,28 +206,6 @@ func (w *Watcher) checkDomain(
Nameservers: nameservers, Nameservers: nameservers,
LastChecked: now, LastChecked: now,
}) })
// Also look up A/AAAA records for the apex domain so that
// port and TLS checks (which read HostnameState) can find
// the domain's IP addresses.
records, err := w.resolver.LookupAllRecords(ctx, domain)
if err != nil {
w.log.Error(
"failed to lookup records for domain",
"domain", domain,
"error", err,
)
return
}
prevHS, hasPrevHS := w.state.GetHostnameState(domain)
if hasPrevHS && !w.firstRun {
w.detectHostnameChanges(ctx, domain, prevHS, records)
}
newState := buildHostnameState(records, now)
w.state.SetHostnameState(domain, newState)
} }
func (w *Watcher) detectNSChanges( func (w *Watcher) detectNSChanges(
@@ -475,94 +421,24 @@ func (w *Watcher) detectInconsistencies(
} }
func (w *Watcher) checkAllPorts(ctx context.Context) { func (w *Watcher) checkAllPorts(ctx context.Context) {
// Phase 1: Build current IP:port → hostname associations for _, hostname := range w.config.Hostnames {
// from fresh DNS data. w.checkPortsForHostname(ctx, hostname)
associations := w.buildPortAssociations()
// Phase 2: Check each unique IP:port and update state
// with the full set of associated hostnames.
for key, hostnames := range associations {
ip, port := parsePortKey(key)
if port == 0 {
continue
} }
w.checkSinglePort(ctx, ip, port, hostnames) for _, domain := range w.config.Domains {
w.checkPortsForHostname(ctx, domain)
} }
// Phase 3: Remove port state entries that no longer have
// any hostname referencing them.
w.cleanupStalePorts(associations)
} }
// buildPortAssociations constructs a map from IP:port keys to func (w *Watcher) checkPortsForHostname(
// the sorted set of hostnames currently resolving to that IP. ctx context.Context,
func (w *Watcher) buildPortAssociations() map[string][]string { hostname string,
assoc := make(map[string]map[string]bool) ) {
ips := w.collectIPs(hostname)
allNames := make(
[]string, 0,
len(w.config.Hostnames)+len(w.config.Domains),
)
allNames = append(allNames, w.config.Hostnames...)
allNames = append(allNames, w.config.Domains...)
for _, name := range allNames {
ips := w.collectIPs(name)
for _, ip := range ips { for _, ip := range ips {
for _, port := range monitoredPorts { for _, port := range monitoredPorts {
key := fmt.Sprintf("%s:%d", ip, port) w.checkSinglePort(ctx, ip, port, hostname)
if assoc[key] == nil {
assoc[key] = make(map[string]bool)
}
assoc[key][name] = true
}
}
}
result := make(map[string][]string, len(assoc))
for key, set := range assoc {
hostnames := make([]string, 0, len(set))
for h := range set {
hostnames = append(hostnames, h)
}
sort.Strings(hostnames)
result[key] = hostnames
}
return result
}
// parsePortKey splits an "ip:port" key into its components.
func parsePortKey(key string) (string, int) {
lastColon := strings.LastIndex(key, ":")
if lastColon < 0 {
return key, 0
}
ip := key[:lastColon]
var p int
_, err := fmt.Sscanf(key[lastColon+1:], "%d", &p)
if err != nil {
return ip, 0
}
return ip, p
}
// cleanupStalePorts removes port state entries that are no
// longer referenced by any hostname in the current DNS data.
func (w *Watcher) cleanupStalePorts(
currentAssociations map[string][]string,
) {
for _, key := range w.state.GetAllPortKeys() {
if _, exists := currentAssociations[key]; !exists {
w.state.DeletePortState(key)
} }
} }
} }
@@ -599,9 +475,9 @@ func (w *Watcher) checkSinglePort(
ctx context.Context, ctx context.Context,
ip string, ip string,
port int, port int,
hostnames []string, hostname string,
) { ) {
result, err := w.portCheck.CheckPort(ctx, ip, port) open, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil { if err != nil {
w.log.Error( w.log.Error(
"port check failed", "port check failed",
@@ -617,15 +493,15 @@ func (w *Watcher) checkSinglePort(
now := time.Now().UTC() now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key) prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != result.Open { if hasPrev && !w.firstRun && prev.Open != open {
stateStr := "closed" stateStr := "closed"
if result.Open { if open {
stateStr = "open" stateStr = "open"
} }
msg := fmt.Sprintf( msg := fmt.Sprintf(
"Hosts: %s\nAddress: %s\nPort now %s", "Host: %s\nAddress: %s\nPort now %s",
strings.Join(hostnames, ", "), key, stateStr, hostname, key, stateStr,
) )
w.notify.SendNotification( w.notify.SendNotification(
@@ -637,8 +513,8 @@ func (w *Watcher) checkSinglePort(
} }
w.state.SetPortState(key, &state.PortState{ w.state.SetPortState(key, &state.PortState{
Open: result.Open, Open: open,
Hostnames: hostnames, Hostname: hostname,
LastChecked: now, LastChecked: now,
}) })
} }
@@ -815,22 +691,6 @@ func (w *Watcher) checkTLSExpiry(
return return
} }
// Deduplicate expiry warnings: don't re-notify for the same
// hostname within the TLS check interval.
dedupKey := fmt.Sprintf("expiry:%s:%s", hostname, ip)
w.expiryNotifiedMu.Lock()
lastNotified, seen := w.expiryNotified[dedupKey]
if seen && time.Since(lastNotified) < w.config.TLSInterval {
w.expiryNotifiedMu.Unlock()
return
}
w.expiryNotified[dedupKey] = time.Now()
w.expiryNotifiedMu.Unlock()
msg := fmt.Sprintf( msg := fmt.Sprintf(
"Host: %s\nIP: %s\nCN: %s\n"+ "Host: %s\nIP: %s\nCN: %s\n"+
"Expires: %s (%.0f days)", "Expires: %s (%.0f days)",

View File

@@ -9,7 +9,6 @@ import (
"time" "time"
"sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/tlscheck"
"sneak.berlin/go/dnswatcher/internal/watcher" "sneak.berlin/go/dnswatcher/internal/watcher"
@@ -110,20 +109,24 @@ func (m *mockPortChecker) CheckPort(
_ context.Context, _ context.Context,
address string, address string,
port int, port int,
) (*portcheck.PortResult, error) { ) (bool, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.calls++ m.calls++
if m.err != nil { if m.err != nil {
return nil, m.err return false, m.err
} }
key := fmt.Sprintf("%s:%d", address, port) key := fmt.Sprintf("%s:%d", address, port)
open := m.results[key] open, ok := m.results[key]
return &portcheck.PortResult{Open: open}, nil if !ok {
return false, nil
}
return open, nil
} }
type mockTLSChecker struct { type mockTLSChecker struct {
@@ -273,10 +276,6 @@ func setupBaselineMocks(deps *testDeps) {
"ns1.example.com.", "ns1.example.com.",
"ns2.example.com.", "ns2.example.com.",
} }
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
"ns2.example.com.": {"A": {"93.184.216.34"}},
}
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{ deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}}, "ns1.example.com.": {"A": {"93.184.216.34"}},
"ns2.example.com.": {"A": {"93.184.216.34"}}, "ns2.example.com.": {"A": {"93.184.216.34"}},
@@ -294,14 +293,6 @@ func setupBaselineMocks(deps *testDeps) {
"www.example.com", "www.example.com",
}, },
} }
deps.tlsChecker.certs["93.184.216.34:example.com"] = &tlscheck.CertificateInfo{
CommonName: "example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"example.com",
},
}
} }
func assertNoNotifications( func assertNoNotifications(
@@ -334,74 +325,14 @@ func assertStatePopulated(
) )
} }
// Hostnames includes both explicit hostnames and domains if len(snap.Hostnames) != 1 {
// (domains now also get hostname state for port/TLS checks).
if len(snap.Hostnames) < 1 {
t.Errorf( t.Errorf(
"expected at least 1 hostname in state, got %d", "expected 1 hostname in state, got %d",
len(snap.Hostnames), len(snap.Hostnames),
) )
} }
} }
func TestDomainPortAndTLSChecks(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Domains = []string{"example.com"}
w, deps := newTestWatcher(t, cfg)
deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.",
}
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"93.184.216.34"}},
}
deps.portChecker.results["93.184.216.34:80"] = true
deps.portChecker.results["93.184.216.34:443"] = true
deps.tlsChecker.certs["93.184.216.34:example.com"] = &tlscheck.CertificateInfo{
CommonName: "example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"example.com",
},
}
w.RunOnce(t.Context())
snap := deps.state.GetSnapshot()
// Domain should have port state populated
if len(snap.Ports) == 0 {
t.Error("expected port state for domain, got none")
}
// Domain should have certificate state populated
if len(snap.Certificates) == 0 {
t.Error("expected certificate state for domain, got none")
}
// Verify port checker was actually called
deps.portChecker.mu.Lock()
calls := deps.portChecker.calls
deps.portChecker.mu.Unlock()
if calls == 0 {
t.Error("expected port checker to be called for domain")
}
// Verify TLS checker was actually called
deps.tlsChecker.mu.Lock()
tlsCalls := deps.tlsChecker.calls
deps.tlsChecker.mu.Unlock()
if tlsCalls == 0 {
t.Error("expected TLS checker to be called for domain")
}
}
func TestNSChangeDetection(t *testing.T) { func TestNSChangeDetection(t *testing.T) {
t.Parallel() t.Parallel()
@@ -414,12 +345,6 @@ func TestNSChangeDetection(t *testing.T) {
"ns1.example.com.", "ns1.example.com.",
"ns2.example.com.", "ns2.example.com.",
} }
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
"ns2.example.com.": {"A": {"1.2.3.4"}},
}
deps.portChecker.results["1.2.3.4:80"] = false
deps.portChecker.results["1.2.3.4:443"] = false
ctx := t.Context() ctx := t.Context()
w.RunOnce(ctx) w.RunOnce(ctx)
@@ -429,10 +354,6 @@ func TestNSChangeDetection(t *testing.T) {
"ns1.example.com.", "ns1.example.com.",
"ns3.example.com.", "ns3.example.com.",
} }
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
"ns3.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.mu.Unlock() deps.resolver.mu.Unlock()
w.RunOnce(ctx) w.RunOnce(ctx)
@@ -588,61 +509,6 @@ func TestTLSExpiryWarning(t *testing.T) {
} }
} }
func TestTLSExpiryWarningDedup(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
cfg.TLSInterval = 24 * time.Hour
w, deps := newTestWatcher(t, cfg)
deps.resolver.allRecords["www.example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.resolver.ipAddresses["www.example.com"] = []string{
"1.2.3.4",
}
deps.portChecker.results["1.2.3.4:80"] = true
deps.portChecker.results["1.2.3.4:443"] = true
deps.tlsChecker.certs["1.2.3.4:www.example.com"] = &tlscheck.CertificateInfo{
CommonName: "www.example.com",
Issuer: "DigiCert",
NotAfter: time.Now().Add(3 * 24 * time.Hour),
SubjectAlternativeNames: []string{
"www.example.com",
},
}
ctx := t.Context()
// First run = baseline, no notifications
w.RunOnce(ctx)
// Second run should fire one expiry warning
w.RunOnce(ctx)
// Third run should NOT fire another warning (dedup)
w.RunOnce(ctx)
notifications := deps.notifier.getNotifications()
expiryCount := 0
for _, n := range notifications {
if n.Title == "TLS Expiry Warning: www.example.com" {
expiryCount++
}
}
if expiryCount != 1 {
t.Errorf(
"expected exactly 1 expiry warning (dedup), got %d",
expiryCount,
)
}
}
func TestGracefulShutdown(t *testing.T) { func TestGracefulShutdown(t *testing.T) {
t.Parallel() t.Parallel()
@@ -656,11 +522,6 @@ func TestGracefulShutdown(t *testing.T) {
deps.resolver.nsRecords["example.com"] = []string{ deps.resolver.nsRecords["example.com"] = []string{
"ns1.example.com.", "ns1.example.com.",
} }
deps.resolver.allRecords["example.com"] = map[string]map[string][]string{
"ns1.example.com.": {"A": {"1.2.3.4"}},
}
deps.portChecker.results["1.2.3.4:80"] = false
deps.portChecker.results["1.2.3.4:443"] = false
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancel(t.Context())
@@ -682,80 +543,6 @@ func TestGracefulShutdown(t *testing.T) {
} }
} }
func setupHostnameIP(
deps *testDeps,
hostname, ip string,
) {
deps.resolver.allRecords[hostname] = map[string]map[string][]string{
"ns1.example.com.": {"A": {ip}},
}
deps.portChecker.results[ip+":80"] = true
deps.portChecker.results[ip+":443"] = true
deps.tlsChecker.certs[ip+":"+hostname] = &tlscheck.CertificateInfo{
CommonName: hostname,
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{hostname},
}
}
func updateHostnameIP(deps *testDeps, hostname, ip string) {
deps.resolver.mu.Lock()
deps.resolver.allRecords[hostname] = map[string]map[string][]string{
"ns1.example.com.": {"A": {ip}},
}
deps.resolver.mu.Unlock()
deps.portChecker.mu.Lock()
deps.portChecker.results[ip+":80"] = true
deps.portChecker.results[ip+":443"] = true
deps.portChecker.mu.Unlock()
deps.tlsChecker.mu.Lock()
deps.tlsChecker.certs[ip+":"+hostname] = &tlscheck.CertificateInfo{
CommonName: hostname,
Issuer: "DigiCert",
NotAfter: time.Now().Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{hostname},
}
deps.tlsChecker.mu.Unlock()
}
func TestDNSRunsBeforePortAndTLSChecks(t *testing.T) {
t.Parallel()
cfg := defaultTestConfig(t)
cfg.Hostnames = []string{"www.example.com"}
w, deps := newTestWatcher(t, cfg)
setupHostnameIP(deps, "www.example.com", "10.0.0.1")
ctx := t.Context()
w.RunOnce(ctx)
snap := deps.state.GetSnapshot()
if _, ok := snap.Ports["10.0.0.1:80"]; !ok {
t.Fatal("expected port state for 10.0.0.1:80")
}
// DNS changes to a new IP; port and TLS must pick it up.
updateHostnameIP(deps, "www.example.com", "10.0.0.2")
w.RunOnce(ctx)
snap = deps.state.GetSnapshot()
if _, ok := snap.Ports["10.0.0.2:80"]; !ok {
t.Error("port check used stale DNS: missing 10.0.0.2:80")
}
certKey := "10.0.0.2:443:www.example.com"
if _, ok := snap.Certificates[certKey]; !ok {
t.Error("TLS check used stale DNS: missing " + certKey)
}
}
func TestNSFailureAndRecovery(t *testing.T) { func TestNSFailureAndRecovery(t *testing.T) {
t.Parallel() t.Parallel()