dnswatcher/internal/tlscheck/tlscheck.go

187 lines
3.6 KiB
Go

// Package tlscheck provides TLS certificate inspection.
package tlscheck
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"time"
"go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/logger"
)
const (
defaultTimeout = 10 * time.Second
defaultPort = 443
)
// ErrUnexpectedConnType indicates the connection was not a TLS
// connection.
var ErrUnexpectedConnType = errors.New(
"unexpected connection type",
)
// CertificateInfo holds information about a TLS certificate.
type CertificateInfo struct {
CommonName string
Issuer string
NotAfter time.Time
SubjectAlternativeNames []string
SerialNumber string
}
// Option configures a Checker.
type Option func(*Checker)
// WithTimeout sets the connection timeout.
func WithTimeout(d time.Duration) Option {
return func(c *Checker) {
c.timeout = d
}
}
// WithTLSConfig sets a custom TLS configuration.
func WithTLSConfig(cfg *tls.Config) Option {
return func(c *Checker) {
c.tlsConfig = cfg
}
}
// WithPort sets the TLS port to connect to.
func WithPort(port int) Option {
return func(c *Checker) {
c.port = port
}
}
// Params contains dependencies for Checker.
type Params struct {
fx.In
Logger *logger.Logger
}
// Checker performs TLS certificate inspection.
type Checker struct {
log *slog.Logger
timeout time.Duration
tlsConfig *tls.Config
port int
}
// New creates a new TLS Checker instance.
func New(
_ fx.Lifecycle,
params Params,
) (*Checker, error) {
return &Checker{
log: params.Logger.Get(),
timeout: defaultTimeout,
port: defaultPort,
}, nil
}
// NewStandalone creates a Checker without fx dependencies.
func NewStandalone(opts ...Option) *Checker {
checker := &Checker{
log: slog.Default(),
timeout: defaultTimeout,
port: defaultPort,
}
for _, opt := range opts {
opt(checker)
}
return checker
}
// CheckCertificate connects to the given IP address using the
// specified SNI hostname and returns certificate information.
func (c *Checker) CheckCertificate(
ctx context.Context,
ipAddress string,
sniHostname string,
) (*CertificateInfo, error) {
target := net.JoinHostPort(
ipAddress, strconv.Itoa(c.port),
)
tlsCfg := c.buildTLSConfig(sniHostname)
dialer := &tls.Dialer{
NetDialer: &net.Dialer{Timeout: c.timeout},
Config: tlsCfg,
}
conn, err := dialer.DialContext(ctx, "tcp", target)
if err != nil {
return nil, fmt.Errorf(
"TLS dial to %s: %w", target, err,
)
}
defer func() {
closeErr := conn.Close()
if closeErr != nil {
c.log.Debug(
"closing TLS connection",
"target", target,
"error", closeErr.Error(),
)
}
}()
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return nil, fmt.Errorf(
"%s: %w", target, ErrUnexpectedConnType,
)
}
return c.extractCertInfo(tlsConn), nil
}
func (c *Checker) buildTLSConfig(
sniHostname string,
) *tls.Config {
if c.tlsConfig != nil {
cfg := c.tlsConfig.Clone()
cfg.ServerName = sniHostname
return cfg
}
return &tls.Config{
ServerName: sniHostname,
MinVersion: tls.VersionTLS12,
}
}
func (c *Checker) extractCertInfo(
conn *tls.Conn,
) *CertificateInfo {
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return &CertificateInfo{}
}
cert := state.PeerCertificates[0]
sans := make([]string, len(cert.DNSNames))
copy(sans, cert.DNSNames)
return &CertificateInfo{
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
NotAfter: cert.NotAfter,
SubjectAlternativeNames: sans,
SerialNumber: cert.SerialNumber.String(),
}
}