2 Commits

Author SHA1 Message Date
clawbot
f0ea83179f fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:42:21 -08:00
clawbot
28f2d829ce feat: implement TCP port connectivity checker (closes #3) 2026-02-19 13:44:20 -08:00
3 changed files with 326 additions and 31 deletions

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@@ -45,9 +46,12 @@ type Params struct {
// Service provides notification functionality. // Service provides notification functionality.
type Service struct { type Service struct {
log *slog.Logger log *slog.Logger
client *http.Client client *http.Client
config *config.Config config *config.Config
ntfyURL *url.URL
slackWebhookURL *url.URL
mattermostWebhookURL *url.URL
} }
// New creates a new notify Service. // New creates a new notify Service.
@@ -55,13 +59,44 @@ func New(
_ fx.Lifecycle, _ fx.Lifecycle,
params Params, params Params,
) (*Service, error) { ) (*Service, error) {
return &Service{ svc := &Service{
log: params.Logger.Get(), log: params.Logger.Get(),
client: &http.Client{ client: &http.Client{
Timeout: httpClientTimeout, Timeout: httpClientTimeout,
}, },
config: params.Config, config: params.Config,
}, nil }
if params.Config.NtfyTopic != "" {
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
if err != nil {
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
}
svc.ntfyURL = u
}
if params.Config.SlackWebhook != "" {
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
if err != nil {
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
}
svc.slackWebhookURL = u
}
if params.Config.MattermostWebhook != "" {
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
if err != nil {
return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err,
)
}
svc.mattermostWebhookURL = u
}
return svc, nil
} }
// SendNotification sends a notification to all configured endpoints. // SendNotification sends a notification to all configured endpoints.
@@ -69,13 +104,13 @@ func (svc *Service) SendNotification(
ctx context.Context, ctx context.Context,
title, message, priority string, title, message, priority string,
) { ) {
if svc.config.NtfyTopic != "" { if svc.ntfyURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendNtfy( err := svc.sendNtfy(
notifyCtx, notifyCtx,
svc.config.NtfyTopic, svc.ntfyURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -87,13 +122,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.SlackWebhook != "" { if svc.slackWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.SlackWebhook, svc.slackWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -105,13 +140,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.MattermostWebhook != "" { if svc.mattermostWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.MattermostWebhook, svc.mattermostWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@@ -126,18 +161,19 @@ func (svc *Service) SendNotification(
func (svc *Service) sendNtfy( func (svc *Service) sendNtfy(
ctx context.Context, ctx context.Context,
topic, title, message, priority string, topicURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending ntfy notification", "sending ntfy notification",
"topic", topic, "topic", topicURL.String(),
"title", title, "title", title,
) )
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
topic, topicURL.String(),
bytes.NewBufferString(message), bytes.NewBufferString(message),
) )
if err != nil { if err != nil {
@@ -147,7 +183,7 @@ func (svc *Service) sendNtfy(
request.Header.Set("Title", title) request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority)) request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending ntfy request: %w", err) return fmt.Errorf("sending ntfy request: %w", err)
} }
@@ -193,11 +229,12 @@ type SlackAttachment struct {
func (svc *Service) sendSlack( func (svc *Service) sendSlack(
ctx context.Context, ctx context.Context,
webhookURL, title, message, priority string, webhookURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending webhook notification", "sending webhook notification",
"url", webhookURL, "url", webhookURL.String(),
"title", title, "title", title,
) )
@@ -219,7 +256,7 @@ func (svc *Service) sendSlack(
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
webhookURL, webhookURL.String(),
bytes.NewBuffer(body), bytes.NewBuffer(body),
) )
if err != nil { if err != nil {
@@ -228,7 +265,7 @@ func (svc *Service) sendSlack(
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending webhook request: %w", err) return fmt.Errorf("sending webhook request: %w", err)
} }

View File

@@ -3,18 +3,28 @@ package portcheck
import ( import (
"context" "context"
"errors" "fmt"
"log/slog" "log/slog"
"net"
"strconv"
"time"
"go.uber.org/fx" "go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
// ErrNotImplemented indicates the port checker is not yet implemented. const defaultTimeout = 5 * time.Second
var ErrNotImplemented = errors.New(
"port checker not yet implemented", // PortResult holds the outcome of a single TCP port check.
) type PortResult struct {
// Open indicates whether the port accepted a connection.
Open bool
// Error contains a description if the connection failed.
Error string
// Latency is the time taken for the TCP handshake.
Latency time.Duration
}
// Params contains dependencies for Checker. // Params contains dependencies for Checker.
type Params struct { type Params struct {
@@ -38,11 +48,96 @@ func New(
}, nil }, nil
} }
// CheckPort tests TCP connectivity to the given address and port. // NewStandalone creates a Checker without fx dependencies.
func (c *Checker) CheckPort( func NewStandalone() *Checker {
_ context.Context, return &Checker{
_ string, log: slog.Default(),
_ int, }
) (bool, error) { }
return false, ErrNotImplemented
// CheckPort tests TCP connectivity to the given address and port.
// It uses a 5-second timeout unless the context has an earlier
// deadline.
func (c *Checker) CheckPort(
ctx context.Context,
address string,
port int,
) (*PortResult, error) {
target := net.JoinHostPort(
address, strconv.Itoa(port),
)
deadline, hasDeadline := ctx.Deadline()
timeout := defaultTimeout
if hasDeadline {
remaining := time.Until(deadline)
if remaining < timeout {
timeout = remaining
}
}
dialer := &net.Dialer{Timeout: timeout}
start := time.Now()
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
latency := time.Since(start)
if dialErr != nil {
c.log.Debug(
"port check failed",
"target", target,
"error", dialErr.Error(),
)
return &PortResult{
Open: false,
Error: dialErr.Error(),
Latency: latency,
}, nil
}
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing connection",
"target", target,
"error", closeErr.Error(),
)
}
c.log.Debug(
"port check succeeded",
"target", target,
"latency", latency,
)
return &PortResult{
Open: true,
Latency: latency,
}, nil
}
// CheckPorts tests TCP connectivity to multiple ports on the
// given address. It returns a map of port number to result.
func (c *Checker) CheckPorts(
ctx context.Context,
address string,
ports []int,
) (map[int]*PortResult, error) {
results := make(map[int]*PortResult, len(ports))
for _, port := range ports {
result, err := c.CheckPort(ctx, address, port)
if err != nil {
return nil, fmt.Errorf(
"checking port %d: %w", port, err,
)
}
results[port] = result
}
return results, nil
} }

View File

@@ -0,0 +1,163 @@
package portcheck_test
import (
"context"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/portcheck"
)
func listenTCP(
t *testing.T,
) (net.Listener, int) {
t.Helper()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to start listener: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return ln, addr.Port
}
func TestCheckPortOpen(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
defer func() { _ = ln.Close() }()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !result.Open {
t.Error("expected port to be open")
}
if result.Error != "" {
t.Errorf("expected no error, got: %s", result.Error)
}
if result.Latency <= 0 {
t.Error("expected positive latency")
}
}
func TestCheckPortClosed(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
_ = ln.Close()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Open {
t.Error("expected port to be closed")
}
if result.Error == "" {
t.Error("expected error message for closed port")
}
}
func TestCheckPortContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(ctx, "127.0.0.1", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Open {
t.Error("expected port to not be open")
}
}
func TestCheckPortsMultiple(t *testing.T) {
t.Parallel()
ln, openPort := listenTCP(t)
defer func() { _ = ln.Close() }()
ln2, closedPort := listenTCP(t)
_ = ln2.Close()
checker := portcheck.NewStandalone()
results, err := checker.CheckPorts(
context.Background(),
"127.0.0.1",
[]int{openPort, closedPort},
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 2 {
t.Fatalf(
"expected 2 results, got %d", len(results),
)
}
if !results[openPort].Open {
t.Error("expected open port to be open")
}
if results[closedPort].Open {
t.Error("expected closed port to be closed")
}
}
func TestCheckPortLatencyReasonable(t *testing.T) {
t.Parallel()
ln, port := listenTCP(t)
defer func() { _ = ln.Close() }()
checker := portcheck.NewStandalone()
result, err := checker.CheckPort(
context.Background(), "127.0.0.1", port,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Latency > time.Second {
t.Errorf(
"latency too high for localhost: %v",
result.Latency,
)
}
}