feat: make CheckPorts concurrent and add port validation
- CheckPorts now runs all port checks concurrently using errgroup - Added port number validation (1-65535) with ErrInvalidPort sentinel error - Updated PortChecker interface to use *PortResult return type - Added tests for invalid port numbers (0, negative, >65535) - All checks pass (make check clean)
This commit is contained in:
@@ -4,6 +4,7 @@ package watcher
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||
)
|
||||
|
||||
@@ -36,7 +37,7 @@ type PortChecker interface {
|
||||
ctx context.Context,
|
||||
address string,
|
||||
port int,
|
||||
) (bool, error)
|
||||
) (*portcheck.PortResult, error)
|
||||
}
|
||||
|
||||
// TLSChecker inspects TLS certificates.
|
||||
|
||||
@@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort(
|
||||
port int,
|
||||
hostname string,
|
||||
) {
|
||||
open, err := w.portCheck.CheckPort(ctx, ip, port)
|
||||
result, err := w.portCheck.CheckPort(ctx, ip, port)
|
||||
if err != nil {
|
||||
w.log.Error(
|
||||
"port check failed",
|
||||
@@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort(
|
||||
now := time.Now().UTC()
|
||||
prev, hasPrev := w.state.GetPortState(key)
|
||||
|
||||
if hasPrev && !w.firstRun && prev.Open != open {
|
||||
if hasPrev && !w.firstRun && prev.Open != result.Open {
|
||||
stateStr := "closed"
|
||||
if open {
|
||||
if result.Open {
|
||||
stateStr = "open"
|
||||
}
|
||||
|
||||
@@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort(
|
||||
}
|
||||
|
||||
w.state.SetPortState(key, &state.PortState{
|
||||
Open: open,
|
||||
Open: result.Open,
|
||||
Hostname: hostname,
|
||||
LastChecked: now,
|
||||
})
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/config"
|
||||
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||
"sneak.berlin/go/dnswatcher/internal/state"
|
||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||
"sneak.berlin/go/dnswatcher/internal/watcher"
|
||||
@@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort(
|
||||
_ context.Context,
|
||||
address string,
|
||||
port int,
|
||||
) (bool, error) {
|
||||
) (*portcheck.PortResult, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.calls++
|
||||
|
||||
if m.err != nil {
|
||||
return false, m.err
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s:%d", address, port)
|
||||
open, ok := m.results[key]
|
||||
open := m.results[key]
|
||||
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return open, nil
|
||||
return &portcheck.PortResult{Open: open}, nil
|
||||
}
|
||||
|
||||
type mockTLSChecker struct {
|
||||
|
||||
Reference in New Issue
Block a user