feat: make CheckPorts concurrent and add port validation

- CheckPorts now runs all port checks concurrently using errgroup
- Added port number validation (1-65535) with ErrInvalidPort sentinel error
- Updated PortChecker interface to use *PortResult return type
- Added tests for invalid port numbers (0, negative, >65535)
- All checks pass (make check clean)
This commit is contained in:
user 2026-02-20 00:14:55 -08:00
parent ab39e77015
commit 57cd228837
7 changed files with 149 additions and 40 deletions

1
go.mod
View File

@ -11,6 +11,7 @@ require (
github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0
golang.org/x/sync v0.19.0
)
require (

2
go.sum
View File

@ -76,6 +76,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=

View File

@ -3,18 +3,29 @@ package portcheck
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"sync"
"time"
"go.uber.org/fx"
"golang.org/x/sync/errgroup"
"sneak.berlin/go/dnswatcher/internal/logger"
)
const defaultTimeout = 5 * time.Second
const (
minPort = 1
maxPort = 65535
defaultTimeout = 5 * time.Second
)
// ErrInvalidPort is returned when a port number is outside
// the valid TCP range (165535).
var ErrInvalidPort = errors.New("invalid port number")
// PortResult holds the outcome of a single TCP port check.
type PortResult struct {
@ -55,6 +66,19 @@ func NewStandalone() *Checker {
}
}
// validatePort checks that a port number is within the valid
// TCP port range (165535).
func validatePort(port int) error {
if port < minPort || port > maxPort {
return fmt.Errorf(
"%w: %d (must be between %d and %d)",
ErrInvalidPort, port, minPort, maxPort,
)
}
return nil
}
// CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier
// deadline.
@ -63,13 +87,18 @@ func (c *Checker) CheckPort(
address string,
port int,
) (*PortResult, error) {
err := validatePort(port)
if err != nil {
return nil, err
}
target := net.JoinHostPort(
address, strconv.Itoa(port),
)
deadline, hasDeadline := ctx.Deadline()
timeout := defaultTimeout
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
remaining := time.Until(deadline)
if remaining < timeout {
@ -77,8 +106,62 @@ func (c *Checker) CheckPort(
}
}
dialer := &net.Dialer{Timeout: timeout}
return c.checkConnection(ctx, target, timeout), nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address concurrently. It returns a map of port number
// to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
for _, port := range ports {
err := validatePort(port)
if err != nil {
return nil, err
}
}
var mu sync.Mutex
results := make(map[int]*PortResult, len(ports))
g, ctx := errgroup.WithContext(ctx)
for _, port := range ports {
g.Go(func() error {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return fmt.Errorf(
"checking port %d: %w", port, err,
)
}
mu.Lock()
results[port] = result
mu.Unlock()
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return results, nil
}
// checkConnection performs the TCP dial and returns a result.
func (c *Checker) checkConnection(
ctx context.Context,
target string,
timeout time.Duration,
) *PortResult {
dialer := &net.Dialer{Timeout: timeout}
start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
@ -95,7 +178,7 @@ func (c *Checker) CheckPort(
Open: false,
Error: dialErr.Error(),
Latency: latency,
}, nil
}
}
closeErr := conn.Close()
@ -116,28 +199,5 @@ func (c *Checker) CheckPort(
return &PortResult{
Open: true,
Latency: latency,
}, nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address. It returns a map of port number to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
results := make(map[int]*PortResult, len(ports))
for _, port := range ports {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return nil, fmt.Errorf(
"checking port %d: %w", port, err,
)
}
results[port] = result
}
return results, nil
}

View File

@ -138,6 +138,54 @@ func TestCheckPortsMultiple(t *testing.T) {
}
}
func TestCheckPortInvalidPorts(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
cases := []struct {
name string
port int
}{
{"zero", 0},
{"negative", -1},
{"too high", 65536},
{"very negative", -1000},
{"very high", 100000},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := checker.CheckPort(
context.Background(), "127.0.0.1", tc.port,
)
if err == nil {
t.Errorf(
"expected error for port %d, got nil",
tc.port,
)
}
})
}
}
func TestCheckPortsInvalidPort(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
_, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{80, 0, 443},
)
if err == nil {
t.Error("expected error for invalid port in list")
}
}
func TestCheckPortLatencyReasonable(t *testing.T) {
t.Parallel()

View File

@ -4,6 +4,7 @@ package watcher
import (
"context"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
@ -36,7 +37,7 @@ type PortChecker interface {
ctx context.Context,
address string,
port int,
) (bool, error)
) (*portcheck.PortResult, error)
}
// TLSChecker inspects TLS certificates.

View File

@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort(
port int,
hostname string,
) {
open, err := w.portCheck.CheckPort(ctx, ip, port)
result, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil {
w.log.Error(
"port check failed",
@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort(
now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != open {
if hasPrev && !w.firstRun && prev.Open != result.Open {
stateStr := "closed"
if open {
if result.Open {
stateStr = "open"
}
@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort(
}
w.state.SetPortState(key, &state.PortState{
Open: open,
Open: result.Open,
Hostname: hostname,
LastChecked: now,
})

View File

@ -9,6 +9,7 @@ import (
"time"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
"sneak.berlin/go/dnswatcher/internal/watcher"
@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort(
_ context.Context,
address string,
port int,
) (bool, error) {
) (*portcheck.PortResult, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
if m.err != nil {
return false, m.err
return nil, m.err
}
key := fmt.Sprintf("%s:%d", address, port)
open, ok := m.results[key]
open := m.results[key]
if !ok {
return false, nil
}
return open, nil
return &portcheck.PortResult{Open: open}, nil
}
type mockTLSChecker struct {