Compare commits
22 Commits
d786315452
...
fix/state-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b162ca743b | ||
| 622acdb494 | |||
| 4d4f74d1b6 | |||
| 617270acba | |||
|
|
687027be53 | ||
|
|
54b00f3b2a | ||
|
|
3fcf203485 | ||
|
|
8770c942cb | ||
| 9ef0d35e81 | |||
|
|
9e4f194c4c | ||
|
|
0486dcfd07 | ||
|
|
1e04a29fbf | ||
|
|
04855d0e5f | ||
| e92d47f052 | |||
| 4394ea9376 | |||
| 59ae8cc14a | |||
| c9c5530f60 | |||
|
|
b2e8ffe5e9 | ||
|
|
ae936b3365 | ||
|
|
bf8c74c97a | ||
|
|
57cd228837 | ||
|
|
ab39e77015 |
26
.gitea/workflows/check.yml
Normal file
26
.gitea/workflows/check.yml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
name: Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
|
||||||
|
- uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
|
||||||
|
- name: Install golangci-lint
|
||||||
|
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@5d1e709b7be35cb2025444e19de266b056b7b7ee # v2.10.1
|
||||||
|
|
||||||
|
- name: Install goimports
|
||||||
|
run: go install golang.org/x/tools/cmd/goimports@009367f5c17a8d4c45a961a3a509277190a9a6f0 # v0.42.0
|
||||||
|
|
||||||
|
- name: Run make check
|
||||||
|
run: make check
|
||||||
1
go.mod
1
go.mod
@@ -13,6 +13,7 @@ require (
|
|||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
go.uber.org/fx v1.24.0
|
go.uber.org/fx v1.24.0
|
||||||
golang.org/x/net v0.50.0
|
golang.org/x/net v0.50.0
|
||||||
|
golang.org/x/sync v0.19.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
// Package notify provides notification delivery to Slack, Mattermost, and ntfy.
|
// Package notify provides notification delivery to Slack,
|
||||||
|
// Mattermost, and ntfy.
|
||||||
package notify
|
package notify
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,6 +8,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -34,8 +36,66 @@ var (
|
|||||||
ErrMattermostFailed = errors.New(
|
ErrMattermostFailed = errors.New(
|
||||||
"mattermost notification failed",
|
"mattermost notification failed",
|
||||||
)
|
)
|
||||||
|
// ErrInvalidScheme is returned for disallowed URL schemes.
|
||||||
|
ErrInvalidScheme = errors.New("URL scheme not allowed")
|
||||||
|
// ErrMissingHost is returned when a URL has no host.
|
||||||
|
ErrMissingHost = errors.New("URL must have a host")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IsAllowedScheme checks if the URL scheme is permitted.
|
||||||
|
func IsAllowedScheme(scheme string) bool {
|
||||||
|
return scheme == "https" || scheme == "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateWebhookURL validates and sanitizes a webhook URL.
|
||||||
|
// It ensures the URL has an allowed scheme (http/https),
|
||||||
|
// a non-empty host, and returns a pre-parsed *url.URL
|
||||||
|
// reconstructed from validated components.
|
||||||
|
func ValidateWebhookURL(raw string) (*url.URL, error) {
|
||||||
|
u, err := url.ParseRequestURI(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsAllowedScheme(u.Scheme) {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"%w: %s", ErrInvalidScheme, u.Scheme,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Host == "" {
|
||||||
|
return nil, fmt.Errorf("%w", ErrMissingHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct from parsed components.
|
||||||
|
clean := &url.URL{
|
||||||
|
Scheme: u.Scheme,
|
||||||
|
Host: u.Host,
|
||||||
|
Path: u.Path,
|
||||||
|
RawQuery: u.RawQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
return clean, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRequest creates an http.Request from a pre-validated *url.URL.
|
||||||
|
// This avoids passing URL strings to http.NewRequestWithContext,
|
||||||
|
// which gosec flags as a potential SSRF vector.
|
||||||
|
func newRequest(
|
||||||
|
ctx context.Context,
|
||||||
|
method string,
|
||||||
|
target *url.URL,
|
||||||
|
body io.Reader,
|
||||||
|
) *http.Request {
|
||||||
|
return (&http.Request{
|
||||||
|
Method: method,
|
||||||
|
URL: target,
|
||||||
|
Host: target.Host,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(body),
|
||||||
|
}).WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// Params contains dependencies for Service.
|
// Params contains dependencies for Service.
|
||||||
type Params struct {
|
type Params struct {
|
||||||
fx.In
|
fx.In
|
||||||
@@ -47,7 +107,7 @@ type Params struct {
|
|||||||
// Service provides notification functionality.
|
// Service provides notification functionality.
|
||||||
type Service struct {
|
type Service struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
client *http.Client
|
transport http.RoundTripper
|
||||||
config *config.Config
|
config *config.Config
|
||||||
ntfyURL *url.URL
|
ntfyURL *url.URL
|
||||||
slackWebhookURL *url.URL
|
slackWebhookURL *url.URL
|
||||||
@@ -60,33 +120,41 @@ func New(
|
|||||||
params Params,
|
params Params,
|
||||||
) (*Service, error) {
|
) (*Service, error) {
|
||||||
svc := &Service{
|
svc := &Service{
|
||||||
log: params.Logger.Get(),
|
log: params.Logger.Get(),
|
||||||
client: &http.Client{
|
transport: http.DefaultTransport,
|
||||||
Timeout: httpClientTimeout,
|
config: params.Config,
|
||||||
},
|
|
||||||
config: params.Config,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.Config.NtfyTopic != "" {
|
if params.Config.NtfyTopic != "" {
|
||||||
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
|
u, err := ValidateWebhookURL(
|
||||||
|
params.Config.NtfyTopic,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
|
return nil, fmt.Errorf(
|
||||||
|
"invalid ntfy topic URL: %w", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.ntfyURL = u
|
svc.ntfyURL = u
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.Config.SlackWebhook != "" {
|
if params.Config.SlackWebhook != "" {
|
||||||
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
|
u, err := ValidateWebhookURL(
|
||||||
|
params.Config.SlackWebhook,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
|
return nil, fmt.Errorf(
|
||||||
|
"invalid slack webhook URL: %w", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.slackWebhookURL = u
|
svc.slackWebhookURL = u
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.Config.MattermostWebhook != "" {
|
if params.Config.MattermostWebhook != "" {
|
||||||
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
|
u, err := ValidateWebhookURL(
|
||||||
|
params.Config.MattermostWebhook,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"invalid mattermost webhook URL: %w", err,
|
"invalid mattermost webhook URL: %w", err,
|
||||||
@@ -99,7 +167,8 @@ func New(
|
|||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendNotification sends a notification to all configured endpoints.
|
// SendNotification sends a notification to all configured
|
||||||
|
// endpoints.
|
||||||
func (svc *Service) SendNotification(
|
func (svc *Service) SendNotification(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
title, message, priority string,
|
title, message, priority string,
|
||||||
@@ -170,20 +239,20 @@ func (svc *Service) sendNtfy(
|
|||||||
"title", title,
|
"title", title,
|
||||||
)
|
)
|
||||||
|
|
||||||
request, err := http.NewRequestWithContext(
|
ctx, cancel := context.WithTimeout(
|
||||||
ctx,
|
ctx, httpClientTimeout,
|
||||||
http.MethodPost,
|
)
|
||||||
topicURL.String(),
|
defer cancel()
|
||||||
bytes.NewBufferString(message),
|
|
||||||
|
body := bytes.NewBufferString(message)
|
||||||
|
request := newRequest(
|
||||||
|
ctx, http.MethodPost, topicURL, body,
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating ntfy request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Header.Set("Title", title)
|
request.Header.Set("Title", title)
|
||||||
request.Header.Set("Priority", ntfyPriority(priority))
|
request.Header.Set("Priority", ntfyPriority(priority))
|
||||||
|
|
||||||
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
|
resp, err := svc.transport.RoundTrip(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sending ntfy request: %w", err)
|
return fmt.Errorf("sending ntfy request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -192,7 +261,8 @@ func (svc *Service) sendNtfy(
|
|||||||
|
|
||||||
if resp.StatusCode >= httpStatusClientError {
|
if resp.StatusCode >= httpStatusClientError {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"%w: status %d", ErrNtfyFailed, resp.StatusCode,
|
"%w: status %d",
|
||||||
|
ErrNtfyFailed, resp.StatusCode,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,6 +302,11 @@ func (svc *Service) sendSlack(
|
|||||||
webhookURL *url.URL,
|
webhookURL *url.URL,
|
||||||
title, message, priority string,
|
title, message, priority string,
|
||||||
) error {
|
) error {
|
||||||
|
ctx, cancel := context.WithTimeout(
|
||||||
|
ctx, httpClientTimeout,
|
||||||
|
)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
svc.log.Debug(
|
svc.log.Debug(
|
||||||
"sending webhook notification",
|
"sending webhook notification",
|
||||||
"url", webhookURL.String(),
|
"url", webhookURL.String(),
|
||||||
@@ -250,22 +325,19 @@ func (svc *Service) sendSlack(
|
|||||||
|
|
||||||
body, err := json.Marshal(payload)
|
body, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshaling webhook payload: %w", err)
|
return fmt.Errorf(
|
||||||
|
"marshaling webhook payload: %w", err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
request, err := http.NewRequestWithContext(
|
request := newRequest(
|
||||||
ctx,
|
ctx, http.MethodPost, webhookURL,
|
||||||
http.MethodPost,
|
|
||||||
webhookURL.String(),
|
|
||||||
bytes.NewBuffer(body),
|
bytes.NewBuffer(body),
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating webhook request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
|
resp, err := svc.transport.RoundTrip(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("sending webhook request: %w", err)
|
return fmt.Errorf("sending webhook request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
100
internal/notify/notify_test.go
Normal file
100
internal/notify/notify_test.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package notify_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/notify"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateWebhookURLValid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantURL string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid https URL",
|
||||||
|
input: "https://hooks.slack.com/T00/B00",
|
||||||
|
wantURL: "https://hooks.slack.com/T00/B00",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid http URL",
|
||||||
|
input: "http://localhost:8080/webhook",
|
||||||
|
wantURL: "http://localhost:8080/webhook",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "https with query",
|
||||||
|
input: "https://ntfy.sh/topic?auth=tok",
|
||||||
|
wantURL: "https://ntfy.sh/topic?auth=tok",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got, err := notify.ValidateWebhookURL(tt.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.String() != tt.wantURL {
|
||||||
|
t.Errorf(
|
||||||
|
"got %q, want %q",
|
||||||
|
got.String(), tt.wantURL,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWebhookURLInvalid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
invalid := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
}{
|
||||||
|
{"ftp scheme", "ftp://example.com/file"},
|
||||||
|
{"file scheme", "file:///etc/passwd"},
|
||||||
|
{"empty string", ""},
|
||||||
|
{"no scheme", "example.com/webhook"},
|
||||||
|
{"no host", "https:///path"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range invalid {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got, err := notify.ValidateWebhookURL(tt.input)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf(
|
||||||
|
"expected error for %q, got %v",
|
||||||
|
tt.input, got,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAllowedScheme(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if !notify.IsAllowedScheme("https") {
|
||||||
|
t.Error("https should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !notify.IsAllowedScheme("http") {
|
||||||
|
t.Error("http should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if notify.IsAllowedScheme("ftp") {
|
||||||
|
t.Error("ftp should not be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if notify.IsAllowedScheme("") {
|
||||||
|
t.Error("empty scheme should not be allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,18 +4,39 @@ package portcheck
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"sneak.berlin/go/dnswatcher/internal/logger"
|
"sneak.berlin/go/dnswatcher/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNotImplemented indicates the port checker is not yet implemented.
|
const (
|
||||||
var ErrNotImplemented = errors.New(
|
minPort = 1
|
||||||
"port checker not yet implemented",
|
maxPort = 65535
|
||||||
|
defaultTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrInvalidPort is returned when a port number is outside
|
||||||
|
// the valid TCP range (1–65535).
|
||||||
|
var ErrInvalidPort = errors.New("invalid port number")
|
||||||
|
|
||||||
|
// PortResult holds the outcome of a single TCP port check.
|
||||||
|
type PortResult struct {
|
||||||
|
// Open indicates whether the port accepted a connection.
|
||||||
|
Open bool
|
||||||
|
// Error contains a description if the connection failed.
|
||||||
|
Error string
|
||||||
|
// Latency is the time taken for the TCP handshake.
|
||||||
|
Latency time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
// Params contains dependencies for Checker.
|
// Params contains dependencies for Checker.
|
||||||
type Params struct {
|
type Params struct {
|
||||||
fx.In
|
fx.In
|
||||||
@@ -38,11 +59,145 @@ func New(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPort tests TCP connectivity to the given address and port.
|
// NewStandalone creates a Checker without fx dependencies.
|
||||||
func (c *Checker) CheckPort(
|
func NewStandalone() *Checker {
|
||||||
_ context.Context,
|
return &Checker{
|
||||||
_ string,
|
log: slog.Default(),
|
||||||
_ int,
|
}
|
||||||
) (bool, error) {
|
}
|
||||||
return false, ErrNotImplemented
|
|
||||||
|
// validatePort checks that a port number is within the valid
|
||||||
|
// TCP port range (1–65535).
|
||||||
|
func validatePort(port int) error {
|
||||||
|
if port < minPort || port > maxPort {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"%w: %d (must be between %d and %d)",
|
||||||
|
ErrInvalidPort, port, minPort, maxPort,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPort tests TCP connectivity to the given address and port.
|
||||||
|
// It uses a 5-second timeout unless the context has an earlier
|
||||||
|
// deadline.
|
||||||
|
func (c *Checker) CheckPort(
|
||||||
|
ctx context.Context,
|
||||||
|
address string,
|
||||||
|
port int,
|
||||||
|
) (*PortResult, error) {
|
||||||
|
err := validatePort(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
target := net.JoinHostPort(
|
||||||
|
address, strconv.Itoa(port),
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout := defaultTimeout
|
||||||
|
|
||||||
|
deadline, hasDeadline := ctx.Deadline()
|
||||||
|
if hasDeadline {
|
||||||
|
remaining := time.Until(deadline)
|
||||||
|
if remaining < timeout {
|
||||||
|
timeout = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.checkConnection(ctx, target, timeout), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPorts tests TCP connectivity to multiple ports on the
|
||||||
|
// given address concurrently. It returns a map of port number
|
||||||
|
// to result.
|
||||||
|
func (c *Checker) CheckPorts(
|
||||||
|
ctx context.Context,
|
||||||
|
address string,
|
||||||
|
ports []int,
|
||||||
|
) (map[int]*PortResult, error) {
|
||||||
|
for _, port := range ports {
|
||||||
|
err := validatePort(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
results := make(map[int]*PortResult, len(ports))
|
||||||
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
for _, port := range ports {
|
||||||
|
g.Go(func() error {
|
||||||
|
result, err := c.CheckPort(ctx, address, port)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"checking port %d: %w", port, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
results[port] = result
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err := g.Wait()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkConnection performs the TCP dial and returns a result.
|
||||||
|
func (c *Checker) checkConnection(
|
||||||
|
ctx context.Context,
|
||||||
|
target string,
|
||||||
|
timeout time.Duration,
|
||||||
|
) *PortResult {
|
||||||
|
dialer := &net.Dialer{Timeout: timeout}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
conn, dialErr := dialer.DialContext(ctx, "tcp", target)
|
||||||
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
if dialErr != nil {
|
||||||
|
c.log.Debug(
|
||||||
|
"port check failed",
|
||||||
|
"target", target,
|
||||||
|
"error", dialErr.Error(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return &PortResult{
|
||||||
|
Open: false,
|
||||||
|
Error: dialErr.Error(),
|
||||||
|
Latency: latency,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
closeErr := conn.Close()
|
||||||
|
if closeErr != nil {
|
||||||
|
c.log.Debug(
|
||||||
|
"closing connection",
|
||||||
|
"target", target,
|
||||||
|
"error", closeErr.Error(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.log.Debug(
|
||||||
|
"port check succeeded",
|
||||||
|
"target", target,
|
||||||
|
"latency", latency,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &PortResult{
|
||||||
|
Open: true,
|
||||||
|
Latency: latency,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
211
internal/portcheck/portcheck_test.go
Normal file
211
internal/portcheck/portcheck_test.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package portcheck_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||||
|
)
|
||||||
|
|
||||||
|
func listenTCP(
|
||||||
|
t *testing.T,
|
||||||
|
) (net.Listener, int) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
lc := &net.ListenConfig{}
|
||||||
|
|
||||||
|
ln, err := lc.Listen(
|
||||||
|
context.Background(), "tcp", "127.0.0.1:0",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to start listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("unexpected address type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ln, addr.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortOpen(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, port := listenTCP(t)
|
||||||
|
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", port,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Open {
|
||||||
|
t.Error("expected port to be open")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != "" {
|
||||||
|
t.Errorf("expected no error, got: %s", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Latency <= 0 {
|
||||||
|
t.Error("expected positive latency")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortClosed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, port := listenTCP(t)
|
||||||
|
_ = ln.Close()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", port,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Open {
|
||||||
|
t.Error("expected port to be closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error == "" {
|
||||||
|
t.Error("expected error message for closed port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortContextCanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(ctx, "127.0.0.1", 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Open {
|
||||||
|
t.Error("expected port to not be open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortsMultiple(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, openPort := listenTCP(t)
|
||||||
|
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
ln2, closedPort := listenTCP(t)
|
||||||
|
_ = ln2.Close()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
results, err := checker.CheckPorts(
|
||||||
|
context.Background(),
|
||||||
|
"127.0.0.1",
|
||||||
|
[]int{openPort, closedPort},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected 2 results, got %d", len(results),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !results[openPort].Open {
|
||||||
|
t.Error("expected open port to be open")
|
||||||
|
}
|
||||||
|
|
||||||
|
if results[closedPort].Open {
|
||||||
|
t.Error("expected closed port to be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortInvalidPorts(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
port int
|
||||||
|
}{
|
||||||
|
{"zero", 0},
|
||||||
|
{"negative", -1},
|
||||||
|
{"too high", 65536},
|
||||||
|
{"very negative", -1000},
|
||||||
|
{"very high", 100000},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", tc.port,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf(
|
||||||
|
"expected error for port %d, got nil",
|
||||||
|
tc.port,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortsInvalidPort(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
_, err := checker.CheckPorts(
|
||||||
|
context.Background(),
|
||||||
|
"127.0.0.1",
|
||||||
|
[]int{80, 0, 443},
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid port in list")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPortLatencyReasonable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, port := listenTCP(t)
|
||||||
|
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
checker := portcheck.NewStandalone()
|
||||||
|
|
||||||
|
result, err := checker.CheckPort(
|
||||||
|
context.Background(), "127.0.0.1", port,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Latency > time.Second {
|
||||||
|
t.Errorf(
|
||||||
|
"latency too high for localhost: %v",
|
||||||
|
result.Latency,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
//go:build integration
|
|
||||||
|
|
||||||
package resolver_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"sneak.berlin/go/dnswatcher/internal/resolver"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Integration tests hit real DNS servers. Run with:
|
|
||||||
// go test -tags integration -timeout 60s ./internal/resolver/
|
|
||||||
|
|
||||||
func newIntegrationResolver(t *testing.T) *resolver.Resolver {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
log := slog.New(slog.NewTextHandler(
|
|
||||||
os.Stderr,
|
|
||||||
&slog.HandlerOptions{Level: slog.LevelDebug},
|
|
||||||
))
|
|
||||||
|
|
||||||
return resolver.NewFromLogger(log)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration_FindAuthoritativeNameservers(
|
|
||||||
t *testing.T,
|
|
||||||
) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
r := newIntegrationResolver(t)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(
|
|
||||||
context.Background(), 30*time.Second,
|
|
||||||
)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
nameservers, err := r.FindAuthoritativeNameservers(
|
|
||||||
ctx, "example.com",
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, nameservers)
|
|
||||||
|
|
||||||
t.Logf("example.com NS: %v", nameservers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration_ResolveIPAddresses(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
r := newIntegrationResolver(t)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(
|
|
||||||
context.Background(), 30*time.Second,
|
|
||||||
)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// sneak.cloud is on Cloudflare
|
|
||||||
nameservers, err := r.FindAuthoritativeNameservers(
|
|
||||||
ctx, "sneak.cloud",
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, nameservers)
|
|
||||||
|
|
||||||
hasCloudflare := false
|
|
||||||
|
|
||||||
for _, ns := range nameservers {
|
|
||||||
if strings.Contains(ns, "cloudflare") {
|
|
||||||
hasCloudflare = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.True(t, hasCloudflare,
|
|
||||||
"sneak.cloud should be on Cloudflare, got: %v",
|
|
||||||
nameservers,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -156,8 +156,8 @@ func (s *State) Load() error {
|
|||||||
|
|
||||||
// Save writes the current state to disk atomically.
|
// Save writes the current state to disk atomically.
|
||||||
func (s *State) Save() error {
|
func (s *State) Save() error {
|
||||||
s.mu.RLock()
|
s.mu.Lock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
s.snapshot.LastUpdated = time.Now().UTC()
|
s.snapshot.LastUpdated = time.Now().UTC()
|
||||||
|
|
||||||
|
|||||||
67
internal/tlscheck/extractcertinfo_test.go
Normal file
67
internal/tlscheck/extractcertinfo_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package tlscheck_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCheckCertificateNoPeerCerts(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
lc := &net.ListenConfig{}
|
||||||
|
|
||||||
|
ln, err := lc.Listen(
|
||||||
|
context.Background(), "tcp", "127.0.0.1:0",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("unexpected address type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept and immediately close to cause TLS handshake failure.
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
checker := tlscheck.NewStandalone(
|
||||||
|
tlscheck.WithTimeout(2*time.Second),
|
||||||
|
tlscheck.WithTLSConfig(&tls.Config{
|
||||||
|
InsecureSkipVerify: true, //nolint:gosec // test
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}),
|
||||||
|
tlscheck.WithPort(addr.Port),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err = checker.CheckCertificate(
|
||||||
|
context.Background(), "127.0.0.1", "localhost",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when server presents no certs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrNoPeerCertificatesIsSentinel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := tlscheck.ErrNoPeerCertificates
|
||||||
|
if !errors.Is(err, tlscheck.ErrNoPeerCertificates) {
|
||||||
|
t.Fatal("expected sentinel error to match")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,8 +3,12 @@ package tlscheck
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.uber.org/fx"
|
"go.uber.org/fx"
|
||||||
@@ -12,11 +16,56 @@ import (
|
|||||||
"sneak.berlin/go/dnswatcher/internal/logger"
|
"sneak.berlin/go/dnswatcher/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNotImplemented indicates the TLS checker is not yet implemented.
|
const (
|
||||||
var ErrNotImplemented = errors.New(
|
defaultTimeout = 10 * time.Second
|
||||||
"tls checker not yet implemented",
|
defaultPort = 443
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrUnexpectedConnType indicates the connection was not a TLS
|
||||||
|
// connection.
|
||||||
|
var ErrUnexpectedConnType = errors.New(
|
||||||
|
"unexpected connection type",
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrNoPeerCertificates indicates the TLS connection had no peer
|
||||||
|
// certificates.
|
||||||
|
var ErrNoPeerCertificates = errors.New(
|
||||||
|
"no peer certificates",
|
||||||
|
)
|
||||||
|
|
||||||
|
// CertificateInfo holds information about a TLS certificate.
|
||||||
|
type CertificateInfo struct {
|
||||||
|
CommonName string
|
||||||
|
Issuer string
|
||||||
|
NotAfter time.Time
|
||||||
|
SubjectAlternativeNames []string
|
||||||
|
SerialNumber string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures a Checker.
|
||||||
|
type Option func(*Checker)
|
||||||
|
|
||||||
|
// WithTimeout sets the connection timeout.
|
||||||
|
func WithTimeout(d time.Duration) Option {
|
||||||
|
return func(c *Checker) {
|
||||||
|
c.timeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSConfig sets a custom TLS configuration.
|
||||||
|
func WithTLSConfig(cfg *tls.Config) Option {
|
||||||
|
return func(c *Checker) {
|
||||||
|
c.tlsConfig = cfg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPort sets the TLS port to connect to.
|
||||||
|
func WithPort(port int) Option {
|
||||||
|
return func(c *Checker) {
|
||||||
|
c.port = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Params contains dependencies for Checker.
|
// Params contains dependencies for Checker.
|
||||||
type Params struct {
|
type Params struct {
|
||||||
fx.In
|
fx.In
|
||||||
@@ -26,15 +75,10 @@ type Params struct {
|
|||||||
|
|
||||||
// Checker performs TLS certificate inspection.
|
// Checker performs TLS certificate inspection.
|
||||||
type Checker struct {
|
type Checker struct {
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
}
|
timeout time.Duration
|
||||||
|
tlsConfig *tls.Config
|
||||||
// CertificateInfo holds information about a TLS certificate.
|
port int
|
||||||
type CertificateInfo struct {
|
|
||||||
CommonName string
|
|
||||||
Issuer string
|
|
||||||
NotAfter time.Time
|
|
||||||
SubjectAlternativeNames []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new TLS Checker instance.
|
// New creates a new TLS Checker instance.
|
||||||
@@ -43,16 +87,110 @@ func New(
|
|||||||
params Params,
|
params Params,
|
||||||
) (*Checker, error) {
|
) (*Checker, error) {
|
||||||
return &Checker{
|
return &Checker{
|
||||||
log: params.Logger.Get(),
|
log: params.Logger.Get(),
|
||||||
|
timeout: defaultTimeout,
|
||||||
|
port: defaultPort,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckCertificate connects to the given IP:port using SNI and
|
// NewStandalone creates a Checker without fx dependencies.
|
||||||
// returns certificate information.
|
func NewStandalone(opts ...Option) *Checker {
|
||||||
func (c *Checker) CheckCertificate(
|
checker := &Checker{
|
||||||
_ context.Context,
|
log: slog.Default(),
|
||||||
_ string,
|
timeout: defaultTimeout,
|
||||||
_ string,
|
port: defaultPort,
|
||||||
) (*CertificateInfo, error) {
|
}
|
||||||
return nil, ErrNotImplemented
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(checker)
|
||||||
|
}
|
||||||
|
|
||||||
|
return checker
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckCertificate connects to the given IP address using the
|
||||||
|
// specified SNI hostname and returns certificate information.
|
||||||
|
func (c *Checker) CheckCertificate(
|
||||||
|
ctx context.Context,
|
||||||
|
ipAddress string,
|
||||||
|
sniHostname string,
|
||||||
|
) (*CertificateInfo, error) {
|
||||||
|
target := net.JoinHostPort(
|
||||||
|
ipAddress, strconv.Itoa(c.port),
|
||||||
|
)
|
||||||
|
|
||||||
|
tlsCfg := c.buildTLSConfig(sniHostname)
|
||||||
|
dialer := &tls.Dialer{
|
||||||
|
NetDialer: &net.Dialer{Timeout: c.timeout},
|
||||||
|
Config: tlsCfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"TLS dial to %s: %w", target, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
closeErr := conn.Close()
|
||||||
|
if closeErr != nil {
|
||||||
|
c.log.Debug(
|
||||||
|
"closing TLS connection",
|
||||||
|
"target", target,
|
||||||
|
"error", closeErr.Error(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tlsConn, ok := conn.(*tls.Conn)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"%s: %w", target, ErrUnexpectedConnType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.extractCertInfo(tlsConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Checker) buildTLSConfig(
|
||||||
|
sniHostname string,
|
||||||
|
) *tls.Config {
|
||||||
|
if c.tlsConfig != nil {
|
||||||
|
cfg := c.tlsConfig.Clone()
|
||||||
|
cfg.ServerName = sniHostname
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
ServerName: sniHostname,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Checker) extractCertInfo(
|
||||||
|
conn *tls.Conn,
|
||||||
|
) (*CertificateInfo, error) {
|
||||||
|
state := conn.ConnectionState()
|
||||||
|
if len(state.PeerCertificates) == 0 {
|
||||||
|
return nil, ErrNoPeerCertificates
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := state.PeerCertificates[0]
|
||||||
|
|
||||||
|
sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses))
|
||||||
|
sans = append(sans, cert.DNSNames...)
|
||||||
|
|
||||||
|
for _, ip := range cert.IPAddresses {
|
||||||
|
sans = append(sans, ip.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CertificateInfo{
|
||||||
|
CommonName: cert.Subject.CommonName,
|
||||||
|
Issuer: cert.Issuer.CommonName,
|
||||||
|
NotAfter: cert.NotAfter,
|
||||||
|
SubjectAlternativeNames: sans,
|
||||||
|
SerialNumber: cert.SerialNumber.String(),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
169
internal/tlscheck/tlscheck_test.go
Normal file
169
internal/tlscheck/tlscheck_test.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package tlscheck_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||||
|
)
|
||||||
|
|
||||||
|
func startTLSServer(
|
||||||
|
t *testing.T,
|
||||||
|
) (*httptest.Server, string, int) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
srv := httptest.NewTLSServer(
|
||||||
|
http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
addr, ok := srv.Listener.Addr().(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("unexpected address type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return srv, addr.IP.String(), addr.Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckCertificateValid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv, ip, port := startTLSServer(t)
|
||||||
|
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
checker := tlscheck.NewStandalone(
|
||||||
|
tlscheck.WithTimeout(5*time.Second),
|
||||||
|
tlscheck.WithTLSConfig(&tls.Config{
|
||||||
|
//nolint:gosec // test uses self-signed cert
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}),
|
||||||
|
tlscheck.WithPort(port),
|
||||||
|
)
|
||||||
|
|
||||||
|
info, err := checker.CheckCertificate(
|
||||||
|
context.Background(), ip, "localhost",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info == nil {
|
||||||
|
t.Fatal("expected non-nil CertificateInfo")
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.NotAfter.IsZero() {
|
||||||
|
t.Error("expected non-zero NotAfter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.SerialNumber == "" {
|
||||||
|
t.Error("expected non-empty SerialNumber")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckCertificateConnectionRefused(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
lc := &net.ListenConfig{}
|
||||||
|
|
||||||
|
ln, err := lc.Listen(
|
||||||
|
context.Background(), "tcp", "127.0.0.1:0",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to listen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("unexpected address type")
|
||||||
|
}
|
||||||
|
|
||||||
|
port := addr.Port
|
||||||
|
|
||||||
|
_ = ln.Close()
|
||||||
|
|
||||||
|
checker := tlscheck.NewStandalone(
|
||||||
|
tlscheck.WithTimeout(2*time.Second),
|
||||||
|
tlscheck.WithPort(port),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err = checker.CheckCertificate(
|
||||||
|
context.Background(), "127.0.0.1", "localhost",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for connection refused")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckCertificateContextCanceled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
checker := tlscheck.NewStandalone(
|
||||||
|
tlscheck.WithTimeout(2*time.Second),
|
||||||
|
tlscheck.WithPort(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := checker.CheckCertificate(
|
||||||
|
ctx, "127.0.0.1", "localhost",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for canceled context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckCertificateTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
checker := tlscheck.NewStandalone(
|
||||||
|
tlscheck.WithTimeout(1*time.Millisecond),
|
||||||
|
tlscheck.WithPort(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := checker.CheckCertificate(
|
||||||
|
context.Background(),
|
||||||
|
"192.0.2.1",
|
||||||
|
"example.com",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckCertificateSANs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv, ip, port := startTLSServer(t)
|
||||||
|
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
checker := tlscheck.NewStandalone(
|
||||||
|
tlscheck.WithTimeout(5*time.Second),
|
||||||
|
tlscheck.WithTLSConfig(&tls.Config{
|
||||||
|
//nolint:gosec // test uses self-signed cert
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}),
|
||||||
|
tlscheck.WithPort(port),
|
||||||
|
)
|
||||||
|
|
||||||
|
info, err := checker.CheckCertificate(
|
||||||
|
context.Background(), ip, "localhost",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.CommonName == "" && len(info.SubjectAlternativeNames) == 0 {
|
||||||
|
t.Error("expected CN or SANs to be populated")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ package watcher
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,7 +37,7 @@ type PortChecker interface {
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
address string,
|
address string,
|
||||||
port int,
|
port int,
|
||||||
) (bool, error)
|
) (*portcheck.PortResult, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSChecker inspects TLS certificates.
|
// TLSChecker inspects TLS certificates.
|
||||||
|
|||||||
@@ -477,7 +477,7 @@ func (w *Watcher) checkSinglePort(
|
|||||||
port int,
|
port int,
|
||||||
hostname string,
|
hostname string,
|
||||||
) {
|
) {
|
||||||
open, err := w.portCheck.CheckPort(ctx, ip, port)
|
result, err := w.portCheck.CheckPort(ctx, ip, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Error(
|
w.log.Error(
|
||||||
"port check failed",
|
"port check failed",
|
||||||
@@ -493,9 +493,9 @@ func (w *Watcher) checkSinglePort(
|
|||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
prev, hasPrev := w.state.GetPortState(key)
|
prev, hasPrev := w.state.GetPortState(key)
|
||||||
|
|
||||||
if hasPrev && !w.firstRun && prev.Open != open {
|
if hasPrev && !w.firstRun && prev.Open != result.Open {
|
||||||
stateStr := "closed"
|
stateStr := "closed"
|
||||||
if open {
|
if result.Open {
|
||||||
stateStr = "open"
|
stateStr = "open"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -513,7 +513,7 @@ func (w *Watcher) checkSinglePort(
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.state.SetPortState(key, &state.PortState{
|
w.state.SetPortState(key, &state.PortState{
|
||||||
Open: open,
|
Open: result.Open,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
LastChecked: now,
|
LastChecked: now,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sneak.berlin/go/dnswatcher/internal/config"
|
"sneak.berlin/go/dnswatcher/internal/config"
|
||||||
|
"sneak.berlin/go/dnswatcher/internal/portcheck"
|
||||||
"sneak.berlin/go/dnswatcher/internal/state"
|
"sneak.berlin/go/dnswatcher/internal/state"
|
||||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||||
"sneak.berlin/go/dnswatcher/internal/watcher"
|
"sneak.berlin/go/dnswatcher/internal/watcher"
|
||||||
@@ -109,24 +110,20 @@ func (m *mockPortChecker) CheckPort(
|
|||||||
_ context.Context,
|
_ context.Context,
|
||||||
address string,
|
address string,
|
||||||
port int,
|
port int,
|
||||||
) (bool, error) {
|
) (*portcheck.PortResult, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
m.calls++
|
m.calls++
|
||||||
|
|
||||||
if m.err != nil {
|
if m.err != nil {
|
||||||
return false, m.err
|
return nil, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s:%d", address, port)
|
key := fmt.Sprintf("%s:%d", address, port)
|
||||||
open, ok := m.results[key]
|
open := m.results[key]
|
||||||
|
|
||||||
if !ok {
|
return &portcheck.PortResult{Open: open}, nil
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return open, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockTLSChecker struct {
|
type mockTLSChecker struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user