feat: make CheckPorts concurrent and add port validation

- CheckPorts now runs all port checks concurrently using errgroup
- Added port number validation (1-65535) with ErrInvalidPort sentinel error
- Updated PortChecker interface to use *PortResult return type
- Added tests for invalid port numbers (0, negative, >65535)
- All checks pass (make check clean)
This commit is contained in:
user
2026-02-20 00:14:55 -08:00
parent ab39e77015
commit 57cd228837
7 changed files with 149 additions and 40 deletions

View File

@@ -3,18 +3,29 @@ package portcheck
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"sync"
"time"
"go.uber.org/fx"
"golang.org/x/sync/errgroup"
"sneak.berlin/go/dnswatcher/internal/logger"
)
const defaultTimeout = 5 * time.Second
const (
minPort = 1
maxPort = 65535
defaultTimeout = 5 * time.Second
)
// ErrInvalidPort is returned when a port number is outside
// the valid TCP range (165535).
var ErrInvalidPort = errors.New("invalid port number")
// PortResult holds the outcome of a single TCP port check.
type PortResult struct {
@@ -55,6 +66,19 @@ func NewStandalone() *Checker {
}
}
// validatePort checks that a port number is within the valid
// TCP port range (165535).
func validatePort(port int) error {
if port < minPort || port > maxPort {
return fmt.Errorf(
"%w: %d (must be between %d and %d)",
ErrInvalidPort, port, minPort, maxPort,
)
}
return nil
}
// CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier
// deadline.
@@ -63,13 +87,18 @@ func (c *Checker) CheckPort(
address string,
port int,
) (*PortResult, error) {
err := validatePort(port)
if err != nil {
return nil, err
}
target := net.JoinHostPort(
address, strconv.Itoa(port),
)
deadline, hasDeadline := ctx.Deadline()
timeout := defaultTimeout
deadline, hasDeadline := ctx.Deadline()
if hasDeadline {
remaining := time.Until(deadline)
if remaining < timeout {
@@ -77,8 +106,62 @@ func (c *Checker) CheckPort(
}
}
dialer := &net.Dialer{Timeout: timeout}
return c.checkConnection(ctx, target, timeout), nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address concurrently. It returns a map of port number
// to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
for _, port := range ports {
err := validatePort(port)
if err != nil {
return nil, err
}
}
var mu sync.Mutex
results := make(map[int]*PortResult, len(ports))
g, ctx := errgroup.WithContext(ctx)
for _, port := range ports {
g.Go(func() error {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return fmt.Errorf(
"checking port %d: %w", port, err,
)
}
mu.Lock()
results[port] = result
mu.Unlock()
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return results, nil
}
// checkConnection performs the TCP dial and returns a result.
func (c *Checker) checkConnection(
ctx context.Context,
target string,
timeout time.Duration,
) *PortResult {
dialer := &net.Dialer{Timeout: timeout}
start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
@@ -95,7 +178,7 @@ func (c *Checker) CheckPort(
Open: false,
Error: dialErr.Error(),
Latency: latency,
}, nil
}
}
closeErr := conn.Close()
@@ -116,28 +199,5 @@ func (c *Checker) CheckPort(
return &PortResult{
Open: true,
Latency: latency,
}, nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address. It returns a map of port number to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
results := make(map[int]*PortResult, len(ports))
for _, port := range ports {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return nil, fmt.Errorf(
"checking port %d: %w", port, err,
)
}
results[port] = result
}
return results, nil
}