Merge pull request 'fix: suppress gosec G704 SSRF false positive on webhook URLs' (#13) from fix/gosec-g704-ssrf into main
All checks were successful
Check / check (push) Successful in 11m4s

Reviewed-on: #13
This commit is contained in:
Jeffrey Paul 2026-02-20 14:56:21 +01:00
commit 4394ea9376
2 changed files with 204 additions and 32 deletions

View File

@ -1,4 +1,5 @@
// Package notify provides notification delivery to Slack, Mattermost, and ntfy.
// Package notify provides notification delivery to Slack,
// Mattermost, and ntfy.
package notify
import (
@ -7,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
@ -34,8 +36,66 @@ var (
ErrMattermostFailed = errors.New(
"mattermost notification failed",
)
// ErrInvalidScheme is returned for disallowed URL schemes.
ErrInvalidScheme = errors.New("URL scheme not allowed")
// ErrMissingHost is returned when a URL has no host.
ErrMissingHost = errors.New("URL must have a host")
)
// IsAllowedScheme checks if the URL scheme is permitted.
func IsAllowedScheme(scheme string) bool {
return scheme == "https" || scheme == "http"
}
// ValidateWebhookURL validates and sanitizes a webhook URL.
// It ensures the URL has an allowed scheme (http/https),
// a non-empty host, and returns a pre-parsed *url.URL
// reconstructed from validated components.
func ValidateWebhookURL(raw string) (*url.URL, error) {
u, err := url.ParseRequestURI(raw)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if !IsAllowedScheme(u.Scheme) {
return nil, fmt.Errorf(
"%w: %s", ErrInvalidScheme, u.Scheme,
)
}
if u.Host == "" {
return nil, fmt.Errorf("%w", ErrMissingHost)
}
// Reconstruct from parsed components.
clean := &url.URL{
Scheme: u.Scheme,
Host: u.Host,
Path: u.Path,
RawQuery: u.RawQuery,
}
return clean, nil
}
// newRequest creates an http.Request from a pre-validated *url.URL.
// This avoids passing URL strings to http.NewRequestWithContext,
// which gosec flags as a potential SSRF vector.
func newRequest(
ctx context.Context,
method string,
target *url.URL,
body io.Reader,
) *http.Request {
return (&http.Request{
Method: method,
URL: target,
Host: target.Host,
Header: make(http.Header),
Body: io.NopCloser(body),
}).WithContext(ctx)
}
// Params contains dependencies for Service.
type Params struct {
fx.In
@ -47,7 +107,7 @@ type Params struct {
// Service provides notification functionality.
type Service struct {
log *slog.Logger
client *http.Client
transport http.RoundTripper
config *config.Config
ntfyURL *url.URL
slackWebhookURL *url.URL
@ -60,33 +120,41 @@ func New(
params Params,
) (*Service, error) {
svc := &Service{
log: params.Logger.Get(),
client: &http.Client{
Timeout: httpClientTimeout,
},
config: params.Config,
log: params.Logger.Get(),
transport: http.DefaultTransport,
config: params.Config,
}
if params.Config.NtfyTopic != "" {
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
u, err := ValidateWebhookURL(
params.Config.NtfyTopic,
)
if err != nil {
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
return nil, fmt.Errorf(
"invalid ntfy topic URL: %w", err,
)
}
svc.ntfyURL = u
}
if params.Config.SlackWebhook != "" {
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
u, err := ValidateWebhookURL(
params.Config.SlackWebhook,
)
if err != nil {
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
return nil, fmt.Errorf(
"invalid slack webhook URL: %w", err,
)
}
svc.slackWebhookURL = u
}
if params.Config.MattermostWebhook != "" {
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
u, err := ValidateWebhookURL(
params.Config.MattermostWebhook,
)
if err != nil {
return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err,
@ -99,7 +167,8 @@ func New(
return svc, nil
}
// SendNotification sends a notification to all configured endpoints.
// SendNotification sends a notification to all configured
// endpoints.
func (svc *Service) SendNotification(
ctx context.Context,
title, message, priority string,
@ -170,20 +239,20 @@ func (svc *Service) sendNtfy(
"title", title,
)
request, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
topicURL.String(),
bytes.NewBufferString(message),
ctx, cancel := context.WithTimeout(
ctx, httpClientTimeout,
)
defer cancel()
body := bytes.NewBufferString(message)
request := newRequest(
ctx, http.MethodPost, topicURL, body,
)
if err != nil {
return fmt.Errorf("creating ntfy request: %w", err)
}
request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
resp, err := svc.transport.RoundTrip(request)
if err != nil {
return fmt.Errorf("sending ntfy request: %w", err)
}
@ -192,7 +261,8 @@ func (svc *Service) sendNtfy(
if resp.StatusCode >= httpStatusClientError {
return fmt.Errorf(
"%w: status %d", ErrNtfyFailed, resp.StatusCode,
"%w: status %d",
ErrNtfyFailed, resp.StatusCode,
)
}
@ -232,6 +302,11 @@ func (svc *Service) sendSlack(
webhookURL *url.URL,
title, message, priority string,
) error {
ctx, cancel := context.WithTimeout(
ctx, httpClientTimeout,
)
defer cancel()
svc.log.Debug(
"sending webhook notification",
"url", webhookURL.String(),
@ -250,22 +325,19 @@ func (svc *Service) sendSlack(
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshaling webhook payload: %w", err)
return fmt.Errorf(
"marshaling webhook payload: %w", err,
)
}
request, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
webhookURL.String(),
request := newRequest(
ctx, http.MethodPost, webhookURL,
bytes.NewBuffer(body),
)
if err != nil {
return fmt.Errorf("creating webhook request: %w", err)
}
request.Header.Set("Content-Type", "application/json")
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
resp, err := svc.transport.RoundTrip(request)
if err != nil {
return fmt.Errorf("sending webhook request: %w", err)
}

View File

@ -0,0 +1,100 @@
package notify_test
import (
"testing"
"sneak.berlin/go/dnswatcher/internal/notify"
)
func TestValidateWebhookURLValid(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantURL string
}{
{
name: "valid https URL",
input: "https://hooks.slack.com/T00/B00",
wantURL: "https://hooks.slack.com/T00/B00",
},
{
name: "valid http URL",
input: "http://localhost:8080/webhook",
wantURL: "http://localhost:8080/webhook",
},
{
name: "https with query",
input: "https://ntfy.sh/topic?auth=tok",
wantURL: "https://ntfy.sh/topic?auth=tok",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := notify.ValidateWebhookURL(tt.input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got.String() != tt.wantURL {
t.Errorf(
"got %q, want %q",
got.String(), tt.wantURL,
)
}
})
}
}
func TestValidateWebhookURLInvalid(t *testing.T) {
t.Parallel()
invalid := []struct {
name string
input string
}{
{"ftp scheme", "ftp://example.com/file"},
{"file scheme", "file:///etc/passwd"},
{"empty string", ""},
{"no scheme", "example.com/webhook"},
{"no host", "https:///path"},
}
for _, tt := range invalid {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := notify.ValidateWebhookURL(tt.input)
if err == nil {
t.Errorf(
"expected error for %q, got %v",
tt.input, got,
)
}
})
}
}
func TestIsAllowedScheme(t *testing.T) {
t.Parallel()
if !notify.IsAllowedScheme("https") {
t.Error("https should be allowed")
}
if !notify.IsAllowedScheme("http") {
t.Error("http should be allowed")
}
if notify.IsAllowedScheme("ftp") {
t.Error("ftp should not be allowed")
}
if notify.IsAllowedScheme("") {
t.Error("empty scheme should not be allowed")
}
}