feat: implement TCP port connectivity checker (closes #3) #6
1
go.mod
1
go.mod
@ -11,6 +11,7 @@ require (
|
||||
github.com/spf13/viper v1.21.0
|
||||
go.uber.org/fx v1.24.0
|
||||
golang.org/x/net v0.50.0
|
||||
golang.org/x/sync v0.19.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
2
go.sum
2
go.sum
@ -76,6 +76,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
|
||||
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
|
||||
@ -4,18 +4,39 @@ package portcheck
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/fx"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/logger"
|
||||
)
|
||||
|
||||
// ErrNotImplemented indicates the port checker is not yet implemented.
|
||||
var ErrNotImplemented = errors.New(
|
||||
"port checker not yet implemented",
|
||||
const (
|
||||
minPort = 1
|
||||
maxPort = 65535
|
||||
defaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// ErrInvalidPort is returned when a port number is outside
|
||||
// the valid TCP range (1–65535).
|
||||
var ErrInvalidPort = errors.New("invalid port number")
|
||||
|
||||
// PortResult holds the outcome of a single TCP port check.
|
||||
type PortResult struct {
|
||||
// Open indicates whether the port accepted a connection.
|
||||
Open bool
|
||||
// Error contains a description if the connection failed.
|
||||
Error string
|
||||
// Latency is the time taken for the TCP handshake.
|
||||
Latency time.Duration
|
||||
}
|
||||
|
||||
// Params contains dependencies for Checker.
|
||||
type Params struct {
|
||||
fx.In
|
||||
@ -38,11 +59,145 @@ func New(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckPort tests TCP connectivity to the given address and port.
|
||||
func (c *Checker) CheckPort(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ int,
|
||||
) (bool, error) {
|
||||
return false, ErrNotImplemented
|
||||
// NewStandalone creates a Checker without fx dependencies.
|
||||
func NewStandalone() *Checker {
|
||||
return &Checker{
|
||||
log: slog.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// validatePort checks that a port number is within the valid
|
||||
// TCP port range (1–65535).
|
||||
func validatePort(port int) error {
|
||||
if port < minPort || port > maxPort {
|
||||
return fmt.Errorf(
|
||||
"%w: %d (must be between %d and %d)",
|
||||
ErrInvalidPort, port, minPort, maxPort,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPort tests TCP connectivity to the given address and port.
|
||||
// It uses a 5-second timeout unless the context has an earlier
|
||||
// deadline.
|
||||
func (c *Checker) CheckPort(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
port int,
|
||||
|
|
||||
) (*PortResult, error) {
|
||||
err := validatePort(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
target := net.JoinHostPort(
|
||||
address, strconv.Itoa(port),
|
||||
)
|
||||
|
||||
timeout := defaultTimeout
|
||||
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
if hasDeadline {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining < timeout {
|
||||
timeout = remaining
|
||||
}
|
||||
}
|
||||
|
||||
return c.checkConnection(ctx, target, timeout), nil
|
||||
}
|
||||
|
||||
// CheckPorts tests TCP connectivity to multiple ports on the
|
||||
// given address concurrently. It returns a map of port number
|
||||
// to result.
|
||||
func (c *Checker) CheckPorts(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
ports []int,
|
||||
) (map[int]*PortResult, error) {
|
||||
for _, port := range ports {
|
||||
err := validatePort(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
results := make(map[int]*PortResult, len(ports))
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
for _, port := range ports {
|
||||
g.Go(func() error {
|
||||
result, err := c.CheckPort(ctx, address, port)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"checking port %d: %w", port, err,
|
||||
)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
results[port] = result
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err := g.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// checkConnection performs the TCP dial and returns a result.
|
||||
func (c *Checker) checkConnection(
|
||||
ctx context.Context,
|
||||
target string,
|
||||
timeout time.Duration,
|
||||
) *PortResult {
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
start := time.Now()
|
||||
|
||||
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
|
||||
latency := time.Since(start)
|
||||
|
||||
if dialErr != nil {
|
||||
c.log.Debug(
|
||||
"port check failed",
|
||||
"target", target,
|
||||
"error", dialErr.Error(),
|
||||
)
|
||||
|
||||
return &PortResult{
|
||||
Open: false,
|
||||
Error: dialErr.Error(),
|
||||
Latency: latency,
|
||||
}
|
||||
}
|
||||
|
||||
closeErr := conn.Close()
|
||||
if closeErr != nil {
|
||||
c.log.Debug(
|
||||
"closing connection",
|
||||
"target", target,
|
||||
"error", closeErr.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
c.log.Debug(
|
||||
"port check succeeded",
|
||||
"target", target,
|
||||
"latency", latency,
|
||||
)
|
||||
|
||||
return &PortResult{
|
||||
Open: true,
|
||||
Latency: latency,
|
||||
}
|
||||
}
|
||||
|
||||
211
internal/portcheck/portcheck_test.go
Normal file
211
internal/portcheck/portcheck_test.go
Normal file
@ -0,0 +1,211 @@
|
||||
package portcheck_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||
)
|
||||
|
||||
func listenTCP(
|
||||
t *testing.T,
|
||||
) (net.Listener, int) {
|
||||
t.Helper()
|
||||
|
||||
lc := &net.ListenConfig{}
|
||||
|
||||
ln, err := lc.Listen(
|
||||
context.Background(), "tcp", "127.0.0.1:0",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start listener: %v", err)
|
||||
}
|
||||
|
||||
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
t.Fatal("unexpected address type")
|
||||
}
|
||||
|
||||
return ln, addr.Port
|
||||
}
|
||||
|
||||
func TestCheckPortOpen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, port := listenTCP(t)
|
||||
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(
|
||||
context.Background(), "127.0.0.1", port,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !result.Open {
|
||||
t.Error("expected port to be open")
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
t.Errorf("expected no error, got: %s", result.Error)
|
||||
}
|
||||
|
||||
if result.Latency <= 0 {
|
||||
t.Error("expected positive latency")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, port := listenTCP(t)
|
||||
_ = ln.Close()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(
|
||||
context.Background(), "127.0.0.1", port,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Open {
|
||||
t.Error("expected port to be closed")
|
||||
}
|
||||
|
||||
if result.Error == "" {
|
||||
t.Error("expected error message for closed port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(ctx, "127.0.0.1", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Open {
|
||||
t.Error("expected port to not be open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortsMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, openPort := listenTCP(t)
|
||||
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
ln2, closedPort := listenTCP(t)
|
||||
_ = ln2.Close()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
results, err := checker.CheckPorts(
|
||||
context.Background(),
|
||||
"127.0.0.1",
|
||||
[]int{openPort, closedPort},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatalf(
|
||||
"expected 2 results, got %d", len(results),
|
||||
)
|
||||
}
|
||||
|
||||
if !results[openPort].Open {
|
||||
t.Error("expected open port to be open")
|
||||
}
|
||||
|
||||
if results[closedPort].Open {
|
||||
t.Error("expected closed port to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortInvalidPorts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
port int
|
||||
}{
|
||||
{"zero", 0},
|
||||
{"negative", -1},
|
||||
{"too high", 65536},
|
||||
{"very negative", -1000},
|
||||
{"very high", 100000},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := checker.CheckPort(
|
||||
context.Background(), "127.0.0.1", tc.port,
|
||||
)
|
||||
if err == nil {
|
||||
t.Errorf(
|
||||
"expected error for port %d, got nil",
|
||||
tc.port,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortsInvalidPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
_, err := checker.CheckPorts(
|
||||
context.Background(),
|
||||
"127.0.0.1",
|
||||
[]int{80, 0, 443},
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid port in list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortLatencyReasonable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, port := listenTCP(t)
|
||||
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(
|
||||
context.Background(), "127.0.0.1", port,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Latency > time.Second {
|
||||
t.Errorf(
|
||||
"latency too high for localhost: %v",
|
||||
result.Latency,
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ package watcher
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||
)
|
||||
|
||||
@ -36,7 +37,7 @@ type PortChecker interface {
|
||||
ctx context.Context,
|
||||
address string,
|
||||
port int,
|
||||
) (bool, error)
|
||||
) (*portcheck.PortResult, error)
|
||||
}
|
||||
|
||||
// TLSChecker inspects TLS certificates.
|
||||
|
||||
@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort(
|
||||
port int,
|
||||
hostname string,
|
||||
) {
|
||||
open, err := w.portCheck.CheckPort(ctx, ip, port)
|
||||
result, err := w.portCheck.CheckPort(ctx, ip, port)
|
||||
if err != nil {
|
||||
w.log.Error(
|
||||
"port check failed",
|
||||
@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort(
|
||||
now := time.Now().UTC()
|
||||
prev, hasPrev := w.state.GetPortState(key)
|
||||
|
||||
if hasPrev && !w.firstRun && prev.Open != open {
|
||||
if hasPrev && !w.firstRun && prev.Open != result.Open {
|
||||
stateStr := "closed"
|
||||
if open {
|
||||
if result.Open {
|
||||
stateStr = "open"
|
||||
}
|
||||
|
||||
@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort(
|
||||
}
|
||||
|
||||
w.state.SetPortState(key, &state.PortState{
|
||||
Open: open,
|
||||
Open: result.Open,
|
||||
Hostname: hostname,
|
||||
LastChecked: now,
|
||||
})
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/config"
|
||||
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||
"sneak.berlin/go/dnswatcher/internal/state"
|
||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||
"sneak.berlin/go/dnswatcher/internal/watcher"
|
||||
@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort(
|
||||
_ context.Context,
|
||||
address string,
|
||||
port int,
|
||||
) (bool, error) {
|
||||
) (*portcheck.PortResult, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.calls++
|
||||
|
||||
if m.err != nil {
|
||||
return false, m.err
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s:%d", address, port)
|
||||
open, ok := m.results[key]
|
||||
open := m.results[key]
|
||||
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return open, nil
|
||||
return &portcheck.PortResult{Open: open}, nil
|
||||
}
|
||||
|
||||
type mockTLSChecker struct {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user
CheckPortschecks ports sequentially — for large port lists this could be slow (5s timeout × N ports worst case). Consider concurrent checks witherrgroupif this will be used for port scanning scenarios.