package state_test import ( "encoding/json" "os" "path/filepath" "sync" "testing" "time" "sneak.berlin/go/dnswatcher/internal/state" ) const testHostname = "www.example.com" // populateState fills a State with representative test data across all categories. func populateState(t *testing.T, s *state.State) { t.Helper() now := time.Now().UTC().Truncate(time.Second) s.SetDomainState("example.com", &state.DomainState{ Nameservers: []string{"ns1.example.com.", "ns2.example.com."}, LastChecked: now, }) s.SetDomainState("example.org", &state.DomainState{ Nameservers: []string{"ns1.example.org."}, LastChecked: now, }) s.SetHostnameState(testHostname, &state.HostnameState{ RecordsByNameserver: map[string]*state.NameserverRecordState{ "ns1.example.com.": { Records: map[string][]string{ "A": {"93.184.216.34"}, "AAAA": {"2606:2800:220:1:248:1893:25c8:1946"}, }, Status: "ok", LastChecked: now, }, "ns2.example.com.": { Records: map[string][]string{ "A": {"93.184.216.34"}, }, Status: "ok", LastChecked: now, }, }, LastChecked: now, }) s.SetPortState("93.184.216.34:80", &state.PortState{ Open: true, Hostnames: []string{testHostname}, LastChecked: now, }) s.SetPortState("93.184.216.34:443", &state.PortState{ Open: true, Hostnames: []string{testHostname, "api.example.com"}, LastChecked: now, }) s.SetCertificateState("93.184.216.34:443:www.example.com", &state.CertificateState{ CommonName: testHostname, Issuer: "DigiCert TLS RSA SHA256 2020 CA1", NotAfter: now.Add(365 * 24 * time.Hour), SubjectAlternativeNames: []string{testHostname, "example.com"}, Status: "ok", LastChecked: now, }) } // TestSaveLoadRoundTrip_Domains verifies domain data survives a save/load cycle. func TestSaveLoadRoundTrip_Domains(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) populateState(t, s) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } snap := loaded.GetSnapshot() if len(snap.Domains) != 2 { t.Fatalf("expected 2 domains, got %d", len(snap.Domains)) } dom, ok := snap.Domains["example.com"] if !ok { t.Fatal("missing domain example.com") } if len(dom.Nameservers) != 2 { t.Errorf("expected 2 nameservers, got %d", len(dom.Nameservers)) } } // TestSaveLoadRoundTrip_Hostnames verifies hostname data survives a save/load cycle. func TestSaveLoadRoundTrip_Hostnames(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) populateState(t, s) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } snap := loaded.GetSnapshot() if len(snap.Hostnames) != 1 { t.Fatalf("expected 1 hostname, got %d", len(snap.Hostnames)) } hn, ok := snap.Hostnames[testHostname] if !ok { t.Fatal("missing hostname") } if len(hn.RecordsByNameserver) != 2 { t.Errorf("expected 2 NS entries, got %d", len(hn.RecordsByNameserver)) } verifyNS1Records(t, hn) } // verifyNS1Records checks the ns1.example.com. records in the loaded hostname state. func verifyNS1Records(t *testing.T, hn *state.HostnameState) { t.Helper() ns1, ok := hn.RecordsByNameserver["ns1.example.com."] if !ok { t.Fatal("missing nameserver ns1.example.com.") } aRecords := ns1.Records["A"] if len(aRecords) != 1 || aRecords[0] != "93.184.216.34" { t.Errorf("ns1 A records: got %v", aRecords) } aaaaRecords := ns1.Records["AAAA"] if len(aaaaRecords) != 1 || aaaaRecords[0] != "2606:2800:220:1:248:1893:25c8:1946" { t.Errorf("ns1 AAAA records: got %v", aaaaRecords) } if ns1.Status != "ok" { t.Errorf("ns1 status: got %q, want %q", ns1.Status, "ok") } } // TestSaveLoadRoundTrip_Ports verifies port data survives a save/load cycle. func TestSaveLoadRoundTrip_Ports(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) populateState(t, s) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } snap := loaded.GetSnapshot() if len(snap.Ports) != 2 { t.Fatalf("expected 2 ports, got %d", len(snap.Ports)) } port443, ok := snap.Ports["93.184.216.34:443"] if !ok { t.Fatal("missing port 93.184.216.34:443") } if !port443.Open { t.Error("port 443: expected open") } if len(port443.Hostnames) != 2 { t.Errorf("expected 2 hostnames, got %d", len(port443.Hostnames)) } } // TestSaveLoadRoundTrip_Certificates verifies certificate data survives a save/load cycle. func TestSaveLoadRoundTrip_Certificates(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) populateState(t, s) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } snap := loaded.GetSnapshot() if len(snap.Certificates) != 1 { t.Fatalf("expected 1 certificate, got %d", len(snap.Certificates)) } cert, ok := snap.Certificates["93.184.216.34:443:www.example.com"] if !ok { t.Fatal("missing certificate") } if cert.CommonName != testHostname { t.Errorf("cert CN: got %q", cert.CommonName) } if cert.Issuer != "DigiCert TLS RSA SHA256 2020 CA1" { t.Errorf("cert issuer: got %q", cert.Issuer) } if len(cert.SubjectAlternativeNames) != 2 { t.Errorf("cert SANs: expected 2, got %d", len(cert.SubjectAlternativeNames)) } if cert.Status != "ok" { t.Errorf("cert status: got %q", cert.Status) } } // TestSaveCreatesDirectory verifies Save creates the data directory if missing. func TestSaveCreatesDirectory(t *testing.T) { t.Parallel() dir := t.TempDir() nested := filepath.Join(dir, "nested", "deep") s := state.NewForTestWithDataDir(nested) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } statePath := filepath.Join(nested, "state.json") info, err := os.Stat(statePath) if err != nil { t.Fatalf("state file not found: %v", err) } if info.Size() == 0 { t.Error("state file is empty") } } // TestSaveAtomicWrite verifies the atomic write pattern (no leftover temp files). func TestSaveAtomicWrite(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) populateState(t, s) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } statePath := filepath.Join(dir, "state.json") tmpPath := statePath + ".tmp" // No .tmp file should remain after successful save. _, err = os.Stat(tmpPath) if !os.IsNotExist(err) { t.Error("temp file should not exist after successful save") } // Modify state and save again; verify updated content. s.SetDomainState("new.example.com", &state.DomainState{ Nameservers: []string{"ns1.new.example.com."}, LastChecked: time.Now().UTC(), }) err = s.Save() if err != nil { t.Fatalf("second Save() error: %v", err) } _, err = os.Stat(tmpPath) if !os.IsNotExist(err) { t.Error("temp file should not exist after second save") } // Verify new domain is present by loading. loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } _, ok := loaded.GetDomainState("new.example.com") if !ok { t.Error("new domain not found after second save") } } // TestLoadMissingFile verifies that loading a nonexistent file is not an error. func TestLoadMissingFile(t *testing.T) { t.Parallel() s := state.NewForTestWithDataDir(t.TempDir()) err := s.Load() if err != nil { t.Fatalf("Load() with missing file should not error: %v", err) } snap := s.GetSnapshot() if len(snap.Domains) != 0 { t.Error("expected empty domains") } if len(snap.Hostnames) != 0 { t.Error("expected empty hostnames") } if len(snap.Ports) != 0 { t.Error("expected empty ports") } if len(snap.Certificates) != 0 { t.Error("expected empty certificates") } } // TestLoadCorruptFile verifies that a corrupt state file returns an error. func TestLoadCorruptFile(t *testing.T) { t.Parallel() dir := t.TempDir() err := os.WriteFile( filepath.Join(dir, "state.json"), []byte("not valid json{{{"), 0o600, ) if err != nil { t.Fatalf("writing corrupt file: %v", err) } s := state.NewForTestWithDataDir(dir) err = s.Load() if err == nil { t.Fatal("Load() should return error for corrupt file") } } // TestLoadEmptyFile verifies that an empty state file returns an error. func TestLoadEmptyFile(t *testing.T) { t.Parallel() dir := t.TempDir() err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(""), 0o600) if err != nil { t.Fatalf("writing empty file: %v", err) } s := state.NewForTestWithDataDir(dir) err = s.Load() if err == nil { t.Fatal("Load() should return error for empty file") } } // TestLoadReadPermissionError verifies that an unreadable state file returns an error. func TestLoadReadPermissionError(t *testing.T) { t.Parallel() if os.Getuid() == 0 { t.Skip("skipping permission test when running as root") } dir := t.TempDir() statePath := filepath.Join(dir, "state.json") err := os.WriteFile(statePath, []byte("{}"), 0o600) if err != nil { t.Fatalf("writing file: %v", err) } err = os.Chmod(statePath, 0o000) if err != nil { t.Fatalf("chmod: %v", err) } t.Cleanup(func() { _ = os.Chmod(statePath, 0o600) }) s := state.NewForTestWithDataDir(dir) err = s.Load() if err == nil { t.Fatal("Load() should return error for unreadable file") } } // TestSaveWritePermissionError verifies that Save returns an error // when the directory is not writable. func TestSaveWritePermissionError(t *testing.T) { t.Parallel() if os.Getuid() == 0 { t.Skip("skipping permission test when running as root") } dir := t.TempDir() dataDir := filepath.Join(dir, "readonly") err := os.MkdirAll(dataDir, 0o700) if err != nil { t.Fatalf("creating dir: %v", err) } //nolint:gosec // intentionally making dir read-only+execute for test err = os.Chmod(dataDir, 0o500) if err != nil { t.Fatalf("chmod: %v", err) } t.Cleanup(func() { //nolint:gosec // restoring write permission for cleanup _ = os.Chmod(dataDir, 0o700) }) s := state.NewForTestWithDataDir(dataDir) err = s.Save() if err == nil { t.Fatal("Save() should return error for read-only directory") } } // TestPortStateUnmarshalJSON_NewFormat verifies deserialization of the // current multi-hostname format. func TestPortStateUnmarshalJSON_NewFormat(t *testing.T) { t.Parallel() data := []byte(`{ "open": true, "hostnames": ["www.example.com", "api.example.com"], "lastChecked": "2026-02-19T12:00:00Z" }`) var ps state.PortState err := json.Unmarshal(data, &ps) if err != nil { t.Fatalf("Unmarshal error: %v", err) } if !ps.Open { t.Error("expected open=true") } if len(ps.Hostnames) != 2 { t.Fatalf("expected 2 hostnames, got %d", len(ps.Hostnames)) } if ps.Hostnames[0] != testHostname { t.Errorf("hostname[0]: got %q", ps.Hostnames[0]) } if ps.Hostnames[1] != "api.example.com" { t.Errorf("hostname[1]: got %q", ps.Hostnames[1]) } if ps.LastChecked.IsZero() { t.Error("expected non-zero LastChecked") } } // TestPortStateUnmarshalJSON_OldFormat verifies backward-compatible // deserialization of the old single-hostname format. func TestPortStateUnmarshalJSON_OldFormat(t *testing.T) { t.Parallel() data := []byte(`{ "open": false, "hostname": "legacy.example.com", "lastChecked": "2026-01-15T10:30:00Z" }`) var ps state.PortState err := json.Unmarshal(data, &ps) if err != nil { t.Fatalf("Unmarshal error: %v", err) } if ps.Open { t.Error("expected open=false") } if len(ps.Hostnames) != 1 { t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames)) } if ps.Hostnames[0] != "legacy.example.com" { t.Errorf("hostname: got %q", ps.Hostnames[0]) } } // TestPortStateUnmarshalJSON_EmptyHostnames verifies that when neither // hostnames nor hostname is present, Hostnames remains nil. func TestPortStateUnmarshalJSON_EmptyHostnames(t *testing.T) { t.Parallel() data := []byte(`{ "open": true, "lastChecked": "2026-02-19T12:00:00Z" }`) var ps state.PortState err := json.Unmarshal(data, &ps) if err != nil { t.Fatalf("Unmarshal error: %v", err) } if len(ps.Hostnames) != 0 { t.Errorf("expected empty hostnames, got %v", ps.Hostnames) } } // TestPortStateUnmarshalJSON_InvalidJSON verifies that invalid JSON returns an error. func TestPortStateUnmarshalJSON_InvalidJSON(t *testing.T) { t.Parallel() data := []byte(`{invalid json}`) var ps state.PortState err := json.Unmarshal(data, &ps) if err == nil { t.Fatal("expected error for invalid JSON") } } // TestPortStateUnmarshalJSON_BothFormats verifies that when both "hostnames" // and "hostname" are present, the new format takes precedence. func TestPortStateUnmarshalJSON_BothFormats(t *testing.T) { t.Parallel() data := []byte(`{ "open": true, "hostnames": ["new.example.com"], "hostname": "old.example.com", "lastChecked": "2026-02-19T12:00:00Z" }`) var ps state.PortState err := json.Unmarshal(data, &ps) if err != nil { t.Fatalf("Unmarshal error: %v", err) } // When hostnames is non-empty, the old hostname field is ignored. if len(ps.Hostnames) != 1 { t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames)) } if ps.Hostnames[0] != "new.example.com" { t.Errorf("hostname: got %q, want %q", ps.Hostnames[0], "new.example.com") } } // TestGetSnapshot_ReturnsCopy verifies that GetSnapshot returns a value copy. func TestGetSnapshot_ReturnsCopy(t *testing.T) { t.Parallel() s := state.NewForTest() populateState(t, s) snap := s.GetSnapshot() // Mutate the returned snapshot's map. snap.Domains["mutated.example.com"] = &state.DomainState{ Nameservers: []string{"ns1.mutated.com."}, } // The original domain data should still be accessible. _, ok := s.GetDomainState("example.com") if !ok { t.Error("original domain should still exist") } } // TestDomainState_GetSet tests domain state getter and setter. func TestDomainState_GetSet(t *testing.T) { t.Parallel() s := state.NewForTest() // Get on missing key returns false. _, ok := s.GetDomainState("nonexistent.com") if ok { t.Error("expected false for missing domain") } now := time.Now().UTC().Truncate(time.Second) ds := &state.DomainState{ Nameservers: []string{"ns1.test.com."}, LastChecked: now, } s.SetDomainState("test.com", ds) got, ok := s.GetDomainState("test.com") if !ok { t.Fatal("expected true for existing domain") } if len(got.Nameservers) != 1 || got.Nameservers[0] != "ns1.test.com." { t.Errorf("nameservers: got %v", got.Nameservers) } if !got.LastChecked.Equal(now) { t.Errorf("lastChecked: got %v, want %v", got.LastChecked, now) } // Overwrite. ds2 := &state.DomainState{ Nameservers: []string{"ns1.test.com.", "ns2.test.com."}, LastChecked: now.Add(time.Hour), } s.SetDomainState("test.com", ds2) got2, ok := s.GetDomainState("test.com") if !ok { t.Fatal("expected true after overwrite") } if len(got2.Nameservers) != 2 { t.Errorf("expected 2 nameservers after overwrite, got %d", len(got2.Nameservers)) } } // TestHostnameState_GetSet tests hostname state getter and setter. func TestHostnameState_GetSet(t *testing.T) { t.Parallel() s := state.NewForTest() _, ok := s.GetHostnameState("missing.example.com") if ok { t.Error("expected false for missing hostname") } now := time.Now().UTC().Truncate(time.Second) hs := &state.HostnameState{ RecordsByNameserver: map[string]*state.NameserverRecordState{ "ns1.example.com.": { Records: map[string][]string{"A": {"1.2.3.4"}}, Status: "ok", LastChecked: now, }, }, LastChecked: now, } s.SetHostnameState(testHostname, hs) got, ok := s.GetHostnameState(testHostname) if !ok { t.Fatal("expected true for existing hostname") } nsState, ok := got.RecordsByNameserver["ns1.example.com."] if !ok { t.Fatal("missing nameserver entry") } if nsState.Status != "ok" { t.Errorf("status: got %q", nsState.Status) } aRecords := nsState.Records["A"] if len(aRecords) != 1 || aRecords[0] != "1.2.3.4" { t.Errorf("A records: got %v", aRecords) } } // TestPortState_GetSetDelete tests port state getter, setter, and deletion. func TestPortState_GetSetDelete(t *testing.T) { t.Parallel() s := state.NewForTest() _, ok := s.GetPortState("1.2.3.4:80") if ok { t.Error("expected false for missing port") } now := time.Now().UTC().Truncate(time.Second) ps := &state.PortState{ Open: true, Hostnames: []string{testHostname}, LastChecked: now, } s.SetPortState("1.2.3.4:80", ps) got, ok := s.GetPortState("1.2.3.4:80") if !ok { t.Fatal("expected true for existing port") } if !got.Open { t.Error("expected open=true") } // Delete the port state. s.DeletePortState("1.2.3.4:80") _, ok = s.GetPortState("1.2.3.4:80") if ok { t.Error("expected false after delete") } } // TestGetAllPortKeys tests retrieving all port state keys. func TestGetAllPortKeys(t *testing.T) { t.Parallel() s := state.NewForTest() keys := s.GetAllPortKeys() if len(keys) != 0 { t.Errorf("expected 0 keys, got %d", len(keys)) } now := time.Now().UTC() s.SetPortState("1.2.3.4:80", &state.PortState{ Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now, }) s.SetPortState("1.2.3.4:443", &state.PortState{ Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now, }) s.SetPortState("5.6.7.8:80", &state.PortState{ Open: false, Hostnames: []string{"b.example.com"}, LastChecked: now, }) keys = s.GetAllPortKeys() if len(keys) != 3 { t.Fatalf("expected 3 keys, got %d", len(keys)) } keySet := make(map[string]bool) for _, k := range keys { keySet[k] = true } for _, want := range []string{"1.2.3.4:80", "1.2.3.4:443", "5.6.7.8:80"} { if !keySet[want] { t.Errorf("missing key %q", want) } } } // TestCertificateState_GetSet tests certificate state getter and setter. func TestCertificateState_GetSet(t *testing.T) { t.Parallel() s := state.NewForTest() _, ok := s.GetCertificateState("1.2.3.4:443:www.example.com") if ok { t.Error("expected false for missing certificate") } now := time.Now().UTC().Truncate(time.Second) cs := &state.CertificateState{ CommonName: testHostname, Issuer: "Let's Encrypt Authority X3", NotAfter: now.Add(90 * 24 * time.Hour), SubjectAlternativeNames: []string{testHostname}, Status: "ok", LastChecked: now, } s.SetCertificateState("1.2.3.4:443:www.example.com", cs) got, ok := s.GetCertificateState("1.2.3.4:443:www.example.com") if !ok { t.Fatal("expected true for existing certificate") } if got.CommonName != testHostname { t.Errorf("CN: got %q", got.CommonName) } if got.Issuer != "Let's Encrypt Authority X3" { t.Errorf("issuer: got %q", got.Issuer) } if got.Status != "ok" { t.Errorf("status: got %q", got.Status) } if len(got.SubjectAlternativeNames) != 1 { t.Errorf("SANs: expected 1, got %d", len(got.SubjectAlternativeNames)) } } // TestCertificateState_ErrorField tests that the error field round-trips. func TestCertificateState_ErrorField(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) now := time.Now().UTC().Truncate(time.Second) cs := &state.CertificateState{ Status: "error", Error: "connection refused", LastChecked: now, } s.SetCertificateState("10.0.0.1:443:err.example.com", cs) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } got, ok := loaded.GetCertificateState("10.0.0.1:443:err.example.com") if !ok { t.Fatal("missing certificate after load") } if got.Status != "error" { t.Errorf("status: got %q, want %q", got.Status, "error") } if got.Error != "connection refused" { t.Errorf("error field: got %q", got.Error) } } // TestHostnameState_ErrorField tests that hostname error states round-trip. func TestHostnameState_ErrorField(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) now := time.Now().UTC().Truncate(time.Second) hs := &state.HostnameState{ RecordsByNameserver: map[string]*state.NameserverRecordState{ "ns1.example.com.": { Records: nil, Status: "error", Error: "SERVFAIL", LastChecked: now, }, }, LastChecked: now, } s.SetHostnameState("fail.example.com", hs) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } got, ok := loaded.GetHostnameState("fail.example.com") if !ok { t.Fatal("missing hostname after load") } nsState := got.RecordsByNameserver["ns1.example.com."] if nsState.Status != "error" { t.Errorf("status: got %q, want %q", nsState.Status, "error") } if nsState.Error != "SERVFAIL" { t.Errorf("error: got %q, want %q", nsState.Error, "SERVFAIL") } } // TestSnapshotVersion verifies that the snapshot version is written correctly. func TestSnapshotVersion(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } //nolint:gosec // test file with known path data, err := os.ReadFile(filepath.Join(dir, "state.json")) if err != nil { t.Fatalf("reading state file: %v", err) } var snap state.Snapshot err = json.Unmarshal(data, &snap) if err != nil { t.Fatalf("parsing state: %v", err) } if snap.Version != 1 { t.Errorf("version: got %d, want 1", snap.Version) } if snap.LastUpdated.IsZero() { t.Error("lastUpdated should be set after save") } } // TestSaveUpdatesLastUpdated verifies that Save sets LastUpdated. func TestSaveUpdatesLastUpdated(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) before := time.Now().UTC().Add(-time.Second) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } after := time.Now().UTC().Add(time.Second) snap := s.GetSnapshot() if snap.LastUpdated.Before(before) || snap.LastUpdated.After(after) { t.Errorf( "LastUpdated %v not in expected range [%v, %v]", snap.LastUpdated, before, after, ) } } // TestLoadPreservesExistingStateOnMissingFile verifies that when Load is // called and the file doesn't exist, the existing in-memory state is kept. func TestLoadPreservesExistingStateOnMissingFile(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) // Add some data before loading. s.SetDomainState("pre.example.com", &state.DomainState{ Nameservers: []string{"ns1.pre.example.com."}, LastChecked: time.Now().UTC(), }) err := s.Load() if err != nil { t.Fatalf("Load() error: %v", err) } // Since the file doesn't exist, the existing data should persist. _, ok := s.GetDomainState("pre.example.com") if !ok { t.Error("existing domain state should persist when file is missing") } } // TestConcurrentGetSet verifies that concurrent reads and writes don't race. func TestConcurrentGetSet(t *testing.T) { t.Parallel() s := state.NewForTest() const goroutines = 20 var wg sync.WaitGroup wg.Add(goroutines) for i := range goroutines { go func(id int) { defer wg.Done() now := time.Now().UTC() key := string(rune('a' + id%26)) runConcurrentOps(s, key, now) }(i) } wg.Wait() } // runConcurrentOps performs a series of get/set/delete operations for concurrency testing. func runConcurrentOps(s *state.State, key string, now time.Time) { const iterations = 50 for j := range iterations { s.SetDomainState(key+".com", &state.DomainState{ Nameservers: []string{"ns1." + key + ".com."}, LastChecked: now, }) s.GetDomainState(key + ".com") s.SetPortState(key+":80", &state.PortState{ Open: j%2 == 0, Hostnames: []string{key + ".com"}, LastChecked: now, }) s.GetPortState(key + ":80") s.GetAllPortKeys() s.GetSnapshot() s.SetHostnameState(key+".example.com", &state.HostnameState{ RecordsByNameserver: map[string]*state.NameserverRecordState{ "ns1.test.": { Records: map[string][]string{"A": {"1.2.3.4"}}, Status: "ok", LastChecked: now, }, }, LastChecked: now, }) s.GetHostnameState(key + ".example.com") s.SetCertificateState("1.2.3.4:443:"+key+".com", &state.CertificateState{ CommonName: key + ".com", Status: "ok", LastChecked: now, }) s.GetCertificateState("1.2.3.4:443:" + key + ".com") if j%10 == 0 { s.DeletePortState(key + ":80") } } } // TestConcurrentSaveLoad verifies that concurrent Save and Load // operations don't race or corrupt state. func TestConcurrentSaveLoad(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) populateState(t, s) const goroutines = 10 var wg sync.WaitGroup wg.Add(goroutines) for i := range goroutines { go func(id int) { defer wg.Done() if id%2 == 0 { _ = s.Save() } else { _ = s.Load() } }(i) } wg.Wait() } // TestPortStateRoundTripThroughSaveLoad verifies backward-compat deserialization // works through the full Save→Load cycle with old-format data on disk. func TestPortStateRoundTripThroughSaveLoad(t *testing.T) { t.Parallel() dir := t.TempDir() // Write a state file with old single-hostname port format. oldFormatJSON := `{ "version": 1, "lastUpdated": "2026-02-19T12:00:00Z", "domains": {}, "hostnames": {}, "ports": { "10.0.0.1:80": { "open": true, "hostname": "legacy.example.com", "lastChecked": "2026-02-19T12:00:00Z" }, "10.0.0.1:443": { "open": true, "hostnames": ["modern.example.com"], "lastChecked": "2026-02-19T12:00:00Z" } }, "certificates": {} }` err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(oldFormatJSON), 0o600) if err != nil { t.Fatalf("writing old format state: %v", err) } s := state.NewForTestWithDataDir(dir) err = s.Load() if err != nil { t.Fatalf("Load() error: %v", err) } // Verify old format was migrated. ps80, ok := s.GetPortState("10.0.0.1:80") if !ok { t.Fatal("missing port 10.0.0.1:80") } if len(ps80.Hostnames) != 1 || ps80.Hostnames[0] != "legacy.example.com" { t.Errorf("old format port hostnames: got %v", ps80.Hostnames) } // Verify new format loaded correctly. ps443, ok := s.GetPortState("10.0.0.1:443") if !ok { t.Fatal("missing port 10.0.0.1:443") } if len(ps443.Hostnames) != 1 || ps443.Hostnames[0] != "modern.example.com" { t.Errorf("new format port hostnames: got %v", ps443.Hostnames) } } // TestMultipleSavesOverwrite verifies that subsequent saves overwrite the // previous state completely. func TestMultipleSavesOverwrite(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) s.SetDomainState("first.com", &state.DomainState{ Nameservers: []string{"ns1.first.com."}, LastChecked: time.Now().UTC(), }) err := s.Save() if err != nil { t.Fatalf("first Save() error: %v", err) } // Create a brand new state with the same dir and set different data. s2 := state.NewForTestWithDataDir(dir) s2.SetDomainState("second.com", &state.DomainState{ Nameservers: []string{"ns1.second.com."}, LastChecked: time.Now().UTC(), }) err = s2.Save() if err != nil { t.Fatalf("second Save() error: %v", err) } // Load and verify only the second domain exists. loaded := state.NewForTestWithDataDir(dir) err = loaded.Load() if err != nil { t.Fatalf("Load() error: %v", err) } snap := loaded.GetSnapshot() if _, ok := snap.Domains["first.com"]; ok { t.Error("first.com should not exist after overwrite") } if _, ok := snap.Domains["second.com"]; !ok { t.Error("second.com should exist after overwrite") } } // TestNewForTest verifies the test helper creates a valid empty state. func TestNewForTest(t *testing.T) { t.Parallel() s := state.NewForTest() snap := s.GetSnapshot() if snap.Version != 1 { t.Errorf("version: got %d, want 1", snap.Version) } if snap.Domains == nil { t.Error("Domains map should be initialized") } if snap.Hostnames == nil { t.Error("Hostnames map should be initialized") } if snap.Ports == nil { t.Error("Ports map should be initialized") } if snap.Certificates == nil { t.Error("Certificates map should be initialized") } } // TestSaveFilePermissions verifies the saved file has restricted permissions. func TestSaveFilePermissions(t *testing.T) { t.Parallel() dir := t.TempDir() s := state.NewForTestWithDataDir(dir) err := s.Save() if err != nil { t.Fatalf("Save() error: %v", err) } info, err := os.Stat(filepath.Join(dir, "state.json")) if err != nil { t.Fatalf("stat error: %v", err) } perm := info.Mode().Perm() if perm != 0o600 { t.Errorf("file permissions: got %o, want 600", perm) } }