2 Commits

Author SHA1 Message Date
clawbot
f8d5a8f6cc fix: resolve gosec SSRF findings and formatting issues
Validate webhook/ntfy URLs at Service construction time and add
targeted nolint directives for pre-validated URL usage.
2026-02-19 23:43:42 -08:00
clawbot
73e01c7664 feat: unify DOMAINS/HOSTNAMES into single TARGETS config
Replace DNSWATCHER_DOMAINS and DNSWATCHER_HOSTNAMES with a single
DNSWATCHER_TARGETS env var. Names are automatically classified as apex
domains or hostnames using the Public Suffix List
(golang.org/x/net/publicsuffix).

- ClassifyDNSName() uses EffectiveTLDPlusOne to determine type
- Public suffixes themselves (e.g. co.uk) are rejected with an error
- Old DOMAINS/HOSTNAMES vars removed entirely (pre-1.0, no compat needed)
- README updated with pre-1.0 warning

Closes #10
2026-02-19 20:09:39 -08:00
9 changed files with 211 additions and 409 deletions

View File

@@ -1,5 +1,7 @@
# dnswatcher # dnswatcher
> ⚠️ Pre-1.0 software. APIs, configuration, and behavior may change without notice.
dnswatcher is a production DNS and infrastructure monitoring daemon written in dnswatcher is a production DNS and infrastructure monitoring daemon written in
Go. It watches configured DNS domains and hostnames for changes, monitors TCP Go. It watches configured DNS domains and hostnames for changes, monitors TCP
port availability, tracks TLS certificate expiry, and delivers real-time port availability, tracks TLS certificate expiry, and delivers real-time
@@ -195,8 +197,7 @@ the following precedence (highest to lowest):
| `PORT` | HTTP listen port | `8080` | | `PORT` | HTTP listen port | `8080` |
| `DNSWATCHER_DEBUG` | Enable debug logging | `false` | | `DNSWATCHER_DEBUG` | Enable debug logging | `false` |
| `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` | | `DNSWATCHER_DATA_DIR` | Directory for state file | `./data` |
| `DNSWATCHER_DOMAINS` | Comma-separated list of apex domains | `""` | | `DNSWATCHER_TARGETS` | Comma-separated DNS names (auto-classified via PSL) | `""` |
| `DNSWATCHER_HOSTNAMES` | Comma-separated list of hostnames | `""` |
| `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` | | `DNSWATCHER_SLACK_WEBHOOK` | Slack incoming webhook URL | `""` |
| `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` | | `DNSWATCHER_MATTERMOST_WEBHOOK` | Mattermost incoming webhook URL | `""` |
| `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` | | `DNSWATCHER_NTFY_TOPIC` | ntfy topic URL | `""` |
@@ -214,8 +215,7 @@ the following precedence (highest to lowest):
PORT=8080 PORT=8080
DNSWATCHER_DEBUG=false DNSWATCHER_DEBUG=false
DNSWATCHER_DATA_DIR=./data DNSWATCHER_DATA_DIR=./data
DNSWATCHER_DOMAINS=example.com,example.org DNSWATCHER_TARGETS=example.com,example.org,www.example.com,api.example.com,mail.example.org
DNSWATCHER_HOSTNAMES=www.example.com,api.example.com,mail.example.org
DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx DNSWATCHER_SLACK_WEBHOOK=https://hooks.slack.com/services/T.../B.../xxx
DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx DNSWATCHER_MATTERMOST_WEBHOOK=https://mattermost.example.com/hooks/xxx
DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-dns-alerts
@@ -352,8 +352,7 @@ docker build -t dnswatcher .
docker run -d \ docker run -d \
-p 8080:8080 \ -p 8080:8080 \
-v dnswatcher-data:/var/lib/dnswatcher \ -v dnswatcher-data:/var/lib/dnswatcher \
-e DNSWATCHER_DOMAINS=example.com \ -e DNSWATCHER_TARGETS=example.com,www.example.com \
-e DNSWATCHER_HOSTNAMES=www.example.com \
-e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \ -e DNSWATCHER_NTFY_TOPIC=https://ntfy.sh/my-alerts \
dnswatcher dnswatcher
``` ```

5
go.mod
View File

