feat: implement TCP port connectivity checker (closes #3) #6
@ -3,18 +3,28 @@ package portcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
|
|
||||||
"sneak.berlin/go/dnswatcher/internal/logger"
|
"sneak.berlin/go/dnswatcher/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNotImplemented indicates the port checker is not yet implemented.
|
const defaultTimeout = 5 * time.Second
|
||||||
var ErrNotImplemented = errors.New(
|
|
||||||
"port checker not yet implemented",
|
// PortResult holds the outcome of a single TCP port check.
|
||||||
)
|
type PortResult struct {
|
||||||
|
// Open indicates whether the port accepted a connection.
|
||||||
|
Open bool
|
||||||
|
// Error contains a description if the connection failed.
|
||||||
|
Error string
|
||||||
|
// Latency is the time taken for the TCP handshake.
|
||||||
|
Latency time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
// Params contains dependencies for Checker.
|
// Params contains dependencies for Checker.
|
||||||
type Params struct {
|
type Params struct {
|
||||||
@ -38,11 +48,96 @@ func New(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPort tests TCP connectivity to the given address and port.
|
// NewStandalone creates a Checker without fx dependencies.
|
||||||
func (c *Checker) CheckPort(
|
func NewStandalone() *Checker {
|
||||||
_ context.Context,
|
return &Checker{
|
||||||
_ string,
|
log: slog.Default(),
|
||||||
_ int,
|
}
|
||||||
) (bool, error) {
|
}
|
||||||
return false, ErrNotImplemented
|
|
||||||
|
// CheckPort tests TCP connectivity to the given address and port.
|
||||||
|
|
|||||||
|
// It uses a 5-second timeout unless the context has an earlier
|
||||||
|
// deadline.
|
||||||
|
func (c *Checker) CheckPort(
|
||||||
|
ctx context.Context,
|
||||||
|
address string,
|
||||||
|
port int,
|
||||||
|
) (*PortResult, error) {
|
||||||
|
target := net.JoinHostPort(
|
||||||
|
address, strconv.Itoa(port),
|
||||||
|
)
|
||||||
|
|
||||||
|
deadline, hasDeadline := ctx.Deadline()
|
||||||
|
timeout := defaultTimeout
|
||||||
|
|
||||||
|
if hasDeadline {
|
||||||
|
remaining := time.Until(deadline)
|
||||||
|
if remaining < timeout {
|
||||||
|
timeout = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{Timeout: timeout}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
|
||||||
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
if dialErr != nil {
|
||||||
|
c.log.Debug(
|
||||||
|
clawbot
commented
`CheckPorts` checks ports sequentially — for large port lists this could be slow (5s timeout × N ports worst case). Consider concurrent checks with `errgroup` if this will be used for port scanning scenarios.
|
|||||||
|
"port check failed",
|
||||||
|
"target", target,
|
||||||
|
"error", dialErr.Error(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return &PortResult{
|
||||||
|
Open: false,
|
||||||
|
Error: dialErr.Error(),
|
||||||
|
Latency: latency,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
closeErr := conn.Close()
|
||||||
|
if closeErr != nil {
|
||||||
|
c.log.Debug(
|
||||||
|
"closing connection",
|
||||||
|
"target", target,
|
||||||
|
"error", closeErr.Error(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Debug(
|
||||||
|
"port check succeeded",
|
||||||
|
"target", target,
|
||||||
|
"latency", latency,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &PortResult{
|
||||||
|
Open: true,
|
||||||
|
Latency: latency,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPorts tests TCP connectivity to multiple ports on the
|
||||||
|
// given address. It returns a map of port number to result.
|
||||||
|
func (c *Checker) CheckPorts(
|
||||||
|
ctx context.Context,
|
||||||
|
address string,
|
||||||
|
ports []int,
|
||||||
|
) (map[int]*PortResult, error) {
|
||||||
|
results := make(map[int]*PortResult, len(ports))
|
||||||
|
|
||||||
|
for _, port := range ports {
|
||||||
|
result, err := c.CheckPort(ctx, address, port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"checking port %d: %w", port, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
results[port] = result
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|||||||
163
internal/portcheck/portcheck_test.go
Normal file
163
internal/portcheck/portcheck_test.go
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
package portcheck_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||||
|
)
|
||||||
|
|
||||||
|
func listenTCP(
|
||||||
|
t *testing.T,
|
||||||
|
) (net.Listener, int) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
lc := &net.ListenConfig{}
|
||||||
|
|
||||||
|
ln, err := lc.Listen(
|
||||||
|
context.Background(), "tcp", "127.0.0.1:0",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("unexpected address type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ln, addr.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortOpen(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, port := listenTCP(t)
|
||||||
|
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", port,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Open {
|
||||||
|
t.Error("expected port to be open")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != "" {
|
||||||
|
t.Errorf("expected no error, got: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Latency <= 0 {
|
||||||
|
t.Error("expected positive latency")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortClosed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, port := listenTCP(t)
|
||||||
|
_ = ln.Close()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", port,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Open {
|
||||||
|
t.Error("expected port to be closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error == "" {
|
||||||
|
t.Error("expected error message for closed port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortContextCanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(ctx, "127.0.0.1", 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Open {
|
||||||
|
t.Error("expected port to not be open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortsMultiple(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, openPort := listenTCP(t)
|
||||||
|
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
ln2, closedPort := listenTCP(t)
|
||||||
|
_ = ln2.Close()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
results, err := checker.CheckPorts(
|
||||||
|
context.Background(),
|
||||||
|
"127.0.0.1",
|
||||||
|
[]int{openPort, closedPort},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected 2 results, got %d", len(results),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !results[openPort].Open {
|
||||||
|
t.Error("expected open port to be open")
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[closedPort].Open {
|
||||||
|
t.Error("expected closed port to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortLatencyReasonable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, port := listenTCP(t)
|
||||||
|
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", port,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Latency > time.Second {
|
||||||
|
t.Errorf(
|
||||||
|
"latency too high for localhost: %v",
|
||||||
|
result.Latency,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user
Minor: no validation that
portis in valid range (1-65535).net.Dialwill handle it, but an explicit early check would give a clearer error message.