Compare commits
3 Commits
fix/gosec-
...
9a9a95581a
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a9a95581a | |||
|
|
f0ea83179f | ||
|
|
28f2d829ce |
@@ -1,5 +1,4 @@
|
||||
// Package notify provides notification delivery to Slack,
|
||||
// Mattermost, and ntfy.
|
||||
// Package notify provides notification delivery to Slack, Mattermost, and ntfy.
|
||||
package notify
|
||||
|
||||
import (
|
||||
@@ -8,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -36,66 +34,8 @@ var (
|
||||
ErrMattermostFailed = errors.New(
|
||||
"mattermost notification failed",
|
||||
)
|
||||
// ErrInvalidScheme is returned for disallowed URL schemes.
|
||||
ErrInvalidScheme = errors.New("URL scheme not allowed")
|
||||
// ErrMissingHost is returned when a URL has no host.
|
||||
ErrMissingHost = errors.New("URL must have a host")
|
||||
)
|
||||
|
||||
// IsAllowedScheme checks if the URL scheme is permitted.
|
||||
func IsAllowedScheme(scheme string) bool {
|
||||
return scheme == "https" || scheme == "http"
|
||||
}
|
||||
|
||||
// ValidateWebhookURL validates and sanitizes a webhook URL.
|
||||
// It ensures the URL has an allowed scheme (http/https),
|
||||
// a non-empty host, and returns a pre-parsed *url.URL
|
||||
// reconstructed from validated components.
|
||||
func ValidateWebhookURL(raw string) (*url.URL, error) {
|
||||
u, err := url.ParseRequestURI(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
if !IsAllowedScheme(u.Scheme) {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: %s", ErrInvalidScheme, u.Scheme,
|
||||
)
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
return nil, fmt.Errorf("%w", ErrMissingHost)
|
||||
}
|
||||
|
||||
// Reconstruct from parsed components.
|
||||
clean := &url.URL{
|
||||
Scheme: u.Scheme,
|
||||
Host: u.Host,
|
||||
Path: u.Path,
|
||||
RawQuery: u.RawQuery,
|
||||
}
|
||||
|
||||
return clean, nil
|
||||
}
|
||||
|
||||
// newRequest creates an http.Request from a pre-validated *url.URL.
|
||||
// This avoids passing URL strings to http.NewRequestWithContext,
|
||||
// which gosec flags as a potential SSRF vector.
|
||||
func newRequest(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
target *url.URL,
|
||||
body io.Reader,
|
||||
) *http.Request {
|
||||
return (&http.Request{
|
||||
Method: method,
|
||||
URL: target,
|
||||
Host: target.Host,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(body),
|
||||
}).WithContext(ctx)
|
||||
}
|
||||
|
||||
// Params contains dependencies for Service.
|
||||
type Params struct {
|
||||
fx.In
|
||||
@@ -107,7 +47,7 @@ type Params struct {
|
||||
// Service provides notification functionality.
|
||||
type Service struct {
|
||||
log *slog.Logger
|
||||
transport http.RoundTripper
|
||||
client *http.Client
|
||||
config *config.Config
|
||||
ntfyURL *url.URL
|
||||
slackWebhookURL *url.URL
|
||||
@@ -120,41 +60,33 @@ func New(
|
||||
params Params,
|
||||
) (*Service, error) {
|
||||
svc := &Service{
|
||||
log: params.Logger.Get(),
|
||||
transport: http.DefaultTransport,
|
||||
config: params.Config,
|
||||
log: params.Logger.Get(),
|
||||
client: &http.Client{
|
||||
Timeout: httpClientTimeout,
|
||||
},
|
||||
config: params.Config,
|
||||
}
|
||||
|
||||
if params.Config.NtfyTopic != "" {
|
||||
u, err := ValidateWebhookURL(
|
||||
params.Config.NtfyTopic,
|
||||
)
|
||||
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid ntfy topic URL: %w", err,
|
||||
)
|
||||
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
|
||||
}
|
||||
|
||||
svc.ntfyURL = u
|
||||
}
|
||||
|
||||
if params.Config.SlackWebhook != "" {
|
||||
u, err := ValidateWebhookURL(
|
||||
params.Config.SlackWebhook,
|
||||
)
|
||||
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid slack webhook URL: %w", err,
|
||||
)
|
||||
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
|
||||
}
|
||||
|
||||
svc.slackWebhookURL = u
|
||||
}
|
||||
|
||||
if params.Config.MattermostWebhook != "" {
|
||||
u, err := ValidateWebhookURL(
|
||||
params.Config.MattermostWebhook,
|
||||
)
|
||||
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid mattermost webhook URL: %w", err,
|
||||
@@ -167,8 +99,7 @@ func New(
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// SendNotification sends a notification to all configured
|
||||
// endpoints.
|
||||
// SendNotification sends a notification to all configured endpoints.
|
||||
func (svc *Service) SendNotification(
|
||||
ctx context.Context,
|
||||
title, message, priority string,
|
||||
@@ -239,20 +170,20 @@ func (svc *Service) sendNtfy(
|
||||
"title", title,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
ctx, httpClientTimeout,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
body := bytes.NewBufferString(message)
|
||||
request := newRequest(
|
||||
ctx, http.MethodPost, topicURL, body,
|
||||
request, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
topicURL.String(),
|
||||
bytes.NewBufferString(message),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating ntfy request: %w", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Title", title)
|
||||
request.Header.Set("Priority", ntfyPriority(priority))
|
||||
|
||||
resp, err := svc.transport.RoundTrip(request)
|
||||
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending ntfy request: %w", err)
|
||||
}
|
||||
@@ -261,8 +192,7 @@ func (svc *Service) sendNtfy(
|
||||
|
||||
if resp.StatusCode >= httpStatusClientError {
|
||||
return fmt.Errorf(
|
||||
"%w: status %d",
|
||||
ErrNtfyFailed, resp.StatusCode,
|
||||
"%w: status %d", ErrNtfyFailed, resp.StatusCode,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -302,11 +232,6 @@ func (svc *Service) sendSlack(
|
||||
webhookURL *url.URL,
|
||||
title, message, priority string,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(
|
||||
ctx, httpClientTimeout,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
svc.log.Debug(
|
||||
"sending webhook notification",
|
||||
"url", webhookURL.String(),
|
||||
@@ -325,19 +250,22 @@ func (svc *Service) sendSlack(
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"marshaling webhook payload: %w", err,
|
||||
)
|
||||
return fmt.Errorf("marshaling webhook payload: %w", err)
|
||||
}
|
||||
|
||||
request := newRequest(
|
||||
ctx, http.MethodPost, webhookURL,
|
||||
request, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
webhookURL.String(),
|
||||
bytes.NewBuffer(body),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating webhook request: %w", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := svc.transport.RoundTrip(request)
|
||||
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending webhook request: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package notify_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/notify"
|
||||
)
|
||||
|
||||
func TestValidateWebhookURLValid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantURL string
|
||||
}{
|
||||
{
|
||||
name: "valid https URL",
|
||||
input: "https://hooks.slack.com/T00/B00",
|
||||
wantURL: "https://hooks.slack.com/T00/B00",
|
||||
},
|
||||
{
|
||||
name: "valid http URL",
|
||||
input: "http://localhost:8080/webhook",
|
||||
wantURL: "http://localhost:8080/webhook",
|
||||
},
|
||||
{
|
||||
name: "https with query",
|
||||
input: "https://ntfy.sh/topic?auth=tok",
|
||||
wantURL: "https://ntfy.sh/topic?auth=tok",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := notify.ValidateWebhookURL(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if got.String() != tt.wantURL {
|
||||
t.Errorf(
|
||||
"got %q, want %q",
|
||||
got.String(), tt.wantURL,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWebhookURLInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
invalid := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"ftp scheme", "ftp://example.com/file"},
|
||||
{"file scheme", "file:///etc/passwd"},
|
||||
{"empty string", ""},
|
||||
{"no scheme", "example.com/webhook"},
|
||||
{"no host", "https:///path"},
|
||||
}
|
||||
|
||||
for _, tt := range invalid {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := notify.ValidateWebhookURL(tt.input)
|
||||
if err == nil {
|
||||
t.Errorf(
|
||||
"expected error for %q, got %v",
|
||||
tt.input, got,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedScheme(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !notify.IsAllowedScheme("https") {
|
||||
t.Error("https should be allowed")
|
||||
}
|
||||
|
||||
if !notify.IsAllowedScheme("http") {
|
||||
t.Error("http should be allowed")
|
||||
}
|
||||
|
||||
if notify.IsAllowedScheme("ftp") {
|
||||
t.Error("ftp should not be allowed")
|
||||
}
|
||||
|
||||
if notify.IsAllowedScheme("") {
|
||||
t.Error("empty scheme should not be allowed")
|
||||
}
|
||||
}
|
||||
@@ -3,18 +3,28 @@ package portcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go.uber.org/fx"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/logger"
|
||||
)
|
||||
|
||||
// ErrNotImplemented indicates the port checker is not yet implemented.
|
||||
var ErrNotImplemented = errors.New(
|
||||
"port checker not yet implemented",
|
||||
)
|
||||
const defaultTimeout = 5 * time.Second
|
||||
|
||||
// PortResult holds the outcome of a single TCP port check.
|
||||
type PortResult struct {
|
||||
// Open indicates whether the port accepted a connection.
|
||||
Open bool
|
||||
// Error contains a description if the connection failed.
|
||||
Error string
|
||||
// Latency is the time taken for the TCP handshake.
|
||||
Latency time.Duration
|
||||
}
|
||||
|
||||
// Params contains dependencies for Checker.
|
||||
type Params struct {
|
||||
@@ -38,11 +48,96 @@ func New(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckPort tests TCP connectivity to the given address and port.
|
||||
func (c *Checker) CheckPort(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ int,
|
||||
) (bool, error) {
|
||||
return false, ErrNotImplemented
|
||||
// NewStandalone creates a Checker without fx dependencies.
|
||||
func NewStandalone() *Checker {
|
||||
return &Checker{
|
||||
log: slog.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckPort tests TCP connectivity to the given address and port.
|
||||
// It uses a 5-second timeout unless the context has an earlier
|
||||
// deadline.
|
||||
func (c *Checker) CheckPort(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
port int,
|
||||
) (*PortResult, error) {
|
||||
target := net.JoinHostPort(
|
||||
address, strconv.Itoa(port),
|
||||
)
|
||||
|
||||
deadline, hasDeadline := ctx.Deadline()
|
||||
timeout := defaultTimeout
|
||||
|
||||
if hasDeadline {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining < timeout {
|
||||
timeout = remaining
|
||||
}
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
|
||||
latency := time.Since(start)
|
||||
|
||||
if dialErr != nil {
|
||||
c.log.Debug(
|
||||
"port check failed",
|
||||
"target", target,
|
||||
"error", dialErr.Error(),
|
||||
)
|
||||
|
||||
return &PortResult{
|
||||
Open: false,
|
||||
Error: dialErr.Error(),
|
||||
Latency: latency,
|
||||
}, nil
|
||||
}
|
||||
|
||||
closeErr := conn.Close()
|
||||
if closeErr != nil {
|
||||
c.log.Debug(
|
||||
"closing connection",
|
||||
"target", target,
|
||||
"error", closeErr.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
c.log.Debug(
|
||||
"port check succeeded",
|
||||
"target", target,
|
||||
"latency", latency,
|
||||
)
|
||||
|
||||
return &PortResult{
|
||||
Open: true,
|
||||
Latency: latency,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckPorts tests TCP connectivity to multiple ports on the
|
||||
// given address. It returns a map of port number to result.
|
||||
func (c *Checker) CheckPorts(
|
||||
ctx context.Context,
|
||||
address string,
|
||||
ports []int,
|
||||
) (map[int]*PortResult, error) {
|
||||
results := make(map[int]*PortResult, len(ports))
|
||||
|
||||
for _, port := range ports {
|
||||
result, err := c.CheckPort(ctx, address, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"checking port %d: %w", port, err,
|
||||
)
|
||||
}
|
||||
|
||||
results[port] = result
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
163
internal/portcheck/portcheck_test.go
Normal file
163
internal/portcheck/portcheck_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package portcheck_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||
)
|
||||
|
||||
func listenTCP(
|
||||
t *testing.T,
|
||||
) (net.Listener, int) {
|
||||
t.Helper()
|
||||
|
||||
lc := &net.ListenConfig{}
|
||||
|
||||
ln, err := lc.Listen(
|
||||
context.Background(), "tcp", "127.0.0.1:0",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start listener: %v", err)
|
||||
}
|
||||
|
||||
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
t.Fatal("unexpected address type")
|
||||
}
|
||||
|
||||
return ln, addr.Port
|
||||
}
|
||||
|
||||
func TestCheckPortOpen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, port := listenTCP(t)
|
||||
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(
|
||||
context.Background(), "127.0.0.1", port,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !result.Open {
|
||||
t.Error("expected port to be open")
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
t.Errorf("expected no error, got: %s", result.Error)
|
||||
}
|
||||
|
||||
if result.Latency <= 0 {
|
||||
t.Error("expected positive latency")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, port := listenTCP(t)
|
||||
_ = ln.Close()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(
|
||||
context.Background(), "127.0.0.1", port,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Open {
|
||||
t.Error("expected port to be closed")
|
||||
}
|
||||
|
||||
if result.Error == "" {
|
||||
t.Error("expected error message for closed port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(ctx, "127.0.0.1", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Open {
|
||||
t.Error("expected port to not be open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortsMultiple(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, openPort := listenTCP(t)
|
||||
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
ln2, closedPort := listenTCP(t)
|
||||
_ = ln2.Close()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
results, err := checker.CheckPorts(
|
||||
context.Background(),
|
||||
"127.0.0.1",
|
||||
[]int{openPort, closedPort},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatalf(
|
||||
"expected 2 results, got %d", len(results),
|
||||
)
|
||||
}
|
||||
|
||||
if !results[openPort].Open {
|
||||
t.Error("expected open port to be open")
|
||||
}
|
||||
|
||||
if results[closedPort].Open {
|
||||
t.Error("expected closed port to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPortLatencyReasonable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, port := listenTCP(t)
|
||||
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
checker := portcheck.NewStandalone()
|
||||
|
||||
result, err := checker.CheckPort(
|
||||
context.Background(), "127.0.0.1", port,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Latency > time.Second {
|
||||
t.Errorf(
|
||||
"latency too high for localhost: %v",
|
||||
result.Latency,
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user