Merge branch 'main' into feature/watcher-implementation
This commit is contained in:
85
internal/config/classify.go
Normal file
85
internal/config/classify.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/publicsuffix"
|
||||
)
|
||||
|
||||
// DNSNameType indicates whether a DNS name is an apex domain or a hostname.
|
||||
type DNSNameType int
|
||||
|
||||
const (
|
||||
// DNSNameTypeDomain indicates the name is an apex (eTLD+1) domain.
|
||||
DNSNameTypeDomain DNSNameType = iota
|
||||
// DNSNameTypeHostname indicates the name is a subdomain/hostname.
|
||||
DNSNameTypeHostname
|
||||
)
|
||||
|
||||
// ErrEmptyDNSName is returned when an empty string is passed to ClassifyDNSName.
|
||||
var ErrEmptyDNSName = errors.New("empty DNS name")
|
||||
|
||||
// String returns the string representation of a DNSNameType.
|
||||
func (t DNSNameType) String() string {
|
||||
switch t {
|
||||
case DNSNameTypeDomain:
|
||||
return "domain"
|
||||
case DNSNameTypeHostname:
|
||||
return "hostname"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ClassifyDNSName determines whether a DNS name is an apex domain or a
|
||||
// hostname (subdomain) using the Public Suffix List. It returns an error
|
||||
// if the input is empty or is itself a public suffix (e.g. "co.uk").
|
||||
func ClassifyDNSName(name string) (DNSNameType, error) {
|
||||
name = strings.ToLower(strings.TrimSuffix(strings.TrimSpace(name), "."))
|
||||
|
||||
if name == "" {
|
||||
return 0, ErrEmptyDNSName
|
||||
}
|
||||
|
||||
etld1, err := publicsuffix.EffectiveTLDPlusOne(name)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid DNS name %q: %w", name, err)
|
||||
}
|
||||
|
||||
if name == etld1 {
|
||||
return DNSNameTypeDomain, nil
|
||||
}
|
||||
|
||||
return DNSNameTypeHostname, nil
|
||||
}
|
||||
|
||||
// ClassifyTargets splits a list of DNS names into apex domains and
|
||||
// hostnames using the Public Suffix List. It returns an error if any
|
||||
// name cannot be classified.
|
||||
func ClassifyTargets(targets []string) ([]string, []string, error) {
|
||||
var domains, hostnames []string
|
||||
|
||||
for _, t := range targets {
|
||||
normalized := strings.ToLower(strings.TrimSuffix(strings.TrimSpace(t), "."))
|
||||
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
typ, classErr := ClassifyDNSName(normalized)
|
||||
if classErr != nil {
|
||||
return nil, nil, classErr
|
||||
}
|
||||
|
||||
switch typ {
|
||||
case DNSNameTypeDomain:
|
||||
domains = append(domains, normalized)
|
||||
case DNSNameTypeHostname:
|
||||
hostnames = append(hostnames, normalized)
|
||||
}
|
||||
}
|
||||
|
||||
return domains, hostnames, nil
|
||||
}
|
||||
83
internal/config/classify_test.go
Normal file
83
internal/config/classify_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"sneak.berlin/go/dnswatcher/internal/config"
|
||||
)
|
||||
|
||||
func TestClassifyDNSName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want config.DNSNameType
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "apex domain simple", input: "example.com", want: config.DNSNameTypeDomain},
|
||||
{name: "hostname simple", input: "www.example.com", want: config.DNSNameTypeHostname},
|
||||
{name: "apex domain multi-part TLD", input: "example.co.uk", want: config.DNSNameTypeDomain},
|
||||
{name: "hostname multi-part TLD", input: "api.example.co.uk", want: config.DNSNameTypeHostname},
|
||||
{name: "public suffix itself", input: "co.uk", wantErr: true},
|
||||
{name: "empty string", input: "", wantErr: true},
|
||||
{name: "deeply nested hostname", input: "a.b.c.example.com", want: config.DNSNameTypeHostname},
|
||||
{name: "trailing dot stripped", input: "example.com.", want: config.DNSNameTypeDomain},
|
||||
{name: "uppercase normalized", input: "WWW.Example.COM", want: config.DNSNameTypeHostname},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := config.ClassifyDNSName(tt.input)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("ClassifyDNSName(%q) expected error, got %v", tt.input, got)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ClassifyDNSName(%q) unexpected error: %v", tt.input, err)
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("ClassifyDNSName(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyTargets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
domains, hostnames, err := config.ClassifyTargets([]string{
|
||||
"example.com",
|
||||
"www.example.com",
|
||||
"example.co.uk",
|
||||
"api.example.co.uk",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(domains) != 2 {
|
||||
t.Errorf("expected 2 domains, got %d: %v", len(domains), domains)
|
||||
}
|
||||
|
||||
if len(hostnames) != 2 {
|
||||
t.Errorf("expected 2 hostnames, got %d: %v", len(hostnames), hostnames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyTargetsRejectsPublicSuffix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, _, err := config.ClassifyTargets([]string{"co.uk"})
|
||||
if err == nil {
|
||||
t.Error("expected error for public suffix, got nil")
|
||||
}
|
||||
}
|
||||
@@ -89,8 +89,7 @@ func setupViper(name string) {
|
||||
viper.SetDefault("PORT", defaultPort)
|
||||
viper.SetDefault("DEBUG", false)
|
||||
viper.SetDefault("DATA_DIR", "./data")
|
||||
viper.SetDefault("DOMAINS", "")
|
||||
viper.SetDefault("HOSTNAMES", "")
|
||||
viper.SetDefault("TARGETS", "")
|
||||
viper.SetDefault("SLACK_WEBHOOK", "")
|
||||
viper.SetDefault("MATTERMOST_WEBHOOK", "")
|
||||
viper.SetDefault("NTFY_TOPIC", "")
|
||||
@@ -133,12 +132,19 @@ func buildConfig(
|
||||
tlsInterval = defaultTLSInterval
|
||||
}
|
||||
|
||||
domains, hostnames, err := ClassifyTargets(
|
||||
parseCSV(viper.GetString("TARGETS")),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid targets configuration: %w", err)
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Port: viper.GetInt("PORT"),
|
||||
Debug: viper.GetBool("DEBUG"),
|
||||
DataDir: viper.GetString("DATA_DIR"),
|
||||
Domains: parseCSV(viper.GetString("DOMAINS")),
|
||||
Hostnames: parseCSV(viper.GetString("HOSTNAMES")),
|
||||
Domains: domains,
|
||||
Hostnames: hostnames,
|
||||
SlackWebhook: viper.GetString("SLACK_WEBHOOK"),
|
||||
MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"),
|
||||
NtfyTopic: viper.GetString("NTFY_TOPIC"),
|
||||
|
||||
Reference in New Issue
Block a user