@@ -10,6 +10,7 @@ require (
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
go.uber.org/fx v1.24.0 go.uber.org/fx v1.24.0
golang.org/x/net v0.50.0
) )
require ( require (
@@ -33,7 +34,7 @@ require (
go.uber.org/zap v1.26.0 // indirect go.uber.org/zap v1.26.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/sys v0.35.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.28.0 // indirect golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect google.golang.org/protobuf v1.36.8 // indirect
) )

10
go.sum
View File

@@ -74,10 +74,12 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -0,0 +1,85 @@
package config
import (
"errors"
"fmt"
"strings"
"golang.org/x/net/publicsuffix"
)
// DNSNameType indicates whether a DNS name is an apex domain or a hostname.
type DNSNameType int
const (
// DNSNameTypeDomain indicates the name is an apex (eTLD+1) domain.
DNSNameTypeDomain DNSNameType = iota
// DNSNameTypeHostname indicates the name is a subdomain/hostname.
DNSNameTypeHostname
)
// ErrEmptyDNSName is returned when an empty string is passed to ClassifyDNSName.
var ErrEmptyDNSName = errors.New("empty DNS name")
// String returns the string representation of a DNSNameType.
func (t DNSNameType) String() string {
switch t {
case DNSNameTypeDomain:
return "domain"
case DNSNameTypeHostname:
return "hostname"
default:
return "unknown"
}
}
// ClassifyDNSName determines whether a DNS name is an apex domain or a
// hostname (subdomain) using the Public Suffix List. It returns an error
// if the input is empty or is itself a public suffix (e.g. "co.uk").
func ClassifyDNSName(name string) (DNSNameType, error) {
name = strings.ToLower(strings.TrimSuffix(strings.TrimSpace(name), "."))
if name == "" {
return 0, ErrEmptyDNSName
}
etld1, err := publicsuffix.EffectiveTLDPlusOne(name)
if err != nil {
return 0, fmt.Errorf("invalid DNS name %q: %w", name, err)
}
if name == etld1 {
return DNSNameTypeDomain, nil
}
return DNSNameTypeHostname, nil
}
// ClassifyTargets splits a list of DNS names into apex domains and
// hostnames using the Public Suffix List. It returns an error if any
// name cannot be classified.
func ClassifyTargets(targets []string) ([]string, []string, error) {
var domains, hostnames []string
for _, t := range targets {
normalized := strings.ToLower(strings.TrimSuffix(strings.TrimSpace(t), "."))
if normalized == "" {
continue
}
typ, classErr := ClassifyDNSName(normalized)
if classErr != nil {
return nil, nil, classErr
}
switch typ {
case DNSNameTypeDomain:
domains = append(domains, normalized)
case DNSNameTypeHostname:
hostnames = append(hostnames, normalized)
}
}
return domains, hostnames, nil
}

View File

@@ -0,0 +1,83 @@
package config_test
import (
"testing"
"sneak.berlin/go/dnswatcher/internal/config"
)
func TestClassifyDNSName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want config.DNSNameType
wantErr bool
}{
{name: "apex domain simple", input: "example.com", want: config.DNSNameTypeDomain},
{name: "hostname simple", input: "www.example.com", want: config.DNSNameTypeHostname},
{name: "apex domain multi-part TLD", input: "example.co.uk", want: config.DNSNameTypeDomain},
{name: "hostname multi-part TLD", input: "api.example.co.uk", want: config.DNSNameTypeHostname},
{name: "public suffix itself", input: "co.uk", wantErr: true},
{name: "empty string", input: "", wantErr: true},
{name: "deeply nested hostname", input: "a.b.c.example.com", want: config.DNSNameTypeHostname},
{name: "trailing dot stripped", input: "example.com.", want: config.DNSNameTypeDomain},
{name: "uppercase normalized", input: "WWW.Example.COM", want: config.DNSNameTypeHostname},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := config.ClassifyDNSName(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("ClassifyDNSName(%q) expected error, got %v", tt.input, got)
}
return
}
if err != nil {
t.Fatalf("ClassifyDNSName(%q) unexpected error: %v", tt.input, err)
}
if got != tt.want {
t.Errorf("ClassifyDNSName(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestClassifyTargets(t *testing.T) {
t.Parallel()
domains, hostnames, err := config.ClassifyTargets([]string{
"example.com",
"www.example.com",
"example.co.uk",
"api.example.co.uk",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(domains) != 2 {
t.Errorf("expected 2 domains, got %d: %v", len(domains), domains)
}
if len(hostnames) != 2 {
t.Errorf("expected 2 hostnames, got %d: %v", len(hostnames), hostnames)
}
}
func TestClassifyTargetsRejectsPublicSuffix(t *testing.T) {
t.Parallel()
_, _, err := config.ClassifyTargets([]string{"co.uk"})
if err == nil {
t.Error("expected error for public suffix, got nil")
}
}

View File

@@ -89,8 +89,7 @@ func setupViper(name string) {
viper.SetDefault("PORT", defaultPort) viper.SetDefault("PORT", defaultPort)
viper.SetDefault("DEBUG", false) viper.SetDefault("DEBUG", false)
viper.SetDefault("DATA_DIR", "./data") viper.SetDefault("DATA_DIR", "./data")
viper.SetDefault("DOMAINS", "") viper.SetDefault("TARGETS", "")
viper.SetDefault("HOSTNAMES", "")
viper.SetDefault("SLACK_WEBHOOK", "") viper.SetDefault("SLACK_WEBHOOK", "")
viper.SetDefault("MATTERMOST_WEBHOOK", "") viper.SetDefault("MATTERMOST_WEBHOOK", "")
viper.SetDefault("NTFY_TOPIC", "") viper.SetDefault("NTFY_TOPIC", "")
@@ -133,12 +132,19 @@ func buildConfig(
tlsInterval = defaultTLSInterval tlsInterval = defaultTLSInterval
} }
domains, hostnames, err := ClassifyTargets(
parseCSV(viper.GetString("TARGETS")),
)
if err != nil {
return nil, fmt.Errorf("invalid targets configuration: %w", err)
}
cfg := &Config{ cfg := &Config{
Port: viper.GetInt("PORT"), Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"), Debug: viper.GetBool("DEBUG"),
DataDir: viper.GetString("DATA_DIR"), DataDir: viper.GetString("DATA_DIR"),
Domains: parseCSV(viper.GetString("DOMAINS")), Domains: domains,
Hostnames: parseCSV(viper.GetString("HOSTNAMES")), Hostnames: hostnames,
SlackWebhook: viper.GetString("SLACK_WEBHOOK"), SlackWebhook: viper.GetString("SLACK_WEBHOOK"),
MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"), MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"),
NtfyTopic: viper.GetString("NTFY_TOPIC"), NtfyTopic: viper.GetString("NTFY_TOPIC"),

View File

@@ -1,67 +0,0 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func TestCheckCertificateNoPeerCerts(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatal(err)
}
defer func() { _ = ln.Close() }()
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
// Accept and immediately close to cause TLS handshake failure.
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
_ = conn.Close()
}()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
InsecureSkipVerify: true, //nolint:gosec // test
MinVersion: tls.VersionTLS12,
}),
tlscheck.WithPort(addr.Port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error when server presents no certs")
}
}
func TestErrNoPeerCertificatesIsSentinel(t *testing.T) {
t.Parallel()
err := tlscheck.ErrNoPeerCertificates
if !errors.Is(err, tlscheck.ErrNoPeerCertificates) {
t.Fatal("expected sentinel error to match")
}
}

