feat: implement TLS certificate inspector (closes #4) #7
@ -3,8 +3,12 @@ package tlscheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go.uber.org/fx"
|
||||
@ -12,11 +16,50 @@ import (
|
||||
"sneak.berlin/go/dnswatcher/internal/logger"
|
||||
)
|
||||
|
||||
// ErrNotImplemented indicates the TLS checker is not yet implemented.
|
||||
var ErrNotImplemented = errors.New(
|
||||
"tls checker not yet implemented",
|
||||
const (
|
||||
defaultTimeout = 10 * time.Second
|
||||
defaultPort = 443
|
||||
)
|
||||
|
||||
// ErrUnexpectedConnType indicates the connection was not a TLS
|
||||
// connection.
|
||||
var ErrUnexpectedConnType = errors.New(
|
||||
"unexpected connection type",
|
||||
)
|
||||
|
||||
// CertificateInfo holds information about a TLS certificate.
|
||||
type CertificateInfo struct {
|
||||
CommonName string
|
||||
Issuer string
|
||||
NotAfter time.Time
|
||||
SubjectAlternativeNames []string
|
||||
SerialNumber string
|
||||
}
|
||||
|
||||
// Option configures a Checker.
|
||||
type Option func(*Checker)
|
||||
|
||||
// WithTimeout sets the connection timeout.
|
||||
func WithTimeout(d time.Duration) Option {
|
||||
return func(c *Checker) {
|
||||
c.timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig sets a custom TLS configuration.
|
||||
func WithTLSConfig(cfg *tls.Config) Option {
|
||||
return func(c *Checker) {
|
||||
c.tlsConfig = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// WithPort sets the TLS port to connect to.
|
||||
func WithPort(port int) Option {
|
||||
return func(c *Checker) {
|
||||
c.port = port
|
||||
}
|
||||
}
|
||||
|
||||
// Params contains dependencies for Checker.
|
||||
type Params struct {
|
||||
fx.In
|
||||
@ -26,15 +69,10 @@ type Params struct {
|
||||
|
||||
// Checker performs TLS certificate inspection.
|
||||
type Checker struct {
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// CertificateInfo holds information about a TLS certificate.
|
||||
type CertificateInfo struct {
|
||||
CommonName string
|
||||
Issuer string
|
||||
NotAfter time.Time
|
||||
SubjectAlternativeNames []string
|
||||
log *slog.Logger
|
||||
timeout time.Duration
|
||||
tlsConfig *tls.Config
|
||||
port int
|
||||
}
|
||||
|
||||
// New creates a new TLS Checker instance.
|
||||
@ -43,16 +81,106 @@ func New(
|
||||
params Params,
|
||||
) (*Checker, error) {
|
||||
return &Checker{
|
||||
log: params.Logger.Get(),
|
||||
log: params.Logger.Get(),
|
||||
timeout: defaultTimeout,
|
||||
port: defaultPort,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CheckCertificate connects to the given IP:port using SNI and
|
||||
// returns certificate information.
|
||||
func (c *Checker) CheckCertificate(
|
||||
_ context.Context,
|
||||
_ string,
|
||||
_ string,
|
||||
) (*CertificateInfo, error) {
|
||||
return nil, ErrNotImplemented
|
||||
// NewStandalone creates a Checker without fx dependencies.
|
||||
func NewStandalone(opts ...Option) *Checker {
|
||||
checker := &Checker{
|
||||
log: slog.Default(),
|
||||
timeout: defaultTimeout,
|
||||
port: defaultPort,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(checker)
|
||||
}
|
||||
|
||||
return checker
|
||||
}
|
||||
|
||||
// CheckCertificate connects to the given IP address using the
|
||||
// specified SNI hostname and returns certificate information.
|
||||
func (c *Checker) CheckCertificate(
|
||||
ctx context.Context,
|
||||
ipAddress string,
|
||||
sniHostname string,
|
||||
) (*CertificateInfo, error) {
|
||||
target := net.JoinHostPort(
|
||||
ipAddress, strconv.Itoa(c.port),
|
||||
)
|
||||
|
||||
tlsCfg := c.buildTLSConfig(sniHostname)
|
||||
dialer := &tls.Dialer{
|
||||
NetDialer: &net.Dialer{Timeout: c.timeout},
|
||||
Config: tlsCfg,
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", target)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"TLS dial to %s: %w", target, err,
|
||||
)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
closeErr := conn.Close()
|
||||
if closeErr != nil {
|
||||
c.log.Debug(
|
||||
"closing TLS connection",
|
||||
"target", target,
|
||||
"error", closeErr.Error(),
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(
|
||||
"%s: %w", target, ErrUnexpectedConnType,
|
||||
)
|
||||
}
|
||||
|
||||
return c.extractCertInfo(tlsConn), nil
|
||||
}
|
||||
|
||||
func (c *Checker) buildTLSConfig(
|
||||
sniHostname string,
|
||||
) *tls.Config {
|
||||
if c.tlsConfig != nil {
|
||||
cfg := c.tlsConfig.Clone()
|
||||
cfg.ServerName = sniHostname
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
ServerName: sniHostname,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Checker) extractCertInfo(
|
||||
conn *tls.Conn,
|
||||
) *CertificateInfo {
|
||||
state := conn.ConnectionState()
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
return &CertificateInfo{}
|
||||
}
|
||||
|
||||
cert := state.PeerCertificates[0]
|
||||
|
||||
sans := make([]string, len(cert.DNSNames))
|
||||
copy(sans, cert.DNSNames)
|
||||
|
||||
return &CertificateInfo{
|
||||
CommonName: cert.Subject.CommonName,
|
||||
Issuer: cert.Issuer.CommonName,
|
||||
NotAfter: cert.NotAfter,
|
||||
SubjectAlternativeNames: sans,
|
||||
SerialNumber: cert.SerialNumber.String(),
|
||||
}
|
||||
}
|
||||
|
||||
169
internal/tlscheck/tlscheck_test.go
Normal file
169
internal/tlscheck/tlscheck_test.go
Normal file
@ -0,0 +1,169 @@
|
||||
package tlscheck_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/tlscheck"
|
||||
)
|
||||
|
||||
func startTLSServer(
|
||||
t *testing.T,
|
||||
) (*httptest.Server, string, int) {
|
||||
t.Helper()
|
||||
|
||||
srv := httptest.NewTLSServer(
|
||||
http.HandlerFunc(
|
||||
func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
addr, ok := srv.Listener.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
t.Fatal("unexpected address type")
|
||||
}
|
||||
|
||||
return srv, addr.IP.String(), addr.Port
|
||||
}
|
||||
|
||||
func TestCheckCertificateValid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv, ip, port := startTLSServer(t)
|
||||
|
||||
defer srv.Close()
|
||||
|
||||
checker := tlscheck.NewStandalone(
|
||||
tlscheck.WithTimeout(5 * time.Second),
|
||||
tlscheck.WithTLSConfig(&tls.Config{
|
||||
//nolint:gosec // test uses self-signed cert
|
||||
InsecureSkipVerify: true,
|
||||
}),
|
||||
tlscheck.WithPort(port),
|
||||
)
|
||||
|
||||
info, err := checker.CheckCertificate(
|
||||
context.Background(), ip, "localhost",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
t.Fatal("expected non-nil CertificateInfo")
|
||||
}
|
||||
|
||||
if info.NotAfter.IsZero() {
|
||||
t.Error("expected non-zero NotAfter")
|
||||
}
|
||||
|
||||
if info.SerialNumber == "" {
|
||||
t.Error("expected non-empty SerialNumber")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCertificateConnectionRefused(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lc := &net.ListenConfig{}
|
||||
|
||||
ln, err := lc.Listen(
|
||||
context.Background(), "tcp", "127.0.0.1:0",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
addr, ok := ln.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
t.Fatal("unexpected address type")
|
||||
}
|
||||
|
||||
port := addr.Port
|
||||
|
||||
_ = ln.Close()
|
||||
|
||||
checker := tlscheck.NewStandalone(
|
||||
tlscheck.WithTimeout(2*time.Second),
|
||||
tlscheck.WithPort(port),
|
||||
)
|
||||
|
||||
_, err = checker.CheckCertificate(
|
||||
context.Background(), "127.0.0.1", "localhost",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for connection refused")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCertificateContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
checker := tlscheck.NewStandalone(
|
||||
tlscheck.WithTimeout(2 * time.Second),
|
||||
tlscheck.WithPort(1),
|
||||
)
|
||||
|
||||
_, err := checker.CheckCertificate(
|
||||
ctx, "127.0.0.1", "localhost",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for canceled context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCertificateTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checker := tlscheck.NewStandalone(
|
||||
tlscheck.WithTimeout(1 * time.Millisecond),
|
||||
tlscheck.WithPort(1),
|
||||
)
|
||||
|
||||
_, err := checker.CheckCertificate(
|
||||
context.Background(),
|
||||
"192.0.2.1",
|
||||
"example.com",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCertificateSANs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv, ip, port := startTLSServer(t)
|
||||
|
||||
defer srv.Close()
|
||||
|
||||
checker := tlscheck.NewStandalone(
|
||||
tlscheck.WithTimeout(5*time.Second),
|
||||
tlscheck.WithTLSConfig(&tls.Config{
|
||||
//nolint:gosec // test uses self-signed cert
|
||||
InsecureSkipVerify: true,
|
||||
}),
|
||||
tlscheck.WithPort(port),
|
||||
)
|
||||
|
||||
info, err := checker.CheckCertificate(
|
||||
context.Background(), ip, "localhost",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if info.CommonName == "" && len(info.SubjectAlternativeNames) == 0 {
|
||||
t.Error("expected CN or SANs to be populated")
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user