All checks were successful
check / check (push) Successful in 51s
Add 32 tests covering: - Save/Load round-trip for all state types (domains, hostnames, ports, certificates) - Atomic write verification (no leftover temp files) - PortState.UnmarshalJSON backward compatibility (old single-hostname format) - Missing, corrupt, and empty state file handling - Permission error handling (skipped when running as root) - All getter/setter/delete methods for every state type - GetSnapshot returns a value copy - GetAllPortKeys enumeration - Concurrent read/write safety with race detection - Concurrent Save/Load safety - File permission verification (0600) - Multiple saves overwrite previous state - LastUpdated timestamp updates on save - Error field round-trips for certificates and hostnames - Snapshot version correctness Also adds NewForTestWithDataDir() helper for tests requiring file persistence. Closes #70
1303 lines
29 KiB
Go
1303 lines
29 KiB
Go
package state_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"sneak.berlin/go/dnswatcher/internal/state"
|
|
)
|
|
|
|
const testHostname = "www.example.com"
|
|
|
|
// populateState fills a State with representative test data across all categories.
|
|
func populateState(t *testing.T, s *state.State) {
|
|
t.Helper()
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
|
|
s.SetDomainState("example.com", &state.DomainState{
|
|
Nameservers: []string{"ns1.example.com.", "ns2.example.com."},
|
|
LastChecked: now,
|
|
})
|
|
|
|
s.SetDomainState("example.org", &state.DomainState{
|
|
Nameservers: []string{"ns1.example.org."},
|
|
LastChecked: now,
|
|
})
|
|
|
|
s.SetHostnameState(testHostname, &state.HostnameState{
|
|
RecordsByNameserver: map[string]*state.NameserverRecordState{
|
|
"ns1.example.com.": {
|
|
Records: map[string][]string{
|
|
"A": {"93.184.216.34"},
|
|
"AAAA": {"2606:2800:220:1:248:1893:25c8:1946"},
|
|
},
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
},
|
|
"ns2.example.com.": {
|
|
Records: map[string][]string{
|
|
"A": {"93.184.216.34"},
|
|
},
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
},
|
|
},
|
|
LastChecked: now,
|
|
})
|
|
|
|
s.SetPortState("93.184.216.34:80", &state.PortState{
|
|
Open: true,
|
|
Hostnames: []string{testHostname},
|
|
LastChecked: now,
|
|
})
|
|
|
|
s.SetPortState("93.184.216.34:443", &state.PortState{
|
|
Open: true,
|
|
Hostnames: []string{testHostname, "api.example.com"},
|
|
LastChecked: now,
|
|
})
|
|
|
|
s.SetCertificateState("93.184.216.34:443:www.example.com", &state.CertificateState{
|
|
CommonName: testHostname,
|
|
Issuer: "DigiCert TLS RSA SHA256 2020 CA1",
|
|
NotAfter: now.Add(365 * 24 * time.Hour),
|
|
SubjectAlternativeNames: []string{testHostname, "example.com"},
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
})
|
|
}
|
|
|
|
// TestSaveLoadRoundTrip_Domains verifies domain data survives a save/load cycle.
|
|
func TestSaveLoadRoundTrip_Domains(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
populateState(t, s)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
snap := loaded.GetSnapshot()
|
|
|
|
if len(snap.Domains) != 2 {
|
|
t.Fatalf("expected 2 domains, got %d", len(snap.Domains))
|
|
}
|
|
|
|
dom, ok := snap.Domains["example.com"]
|
|
if !ok {
|
|
t.Fatal("missing domain example.com")
|
|
}
|
|
|
|
if len(dom.Nameservers) != 2 {
|
|
t.Errorf("expected 2 nameservers, got %d", len(dom.Nameservers))
|
|
}
|
|
}
|
|
|
|
// TestSaveLoadRoundTrip_Hostnames verifies hostname data survives a save/load cycle.
|
|
func TestSaveLoadRoundTrip_Hostnames(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
populateState(t, s)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
snap := loaded.GetSnapshot()
|
|
|
|
if len(snap.Hostnames) != 1 {
|
|
t.Fatalf("expected 1 hostname, got %d", len(snap.Hostnames))
|
|
}
|
|
|
|
hn, ok := snap.Hostnames[testHostname]
|
|
if !ok {
|
|
t.Fatal("missing hostname")
|
|
}
|
|
|
|
if len(hn.RecordsByNameserver) != 2 {
|
|
t.Errorf("expected 2 NS entries, got %d", len(hn.RecordsByNameserver))
|
|
}
|
|
|
|
verifyNS1Records(t, hn)
|
|
}
|
|
|
|
// verifyNS1Records checks the ns1.example.com. records in the loaded hostname state.
|
|
func verifyNS1Records(t *testing.T, hn *state.HostnameState) {
|
|
t.Helper()
|
|
|
|
ns1, ok := hn.RecordsByNameserver["ns1.example.com."]
|
|
if !ok {
|
|
t.Fatal("missing nameserver ns1.example.com.")
|
|
}
|
|
|
|
aRecords := ns1.Records["A"]
|
|
if len(aRecords) != 1 || aRecords[0] != "93.184.216.34" {
|
|
t.Errorf("ns1 A records: got %v", aRecords)
|
|
}
|
|
|
|
aaaaRecords := ns1.Records["AAAA"]
|
|
if len(aaaaRecords) != 1 || aaaaRecords[0] != "2606:2800:220:1:248:1893:25c8:1946" {
|
|
t.Errorf("ns1 AAAA records: got %v", aaaaRecords)
|
|
}
|
|
|
|
if ns1.Status != "ok" {
|
|
t.Errorf("ns1 status: got %q, want %q", ns1.Status, "ok")
|
|
}
|
|
}
|
|
|
|
// TestSaveLoadRoundTrip_Ports verifies port data survives a save/load cycle.
|
|
func TestSaveLoadRoundTrip_Ports(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
populateState(t, s)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
snap := loaded.GetSnapshot()
|
|
|
|
if len(snap.Ports) != 2 {
|
|
t.Fatalf("expected 2 ports, got %d", len(snap.Ports))
|
|
}
|
|
|
|
port443, ok := snap.Ports["93.184.216.34:443"]
|
|
if !ok {
|
|
t.Fatal("missing port 93.184.216.34:443")
|
|
}
|
|
|
|
if !port443.Open {
|
|
t.Error("port 443: expected open")
|
|
}
|
|
|
|
if len(port443.Hostnames) != 2 {
|
|
t.Errorf("expected 2 hostnames, got %d", len(port443.Hostnames))
|
|
}
|
|
}
|
|
|
|
// TestSaveLoadRoundTrip_Certificates verifies certificate data survives a save/load cycle.
|
|
func TestSaveLoadRoundTrip_Certificates(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
populateState(t, s)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
snap := loaded.GetSnapshot()
|
|
|
|
if len(snap.Certificates) != 1 {
|
|
t.Fatalf("expected 1 certificate, got %d", len(snap.Certificates))
|
|
}
|
|
|
|
cert, ok := snap.Certificates["93.184.216.34:443:www.example.com"]
|
|
if !ok {
|
|
t.Fatal("missing certificate")
|
|
}
|
|
|
|
if cert.CommonName != testHostname {
|
|
t.Errorf("cert CN: got %q", cert.CommonName)
|
|
}
|
|
|
|
if cert.Issuer != "DigiCert TLS RSA SHA256 2020 CA1" {
|
|
t.Errorf("cert issuer: got %q", cert.Issuer)
|
|
}
|
|
|
|
if len(cert.SubjectAlternativeNames) != 2 {
|
|
t.Errorf("cert SANs: expected 2, got %d", len(cert.SubjectAlternativeNames))
|
|
}
|
|
|
|
if cert.Status != "ok" {
|
|
t.Errorf("cert status: got %q", cert.Status)
|
|
}
|
|
}
|
|
|
|
// TestSaveCreatesDirectory verifies Save creates the data directory if missing.
|
|
func TestSaveCreatesDirectory(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
nested := filepath.Join(dir, "nested", "deep")
|
|
|
|
s := state.NewForTestWithDataDir(nested)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
statePath := filepath.Join(nested, "state.json")
|
|
|
|
info, err := os.Stat(statePath)
|
|
if err != nil {
|
|
t.Fatalf("state file not found: %v", err)
|
|
}
|
|
|
|
if info.Size() == 0 {
|
|
t.Error("state file is empty")
|
|
}
|
|
}
|
|
|
|
// TestSaveAtomicWrite verifies the atomic write pattern (no leftover temp files).
|
|
func TestSaveAtomicWrite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
populateState(t, s)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
statePath := filepath.Join(dir, "state.json")
|
|
tmpPath := statePath + ".tmp"
|
|
|
|
// No .tmp file should remain after successful save.
|
|
_, err = os.Stat(tmpPath)
|
|
if !os.IsNotExist(err) {
|
|
t.Error("temp file should not exist after successful save")
|
|
}
|
|
|
|
// Modify state and save again; verify updated content.
|
|
s.SetDomainState("new.example.com", &state.DomainState{
|
|
Nameservers: []string{"ns1.new.example.com."},
|
|
LastChecked: time.Now().UTC(),
|
|
})
|
|
|
|
err = s.Save()
|
|
if err != nil {
|
|
t.Fatalf("second Save() error: %v", err)
|
|
}
|
|
|
|
_, err = os.Stat(tmpPath)
|
|
if !os.IsNotExist(err) {
|
|
t.Error("temp file should not exist after second save")
|
|
}
|
|
|
|
// Verify new domain is present by loading.
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
_, ok := loaded.GetDomainState("new.example.com")
|
|
if !ok {
|
|
t.Error("new domain not found after second save")
|
|
}
|
|
}
|
|
|
|
// TestLoadMissingFile verifies that loading a nonexistent file is not an error.
|
|
func TestLoadMissingFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTestWithDataDir(t.TempDir())
|
|
|
|
err := s.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() with missing file should not error: %v", err)
|
|
}
|
|
|
|
snap := s.GetSnapshot()
|
|
|
|
if len(snap.Domains) != 0 {
|
|
t.Error("expected empty domains")
|
|
}
|
|
|
|
if len(snap.Hostnames) != 0 {
|
|
t.Error("expected empty hostnames")
|
|
}
|
|
|
|
if len(snap.Ports) != 0 {
|
|
t.Error("expected empty ports")
|
|
}
|
|
|
|
if len(snap.Certificates) != 0 {
|
|
t.Error("expected empty certificates")
|
|
}
|
|
}
|
|
|
|
// TestLoadCorruptFile verifies that a corrupt state file returns an error.
|
|
func TestLoadCorruptFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
|
|
err := os.WriteFile(
|
|
filepath.Join(dir, "state.json"),
|
|
[]byte("not valid json{{{"),
|
|
0o600,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("writing corrupt file: %v", err)
|
|
}
|
|
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
err = s.Load()
|
|
if err == nil {
|
|
t.Fatal("Load() should return error for corrupt file")
|
|
}
|
|
}
|
|
|
|
// TestLoadEmptyFile verifies that an empty state file returns an error.
|
|
func TestLoadEmptyFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
|
|
err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(""), 0o600)
|
|
if err != nil {
|
|
t.Fatalf("writing empty file: %v", err)
|
|
}
|
|
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
err = s.Load()
|
|
if err == nil {
|
|
t.Fatal("Load() should return error for empty file")
|
|
}
|
|
}
|
|
|
|
// TestLoadReadPermissionError verifies that an unreadable state file returns an error.
|
|
func TestLoadReadPermissionError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if os.Getuid() == 0 {
|
|
t.Skip("skipping permission test when running as root")
|
|
}
|
|
|
|
dir := t.TempDir()
|
|
statePath := filepath.Join(dir, "state.json")
|
|
|
|
err := os.WriteFile(statePath, []byte("{}"), 0o600)
|
|
if err != nil {
|
|
t.Fatalf("writing file: %v", err)
|
|
}
|
|
|
|
err = os.Chmod(statePath, 0o000)
|
|
if err != nil {
|
|
t.Fatalf("chmod: %v", err)
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
_ = os.Chmod(statePath, 0o600)
|
|
})
|
|
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
err = s.Load()
|
|
if err == nil {
|
|
t.Fatal("Load() should return error for unreadable file")
|
|
}
|
|
}
|
|
|
|
// TestSaveWritePermissionError verifies that Save returns an error
|
|
// when the directory is not writable.
|
|
func TestSaveWritePermissionError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if os.Getuid() == 0 {
|
|
t.Skip("skipping permission test when running as root")
|
|
}
|
|
|
|
dir := t.TempDir()
|
|
dataDir := filepath.Join(dir, "readonly")
|
|
|
|
err := os.MkdirAll(dataDir, 0o700)
|
|
if err != nil {
|
|
t.Fatalf("creating dir: %v", err)
|
|
}
|
|
|
|
//nolint:gosec // intentionally making dir read-only+execute for test
|
|
err = os.Chmod(dataDir, 0o500)
|
|
if err != nil {
|
|
t.Fatalf("chmod: %v", err)
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
//nolint:gosec // restoring write permission for cleanup
|
|
_ = os.Chmod(dataDir, 0o700)
|
|
})
|
|
|
|
s := state.NewForTestWithDataDir(dataDir)
|
|
|
|
err = s.Save()
|
|
if err == nil {
|
|
t.Fatal("Save() should return error for read-only directory")
|
|
}
|
|
}
|
|
|
|
// TestPortStateUnmarshalJSON_NewFormat verifies deserialization of the
|
|
// current multi-hostname format.
|
|
func TestPortStateUnmarshalJSON_NewFormat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
data := []byte(`{
|
|
"open": true,
|
|
"hostnames": ["www.example.com", "api.example.com"],
|
|
"lastChecked": "2026-02-19T12:00:00Z"
|
|
}`)
|
|
|
|
var ps state.PortState
|
|
|
|
err := json.Unmarshal(data, &ps)
|
|
if err != nil {
|
|
t.Fatalf("Unmarshal error: %v", err)
|
|
}
|
|
|
|
if !ps.Open {
|
|
t.Error("expected open=true")
|
|
}
|
|
|
|
if len(ps.Hostnames) != 2 {
|
|
t.Fatalf("expected 2 hostnames, got %d", len(ps.Hostnames))
|
|
}
|
|
|
|
if ps.Hostnames[0] != testHostname {
|
|
t.Errorf("hostname[0]: got %q", ps.Hostnames[0])
|
|
}
|
|
|
|
if ps.Hostnames[1] != "api.example.com" {
|
|
t.Errorf("hostname[1]: got %q", ps.Hostnames[1])
|
|
}
|
|
|
|
if ps.LastChecked.IsZero() {
|
|
t.Error("expected non-zero LastChecked")
|
|
}
|
|
}
|
|
|
|
// TestPortStateUnmarshalJSON_OldFormat verifies backward-compatible
|
|
// deserialization of the old single-hostname format.
|
|
func TestPortStateUnmarshalJSON_OldFormat(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
data := []byte(`{
|
|
"open": false,
|
|
"hostname": "legacy.example.com",
|
|
"lastChecked": "2026-01-15T10:30:00Z"
|
|
}`)
|
|
|
|
var ps state.PortState
|
|
|
|
err := json.Unmarshal(data, &ps)
|
|
if err != nil {
|
|
t.Fatalf("Unmarshal error: %v", err)
|
|
}
|
|
|
|
if ps.Open {
|
|
t.Error("expected open=false")
|
|
}
|
|
|
|
if len(ps.Hostnames) != 1 {
|
|
t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames))
|
|
}
|
|
|
|
if ps.Hostnames[0] != "legacy.example.com" {
|
|
t.Errorf("hostname: got %q", ps.Hostnames[0])
|
|
}
|
|
}
|
|
|
|
// TestPortStateUnmarshalJSON_EmptyHostnames verifies that when neither
|
|
// hostnames nor hostname is present, Hostnames remains nil.
|
|
func TestPortStateUnmarshalJSON_EmptyHostnames(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
data := []byte(`{
|
|
"open": true,
|
|
"lastChecked": "2026-02-19T12:00:00Z"
|
|
}`)
|
|
|
|
var ps state.PortState
|
|
|
|
err := json.Unmarshal(data, &ps)
|
|
if err != nil {
|
|
t.Fatalf("Unmarshal error: %v", err)
|
|
}
|
|
|
|
if len(ps.Hostnames) != 0 {
|
|
t.Errorf("expected empty hostnames, got %v", ps.Hostnames)
|
|
}
|
|
}
|
|
|
|
// TestPortStateUnmarshalJSON_InvalidJSON verifies that invalid JSON returns an error.
|
|
func TestPortStateUnmarshalJSON_InvalidJSON(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
data := []byte(`{invalid json}`)
|
|
|
|
var ps state.PortState
|
|
|
|
err := json.Unmarshal(data, &ps)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid JSON")
|
|
}
|
|
}
|
|
|
|
// TestPortStateUnmarshalJSON_BothFormats verifies that when both "hostnames"
|
|
// and "hostname" are present, the new format takes precedence.
|
|
func TestPortStateUnmarshalJSON_BothFormats(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
data := []byte(`{
|
|
"open": true,
|
|
"hostnames": ["new.example.com"],
|
|
"hostname": "old.example.com",
|
|
"lastChecked": "2026-02-19T12:00:00Z"
|
|
}`)
|
|
|
|
var ps state.PortState
|
|
|
|
err := json.Unmarshal(data, &ps)
|
|
if err != nil {
|
|
t.Fatalf("Unmarshal error: %v", err)
|
|
}
|
|
|
|
// When hostnames is non-empty, the old hostname field is ignored.
|
|
if len(ps.Hostnames) != 1 {
|
|
t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames))
|
|
}
|
|
|
|
if ps.Hostnames[0] != "new.example.com" {
|
|
t.Errorf("hostname: got %q, want %q", ps.Hostnames[0], "new.example.com")
|
|
}
|
|
}
|
|
|
|
// TestGetSnapshot_ReturnsCopy verifies that GetSnapshot returns a value copy.
|
|
func TestGetSnapshot_ReturnsCopy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
populateState(t, s)
|
|
|
|
snap := s.GetSnapshot()
|
|
|
|
// Mutate the returned snapshot's map.
|
|
snap.Domains["mutated.example.com"] = &state.DomainState{
|
|
Nameservers: []string{"ns1.mutated.com."},
|
|
}
|
|
|
|
// The original domain data should still be accessible.
|
|
_, ok := s.GetDomainState("example.com")
|
|
if !ok {
|
|
t.Error("original domain should still exist")
|
|
}
|
|
}
|
|
|
|
// TestDomainState_GetSet tests domain state getter and setter.
|
|
func TestDomainState_GetSet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
// Get on missing key returns false.
|
|
_, ok := s.GetDomainState("nonexistent.com")
|
|
if ok {
|
|
t.Error("expected false for missing domain")
|
|
}
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
ds := &state.DomainState{
|
|
Nameservers: []string{"ns1.test.com."},
|
|
LastChecked: now,
|
|
}
|
|
|
|
s.SetDomainState("test.com", ds)
|
|
|
|
got, ok := s.GetDomainState("test.com")
|
|
if !ok {
|
|
t.Fatal("expected true for existing domain")
|
|
}
|
|
|
|
if len(got.Nameservers) != 1 || got.Nameservers[0] != "ns1.test.com." {
|
|
t.Errorf("nameservers: got %v", got.Nameservers)
|
|
}
|
|
|
|
if !got.LastChecked.Equal(now) {
|
|
t.Errorf("lastChecked: got %v, want %v", got.LastChecked, now)
|
|
}
|
|
|
|
// Overwrite.
|
|
ds2 := &state.DomainState{
|
|
Nameservers: []string{"ns1.test.com.", "ns2.test.com."},
|
|
LastChecked: now.Add(time.Hour),
|
|
}
|
|
|
|
s.SetDomainState("test.com", ds2)
|
|
|
|
got2, ok := s.GetDomainState("test.com")
|
|
if !ok {
|
|
t.Fatal("expected true after overwrite")
|
|
}
|
|
|
|
if len(got2.Nameservers) != 2 {
|
|
t.Errorf("expected 2 nameservers after overwrite, got %d", len(got2.Nameservers))
|
|
}
|
|
}
|
|
|
|
// TestHostnameState_GetSet tests hostname state getter and setter.
|
|
func TestHostnameState_GetSet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
_, ok := s.GetHostnameState("missing.example.com")
|
|
if ok {
|
|
t.Error("expected false for missing hostname")
|
|
}
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
hs := &state.HostnameState{
|
|
RecordsByNameserver: map[string]*state.NameserverRecordState{
|
|
"ns1.example.com.": {
|
|
Records: map[string][]string{"A": {"1.2.3.4"}},
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
},
|
|
},
|
|
LastChecked: now,
|
|
}
|
|
|
|
s.SetHostnameState(testHostname, hs)
|
|
|
|
got, ok := s.GetHostnameState(testHostname)
|
|
if !ok {
|
|
t.Fatal("expected true for existing hostname")
|
|
}
|
|
|
|
nsState, ok := got.RecordsByNameserver["ns1.example.com."]
|
|
if !ok {
|
|
t.Fatal("missing nameserver entry")
|
|
}
|
|
|
|
if nsState.Status != "ok" {
|
|
t.Errorf("status: got %q", nsState.Status)
|
|
}
|
|
|
|
aRecords := nsState.Records["A"]
|
|
if len(aRecords) != 1 || aRecords[0] != "1.2.3.4" {
|
|
t.Errorf("A records: got %v", aRecords)
|
|
}
|
|
}
|
|
|
|
// TestPortState_GetSetDelete tests port state getter, setter, and deletion.
|
|
func TestPortState_GetSetDelete(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
_, ok := s.GetPortState("1.2.3.4:80")
|
|
if ok {
|
|
t.Error("expected false for missing port")
|
|
}
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
ps := &state.PortState{
|
|
Open: true,
|
|
Hostnames: []string{testHostname},
|
|
LastChecked: now,
|
|
}
|
|
|
|
s.SetPortState("1.2.3.4:80", ps)
|
|
|
|
got, ok := s.GetPortState("1.2.3.4:80")
|
|
if !ok {
|
|
t.Fatal("expected true for existing port")
|
|
}
|
|
|
|
if !got.Open {
|
|
t.Error("expected open=true")
|
|
}
|
|
|
|
// Delete the port state.
|
|
s.DeletePortState("1.2.3.4:80")
|
|
|
|
_, ok = s.GetPortState("1.2.3.4:80")
|
|
if ok {
|
|
t.Error("expected false after delete")
|
|
}
|
|
}
|
|
|
|
// TestGetAllPortKeys tests retrieving all port state keys.
|
|
func TestGetAllPortKeys(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
keys := s.GetAllPortKeys()
|
|
if len(keys) != 0 {
|
|
t.Errorf("expected 0 keys, got %d", len(keys))
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
|
|
s.SetPortState("1.2.3.4:80", &state.PortState{
|
|
Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now,
|
|
})
|
|
|
|
s.SetPortState("1.2.3.4:443", &state.PortState{
|
|
Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now,
|
|
})
|
|
|
|
s.SetPortState("5.6.7.8:80", &state.PortState{
|
|
Open: false, Hostnames: []string{"b.example.com"}, LastChecked: now,
|
|
})
|
|
|
|
keys = s.GetAllPortKeys()
|
|
if len(keys) != 3 {
|
|
t.Fatalf("expected 3 keys, got %d", len(keys))
|
|
}
|
|
|
|
keySet := make(map[string]bool)
|
|
for _, k := range keys {
|
|
keySet[k] = true
|
|
}
|
|
|
|
for _, want := range []string{"1.2.3.4:80", "1.2.3.4:443", "5.6.7.8:80"} {
|
|
if !keySet[want] {
|
|
t.Errorf("missing key %q", want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestCertificateState_GetSet tests certificate state getter and setter.
|
|
func TestCertificateState_GetSet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
_, ok := s.GetCertificateState("1.2.3.4:443:www.example.com")
|
|
if ok {
|
|
t.Error("expected false for missing certificate")
|
|
}
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
cs := &state.CertificateState{
|
|
CommonName: testHostname,
|
|
Issuer: "Let's Encrypt Authority X3",
|
|
NotAfter: now.Add(90 * 24 * time.Hour),
|
|
SubjectAlternativeNames: []string{testHostname},
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
}
|
|
|
|
s.SetCertificateState("1.2.3.4:443:www.example.com", cs)
|
|
|
|
got, ok := s.GetCertificateState("1.2.3.4:443:www.example.com")
|
|
if !ok {
|
|
t.Fatal("expected true for existing certificate")
|
|
}
|
|
|
|
if got.CommonName != testHostname {
|
|
t.Errorf("CN: got %q", got.CommonName)
|
|
}
|
|
|
|
if got.Issuer != "Let's Encrypt Authority X3" {
|
|
t.Errorf("issuer: got %q", got.Issuer)
|
|
}
|
|
|
|
if got.Status != "ok" {
|
|
t.Errorf("status: got %q", got.Status)
|
|
}
|
|
|
|
if len(got.SubjectAlternativeNames) != 1 {
|
|
t.Errorf("SANs: expected 1, got %d", len(got.SubjectAlternativeNames))
|
|
}
|
|
}
|
|
|
|
// TestCertificateState_ErrorField tests that the error field round-trips.
|
|
func TestCertificateState_ErrorField(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
cs := &state.CertificateState{
|
|
Status: "error",
|
|
Error: "connection refused",
|
|
LastChecked: now,
|
|
}
|
|
|
|
s.SetCertificateState("10.0.0.1:443:err.example.com", cs)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
got, ok := loaded.GetCertificateState("10.0.0.1:443:err.example.com")
|
|
if !ok {
|
|
t.Fatal("missing certificate after load")
|
|
}
|
|
|
|
if got.Status != "error" {
|
|
t.Errorf("status: got %q, want %q", got.Status, "error")
|
|
}
|
|
|
|
if got.Error != "connection refused" {
|
|
t.Errorf("error field: got %q", got.Error)
|
|
}
|
|
}
|
|
|
|
// TestHostnameState_ErrorField tests that hostname error states round-trip.
|
|
func TestHostnameState_ErrorField(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
now := time.Now().UTC().Truncate(time.Second)
|
|
hs := &state.HostnameState{
|
|
RecordsByNameserver: map[string]*state.NameserverRecordState{
|
|
"ns1.example.com.": {
|
|
Records: nil,
|
|
Status: "error",
|
|
Error: "SERVFAIL",
|
|
LastChecked: now,
|
|
},
|
|
},
|
|
LastChecked: now,
|
|
}
|
|
|
|
s.SetHostnameState("fail.example.com", hs)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
got, ok := loaded.GetHostnameState("fail.example.com")
|
|
if !ok {
|
|
t.Fatal("missing hostname after load")
|
|
}
|
|
|
|
nsState := got.RecordsByNameserver["ns1.example.com."]
|
|
if nsState.Status != "error" {
|
|
t.Errorf("status: got %q, want %q", nsState.Status, "error")
|
|
}
|
|
|
|
if nsState.Error != "SERVFAIL" {
|
|
t.Errorf("error: got %q, want %q", nsState.Error, "SERVFAIL")
|
|
}
|
|
}
|
|
|
|
// TestSnapshotVersion verifies that the snapshot version is written correctly.
|
|
func TestSnapshotVersion(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
//nolint:gosec // test file with known path
|
|
data, err := os.ReadFile(filepath.Join(dir, "state.json"))
|
|
if err != nil {
|
|
t.Fatalf("reading state file: %v", err)
|
|
}
|
|
|
|
var snap state.Snapshot
|
|
|
|
err = json.Unmarshal(data, &snap)
|
|
if err != nil {
|
|
t.Fatalf("parsing state: %v", err)
|
|
}
|
|
|
|
if snap.Version != 1 {
|
|
t.Errorf("version: got %d, want 1", snap.Version)
|
|
}
|
|
|
|
if snap.LastUpdated.IsZero() {
|
|
t.Error("lastUpdated should be set after save")
|
|
}
|
|
}
|
|
|
|
// TestSaveUpdatesLastUpdated verifies that Save sets LastUpdated.
|
|
func TestSaveUpdatesLastUpdated(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
before := time.Now().UTC().Add(-time.Second)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
after := time.Now().UTC().Add(time.Second)
|
|
|
|
snap := s.GetSnapshot()
|
|
if snap.LastUpdated.Before(before) || snap.LastUpdated.After(after) {
|
|
t.Errorf(
|
|
"LastUpdated %v not in expected range [%v, %v]",
|
|
snap.LastUpdated, before, after,
|
|
)
|
|
}
|
|
}
|
|
|
|
// TestLoadPreservesExistingStateOnMissingFile verifies that when Load is
|
|
// called and the file doesn't exist, the existing in-memory state is kept.
|
|
func TestLoadPreservesExistingStateOnMissingFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
// Add some data before loading.
|
|
s.SetDomainState("pre.example.com", &state.DomainState{
|
|
Nameservers: []string{"ns1.pre.example.com."},
|
|
LastChecked: time.Now().UTC(),
|
|
})
|
|
|
|
err := s.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
// Since the file doesn't exist, the existing data should persist.
|
|
_, ok := s.GetDomainState("pre.example.com")
|
|
if !ok {
|
|
t.Error("existing domain state should persist when file is missing")
|
|
}
|
|
}
|
|
|
|
// TestConcurrentGetSet verifies that concurrent reads and writes don't race.
|
|
func TestConcurrentGetSet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
const goroutines = 20
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
wg.Add(goroutines)
|
|
|
|
for i := range goroutines {
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
|
|
now := time.Now().UTC()
|
|
key := string(rune('a' + id%26))
|
|
|
|
runConcurrentOps(s, key, now)
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// runConcurrentOps performs a series of get/set/delete operations for concurrency testing.
|
|
func runConcurrentOps(s *state.State, key string, now time.Time) {
|
|
const iterations = 50
|
|
|
|
for j := range iterations {
|
|
s.SetDomainState(key+".com", &state.DomainState{
|
|
Nameservers: []string{"ns1." + key + ".com."},
|
|
LastChecked: now,
|
|
})
|
|
|
|
s.GetDomainState(key + ".com")
|
|
|
|
s.SetPortState(key+":80", &state.PortState{
|
|
Open: j%2 == 0, Hostnames: []string{key + ".com"}, LastChecked: now,
|
|
})
|
|
|
|
s.GetPortState(key + ":80")
|
|
s.GetAllPortKeys()
|
|
s.GetSnapshot()
|
|
|
|
s.SetHostnameState(key+".example.com", &state.HostnameState{
|
|
RecordsByNameserver: map[string]*state.NameserverRecordState{
|
|
"ns1.test.": {
|
|
Records: map[string][]string{"A": {"1.2.3.4"}},
|
|
Status: "ok",
|
|
LastChecked: now,
|
|
},
|
|
},
|
|
LastChecked: now,
|
|
})
|
|
|
|
s.GetHostnameState(key + ".example.com")
|
|
|
|
s.SetCertificateState("1.2.3.4:443:"+key+".com", &state.CertificateState{
|
|
CommonName: key + ".com", Status: "ok", LastChecked: now,
|
|
})
|
|
|
|
s.GetCertificateState("1.2.3.4:443:" + key + ".com")
|
|
|
|
if j%10 == 0 {
|
|
s.DeletePortState(key + ":80")
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestConcurrentSaveLoad verifies that concurrent Save and Load
|
|
// operations don't race or corrupt state.
|
|
func TestConcurrentSaveLoad(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
populateState(t, s)
|
|
|
|
const goroutines = 10
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
wg.Add(goroutines)
|
|
|
|
for i := range goroutines {
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
|
|
if id%2 == 0 {
|
|
_ = s.Save()
|
|
} else {
|
|
_ = s.Load()
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// TestPortStateRoundTripThroughSaveLoad verifies backward-compat deserialization
|
|
// works through the full Save→Load cycle with old-format data on disk.
|
|
func TestPortStateRoundTripThroughSaveLoad(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
|
|
// Write a state file with old single-hostname port format.
|
|
oldFormatJSON := `{
|
|
"version": 1,
|
|
"lastUpdated": "2026-02-19T12:00:00Z",
|
|
"domains": {},
|
|
"hostnames": {},
|
|
"ports": {
|
|
"10.0.0.1:80": {
|
|
"open": true,
|
|
"hostname": "legacy.example.com",
|
|
"lastChecked": "2026-02-19T12:00:00Z"
|
|
},
|
|
"10.0.0.1:443": {
|
|
"open": true,
|
|
"hostnames": ["modern.example.com"],
|
|
"lastChecked": "2026-02-19T12:00:00Z"
|
|
}
|
|
},
|
|
"certificates": {}
|
|
}`
|
|
|
|
err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(oldFormatJSON), 0o600)
|
|
if err != nil {
|
|
t.Fatalf("writing old format state: %v", err)
|
|
}
|
|
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
err = s.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
// Verify old format was migrated.
|
|
ps80, ok := s.GetPortState("10.0.0.1:80")
|
|
if !ok {
|
|
t.Fatal("missing port 10.0.0.1:80")
|
|
}
|
|
|
|
if len(ps80.Hostnames) != 1 || ps80.Hostnames[0] != "legacy.example.com" {
|
|
t.Errorf("old format port hostnames: got %v", ps80.Hostnames)
|
|
}
|
|
|
|
// Verify new format loaded correctly.
|
|
ps443, ok := s.GetPortState("10.0.0.1:443")
|
|
if !ok {
|
|
t.Fatal("missing port 10.0.0.1:443")
|
|
}
|
|
|
|
if len(ps443.Hostnames) != 1 || ps443.Hostnames[0] != "modern.example.com" {
|
|
t.Errorf("new format port hostnames: got %v", ps443.Hostnames)
|
|
}
|
|
}
|
|
|
|
// TestMultipleSavesOverwrite verifies that subsequent saves overwrite the
|
|
// previous state completely.
|
|
func TestMultipleSavesOverwrite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
s.SetDomainState("first.com", &state.DomainState{
|
|
Nameservers: []string{"ns1.first.com."},
|
|
LastChecked: time.Now().UTC(),
|
|
})
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("first Save() error: %v", err)
|
|
}
|
|
|
|
// Create a brand new state with the same dir and set different data.
|
|
s2 := state.NewForTestWithDataDir(dir)
|
|
|
|
s2.SetDomainState("second.com", &state.DomainState{
|
|
Nameservers: []string{"ns1.second.com."},
|
|
LastChecked: time.Now().UTC(),
|
|
})
|
|
|
|
err = s2.Save()
|
|
if err != nil {
|
|
t.Fatalf("second Save() error: %v", err)
|
|
}
|
|
|
|
// Load and verify only the second domain exists.
|
|
loaded := state.NewForTestWithDataDir(dir)
|
|
|
|
err = loaded.Load()
|
|
if err != nil {
|
|
t.Fatalf("Load() error: %v", err)
|
|
}
|
|
|
|
snap := loaded.GetSnapshot()
|
|
|
|
if _, ok := snap.Domains["first.com"]; ok {
|
|
t.Error("first.com should not exist after overwrite")
|
|
}
|
|
|
|
if _, ok := snap.Domains["second.com"]; !ok {
|
|
t.Error("second.com should exist after overwrite")
|
|
}
|
|
}
|
|
|
|
// TestNewForTest verifies the test helper creates a valid empty state.
|
|
func TestNewForTest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := state.NewForTest()
|
|
|
|
snap := s.GetSnapshot()
|
|
|
|
if snap.Version != 1 {
|
|
t.Errorf("version: got %d, want 1", snap.Version)
|
|
}
|
|
|
|
if snap.Domains == nil {
|
|
t.Error("Domains map should be initialized")
|
|
}
|
|
|
|
if snap.Hostnames == nil {
|
|
t.Error("Hostnames map should be initialized")
|
|
}
|
|
|
|
if snap.Ports == nil {
|
|
t.Error("Ports map should be initialized")
|
|
}
|
|
|
|
if snap.Certificates == nil {
|
|
t.Error("Certificates map should be initialized")
|
|
}
|
|
}
|
|
|
|
// TestSaveFilePermissions verifies the saved file has restricted permissions.
|
|
func TestSaveFilePermissions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
s := state.NewForTestWithDataDir(dir)
|
|
|
|
err := s.Save()
|
|
if err != nil {
|
|
t.Fatalf("Save() error: %v", err)
|
|
}
|
|
|
|
info, err := os.Stat(filepath.Join(dir, "state.json"))
|
|
if err != nil {
|
|
t.Fatalf("stat error: %v", err)
|
|
}
|
|
|
|
perm := info.Mode().Perm()
|
|
if perm != 0o600 {
|
|
t.Errorf("file permissions: got %o, want 600", perm)
|
|
}
|
|
}
|