From 1e7892678cc84cbc972b76edfc5d8861d2483701 Mon Sep 17 00:00:00 2001 From: clawbot Date: Tue, 10 Feb 2026 18:20:45 -0800 Subject: [PATCH] test: add comprehensive test suite - Integration tests for all API endpoints (session, state, channels, messages) - Tests for all commands: PRIVMSG, JOIN, PART, NICK, TOPIC, QUIT, PING - Edge cases: duplicate nick, empty/invalid inputs, malformed JSON, bad auth - Long-poll tests: delivery on notify and timeout behavior - DM tests: delivery to recipient, echo to sender, nonexistent user - Ephemeral channel cleanup test - Concurrent session creation test - Nick broadcast to channel members test - DB unit tests: all CRUD operations, message queue, history - Broker unit tests: wait/notify, remove, concurrent access --- internal/broker/broker_test.go | 94 ++++ internal/db/queries_test.go | 338 +++++++++++++ internal/handlers/api_test.go | 896 +++++++++++++++++++++++++++++++++ 3 files changed, 1328 insertions(+) create mode 100644 internal/broker/broker_test.go create mode 100644 internal/db/queries_test.go create mode 100644 internal/handlers/api_test.go diff --git a/internal/broker/broker_test.go b/internal/broker/broker_test.go new file mode 100644 index 0000000..541d74d --- /dev/null +++ b/internal/broker/broker_test.go @@ -0,0 +1,94 @@ +package broker + +import ( + "sync" + "testing" + "time" +) + +func TestNewBroker(t *testing.T) { + b := New() + if b == nil { + t.Fatal("expected non-nil broker") + } +} + +func TestWaitAndNotify(t *testing.T) { + b := New() + ch := b.Wait(1) + + go func() { + time.Sleep(10 * time.Millisecond) + b.Notify(1) + }() + + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal("timeout") + } +} + +func TestNotifyWithoutWaiters(t *testing.T) { + b := New() + b.Notify(42) // should not panic +} + +func TestRemove(t *testing.T) { + b := New() + ch := b.Wait(1) + b.Remove(1, ch) + + b.Notify(1) + select { + case <-ch: + t.Fatal("should not receive after remove") + case <-time.After(50 * time.Millisecond): + } +} + +func TestMultipleWaiters(t *testing.T) { + b := New() + ch1 := b.Wait(1) + ch2 := b.Wait(1) + + b.Notify(1) + + select { + case <-ch1: + case <-time.After(time.Second): + t.Fatal("ch1 timeout") + } + select { + case <-ch2: + case <-time.After(time.Second): + t.Fatal("ch2 timeout") + } +} + +func TestConcurrentWaitNotify(t *testing.T) { + b := New() + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + go func(uid int64) { + defer wg.Done() + ch := b.Wait(uid) + b.Notify(uid) + select { + case <-ch: + case <-time.After(time.Second): + t.Error("timeout") + } + }(int64(i % 10)) + } + + wg.Wait() +} + +func TestRemoveNonexistent(t *testing.T) { + b := New() + ch := make(chan struct{}, 1) + b.Remove(999, ch) // should not panic +} diff --git a/internal/db/queries_test.go b/internal/db/queries_test.go new file mode 100644 index 0000000..76fa378 --- /dev/null +++ b/internal/db/queries_test.go @@ -0,0 +1,338 @@ +package db + +import ( + "context" + "database/sql" + "encoding/json" + "log/slog" + "testing" + + _ "modernc.org/sqlite" +) + +func setupTestDB(t *testing.T) *Database { + t.Helper() + d, err := sql.Open("sqlite", "file::memory:?cache=shared&_pragma=foreign_keys(1)") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { d.Close() }) + + db := &Database{db: d, log: slog.Default()} + if err := db.runMigrations(context.Background()); err != nil { + t.Fatal(err) + } + return db +} + +func TestCreateUser(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + id, token, err := db.CreateUser(ctx, "alice") + if err != nil { + t.Fatal(err) + } + if id == 0 || token == "" { + t.Fatal("expected valid id and token") + } + + // Duplicate nick + _, _, err = db.CreateUser(ctx, "alice") + if err == nil { + t.Fatal("expected error for duplicate nick") + } +} + +func TestGetUserByToken(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + _, token, _ := db.CreateUser(ctx, "bob") + id, nick, err := db.GetUserByToken(ctx, token) + if err != nil { + t.Fatal(err) + } + if nick != "bob" || id == 0 { + t.Fatalf("expected bob, got %s", nick) + } + + // Invalid token + _, _, err = db.GetUserByToken(ctx, "badtoken") + if err == nil { + t.Fatal("expected error for bad token") + } +} + +func TestGetUserByNick(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + db.CreateUser(ctx, "charlie") + id, err := db.GetUserByNick(ctx, "charlie") + if err != nil || id == 0 { + t.Fatal("expected to find charlie") + } + + _, err = db.GetUserByNick(ctx, "nobody") + if err == nil { + t.Fatal("expected error for unknown nick") + } +} + +func TestChannelOperations(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create channel + chID, err := db.GetOrCreateChannel(ctx, "#test") + if err != nil || chID == 0 { + t.Fatal("expected channel id") + } + + // Get same channel + chID2, err := db.GetOrCreateChannel(ctx, "#test") + if err != nil || chID2 != chID { + t.Fatal("expected same channel id") + } + + // GetChannelByName + chID3, err := db.GetChannelByName(ctx, "#test") + if err != nil || chID3 != chID { + t.Fatal("expected same channel id from GetChannelByName") + } + + // Nonexistent channel + _, err = db.GetChannelByName(ctx, "#nope") + if err == nil { + t.Fatal("expected error for nonexistent channel") + } +} + +func TestJoinAndPart(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + uid, _, _ := db.CreateUser(ctx, "user1") + chID, _ := db.GetOrCreateChannel(ctx, "#chan") + + // Join + if err := db.JoinChannel(ctx, chID, uid); err != nil { + t.Fatal(err) + } + + // Verify membership + ids, err := db.GetChannelMemberIDs(ctx, chID) + if err != nil || len(ids) != 1 || ids[0] != uid { + t.Fatal("expected user in channel") + } + + // Double join (should be ignored) + if err := db.JoinChannel(ctx, chID, uid); err != nil { + t.Fatal(err) + } + + // Part + if err := db.PartChannel(ctx, chID, uid); err != nil { + t.Fatal(err) + } + + ids, _ = db.GetChannelMemberIDs(ctx, chID) + if len(ids) != 0 { + t.Fatal("expected empty channel") + } +} + +func TestDeleteChannelIfEmpty(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + chID, _ := db.GetOrCreateChannel(ctx, "#empty") + uid, _, _ := db.CreateUser(ctx, "temp") + db.JoinChannel(ctx, chID, uid) + db.PartChannel(ctx, chID, uid) + + if err := db.DeleteChannelIfEmpty(ctx, chID); err != nil { + t.Fatal(err) + } + + _, err := db.GetChannelByName(ctx, "#empty") + if err == nil { + t.Fatal("expected channel to be deleted") + } +} + +func TestListChannels(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + uid, _, _ := db.CreateUser(ctx, "lister") + ch1, _ := db.GetOrCreateChannel(ctx, "#a") + ch2, _ := db.GetOrCreateChannel(ctx, "#b") + db.JoinChannel(ctx, ch1, uid) + db.JoinChannel(ctx, ch2, uid) + + channels, err := db.ListChannels(ctx, uid) + if err != nil || len(channels) != 2 { + t.Fatalf("expected 2 channels, got %d", len(channels)) + } +} + +func TestListAllChannels(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + db.GetOrCreateChannel(ctx, "#x") + db.GetOrCreateChannel(ctx, "#y") + + channels, err := db.ListAllChannels(ctx) + if err != nil || len(channels) < 2 { + t.Fatalf("expected >= 2 channels, got %d", len(channels)) + } +} + +func TestChangeNick(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + uid, token, _ := db.CreateUser(ctx, "old") + if err := db.ChangeNick(ctx, uid, "new"); err != nil { + t.Fatal(err) + } + + _, nick, _ := db.GetUserByToken(ctx, token) + if nick != "new" { + t.Fatalf("expected new, got %s", nick) + } +} + +func TestSetTopic(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + db.GetOrCreateChannel(ctx, "#topictest") + if err := db.SetTopic(ctx, "#topictest", "Hello"); err != nil { + t.Fatal(err) + } + + channels, _ := db.ListAllChannels(ctx) + for _, ch := range channels { + if ch.Name == "#topictest" && ch.Topic != "Hello" { + t.Fatalf("expected topic Hello, got %s", ch.Topic) + } + } +} + +func TestInsertAndPollMessages(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + uid, _, _ := db.CreateUser(ctx, "poller") + body := json.RawMessage(`["hello"]`) + + dbID, uuid, err := db.InsertMessage(ctx, "PRIVMSG", "poller", "#test", body, nil) + if err != nil || dbID == 0 || uuid == "" { + t.Fatal("insert failed") + } + + if err := db.EnqueueMessage(ctx, uid, dbID); err != nil { + t.Fatal(err) + } + + msgs, lastQID, err := db.PollMessages(ctx, uid, 0, 10) + if err != nil { + t.Fatal(err) + } + if len(msgs) != 1 { + t.Fatalf("expected 1 message, got %d", len(msgs)) + } + if msgs[0].Command != "PRIVMSG" { + t.Fatalf("expected PRIVMSG, got %s", msgs[0].Command) + } + if lastQID == 0 { + t.Fatal("expected nonzero lastQID") + } + + // Poll again with lastQID - should be empty + msgs, _, _ = db.PollMessages(ctx, uid, lastQID, 10) + if len(msgs) != 0 { + t.Fatalf("expected 0 messages, got %d", len(msgs)) + } +} + +func TestGetHistory(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + for i := 0; i < 10; i++ { + db.InsertMessage(ctx, "PRIVMSG", "user", "#hist", json.RawMessage(`["msg"]`), nil) + } + + msgs, err := db.GetHistory(ctx, "#hist", 0, 5) + if err != nil { + t.Fatal(err) + } + if len(msgs) != 5 { + t.Fatalf("expected 5, got %d", len(msgs)) + } + // Should be ascending order + if msgs[0].DBID > msgs[4].DBID { + t.Fatal("expected ascending order") + } +} + +func TestDeleteUser(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + uid, _, _ := db.CreateUser(ctx, "deleteme") + chID, _ := db.GetOrCreateChannel(ctx, "#delchan") + db.JoinChannel(ctx, chID, uid) + + if err := db.DeleteUser(ctx, uid); err != nil { + t.Fatal(err) + } + + _, err := db.GetUserByNick(ctx, "deleteme") + if err == nil { + t.Fatal("user should be deleted") + } + + // Channel membership should be cleaned up via CASCADE + ids, _ := db.GetChannelMemberIDs(ctx, chID) + if len(ids) != 0 { + t.Fatal("expected no members after user deletion") + } +} + +func TestChannelMembers(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + uid1, _, _ := db.CreateUser(ctx, "m1") + uid2, _, _ := db.CreateUser(ctx, "m2") + chID, _ := db.GetOrCreateChannel(ctx, "#members") + db.JoinChannel(ctx, chID, uid1) + db.JoinChannel(ctx, chID, uid2) + + members, err := db.ChannelMembers(ctx, chID) + if err != nil || len(members) != 2 { + t.Fatalf("expected 2 members, got %d", len(members)) + } +} + +func TestGetAllChannelMembershipsForUser(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + uid, _, _ := db.CreateUser(ctx, "multi") + ch1, _ := db.GetOrCreateChannel(ctx, "#m1") + ch2, _ := db.GetOrCreateChannel(ctx, "#m2") + db.JoinChannel(ctx, ch1, uid) + db.JoinChannel(ctx, ch2, uid) + + channels, err := db.GetAllChannelMembershipsForUser(ctx, uid) + if err != nil || len(channels) != 2 { + t.Fatalf("expected 2 channels, got %d", len(channels)) + } +} diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go new file mode 100644 index 0000000..43c145c --- /dev/null +++ b/internal/handlers/api_test.go @@ -0,0 +1,896 @@ +package handlers_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "git.eeqj.de/sneak/chat/internal/broker" + "git.eeqj.de/sneak/chat/internal/config" + "git.eeqj.de/sneak/chat/internal/db" + "git.eeqj.de/sneak/chat/internal/globals" + "git.eeqj.de/sneak/chat/internal/handlers" + "git.eeqj.de/sneak/chat/internal/healthcheck" + "git.eeqj.de/sneak/chat/internal/logger" + "git.eeqj.de/sneak/chat/internal/middleware" + "git.eeqj.de/sneak/chat/internal/server" + "go.uber.org/fx" + "go.uber.org/fx/fxtest" +) + +// testServer wraps a test HTTP server with helper methods. +type testServer struct { + srv *httptest.Server + t *testing.T + fxApp *fxtest.App +} + +func newTestServer(t *testing.T) *testServer { + t.Helper() + + var s *server.Server + + app := fxtest.New(t, + fx.Provide( + func() *globals.Globals { return &globals.Globals{Appname: "chat-test", Version: "test"} }, + logger.New, + func(lc fx.Lifecycle, g *globals.Globals, l *logger.Logger) (*config.Config, error) { + return config.New(lc, config.Params{Globals: g, Logger: l}) + }, + func(lc fx.Lifecycle, l *logger.Logger, c *config.Config) (*db.Database, error) { + return db.New(lc, db.Params{Logger: l, Config: c}) + }, + func(lc fx.Lifecycle, g *globals.Globals, c *config.Config, l *logger.Logger, d *db.Database) (*healthcheck.Healthcheck, error) { + return healthcheck.New(lc, healthcheck.Params{Globals: g, Config: c, Logger: l, Database: d}) + }, + func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config) (*middleware.Middleware, error) { + return middleware.New(lc, middleware.Params{Logger: l, Globals: g, Config: c}) + }, + func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config, d *db.Database, hc *healthcheck.Healthcheck) (*handlers.Handlers, error) { + return handlers.New(lc, handlers.Params{Logger: l, Globals: g, Config: c, Database: d, Healthcheck: hc}) + }, + func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config, mw *middleware.Middleware, h *handlers.Handlers) (*server.Server, error) { + return server.New(lc, server.Params{Logger: l, Globals: g, Config: c, Middleware: mw, Handlers: h}) + }, + ), + fx.Populate(&s), + ) + + app.RequireStart() + // Give the server a moment to set up routes. + time.Sleep(100 * time.Millisecond) + + ts := httptest.NewServer(s) + t.Cleanup(func() { + ts.Close() + app.RequireStop() + }) + + return &testServer{srv: ts, t: t, fxApp: app} +} + +func (ts *testServer) url(path string) string { + return ts.srv.URL + path +} + +func (ts *testServer) createSession(nick string) (int64, string) { + ts.t.Helper() + body, _ := json.Marshal(map[string]string{"nick": nick}) + resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + if err != nil { + ts.t.Fatalf("create session: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + b, _ := io.ReadAll(resp.Body) + ts.t.Fatalf("create session: status %d: %s", resp.StatusCode, b) + } + var result struct { + ID int64 `json:"id"` + Token string `json:"token"` + } + json.NewDecoder(resp.Body).Decode(&result) + return result.ID, result.Token +} + +func (ts *testServer) sendCommand(token string, cmd map[string]any) (*http.Response, map[string]any) { + ts.t.Helper() + body, _ := json.Marshal(cmd) + req, _ := http.NewRequest("POST", ts.url("/api/v1/messages"), bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + ts.t.Fatalf("send command: %v", err) + } + defer resp.Body.Close() + var result map[string]any + json.NewDecoder(resp.Body).Decode(&result) + return resp, result +} + +func (ts *testServer) getJSON(token, path string) (*http.Response, map[string]any) { + ts.t.Helper() + req, _ := http.NewRequest("GET", ts.url(path), nil) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + ts.t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + var result map[string]any + json.NewDecoder(resp.Body).Decode(&result) + return resp, result +} + +func (ts *testServer) pollMessages(token string, afterID int64, timeout int) ([]map[string]any, int64) { + ts.t.Helper() + url := fmt.Sprintf("%s/api/v1/messages?timeout=%d&after=%d", ts.srv.URL, timeout, afterID) + req, _ := http.NewRequestWithContext(context.Background(), "GET", url, nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + ts.t.Fatalf("poll: %v", err) + } + defer resp.Body.Close() + var result struct { + Messages []map[string]any `json:"messages"` + LastID json.Number `json:"last_id"` + } + json.NewDecoder(resp.Body).Decode(&result) + lastID, _ := result.LastID.Int64() + return result.Messages, lastID +} + +// --- Tests --- + +func TestCreateSession(t *testing.T) { + ts := newTestServer(t) + + t.Run("valid nick", func(t *testing.T) { + _, token := ts.createSession("alice") + if token == "" { + t.Fatal("expected token") + } + }) + + t.Run("duplicate nick", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"nick": "alice"}) + resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusConflict { + t.Fatalf("expected 409, got %d", resp.StatusCode) + } + }) + + t.Run("empty nick", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"nick": ""}) + resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) + + t.Run("invalid nick chars", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"nick": "hello world"}) + resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) + + t.Run("nick starting with number", func(t *testing.T) { + body, _ := json.Marshal(map[string]string{"nick": "123abc"}) + resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) + + t.Run("malformed json", func(t *testing.T) { + resp, err := http.Post(ts.url("/api/v1/session"), "application/json", strings.NewReader("{bad")) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) +} + +func TestAuth(t *testing.T) { + ts := newTestServer(t) + + t.Run("no auth header", func(t *testing.T) { + resp, _ := ts.getJSON("", "/api/v1/state") + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } + }) + + t.Run("bad token", func(t *testing.T) { + resp, _ := ts.getJSON("invalid-token-12345", "/api/v1/state") + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } + }) + + t.Run("valid token", func(t *testing.T) { + _, token := ts.createSession("authtest") + resp, result := ts.getJSON(token, "/api/v1/state") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if result["nick"] != "authtest" { + t.Fatalf("expected nick authtest, got %v", result["nick"]) + } + }) +} + +func TestJoinAndPart(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("bob") + + t.Run("join channel", func(t *testing.T) { + resp, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#test"}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + } + if result["channel"] != "#test" { + t.Fatalf("expected #test, got %v", result["channel"]) + } + }) + + t.Run("join without hash", func(t *testing.T) { + resp, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "other"}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + } + if result["channel"] != "#other" { + t.Fatalf("expected #other, got %v", result["channel"]) + } + }) + + t.Run("part channel", func(t *testing.T) { + resp, result := ts.sendCommand(token, map[string]any{"command": "PART", "to": "#test"}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + } + if result["channel"] != "#test" { + t.Fatalf("expected #test, got %v", result["channel"]) + } + }) + + t.Run("join missing to", func(t *testing.T) { + resp, _ := ts.sendCommand(token, map[string]any{"command": "JOIN"}) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) +} + +func TestPrivmsg(t *testing.T) { + ts := newTestServer(t) + _, aliceToken := ts.createSession("alice_msg") + _, bobToken := ts.createSession("bob_msg") + + // Both join #chat + ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#chat"}) + ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#chat"}) + + // Drain existing messages (JOINs) + _, _ = ts.pollMessages(aliceToken, 0, 0) + _, bobLastID := ts.pollMessages(bobToken, 0, 0) + + t.Run("send channel message", func(t *testing.T) { + resp, result := ts.sendCommand(aliceToken, map[string]any{ + "command": "PRIVMSG", + "to": "#chat", + "body": []string{"hello world"}, + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) + } + if result["id"] == nil || result["id"] == "" { + t.Fatal("expected message id") + } + }) + + t.Run("bob receives message", func(t *testing.T) { + msgs, _ := ts.pollMessages(bobToken, bobLastID, 0) + found := false + for _, m := range msgs { + if m["command"] == "PRIVMSG" && m["from"] == "alice_msg" { + found = true + break + } + } + if !found { + t.Fatalf("bob didn't receive alice's message: %v", msgs) + } + }) + + t.Run("missing body", func(t *testing.T) { + resp, _ := ts.sendCommand(aliceToken, map[string]any{ + "command": "PRIVMSG", + "to": "#chat", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) + + t.Run("missing to", func(t *testing.T) { + resp, _ := ts.sendCommand(aliceToken, map[string]any{ + "command": "PRIVMSG", + "body": []string{"hello"}, + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) +} + +func TestDM(t *testing.T) { + ts := newTestServer(t) + _, aliceToken := ts.createSession("alice_dm") + _, bobToken := ts.createSession("bob_dm") + + t.Run("send DM", func(t *testing.T) { + resp, result := ts.sendCommand(aliceToken, map[string]any{ + "command": "PRIVMSG", + "to": "bob_dm", + "body": []string{"hey bob"}, + }) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) + } + }) + + t.Run("bob receives DM", func(t *testing.T) { + msgs, _ := ts.pollMessages(bobToken, 0, 0) + found := false + for _, m := range msgs { + if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" { + found = true + } + } + if !found { + t.Fatal("bob didn't receive DM") + } + }) + + t.Run("alice gets echo", func(t *testing.T) { + msgs, _ := ts.pollMessages(aliceToken, 0, 0) + found := false + for _, m := range msgs { + if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" && m["to"] == "bob_dm" { + found = true + } + } + if !found { + t.Fatal("alice didn't get DM echo") + } + }) + + t.Run("DM to nonexistent user", func(t *testing.T) { + resp, _ := ts.sendCommand(aliceToken, map[string]any{ + "command": "PRIVMSG", + "to": "nobody", + "body": []string{"hello?"}, + }) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + }) +} + +func TestNick(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("nick_test") + + t.Run("change nick", func(t *testing.T) { + resp, result := ts.sendCommand(token, map[string]any{ + "command": "NICK", + "body": []string{"newnick"}, + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + } + if result["nick"] != "newnick" { + t.Fatalf("expected newnick, got %v", result["nick"]) + } + }) + + t.Run("nick same as current", func(t *testing.T) { + resp, _ := ts.sendCommand(token, map[string]any{ + "command": "NICK", + "body": []string{"newnick"}, + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("nick collision", func(t *testing.T) { + ts.createSession("taken_nick") + resp, _ := ts.sendCommand(token, map[string]any{ + "command": "NICK", + "body": []string{"taken_nick"}, + }) + if resp.StatusCode != http.StatusConflict { + t.Fatalf("expected 409, got %d", resp.StatusCode) + } + }) + + t.Run("invalid nick", func(t *testing.T) { + resp, _ := ts.sendCommand(token, map[string]any{ + "command": "NICK", + "body": []string{"bad nick!"}, + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) + + t.Run("empty body", func(t *testing.T) { + resp, _ := ts.sendCommand(token, map[string]any{ + "command": "NICK", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) +} + +func TestTopic(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("topic_user") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#topictest"}) + + t.Run("set topic", func(t *testing.T) { + resp, result := ts.sendCommand(token, map[string]any{ + "command": "TOPIC", + "to": "#topictest", + "body": []string{"Hello World Topic"}, + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + } + if result["topic"] != "Hello World Topic" { + t.Fatalf("expected topic, got %v", result["topic"]) + } + }) + + t.Run("missing to", func(t *testing.T) { + resp, _ := ts.sendCommand(token, map[string]any{ + "command": "TOPIC", + "body": []string{"topic"}, + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) + + t.Run("missing body", func(t *testing.T) { + resp, _ := ts.sendCommand(token, map[string]any{ + "command": "TOPIC", + "to": "#topictest", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + }) +} + +func TestPing(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("ping_user") + + resp, result := ts.sendCommand(token, map[string]any{"command": "PING"}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if result["command"] != "PONG" { + t.Fatalf("expected PONG, got %v", result["command"]) + } +} + +func TestQuit(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("quitter") + _, observerToken := ts.createSession("observer") + + // Both join a channel + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#quitchan"}) + ts.sendCommand(observerToken, map[string]any{"command": "JOIN", "to": "#quitchan"}) + + // Drain messages + _, lastID := ts.pollMessages(observerToken, 0, 0) + + // Quit + resp, result := ts.sendCommand(token, map[string]any{"command": "QUIT"}) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + } + + // Observer should get QUIT message + msgs, _ := ts.pollMessages(observerToken, lastID, 0) + found := false + for _, m := range msgs { + if m["command"] == "QUIT" && m["from"] == "quitter" { + found = true + } + } + if !found { + t.Fatalf("observer didn't get QUIT: %v", msgs) + } + + // Token should be invalid now + resp2, _ := ts.getJSON(token, "/api/v1/state") + if resp2.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401 after quit, got %d", resp2.StatusCode) + } +} + +func TestUnknownCommand(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("cmdtest") + + resp, result := ts.sendCommand(token, map[string]any{"command": "BOGUS"}) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result) + } +} + +func TestEmptyCommand(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("emptycmd") + + resp, _ := ts.sendCommand(token, map[string]any{"command": ""}) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestHistory(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("historian") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#history"}) + + // Send some messages + for i := 0; i < 5; i++ { + ts.sendCommand(token, map[string]any{ + "command": "PRIVMSG", + "to": "#history", + "body": []string{"msg " + string(rune('A'+i))}, + }) + } + + req, _ := http.NewRequest("GET", ts.url("/api/v1/history?target=%23history&limit=3"), nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + var msgs []map[string]any + json.NewDecoder(resp.Body).Decode(&msgs) + if len(msgs) != 3 { + t.Fatalf("expected 3 messages, got %d", len(msgs)) + } +} + +func TestChannelList(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("lister") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#listchan"}) + + req, _ := http.NewRequest("GET", ts.url("/api/v1/channels"), nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + var channels []map[string]any + json.NewDecoder(resp.Body).Decode(&channels) + found := false + for _, ch := range channels { + if ch["name"] == "#listchan" { + found = true + } + } + if !found { + t.Fatal("channel not in list") + } +} + +func TestChannelMembers(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("membertest") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#members"}) + + req, _ := http.NewRequest("GET", ts.url("/api/v1/channels/members/members"), nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } +} + +func TestLongPoll(t *testing.T) { + ts := newTestServer(t) + _, aliceToken := ts.createSession("lp_alice") + _, bobToken := ts.createSession("lp_bob") + + ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#longpoll"}) + ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#longpoll"}) + + // Drain existing messages + _, lastID := ts.pollMessages(bobToken, 0, 0) + + // Start long-poll in goroutine + var wg sync.WaitGroup + var pollMsgs []map[string]any + + wg.Add(1) + go func() { + defer wg.Done() + url := fmt.Sprintf("%s/api/v1/messages?timeout=5&after=%d", ts.srv.URL, lastID) + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("Authorization", "Bearer "+bobToken) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + var result struct { + Messages []map[string]any `json:"messages"` + } + json.NewDecoder(resp.Body).Decode(&result) + pollMsgs = result.Messages + }() + + // Give the long-poll a moment to start + time.Sleep(200 * time.Millisecond) + + // Send a message + ts.sendCommand(aliceToken, map[string]any{ + "command": "PRIVMSG", + "to": "#longpoll", + "body": []string{"wake up!"}, + }) + + wg.Wait() + + found := false + for _, m := range pollMsgs { + if m["command"] == "PRIVMSG" && m["from"] == "lp_alice" { + found = true + } + } + if !found { + t.Fatalf("long-poll didn't receive message: %v", pollMsgs) + } +} + +func TestLongPollTimeout(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("lp_timeout") + + start := time.Now() + req, _ := http.NewRequest("GET", ts.url("/api/v1/messages?timeout=1"), nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + elapsed := time.Since(start) + + if elapsed < 900*time.Millisecond { + t.Fatalf("long-poll returned too fast: %v", elapsed) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } +} + +func TestEphemeralChannelCleanup(t *testing.T) { + ts := newTestServer(t) + _, token := ts.createSession("ephemeral") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#ephemeral"}) + ts.sendCommand(token, map[string]any{"command": "PART", "to": "#ephemeral"}) + + // Channel should be gone + req, _ := http.NewRequest("GET", ts.url("/api/v1/channels"), nil) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + var channels []map[string]any + json.NewDecoder(resp.Body).Decode(&channels) + for _, ch := range channels { + if ch["name"] == "#ephemeral" { + t.Fatal("ephemeral channel should have been cleaned up") + } + } +} + +func TestConcurrentSessions(t *testing.T) { + ts := newTestServer(t) + + var wg sync.WaitGroup + errors := make(chan error, 20) + + for i := 0; i < 20; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + nick := "concurrent_" + string(rune('a'+i)) + body, _ := json.Marshal(map[string]string{"nick": nick}) + resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + if err != nil { + errors <- err + return + } + resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + errors <- err + } + }(i) + } + + wg.Wait() + close(errors) + + for err := range errors { + if err != nil { + t.Fatalf("concurrent session creation error: %v", err) + } + } +} + +func TestServerInfo(t *testing.T) { + ts := newTestServer(t) + + resp, err := http.Get(ts.url("/api/v1/server")) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } +} + +func TestHealthcheck(t *testing.T) { + ts := newTestServer(t) + + resp, err := http.Get(ts.url("/.well-known/healthcheck.json")) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + var result map[string]any + json.NewDecoder(resp.Body).Decode(&result) + if result["status"] != "ok" { + t.Fatalf("expected ok status, got %v", result["status"]) + } +} + +func TestNickBroadcastToChannels(t *testing.T) { + ts := newTestServer(t) + _, aliceToken := ts.createSession("nick_a") + _, bobToken := ts.createSession("nick_b") + + ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#nicktest"}) + ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#nicktest"}) + + // Drain + _, lastID := ts.pollMessages(bobToken, 0, 0) + + // Alice changes nick + ts.sendCommand(aliceToken, map[string]any{"command": "NICK", "body": []string{"nick_a_new"}}) + + // Bob should see it + msgs, _ := ts.pollMessages(bobToken, lastID, 0) + found := false + for _, m := range msgs { + if m["command"] == "NICK" && m["from"] == "nick_a" { + found = true + } + } + if !found { + t.Fatalf("bob didn't get nick change: %v", msgs) + } +} + +// Broker unit tests + +func TestBrokerNotifyWithoutWaiters(t *testing.T) { + b := broker.New() + // Should not panic + b.Notify(999) +} + +func TestBrokerWaitAndNotify(t *testing.T) { + b := broker.New() + ch := b.Wait(1) + + go func() { + time.Sleep(50 * time.Millisecond) + b.Notify(1) + }() + + select { + case <-ch: + // ok + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for notification") + } +} + +func TestBrokerRemove(t *testing.T) { + b := broker.New() + ch := b.Wait(1) + b.Remove(1, ch) + // Notify should not send to removed channel + b.Notify(1) + + select { + case <-ch: + t.Fatal("should not receive after remove") + case <-time.After(100 * time.Millisecond): + // ok + } +}