- CheckPorts now runs all port checks concurrently using errgroup - Added port number validation (1-65535) with ErrInvalidPort sentinel error - Updated PortChecker interface to use *PortResult return type - Added tests for invalid port numbers (0, negative, >65535) - All checks pass (make check clean)
204 lines
3.8 KiB
Go
204 lines
3.8 KiB
Go
// Package portcheck provides TCP port connectivity checking.
|
||
package portcheck
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"log/slog"
|
||
"net"
|
||
"strconv"
|
||
"sync"
|
||
"time"
|
||
|
||
"go.uber.org/fx"
|
||
"golang.org/x/sync/errgroup"
|
||
|
||
"sneak.berlin/go/dnswatcher/internal/logger"
|
||
)
|
||
|
||
const (
|
||
minPort = 1
|
||
maxPort = 65535
|
||
defaultTimeout = 5 * time.Second
|
||
)
|
||
|
||
// ErrInvalidPort is returned when a port number is outside
|
||
// the valid TCP range (1–65535).
|
||
var ErrInvalidPort = errors.New("invalid port number")
|
||
|
||
// PortResult holds the outcome of a single TCP port check.
|
||
type PortResult struct {
|
||
// Open indicates whether the port accepted a connection.
|
||
Open bool
|
||
// Error contains a description if the connection failed.
|
||
Error string
|
||
// Latency is the time taken for the TCP handshake.
|
||
Latency time.Duration
|
||
}
|
||
|
||
// Params contains dependencies for Checker.
|
||
type Params struct {
|
||
fx.In
|
||
|
||
Logger *logger.Logger
|
||
}
|
||
|
||
// Checker performs TCP port connectivity checks.
|
||
type Checker struct {
|
||
log *slog.Logger
|
||
}
|
||
|
||
// New creates a new port Checker instance.
|
||
func New(
|
||
_ fx.Lifecycle,
|
||
params Params,
|
||
) (*Checker, error) {
|
||
return &Checker{
|
||
log: params.Logger.Get(),
|
||
}, nil
|
||
}
|
||
|
||
// NewStandalone creates a Checker without fx dependencies.
|
||
func NewStandalone() *Checker {
|
||
return &Checker{
|
||
log: slog.Default(),
|
||
}
|
||
}
|
||
|
||
// validatePort checks that a port number is within the valid
|
||
// TCP port range (1–65535).
|
||
func validatePort(port int) error {
|
||
if port < minPort || port > maxPort {
|
||
return fmt.Errorf(
|
||
"%w: %d (must be between %d and %d)",
|
||
ErrInvalidPort, port, minPort, maxPort,
|
||
)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// CheckPort tests TCP connectivity to the given address and port.
|
||
// It uses a 5-second timeout unless the context has an earlier
|
||
// deadline.
|
||
func (c *Checker) CheckPort(
|
||
ctx context.Context,
|
||
address string,
|
||
port int,
|
||
) (*PortResult, error) {
|
||
err := validatePort(port)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
target := net.JoinHostPort(
|
||
address, strconv.Itoa(port),
|
||
)
|
||
|
||
timeout := defaultTimeout
|
||
|
||
deadline, hasDeadline := ctx.Deadline()
|
||
if hasDeadline {
|
||
remaining := time.Until(deadline)
|
||
if remaining < timeout {
|
||
timeout = remaining
|
||
}
|
||
}
|
||
|
||
return c.checkConnection(ctx, target, timeout), nil
|
||
}
|
||
|
||
// CheckPorts tests TCP connectivity to multiple ports on the
|
||
// given address concurrently. It returns a map of port number
|
||
// to result.
|
||
func (c *Checker) CheckPorts(
|
||
ctx context.Context,
|
||
address string,
|
||
ports []int,
|
||
) (map[int]*PortResult, error) {
|
||
for _, port := range ports {
|
||
err := validatePort(port)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
var mu sync.Mutex
|
||
|
||
results := make(map[int]*PortResult, len(ports))
|
||
|
||
g, ctx := errgroup.WithContext(ctx)
|
||
|
||
for _, port := range ports {
|
||
g.Go(func() error {
|
||
result, err := c.CheckPort(ctx, address, port)
|
||
if err != nil {
|
||
return fmt.Errorf(
|
||
"checking port %d: %w", port, err,
|
||
)
|
||
}
|
||
|
||
mu.Lock()
|
||
results[port] = result
|
||
mu.Unlock()
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
err := g.Wait()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// checkConnection performs the TCP dial and returns a result.
|
||
func (c *Checker) checkConnection(
|
||
ctx context.Context,
|
||
target string,
|
||
timeout time.Duration,
|
||
) *PortResult {
|
||
dialer := &net.Dialer{Timeout: timeout}
|
||
start := time.Now()
|
||
|
||
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
|
||
latency := time.Since(start)
|
||
|
||
if dialErr != nil {
|
||
c.log.Debug(
|
||
"port check failed",
|
||
"target", target,
|
||
"error", dialErr.Error(),
|
||
)
|
||
|
||
return &PortResult{
|
||
Open: false,
|
||
Error: dialErr.Error(),
|
||
Latency: latency,
|
||
}
|
||
}
|
||
|
||
closeErr := conn.Close()
|
||
if closeErr != nil {
|
||
c.log.Debug(
|
||
"closing connection",
|
||
"target", target,
|
||
"error", closeErr.Error(),
|
||
)
|
||
}
|
||
|
||
c.log.Debug(
|
||
"port check succeeded",
|
||
"target", target,
|
||
"latency", latency,
|
||
)
|
||
|
||
return &PortResult{
|
||
Open: true,
|
||
Latency: latency,
|
||
}
|
||
}
|