diff --git a/internal/config/config.go b/internal/config/config.go index 0acf89e..8f54210 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,6 +15,12 @@ import ( "sneak.berlin/go/dnswatcher/internal/logger" ) +// ErrNoTargets is returned when DNSWATCHER_TARGETS is empty or unset. +var ErrNoTargets = errors.New( + "no targets configured: set DNSWATCHER_TARGETS to a comma-separated " + + "list of DNS names to monitor", +) + // Default configuration values. const ( defaultPort = 8080 @@ -118,25 +124,9 @@ func buildConfig( } } - dnsInterval, err := time.ParseDuration( - viper.GetString("DNS_INTERVAL"), - ) + domains, hostnames, err := classifyAndValidateTargets() if err != nil { - dnsInterval = defaultDNSInterval - } - - tlsInterval, err := time.ParseDuration( - viper.GetString("TLS_INTERVAL"), - ) - if err != nil { - tlsInterval = defaultTLSInterval - } - - domains, hostnames, err := ClassifyTargets( - parseCSV(viper.GetString("TARGETS")), - ) - if err != nil { - return nil, fmt.Errorf("invalid targets configuration: %w", err) + return nil, err } cfg := &Config{ @@ -148,8 +138,8 @@ func buildConfig( SlackWebhook: viper.GetString("SLACK_WEBHOOK"), MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"), NtfyTopic: viper.GetString("NTFY_TOPIC"), - DNSInterval: dnsInterval, - TLSInterval: tlsInterval, + DNSInterval: parseDurationOrDefault("DNS_INTERVAL", defaultDNSInterval), + TLSInterval: parseDurationOrDefault("TLS_INTERVAL", defaultTLSInterval), TLSExpiryWarning: viper.GetInt("TLS_EXPIRY_WARNING"), SentryDSN: viper.GetString("SENTRY_DSN"), MaintenanceMode: viper.GetBool("MAINTENANCE_MODE"), @@ -162,6 +152,32 @@ func buildConfig( return cfg, nil } +func classifyAndValidateTargets() ([]string, []string, error) { + domains, hostnames, err := ClassifyTargets( + parseCSV(viper.GetString("TARGETS")), + ) + if err != nil { + return nil, nil, fmt.Errorf( + "invalid targets configuration: %w", err, + ) + } + + if len(domains) == 0 && len(hostnames) == 0 { + return nil, nil, ErrNoTargets + } + + return domains, hostnames, nil +} + +func parseDurationOrDefault(key string, fallback time.Duration) time.Duration { + d, err := time.ParseDuration(viper.GetString(key)) + if err != nil { + return fallback + } + + return d +} + func parseCSV(input string) []string { if input == "" { return nil diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..fafaa16 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,87 @@ +package config_test + +import ( + "errors" + "testing" + + "go.uber.org/fx" + + "sneak.berlin/go/dnswatcher/internal/config" + "sneak.berlin/go/dnswatcher/internal/globals" + "sneak.berlin/go/dnswatcher/internal/logger" +) + +func TestNewReturnsErrNoTargetsWhenEmpty(t *testing.T) { + // Cannot use t.Parallel() because t.Setenv modifies the process + // environment. + t.Setenv("DNSWATCHER_TARGETS", "") + t.Setenv("DNSWATCHER_DATA_DIR", t.TempDir()) + + var cfg *config.Config + + app := fx.New( + fx.Provide( + func() *globals.Globals { + return &globals.Globals{ + Appname: "dnswatcher-test-empty", + } + }, + logger.New, + config.New, + ), + fx.Populate(&cfg), + fx.NopLogger, + ) + + err := app.Err() + if err == nil { + t.Fatal( + "expected error when DNSWATCHER_TARGETS is empty, got nil", + ) + } + + if !errors.Is(err, config.ErrNoTargets) { + t.Errorf("expected ErrNoTargets, got: %v", err) + } +} + +func TestNewSucceedsWithTargets(t *testing.T) { + // Cannot use t.Parallel() because t.Setenv modifies the process + // environment. + t.Setenv("DNSWATCHER_TARGETS", "example.com") + t.Setenv("DNSWATCHER_DATA_DIR", t.TempDir()) + + // Prevent loading a local config file by changing to a temp dir. + t.Chdir(t.TempDir()) + + var cfg *config.Config + + app := fx.New( + fx.Provide( + func() *globals.Globals { + return &globals.Globals{ + Appname: "dnswatcher-test-ok", + } + }, + logger.New, + config.New, + ), + fx.Populate(&cfg), + fx.NopLogger, + ) + + err := app.Err() + if err != nil { + t.Fatalf( + "expected no error with valid targets, got: %v", + err, + ) + } + + if len(cfg.Domains) != 1 || cfg.Domains[0] != "example.com" { + t.Errorf( + "expected [example.com], got domains=%v", + cfg.Domains, + ) + } +}