From c5bf16055e22d573c4a603c53078f227ff7c4c9a Mon Sep 17 00:00:00 2001 From: clawbot Date: Wed, 4 Mar 2026 11:26:05 +0100 Subject: [PATCH] test(state): add comprehensive test coverage for internal/state package (#80) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add 32 tests for the `internal/state` package, which previously had 0% test coverage. ### Tests added: **Save/Load round-trip:** - Domain, hostname, port, and certificate data all survive save→load cycles - Error fields (omitempty) round-trip correctly - Backward-compatible PortState deserialization (old single-hostname → new multi-hostname format) **Edge cases:** - Missing state file: returns nil error, keeps existing in-memory state - Corrupt state file: returns parse error - Empty state file: returns parse error - Permission errors (read/write): properly reported, skipped when running as root in Docker **Atomic write:** - No leftover .tmp files after successful save - Updated content verified after second save **Getter/setter coverage:** - Domain: get, set, overwrite - Hostname: get, set with nested nameserver records - Port: get, set, delete - Certificate: get, set - GetAllPortKeys enumeration - GetSnapshot returns value copy **Concurrency:** - 20 goroutines × 50 iterations of concurrent get/set/delete with race detector - 10 goroutines doing concurrent Save/Load **Other:** - Snapshot version written correctly - LastUpdated timestamp set on save - File permissions are 0600 - Multiple saves overwrite previous state completely - NewForTest helper creates valid empty state - Save creates nested data directories Also adds `NewForTestWithDataDir()` to the test helper for tests requiring file persistence. Closes [issue #70](https://git.eeqj.de/sneak/dnswatcher/issues/70) Co-authored-by: clawbot Reviewed-on: https://git.eeqj.de/sneak/dnswatcher/pulls/80 Co-authored-by: clawbot Co-committed-by: clawbot --- internal/state/state_test.go | 1302 +++++++++++++++++++++++++++ internal/state/state_test_helper.go | 16 + 2 files changed, 1318 insertions(+) create mode 100644 internal/state/state_test.go diff --git a/internal/state/state_test.go b/internal/state/state_test.go new file mode 100644 index 0000000..5e95eb9 --- /dev/null +++ b/internal/state/state_test.go @@ -0,0 +1,1302 @@ +package state_test + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "sneak.berlin/go/dnswatcher/internal/state" +) + +const testHostname = "www.example.com" + +// populateState fills a State with representative test data across all categories. +func populateState(t *testing.T, s *state.State) { + t.Helper() + + now := time.Now().UTC().Truncate(time.Second) + + s.SetDomainState("example.com", &state.DomainState{ + Nameservers: []string{"ns1.example.com.", "ns2.example.com."}, + LastChecked: now, + }) + + s.SetDomainState("example.org", &state.DomainState{ + Nameservers: []string{"ns1.example.org."}, + LastChecked: now, + }) + + s.SetHostnameState(testHostname, &state.HostnameState{ + RecordsByNameserver: map[string]*state.NameserverRecordState{ + "ns1.example.com.": { + Records: map[string][]string{ + "A": {"93.184.216.34"}, + "AAAA": {"2606:2800:220:1:248:1893:25c8:1946"}, + }, + Status: "ok", + LastChecked: now, + }, + "ns2.example.com.": { + Records: map[string][]string{ + "A": {"93.184.216.34"}, + }, + Status: "ok", + LastChecked: now, + }, + }, + LastChecked: now, + }) + + s.SetPortState("93.184.216.34:80", &state.PortState{ + Open: true, + Hostnames: []string{testHostname}, + LastChecked: now, + }) + + s.SetPortState("93.184.216.34:443", &state.PortState{ + Open: true, + Hostnames: []string{testHostname, "api.example.com"}, + LastChecked: now, + }) + + s.SetCertificateState("93.184.216.34:443:www.example.com", &state.CertificateState{ + CommonName: testHostname, + Issuer: "DigiCert TLS RSA SHA256 2020 CA1", + NotAfter: now.Add(365 * 24 * time.Hour), + SubjectAlternativeNames: []string{testHostname, "example.com"}, + Status: "ok", + LastChecked: now, + }) +} + +// TestSaveLoadRoundTrip_Domains verifies domain data survives a save/load cycle. +func TestSaveLoadRoundTrip_Domains(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + populateState(t, s) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + snap := loaded.GetSnapshot() + + if len(snap.Domains) != 2 { + t.Fatalf("expected 2 domains, got %d", len(snap.Domains)) + } + + dom, ok := snap.Domains["example.com"] + if !ok { + t.Fatal("missing domain example.com") + } + + if len(dom.Nameservers) != 2 { + t.Errorf("expected 2 nameservers, got %d", len(dom.Nameservers)) + } +} + +// TestSaveLoadRoundTrip_Hostnames verifies hostname data survives a save/load cycle. +func TestSaveLoadRoundTrip_Hostnames(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + populateState(t, s) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + snap := loaded.GetSnapshot() + + if len(snap.Hostnames) != 1 { + t.Fatalf("expected 1 hostname, got %d", len(snap.Hostnames)) + } + + hn, ok := snap.Hostnames[testHostname] + if !ok { + t.Fatal("missing hostname") + } + + if len(hn.RecordsByNameserver) != 2 { + t.Errorf("expected 2 NS entries, got %d", len(hn.RecordsByNameserver)) + } + + verifyNS1Records(t, hn) +} + +// verifyNS1Records checks the ns1.example.com. records in the loaded hostname state. +func verifyNS1Records(t *testing.T, hn *state.HostnameState) { + t.Helper() + + ns1, ok := hn.RecordsByNameserver["ns1.example.com."] + if !ok { + t.Fatal("missing nameserver ns1.example.com.") + } + + aRecords := ns1.Records["A"] + if len(aRecords) != 1 || aRecords[0] != "93.184.216.34" { + t.Errorf("ns1 A records: got %v", aRecords) + } + + aaaaRecords := ns1.Records["AAAA"] + if len(aaaaRecords) != 1 || aaaaRecords[0] != "2606:2800:220:1:248:1893:25c8:1946" { + t.Errorf("ns1 AAAA records: got %v", aaaaRecords) + } + + if ns1.Status != "ok" { + t.Errorf("ns1 status: got %q, want %q", ns1.Status, "ok") + } +} + +// TestSaveLoadRoundTrip_Ports verifies port data survives a save/load cycle. +func TestSaveLoadRoundTrip_Ports(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + populateState(t, s) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + snap := loaded.GetSnapshot() + + if len(snap.Ports) != 2 { + t.Fatalf("expected 2 ports, got %d", len(snap.Ports)) + } + + port443, ok := snap.Ports["93.184.216.34:443"] + if !ok { + t.Fatal("missing port 93.184.216.34:443") + } + + if !port443.Open { + t.Error("port 443: expected open") + } + + if len(port443.Hostnames) != 2 { + t.Errorf("expected 2 hostnames, got %d", len(port443.Hostnames)) + } +} + +// TestSaveLoadRoundTrip_Certificates verifies certificate data survives a save/load cycle. +func TestSaveLoadRoundTrip_Certificates(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + populateState(t, s) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + snap := loaded.GetSnapshot() + + if len(snap.Certificates) != 1 { + t.Fatalf("expected 1 certificate, got %d", len(snap.Certificates)) + } + + cert, ok := snap.Certificates["93.184.216.34:443:www.example.com"] + if !ok { + t.Fatal("missing certificate") + } + + if cert.CommonName != testHostname { + t.Errorf("cert CN: got %q", cert.CommonName) + } + + if cert.Issuer != "DigiCert TLS RSA SHA256 2020 CA1" { + t.Errorf("cert issuer: got %q", cert.Issuer) + } + + if len(cert.SubjectAlternativeNames) != 2 { + t.Errorf("cert SANs: expected 2, got %d", len(cert.SubjectAlternativeNames)) + } + + if cert.Status != "ok" { + t.Errorf("cert status: got %q", cert.Status) + } +} + +// TestSaveCreatesDirectory verifies Save creates the data directory if missing. +func TestSaveCreatesDirectory(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + nested := filepath.Join(dir, "nested", "deep") + + s := state.NewForTestWithDataDir(nested) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + statePath := filepath.Join(nested, "state.json") + + info, err := os.Stat(statePath) + if err != nil { + t.Fatalf("state file not found: %v", err) + } + + if info.Size() == 0 { + t.Error("state file is empty") + } +} + +// TestSaveAtomicWrite verifies the atomic write pattern (no leftover temp files). +func TestSaveAtomicWrite(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + populateState(t, s) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + statePath := filepath.Join(dir, "state.json") + tmpPath := statePath + ".tmp" + + // No .tmp file should remain after successful save. + _, err = os.Stat(tmpPath) + if !os.IsNotExist(err) { + t.Error("temp file should not exist after successful save") + } + + // Modify state and save again; verify updated content. + s.SetDomainState("new.example.com", &state.DomainState{ + Nameservers: []string{"ns1.new.example.com."}, + LastChecked: time.Now().UTC(), + }) + + err = s.Save() + if err != nil { + t.Fatalf("second Save() error: %v", err) + } + + _, err = os.Stat(tmpPath) + if !os.IsNotExist(err) { + t.Error("temp file should not exist after second save") + } + + // Verify new domain is present by loading. + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + _, ok := loaded.GetDomainState("new.example.com") + if !ok { + t.Error("new domain not found after second save") + } +} + +// TestLoadMissingFile verifies that loading a nonexistent file is not an error. +func TestLoadMissingFile(t *testing.T) { + t.Parallel() + + s := state.NewForTestWithDataDir(t.TempDir()) + + err := s.Load() + if err != nil { + t.Fatalf("Load() with missing file should not error: %v", err) + } + + snap := s.GetSnapshot() + + if len(snap.Domains) != 0 { + t.Error("expected empty domains") + } + + if len(snap.Hostnames) != 0 { + t.Error("expected empty hostnames") + } + + if len(snap.Ports) != 0 { + t.Error("expected empty ports") + } + + if len(snap.Certificates) != 0 { + t.Error("expected empty certificates") + } +} + +// TestLoadCorruptFile verifies that a corrupt state file returns an error. +func TestLoadCorruptFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + err := os.WriteFile( + filepath.Join(dir, "state.json"), + []byte("not valid json{{{"), + 0o600, + ) + if err != nil { + t.Fatalf("writing corrupt file: %v", err) + } + + s := state.NewForTestWithDataDir(dir) + + err = s.Load() + if err == nil { + t.Fatal("Load() should return error for corrupt file") + } +} + +// TestLoadEmptyFile verifies that an empty state file returns an error. +func TestLoadEmptyFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(""), 0o600) + if err != nil { + t.Fatalf("writing empty file: %v", err) + } + + s := state.NewForTestWithDataDir(dir) + + err = s.Load() + if err == nil { + t.Fatal("Load() should return error for empty file") + } +} + +// TestLoadReadPermissionError verifies that an unreadable state file returns an error. +func TestLoadReadPermissionError(t *testing.T) { + t.Parallel() + + if os.Getuid() == 0 { + t.Skip("skipping permission test when running as root") + } + + dir := t.TempDir() + statePath := filepath.Join(dir, "state.json") + + err := os.WriteFile(statePath, []byte("{}"), 0o600) + if err != nil { + t.Fatalf("writing file: %v", err) + } + + err = os.Chmod(statePath, 0o000) + if err != nil { + t.Fatalf("chmod: %v", err) + } + + t.Cleanup(func() { + _ = os.Chmod(statePath, 0o600) + }) + + s := state.NewForTestWithDataDir(dir) + + err = s.Load() + if err == nil { + t.Fatal("Load() should return error for unreadable file") + } +} + +// TestSaveWritePermissionError verifies that Save returns an error +// when the directory is not writable. +func TestSaveWritePermissionError(t *testing.T) { + t.Parallel() + + if os.Getuid() == 0 { + t.Skip("skipping permission test when running as root") + } + + dir := t.TempDir() + dataDir := filepath.Join(dir, "readonly") + + err := os.MkdirAll(dataDir, 0o700) + if err != nil { + t.Fatalf("creating dir: %v", err) + } + + //nolint:gosec // intentionally making dir read-only+execute for test + err = os.Chmod(dataDir, 0o500) + if err != nil { + t.Fatalf("chmod: %v", err) + } + + t.Cleanup(func() { + //nolint:gosec // restoring write permission for cleanup + _ = os.Chmod(dataDir, 0o700) + }) + + s := state.NewForTestWithDataDir(dataDir) + + err = s.Save() + if err == nil { + t.Fatal("Save() should return error for read-only directory") + } +} + +// TestPortStateUnmarshalJSON_NewFormat verifies deserialization of the +// current multi-hostname format. +func TestPortStateUnmarshalJSON_NewFormat(t *testing.T) { + t.Parallel() + + data := []byte(`{ + "open": true, + "hostnames": ["www.example.com", "api.example.com"], + "lastChecked": "2026-02-19T12:00:00Z" + }`) + + var ps state.PortState + + err := json.Unmarshal(data, &ps) + if err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if !ps.Open { + t.Error("expected open=true") + } + + if len(ps.Hostnames) != 2 { + t.Fatalf("expected 2 hostnames, got %d", len(ps.Hostnames)) + } + + if ps.Hostnames[0] != testHostname { + t.Errorf("hostname[0]: got %q", ps.Hostnames[0]) + } + + if ps.Hostnames[1] != "api.example.com" { + t.Errorf("hostname[1]: got %q", ps.Hostnames[1]) + } + + if ps.LastChecked.IsZero() { + t.Error("expected non-zero LastChecked") + } +} + +// TestPortStateUnmarshalJSON_OldFormat verifies backward-compatible +// deserialization of the old single-hostname format. +func TestPortStateUnmarshalJSON_OldFormat(t *testing.T) { + t.Parallel() + + data := []byte(`{ + "open": false, + "hostname": "legacy.example.com", + "lastChecked": "2026-01-15T10:30:00Z" + }`) + + var ps state.PortState + + err := json.Unmarshal(data, &ps) + if err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if ps.Open { + t.Error("expected open=false") + } + + if len(ps.Hostnames) != 1 { + t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames)) + } + + if ps.Hostnames[0] != "legacy.example.com" { + t.Errorf("hostname: got %q", ps.Hostnames[0]) + } +} + +// TestPortStateUnmarshalJSON_EmptyHostnames verifies that when neither +// hostnames nor hostname is present, Hostnames remains nil. +func TestPortStateUnmarshalJSON_EmptyHostnames(t *testing.T) { + t.Parallel() + + data := []byte(`{ + "open": true, + "lastChecked": "2026-02-19T12:00:00Z" + }`) + + var ps state.PortState + + err := json.Unmarshal(data, &ps) + if err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if len(ps.Hostnames) != 0 { + t.Errorf("expected empty hostnames, got %v", ps.Hostnames) + } +} + +// TestPortStateUnmarshalJSON_InvalidJSON verifies that invalid JSON returns an error. +func TestPortStateUnmarshalJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + data := []byte(`{invalid json}`) + + var ps state.PortState + + err := json.Unmarshal(data, &ps) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +// TestPortStateUnmarshalJSON_BothFormats verifies that when both "hostnames" +// and "hostname" are present, the new format takes precedence. +func TestPortStateUnmarshalJSON_BothFormats(t *testing.T) { + t.Parallel() + + data := []byte(`{ + "open": true, + "hostnames": ["new.example.com"], + "hostname": "old.example.com", + "lastChecked": "2026-02-19T12:00:00Z" + }`) + + var ps state.PortState + + err := json.Unmarshal(data, &ps) + if err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + // When hostnames is non-empty, the old hostname field is ignored. + if len(ps.Hostnames) != 1 { + t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames)) + } + + if ps.Hostnames[0] != "new.example.com" { + t.Errorf("hostname: got %q, want %q", ps.Hostnames[0], "new.example.com") + } +} + +// TestGetSnapshot_ReturnsCopy verifies that GetSnapshot returns a value copy. +func TestGetSnapshot_ReturnsCopy(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + populateState(t, s) + + snap := s.GetSnapshot() + + // Mutate the returned snapshot's map. + snap.Domains["mutated.example.com"] = &state.DomainState{ + Nameservers: []string{"ns1.mutated.com."}, + } + + // The original domain data should still be accessible. + _, ok := s.GetDomainState("example.com") + if !ok { + t.Error("original domain should still exist") + } +} + +// TestDomainState_GetSet tests domain state getter and setter. +func TestDomainState_GetSet(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + // Get on missing key returns false. + _, ok := s.GetDomainState("nonexistent.com") + if ok { + t.Error("expected false for missing domain") + } + + now := time.Now().UTC().Truncate(time.Second) + ds := &state.DomainState{ + Nameservers: []string{"ns1.test.com."}, + LastChecked: now, + } + + s.SetDomainState("test.com", ds) + + got, ok := s.GetDomainState("test.com") + if !ok { + t.Fatal("expected true for existing domain") + } + + if len(got.Nameservers) != 1 || got.Nameservers[0] != "ns1.test.com." { + t.Errorf("nameservers: got %v", got.Nameservers) + } + + if !got.LastChecked.Equal(now) { + t.Errorf("lastChecked: got %v, want %v", got.LastChecked, now) + } + + // Overwrite. + ds2 := &state.DomainState{ + Nameservers: []string{"ns1.test.com.", "ns2.test.com."}, + LastChecked: now.Add(time.Hour), + } + + s.SetDomainState("test.com", ds2) + + got2, ok := s.GetDomainState("test.com") + if !ok { + t.Fatal("expected true after overwrite") + } + + if len(got2.Nameservers) != 2 { + t.Errorf("expected 2 nameservers after overwrite, got %d", len(got2.Nameservers)) + } +} + +// TestHostnameState_GetSet tests hostname state getter and setter. +func TestHostnameState_GetSet(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + _, ok := s.GetHostnameState("missing.example.com") + if ok { + t.Error("expected false for missing hostname") + } + + now := time.Now().UTC().Truncate(time.Second) + hs := &state.HostnameState{ + RecordsByNameserver: map[string]*state.NameserverRecordState{ + "ns1.example.com.": { + Records: map[string][]string{"A": {"1.2.3.4"}}, + Status: "ok", + LastChecked: now, + }, + }, + LastChecked: now, + } + + s.SetHostnameState(testHostname, hs) + + got, ok := s.GetHostnameState(testHostname) + if !ok { + t.Fatal("expected true for existing hostname") + } + + nsState, ok := got.RecordsByNameserver["ns1.example.com."] + if !ok { + t.Fatal("missing nameserver entry") + } + + if nsState.Status != "ok" { + t.Errorf("status: got %q", nsState.Status) + } + + aRecords := nsState.Records["A"] + if len(aRecords) != 1 || aRecords[0] != "1.2.3.4" { + t.Errorf("A records: got %v", aRecords) + } +} + +// TestPortState_GetSetDelete tests port state getter, setter, and deletion. +func TestPortState_GetSetDelete(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + _, ok := s.GetPortState("1.2.3.4:80") + if ok { + t.Error("expected false for missing port") + } + + now := time.Now().UTC().Truncate(time.Second) + ps := &state.PortState{ + Open: true, + Hostnames: []string{testHostname}, + LastChecked: now, + } + + s.SetPortState("1.2.3.4:80", ps) + + got, ok := s.GetPortState("1.2.3.4:80") + if !ok { + t.Fatal("expected true for existing port") + } + + if !got.Open { + t.Error("expected open=true") + } + + // Delete the port state. + s.DeletePortState("1.2.3.4:80") + + _, ok = s.GetPortState("1.2.3.4:80") + if ok { + t.Error("expected false after delete") + } +} + +// TestGetAllPortKeys tests retrieving all port state keys. +func TestGetAllPortKeys(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + keys := s.GetAllPortKeys() + if len(keys) != 0 { + t.Errorf("expected 0 keys, got %d", len(keys)) + } + + now := time.Now().UTC() + + s.SetPortState("1.2.3.4:80", &state.PortState{ + Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now, + }) + + s.SetPortState("1.2.3.4:443", &state.PortState{ + Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now, + }) + + s.SetPortState("5.6.7.8:80", &state.PortState{ + Open: false, Hostnames: []string{"b.example.com"}, LastChecked: now, + }) + + keys = s.GetAllPortKeys() + if len(keys) != 3 { + t.Fatalf("expected 3 keys, got %d", len(keys)) + } + + keySet := make(map[string]bool) + for _, k := range keys { + keySet[k] = true + } + + for _, want := range []string{"1.2.3.4:80", "1.2.3.4:443", "5.6.7.8:80"} { + if !keySet[want] { + t.Errorf("missing key %q", want) + } + } +} + +// TestCertificateState_GetSet tests certificate state getter and setter. +func TestCertificateState_GetSet(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + _, ok := s.GetCertificateState("1.2.3.4:443:www.example.com") + if ok { + t.Error("expected false for missing certificate") + } + + now := time.Now().UTC().Truncate(time.Second) + cs := &state.CertificateState{ + CommonName: testHostname, + Issuer: "Let's Encrypt Authority X3", + NotAfter: now.Add(90 * 24 * time.Hour), + SubjectAlternativeNames: []string{testHostname}, + Status: "ok", + LastChecked: now, + } + + s.SetCertificateState("1.2.3.4:443:www.example.com", cs) + + got, ok := s.GetCertificateState("1.2.3.4:443:www.example.com") + if !ok { + t.Fatal("expected true for existing certificate") + } + + if got.CommonName != testHostname { + t.Errorf("CN: got %q", got.CommonName) + } + + if got.Issuer != "Let's Encrypt Authority X3" { + t.Errorf("issuer: got %q", got.Issuer) + } + + if got.Status != "ok" { + t.Errorf("status: got %q", got.Status) + } + + if len(got.SubjectAlternativeNames) != 1 { + t.Errorf("SANs: expected 1, got %d", len(got.SubjectAlternativeNames)) + } +} + +// TestCertificateState_ErrorField tests that the error field round-trips. +func TestCertificateState_ErrorField(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + now := time.Now().UTC().Truncate(time.Second) + cs := &state.CertificateState{ + Status: "error", + Error: "connection refused", + LastChecked: now, + } + + s.SetCertificateState("10.0.0.1:443:err.example.com", cs) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + got, ok := loaded.GetCertificateState("10.0.0.1:443:err.example.com") + if !ok { + t.Fatal("missing certificate after load") + } + + if got.Status != "error" { + t.Errorf("status: got %q, want %q", got.Status, "error") + } + + if got.Error != "connection refused" { + t.Errorf("error field: got %q", got.Error) + } +} + +// TestHostnameState_ErrorField tests that hostname error states round-trip. +func TestHostnameState_ErrorField(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + now := time.Now().UTC().Truncate(time.Second) + hs := &state.HostnameState{ + RecordsByNameserver: map[string]*state.NameserverRecordState{ + "ns1.example.com.": { + Records: nil, + Status: "error", + Error: "SERVFAIL", + LastChecked: now, + }, + }, + LastChecked: now, + } + + s.SetHostnameState("fail.example.com", hs) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + got, ok := loaded.GetHostnameState("fail.example.com") + if !ok { + t.Fatal("missing hostname after load") + } + + nsState := got.RecordsByNameserver["ns1.example.com."] + if nsState.Status != "error" { + t.Errorf("status: got %q, want %q", nsState.Status, "error") + } + + if nsState.Error != "SERVFAIL" { + t.Errorf("error: got %q, want %q", nsState.Error, "SERVFAIL") + } +} + +// TestSnapshotVersion verifies that the snapshot version is written correctly. +func TestSnapshotVersion(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + //nolint:gosec // test file with known path + data, err := os.ReadFile(filepath.Join(dir, "state.json")) + if err != nil { + t.Fatalf("reading state file: %v", err) + } + + var snap state.Snapshot + + err = json.Unmarshal(data, &snap) + if err != nil { + t.Fatalf("parsing state: %v", err) + } + + if snap.Version != 1 { + t.Errorf("version: got %d, want 1", snap.Version) + } + + if snap.LastUpdated.IsZero() { + t.Error("lastUpdated should be set after save") + } +} + +// TestSaveUpdatesLastUpdated verifies that Save sets LastUpdated. +func TestSaveUpdatesLastUpdated(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + before := time.Now().UTC().Add(-time.Second) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + after := time.Now().UTC().Add(time.Second) + + snap := s.GetSnapshot() + if snap.LastUpdated.Before(before) || snap.LastUpdated.After(after) { + t.Errorf( + "LastUpdated %v not in expected range [%v, %v]", + snap.LastUpdated, before, after, + ) + } +} + +// TestLoadPreservesExistingStateOnMissingFile verifies that when Load is +// called and the file doesn't exist, the existing in-memory state is kept. +func TestLoadPreservesExistingStateOnMissingFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + // Add some data before loading. + s.SetDomainState("pre.example.com", &state.DomainState{ + Nameservers: []string{"ns1.pre.example.com."}, + LastChecked: time.Now().UTC(), + }) + + err := s.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // Since the file doesn't exist, the existing data should persist. + _, ok := s.GetDomainState("pre.example.com") + if !ok { + t.Error("existing domain state should persist when file is missing") + } +} + +// TestConcurrentGetSet verifies that concurrent reads and writes don't race. +func TestConcurrentGetSet(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + const goroutines = 20 + + var wg sync.WaitGroup + + wg.Add(goroutines) + + for i := range goroutines { + go func(id int) { + defer wg.Done() + + now := time.Now().UTC() + key := string(rune('a' + id%26)) + + runConcurrentOps(s, key, now) + }(i) + } + + wg.Wait() +} + +// runConcurrentOps performs a series of get/set/delete operations for concurrency testing. +func runConcurrentOps(s *state.State, key string, now time.Time) { + const iterations = 50 + + for j := range iterations { + s.SetDomainState(key+".com", &state.DomainState{ + Nameservers: []string{"ns1." + key + ".com."}, + LastChecked: now, + }) + + s.GetDomainState(key + ".com") + + s.SetPortState(key+":80", &state.PortState{ + Open: j%2 == 0, Hostnames: []string{key + ".com"}, LastChecked: now, + }) + + s.GetPortState(key + ":80") + s.GetAllPortKeys() + s.GetSnapshot() + + s.SetHostnameState(key+".example.com", &state.HostnameState{ + RecordsByNameserver: map[string]*state.NameserverRecordState{ + "ns1.test.": { + Records: map[string][]string{"A": {"1.2.3.4"}}, + Status: "ok", + LastChecked: now, + }, + }, + LastChecked: now, + }) + + s.GetHostnameState(key + ".example.com") + + s.SetCertificateState("1.2.3.4:443:"+key+".com", &state.CertificateState{ + CommonName: key + ".com", Status: "ok", LastChecked: now, + }) + + s.GetCertificateState("1.2.3.4:443:" + key + ".com") + + if j%10 == 0 { + s.DeletePortState(key + ":80") + } + } +} + +// TestConcurrentSaveLoad verifies that concurrent Save and Load +// operations don't race or corrupt state. +func TestConcurrentSaveLoad(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + populateState(t, s) + + const goroutines = 10 + + var wg sync.WaitGroup + + wg.Add(goroutines) + + for i := range goroutines { + go func(id int) { + defer wg.Done() + + if id%2 == 0 { + _ = s.Save() + } else { + _ = s.Load() + } + }(i) + } + + wg.Wait() +} + +// TestPortStateRoundTripThroughSaveLoad verifies backward-compat deserialization +// works through the full Save→Load cycle with old-format data on disk. +func TestPortStateRoundTripThroughSaveLoad(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + // Write a state file with old single-hostname port format. + oldFormatJSON := `{ + "version": 1, + "lastUpdated": "2026-02-19T12:00:00Z", + "domains": {}, + "hostnames": {}, + "ports": { + "10.0.0.1:80": { + "open": true, + "hostname": "legacy.example.com", + "lastChecked": "2026-02-19T12:00:00Z" + }, + "10.0.0.1:443": { + "open": true, + "hostnames": ["modern.example.com"], + "lastChecked": "2026-02-19T12:00:00Z" + } + }, + "certificates": {} + }` + + err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(oldFormatJSON), 0o600) + if err != nil { + t.Fatalf("writing old format state: %v", err) + } + + s := state.NewForTestWithDataDir(dir) + + err = s.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // Verify old format was migrated. + ps80, ok := s.GetPortState("10.0.0.1:80") + if !ok { + t.Fatal("missing port 10.0.0.1:80") + } + + if len(ps80.Hostnames) != 1 || ps80.Hostnames[0] != "legacy.example.com" { + t.Errorf("old format port hostnames: got %v", ps80.Hostnames) + } + + // Verify new format loaded correctly. + ps443, ok := s.GetPortState("10.0.0.1:443") + if !ok { + t.Fatal("missing port 10.0.0.1:443") + } + + if len(ps443.Hostnames) != 1 || ps443.Hostnames[0] != "modern.example.com" { + t.Errorf("new format port hostnames: got %v", ps443.Hostnames) + } +} + +// TestMultipleSavesOverwrite verifies that subsequent saves overwrite the +// previous state completely. +func TestMultipleSavesOverwrite(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + s.SetDomainState("first.com", &state.DomainState{ + Nameservers: []string{"ns1.first.com."}, + LastChecked: time.Now().UTC(), + }) + + err := s.Save() + if err != nil { + t.Fatalf("first Save() error: %v", err) + } + + // Create a brand new state with the same dir and set different data. + s2 := state.NewForTestWithDataDir(dir) + + s2.SetDomainState("second.com", &state.DomainState{ + Nameservers: []string{"ns1.second.com."}, + LastChecked: time.Now().UTC(), + }) + + err = s2.Save() + if err != nil { + t.Fatalf("second Save() error: %v", err) + } + + // Load and verify only the second domain exists. + loaded := state.NewForTestWithDataDir(dir) + + err = loaded.Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + snap := loaded.GetSnapshot() + + if _, ok := snap.Domains["first.com"]; ok { + t.Error("first.com should not exist after overwrite") + } + + if _, ok := snap.Domains["second.com"]; !ok { + t.Error("second.com should exist after overwrite") + } +} + +// TestNewForTest verifies the test helper creates a valid empty state. +func TestNewForTest(t *testing.T) { + t.Parallel() + + s := state.NewForTest() + + snap := s.GetSnapshot() + + if snap.Version != 1 { + t.Errorf("version: got %d, want 1", snap.Version) + } + + if snap.Domains == nil { + t.Error("Domains map should be initialized") + } + + if snap.Hostnames == nil { + t.Error("Hostnames map should be initialized") + } + + if snap.Ports == nil { + t.Error("Ports map should be initialized") + } + + if snap.Certificates == nil { + t.Error("Certificates map should be initialized") + } +} + +// TestSaveFilePermissions verifies the saved file has restricted permissions. +func TestSaveFilePermissions(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + s := state.NewForTestWithDataDir(dir) + + err := s.Save() + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + info, err := os.Stat(filepath.Join(dir, "state.json")) + if err != nil { + t.Fatalf("stat error: %v", err) + } + + perm := info.Mode().Perm() + if perm != 0o600 { + t.Errorf("file permissions: got %o, want 600", perm) + } +} diff --git a/internal/state/state_test_helper.go b/internal/state/state_test_helper.go index 2142bfb..7e9aac7 100644 --- a/internal/state/state_test_helper.go +++ b/internal/state/state_test_helper.go @@ -20,3 +20,19 @@ func NewForTest() *State { config: &config.Config{DataDir: ""}, } } + +// NewForTestWithDataDir creates a State backed by the given directory +// for tests that need file persistence. +func NewForTestWithDataDir(dataDir string) *State { + return &State{ + log: slog.Default(), + snapshot: &Snapshot{ + Version: stateVersion, + Domains: make(map[string]*DomainState), + Hostnames: make(map[string]*HostnameState), + Ports: make(map[string]*PortState), + Certificates: make(map[string]*CertificateState), + }, + config: &config.Config{DataDir: dataDir}, + } +}