// Package tlscheck provides TLS certificate inspection. package tlscheck import ( "context" "crypto/tls" "errors" "fmt" "log/slog" "net" "strconv" "time" "go.uber.org/fx" "sneak.berlin/go/dnswatcher/internal/logger" ) const ( defaultTimeout = 10 * time.Second defaultPort = 443 ) // ErrUnexpectedConnType indicates the connection was not a TLS // connection. var ErrUnexpectedConnType = errors.New( "unexpected connection type", ) // CertificateInfo holds information about a TLS certificate. type CertificateInfo struct { CommonName string Issuer string NotAfter time.Time SubjectAlternativeNames []string SerialNumber string } // Option configures a Checker. type Option func(*Checker) // WithTimeout sets the connection timeout. func WithTimeout(d time.Duration) Option { return func(c *Checker) { c.timeout = d } } // WithTLSConfig sets a custom TLS configuration. func WithTLSConfig(cfg *tls.Config) Option { return func(c *Checker) { c.tlsConfig = cfg } } // WithPort sets the TLS port to connect to. func WithPort(port int) Option { return func(c *Checker) { c.port = port } } // Params contains dependencies for Checker. type Params struct { fx.In Logger *logger.Logger } // Checker performs TLS certificate inspection. type Checker struct { log *slog.Logger timeout time.Duration tlsConfig *tls.Config port int } // New creates a new TLS Checker instance. func New( _ fx.Lifecycle, params Params, ) (*Checker, error) { return &Checker{ log: params.Logger.Get(), timeout: defaultTimeout, port: defaultPort, }, nil } // NewStandalone creates a Checker without fx dependencies. func NewStandalone(opts ...Option) *Checker { checker := &Checker{ log: slog.Default(), timeout: defaultTimeout, port: defaultPort, } for _, opt := range opts { opt(checker) } return checker } // CheckCertificate connects to the given IP address using the // specified SNI hostname and returns certificate information. func (c *Checker) CheckCertificate( ctx context.Context, ipAddress string, sniHostname string, ) (*CertificateInfo, error) { target := net.JoinHostPort( ipAddress, strconv.Itoa(c.port), ) tlsCfg := c.buildTLSConfig(sniHostname) dialer := &tls.Dialer{ NetDialer: &net.Dialer{Timeout: c.timeout}, Config: tlsCfg, } conn, err := dialer.DialContext(ctx, "tcp", target) if err != nil { return nil, fmt.Errorf( "TLS dial to %s: %w", target, err, ) } defer func() { closeErr := conn.Close() if closeErr != nil { c.log.Debug( "closing TLS connection", "target", target, "error", closeErr.Error(), ) } }() tlsConn, ok := conn.(*tls.Conn) if !ok { return nil, fmt.Errorf( "%s: %w", target, ErrUnexpectedConnType, ) } return c.extractCertInfo(tlsConn), nil } func (c *Checker) buildTLSConfig( sniHostname string, ) *tls.Config { if c.tlsConfig != nil { cfg := c.tlsConfig.Clone() cfg.ServerName = sniHostname return cfg } return &tls.Config{ ServerName: sniHostname, MinVersion: tls.VersionTLS12, } } func (c *Checker) extractCertInfo( conn *tls.Conn, ) *CertificateInfo { state := conn.ConnectionState() if len(state.PeerCertificates) == 0 { return &CertificateInfo{} } cert := state.PeerCertificates[0] sans := make([]string, len(cert.DNSNames)) copy(sans, cert.DNSNames) return &CertificateInfo{ CommonName: cert.Subject.CommonName, Issuer: cert.Issuer.CommonName, NotAfter: cert.NotAfter, SubjectAlternativeNames: sans, SerialNumber: cert.SerialNumber.String(), } }