feat: implement TCP port connectivity checker (closes #3) #6

Merged
sneak merged 2 commits from feature/portcheck-implementation into main 2026-02-20 19:38:37 +01:00
7 changed files with 149 additions and 40 deletions
Showing only changes of commit 57cd228837 - Show all commits

1
go.mod
View File

@ -11,6 +11,7 @@ require (
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0 golang.org/x/net v0.50.0
golang.org/x/sync v0.19.0
) )
require ( require (

2
go.sum
View File

@ -76,6 +76,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=

View File

@ -3,18 +3,29 @@ package portcheck
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
"strconv" "strconv"
"sync"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
"golang.org/x/sync/errgroup"
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
const defaultTimeout = 5 * time.Second const (
minPort = 1
maxPort = 65535
defaultTimeout = 5 * time.Second
)
// ErrInvalidPort is returned when a port number is outside
// the valid TCP range (165535).
var ErrInvalidPort = errors.New("invalid port number")
// PortResult holds the outcome of a single TCP port check. // PortResult holds the outcome of a single TCP port check.
type PortResult struct { type PortResult struct {
@ -55,6 +66,19 @@ func NewStandalone() *Checker {
} }
} }
// validatePort checks that a port number is within the valid
// TCP port range (165535).
func validatePort(port int) error {
if port < minPort || port > maxPort {
return fmt.Errorf(
"%w: %d (must be between %d and %d)",
ErrInvalidPort, port, minPort, maxPort,
)
}
return nil
}
// CheckPort tests TCP connectivity to the given address and port. // CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier // It uses a 5-second timeout unless the context has an earlier
// deadline. // deadline.
@ -63,13 +87,18 @@ func (c *Checker) CheckPort(
address string, address string,
port int, port int,

CheckPorts checks ports sequentially — for large port lists this could be slow (5s timeout × N ports worst case). Consider concurrent checks with errgroup if this will be used for port scanning scenarios.

`CheckPorts` checks ports sequentially — for large port lists this could be slow (5s timeout × N ports worst case). Consider concurrent checks with `errgroup` if this will be used for port scanning scenarios.
) (*PortResult, error) { ) (*PortResult, error) {
err := validatePort(port)
if err != nil {
return nil, err
}
target := net.JoinHostPort( target := net.JoinHostPort(
address, strconv.Itoa(port), address, strconv.Itoa(port),
) )
deadline, hasDeadline := ctx.Deadline()
timeout := defaultTimeout timeout := defaultTimeout
deadline, hasDeadline := ctx.Deadline()
if hasDeadline { if hasDeadline {
remaining := time.Until(deadline) remaining := time.Until(deadline)
if remaining < timeout { if remaining < timeout {
@ -77,8 +106,62 @@ func (c *Checker) CheckPort(
} }
} }
dialer := &net.Dialer{Timeout: timeout} return c.checkConnection(ctx, target, timeout), nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address concurrently. It returns a map of port number
// to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
for _, port := range ports {
err := validatePort(port)
if err != nil {
return nil, err
}
}
var mu sync.Mutex
results := make(map[int]*PortResult, len(ports))
g, ctx := errgroup.WithContext(ctx)
for _, port := range ports {
g.Go(func() error {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return fmt.Errorf(
"checking port %d: %w", port, err,
)
}
mu.Lock()
results[port] = result
mu.Unlock()
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return results, nil
}
// checkConnection performs the TCP dial and returns a result.
func (c *Checker) checkConnection(
ctx context.Context,
target string,
timeout time.Duration,
) *PortResult {
dialer := &net.Dialer{Timeout: timeout}
start := time.Now() start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target) conn, dialErr := dialer.DialContext(ctx, "tcp", target)
@ -95,7 +178,7 @@ func (c *Checker) CheckPort(
Open: false, Open: false,
Error: dialErr.Error(), Error: dialErr.Error(),
Latency: latency, Latency: latency,
}, nil }
} }
closeErr := conn.Close() closeErr := conn.Close()
@ -116,28 +199,5 @@ func (c *Checker) CheckPort(
return &PortResult{ return &PortResult{
Open: true, Open: true,
Latency: latency, Latency: latency,
}, nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address. It returns a map of port number to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
results := make(map[int]*PortResult, len(ports))
for _, port := range ports {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return nil, fmt.Errorf(
"checking port %d: %w", port, err,
)
}
results[port] = result
} }
return results, nil
} }

View File

@ -138,6 +138,54 @@ func TestCheckPortsMultiple(t *testing.T) {
} }
} }
func TestCheckPortInvalidPorts(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
cases := []struct {
name string
port int
}{
{"zero", 0},
{"negative", -1},
{"too high", 65536},
{"very negative", -1000},
{"very high", 100000},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := checker.CheckPort(
context.Background(), "127.0.0.1", tc.port,
)
if err == nil {
t.Errorf(
"expected error for port %d, got nil",
tc.port,
)
}
})
}
}
func TestCheckPortsInvalidPort(t *testing.T) {
t.Parallel()
checker := portcheck.NewStandalone()
_, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{80, 0, 443},
)
if err == nil {
t.Error("expected error for invalid port in list")
}
}
func TestCheckPortLatencyReasonable(t *testing.T) { func TestCheckPortLatencyReasonable(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -4,6 +4,7 @@ package watcher
import ( import (
"context" "context"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/tlscheck"
) )
@ -36,7 +37,7 @@ type PortChecker interface {
ctx context.Context, ctx context.Context,
address string, address string,
port int, port int,
) (bool, error) ) (*portcheck.PortResult, error)
} }
// TLSChecker inspects TLS certificates. // TLSChecker inspects TLS certificates.

View File

@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort(
port int, port int,
hostname string, hostname string,
) { ) {
open, err := w.portCheck.CheckPort(ctx, ip, port) result, err := w.portCheck.CheckPort(ctx, ip, port)
if err != nil { if err != nil {
w.log.Error( w.log.Error(
"port check failed", "port check failed",
@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort(
now := time.Now().UTC() now := time.Now().UTC()
prev, hasPrev := w.state.GetPortState(key) prev, hasPrev := w.state.GetPortState(key)
if hasPrev && !w.firstRun && prev.Open != open { if hasPrev && !w.firstRun && prev.Open != result.Open {
stateStr := "closed" stateStr := "closed"
if open { if result.Open {
stateStr = "open" stateStr = "open"
} }
@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort(
} }
w.state.SetPortState(key, &state.PortState{ w.state.SetPortState(key, &state.PortState{
Open: open, Open: result.Open,
Hostname: hostname, Hostname: hostname,
LastChecked: now, LastChecked: now,
}) })

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"sneak.berlin/go/dnswatcher/internal/config" "sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/portcheck"
"sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/state"
"sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/tlscheck"
"sneak.berlin/go/dnswatcher/internal/watcher" "sneak.berlin/go/dnswatcher/internal/watcher"
@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort(
_ context.Context, _ context.Context,
address string, address string,
port int, port int,
) (bool, error) { ) (*portcheck.PortResult, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.calls++ m.calls++
if m.err != nil { if m.err != nil {
return false, m.err return nil, m.err
} }
key := fmt.Sprintf("%s:%d", address, port) key := fmt.Sprintf("%s:%d", address, port)
open, ok := m.results[key] open := m.results[key]
if !ok { return &portcheck.PortResult{Open: open}, nil
return false, nil
}
return open, nil
} }
type mockTLSChecker struct { type mockTLSChecker struct {