12 Commits

Author SHA1 Message Date
clawbot
889855306f fix: mock DNS in resolver tests for deterministic fast suite
All checks were successful
Check / check (pull_request) Successful in 10m24s
Replace all real DNS queries in resolver_test.go with a mock DNSClient
that simulates the full iterative resolution hierarchy (root → TLD → auth NS).

- Uses NewFromLoggerWithClient with mock DNSClient
- All 28 test behaviors preserved (NS lookup, A/AAAA/MX/TXT, NXDOMAIN,
  sorting, dedup, context cancellation, trailing dots)
- Tests pass with -race flag, no data races
- Total resolver test time: ~1.5s (was >30s)
- Zero linter issues

Closes #32
2026-02-22 04:25:33 -08:00
8cfff5dcc8 Merge pull request 'fix: use full Lock in State.Save() to prevent data race (closes #17)' (#20) from fix/state-save-data-race into main
Some checks failed
Check / check (push) Failing after 5m43s
Reviewed-on: #20
2026-02-21 11:22:46 +01:00
clawbot
b162ca743b fix: use full Lock in State.Save() to prevent data race (closes #17)
Some checks failed
Check / check (pull_request) Failing after 5m31s
State.Save() was using RLock but mutating s.snapshot.LastUpdated,
which is a write operation. This created a data race since other
goroutines could also hold a read lock and observe a partially
written timestamp. Changed to full Lock to ensure exclusive access
during the mutation.
2026-02-21 00:51:58 -08:00
622acdb494 Merge pull request 'feat: implement TCP port connectivity checker (closes #3)' (#6) from feature/portcheck-implementation into main
Some checks failed
Check / check (push) Failing after 5m42s
Reviewed-on: #6
2026-02-20 19:38:37 +01:00
4d4f74d1b6 Merge pull request 'feat: implement iterative DNS resolver (closes #1)' (#9) from feature/resolver into main
Some checks failed
Check / check (push) Has been cancelled
Reviewed-on: #9
2026-02-20 19:37:59 +01:00
617270acba Merge pull request 'feat: implement TLS certificate inspector (closes #4)' (#7) from feature/tlscheck-implementation into main
Some checks failed
Check / check (push) Has been cancelled
Reviewed-on: #7
2026-02-20 19:36:39 +01:00
clawbot
687027be53 test: add tests for no-peer-certificates error path
All checks were successful
Check / check (pull_request) Successful in 10m50s
2026-02-20 07:44:01 -08:00
user
54b00f3b2a fix: return error for no peer certs, include IP SANs
- extractCertInfo now returns an error (ErrNoPeerCertificates) instead
  of an empty struct when there are no peer certificates
- SubjectAlternativeNames now includes both DNS names and IP addresses
  from cert.IPAddresses

Addresses review feedback on PR #7.
2026-02-20 07:44:01 -08:00
clawbot
3fcf203485 fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
Fix goimports formatting in tlscheck_test.go.
2026-02-20 07:44:01 -08:00
clawbot
8770c942cb feat: implement TLS certificate inspector (closes #4) 2026-02-20 07:43:47 -08:00
user
57cd228837 feat: make CheckPorts concurrent and add port validation
- CheckPorts now runs all port checks concurrently using errgroup
- Added port number validation (1-65535) with ErrInvalidPort sentinel error
- Updated PortChecker interface to use *PortResult return type
- Added tests for invalid port numbers (0, negative, >65535)
- All checks pass (make check clean)
2026-02-20 00:14:55 -08:00
clawbot
ab39e77015 feat: implement TCP port connectivity checker (closes #3) 2026-02-20 00:11:26 -08:00
11 changed files with 1328 additions and 94 deletions

1
go.mod
View File

@@ -13,6 +13,7 @@ require (
github.com/stretchr/testify v1.11.1
go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0
golang.org/x/sync v0.19.0
)
require (

View File

@@ -4,18 +4,39 @@ package portcheck
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"sync"
"time"
"go.uber.org/fx"
"golang.org/x/sync/errgroup"
"sneak.berlin/go/dnswatcher/internal/logger"
)
// ErrNotImplemented indicates the port checker is not yet implemented.
var ErrNotImplemented = errors.New(
"port checker not yet implemented",
const (
minPort = 1
maxPort = 65535
defaultTimeout = 5 * time.Second
)
// ErrInvalidPort is returned when a port number is outside
// the valid TCP range (165535).
var ErrInvalidPort = errors.New("invalid port number")
// PortResult holds the outcome of a single TCP port check.
type PortResult struct {
// Open indicates whether the port accepted a connection.
Open bool
// Error contains a description if the connection failed.
Error string
// Latency is the time taken for the TCP handshake.
Latency time.Duration
}
// Params contains dependencies for Checker.
type Params struct {
fx.In
@@ -38,11 +59,145 @@ func New(
}, nil
}
// CheckPort tests TCP connectivity to the given address and port.
func (c *Checker) CheckPort(
_ context.Context,
_ string,
_ int,
) (bool, error) {
return false, ErrNotImplemented
// NewStandalone creates a Checker without fx dependencies.
func NewStandalone() *Checker {
return &Checker{
log: slog.Default(),
}
}
// validatePort checks that a port number is within the valid
// TCP port range (165535).
func validatePort(port int) error {
if port < minPort || port > maxPort {
return fmt.Errorf(
"%w: %d (must be between %d and %d)",
ErrInvalidPort, port, minPort, maxPort,
)
}
return nil
}
// CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier
// deadline.
func (c *Checker) CheckPort(
ctx context.Context,
address string,
port int,
) (*PortResult, error) {
err := validatePort(port)
if err != nil {
return nil, err
}
target := net.JoinHostPort(
address, strconv.Itoa(port),
)
timeout := defaultTimeout
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
remaining := time.Until(deadline)
if remaining < timeout {
timeout = remaining
}
}
return c.checkConnection(ctx, target, timeout), nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address concurrently. It returns a map of port number
// to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
for _, port := range ports {
err := validatePort(port)
if err != nil {
return nil, err
}
}
var mu sync.Mutex
results := make(map[int]*PortResult, len(ports))
g, ctx := errgroup.WithContext(ctx)
for _, port := range ports {
g.Go(func() error {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return fmt.Errorf(
"checking port %d: %w", port, err,
)
}
mu.Lock()
results[port] = result
mu.Unlock()
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return results, nil
}
// checkConnection performs the TCP dial and returns a result.
func (c *Checker) checkConnection(
ctx context.Context,
target string,
timeout time.Duration,
) *PortResult {
dialer := &net.Dialer{Timeout: timeout}
start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
latency := time.Since(start)
if dialErr != nil {
c.log.Debug(
"port check failed",
"target", target,
"error", dialErr.Error(),
)
return &PortResult{
Open: false,
Error: dialErr.Error(),
Latency: latency,
}
}
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing connection",
"target", target,
"error", closeErr.Error(),
)
}
c.log.Debug(
"port check succeeded",
"target", target,
"latency", latency,
)
return &PortResult{
Open: true,
Latency: latency,
}
}

View File

@@ -0,0 +1,211 @@
package portcheck_test
import (
"context"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/portcheck"
)
func listenTCP(
t *testing.T,
) (net.Listener, int) {
t.Helper()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to start listener: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return ln, addr.Port
}
func TestCheckPortOpen(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
defer func() { _ = ln.Close() }()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !result.Open {
t.Error("expected port to be open")
}
if result.Error != "" {
t.Errorf("expected no error, got: %s", result.Error)
}
if result.Latency <= 0 {
t.Error("expected positive latency")
}
}
func TestCheckPortClosed(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
_ = ln.Close()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Open {
t.Error("expected port to be closed")
}
if result.Error == "" {
t.Error("expected error message for closed port")
}
}
func TestCheckPortContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(ctx, "127.0.0.1", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Open {
t.Error("expected port to not be open")
}
}
func TestCheckPortsMultiple(t *testing.T) {
t.Parallel()
ln, openPort := listenTCP(t)
defer func() { _ = ln.Close() }()
ln2, closedPort := listenTCP(t)
_ = ln2.Close()
checker := portcheck.NewStandalone()
results, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{openPort, closedPort},
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 2 {
t.Fatalf(
"expected 2 results, got %d", len(results),
)
}
if !results[openPort].Open {
t.Error("expected open port to be open")
}
if results[closedPort].Open {
t.Error("expected closed port to be closed")
}
}
func TestCheckPortInvalidPorts(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
cases := []struct {
name string
port int
}{
{"zero", 0},
{"negative", -1},
{"too high", 65536},
{"very negative", -1000},
{"very high", 100000},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := checker.CheckPort(
context.Background(), "127.0.0.1", tc.port,
)
if err == nil {
t.Errorf(
"expected error for port %d, got nil",
tc.port,
)
}
})
}
}
func TestCheckPortsInvalidPort(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
_, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{80, 0, 443},
)
if err == nil {
t.Error("expected error for invalid port in list")
}
}
func TestCheckPortLatencyReasonable(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
defer func() { _ = ln.Close() }()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Latency > time.Second {
t.Errorf(
"latency too high for localhost: %v",
result.Latency,
)
}
}

View File

@@ -2,6 +2,7 @@ package resolver_test
import (
"context"
"fmt"
"log/slog"
"net"
"os"
@@ -10,12 +11,504 @@ import (
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sneak.berlin/go/dnswatcher/internal/resolver"
)
// ----------------------------------------------------------------
// Mock DNS client
// ----------------------------------------------------------------
// mockDNSClient implements resolver.DNSClient with canned responses.
type mockDNSClient struct {
handlers map[string]func(msg *dns.Msg) *dns.Msg
}
func newMockClient() *mockDNSClient {
return &mockDNSClient{
handlers: make(map[string]func(msg *dns.Msg) *dns.Msg),
}
}
func (m *mockDNSClient) ExchangeContext(
ctx context.Context,
msg *dns.Msg,
addr string,
) (*dns.Msg, time.Duration, error) {
err := ctx.Err()
if err != nil {
return nil, 0, err
}
host, _, _ := net.SplitHostPort(addr)
if host == "" {
host = addr
}
qname := msg.Question[0].Name
qtype := dns.TypeToString[msg.Question[0].Qtype]
resp := m.findHandler(host, qname, qtype, msg)
return resp, time.Millisecond, nil
}
func (m *mockDNSClient) findHandler(
host, qname, qtype string,
msg *dns.Msg,
) *dns.Msg {
key := fmt.Sprintf(
"%s|%s|%s", host, strings.ToLower(qname), qtype,
)
if h, ok := m.handlers[key]; ok {
return h(msg)
}
wildKey := fmt.Sprintf(
"*|%s|%s", strings.ToLower(qname), qtype,
)
if h, ok := m.handlers[wildKey]; ok {
return h(msg)
}
resp := new(dns.Msg)
resp.SetReply(msg)
return resp
}
func (m *mockDNSClient) on(
server, qname, qtype string,
handler func(msg *dns.Msg) *dns.Msg,
) {
key := fmt.Sprintf(
"%s|%s|%s",
server, dns.Fqdn(strings.ToLower(qname)), qtype,
)
m.handlers[key] = handler
}
// ----------------------------------------------------------------
// Response builders
// ----------------------------------------------------------------
func referralResponse(
msg *dns.Msg,
nsNames []string,
glue map[string]string,
) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
for _, ns := range nsNames {
resp.Ns = append(resp.Ns, &dns.NS{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 3600,
},
Ns: dns.Fqdn(ns),
})
}
for name, ip := range glue {
resp.Extra = append(resp.Extra, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name),
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
},
A: net.ParseIP(ip),
})
}
return resp
}
func nsAnswerResponse(
msg *dns.Msg, nsNames []string,
) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
for _, ns := range nsNames {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: 3600,
},
Ns: dns.Fqdn(ns),
})
}
return resp
}
func nxdomainResponse(msg *dns.Msg) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
resp.Rcode = dns.RcodeNameError
return resp
}
func aResponse(
msg *dns.Msg, name string, ip string,
) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name), Rrtype: dns.TypeA,
Class: dns.ClassINET, Ttl: 300,
},
A: net.ParseIP(ip),
})
return resp
}
func aaaaResponse(
msg *dns.Msg, name string, ip string,
) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
resp.Answer = append(resp.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn(name), Rrtype: dns.TypeAAAA,
Class: dns.ClassINET, Ttl: 300,
},
AAAA: net.ParseIP(ip),
})
return resp
}
func emptyResponse(msg *dns.Msg) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
return resp
}
// ----------------------------------------------------------------
// Mock DNS hierarchy setup
// ----------------------------------------------------------------
// mockData holds all test DNS hierarchy configuration.
type mockData struct {
tldNS []string
tldGlue map[string]string
exNS []string
exGlue map[string]string
cfNS []string
cfGlue map[string]string
}
func newMockData() mockData {
return mockData{
tldNS: []string{"ns1.tld.com", "ns2.tld.com"},
tldGlue: map[string]string{
"ns1.tld.com": "10.0.0.1",
"ns2.tld.com": "10.0.0.2",
},
exNS: []string{
"ns1.example.com", "ns2.example.com",
"ns3.example.com",
},
exGlue: map[string]string{
"ns1.example.com": "10.1.0.1",
"ns2.example.com": "10.1.0.2",
"ns3.example.com": "10.1.0.3",
},
cfNS: []string{
"ns1.cloudflare.com", "ns2.cloudflare.com",
},
cfGlue: map[string]string{
"ns1.cloudflare.com": "10.2.0.1",
"ns2.cloudflare.com": "10.2.0.2",
},
}
}
func rootIPList() []string {
return []string{
"198.41.0.4", "170.247.170.2", "192.33.4.12",
"199.7.91.13", "192.203.230.10", "192.5.5.241",
"192.112.36.4", "198.97.190.53", "192.36.148.17",
"192.58.128.30", "193.0.14.129", "199.7.83.42",
"202.12.27.33",
}
}
func allQueryTypes() []string {
return []string{
"NS", "A", "AAAA", "CNAME", "MX", "TXT", "SRV", "CAA",
}
}
func setupRootDelegations(
m *mockDNSClient,
tNS []string,
tGlue map[string]string,
) {
domains := []string{
"example.com.", "www.example.com.",
"this-surely-does-not-exist-xyz.example.com.",
"cloudflare.com.",
}
for _, rootIP := range rootIPList() {
for _, domain := range domains {
for _, qtype := range allQueryTypes() {
m.on(rootIP, domain, qtype,
func(msg *dns.Msg) *dns.Msg {
return referralResponse(
msg, tNS, tGlue,
)
},
)
}
}
}
}
func setupRootARecords(m *mockDNSClient) {
nsIPs := map[string]string{
"ns1.example.com.": "10.1.0.1",
"ns2.example.com.": "10.1.0.2",
"ns3.example.com.": "10.1.0.3",
"ns1.cloudflare.com.": "10.2.0.1",
"ns2.cloudflare.com.": "10.2.0.2",
}
for _, rootIP := range rootIPList() {
for nsName, nsIP := range nsIPs {
ip := nsIP
name := nsName
m.on(rootIP, name, "A",
func(msg *dns.Msg) *dns.Msg {
return aResponse(msg, name, ip)
},
)
}
}
}
func setupTLDDelegations(
m *mockDNSClient,
exNS []string,
exGlue map[string]string,
cfNS []string,
cfGlue map[string]string,
) {
tldIPs := []string{"10.0.0.1", "10.0.0.2"}
exDomains := []string{
"example.com.", "www.example.com.",
"this-surely-does-not-exist-xyz.example.com.",
}
for _, tldIP := range tldIPs {
for _, domain := range exDomains {
for _, qtype := range allQueryTypes() {
m.on(tldIP, domain, qtype,
func(msg *dns.Msg) *dns.Msg {
return referralResponse(
msg, exNS, exGlue,
)
},
)
}
}
for _, qtype := range allQueryTypes() {
m.on(tldIP, "cloudflare.com.", qtype,
func(msg *dns.Msg) *dns.Msg {
return referralResponse(
msg, cfNS, cfGlue,
)
},
)
}
}
}
func setupExampleNSAndA(
m *mockDNSClient, exNS []string,
) {
exIPs := []string{"10.1.0.1", "10.1.0.2", "10.1.0.3"}
for _, authIP := range exIPs {
m.on(authIP, "example.com.", "NS",
func(msg *dns.Msg) *dns.Msg {
return nsAnswerResponse(msg, exNS)
},
)
m.on(authIP, "example.com.", "A",
func(msg *dns.Msg) *dns.Msg {
return aResponse(
msg, "example.com.", "93.184.216.34",
)
},
)
m.on(authIP, "example.com.", "AAAA",
func(msg *dns.Msg) *dns.Msg {
return aaaaResponse(
msg, "example.com.",
"2606:2800:220:1:248:1893:25c8:1946",
)
},
)
}
}
func setupExampleMXAndTXT(m *mockDNSClient) {
exIPs := []string{"10.1.0.1", "10.1.0.2", "10.1.0.3"}
for _, authIP := range exIPs {
m.on(authIP, "example.com.", "MX",
func(msg *dns.Msg) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
resp.Answer = append(resp.Answer,
&dns.MX{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeMX,
Class: dns.ClassINET,
Ttl: 300,
},
Preference: 10,
Mx: "mail.example.com.",
},
&dns.MX{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeMX,
Class: dns.ClassINET,
Ttl: 300,
},
Preference: 20,
Mx: "mail2.example.com.",
},
)
return resp
},
)
m.on(authIP, "example.com.", "TXT",
func(msg *dns.Msg) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(msg)
resp.Answer = append(resp.Answer, &dns.TXT{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 300,
},
Txt: []string{
"v=spf1 include:_spf.example.com ~all",
},
})
return resp
},
)
}
}
func setupExampleSubdomains(
m *mockDNSClient, exNS []string,
) {
exIPs := []string{"10.1.0.1", "10.1.0.2", "10.1.0.3"}
for _, authIP := range exIPs {
m.on(authIP, "www.example.com.", "NS",
func(msg *dns.Msg) *dns.Msg {
return nsAnswerResponse(msg, exNS)
},
)
m.on(authIP, "www.example.com.", "A",
func(msg *dns.Msg) *dns.Msg {
return aResponse(
msg, "www.example.com.", "93.184.216.34",
)
},
)
nxName := "this-surely-does-not-exist-xyz.example.com."
for _, qtype := range allQueryTypes() {
m.on(authIP, nxName, qtype, nxdomainResponse)
}
}
}
func setupCloudflareAuthRecords(
m *mockDNSClient, cfNS []string,
) {
cfIPs := []string{"10.2.0.1", "10.2.0.2"}
for _, authIP := range cfIPs {
m.on(authIP, "cloudflare.com.", "NS",
func(msg *dns.Msg) *dns.Msg {
return nsAnswerResponse(msg, cfNS)
},
)
m.on(authIP, "cloudflare.com.", "A",
func(msg *dns.Msg) *dns.Msg {
return aResponse(
msg, "cloudflare.com.", "104.16.132.229",
)
},
)
m.on(authIP, "cloudflare.com.", "AAAA",
func(msg *dns.Msg) *dns.Msg {
return aaaaResponse(
msg, "cloudflare.com.",
"2606:4700::6810:84e5",
)
},
)
m.on(authIP, "cloudflare.com.", "MX", emptyResponse)
m.on(authIP, "cloudflare.com.", "TXT", emptyResponse)
}
}
func setupMockDNS() *mockDNSClient {
m := newMockClient()
d := newMockData()
setupRootDelegations(m, d.tldNS, d.tldGlue)
setupRootARecords(m)
setupTLDDelegations(m, d.exNS, d.exGlue, d.cfNS, d.cfGlue)
setupExampleNSAndA(m, d.exNS)
setupExampleMXAndTXT(m)
setupExampleSubdomains(m, d.exNS)
setupCloudflareAuthRecords(m, d.cfNS)
return m
}
// ----------------------------------------------------------------
// Test helpers
// ----------------------------------------------------------------
@@ -28,14 +521,14 @@ func newTestResolver(t *testing.T) *resolver.Resolver {
&slog.HandlerOptions{Level: slog.LevelDebug},
))
return resolver.NewFromLogger(log)
return resolver.NewFromLoggerWithClient(log, setupMockDNS())
}
func testContext(t *testing.T) context.Context {
t.Helper()
ctx, cancel := context.WithTimeout(
context.Background(), 60*time.Second,
context.Background(), 10*time.Second,
)
t.Cleanup(cancel)
@@ -72,23 +565,23 @@ func TestFindAuthoritativeNameservers_ValidDomain(
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
hasGoogleNS := false
hasExampleNS := false
for _, ns := range nameservers {
if strings.Contains(ns, "google") {
hasGoogleNS = true
if strings.Contains(ns, "example") {
hasExampleNS = true
break
}
}
assert.True(t, hasGoogleNS,
"expected google nameservers, got: %v", nameservers,
assert.True(t, hasExampleNS,
"expected example nameservers, got: %v", nameservers,
)
}
@@ -101,7 +594,7 @@ func TestFindAuthoritativeNameservers_Subdomain(
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "www.google.com",
ctx, "www.example.com",
)
require.NoError(t, err)
require.NotEmpty(t, nameservers)
@@ -116,7 +609,7 @@ func TestFindAuthoritativeNameservers_ReturnsSorted(
ctx := testContext(t)
nameservers, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
@@ -136,12 +629,12 @@ func TestFindAuthoritativeNameservers_Deterministic(
ctx := testContext(t)
first, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
second, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
@@ -157,12 +650,12 @@ func TestFindAuthoritativeNameservers_TrailingDot(
ctx := testContext(t)
ns1, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
ns2, err := r.FindAuthoritativeNameservers(
ctx, "google.com.",
ctx, "example.com.",
)
require.NoError(t, err)
@@ -199,10 +692,10 @@ func TestQueryNameserver_BasicA(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
ns := findOneNSForDomain(t, r, ctx, "example.com")
resp, err := r.QueryNameserver(
ctx, ns, "www.google.com",
ctx, ns, "www.example.com",
)
require.NoError(t, err)
require.NotNil(t, resp)
@@ -213,7 +706,7 @@ func TestQueryNameserver_BasicA(t *testing.T) {
hasRecords := len(resp.Records["A"]) > 0 ||
len(resp.Records["CNAME"]) > 0
assert.True(t, hasRecords,
"expected A or CNAME records for www.google.com",
"expected A or CNAME records for www.example.com",
)
}
@@ -247,16 +740,16 @@ func TestQueryNameserver_MX(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
ns := findOneNSForDomain(t, r, ctx, "example.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
ctx, ns, "example.com",
)
require.NoError(t, err)
mxRecords := resp.Records["MX"]
require.NotEmpty(t, mxRecords,
"google.com should have MX records",
"example.com should have MX records",
)
}
@@ -265,16 +758,16 @@ func TestQueryNameserver_TXT(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
ns := findOneNSForDomain(t, r, ctx, "example.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
ctx, ns, "example.com",
)
require.NoError(t, err)
txtRecords := resp.Records["TXT"]
require.NotEmpty(t, txtRecords,
"google.com should have TXT records",
"example.com should have TXT records",
)
hasSPF := false
@@ -288,7 +781,7 @@ func TestQueryNameserver_TXT(t *testing.T) {
}
assert.True(t, hasSPF,
"google.com should have SPF TXT record",
"example.com should have SPF TXT record",
)
}
@@ -297,11 +790,11 @@ func TestQueryNameserver_NXDomain(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
ns := findOneNSForDomain(t, r, ctx, "example.com")
resp, err := r.QueryNameserver(
ctx, ns,
"this-surely-does-not-exist-xyz.google.com",
"this-surely-does-not-exist-xyz.example.com",
)
require.NoError(t, err)
@@ -313,10 +806,10 @@ func TestQueryNameserver_RecordsSorted(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
ns := findOneNSForDomain(t, r, ctx, "example.com")
resp, err := r.QueryNameserver(
ctx, ns, "google.com",
ctx, ns, "example.com",
)
require.NoError(t, err)
@@ -353,11 +846,11 @@ func TestQueryNameserver_EmptyRecordsOnNXDomain(
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
ns := findOneNSForDomain(t, r, ctx, "example.com")
resp, err := r.QueryNameserver(
ctx, ns,
"this-surely-does-not-exist-xyz.google.com",
"this-surely-does-not-exist-xyz.example.com",
)
require.NoError(t, err)
@@ -374,15 +867,15 @@ func TestQueryNameserver_TrailingDotHandling(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ns := findOneNSForDomain(t, r, ctx, "google.com")
ns := findOneNSForDomain(t, r, ctx, "example.com")
resp1, err := r.QueryNameserver(
ctx, ns, "google.com",
ctx, ns, "example.com",
)
require.NoError(t, err)
resp2, err := r.QueryNameserver(
ctx, ns, "google.com.",
ctx, ns, "example.com.",
)
require.NoError(t, err)
@@ -400,7 +893,7 @@ func TestQueryAllNameservers_ReturnsAllNS(t *testing.T) {
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
require.NotEmpty(t, results)
@@ -419,7 +912,7 @@ func TestQueryAllNameservers_AllReturnOK(t *testing.T) {
ctx := testContext(t)
results, err := r.QueryAllNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
@@ -441,7 +934,7 @@ func TestQueryAllNameservers_NXDomainFromAllNS(
results, err := r.QueryAllNameservers(
ctx,
"this-surely-does-not-exist-xyz.google.com",
"this-surely-does-not-exist-xyz.example.com",
)
require.NoError(t, err)
@@ -463,7 +956,7 @@ func TestLookupNS_ValidDomain(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, "google.com")
nameservers, err := r.LookupNS(ctx, "example.com")
require.NoError(t, err)
require.NotEmpty(t, nameservers)
@@ -480,7 +973,7 @@ func TestLookupNS_Sorted(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
nameservers, err := r.LookupNS(ctx, "google.com")
nameservers, err := r.LookupNS(ctx, "example.com")
require.NoError(t, err)
assert.True(t, sort.StringsAreSorted(nameservers))
@@ -492,11 +985,11 @@ func TestLookupNS_MatchesFindAuthoritative(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
fromLookup, err := r.LookupNS(ctx, "google.com")
fromLookup, err := r.LookupNS(ctx, "example.com")
require.NoError(t, err)
fromFind, err := r.FindAuthoritativeNameservers(
ctx, "google.com",
ctx, "example.com",
)
require.NoError(t, err)
@@ -513,7 +1006,7 @@ func TestResolveIPAddresses_ReturnsIPs(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
ips, err := r.ResolveIPAddresses(ctx, "example.com")
require.NoError(t, err)
require.NotEmpty(t, ips)
@@ -531,7 +1024,7 @@ func TestResolveIPAddresses_Deduplicated(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
ips, err := r.ResolveIPAddresses(ctx, "example.com")
require.NoError(t, err)
seen := make(map[string]bool)
@@ -548,7 +1041,7 @@ func TestResolveIPAddresses_Sorted(t *testing.T) {
r := newTestResolver(t)
ctx := testContext(t)
ips, err := r.ResolveIPAddresses(ctx, "google.com")
ips, err := r.ResolveIPAddresses(ctx, "example.com")
require.NoError(t, err)
assert.True(t, sort.StringsAreSorted(ips))
@@ -564,7 +1057,7 @@ func TestResolveIPAddresses_NXDomainReturnsEmpty(
ips, err := r.ResolveIPAddresses(
ctx,
"this-surely-does-not-exist-xyz.google.com",
"this-surely-does-not-exist-xyz.example.com",
)
require.NoError(t, err)
assert.Empty(t, ips)
@@ -594,7 +1087,9 @@ func TestFindAuthoritativeNameservers_ContextCanceled(
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.FindAuthoritativeNameservers(ctx, "google.com")
_, err := r.FindAuthoritativeNameservers(
ctx, "example.com",
)
assert.Error(t, err)
}
@@ -606,7 +1101,7 @@ func TestQueryNameserver_ContextCanceled(t *testing.T) {
cancel()
_, err := r.QueryNameserver(
ctx, "ns1.google.com.", "google.com",
ctx, "ns1.example.com.", "example.com",
)
assert.Error(t, err)
}
@@ -618,7 +1113,7 @@ func TestQueryAllNameservers_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.QueryAllNameservers(ctx, "google.com")
_, err := r.QueryAllNameservers(ctx, "example.com")
assert.Error(t, err)
}
@@ -629,6 +1124,6 @@ func TestResolveIPAddresses_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := r.ResolveIPAddresses(ctx, "google.com")
_, err := r.ResolveIPAddresses(ctx, "example.com")
assert.Error(t, err)
}

View File

@@ -156,8 +156,8 @@ func (s *State) Load() error {
// Save writes the current state to disk atomically.
func (s *State) Save() error {
s.mu.RLock()
defer s.mu.RUnlock()
s.mu.Lock()
defer s.mu.Unlock()
s.snapshot.LastUpdated = time.Now().UTC()

View File

@@ -0,0 +1,67 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func TestCheckCertificateNoPeerCerts(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = ln.Close() }()
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
// Accept and immediately close to cause TLS handshake failure.
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
InsecureSkipVerify: true, //nolint:gosec // test
MinVersion: tls.VersionTLS12,
}),
tlscheck.WithPort(addr.Port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error when server presents no certs")
}
}
func TestErrNoPeerCertificatesIsSentinel(t *testing.T) {
t.Parallel()
err := tlscheck.ErrNoPeerCertificates
if !errors.Is(err, tlscheck.ErrNoPeerCertificates) {
t.Fatal("expected sentinel error to match")
}
}

View File

@@ -3,8 +3,12 @@ package tlscheck
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"time"
"go.uber.org/fx"
@@ -12,11 +16,56 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger"
)
// ErrNotImplemented indicates the TLS checker is not yet implemented.
var ErrNotImplemented = errors.New(
"tls checker not yet implemented",
const (
defaultTimeout = 10 * time.Second
defaultPort = 443
)
// ErrUnexpectedConnType indicates the connection was not a TLS
// connection.
var ErrUnexpectedConnType = errors.New(
"unexpected connection type",
)
// ErrNoPeerCertificates indicates the TLS connection had no peer
// certificates.
var ErrNoPeerCertificates = errors.New(
"no peer certificates",
)
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
SerialNumber string
}
// Option configures a Checker.
type Option func(*Checker)
// WithTimeout sets the connection timeout.
func WithTimeout(d time.Duration) Option {
return func(c *Checker) {
c.timeout = d
}
}
// WithTLSConfig sets a custom TLS configuration.
func WithTLSConfig(cfg *tls.Config) Option {
return func(c *Checker) {
c.tlsConfig = cfg
}
}
// WithPort sets the TLS port to connect to.
func WithPort(port int) Option {
return func(c *Checker) {
c.port = port
}
}
// Params contains dependencies for Checker.
type Params struct {
fx.In
@@ -26,15 +75,10 @@ type Params struct {
// Checker performs TLS certificate inspection.
type Checker struct {
log *slog.Logger
}
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
log *slog.Logger
timeout time.Duration
tlsConfig *tls.Config
port int
}
// New creates a new TLS Checker instance.
@@ -43,16 +87,110 @@ func New(
params Params,
) (*Checker, error) {
return &Checker{
log: params.Logger.Get(),
log: params.Logger.Get(),
timeout: defaultTimeout,
port: defaultPort,
}, nil
}
// CheckCertificate connects to the given IP:port using SNI and
// returns certificate information.
func (c *Checker) CheckCertificate(
_ context.Context,
_ string,
_ string,
) (*CertificateInfo, error) {
return nil, ErrNotImplemented
// NewStandalone creates a Checker without fx dependencies.
func NewStandalone(opts ...Option) *Checker {
checker := &Checker{
log: slog.Default(),
timeout: defaultTimeout,
port: defaultPort,
}
for _, opt := range opts {
opt(checker)
}
return checker
}
// CheckCertificate connects to the given IP address using the
// specified SNI hostname and returns certificate information.
func (c *Checker) CheckCertificate(
ctx context.Context,
ipAddress string,
sniHostname string,
) (*CertificateInfo, error) {
target := net.JoinHostPort(
ipAddress, strconv.Itoa(c.port),
)
tlsCfg := c.buildTLSConfig(sniHostname)
dialer := &tls.Dialer{
NetDialer: &net.Dialer{Timeout: c.timeout},
Config: tlsCfg,
}
conn, err := dialer.DialContext(ctx, "tcp", target)
if err != nil {
return nil, fmt.Errorf(
"TLS dial to %s: %w", target, err,
)
}
defer func() {
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing TLS connection",
"target", target,
"error", closeErr.Error(),
)
}
}()
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return nil, fmt.Errorf(
"%s: %w", target, ErrUnexpectedConnType,
)
}
return c.extractCertInfo(tlsConn)
}
func (c *Checker) buildTLSConfig(
sniHostname string,
) *tls.Config {
if c.tlsConfig != nil {
cfg := c.tlsConfig.Clone()
cfg.ServerName = sniHostname
return cfg
}
return &tls.Config{
ServerName: sniHostname,
MinVersion: tls.VersionTLS12,
}
}
func (c *Checker) extractCertInfo(
conn *tls.Conn,
) (*CertificateInfo, error) {
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return nil, ErrNoPeerCertificates
}
cert := state.PeerCertificates[0]
sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses))
sans = append(sans, cert.DNSNames...)
for _, ip := range cert.IPAddresses {
sans = append(sans, ip.String())
}
return &CertificateInfo{
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: sans,
SerialNumber: cert.SerialNumber.String(),
}, nil
}

View File

@@ -0,0 +1,169 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func startTLSServer(
t *testing.T,
) (*httptest.Server, string, int) {
t.Helper()
srv := httptest.NewTLSServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
},
),
)
addr, ok := srv.Listener.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return srv, addr.IP.String(), addr.Port
}
func TestCheckCertificateValid(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info == nil {
t.Fatal("expected non-nil CertificateInfo")
}
if info.NotAfter.IsZero() {
t.Error("expected non-zero NotAfter")
}
if info.SerialNumber == "" {
t.Error("expected non-empty SerialNumber")
}
}
func TestCheckCertificateConnectionRefused(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
port := addr.Port
_ = ln.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for connection refused")
}
}
func TestCheckCertificateContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
ctx, "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for canceled context")
}
}
func TestCheckCertificateTimeout(t *testing.T) {
t.Parallel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(1*time.Millisecond),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
context.Background(),
"192.0.2.1",
"example.com",
)
if err == nil {
t.Fatal("expected error for timeout")
}
}
func TestCheckCertificateSANs(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.CommonName == "" && len(info.SubjectAlternativeNames) == 0 {
t.Error("expected CN or SANs to be populated")
}
}

View File

@@ -4,6 +4,7 @@ package watcher
import (
"context"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
@@ -36,7 +37,7 @@ type PortChecker interface {
ctx context.Context,
address string,
port int,
) (bool, error)
) (*portcheck.PortResult, error)
}
// TLSChecker inspects TLS certificates.

View File

@@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort(
port int,
hostname string,
) {
open, err := w.portCheck.CheckPort(ctx, ip, port)
result, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil {
w.log.Error(
"port check failed",
@@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort(
now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != open {
if hasPrev && !w.firstRun && prev.Open != result.Open {
stateStr := "closed"
if open {
if result.Open {
stateStr = "open"
}
@@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort(
}
w.state.SetPortState(key, &state.PortState{
Open: open,
Open: result.Open,
Hostname: hostname,
LastChecked: now,
})

View File

@@ -9,6 +9,7 @@ import (
"time"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
"sneak.berlin/go/dnswatcher/internal/watcher"
@@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort(
_ context.Context,
address string,
port int,
) (bool, error) {
) (*portcheck.PortResult, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return false, m.err
return nil, m.err
}
key := fmt.Sprintf("%s:%d", address, port)
open, ok := m.results[key]
open := m.results[key]
if !ok {
return false, nil
}
return open, nil
return &portcheck.PortResult{Open: open}, nil
}
type mockTLSChecker struct {