diff --git a/go.mod b/go.mod index 32ad532..234b2f3 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/spf13/viper v1.21.0 go.uber.org/fx v1.24.0 golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 ) require ( diff --git a/go.sum b/go.sum index 66cc528..0edfeb0 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,8 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= diff --git a/internal/portcheck/portcheck.go b/internal/portcheck/portcheck.go index 2c061b7..a57b230 100644 --- a/internal/portcheck/portcheck.go +++ b/internal/portcheck/portcheck.go @@ -4,18 +4,39 @@ package portcheck import ( "context" "errors" + "fmt" "log/slog" + "net" + "strconv" + "sync" + "time" "go.uber.org/fx" + "golang.org/x/sync/errgroup" "sneak.berlin/go/dnswatcher/internal/logger" ) -// ErrNotImplemented indicates the port checker is not yet implemented. -var ErrNotImplemented = errors.New( - "port checker not yet implemented", +const ( + minPort = 1 + maxPort = 65535 + defaultTimeout = 5 * time.Second ) +// ErrInvalidPort is returned when a port number is outside +// the valid TCP range (1–65535). +var ErrInvalidPort = errors.New("invalid port number") + +// PortResult holds the outcome of a single TCP port check. +type PortResult struct { + // Open indicates whether the port accepted a connection. + Open bool + // Error contains a description if the connection failed. + Error string + // Latency is the time taken for the TCP handshake. + Latency time.Duration +} + // Params contains dependencies for Checker. type Params struct { fx.In @@ -38,11 +59,145 @@ func New( }, nil } -// CheckPort tests TCP connectivity to the given address and port. -func (c *Checker) CheckPort( - _ context.Context, - _ string, - _ int, -) (bool, error) { - return false, ErrNotImplemented +// NewStandalone creates a Checker without fx dependencies. +func NewStandalone() *Checker { + return &Checker{ + log: slog.Default(), + } +} + +// validatePort checks that a port number is within the valid +// TCP port range (1–65535). +func validatePort(port int) error { + if port < minPort || port > maxPort { + return fmt.Errorf( + "%w: %d (must be between %d and %d)", + ErrInvalidPort, port, minPort, maxPort, + ) + } + + return nil +} + +// CheckPort tests TCP connectivity to the given address and port. +// It uses a 5-second timeout unless the context has an earlier +// deadline. +func (c *Checker) CheckPort( + ctx context.Context, + address string, + port int, +) (*PortResult, error) { + err := validatePort(port) + if err != nil { + return nil, err + } + + target := net.JoinHostPort( + address, strconv.Itoa(port), + ) + + timeout := defaultTimeout + + deadline, hasDeadline := ctx.Deadline() + if hasDeadline { + remaining := time.Until(deadline) + if remaining < timeout { + timeout = remaining + } + } + + return c.checkConnection(ctx, target, timeout), nil +} + +// CheckPorts tests TCP connectivity to multiple ports on the +// given address concurrently. It returns a map of port number +// to result. +func (c *Checker) CheckPorts( + ctx context.Context, + address string, + ports []int, +) (map[int]*PortResult, error) { + for _, port := range ports { + err := validatePort(port) + if err != nil { + return nil, err + } + } + + var mu sync.Mutex + + results := make(map[int]*PortResult, len(ports)) + + g, ctx := errgroup.WithContext(ctx) + + for _, port := range ports { + g.Go(func() error { + result, err := c.CheckPort(ctx, address, port) + if err != nil { + return fmt.Errorf( + "checking port %d: %w", port, err, + ) + } + + mu.Lock() + results[port] = result + mu.Unlock() + + return nil + }) + } + + err := g.Wait() + if err != nil { + return nil, err + } + + return results, nil +} + +// checkConnection performs the TCP dial and returns a result. +func (c *Checker) checkConnection( + ctx context.Context, + target string, + timeout time.Duration, +) *PortResult { + dialer := &net.Dialer{Timeout: timeout} + start := time.Now() + + conn, dialErr := dialer.DialContext(ctx, "tcp", target) + latency := time.Since(start) + + if dialErr != nil { + c.log.Debug( + "port check failed", + "target", target, + "error", dialErr.Error(), + ) + + return &PortResult{ + Open: false, + Error: dialErr.Error(), + Latency: latency, + } + } + + closeErr := conn.Close() + if closeErr != nil { + c.log.Debug( + "closing connection", + "target", target, + "error", closeErr.Error(), + ) + } + + c.log.Debug( + "port check succeeded", + "target", target, + "latency", latency, + ) + + return &PortResult{ + Open: true, + Latency: latency, + } } diff --git a/internal/portcheck/portcheck_test.go b/internal/portcheck/portcheck_test.go new file mode 100644 index 0000000..7ea7fc0 --- /dev/null +++ b/internal/portcheck/portcheck_test.go @@ -0,0 +1,211 @@ +package portcheck_test + +import ( + "context" + "net" + "testing" + "time" + + "sneak.berlin/go/dnswatcher/internal/portcheck" +) + +func listenTCP( + t *testing.T, +) (net.Listener, int) { + t.Helper() + + lc := &net.ListenConfig{} + + ln, err := lc.Listen( + context.Background(), "tcp", "127.0.0.1:0", + ) + if err != nil { + t.Fatalf("failed to start listener: %v", err) + } + + addr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + t.Fatal("unexpected address type") + } + + return ln, addr.Port +} + +func TestCheckPortOpen(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + + defer func() { _ = ln.Close() }() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !result.Open { + t.Error("expected port to be open") + } + + if result.Error != "" { + t.Errorf("expected no error, got: %s", result.Error) + } + + if result.Latency <= 0 { + t.Error("expected positive latency") + } +} + +func TestCheckPortClosed(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + _ = ln.Close() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Open { + t.Error("expected port to be closed") + } + + if result.Error == "" { + t.Error("expected error message for closed port") + } +} + +func TestCheckPortContextCanceled(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort(ctx, "127.0.0.1", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Open { + t.Error("expected port to not be open") + } +} + +func TestCheckPortsMultiple(t *testing.T) { + t.Parallel() + + ln, openPort := listenTCP(t) + + defer func() { _ = ln.Close() }() + + ln2, closedPort := listenTCP(t) + _ = ln2.Close() + + checker := portcheck.NewStandalone() + + results, err := checker.CheckPorts( + context.Background(), + "127.0.0.1", + []int{openPort, closedPort}, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(results) != 2 { + t.Fatalf( + "expected 2 results, got %d", len(results), + ) + } + + if !results[openPort].Open { + t.Error("expected open port to be open") + } + + if results[closedPort].Open { + t.Error("expected closed port to be closed") + } +} + +func TestCheckPortInvalidPorts(t *testing.T) { + t.Parallel() + + checker := portcheck.NewStandalone() + + cases := []struct { + name string + port int + }{ + {"zero", 0}, + {"negative", -1}, + {"too high", 65536}, + {"very negative", -1000}, + {"very high", 100000}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := checker.CheckPort( + context.Background(), "127.0.0.1", tc.port, + ) + if err == nil { + t.Errorf( + "expected error for port %d, got nil", + tc.port, + ) + } + }) + } +} + +func TestCheckPortsInvalidPort(t *testing.T) { + t.Parallel() + + checker := portcheck.NewStandalone() + + _, err := checker.CheckPorts( + context.Background(), + "127.0.0.1", + []int{80, 0, 443}, + ) + if err == nil { + t.Error("expected error for invalid port in list") + } +} + +func TestCheckPortLatencyReasonable(t *testing.T) { + t.Parallel() + + ln, port := listenTCP(t) + + defer func() { _ = ln.Close() }() + + checker := portcheck.NewStandalone() + + result, err := checker.CheckPort( + context.Background(), "127.0.0.1", port, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Latency > time.Second { + t.Errorf( + "latency too high for localhost: %v", + result.Latency, + ) + } +} diff --git a/internal/watcher/interfaces.go b/internal/watcher/interfaces.go index 695139d..dd68017 100644 --- a/internal/watcher/interfaces.go +++ b/internal/watcher/interfaces.go @@ -4,6 +4,7 @@ package watcher import ( "context" + "sneak.berlin/go/dnswatcher/internal/portcheck" "sneak.berlin/go/dnswatcher/internal/tlscheck" ) @@ -36,7 +37,7 @@ type PortChecker interface { ctx context.Context, address string, port int, - ) (bool, error) + ) (*portcheck.PortResult, error) } // TLSChecker inspects TLS certificates. diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 2493264..742834c 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort( port int, hostname string, ) { - open, err := w.portCheck.CheckPort(ctx, ip, port) + result, err := w.portCheck.CheckPort(ctx, ip, port) if err != nil { w.log.Error( "port check failed", @@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort( now := time.Now().UTC() prev, hasPrev := w.state.GetPortState(key) - if hasPrev && !w.firstRun && prev.Open != open { + if hasPrev && !w.firstRun && prev.Open != result.Open { stateStr := "closed" - if open { + if result.Open { stateStr = "open" } @@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort( } w.state.SetPortState(key, &state.PortState{ - Open: open, + Open: result.Open, Hostname: hostname, LastChecked: now, }) diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 69772aa..57b56c7 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -9,6 +9,7 @@ import ( "time" "sneak.berlin/go/dnswatcher/internal/config" + "sneak.berlin/go/dnswatcher/internal/portcheck" "sneak.berlin/go/dnswatcher/internal/state" "sneak.berlin/go/dnswatcher/internal/tlscheck" "sneak.berlin/go/dnswatcher/internal/watcher" @@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort( _ context.Context, address string, port int, -) (bool, error) { +) (*portcheck.PortResult, error) { m.mu.Lock() defer m.mu.Unlock() m.calls++ if m.err != nil { - return false, m.err + return nil, m.err } key := fmt.Sprintf("%s:%d", address, port) - open, ok := m.results[key] + open := m.results[key] - if !ok { - return false, nil - } - - return open, nil + return &portcheck.PortResult{Open: open}, nil } type mockTLSChecker struct {