dnswatcher/internal/config/config.go
clawbot acae697aa2 feat: replace DOMAINS/HOSTNAMES with single TARGETS config
Single DNSWATCHER_TARGETS env var replaces the separate DOMAINS and
HOSTNAMES variables. Classification is automatic via the PSL.

Closes #10
2026-02-19 20:09:07 -08:00

194 lines
4.6 KiB
Go

// Package config provides application configuration via Viper.
package config
import (
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/spf13/viper"
"go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/globals"
"sneak.berlin/go/dnswatcher/internal/logger"
)
// Default configuration values.
const (
defaultPort = 8080
defaultDNSInterval = 1 * time.Hour
defaultTLSInterval = 12 * time.Hour
defaultTLSExpiryWarning = 7
)
// Params contains dependencies for Config.
type Params struct {
fx.In
Globals *globals.Globals
Logger *logger.Logger
}
// Config holds application configuration.
type Config struct {
Port int
Debug bool
DataDir string
Domains []string
Hostnames []string
SlackWebhook string
MattermostWebhook string
NtfyTopic string
DNSInterval time.Duration
TLSInterval time.Duration
TLSExpiryWarning int
SentryDSN string
MaintenanceMode bool
MetricsUsername string
MetricsPassword string
params *Params
log *slog.Logger
}
// New creates a new Config instance from environment and config files.
func New(_ fx.Lifecycle, params Params) (*Config, error) {
log := params.Logger.Get()
name := params.Globals.Appname
if name == "" {
name = "dnswatcher"
}
setupViper(name)
cfg, err := buildConfig(log, &params)
if err != nil {
return nil, err
}
configureDebugLogging(cfg, params)
return cfg, nil
}
func setupViper(name string) {
viper.SetConfigName(name)
viper.SetConfigType("yaml")
viper.AddConfigPath("/etc/" + name)
viper.AddConfigPath("$HOME/.config/" + name)
viper.AddConfigPath(".")
viper.SetEnvPrefix("DNSWATCHER")
viper.AutomaticEnv()
// PORT is not prefixed for compatibility
_ = viper.BindEnv("PORT", "PORT")
viper.SetDefault("PORT", defaultPort)
viper.SetDefault("DEBUG", false)
viper.SetDefault("DATA_DIR", "./data")
viper.SetDefault("TARGETS", "")
viper.SetDefault("SLACK_WEBHOOK", "")
viper.SetDefault("MATTERMOST_WEBHOOK", "")
viper.SetDefault("NTFY_TOPIC", "")
viper.SetDefault("DNS_INTERVAL", defaultDNSInterval.String())
viper.SetDefault("TLS_INTERVAL", defaultTLSInterval.String())
viper.SetDefault("TLS_EXPIRY_WARNING", defaultTLSExpiryWarning)
viper.SetDefault("SENTRY_DSN", "")
viper.SetDefault("MAINTENANCE_MODE", false)
viper.SetDefault("METRICS_USERNAME", "")
viper.SetDefault("METRICS_PASSWORD", "")
}
func buildConfig(
log *slog.Logger,
params *Params,
) (*Config, error) {
err := viper.ReadInConfig()
if err != nil {
var notFound viper.ConfigFileNotFoundError
if !errors.As(err, &notFound) {
log.Error("config file malformed", "error", err)
return nil, fmt.Errorf(
"config file malformed: %w", err,
)
}
}
dnsInterval, err := time.ParseDuration(
viper.GetString("DNS_INTERVAL"),
)
if err != nil {
dnsInterval = defaultDNSInterval
}
tlsInterval, err := time.ParseDuration(
viper.GetString("TLS_INTERVAL"),
)
if err != nil {
tlsInterval = defaultTLSInterval
}
targets := parseCSV(viper.GetString("TARGETS"))
domains, hostnames, classifyErr := classifyTargets(targets)
if classifyErr != nil {
return nil, fmt.Errorf("classifying targets: %w", classifyErr)
}
cfg := &Config{
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
DataDir: viper.GetString("DATA_DIR"),
Domains: domains,
Hostnames: hostnames,
SlackWebhook: viper.GetString("SLACK_WEBHOOK"),
MattermostWebhook: viper.GetString("MATTERMOST_WEBHOOK"),
NtfyTopic: viper.GetString("NTFY_TOPIC"),
DNSInterval: dnsInterval,
TLSInterval: tlsInterval,
TLSExpiryWarning: viper.GetInt("TLS_EXPIRY_WARNING"),
SentryDSN: viper.GetString("SENTRY_DSN"),
MaintenanceMode: viper.GetBool("MAINTENANCE_MODE"),
MetricsUsername: viper.GetString("METRICS_USERNAME"),
MetricsPassword: viper.GetString("METRICS_PASSWORD"),
params: params,
log: log,
}
return cfg, nil
}
func parseCSV(input string) []string {
if input == "" {
return nil
}
parts := strings.Split(input, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
func configureDebugLogging(cfg *Config, params Params) {
if cfg.Debug {
params.Logger.EnableDebugLogging()
cfg.log = params.Logger.Get()
}
}
// StatePath returns the full path to the state JSON file.
func (c *Config) StatePath() string {
return c.DataDir + "/state.json"
}