Files
dnswatcher/internal/state/state.go
clawbot b20e75459f
Visas pārbaudes ir veiksmīgas
check / check (push) Successful in 34s
fix: track multiple hostnames per IP:port in port state (#65)
## Summary

Port state keys are `ip:port` with a single `hostname` field. When multiple hostnames resolve to the same IP (shared hosting, CDN), only one hostname was associated. This caused orphaned port state when that hostname removed the IP from DNS while the IP remained valid for other hostnames.

## Changes

### State (`internal/state/state.go`)
- `PortState.Hostname` (string) → `PortState.Hostnames` ([]string)
- Custom `UnmarshalJSON` for backward compatibility: reads old single `hostname` field and migrates to a single-element `hostnames` slice
- Added `DeletePortState` and `GetAllPortKeys` methods for cleanup

### Watcher (`internal/watcher/watcher.go`)
- Refactored `checkAllPorts` into three phases:
  1. Build IP:port → hostname associations from current DNS data
  2. Check each unique IP:port once with all associated hostnames
  3. Clean up stale port state entries with no hostname references
- Port change notifications now list all associated hostnames (`Hosts:` instead of `Host:`)
- Added `buildPortAssociations`, `parsePortKey`, and `cleanupStalePorts` helper functions

### README
- Updated state file format example: `hostname` → `hostnames` (array)
- Updated notification description to reflect multiple hostnames

## Backward Compatibility

Existing state files with the old single `hostname` string are handled gracefully via custom JSON unmarshaling — they are read as single-element `hostnames` slices.

Closes #55

Co-authored-by: clawbot <clawbot@noreply.eeqj.de>
Reviewed-on: #65
Co-authored-by: clawbot <clawbot@noreply.example.org>
Co-committed-by: clawbot <clawbot@noreply.example.org>
2026-03-02 00:32:27 +01:00

348 rindas
8.0 KiB
Go

// Package state provides JSON file-based state persistence.
package state
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"sync"
"time"
"go.uber.org/fx"
"sneak.berlin/go/dnswatcher/internal/config"
"sneak.berlin/go/dnswatcher/internal/logger"
)
// filePermissions for the state file.
const filePermissions = 0o600
// dirPermissions for the data directory.
const dirPermissions = 0o700
// stateVersion is the current state file format version.
const stateVersion = 1
// Params contains dependencies for State.
type Params struct {
fx.In
Logger *logger.Logger
Config *config.Config
}
// DomainState holds the monitoring state for an apex domain.
type DomainState struct {
Nameservers []string `json:"nameservers"`
LastChecked time.Time `json:"lastChecked"`
}
// NameserverRecordState holds one NS's response for a hostname.
type NameserverRecordState struct {
Records map[string][]string `json:"records"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
LastChecked time.Time `json:"lastChecked"`
}
// HostnameState holds per-nameserver monitoring state for a hostname.
type HostnameState struct {
RecordsByNameserver map[string]*NameserverRecordState `json:"recordsByNameserver"`
LastChecked time.Time `json:"lastChecked"`
}
// PortState holds the monitoring state for a port.
type PortState struct {
Open bool `json:"open"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
// UnmarshalJSON implements custom unmarshaling to handle both
// the old single-hostname format and the new multi-hostname
// format for backward compatibility with existing state files.
func (ps *PortState) UnmarshalJSON(data []byte) error {
// Use an alias to prevent infinite recursion.
type portStateAlias struct {
Open bool `json:"open"`
Hostnames []string `json:"hostnames"`
LastChecked time.Time `json:"lastChecked"`
}
var alias portStateAlias
err := json.Unmarshal(data, &alias)
if err != nil {
return fmt.Errorf("unmarshaling port state: %w", err)
}
ps.Open = alias.Open
ps.Hostnames = alias.Hostnames
ps.LastChecked = alias.LastChecked
// If Hostnames is empty, try reading the old single-hostname
// format for backward compatibility.
if len(ps.Hostnames) == 0 {
var old struct {
Hostname string `json:"hostname"`
}
// Best-effort: ignore errors since the main unmarshal
// already succeeded.
if json.Unmarshal(data, &old) == nil && old.Hostname != "" {
ps.Hostnames = []string{old.Hostname}
}
}
return nil
}
// CertificateState holds TLS certificate monitoring state.
type CertificateState struct {
CommonName string `json:"commonName"`
Issuer string `json:"issuer"`
NotAfter time.Time `json:"notAfter"`
SubjectAlternativeNames []string `json:"subjectAlternativeNames"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
LastChecked time.Time `json:"lastChecked"`
}
// Snapshot is the complete monitoring state persisted to disk.
type Snapshot struct {
Version int `json:"version"`
LastUpdated time.Time `json:"lastUpdated"`
Domains map[string]*DomainState `json:"domains"`
Hostnames map[string]*HostnameState `json:"hostnames"`
Ports map[string]*PortState `json:"ports"`
Certificates map[string]*CertificateState `json:"certificates"`
}
// State manages the monitoring state with file persistence.
type State struct {
mu sync.RWMutex
snapshot *Snapshot
log *slog.Logger
config *config.Config
}
// New creates a new State instance and loads existing state from disk.
func New(
lifecycle fx.Lifecycle,
params Params,
) (*State, error) {
state := &State{
log: params.Logger.Get(),
config: params.Config,
snapshot: &Snapshot{
Version: stateVersion,
Domains: make(map[string]*DomainState),
Hostnames: make(map[string]*HostnameState),
Ports: make(map[string]*PortState),
Certificates: make(map[string]*CertificateState),
},
}
lifecycle.Append(fx.Hook{
OnStart: func(_ context.Context) error {
return state.Load()
},
OnStop: func(_ context.Context) error {
return state.Save()
},
})
return state, nil
}
// Load reads the state from disk.
func (s *State) Load() error {
s.mu.Lock()
defer s.mu.Unlock()
path := s.config.StatePath()
//nolint:gosec // path is from trusted config
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
s.log.Info(
"no existing state file, starting fresh",
"path", path,
)
return nil
}
return fmt.Errorf("reading state file: %w", err)
}
var snapshot Snapshot
err = json.Unmarshal(data, &snapshot)
if err != nil {
return fmt.Errorf("parsing state file: %w", err)
}
s.snapshot = &snapshot
s.log.Info("loaded state from disk", "path", path)
return nil
}
// Save writes the current state to disk atomically.
func (s *State) Save() error {
s.mu.Lock()
defer s.mu.Unlock()
s.snapshot.LastUpdated = time.Now().UTC()
data, err := json.MarshalIndent(s.snapshot, "", " ")
if err != nil {
return fmt.Errorf("marshaling state: %w", err)
}
path := s.config.StatePath()
err = os.MkdirAll(filepath.Dir(path), dirPermissions)
if err != nil {
return fmt.Errorf("creating data directory: %w", err)
}
// Atomic write: write to temp file, then rename
tmpPath := path + ".tmp"
err = os.WriteFile(tmpPath, data, filePermissions)
if err != nil {
return fmt.Errorf("writing temp state file: %w", err)
}
err = os.Rename(tmpPath, path)
if err != nil {
return fmt.Errorf("renaming state file: %w", err)
}
s.log.Debug("state saved to disk", "path", path)
return nil
}
// GetSnapshot returns a copy of the current snapshot.
func (s *State) GetSnapshot() Snapshot {
s.mu.RLock()
defer s.mu.RUnlock()
return *s.snapshot
}
// SetDomainState updates the state for a domain.
func (s *State) SetDomainState(
domain string,
ds *DomainState,
) {
s.mu.Lock()
defer s.mu.Unlock()
s.snapshot.Domains[domain] = ds
}
// GetDomainState returns the state for a domain.
func (s *State) GetDomainState(
domain string,
) (*DomainState, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
ds, ok := s.snapshot.Domains[domain]
return ds, ok
}
// SetHostnameState updates the state for a hostname.
func (s *State) SetHostnameState(
hostname string,
hs *HostnameState,
) {
s.mu.Lock()
defer s.mu.Unlock()
s.snapshot.Hostnames[hostname] = hs
}
// GetHostnameState returns the state for a hostname.
func (s *State) GetHostnameState(
hostname string,
) (*HostnameState, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
hs, ok := s.snapshot.Hostnames[hostname]
return hs, ok
}
// SetPortState updates the state for a port.
func (s *State) SetPortState(key string, ps *PortState) {
s.mu.Lock()
defer s.mu.Unlock()
s.snapshot.Ports[key] = ps
}
// GetPortState returns the state for a port.
func (s *State) GetPortState(key string) (*PortState, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
ps, ok := s.snapshot.Ports[key]
return ps, ok
}
// DeletePortState removes a port state entry.
func (s *State) DeletePortState(key string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.snapshot.Ports, key)
}
// GetAllPortKeys returns all port state keys.
func (s *State) GetAllPortKeys() []string {
s.mu.RLock()
defer s.mu.RUnlock()
keys := make([]string, 0, len(s.snapshot.Ports))
for k := range s.snapshot.Ports {
keys = append(keys, k)
}
return keys
}
// SetCertificateState updates the state for a certificate.
func (s *State) SetCertificateState(
key string,
cs *CertificateState,
) {
s.mu.Lock()
defer s.mu.Unlock()
s.snapshot.Certificates[key] = cs
}
// GetCertificateState returns the state for a certificate.
func (s *State) GetCertificateState(
key string,
) (*CertificateState, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
cs, ok := s.snapshot.Certificates[key]
return cs, ok
}