feat: implement TCP port connectivity checker (closes #3) #6
1
go.mod
1
go.mod
@ -11,6 +11,7 @@ require (
|
|||||||
github.com/spf13/viper v1.21.0
|
github.com/spf13/viper v1.21.0
|
||||||
go.uber.org/fx v1.24.0
|
go.uber.org/fx v1.24.0
|
||||||
golang.org/x/net v0.50.0
|
golang.org/x/net v0.50.0
|
||||||
|
golang.org/x/sync v0.19.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
2
go.sum
2
go.sum
@ -76,6 +76,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
|||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
|
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
|
||||||
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
|
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
|
||||||
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||||
|
|||||||
@ -3,18 +3,29 @@ package portcheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"sneak.berlin/go/dnswatcher/internal/logger"
|
"sneak.berlin/go/dnswatcher/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultTimeout = 5 * time.Second
|
const (
|
||||||
|
minPort = 1
|
||||||
|
maxPort = 65535
|
||||||
|
defaultTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidPort is returned when a port number is outside
|
||||||
|
// the valid TCP range (1–65535).
|
||||||
|
var ErrInvalidPort = errors.New("invalid port number")
|
||||||
|
|
||||||
// PortResult holds the outcome of a single TCP port check.
|
// PortResult holds the outcome of a single TCP port check.
|
||||||
type PortResult struct {
|
type PortResult struct {
|
||||||
@ -55,6 +66,19 @@ func NewStandalone() *Checker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validatePort checks that a port number is within the valid
|
||||||
|
// TCP port range (1–65535).
|
||||||
|
func validatePort(port int) error {
|
||||||
|
if port < minPort || port > maxPort {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"%w: %d (must be between %d and %d)",
|
||||||
|
ErrInvalidPort, port, minPort, maxPort,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CheckPort tests TCP connectivity to the given address and port.
|
// CheckPort tests TCP connectivity to the given address and port.
|
||||||
// It uses a 5-second timeout unless the context has an earlier
|
// It uses a 5-second timeout unless the context has an earlier
|
||||||
// deadline.
|
// deadline.
|
||||||
@ -63,13 +87,18 @@ func (c *Checker) CheckPort(
|
|||||||
address string,
|
address string,
|
||||||
port int,
|
port int,
|
||||||
|
|
|||||||
) (*PortResult, error) {
|
) (*PortResult, error) {
|
||||||
|
err := validatePort(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
target := net.JoinHostPort(
|
target := net.JoinHostPort(
|
||||||
address, strconv.Itoa(port),
|
address, strconv.Itoa(port),
|
||||||
)
|
)
|
||||||
|
|
||||||
deadline, hasDeadline := ctx.Deadline()
|
|
||||||
timeout := defaultTimeout
|
timeout := defaultTimeout
|
||||||
|
|
||||||
|
deadline, hasDeadline := ctx.Deadline()
|
||||||
if hasDeadline {
|
if hasDeadline {
|
||||||
remaining := time.Until(deadline)
|
remaining := time.Until(deadline)
|
||||||
if remaining < timeout {
|
if remaining < timeout {
|
||||||
@ -77,8 +106,62 @@ func (c *Checker) CheckPort(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer := &net.Dialer{Timeout: timeout}
|
return c.checkConnection(ctx, target, timeout), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPorts tests TCP connectivity to multiple ports on the
|
||||||
|
// given address concurrently. It returns a map of port number
|
||||||
|
// to result.
|
||||||
|
func (c *Checker) CheckPorts(
|
||||||
|
ctx context.Context,
|
||||||
|
address string,
|
||||||
|
ports []int,
|
||||||
|
) (map[int]*PortResult, error) {
|
||||||
|
for _, port := range ports {
|
||||||
|
err := validatePort(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
results := make(map[int]*PortResult, len(ports))
|
||||||
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
for _, port := range ports {
|
||||||
|
g.Go(func() error {
|
||||||
|
result, err := c.CheckPort(ctx, address, port)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"checking port %d: %w", port, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
results[port] = result
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err := g.Wait()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkConnection performs the TCP dial and returns a result.
|
||||||
|
func (c *Checker) checkConnection(
|
||||||
|
ctx context.Context,
|
||||||
|
target string,
|
||||||
|
timeout time.Duration,
|
||||||
|
) *PortResult {
|
||||||
|
dialer := &net.Dialer{Timeout: timeout}
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
|
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
|
||||||
@ -95,7 +178,7 @@ func (c *Checker) CheckPort(
|
|||||||
Open: false,
|
Open: false,
|
||||||
Error: dialErr.Error(),
|
Error: dialErr.Error(),
|
||||||
Latency: latency,
|
Latency: latency,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
closeErr := conn.Close()
|
closeErr := conn.Close()
|
||||||
@ -116,28 +199,5 @@ func (c *Checker) CheckPort(
|
|||||||
return &PortResult{
|
return &PortResult{
|
||||||
Open: true,
|
Open: true,
|
||||||
Latency: latency,
|
Latency: latency,
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPorts tests TCP connectivity to multiple ports on the
|
|
||||||
// given address. It returns a map of port number to result.
|
|
||||||
func (c *Checker) CheckPorts(
|
|
||||||
ctx context.Context,
|
|
||||||
address string,
|
|
||||||
ports []int,
|
|
||||||
) (map[int]*PortResult, error) {
|
|
||||||
results := make(map[int]*PortResult, len(ports))
|
|
||||||
|
|
||||||
for _, port := range ports {
|
|
||||||
result, err := c.CheckPort(ctx, address, port)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf(
|
|
||||||
"checking port %d: %w", port, err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
results[port] = result
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -138,6 +138,54 @@ func TestCheckPortsMultiple(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCheckPortInvalidPorts(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
port int
|
||||||
|
}{
|
||||||
|
{"zero", 0},
|
||||||
|
{"negative", -1},
|
||||||
|
{"too high", 65536},
|
||||||
|
{"very negative", -1000},
|
||||||
|
{"very high", 100000},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", tc.port,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf(
|
||||||
|
"expected error for port %d, got nil",
|
||||||
|
tc.port,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortsInvalidPort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
_, err := checker.CheckPorts(
|
||||||
|
context.Background(),
|
||||||
|
"127.0.0.1",
|
||||||
|
[]int{80, 0, 443},
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid port in list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCheckPortLatencyReasonable(t *testing.T) {
|
func TestCheckPortLatencyReasonable(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ package watcher
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ type PortChecker interface {
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
address string,
|
address string,
|
||||||
port int,
|
port int,
|
||||||
) (bool, error)
|
) (*portcheck.PortResult, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSChecker inspects TLS certificates.
|
// TLSChecker inspects TLS certificates.
|
||||||
|
|||||||
@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort(
|
|||||||
port int,
|
port int,
|
||||||
hostname string,
|
hostname string,
|
||||||
) {
|
) {
|
||||||
open, err := w.portCheck.CheckPort(ctx, ip, port)
|
result, err := w.portCheck.CheckPort(ctx, ip, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Error(
|
w.log.Error(
|
||||||
"port check failed",
|
"port check failed",
|
||||||
@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort(
|
|||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
prev, hasPrev := w.state.GetPortState(key)
|
prev, hasPrev := w.state.GetPortState(key)
|
||||||
|
|
||||||
if hasPrev && !w.firstRun && prev.Open != open {
|
if hasPrev && !w.firstRun && prev.Open != result.Open {
|
||||||
stateStr := "closed"
|
stateStr := "closed"
|
||||||
if open {
|
if result.Open {
|
||||||
stateStr = "open"
|
stateStr = "open"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort(
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.state.SetPortState(key, &state.PortState{
|
w.state.SetPortState(key, &state.PortState{
|
||||||
Open: open,
|
Open: result.Open,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
LastChecked: now,
|
LastChecked: now,
|
||||||
})
|
})
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sneak.berlin/go/dnswatcher/internal/config"
|
"sneak.berlin/go/dnswatcher/internal/config"
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||||
"sneak.berlin/go/dnswatcher/internal/state"
|
"sneak.berlin/go/dnswatcher/internal/state"
|
||||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||||
"sneak.berlin/go/dnswatcher/internal/watcher"
|
"sneak.berlin/go/dnswatcher/internal/watcher"
|
||||||
@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort(
|
|||||||
_ context.Context,
|
_ context.Context,
|
||||||
address string,
|
address string,
|
||||||
port int,
|
port int,
|
||||||
) (bool, error) {
|
) (*portcheck.PortResult, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
m.calls++
|
m.calls++
|
||||||
|
|
||||||
if m.err != nil {
|
if m.err != nil {
|
||||||
return false, m.err
|
return nil, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s:%d", address, port)
|
key := fmt.Sprintf("%s:%d", address, port)
|
||||||
open, ok := m.results[key]
|
open := m.results[key]
|
||||||
|
|
||||||
if !ok {
|
return &portcheck.PortResult{Open: open}, nil
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return open, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockTLSChecker struct {
|
type mockTLSChecker struct {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user
CheckPortschecks ports sequentially — for large port lists this could be slow (5s timeout × N ports worst case). Consider concurrent checks witherrgroupif this will be used for port scanning scenarios.