// Package portcheck provides TCP port connectivity checking. package portcheck import ( "context" "errors" "fmt" "log/slog" "net" "strconv" "sync" "time" "go.uber.org/fx" "golang.org/x/sync/errgroup" "sneak.berlin/go/dnswatcher/internal/logger" ) const ( minPort = 1 maxPort = 65535 defaultTimeout = 5 * time.Second ) // ErrInvalidPort is returned when a port number is outside // the valid TCP range (1–65535). var ErrInvalidPort = errors.New("invalid port number") // PortResult holds the outcome of a single TCP port check. type PortResult struct { // Open indicates whether the port accepted a connection. Open bool // Error contains a description if the connection failed. Error string // Latency is the time taken for the TCP handshake. Latency time.Duration } // Params contains dependencies for Checker. type Params struct { fx.In Logger *logger.Logger } // Checker performs TCP port connectivity checks. type Checker struct { log *slog.Logger } // New creates a new port Checker instance. func New( _ fx.Lifecycle, params Params, ) (*Checker, error) { return &Checker{ log: params.Logger.Get(), }, nil } // NewStandalone creates a Checker without fx dependencies. func NewStandalone() *Checker { return &Checker{ log: slog.Default(), } } // validatePort checks that a port number is within the valid // TCP port range (1–65535). func validatePort(port int) error { if port < minPort || port > maxPort { return fmt.Errorf( "%w: %d (must be between %d and %d)", ErrInvalidPort, port, minPort, maxPort, ) } return nil } // CheckPort tests TCP connectivity to the given address and port. // It uses a 5-second timeout unless the context has an earlier // deadline. func (c *Checker) CheckPort( ctx context.Context, address string, port int, ) (*PortResult, error) { err := validatePort(port) if err != nil { return nil, err } target := net.JoinHostPort( address, strconv.Itoa(port), ) timeout := defaultTimeout deadline, hasDeadline := ctx.Deadline() if hasDeadline { remaining := time.Until(deadline) if remaining < timeout { timeout = remaining } } return c.checkConnection(ctx, target, timeout), nil } // CheckPorts tests TCP connectivity to multiple ports on the // given address concurrently. It returns a map of port number // to result. func (c *Checker) CheckPorts( ctx context.Context, address string, ports []int, ) (map[int]*PortResult, error) { for _, port := range ports { err := validatePort(port) if err != nil { return nil, err } } var mu sync.Mutex results := make(map[int]*PortResult, len(ports)) g, ctx := errgroup.WithContext(ctx) for _, port := range ports { g.Go(func() error { result, err := c.CheckPort(ctx, address, port) if err != nil { return fmt.Errorf( "checking port %d: %w", port, err, ) } mu.Lock() results[port] = result mu.Unlock() return nil }) } err := g.Wait() if err != nil { return nil, err } return results, nil } // checkConnection performs the TCP dial and returns a result. func (c *Checker) checkConnection( ctx context.Context, target string, timeout time.Duration, ) *PortResult { dialer := &net.Dialer{Timeout: timeout} start := time.Now() conn, dialErr := dialer.DialContext(ctx, "tcp", target) latency := time.Since(start) if dialErr != nil { c.log.Debug( "port check failed", "target", target, "error", dialErr.Error(), ) return &PortResult{ Open: false, Error: dialErr.Error(), Latency: latency, } } closeErr := conn.Close() if closeErr != nil { c.log.Debug( "closing connection", "target", target, "error", closeErr.Error(), ) } c.log.Debug( "port check succeeded", "target", target, "latency", latency, ) return &PortResult{ Open: true, Latency: latency, } }