- extractCertInfo now returns an error (ErrNoPeerCertificates) instead of an empty struct when there are no peer certificates - SubjectAlternativeNames now includes both DNS names and IP addresses from cert.IPAddresses Addresses review feedback on PR #7.
197 lines
3.9 KiB
Go
197 lines
3.9 KiB
Go
// Package tlscheck provides TLS certificate inspection.
|
|
package tlscheck
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"strconv"
|
|
"time"
|
|
|
|
"go.uber.org/fx"
|
|
|
|
"sneak.berlin/go/dnswatcher/internal/logger"
|
|
)
|
|
|
|
const (
|
|
defaultTimeout = 10 * time.Second
|
|
defaultPort = 443
|
|
)
|
|
|
|
// ErrUnexpectedConnType indicates the connection was not a TLS
|
|
// connection.
|
|
var ErrUnexpectedConnType = errors.New(
|
|
"unexpected connection type",
|
|
)
|
|
|
|
// ErrNoPeerCertificates indicates the TLS connection had no peer
|
|
// certificates.
|
|
var ErrNoPeerCertificates = errors.New(
|
|
"no peer certificates",
|
|
)
|
|
|
|
// CertificateInfo holds information about a TLS certificate.
|
|
type CertificateInfo struct {
|
|
CommonName string
|
|
Issuer string
|
|
NotAfter time.Time
|
|
SubjectAlternativeNames []string
|
|
SerialNumber string
|
|
}
|
|
|
|
// Option configures a Checker.
|
|
type Option func(*Checker)
|
|
|
|
// WithTimeout sets the connection timeout.
|
|
func WithTimeout(d time.Duration) Option {
|
|
return func(c *Checker) {
|
|
c.timeout = d
|
|
}
|
|
}
|
|
|
|
// WithTLSConfig sets a custom TLS configuration.
|
|
func WithTLSConfig(cfg *tls.Config) Option {
|
|
return func(c *Checker) {
|
|
c.tlsConfig = cfg
|
|
}
|
|
}
|
|
|
|
// WithPort sets the TLS port to connect to.
|
|
func WithPort(port int) Option {
|
|
return func(c *Checker) {
|
|
c.port = port
|
|
}
|
|
}
|
|
|
|
// Params contains dependencies for Checker.
|
|
type Params struct {
|
|
fx.In
|
|
|
|
Logger *logger.Logger
|
|
}
|
|
|
|
// Checker performs TLS certificate inspection.
|
|
type Checker struct {
|
|
log *slog.Logger
|
|
timeout time.Duration
|
|
tlsConfig *tls.Config
|
|
port int
|
|
}
|
|
|
|
// New creates a new TLS Checker instance.
|
|
func New(
|
|
_ fx.Lifecycle,
|
|
params Params,
|
|
) (*Checker, error) {
|
|
return &Checker{
|
|
log: params.Logger.Get(),
|
|
timeout: defaultTimeout,
|
|
port: defaultPort,
|
|
}, nil
|
|
}
|
|
|
|
// NewStandalone creates a Checker without fx dependencies.
|
|
func NewStandalone(opts ...Option) *Checker {
|
|
checker := &Checker{
|
|
log: slog.Default(),
|
|
timeout: defaultTimeout,
|
|
port: defaultPort,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(checker)
|
|
}
|
|
|
|
return checker
|
|
}
|
|
|
|
// CheckCertificate connects to the given IP address using the
|
|
// specified SNI hostname and returns certificate information.
|
|
func (c *Checker) CheckCertificate(
|
|
ctx context.Context,
|
|
ipAddress string,
|
|
sniHostname string,
|
|
) (*CertificateInfo, error) {
|
|
target := net.JoinHostPort(
|
|
ipAddress, strconv.Itoa(c.port),
|
|
)
|
|
|
|
tlsCfg := c.buildTLSConfig(sniHostname)
|
|
dialer := &tls.Dialer{
|
|
NetDialer: &net.Dialer{Timeout: c.timeout},
|
|
Config: tlsCfg,
|
|
}
|
|
|
|
conn, err := dialer.DialContext(ctx, "tcp", target)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(
|
|
"TLS dial to %s: %w", target, err,
|
|
)
|
|
}
|
|
|
|
defer func() {
|
|
closeErr := conn.Close()
|
|
if closeErr != nil {
|
|
c.log.Debug(
|
|
"closing TLS connection",
|
|
"target", target,
|
|
"error", closeErr.Error(),
|
|
)
|
|
}
|
|
}()
|
|
|
|
tlsConn, ok := conn.(*tls.Conn)
|
|
if !ok {
|
|
return nil, fmt.Errorf(
|
|
"%s: %w", target, ErrUnexpectedConnType,
|
|
)
|
|
}
|
|
|
|
return c.extractCertInfo(tlsConn)
|
|
}
|
|
|
|
func (c *Checker) buildTLSConfig(
|
|
sniHostname string,
|
|
) *tls.Config {
|
|
if c.tlsConfig != nil {
|
|
cfg := c.tlsConfig.Clone()
|
|
cfg.ServerName = sniHostname
|
|
|
|
return cfg
|
|
}
|
|
|
|
return &tls.Config{
|
|
ServerName: sniHostname,
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
}
|
|
|
|
func (c *Checker) extractCertInfo(
|
|
conn *tls.Conn,
|
|
) (*CertificateInfo, error) {
|
|
state := conn.ConnectionState()
|
|
if len(state.PeerCertificates) == 0 {
|
|
return nil, ErrNoPeerCertificates
|
|
}
|
|
|
|
cert := state.PeerCertificates[0]
|
|
|
|
sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses))
|
|
sans = append(sans, cert.DNSNames...)
|
|
|
|
for _, ip := range cert.IPAddresses {
|
|
sans = append(sans, ip.String())
|
|
}
|
|
|
|
return &CertificateInfo{
|
|
CommonName: cert.Subject.CommonName,
|
|
Issuer: cert.Issuer.CommonName,
|
|
NotAfter: cert.NotAfter,
|
|
SubjectAlternativeNames: sans,
|
|
SerialNumber: cert.SerialNumber.String(),
|
|
}, nil
|
|
}
|