From ab39e77015d2bbb333339c296be296a694b93d24 Mon Sep 17 00:00:00 2001 From: clawbot Date: Thu, 19 Feb 2026 13:44:20 -0800 Subject: [PATCH] feat: implement TCP port connectivity checker (closes #3) --- internal/portcheck/portcheck.go | 119 +++++++++++++++++-- internal/portcheck/portcheck_test.go | 163 +++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 12 deletions(-) create mode 100644 internal/portcheck/portcheck_test.go diff --git a/internal/portcheck/portcheck.go b/internal/portcheck/portcheck.go index 2c061b7..4977a5b 100644 --- a/internal/portcheck/portcheck.go +++ b/internal/portcheck/portcheck.go @@ -3,18 +3,28 @@ package portcheck import ( "context" - "errors" + "fmt" "log/slog" + "net" + "strconv" + "time" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/logger" ) -// ErrNotImplemented indicates the port checker is not yet implemented. -var ErrNotImplemented = errors.New( - "port checker not yet implemented", -) +const defaultTimeout = 5 * time.Second + +// PortResult holds the outcome of a single TCP port check. +type PortResult struct { + // Open indicates whether the port accepted a connection. + Open bool + // Error contains a description if the connection failed. + Error string + // Latency is the time taken for the TCP handshake. + Latency time.Duration +} // Params contains dependencies for Checker. type Params struct { @@ -38,11 +48,96 @@ func New( }, nil } -// CheckPort tests TCP connectivity to the given address and port. -func (c *Checker) CheckPort( - _ context.Context, - _ string, - _ int, -) (bool, error) { - return false, ErrNotImplemented +// NewStandalone creates a Checker without fx dependencies. +func NewStandalone() *Checker { + return &Checker{ + log: slog.Default(), + } +} + +// CheckPort tests TCP connectivity to the given address and port. +// It uses a 5-second timeout unless the context has an earlier +// deadline. +func (c *Checker) CheckPort( + ctx context.Context, + address string, + port int, +) (*PortResult, error) { + target := net.JoinHostPort( + address, strconv.Itoa(port), + ) + + deadline, hasDeadline := ctx.Deadline() + timeout := defaultTimeout + + if hasDeadline { + remaining := time.Until(deadline) + if remaining < timeout { + timeout = remaining + } + } + + dialer := &net.Dialer{Timeout: timeout} + + start := time.Now() + + conn, dialErr := dialer.DialContext(ctx, "tcp", target) + latency := time.Since(start) + + if dialErr != nil { + c.log.Debug( + "port check failed", + "target", target, + "error", dialErr.Error(), + ) + + return &PortResult{ + Open: false, + Error: dialErr.Error(), + Latency: latency, + }, nil + } + + closeErr := conn.Close() + if closeErr != nil { + c.log.Debug( + "closing connection", + "target", target, + "error", closeErr.Error(), + ) + } + + c.log.Debug( + "port check succeeded", + "target", target, + "latency", latency, + ) + + return &PortResult{ + Open: true, + Latency: latency, + }, nil +} + +// CheckPorts tests TCP connectivity to multiple ports on the +// given address. It returns a map of port number to result. +func (c *Checker) CheckPorts( + ctx context.Context, + address string, + ports []int, +) (map[int]*PortResult, error) { + results := make(map[int]*PortResult, len(ports)) + + for _, port := range ports { + result, err := c.CheckPort(ctx, address, port) + if err != nil { + return nil, fmt.Errorf( + "checking port %d: %w", port, err, + ) + } + + results[port] = result + } + + return results, nil } diff --git a/internal/portcheck/portcheck_test.go b/internal/portcheck/portcheck_test.go new file mode 100644 index 0000000..bd248af --- /dev/null +++ b/internal/portcheck/portcheck_test.go @@ -0,0 +1,163 @@ +package portcheck_test + +import ( + "context" + "net" + "testing" + "time" + + "sneak.berlin/go/dnswatcher/internal/portcheck" +) + +func listenTCP( + t *testing.T, +) (net.Listener, int) { + t.Helper() + + lc := &net.ListenConfig{} + + ln, err := lc.Listen( + context.Background(), "tcp", "127.0.0.1:0", + ) + if err != nil { + t.Fatalf("failed to start listener: %v", err) + } + + addr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + t.Fatal("unexpected address type") + } + + return ln, addr.Port +} + +func TestCheckPortOpen(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + + defer func() { _ = ln.Close() }() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !result.Open { + t.Error("expected port to be open") + } + + if result.Error != "" { + t.Errorf("expected no error, got: %s", result.Error) + } + + if result.Latency <= 0 { + t.Error("expected positive latency") + } +} + +func TestCheckPortClosed(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + _ = ln.Close() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Open { + t.Error("expected port to be closed") + } + + if result.Error == "" { + t.Error("expected error message for closed port") + } +} + +func TestCheckPortContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort(ctx, "127.0.0.1", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Open { + t.Error("expected port to not be open") + } +} + +func TestCheckPortsMultiple(t *testing.T) { + t.Parallel() + + ln, openPort := listenTCP(t) + + defer func() { _ = ln.Close() }() + + ln2, closedPort := listenTCP(t) + _ = ln2.Close() + + checker := portcheck.NewStandalone() + + results, err := checker.CheckPorts( + context.Background(), + "127.0.0.1", + []int{openPort, closedPort}, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(results) != 2 { + t.Fatalf( + "expected 2 results, got %d", len(results), + ) + } + + if !results[openPort].Open { + t.Error("expected open port to be open") + } + + if results[closedPort].Open { + t.Error("expected closed port to be closed") + } +} + +func TestCheckPortLatencyReasonable(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + + defer func() { _ = ln.Close() }() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Latency > time.Second { + t.Errorf( + "latency too high for localhost: %v", + result.Latency, + ) + } +}