From 57cd228837ad9df30a97b9d4d0f82fc1cafe6773 Mon Sep 17 00:00:00 2001 From: user Date: Fri, 20 Feb 2026 00:14:55 -0800 Subject: [PATCH] feat: make CheckPorts concurrent and add port validation - CheckPorts now runs all port checks concurrently using errgroup - Added port number validation (1-65535) with ErrInvalidPort sentinel error - Updated PortChecker interface to use *PortResult return type - Added tests for invalid port numbers (0, negative, >65535) - All checks pass (make check clean) --- go.mod | 1 + go.sum | 2 + internal/portcheck/portcheck.go | 114 ++++++++++++++++++++------- internal/portcheck/portcheck_test.go | 48 +++++++++++ internal/watcher/interfaces.go | 3 +- internal/watcher/watcher.go | 8 +- internal/watcher/watcher_test.go | 13 ++- 7 files changed, 149 insertions(+), 40 deletions(-) diff --git a/go.mod b/go.mod index 32ad532..234b2f3 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/spf13/viper v1.21.0 go.uber.org/fx v1.24.0 golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 ) require ( diff --git a/go.sum b/go.sum index 66cc528..0edfeb0 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= diff --git a/internal/portcheck/portcheck.go b/internal/portcheck/portcheck.go index 4977a5b..a57b230 100644 --- a/internal/portcheck/portcheck.go +++ b/internal/portcheck/portcheck.go @@ -3,18 +3,29 @@ package portcheck import ( "context" + "errors" "fmt" "log/slog" "net" "strconv" + "sync" "time" "go.uber.org/fx" + "golang.org/x/sync/errgroup" "sneak.berlin/go/dnswatcher/internal/logger" ) -const defaultTimeout = 5 * time.Second +const ( + minPort = 1 + maxPort = 65535 + defaultTimeout = 5 * time.Second +) + +// ErrInvalidPort is returned when a port number is outside +// the valid TCP range (1–65535). +var ErrInvalidPort = errors.New("invalid port number") // PortResult holds the outcome of a single TCP port check. type PortResult struct { @@ -55,6 +66,19 @@ func NewStandalone() *Checker { } } +// validatePort checks that a port number is within the valid +// TCP port range (1–65535). +func validatePort(port int) error { + if port < minPort || port > maxPort { + return fmt.Errorf( + "%w: %d (must be between %d and %d)", + ErrInvalidPort, port, minPort, maxPort, + ) + } + + return nil +} + // CheckPort tests TCP connectivity to the given address and port. // It uses a 5-second timeout unless the context has an earlier // deadline. @@ -63,13 +87,18 @@ func (c *Checker) CheckPort( address string, port int, ) (*PortResult, error) { + err := validatePort(port) + if err != nil { + return nil, err + } + target := net.JoinHostPort( address, strconv.Itoa(port), ) - deadline, hasDeadline := ctx.Deadline() timeout := defaultTimeout + deadline, hasDeadline := ctx.Deadline() if hasDeadline { remaining := time.Until(deadline) if remaining < timeout { @@ -77,8 +106,62 @@ func (c *Checker) CheckPort( } } - dialer := &net.Dialer{Timeout: timeout} + return c.checkConnection(ctx, target, timeout), nil +} +// CheckPorts tests TCP connectivity to multiple ports on the +// given address concurrently. It returns a map of port number +// to result. +func (c *Checker) CheckPorts( + ctx context.Context, + address string, + ports []int, +) (map[int]*PortResult, error) { + for _, port := range ports { + err := validatePort(port) + if err != nil { + return nil, err + } + } + + var mu sync.Mutex + + results := make(map[int]*PortResult, len(ports)) + + g, ctx := errgroup.WithContext(ctx) + + for _, port := range ports { + g.Go(func() error { + result, err := c.CheckPort(ctx, address, port) + if err != nil { + return fmt.Errorf( + "checking port %d: %w", port, err, + ) + } + + mu.Lock() + results[port] = result + mu.Unlock() + + return nil + }) + } + + err := g.Wait() + if err != nil { + return nil, err + } + + return results, nil +} + +// checkConnection performs the TCP dial and returns a result. +func (c *Checker) checkConnection( + ctx context.Context, + target string, + timeout time.Duration, +) *PortResult { + dialer := &net.Dialer{Timeout: timeout} start := time.Now() conn, dialErr := dialer.DialContext(ctx, "tcp", target) @@ -95,7 +178,7 @@ func (c *Checker) CheckPort( Open: false, Error: dialErr.Error(), Latency: latency, - }, nil + } } closeErr := conn.Close() @@ -116,28 +199,5 @@ func (c *Checker) CheckPort( return &PortResult{ Open: true, Latency: latency, - }, nil -} - -// CheckPorts tests TCP connectivity to multiple ports on the -// given address. It returns a map of port number to result. -func (c *Checker) CheckPorts( - ctx context.Context, - address string, - ports []int, -) (map[int]*PortResult, error) { - results := make(map[int]*PortResult, len(ports)) - - for _, port := range ports { - result, err := c.CheckPort(ctx, address, port) - if err != nil { - return nil, fmt.Errorf( - "checking port %d: %w", port, err, - ) - } - - results[port] = result } - - return results, nil } diff --git a/internal/portcheck/portcheck_test.go b/internal/portcheck/portcheck_test.go index bd248af..7ea7fc0 100644 --- a/internal/portcheck/portcheck_test.go +++ b/internal/portcheck/portcheck_test.go @@ -138,6 +138,54 @@ func TestCheckPortsMultiple(t *testing.T) { } } +func TestCheckPortInvalidPorts(t *testing.T) { + t.Parallel() + + checker := portcheck.NewStandalone() + + cases := []struct { + name string + port int + }{ + {"zero", 0}, + {"negative", -1}, + {"too high", 65536}, + {"very negative", -1000}, + {"very high", 100000}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := checker.CheckPort( + context.Background(), "127.0.0.1", tc.port, + ) + if err == nil { + t.Errorf( + "expected error for port %d, got nil", + tc.port, + ) + } + }) + } +} + +func TestCheckPortsInvalidPort(t *testing.T) { + t.Parallel() + + checker := portcheck.NewStandalone() + + _, err := checker.CheckPorts( + context.Background(), + "127.0.0.1", + []int{80, 0, 443}, + ) + if err == nil { + t.Error("expected error for invalid port in list") + } +} + func TestCheckPortLatencyReasonable(t *testing.T) { t.Parallel() diff --git a/internal/watcher/interfaces.go b/internal/watcher/interfaces.go index 695139d..dd68017 100644 --- a/internal/watcher/interfaces.go +++ b/internal/watcher/interfaces.go @@ -4,6 +4,7 @@ package watcher import ( "context" + "sneak.berlin/go/dnswatcher/internal/portcheck" "sneak.berlin/go/dnswatcher/internal/tlscheck" ) @@ -36,7 +37,7 @@ type PortChecker interface { ctx context.Context, address string, port int, - ) (bool, error) + ) (*portcheck.PortResult, error) } // TLSChecker inspects TLS certificates. diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 2493264..742834c 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort( port int, hostname string, ) { - open, err := w.portCheck.CheckPort(ctx, ip, port) + result, err := w.portCheck.CheckPort(ctx, ip, port) if err != nil { w.log.Error( "port check failed", @@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort( now := time.Now().UTC() prev, hasPrev := w.state.GetPortState(key) - if hasPrev && !w.firstRun && prev.Open != open { + if hasPrev && !w.firstRun && prev.Open != result.Open { stateStr := "closed" - if open { + if result.Open { stateStr = "open" } @@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort( } w.state.SetPortState(key, &state.PortState{ - Open: open, + Open: result.Open, Hostname: hostname, LastChecked: now, }) diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 69772aa..57b56c7 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -9,6 +9,7 @@ import ( "time" "sneak.berlin/go/dnswatcher/internal/config" + "sneak.berlin/go/dnswatcher/internal/portcheck" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/watcher" @@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort( _ context.Context, address string, port int, -) (bool, error) { +) (*portcheck.PortResult, error) { m.mu.Lock() defer m.mu.Unlock() m.calls++ if m.err != nil { - return false, m.err + return nil, m.err } key := fmt.Sprintf("%s:%d", address, port) - open, ok := m.results[key] + open := m.results[key] - if !ok { - return false, nil - } - - return open, nil + return &portcheck.PortResult{Open: open}, nil } type mockTLSChecker struct {