Files
dnswatcher/internal/state/state_test.go
clawbot 0774688c34
All checks were successful
check / check (push) Successful in 51s
test(state): add comprehensive test coverage for internal/state package
Add 32 tests covering:
- Save/Load round-trip for all state types (domains, hostnames, ports, certificates)
- Atomic write verification (no leftover temp files)
- PortState.UnmarshalJSON backward compatibility (old single-hostname format)
- Missing, corrupt, and empty state file handling
- Permission error handling (skipped when running as root)
- All getter/setter/delete methods for every state type
- GetSnapshot returns a value copy
- GetAllPortKeys enumeration
- Concurrent read/write safety with race detection
- Concurrent Save/Load safety
- File permission verification (0600)
- Multiple saves overwrite previous state
- LastUpdated timestamp updates on save
- Error field round-trips for certificates and hostnames
- Snapshot version correctness

Also adds NewForTestWithDataDir() helper for tests requiring file persistence.

Closes #70
2026-03-01 23:54:25 -08:00

1303 lines
29 KiB
Go

package state_test
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"testing"
"time"
"sneak.berlin/go/dnswatcher/internal/state"
)
const testHostname = "www.example.com"
// populateState fills a State with representative test data across all categories.
func populateState(t *testing.T, s *state.State) {
t.Helper()
now := time.Now().UTC().Truncate(time.Second)
s.SetDomainState("example.com", &state.DomainState{
Nameservers: []string{"ns1.example.com.", "ns2.example.com."},
LastChecked: now,
})
s.SetDomainState("example.org", &state.DomainState{
Nameservers: []string{"ns1.example.org."},
LastChecked: now,
})
s.SetHostnameState(testHostname, &state.HostnameState{
RecordsByNameserver: map[string]*state.NameserverRecordState{
"ns1.example.com.": {
Records: map[string][]string{
"A": {"93.184.216.34"},
"AAAA": {"2606:2800:220:1:248:1893:25c8:1946"},
},
Status: "ok",
LastChecked: now,
},
"ns2.example.com.": {
Records: map[string][]string{
"A": {"93.184.216.34"},
},
Status: "ok",
LastChecked: now,
},
},
LastChecked: now,
})
s.SetPortState("93.184.216.34:80", &state.PortState{
Open: true,
Hostnames: []string{testHostname},
LastChecked: now,
})
s.SetPortState("93.184.216.34:443", &state.PortState{
Open: true,
Hostnames: []string{testHostname, "api.example.com"},
LastChecked: now,
})
s.SetCertificateState("93.184.216.34:443:www.example.com", &state.CertificateState{
CommonName: testHostname,
Issuer: "DigiCert TLS RSA SHA256 2020 CA1",
NotAfter: now.Add(365 * 24 * time.Hour),
SubjectAlternativeNames: []string{testHostname, "example.com"},
Status: "ok",
LastChecked: now,
})
}
// TestSaveLoadRoundTrip_Domains verifies domain data survives a save/load cycle.
func TestSaveLoadRoundTrip_Domains(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
populateState(t, s)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
snap := loaded.GetSnapshot()
if len(snap.Domains) != 2 {
t.Fatalf("expected 2 domains, got %d", len(snap.Domains))
}
dom, ok := snap.Domains["example.com"]
if !ok {
t.Fatal("missing domain example.com")
}
if len(dom.Nameservers) != 2 {
t.Errorf("expected 2 nameservers, got %d", len(dom.Nameservers))
}
}
// TestSaveLoadRoundTrip_Hostnames verifies hostname data survives a save/load cycle.
func TestSaveLoadRoundTrip_Hostnames(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
populateState(t, s)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
snap := loaded.GetSnapshot()
if len(snap.Hostnames) != 1 {
t.Fatalf("expected 1 hostname, got %d", len(snap.Hostnames))
}
hn, ok := snap.Hostnames[testHostname]
if !ok {
t.Fatal("missing hostname")
}
if len(hn.RecordsByNameserver) != 2 {
t.Errorf("expected 2 NS entries, got %d", len(hn.RecordsByNameserver))
}
verifyNS1Records(t, hn)
}
// verifyNS1Records checks the ns1.example.com. records in the loaded hostname state.
func verifyNS1Records(t *testing.T, hn *state.HostnameState) {
t.Helper()
ns1, ok := hn.RecordsByNameserver["ns1.example.com."]
if !ok {
t.Fatal("missing nameserver ns1.example.com.")
}
aRecords := ns1.Records["A"]
if len(aRecords) != 1 || aRecords[0] != "93.184.216.34" {
t.Errorf("ns1 A records: got %v", aRecords)
}
aaaaRecords := ns1.Records["AAAA"]
if len(aaaaRecords) != 1 || aaaaRecords[0] != "2606:2800:220:1:248:1893:25c8:1946" {
t.Errorf("ns1 AAAA records: got %v", aaaaRecords)
}
if ns1.Status != "ok" {
t.Errorf("ns1 status: got %q, want %q", ns1.Status, "ok")
}
}
// TestSaveLoadRoundTrip_Ports verifies port data survives a save/load cycle.
func TestSaveLoadRoundTrip_Ports(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
populateState(t, s)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
snap := loaded.GetSnapshot()
if len(snap.Ports) != 2 {
t.Fatalf("expected 2 ports, got %d", len(snap.Ports))
}
port443, ok := snap.Ports["93.184.216.34:443"]
if !ok {
t.Fatal("missing port 93.184.216.34:443")
}
if !port443.Open {
t.Error("port 443: expected open")
}
if len(port443.Hostnames) != 2 {
t.Errorf("expected 2 hostnames, got %d", len(port443.Hostnames))
}
}
// TestSaveLoadRoundTrip_Certificates verifies certificate data survives a save/load cycle.
func TestSaveLoadRoundTrip_Certificates(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
populateState(t, s)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
snap := loaded.GetSnapshot()
if len(snap.Certificates) != 1 {
t.Fatalf("expected 1 certificate, got %d", len(snap.Certificates))
}
cert, ok := snap.Certificates["93.184.216.34:443:www.example.com"]
if !ok {
t.Fatal("missing certificate")
}
if cert.CommonName != testHostname {
t.Errorf("cert CN: got %q", cert.CommonName)
}
if cert.Issuer != "DigiCert TLS RSA SHA256 2020 CA1" {
t.Errorf("cert issuer: got %q", cert.Issuer)
}
if len(cert.SubjectAlternativeNames) != 2 {
t.Errorf("cert SANs: expected 2, got %d", len(cert.SubjectAlternativeNames))
}
if cert.Status != "ok" {
t.Errorf("cert status: got %q", cert.Status)
}
}
// TestSaveCreatesDirectory verifies Save creates the data directory if missing.
func TestSaveCreatesDirectory(t *testing.T) {
t.Parallel()
dir := t.TempDir()
nested := filepath.Join(dir, "nested", "deep")
s := state.NewForTestWithDataDir(nested)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
statePath := filepath.Join(nested, "state.json")
info, err := os.Stat(statePath)
if err != nil {
t.Fatalf("state file not found: %v", err)
}
if info.Size() == 0 {
t.Error("state file is empty")
}
}
// TestSaveAtomicWrite verifies the atomic write pattern (no leftover temp files).
func TestSaveAtomicWrite(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
populateState(t, s)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
statePath := filepath.Join(dir, "state.json")
tmpPath := statePath + ".tmp"
// No .tmp file should remain after successful save.
_, err = os.Stat(tmpPath)
if !os.IsNotExist(err) {
t.Error("temp file should not exist after successful save")
}
// Modify state and save again; verify updated content.
s.SetDomainState("new.example.com", &state.DomainState{
Nameservers: []string{"ns1.new.example.com."},
LastChecked: time.Now().UTC(),
})
err = s.Save()
if err != nil {
t.Fatalf("second Save() error: %v", err)
}
_, err = os.Stat(tmpPath)
if !os.IsNotExist(err) {
t.Error("temp file should not exist after second save")
}
// Verify new domain is present by loading.
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
_, ok := loaded.GetDomainState("new.example.com")
if !ok {
t.Error("new domain not found after second save")
}
}
// TestLoadMissingFile verifies that loading a nonexistent file is not an error.
func TestLoadMissingFile(t *testing.T) {
t.Parallel()
s := state.NewForTestWithDataDir(t.TempDir())
err := s.Load()
if err != nil {
t.Fatalf("Load() with missing file should not error: %v", err)
}
snap := s.GetSnapshot()
if len(snap.Domains) != 0 {
t.Error("expected empty domains")
}
if len(snap.Hostnames) != 0 {
t.Error("expected empty hostnames")
}
if len(snap.Ports) != 0 {
t.Error("expected empty ports")
}
if len(snap.Certificates) != 0 {
t.Error("expected empty certificates")
}
}
// TestLoadCorruptFile verifies that a corrupt state file returns an error.
func TestLoadCorruptFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
err := os.WriteFile(
filepath.Join(dir, "state.json"),
[]byte("not valid json{{{"),
0o600,
)
if err != nil {
t.Fatalf("writing corrupt file: %v", err)
}
s := state.NewForTestWithDataDir(dir)
err = s.Load()
if err == nil {
t.Fatal("Load() should return error for corrupt file")
}
}
// TestLoadEmptyFile verifies that an empty state file returns an error.
func TestLoadEmptyFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(""), 0o600)
if err != nil {
t.Fatalf("writing empty file: %v", err)
}
s := state.NewForTestWithDataDir(dir)
err = s.Load()
if err == nil {
t.Fatal("Load() should return error for empty file")
}
}
// TestLoadReadPermissionError verifies that an unreadable state file returns an error.
func TestLoadReadPermissionError(t *testing.T) {
t.Parallel()
if os.Getuid() == 0 {
t.Skip("skipping permission test when running as root")
}
dir := t.TempDir()
statePath := filepath.Join(dir, "state.json")
err := os.WriteFile(statePath, []byte("{}"), 0o600)
if err != nil {
t.Fatalf("writing file: %v", err)
}
err = os.Chmod(statePath, 0o000)
if err != nil {
t.Fatalf("chmod: %v", err)
}
t.Cleanup(func() {
_ = os.Chmod(statePath, 0o600)
})
s := state.NewForTestWithDataDir(dir)
err = s.Load()
if err == nil {
t.Fatal("Load() should return error for unreadable file")
}
}
// TestSaveWritePermissionError verifies that Save returns an error
// when the directory is not writable.
func TestSaveWritePermissionError(t *testing.T) {
t.Parallel()
if os.Getuid() == 0 {
t.Skip("skipping permission test when running as root")
}
dir := t.TempDir()
dataDir := filepath.Join(dir, "readonly")
err := os.MkdirAll(dataDir, 0o700)
if err != nil {
t.Fatalf("creating dir: %v", err)
}
//nolint:gosec // intentionally making dir read-only+execute for test
err = os.Chmod(dataDir, 0o500)
if err != nil {
t.Fatalf("chmod: %v", err)
}
t.Cleanup(func() {
//nolint:gosec // restoring write permission for cleanup
_ = os.Chmod(dataDir, 0o700)
})
s := state.NewForTestWithDataDir(dataDir)
err = s.Save()
if err == nil {
t.Fatal("Save() should return error for read-only directory")
}
}
// TestPortStateUnmarshalJSON_NewFormat verifies deserialization of the
// current multi-hostname format.
func TestPortStateUnmarshalJSON_NewFormat(t *testing.T) {
t.Parallel()
data := []byte(`{
"open": true,
"hostnames": ["www.example.com", "api.example.com"],
"lastChecked": "2026-02-19T12:00:00Z"
}`)
var ps state.PortState
err := json.Unmarshal(data, &ps)
if err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if !ps.Open {
t.Error("expected open=true")
}
if len(ps.Hostnames) != 2 {
t.Fatalf("expected 2 hostnames, got %d", len(ps.Hostnames))
}
if ps.Hostnames[0] != testHostname {
t.Errorf("hostname[0]: got %q", ps.Hostnames[0])
}
if ps.Hostnames[1] != "api.example.com" {
t.Errorf("hostname[1]: got %q", ps.Hostnames[1])
}
if ps.LastChecked.IsZero() {
t.Error("expected non-zero LastChecked")
}
}
// TestPortStateUnmarshalJSON_OldFormat verifies backward-compatible
// deserialization of the old single-hostname format.
func TestPortStateUnmarshalJSON_OldFormat(t *testing.T) {
t.Parallel()
data := []byte(`{
"open": false,
"hostname": "legacy.example.com",
"lastChecked": "2026-01-15T10:30:00Z"
}`)
var ps state.PortState
err := json.Unmarshal(data, &ps)
if err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if ps.Open {
t.Error("expected open=false")
}
if len(ps.Hostnames) != 1 {
t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames))
}
if ps.Hostnames[0] != "legacy.example.com" {
t.Errorf("hostname: got %q", ps.Hostnames[0])
}
}
// TestPortStateUnmarshalJSON_EmptyHostnames verifies that when neither
// hostnames nor hostname is present, Hostnames remains nil.
func TestPortStateUnmarshalJSON_EmptyHostnames(t *testing.T) {
t.Parallel()
data := []byte(`{
"open": true,
"lastChecked": "2026-02-19T12:00:00Z"
}`)
var ps state.PortState
err := json.Unmarshal(data, &ps)
if err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if len(ps.Hostnames) != 0 {
t.Errorf("expected empty hostnames, got %v", ps.Hostnames)
}
}
// TestPortStateUnmarshalJSON_InvalidJSON verifies that invalid JSON returns an error.
func TestPortStateUnmarshalJSON_InvalidJSON(t *testing.T) {
t.Parallel()
data := []byte(`{invalid json}`)
var ps state.PortState
err := json.Unmarshal(data, &ps)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// TestPortStateUnmarshalJSON_BothFormats verifies that when both "hostnames"
// and "hostname" are present, the new format takes precedence.
func TestPortStateUnmarshalJSON_BothFormats(t *testing.T) {
t.Parallel()
data := []byte(`{
"open": true,
"hostnames": ["new.example.com"],
"hostname": "old.example.com",
"lastChecked": "2026-02-19T12:00:00Z"
}`)
var ps state.PortState
err := json.Unmarshal(data, &ps)
if err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
// When hostnames is non-empty, the old hostname field is ignored.
if len(ps.Hostnames) != 1 {
t.Fatalf("expected 1 hostname, got %d", len(ps.Hostnames))
}
if ps.Hostnames[0] != "new.example.com" {
t.Errorf("hostname: got %q, want %q", ps.Hostnames[0], "new.example.com")
}
}
// TestGetSnapshot_ReturnsCopy verifies that GetSnapshot returns a value copy.
func TestGetSnapshot_ReturnsCopy(t *testing.T) {
t.Parallel()
s := state.NewForTest()
populateState(t, s)
snap := s.GetSnapshot()
// Mutate the returned snapshot's map.
snap.Domains["mutated.example.com"] = &state.DomainState{
Nameservers: []string{"ns1.mutated.com."},
}
// The original domain data should still be accessible.
_, ok := s.GetDomainState("example.com")
if !ok {
t.Error("original domain should still exist")
}
}
// TestDomainState_GetSet tests domain state getter and setter.
func TestDomainState_GetSet(t *testing.T) {
t.Parallel()
s := state.NewForTest()
// Get on missing key returns false.
_, ok := s.GetDomainState("nonexistent.com")
if ok {
t.Error("expected false for missing domain")
}
now := time.Now().UTC().Truncate(time.Second)
ds := &state.DomainState{
Nameservers: []string{"ns1.test.com."},
LastChecked: now,
}
s.SetDomainState("test.com", ds)
got, ok := s.GetDomainState("test.com")
if !ok {
t.Fatal("expected true for existing domain")
}
if len(got.Nameservers) != 1 || got.Nameservers[0] != "ns1.test.com." {
t.Errorf("nameservers: got %v", got.Nameservers)
}
if !got.LastChecked.Equal(now) {
t.Errorf("lastChecked: got %v, want %v", got.LastChecked, now)
}
// Overwrite.
ds2 := &state.DomainState{
Nameservers: []string{"ns1.test.com.", "ns2.test.com."},
LastChecked: now.Add(time.Hour),
}
s.SetDomainState("test.com", ds2)
got2, ok := s.GetDomainState("test.com")
if !ok {
t.Fatal("expected true after overwrite")
}
if len(got2.Nameservers) != 2 {
t.Errorf("expected 2 nameservers after overwrite, got %d", len(got2.Nameservers))
}
}
// TestHostnameState_GetSet tests hostname state getter and setter.
func TestHostnameState_GetSet(t *testing.T) {
t.Parallel()
s := state.NewForTest()
_, ok := s.GetHostnameState("missing.example.com")
if ok {
t.Error("expected false for missing hostname")
}
now := time.Now().UTC().Truncate(time.Second)
hs := &state.HostnameState{
RecordsByNameserver: map[string]*state.NameserverRecordState{
"ns1.example.com.": {
Records: map[string][]string{"A": {"1.2.3.4"}},
Status: "ok",
LastChecked: now,
},
},
LastChecked: now,
}
s.SetHostnameState(testHostname, hs)
got, ok := s.GetHostnameState(testHostname)
if !ok {
t.Fatal("expected true for existing hostname")
}
nsState, ok := got.RecordsByNameserver["ns1.example.com."]
if !ok {
t.Fatal("missing nameserver entry")
}
if nsState.Status != "ok" {
t.Errorf("status: got %q", nsState.Status)
}
aRecords := nsState.Records["A"]
if len(aRecords) != 1 || aRecords[0] != "1.2.3.4" {
t.Errorf("A records: got %v", aRecords)
}
}
// TestPortState_GetSetDelete tests port state getter, setter, and deletion.
func TestPortState_GetSetDelete(t *testing.T) {
t.Parallel()
s := state.NewForTest()
_, ok := s.GetPortState("1.2.3.4:80")
if ok {
t.Error("expected false for missing port")
}
now := time.Now().UTC().Truncate(time.Second)
ps := &state.PortState{
Open: true,
Hostnames: []string{testHostname},
LastChecked: now,
}
s.SetPortState("1.2.3.4:80", ps)
got, ok := s.GetPortState("1.2.3.4:80")
if !ok {
t.Fatal("expected true for existing port")
}
if !got.Open {
t.Error("expected open=true")
}
// Delete the port state.
s.DeletePortState("1.2.3.4:80")
_, ok = s.GetPortState("1.2.3.4:80")
if ok {
t.Error("expected false after delete")
}
}
// TestGetAllPortKeys tests retrieving all port state keys.
func TestGetAllPortKeys(t *testing.T) {
t.Parallel()
s := state.NewForTest()
keys := s.GetAllPortKeys()
if len(keys) != 0 {
t.Errorf("expected 0 keys, got %d", len(keys))
}
now := time.Now().UTC()
s.SetPortState("1.2.3.4:80", &state.PortState{
Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now,
})
s.SetPortState("1.2.3.4:443", &state.PortState{
Open: true, Hostnames: []string{"a.example.com"}, LastChecked: now,
})
s.SetPortState("5.6.7.8:80", &state.PortState{
Open: false, Hostnames: []string{"b.example.com"}, LastChecked: now,
})
keys = s.GetAllPortKeys()
if len(keys) != 3 {
t.Fatalf("expected 3 keys, got %d", len(keys))
}
keySet := make(map[string]bool)
for _, k := range keys {
keySet[k] = true
}
for _, want := range []string{"1.2.3.4:80", "1.2.3.4:443", "5.6.7.8:80"} {
if !keySet[want] {
t.Errorf("missing key %q", want)
}
}
}
// TestCertificateState_GetSet tests certificate state getter and setter.
func TestCertificateState_GetSet(t *testing.T) {
t.Parallel()
s := state.NewForTest()
_, ok := s.GetCertificateState("1.2.3.4:443:www.example.com")
if ok {
t.Error("expected false for missing certificate")
}
now := time.Now().UTC().Truncate(time.Second)
cs := &state.CertificateState{
CommonName: testHostname,
Issuer: "Let's Encrypt Authority X3",
NotAfter: now.Add(90 * 24 * time.Hour),
SubjectAlternativeNames: []string{testHostname},
Status: "ok",
LastChecked: now,
}
s.SetCertificateState("1.2.3.4:443:www.example.com", cs)
got, ok := s.GetCertificateState("1.2.3.4:443:www.example.com")
if !ok {
t.Fatal("expected true for existing certificate")
}
if got.CommonName != testHostname {
t.Errorf("CN: got %q", got.CommonName)
}
if got.Issuer != "Let's Encrypt Authority X3" {
t.Errorf("issuer: got %q", got.Issuer)
}
if got.Status != "ok" {
t.Errorf("status: got %q", got.Status)
}
if len(got.SubjectAlternativeNames) != 1 {
t.Errorf("SANs: expected 1, got %d", len(got.SubjectAlternativeNames))
}
}
// TestCertificateState_ErrorField tests that the error field round-trips.
func TestCertificateState_ErrorField(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
now := time.Now().UTC().Truncate(time.Second)
cs := &state.CertificateState{
Status: "error",
Error: "connection refused",
LastChecked: now,
}
s.SetCertificateState("10.0.0.1:443:err.example.com", cs)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
got, ok := loaded.GetCertificateState("10.0.0.1:443:err.example.com")
if !ok {
t.Fatal("missing certificate after load")
}
if got.Status != "error" {
t.Errorf("status: got %q, want %q", got.Status, "error")
}
if got.Error != "connection refused" {
t.Errorf("error field: got %q", got.Error)
}
}
// TestHostnameState_ErrorField tests that hostname error states round-trip.
func TestHostnameState_ErrorField(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
now := time.Now().UTC().Truncate(time.Second)
hs := &state.HostnameState{
RecordsByNameserver: map[string]*state.NameserverRecordState{
"ns1.example.com.": {
Records: nil,
Status: "error",
Error: "SERVFAIL",
LastChecked: now,
},
},
LastChecked: now,
}
s.SetHostnameState("fail.example.com", hs)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
got, ok := loaded.GetHostnameState("fail.example.com")
if !ok {
t.Fatal("missing hostname after load")
}
nsState := got.RecordsByNameserver["ns1.example.com."]
if nsState.Status != "error" {
t.Errorf("status: got %q, want %q", nsState.Status, "error")
}
if nsState.Error != "SERVFAIL" {
t.Errorf("error: got %q, want %q", nsState.Error, "SERVFAIL")
}
}
// TestSnapshotVersion verifies that the snapshot version is written correctly.
func TestSnapshotVersion(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
//nolint:gosec // test file with known path
data, err := os.ReadFile(filepath.Join(dir, "state.json"))
if err != nil {
t.Fatalf("reading state file: %v", err)
}
var snap state.Snapshot
err = json.Unmarshal(data, &snap)
if err != nil {
t.Fatalf("parsing state: %v", err)
}
if snap.Version != 1 {
t.Errorf("version: got %d, want 1", snap.Version)
}
if snap.LastUpdated.IsZero() {
t.Error("lastUpdated should be set after save")
}
}
// TestSaveUpdatesLastUpdated verifies that Save sets LastUpdated.
func TestSaveUpdatesLastUpdated(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
before := time.Now().UTC().Add(-time.Second)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
after := time.Now().UTC().Add(time.Second)
snap := s.GetSnapshot()
if snap.LastUpdated.Before(before) || snap.LastUpdated.After(after) {
t.Errorf(
"LastUpdated %v not in expected range [%v, %v]",
snap.LastUpdated, before, after,
)
}
}
// TestLoadPreservesExistingStateOnMissingFile verifies that when Load is
// called and the file doesn't exist, the existing in-memory state is kept.
func TestLoadPreservesExistingStateOnMissingFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
// Add some data before loading.
s.SetDomainState("pre.example.com", &state.DomainState{
Nameservers: []string{"ns1.pre.example.com."},
LastChecked: time.Now().UTC(),
})
err := s.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
// Since the file doesn't exist, the existing data should persist.
_, ok := s.GetDomainState("pre.example.com")
if !ok {
t.Error("existing domain state should persist when file is missing")
}
}
// TestConcurrentGetSet verifies that concurrent reads and writes don't race.
func TestConcurrentGetSet(t *testing.T) {
t.Parallel()
s := state.NewForTest()
const goroutines = 20
var wg sync.WaitGroup
wg.Add(goroutines)
for i := range goroutines {
go func(id int) {
defer wg.Done()
now := time.Now().UTC()
key := string(rune('a' + id%26))
runConcurrentOps(s, key, now)
}(i)
}
wg.Wait()
}
// runConcurrentOps performs a series of get/set/delete operations for concurrency testing.
func runConcurrentOps(s *state.State, key string, now time.Time) {
const iterations = 50
for j := range iterations {
s.SetDomainState(key+".com", &state.DomainState{
Nameservers: []string{"ns1." + key + ".com."},
LastChecked: now,
})
s.GetDomainState(key + ".com")
s.SetPortState(key+":80", &state.PortState{
Open: j%2 == 0, Hostnames: []string{key + ".com"}, LastChecked: now,
})
s.GetPortState(key + ":80")
s.GetAllPortKeys()
s.GetSnapshot()
s.SetHostnameState(key+".example.com", &state.HostnameState{
RecordsByNameserver: map[string]*state.NameserverRecordState{
"ns1.test.": {
Records: map[string][]string{"A": {"1.2.3.4"}},
Status: "ok",
LastChecked: now,
},
},
LastChecked: now,
})
s.GetHostnameState(key + ".example.com")
s.SetCertificateState("1.2.3.4:443:"+key+".com", &state.CertificateState{
CommonName: key + ".com", Status: "ok", LastChecked: now,
})
s.GetCertificateState("1.2.3.4:443:" + key + ".com")
if j%10 == 0 {
s.DeletePortState(key + ":80")
}
}
}
// TestConcurrentSaveLoad verifies that concurrent Save and Load
// operations don't race or corrupt state.
func TestConcurrentSaveLoad(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
populateState(t, s)
const goroutines = 10
var wg sync.WaitGroup
wg.Add(goroutines)
for i := range goroutines {
go func(id int) {
defer wg.Done()
if id%2 == 0 {
_ = s.Save()
} else {
_ = s.Load()
}
}(i)
}
wg.Wait()
}
// TestPortStateRoundTripThroughSaveLoad verifies backward-compat deserialization
// works through the full Save→Load cycle with old-format data on disk.
func TestPortStateRoundTripThroughSaveLoad(t *testing.T) {
t.Parallel()
dir := t.TempDir()
// Write a state file with old single-hostname port format.
oldFormatJSON := `{
"version": 1,
"lastUpdated": "2026-02-19T12:00:00Z",
"domains": {},
"hostnames": {},
"ports": {
"10.0.0.1:80": {
"open": true,
"hostname": "legacy.example.com",
"lastChecked": "2026-02-19T12:00:00Z"
},
"10.0.0.1:443": {
"open": true,
"hostnames": ["modern.example.com"],
"lastChecked": "2026-02-19T12:00:00Z"
}
},
"certificates": {}
}`
err := os.WriteFile(filepath.Join(dir, "state.json"), []byte(oldFormatJSON), 0o600)
if err != nil {
t.Fatalf("writing old format state: %v", err)
}
s := state.NewForTestWithDataDir(dir)
err = s.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
// Verify old format was migrated.
ps80, ok := s.GetPortState("10.0.0.1:80")
if !ok {
t.Fatal("missing port 10.0.0.1:80")
}
if len(ps80.Hostnames) != 1 || ps80.Hostnames[0] != "legacy.example.com" {
t.Errorf("old format port hostnames: got %v", ps80.Hostnames)
}
// Verify new format loaded correctly.
ps443, ok := s.GetPortState("10.0.0.1:443")
if !ok {
t.Fatal("missing port 10.0.0.1:443")
}
if len(ps443.Hostnames) != 1 || ps443.Hostnames[0] != "modern.example.com" {
t.Errorf("new format port hostnames: got %v", ps443.Hostnames)
}
}
// TestMultipleSavesOverwrite verifies that subsequent saves overwrite the
// previous state completely.
func TestMultipleSavesOverwrite(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
s.SetDomainState("first.com", &state.DomainState{
Nameservers: []string{"ns1.first.com."},
LastChecked: time.Now().UTC(),
})
err := s.Save()
if err != nil {
t.Fatalf("first Save() error: %v", err)
}
// Create a brand new state with the same dir and set different data.
s2 := state.NewForTestWithDataDir(dir)
s2.SetDomainState("second.com", &state.DomainState{
Nameservers: []string{"ns1.second.com."},
LastChecked: time.Now().UTC(),
})
err = s2.Save()
if err != nil {
t.Fatalf("second Save() error: %v", err)
}
// Load and verify only the second domain exists.
loaded := state.NewForTestWithDataDir(dir)
err = loaded.Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
snap := loaded.GetSnapshot()
if _, ok := snap.Domains["first.com"]; ok {
t.Error("first.com should not exist after overwrite")
}
if _, ok := snap.Domains["second.com"]; !ok {
t.Error("second.com should exist after overwrite")
}
}
// TestNewForTest verifies the test helper creates a valid empty state.
func TestNewForTest(t *testing.T) {
t.Parallel()
s := state.NewForTest()
snap := s.GetSnapshot()
if snap.Version != 1 {
t.Errorf("version: got %d, want 1", snap.Version)
}
if snap.Domains == nil {
t.Error("Domains map should be initialized")
}
if snap.Hostnames == nil {
t.Error("Hostnames map should be initialized")
}
if snap.Ports == nil {
t.Error("Ports map should be initialized")
}
if snap.Certificates == nil {
t.Error("Certificates map should be initialized")
}
}
// TestSaveFilePermissions verifies the saved file has restricted permissions.
func TestSaveFilePermissions(t *testing.T) {
t.Parallel()
dir := t.TempDir()
s := state.NewForTestWithDataDir(dir)
err := s.Save()
if err != nil {
t.Fatalf("Save() error: %v", err)
}
info, err := os.Stat(filepath.Join(dir, "state.json"))
if err != nil {
t.Fatalf("stat error: %v", err)
}
perm := info.Mode().Perm()
if perm != 0o600 {
t.Errorf("file permissions: got %o, want 600", perm)
}
}