dnswatcher/internal/portcheck/portcheck.go

144 lines
2.7 KiB
Go

// Package portcheck provides TCP port connectivity checking.
package portcheck
import (
"context"
"fmt"
"log/slog"
"net"
"strconv"
"time"
"go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/logger"
)
const defaultTimeout = 5 * time.Second
// PortResult holds the outcome of a single TCP port check.
type PortResult struct {
// Open indicates whether the port accepted a connection.
Open bool
// Error contains a description if the connection failed.
Error string
// Latency is the time taken for the TCP handshake.
Latency time.Duration
}
// Params contains dependencies for Checker.
type Params struct {
fx.In
Logger *logger.Logger
}
// Checker performs TCP port connectivity checks.
type Checker struct {
log *slog.Logger
}
// New creates a new port Checker instance.
func New(
_ fx.Lifecycle,
params Params,
) (*Checker, error) {
return &Checker{
log: params.Logger.Get(),
}, nil
}
// NewStandalone creates a Checker without fx dependencies.
func NewStandalone() *Checker {
return &Checker{
log: slog.Default(),
}
}
// CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier
// deadline.
func (c *Checker) CheckPort(
ctx context.Context,
address string,
port int,
) (*PortResult, error) {
target := net.JoinHostPort(
address, strconv.Itoa(port),
)
deadline, hasDeadline := ctx.Deadline()
timeout := defaultTimeout
if hasDeadline {
remaining := time.Until(deadline)
if remaining < timeout {
timeout = remaining
}
}
dialer := &net.Dialer{Timeout: timeout}
start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
latency := time.Since(start)
if dialErr != nil {
c.log.Debug(
"port check failed",
"target", target,
"error", dialErr.Error(),
)
return &PortResult{
Open: false,
Error: dialErr.Error(),
Latency: latency,
}, nil
}
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing connection",
"target", target,
"error", closeErr.Error(),
)
}
c.log.Debug(
"port check succeeded",
"target", target,
"latency", latency,
)
return &PortResult{
Open: true,
Latency: latency,
}, nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address. It returns a map of port number to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
results := make(map[int]*PortResult, len(ports))
for _, port := range ports {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return nil, fmt.Errorf(
"checking port %d: %w", port, err,
)
}
results[port] = result
}
return results, nil
}