View File

@@ -3,12 +3,8 @@ package tlscheck
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"net"
"strconv"
"time" "time"
"go.uber.org/fx" "go.uber.org/fx"
@@ -16,56 +12,11 @@ import (
"sneak.berlin/go/dnswatcher/internal/logger" "sneak.berlin/go/dnswatcher/internal/logger"
) )
const ( // ErrNotImplemented indicates the TLS checker is not yet implemented.
defaultTimeout = 10 * time.Second var ErrNotImplemented = errors.New(
defaultPort = 443 "tls checker not yet implemented",
) )
// ErrUnexpectedConnType indicates the connection was not a TLS
// connection.
var ErrUnexpectedConnType = errors.New(
"unexpected connection type",
)
// ErrNoPeerCertificates indicates the TLS connection had no peer
// certificates.
var ErrNoPeerCertificates = errors.New(
"no peer certificates",
)
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
SerialNumber string
}
// Option configures a Checker.
type Option func(*Checker)
// WithTimeout sets the connection timeout.
func WithTimeout(d time.Duration) Option {
return func(c *Checker) {
c.timeout = d
}
}
// WithTLSConfig sets a custom TLS configuration.
func WithTLSConfig(cfg *tls.Config) Option {
return func(c *Checker) {
c.tlsConfig = cfg
}
}
// WithPort sets the TLS port to connect to.
func WithPort(port int) Option {
return func(c *Checker) {
c.port = port
}
}
// Params contains dependencies for Checker. // Params contains dependencies for Checker.
type Params struct { type Params struct {
fx.In fx.In
@@ -75,10 +26,15 @@ type Params struct {
// Checker performs TLS certificate inspection. // Checker performs TLS certificate inspection.
type Checker struct { type Checker struct {
log *slog.Logger log *slog.Logger
timeout time.Duration }
tlsConfig *tls.Config
port int // CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
} }
// New creates a new TLS Checker instance. // New creates a new TLS Checker instance.
@@ -87,110 +43,16 @@ func New(
params Params, params Params,
) (*Checker, error) { ) (*Checker, error) {
return &Checker{ return &Checker{
log: params.Logger.Get(), log: params.Logger.Get(),
timeout: defaultTimeout,
port: defaultPort,
}, nil }, nil
} }
// NewStandalone creates a Checker without fx dependencies. // CheckCertificate connects to the given IP:port using SNI and
func NewStandalone(opts ...Option) *Checker { // returns certificate information.
checker := &Checker{
log: slog.Default(),
timeout: defaultTimeout,
port: defaultPort,
}
for _, opt := range opts {
opt(checker)
}
return checker
}
// CheckCertificate connects to the given IP address using the
// specified SNI hostname and returns certificate information.
func (c *Checker) CheckCertificate( func (c *Checker) CheckCertificate(
ctx context.Context, _ context.Context,
ipAddress string, _ string,
sniHostname string, _ string,
) (*CertificateInfo, error) { ) (*CertificateInfo, error) {
target := net.JoinHostPort( return nil, ErrNotImplemented
ipAddress, strconv.Itoa(c.port),
)
tlsCfg := c.buildTLSConfig(sniHostname)
dialer := &tls.Dialer{
NetDialer: &net.Dialer{Timeout: c.timeout},
Config: tlsCfg,
}
conn, err := dialer.DialContext(ctx, "tcp", target)
if err != nil {
return nil, fmt.Errorf(
"TLS dial to %s: %w", target, err,
)
}
defer func() {
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing TLS connection",
"target", target,
"error", closeErr.Error(),
)
}
}()
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return nil, fmt.Errorf(
"%s: %w", target, ErrUnexpectedConnType,
)
}
return c.extractCertInfo(tlsConn)
}
func (c *Checker) buildTLSConfig(
sniHostname string,
) *tls.Config {
if c.tlsConfig != nil {
cfg := c.tlsConfig.Clone()
cfg.ServerName = sniHostname
return cfg
}
return &tls.Config{
ServerName: sniHostname,
MinVersion: tls.VersionTLS12,
}
}
func (c *Checker) extractCertInfo(
conn *tls.Conn,
) (*CertificateInfo, error) {
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return nil, ErrNoPeerCertificates
}
cert := state.PeerCertificates[0]
sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses))
sans = append(sans, cert.DNSNames...)
for _, ip := range cert.IPAddresses {
sans = append(sans, ip.String())
}
return &CertificateInfo{
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: sans,
SerialNumber: cert.SerialNumber.String(),
}, nil
} }

