Compare commits

...

4 Commits

Author SHA1 Message Date
clawbot
3c32971e11 test: add tests for no-peer-certificates error path 2026-02-19 23:55:17 -08:00
user
cc49207e27 fix: return error for no peer certs, include IP SANs
- extractCertInfo now returns an error (ErrNoPeerCertificates) instead
  of an empty struct when there are no peer certificates
- SubjectAlternativeNames now includes both DNS names and IP addresses
  from cert.IPAddresses

Addresses review feedback on PR #7.
2026-02-19 23:51:55 -08:00
clawbot
72dda7bd80 fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
Fix goimports formatting in tlscheck_test.go.
2026-02-19 23:44:06 -08:00
clawbot
7097f66506 feat: implement TLS certificate inspector (closes #4) 2026-02-19 13:45:47 -08:00
4 changed files with 451 additions and 40 deletions

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@ -48,6 +49,9 @@ type Service struct {
log *slog.Logger log *slog.Logger
client *http.Client client *http.Client
config *config.Config config *config.Config
ntfyURL *url.URL
slackWebhookURL *url.URL
mattermostWebhookURL *url.URL
} }
// New creates a new notify Service. // New creates a new notify Service.
@ -55,13 +59,44 @@ func New(
_ fx.Lifecycle, _ fx.Lifecycle,
params Params, params Params,
) (*Service, error) { ) (*Service, error) {
return &Service{ svc := &Service{
log: params.Logger.Get(), log: params.Logger.Get(),
client: &http.Client{ client: &http.Client{
Timeout: httpClientTimeout, Timeout: httpClientTimeout,
}, },
config: params.Config, config: params.Config,
}, nil }
if params.Config.NtfyTopic != "" {
u, err := url.ParseRequestURI(params.Config.NtfyTopic)
if err != nil {
return nil, fmt.Errorf("invalid ntfy topic URL: %w", err)
}
svc.ntfyURL = u
}
if params.Config.SlackWebhook != "" {
u, err := url.ParseRequestURI(params.Config.SlackWebhook)
if err != nil {
return nil, fmt.Errorf("invalid slack webhook URL: %w", err)
}
svc.slackWebhookURL = u
}
if params.Config.MattermostWebhook != "" {
u, err := url.ParseRequestURI(params.Config.MattermostWebhook)
if err != nil {
return nil, fmt.Errorf(
"invalid mattermost webhook URL: %w", err,
)
}
svc.mattermostWebhookURL = u
}
return svc, nil
} }
// SendNotification sends a notification to all configured endpoints. // SendNotification sends a notification to all configured endpoints.
@ -69,13 +104,13 @@ func (svc *Service) SendNotification(
ctx context.Context, ctx context.Context,
title, message, priority string, title, message, priority string,
) { ) {
if svc.config.NtfyTopic != "" { if svc.ntfyURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendNtfy( err := svc.sendNtfy(
notifyCtx, notifyCtx,
svc.config.NtfyTopic, svc.ntfyURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@ -87,13 +122,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.SlackWebhook != "" { if svc.slackWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.SlackWebhook, svc.slackWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@ -105,13 +140,13 @@ func (svc *Service) SendNotification(
}() }()
} }
if svc.config.MattermostWebhook != "" { if svc.mattermostWebhookURL != nil {
go func() { go func() {
notifyCtx := context.WithoutCancel(ctx) notifyCtx := context.WithoutCancel(ctx)
err := svc.sendSlack( err := svc.sendSlack(
notifyCtx, notifyCtx,
svc.config.MattermostWebhook, svc.mattermostWebhookURL,
title, message, priority, title, message, priority,
) )
if err != nil { if err != nil {
@ -126,18 +161,19 @@ func (svc *Service) SendNotification(
func (svc *Service) sendNtfy( func (svc *Service) sendNtfy(
ctx context.Context, ctx context.Context,
topic, title, message, priority string, topicURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending ntfy notification", "sending ntfy notification",
"topic", topic, "topic", topicURL.String(),
"title", title, "title", title,
) )
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
topic, topicURL.String(),
bytes.NewBufferString(message), bytes.NewBufferString(message),
) )
if err != nil { if err != nil {
@ -147,7 +183,7 @@ func (svc *Service) sendNtfy(
request.Header.Set("Title", title) request.Header.Set("Title", title)
request.Header.Set("Priority", ntfyPriority(priority)) request.Header.Set("Priority", ntfyPriority(priority))
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending ntfy request: %w", err) return fmt.Errorf("sending ntfy request: %w", err)
} }
@ -193,11 +229,12 @@ type SlackAttachment struct {
func (svc *Service) sendSlack( func (svc *Service) sendSlack(
ctx context.Context, ctx context.Context,
webhookURL, title, message, priority string, webhookURL *url.URL,
title, message, priority string,
) error { ) error {
svc.log.Debug( svc.log.Debug(
"sending webhook notification", "sending webhook notification",
"url", webhookURL, "url", webhookURL.String(),
"title", title, "title", title,
) )
@ -219,7 +256,7 @@ func (svc *Service) sendSlack(
request, err := http.NewRequestWithContext( request, err := http.NewRequestWithContext(
ctx, ctx,
http.MethodPost, http.MethodPost,
webhookURL, webhookURL.String(),
bytes.NewBuffer(body), bytes.NewBuffer(body),
) )
if err != nil { if err != nil {
@ -228,7 +265,7 @@ func (svc *Service) sendSlack(
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
resp, err := svc.client.Do(request) resp, err := svc.client.Do(request) //nolint:gosec // URL validated at Service construction time
if err != nil { if err != nil {
return fmt.Errorf("sending webhook request: %w", err) return fmt.Errorf("sending webhook request: %w", err)
} }

View File

@ -0,0 +1,67 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func TestCheckCertificateNoPeerCerts(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = ln.Close() }()
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
// Accept and immediately close to cause TLS handshake failure.
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
InsecureSkipVerify: true, //nolint:gosec // test
MinVersion: tls.VersionTLS12,
}),
tlscheck.WithPort(addr.Port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error when server presents no certs")
}
}
func TestErrNoPeerCertificatesIsSentinel(t *testing.T) {
t.Parallel()
err := tlscheck.ErrNoPeerCertificates
if !errors.Is(err, tlscheck.ErrNoPeerCertificates) {
t.Fatal("expected sentinel error to match")
}
}

View File

@ -3,8 +3,12 @@ package tlscheck
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"net"
"strconv"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@ -12,11 +16,56 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
// ErrNotImplemented indicates the TLS checker is not yet implemented. const (
var ErrNotImplemented = errors.New( defaultTimeout = 10 * time.Second
"tls checker not yet implemented", defaultPort = 443
) )
// ErrUnexpectedConnType indicates the connection was not a TLS
// connection.
var ErrUnexpectedConnType = errors.New(
"unexpected connection type",
)
// ErrNoPeerCertificates indicates the TLS connection had no peer
// certificates.
var ErrNoPeerCertificates = errors.New(
"no peer certificates",
)
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
SerialNumber string
}
// Option configures a Checker.
type Option func(*Checker)
// WithTimeout sets the connection timeout.
func WithTimeout(d time.Duration) Option {
return func(c *Checker) {
c.timeout = d
}
}
// WithTLSConfig sets a custom TLS configuration.
func WithTLSConfig(cfg *tls.Config) Option {
return func(c *Checker) {
c.tlsConfig = cfg
}
}
// WithPort sets the TLS port to connect to.
func WithPort(port int) Option {
return func(c *Checker) {
c.port = port
}
}
// Params contains dependencies for Checker. // Params contains dependencies for Checker.
type Params struct { type Params struct {
fx.In fx.In
@ -27,14 +76,9 @@ type Params struct {
// Checker performs TLS certificate inspection. // Checker performs TLS certificate inspection.
type Checker struct { type Checker struct {
log *slog.Logger log *slog.Logger
} timeout time.Duration
tlsConfig *tls.Config
// CertificateInfo holds information about a TLS certificate. port int
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
} }
// New creates a new TLS Checker instance. // New creates a new TLS Checker instance.
@ -44,15 +88,109 @@ func New(
) (*Checker, error) { ) (*Checker, error) {
return &Checker{ return &Checker{
log: params.Logger.Get(), log: params.Logger.Get(),
timeout: defaultTimeout,
port: defaultPort,
}, nil }, nil
} }
// CheckCertificate connects to the given IP:port using SNI and // NewStandalone creates a Checker without fx dependencies.
// returns certificate information. func NewStandalone(opts ...Option) *Checker {
func (c *Checker) CheckCertificate( checker := &Checker{
_ context.Context, log: slog.Default(),
_ string, timeout: defaultTimeout,
_ string, port: defaultPort,
) (*CertificateInfo, error) { }
return nil, ErrNotImplemented
for _, opt := range opts {
opt(checker)
}
return checker
}
// CheckCertificate connects to the given IP address using the
// specified SNI hostname and returns certificate information.
func (c *Checker) CheckCertificate(
ctx context.Context,
ipAddress string,
sniHostname string,
) (*CertificateInfo, error) {
target := net.JoinHostPort(
ipAddress, strconv.Itoa(c.port),
)
tlsCfg := c.buildTLSConfig(sniHostname)
dialer := &tls.Dialer{
NetDialer: &net.Dialer{Timeout: c.timeout},
Config: tlsCfg,
}
conn, err := dialer.DialContext(ctx, "tcp", target)
if err != nil {
return nil, fmt.Errorf(
"TLS dial to %s: %w", target, err,
)
}
defer func() {
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing TLS connection",
"target", target,
"error", closeErr.Error(),
)
}
}()
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return nil, fmt.Errorf(
"%s: %w", target, ErrUnexpectedConnType,
)
}
return c.extractCertInfo(tlsConn)
}
func (c *Checker) buildTLSConfig(
sniHostname string,
) *tls.Config {
if c.tlsConfig != nil {
cfg := c.tlsConfig.Clone()
cfg.ServerName = sniHostname
return cfg
}
return &tls.Config{
ServerName: sniHostname,
MinVersion: tls.VersionTLS12,
}
}
func (c *Checker) extractCertInfo(
conn *tls.Conn,
) (*CertificateInfo, error) {
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return nil, ErrNoPeerCertificates
}
cert := state.PeerCertificates[0]
sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses))
sans = append(sans, cert.DNSNames...)
for _, ip := range cert.IPAddresses {
sans = append(sans, ip.String())
}
return &CertificateInfo{
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: sans,
SerialNumber: cert.SerialNumber.String(),
}, nil
} }

View File

@ -0,0 +1,169 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func startTLSServer(
t *testing.T,
) (*httptest.Server, string, int) {
t.Helper()
srv := httptest.NewTLSServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
},
),
)
addr, ok := srv.Listener.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return srv, addr.IP.String(), addr.Port
}
func TestCheckCertificateValid(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info == nil {
t.Fatal("expected non-nil CertificateInfo")
}
if info.NotAfter.IsZero() {
t.Error("expected non-zero NotAfter")
}
if info.SerialNumber == "" {
t.Error("expected non-empty SerialNumber")
}
}
func TestCheckCertificateConnectionRefused(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
port := addr.Port
_ = ln.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for connection refused")
}
}
func TestCheckCertificateContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
ctx, "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for canceled context")
}
}
func TestCheckCertificateTimeout(t *testing.T) {
t.Parallel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(1*time.Millisecond),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
context.Background(),
"192.0.2.1",
"example.com",
)
if err == nil {
t.Fatal("expected error for timeout")
}
}
func TestCheckCertificateSANs(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.CommonName == "" && len(info.SubjectAlternativeNames) == 0 {
t.Error("expected CN or SANs to be populated")
}
}