From ab39e77015d2bbb333339c296be296a694b93d24 Mon Sep 17 00:00:00 2001 From: clawbot Date: Thu, 19 Feb 2026 13:44:20 -0800 Subject: [PATCH 1/2] feat: implement TCP port connectivity checker (closes #3) --- internal/portcheck/portcheck.go | 119 +++++++++++++++++-- internal/portcheck/portcheck_test.go | 163 +++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 12 deletions(-) create mode 100644 internal/portcheck/portcheck_test.go diff --git a/internal/portcheck/portcheck.go b/internal/portcheck/portcheck.go index 2c061b7..4977a5b 100644 --- a/internal/portcheck/portcheck.go +++ b/internal/portcheck/portcheck.go @@ -3,18 +3,28 @@ package portcheck import ( "context" - "errors" + "fmt" "log/slog" + "net" + "strconv" + "time" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/logger" ) -// ErrNotImplemented indicates the port checker is not yet implemented. -var ErrNotImplemented = errors.New( - "port checker not yet implemented", -) +const defaultTimeout = 5 * time.Second + +// PortResult holds the outcome of a single TCP port check. +type PortResult struct { + // Open indicates whether the port accepted a connection. + Open bool + // Error contains a description if the connection failed. + Error string + // Latency is the time taken for the TCP handshake. + Latency time.Duration +} // Params contains dependencies for Checker. type Params struct { @@ -38,11 +48,96 @@ func New( }, nil } -// CheckPort tests TCP connectivity to the given address and port. -func (c *Checker) CheckPort( - _ context.Context, - _ string, - _ int, -) (bool, error) { - return false, ErrNotImplemented +// NewStandalone creates a Checker without fx dependencies. +func NewStandalone() *Checker { + return &Checker{ + log: slog.Default(), + } +} + +// CheckPort tests TCP connectivity to the given address and port. +// It uses a 5-second timeout unless the context has an earlier +// deadline. +func (c *Checker) CheckPort( + ctx context.Context, + address string, + port int, +) (*PortResult, error) { + target := net.JoinHostPort( + address, strconv.Itoa(port), + ) + + deadline, hasDeadline := ctx.Deadline() + timeout := defaultTimeout + + if hasDeadline { + remaining := time.Until(deadline) + if remaining < timeout { + timeout = remaining + } + } + + dialer := &net.Dialer{Timeout: timeout} + + start := time.Now() + + conn, dialErr := dialer.DialContext(ctx, "tcp", target) + latency := time.Since(start) + + if dialErr != nil { + c.log.Debug( + "port check failed", + "target", target, + "error", dialErr.Error(), + ) + + return &PortResult{ + Open: false, + Error: dialErr.Error(), + Latency: latency, + }, nil + } + + closeErr := conn.Close() + if closeErr != nil { + c.log.Debug( + "closing connection", + "target", target, + "error", closeErr.Error(), + ) + } + + c.log.Debug( + "port check succeeded", + "target", target, + "latency", latency, + ) + + return &PortResult{ + Open: true, + Latency: latency, + }, nil +} + +// CheckPorts tests TCP connectivity to multiple ports on the +// given address. It returns a map of port number to result. +func (c *Checker) CheckPorts( + ctx context.Context, + address string, + ports []int, +) (map[int]*PortResult, error) { + results := make(map[int]*PortResult, len(ports)) + + for _, port := range ports { + result, err := c.CheckPort(ctx, address, port) + if err != nil { + return nil, fmt.Errorf( + "checking port %d: %w", port, err, + ) + } + + results[port] = result + } + + return results, nil } diff --git a/internal/portcheck/portcheck_test.go b/internal/portcheck/portcheck_test.go new file mode 100644 index 0000000..bd248af --- /dev/null +++ b/internal/portcheck/portcheck_test.go @@ -0,0 +1,163 @@ +package portcheck_test + +import ( + "context" + "net" + "testing" + "time" + + "sneak.berlin/go/dnswatcher/internal/portcheck" +) + +func listenTCP( + t *testing.T, +) (net.Listener, int) { + t.Helper() + + lc := &net.ListenConfig{} + + ln, err := lc.Listen( + context.Background(), "tcp", "127.0.0.1:0", + ) + if err != nil { + t.Fatalf("failed to start listener: %v", err) + } + + addr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + t.Fatal("unexpected address type") + } + + return ln, addr.Port +} + +func TestCheckPortOpen(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + + defer func() { _ = ln.Close() }() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !result.Open { + t.Error("expected port to be open") + } + + if result.Error != "" { + t.Errorf("expected no error, got: %s", result.Error) + } + + if result.Latency <= 0 { + t.Error("expected positive latency") + } +} + +func TestCheckPortClosed(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + _ = ln.Close() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Open { + t.Error("expected port to be closed") + } + + if result.Error == "" { + t.Error("expected error message for closed port") + } +} + +func TestCheckPortContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort(ctx, "127.0.0.1", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Open { + t.Error("expected port to not be open") + } +} + +func TestCheckPortsMultiple(t *testing.T) { + t.Parallel() + + ln, openPort := listenTCP(t) + + defer func() { _ = ln.Close() }() + + ln2, closedPort := listenTCP(t) + _ = ln2.Close() + + checker := portcheck.NewStandalone() + + results, err := checker.CheckPorts( + context.Background(), + "127.0.0.1", + []int{openPort, closedPort}, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(results) != 2 { + t.Fatalf( + "expected 2 results, got %d", len(results), + ) + } + + if !results[openPort].Open { + t.Error("expected open port to be open") + } + + if results[closedPort].Open { + t.Error("expected closed port to be closed") + } +} + +func TestCheckPortLatencyReasonable(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + + defer func() { _ = ln.Close() }() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Latency > time.Second { + t.Errorf( + "latency too high for localhost: %v", + result.Latency, + ) + } +} -- 2.45.2 From 57cd228837ad9df30a97b9d4d0f82fc1cafe6773 Mon Sep 17 00:00:00 2001 From: user Date: Fri, 20 Feb 2026 00:14:55 -0800 Subject: [PATCH 2/2] feat: make CheckPorts concurrent and add port validation - CheckPorts now runs all port checks concurrently using errgroup - Added port number validation (1-65535) with ErrInvalidPort sentinel error - Updated PortChecker interface to use *PortResult return type - Added tests for invalid port numbers (0, negative, >65535) - All checks pass (make check clean) --- go.mod | 1 + go.sum | 2 + internal/portcheck/portcheck.go | 114 ++++++++++++++++++++------- internal/portcheck/portcheck_test.go | 48 +++++++++++ internal/watcher/interfaces.go | 3 +- internal/watcher/watcher.go | 8 +- internal/watcher/watcher_test.go | 13 ++- 7 files changed, 149 insertions(+), 40 deletions(-) diff --git a/go.mod b/go.mod index 32ad532..234b2f3 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/spf13/viper v1.21.0 go.uber.org/fx v1.24.0 golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 ) require ( diff --git a/go.sum b/go.sum index 66cc528..0edfeb0 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= diff --git a/internal/portcheck/portcheck.go b/internal/portcheck/portcheck.go index 4977a5b..a57b230 100644 --- a/internal/portcheck/portcheck.go +++ b/internal/portcheck/portcheck.go @@ -3,18 +3,29 @@ package portcheck import ( "context" + "errors" "fmt" "log/slog" "net" "strconv" + "sync" "time" "go.uber.org/fx" + "golang.org/x/sync/errgroup" "sneak.berlin/go/dnswatcher/internal/logger" ) -const defaultTimeout = 5 * time.Second +const ( + minPort = 1 + maxPort = 65535 + defaultTimeout = 5 * time.Second +) + +// ErrInvalidPort is returned when a port number is outside +// the valid TCP range (1–65535). +var ErrInvalidPort = errors.New("invalid port number") // PortResult holds the outcome of a single TCP port check. type PortResult struct { @@ -55,6 +66,19 @@ func NewStandalone() *Checker { } } +// validatePort checks that a port number is within the valid +// TCP port range (1–65535). +func validatePort(port int) error { + if port < minPort || port > maxPort { + return fmt.Errorf( + "%w: %d (must be between %d and %d)", + ErrInvalidPort, port, minPort, maxPort, + ) + } + + return nil +} + // CheckPort tests TCP connectivity to the given address and port. // It uses a 5-second timeout unless the context has an earlier // deadline. @@ -63,13 +87,18 @@ func (c *Checker) CheckPort( address string, port int, ) (*PortResult, error) { + err := validatePort(port) + if err != nil { + return nil, err + } + target := net.JoinHostPort( address, strconv.Itoa(port), ) - deadline, hasDeadline := ctx.Deadline() timeout := defaultTimeout + deadline, hasDeadline := ctx.Deadline() if hasDeadline { remaining := time.Until(deadline) if remaining < timeout { @@ -77,8 +106,62 @@ func (c *Checker) CheckPort( } } - dialer := &net.Dialer{Timeout: timeout} + return c.checkConnection(ctx, target, timeout), nil +} +// CheckPorts tests TCP connectivity to multiple ports on the +// given address concurrently. It returns a map of port number +// to result. +func (c *Checker) CheckPorts( + ctx context.Context, + address string, + ports []int, +) (map[int]*PortResult, error) { + for _, port := range ports { + err := validatePort(port) + if err != nil { + return nil, err + } + } + + var mu sync.Mutex + + results := make(map[int]*PortResult, len(ports)) + + g, ctx := errgroup.WithContext(ctx) + + for _, port := range ports { + g.Go(func() error { + result, err := c.CheckPort(ctx, address, port) + if err != nil { + return fmt.Errorf( + "checking port %d: %w", port, err, + ) + } + + mu.Lock() + results[port] = result + mu.Unlock() + + return nil + }) + } + + err := g.Wait() + if err != nil { + return nil, err + } + + return results, nil +} + +// checkConnection performs the TCP dial and returns a result. +func (c *Checker) checkConnection( + ctx context.Context, + target string, + timeout time.Duration, +) *PortResult { + dialer := &net.Dialer{Timeout: timeout} start := time.Now() conn, dialErr := dialer.DialContext(ctx, "tcp", target) @@ -95,7 +178,7 @@ func (c *Checker) CheckPort( Open: false, Error: dialErr.Error(), Latency: latency, - }, nil + } } closeErr := conn.Close() @@ -116,28 +199,5 @@ func (c *Checker) CheckPort( return &PortResult{ Open: true, Latency: latency, - }, nil -} - -// CheckPorts tests TCP connectivity to multiple ports on the -// given address. It returns a map of port number to result. -func (c *Checker) CheckPorts( - ctx context.Context, - address string, - ports []int, -) (map[int]*PortResult, error) { - results := make(map[int]*PortResult, len(ports)) - - for _, port := range ports { - result, err := c.CheckPort(ctx, address, port) - if err != nil { - return nil, fmt.Errorf( - "checking port %d: %w", port, err, - ) - } - - results[port] = result } - - return results, nil } diff --git a/internal/portcheck/portcheck_test.go b/internal/portcheck/portcheck_test.go index bd248af..7ea7fc0 100644 --- a/internal/portcheck/portcheck_test.go +++ b/internal/portcheck/portcheck_test.go @@ -138,6 +138,54 @@ func TestCheckPortsMultiple(t *testing.T) { } } +func TestCheckPortInvalidPorts(t *testing.T) { + t.Parallel() + + checker := portcheck.NewStandalone() + + cases := []struct { + name string + port int + }{ + {"zero", 0}, + {"negative", -1}, + {"too high", 65536}, + {"very negative", -1000}, + {"very high", 100000}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := checker.CheckPort( + context.Background(), "127.0.0.1", tc.port, + ) + if err == nil { + t.Errorf( + "expected error for port %d, got nil", + tc.port, + ) + } + }) + } +} + +func TestCheckPortsInvalidPort(t *testing.T) { + t.Parallel() + + checker := portcheck.NewStandalone() + + _, err := checker.CheckPorts( + context.Background(), + "127.0.0.1", + []int{80, 0, 443}, + ) + if err == nil { + t.Error("expected error for invalid port in list") + } +} + func TestCheckPortLatencyReasonable(t *testing.T) { t.Parallel() diff --git a/internal/watcher/interfaces.go b/internal/watcher/interfaces.go index 695139d..dd68017 100644 --- a/internal/watcher/interfaces.go +++ b/internal/watcher/interfaces.go @@ -4,6 +4,7 @@ package watcher import ( "context" + "sneak.berlin/go/dnswatcher/internal/portcheck" "sneak.berlin/go/dnswatcher/internal/tlscheck" ) @@ -36,7 +37,7 @@ type PortChecker interface { ctx context.Context, address string, port int, - ) (bool, error) + ) (*portcheck.PortResult, error) } // TLSChecker inspects TLS certificates. diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 2493264..742834c 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort( port int, hostname string, ) { - open, err := w.portCheck.CheckPort(ctx, ip, port) + result, err := w.portCheck.CheckPort(ctx, ip, port) if err != nil { w.log.Error( "port check failed", @@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort( now := time.Now().UTC() prev, hasPrev := w.state.GetPortState(key) - if hasPrev && !w.firstRun && prev.Open != open { + if hasPrev && !w.firstRun && prev.Open != result.Open { stateStr := "closed" - if open { + if result.Open { stateStr = "open" } @@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort( } w.state.SetPortState(key, &state.PortState{ - Open: open, + Open: result.Open, Hostname: hostname, LastChecked: now, }) diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 69772aa..57b56c7 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -9,6 +9,7 @@ import ( "time" "sneak.berlin/go/dnswatcher/internal/config" + "sneak.berlin/go/dnswatcher/internal/portcheck" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/watcher" @@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort( _ context.Context, address string, port int, -) (bool, error) { +) (*portcheck.PortResult, error) { m.mu.Lock() defer m.mu.Unlock() m.calls++ if m.err != nil { - return false, m.err + return nil, m.err } key := fmt.Sprintf("%s:%d", address, port) - open, ok := m.results[key] + open := m.results[key] - if !ok { - return false, nil - } - - return open, nil + return &portcheck.PortResult{Open: open}, nil } type mockTLSChecker struct { -- 2.45.2