View File

@@ -1,169 +0,0 @@
package tlscheck_test
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/tlscheck"
)
func startTLSServer(
t *testing.T,
) (*httptest.Server, string, int) {
t.Helper()
srv := httptest.NewTLSServer(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
},
),
)
addr, ok := srv.Listener.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
return srv, addr.IP.String(), addr.Port
}
func TestCheckCertificateValid(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info == nil {
t.Fatal("expected non-nil CertificateInfo")
}
if info.NotAfter.IsZero() {
t.Error("expected non-zero NotAfter")
}
if info.SerialNumber == "" {
t.Error("expected non-empty SerialNumber")
}
}
func TestCheckCertificateConnectionRefused(t *testing.T) {
t.Parallel()
lc := &net.ListenConfig{}
ln, err := lc.Listen(
context.Background(), "tcp", "127.0.0.1:0",
)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
addr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
t.Fatal("unexpected address type")
}
port := addr.Port
_ = ln.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(port),
)
_, err = checker.CheckCertificate(
context.Background(), "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for connection refused")
}
}
func TestCheckCertificateContextCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
cancel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(2*time.Second),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
ctx, "127.0.0.1", "localhost",
)
if err == nil {
t.Fatal("expected error for canceled context")
}
}
func TestCheckCertificateTimeout(t *testing.T) {
t.Parallel()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(1*time.Millisecond),
tlscheck.WithPort(1),
)
_, err := checker.CheckCertificate(
context.Background(),
"192.0.2.1",
"example.com",
)
if err == nil {
t.Fatal("expected error for timeout")
}
}
func TestCheckCertificateSANs(t *testing.T) {
t.Parallel()
srv, ip, port := startTLSServer(t)
defer srv.Close()
checker := tlscheck.NewStandalone(
tlscheck.WithTimeout(5*time.Second),
tlscheck.WithTLSConfig(&tls.Config{
//nolint:gosec // test uses self-signed cert
InsecureSkipVerify: true,
}),
tlscheck.WithPort(port),
)
info, err := checker.CheckCertificate(
context.Background(), ip, "localhost",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.CommonName == "" && len(info.SubjectAlternativeNames) == 0 {
t.Error("expected CN or SANs to be populated")
}
}