MVP 1.0: IRC-over-HTTP chat server #10

Merged
sneak merged 18 commits from feature/mvp-1.0 into main 2026-02-27 07:21:35 +01:00
3 changed files with 1328 additions and 0 deletions
Showing only changes of commit fbeede563d - Show all commits

View File

@@ -0,0 +1,94 @@
package broker
import (
"sync"
"testing"
"time"
)
func TestNewBroker(t *testing.T) {
b := New()
if b == nil {
t.Fatal("expected non-nil broker")
}
}
func TestWaitAndNotify(t *testing.T) {
b := New()
ch := b.Wait(1)
go func() {
time.Sleep(10 * time.Millisecond)
b.Notify(1)
}()
select {
case <-ch:
case <-time.After(2 * time.Second):
t.Fatal("timeout")
}
}
func TestNotifyWithoutWaiters(t *testing.T) {
b := New()
b.Notify(42) // should not panic
}
func TestRemove(t *testing.T) {
b := New()
ch := b.Wait(1)
b.Remove(1, ch)
b.Notify(1)
select {
case <-ch:
t.Fatal("should not receive after remove")
case <-time.After(50 * time.Millisecond):
}
}
func TestMultipleWaiters(t *testing.T) {
b := New()
ch1 := b.Wait(1)
ch2 := b.Wait(1)
b.Notify(1)
select {
case <-ch1:
case <-time.After(time.Second):
t.Fatal("ch1 timeout")
}
select {
case <-ch2:
case <-time.After(time.Second):
t.Fatal("ch2 timeout")
}
}
func TestConcurrentWaitNotify(t *testing.T) {
b := New()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(uid int64) {
defer wg.Done()
ch := b.Wait(uid)
b.Notify(uid)
select {
case <-ch:
case <-time.After(time.Second):
t.Error("timeout")
}
}(int64(i % 10))
}
wg.Wait()
}
func TestRemoveNonexistent(t *testing.T) {
b := New()
ch := make(chan struct{}, 1)
b.Remove(999, ch) // should not panic
}

338
internal/db/queries_test.go Normal file
View File

@@ -0,0 +1,338 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"log/slog"
"testing"
_ "modernc.org/sqlite"
)
func setupTestDB(t *testing.T) *Database {
t.Helper()
d, err := sql.Open("sqlite", "file::memory:?cache=shared&_pragma=foreign_keys(1)")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { d.Close() })
db := &Database{db: d, log: slog.Default()}
if err := db.runMigrations(context.Background()); err != nil {
t.Fatal(err)
}
return db
}
func TestCreateUser(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
id, token, err := db.CreateUser(ctx, "alice")
if err != nil {
t.Fatal(err)
}
if id == 0 || token == "" {
t.Fatal("expected valid id and token")
}
// Duplicate nick
_, _, err = db.CreateUser(ctx, "alice")
if err == nil {
t.Fatal("expected error for duplicate nick")
}
}
func TestGetUserByToken(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
_, token, _ := db.CreateUser(ctx, "bob")
id, nick, err := db.GetUserByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if nick != "bob" || id == 0 {
t.Fatalf("expected bob, got %s", nick)
}
// Invalid token
_, _, err = db.GetUserByToken(ctx, "badtoken")
if err == nil {
t.Fatal("expected error for bad token")
}
}
func TestGetUserByNick(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
db.CreateUser(ctx, "charlie")
id, err := db.GetUserByNick(ctx, "charlie")
if err != nil || id == 0 {
t.Fatal("expected to find charlie")
}
_, err = db.GetUserByNick(ctx, "nobody")
if err == nil {
t.Fatal("expected error for unknown nick")
}
}
func TestChannelOperations(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
// Create channel
chID, err := db.GetOrCreateChannel(ctx, "#test")
if err != nil || chID == 0 {
t.Fatal("expected channel id")
}
// Get same channel
chID2, err := db.GetOrCreateChannel(ctx, "#test")
if err != nil || chID2 != chID {
t.Fatal("expected same channel id")
}
// GetChannelByName
chID3, err := db.GetChannelByName(ctx, "#test")
if err != nil || chID3 != chID {
t.Fatal("expected same channel id from GetChannelByName")
}
// Nonexistent channel
_, err = db.GetChannelByName(ctx, "#nope")
if err == nil {
t.Fatal("expected error for nonexistent channel")
}
}
func TestJoinAndPart(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
uid, _, _ := db.CreateUser(ctx, "user1")
chID, _ := db.GetOrCreateChannel(ctx, "#chan")
// Join
if err := db.JoinChannel(ctx, chID, uid); err != nil {
t.Fatal(err)
}
// Verify membership
ids, err := db.GetChannelMemberIDs(ctx, chID)
if err != nil || len(ids) != 1 || ids[0] != uid {
t.Fatal("expected user in channel")
}
// Double join (should be ignored)
if err := db.JoinChannel(ctx, chID, uid); err != nil {
t.Fatal(err)
}
// Part
if err := db.PartChannel(ctx, chID, uid); err != nil {
t.Fatal(err)
}
ids, _ = db.GetChannelMemberIDs(ctx, chID)
if len(ids) != 0 {
t.Fatal("expected empty channel")
}
}
func TestDeleteChannelIfEmpty(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
chID, _ := db.GetOrCreateChannel(ctx, "#empty")
uid, _, _ := db.CreateUser(ctx, "temp")
db.JoinChannel(ctx, chID, uid)
db.PartChannel(ctx, chID, uid)
if err := db.DeleteChannelIfEmpty(ctx, chID); err != nil {
t.Fatal(err)
}
_, err := db.GetChannelByName(ctx, "#empty")
if err == nil {
t.Fatal("expected channel to be deleted")
}
}
func TestListChannels(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
uid, _, _ := db.CreateUser(ctx, "lister")
ch1, _ := db.GetOrCreateChannel(ctx, "#a")
ch2, _ := db.GetOrCreateChannel(ctx, "#b")
db.JoinChannel(ctx, ch1, uid)
db.JoinChannel(ctx, ch2, uid)
channels, err := db.ListChannels(ctx, uid)
if err != nil || len(channels) != 2 {
t.Fatalf("expected 2 channels, got %d", len(channels))
}
}
func TestListAllChannels(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
db.GetOrCreateChannel(ctx, "#x")
db.GetOrCreateChannel(ctx, "#y")
channels, err := db.ListAllChannels(ctx)
if err != nil || len(channels) < 2 {
t.Fatalf("expected >= 2 channels, got %d", len(channels))
}
}
func TestChangeNick(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
uid, token, _ := db.CreateUser(ctx, "old")
if err := db.ChangeNick(ctx, uid, "new"); err != nil {
t.Fatal(err)
}
_, nick, _ := db.GetUserByToken(ctx, token)
if nick != "new" {
t.Fatalf("expected new, got %s", nick)
}
}
func TestSetTopic(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
db.GetOrCreateChannel(ctx, "#topictest")
if err := db.SetTopic(ctx, "#topictest", "Hello"); err != nil {
t.Fatal(err)
}
channels, _ := db.ListAllChannels(ctx)
for _, ch := range channels {
if ch.Name == "#topictest" && ch.Topic != "Hello" {
t.Fatalf("expected topic Hello, got %s", ch.Topic)
}
}
}
func TestInsertAndPollMessages(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
uid, _, _ := db.CreateUser(ctx, "poller")
body := json.RawMessage(`["hello"]`)
dbID, uuid, err := db.InsertMessage(ctx, "PRIVMSG", "poller", "#test", body, nil)
if err != nil || dbID == 0 || uuid == "" {
t.Fatal("insert failed")
}
if err := db.EnqueueMessage(ctx, uid, dbID); err != nil {
t.Fatal(err)
}
msgs, lastQID, err := db.PollMessages(ctx, uid, 0, 10)
if err != nil {
t.Fatal(err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1 message, got %d", len(msgs))
}
if msgs[0].Command != "PRIVMSG" {
t.Fatalf("expected PRIVMSG, got %s", msgs[0].Command)
}
if lastQID == 0 {
t.Fatal("expected nonzero lastQID")
}
// Poll again with lastQID - should be empty
msgs, _, _ = db.PollMessages(ctx, uid, lastQID, 10)
if len(msgs) != 0 {
t.Fatalf("expected 0 messages, got %d", len(msgs))
}
}
func TestGetHistory(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
for i := 0; i < 10; i++ {
db.InsertMessage(ctx, "PRIVMSG", "user", "#hist", json.RawMessage(`["msg"]`), nil)
}
msgs, err := db.GetHistory(ctx, "#hist", 0, 5)
if err != nil {
t.Fatal(err)
}
if len(msgs) != 5 {
t.Fatalf("expected 5, got %d", len(msgs))
}
// Should be ascending order
if msgs[0].DBID > msgs[4].DBID {
t.Fatal("expected ascending order")
}
}
func TestDeleteUser(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
uid, _, _ := db.CreateUser(ctx, "deleteme")
chID, _ := db.GetOrCreateChannel(ctx, "#delchan")
db.JoinChannel(ctx, chID, uid)
if err := db.DeleteUser(ctx, uid); err != nil {
t.Fatal(err)
}
_, err := db.GetUserByNick(ctx, "deleteme")
if err == nil {
t.Fatal("user should be deleted")
}
// Channel membership should be cleaned up via CASCADE
ids, _ := db.GetChannelMemberIDs(ctx, chID)
if len(ids) != 0 {
t.Fatal("expected no members after user deletion")
}
}
func TestChannelMembers(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
uid1, _, _ := db.CreateUser(ctx, "m1")
uid2, _, _ := db.CreateUser(ctx, "m2")
chID, _ := db.GetOrCreateChannel(ctx, "#members")
db.JoinChannel(ctx, chID, uid1)
db.JoinChannel(ctx, chID, uid2)
members, err := db.ChannelMembers(ctx, chID)
if err != nil || len(members) != 2 {
t.Fatalf("expected 2 members, got %d", len(members))
}
}
func TestGetAllChannelMembershipsForUser(t *testing.T) {
db := setupTestDB(t)
ctx := context.Background()
uid, _, _ := db.CreateUser(ctx, "multi")
ch1, _ := db.GetOrCreateChannel(ctx, "#m1")
ch2, _ := db.GetOrCreateChannel(ctx, "#m2")
db.JoinChannel(ctx, ch1, uid)
db.JoinChannel(ctx, ch2, uid)
channels, err := db.GetAllChannelMembershipsForUser(ctx, uid)
if err != nil || len(channels) != 2 {
t.Fatalf("expected 2 channels, got %d", len(channels))
}
}

View File

@@ -0,0 +1,896 @@
package handlers_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"git.eeqj.de/sneak/chat/internal/broker"
"git.eeqj.de/sneak/chat/internal/config"
"git.eeqj.de/sneak/chat/internal/db"
"git.eeqj.de/sneak/chat/internal/globals"
"git.eeqj.de/sneak/chat/internal/handlers"
"git.eeqj.de/sneak/chat/internal/healthcheck"
"git.eeqj.de/sneak/chat/internal/logger"
"git.eeqj.de/sneak/chat/internal/middleware"
"git.eeqj.de/sneak/chat/internal/server"
"go.uber.org/fx"
"go.uber.org/fx/fxtest"
)
// testServer wraps a test HTTP server with helper methods.
type testServer struct {
srv *httptest.Server
t *testing.T
fxApp *fxtest.App
}
func newTestServer(t *testing.T) *testServer {
t.Helper()
var s *server.Server
app := fxtest.New(t,
fx.Provide(
func() *globals.Globals { return &globals.Globals{Appname: "chat-test", Version: "test"} },
logger.New,
func(lc fx.Lifecycle, g *globals.Globals, l *logger.Logger) (*config.Config, error) {
return config.New(lc, config.Params{Globals: g, Logger: l})
},
func(lc fx.Lifecycle, l *logger.Logger, c *config.Config) (*db.Database, error) {
return db.New(lc, db.Params{Logger: l, Config: c})
},
func(lc fx.Lifecycle, g *globals.Globals, c *config.Config, l *logger.Logger, d *db.Database) (*healthcheck.Healthcheck, error) {
return healthcheck.New(lc, healthcheck.Params{Globals: g, Config: c, Logger: l, Database: d})
},
func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config) (*middleware.Middleware, error) {
return middleware.New(lc, middleware.Params{Logger: l, Globals: g, Config: c})
},
func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config, d *db.Database, hc *healthcheck.Healthcheck) (*handlers.Handlers, error) {
return handlers.New(lc, handlers.Params{Logger: l, Globals: g, Config: c, Database: d, Healthcheck: hc})
},
func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config, mw *middleware.Middleware, h *handlers.Handlers) (*server.Server, error) {
return server.New(lc, server.Params{Logger: l, Globals: g, Config: c, Middleware: mw, Handlers: h})
},
),
fx.Populate(&s),
)
app.RequireStart()
// Give the server a moment to set up routes.
time.Sleep(100 * time.Millisecond)
ts := httptest.NewServer(s)
t.Cleanup(func() {
ts.Close()
app.RequireStop()
})
return &testServer{srv: ts, t: t, fxApp: app}
}
func (ts *testServer) url(path string) string {
return ts.srv.URL + path
}
func (ts *testServer) createSession(nick string) (int64, string) {
ts.t.Helper()
body, _ := json.Marshal(map[string]string{"nick": nick})
resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body))
if err != nil {
ts.t.Fatalf("create session: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
b, _ := io.ReadAll(resp.Body)
ts.t.Fatalf("create session: status %d: %s", resp.StatusCode, b)
}
var result struct {
ID int64 `json:"id"`
Token string `json:"token"`
}
json.NewDecoder(resp.Body).Decode(&result)
return result.ID, result.Token
}
func (ts *testServer) sendCommand(token string, cmd map[string]any) (*http.Response, map[string]any) {
ts.t.Helper()
body, _ := json.Marshal(cmd)
req, _ := http.NewRequest("POST", ts.url("/api/v1/messages"), bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
ts.t.Fatalf("send command: %v", err)
}
defer resp.Body.Close()
var result map[string]any
json.NewDecoder(resp.Body).Decode(&result)
return resp, result
}
func (ts *testServer) getJSON(token, path string) (*http.Response, map[string]any) {
ts.t.Helper()
req, _ := http.NewRequest("GET", ts.url(path), nil)
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
ts.t.Fatalf("get: %v", err)
}
defer resp.Body.Close()
var result map[string]any
json.NewDecoder(resp.Body).Decode(&result)
return resp, result
}
func (ts *testServer) pollMessages(token string, afterID int64, timeout int) ([]map[string]any, int64) {
ts.t.Helper()
url := fmt.Sprintf("%s/api/v1/messages?timeout=%d&after=%d", ts.srv.URL, timeout, afterID)
req, _ := http.NewRequestWithContext(context.Background(), "GET", url, nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
ts.t.Fatalf("poll: %v", err)
}
defer resp.Body.Close()
var result struct {
Messages []map[string]any `json:"messages"`
LastID json.Number `json:"last_id"`
}
json.NewDecoder(resp.Body).Decode(&result)
lastID, _ := result.LastID.Int64()
return result.Messages, lastID
}
// --- Tests ---
func TestCreateSession(t *testing.T) {
ts := newTestServer(t)
t.Run("valid nick", func(t *testing.T) {
_, token := ts.createSession("alice")
if token == "" {
t.Fatal("expected token")
}
})
t.Run("duplicate nick", func(t *testing.T) {
body, _ := json.Marshal(map[string]string{"nick": "alice"})
resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusConflict {
t.Fatalf("expected 409, got %d", resp.StatusCode)
}
})
t.Run("empty nick", func(t *testing.T) {
body, _ := json.Marshal(map[string]string{"nick": ""})
resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
t.Run("invalid nick chars", func(t *testing.T) {
body, _ := json.Marshal(map[string]string{"nick": "hello world"})
resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
t.Run("nick starting with number", func(t *testing.T) {
body, _ := json.Marshal(map[string]string{"nick": "123abc"})
resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
t.Run("malformed json", func(t *testing.T) {
resp, err := http.Post(ts.url("/api/v1/session"), "application/json", strings.NewReader("{bad"))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
}
func TestAuth(t *testing.T) {
ts := newTestServer(t)
t.Run("no auth header", func(t *testing.T) {
resp, _ := ts.getJSON("", "/api/v1/state")
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", resp.StatusCode)
}
})
t.Run("bad token", func(t *testing.T) {
resp, _ := ts.getJSON("invalid-token-12345", "/api/v1/state")
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", resp.StatusCode)
}
})
t.Run("valid token", func(t *testing.T) {
_, token := ts.createSession("authtest")
resp, result := ts.getJSON(token, "/api/v1/state")
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if result["nick"] != "authtest" {
t.Fatalf("expected nick authtest, got %v", result["nick"])
}
})
}
func TestJoinAndPart(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("bob")
t.Run("join channel", func(t *testing.T) {
resp, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#test"})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result)
}
if result["channel"] != "#test" {
t.Fatalf("expected #test, got %v", result["channel"])
}
})
t.Run("join without hash", func(t *testing.T) {
resp, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "other"})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result)
}
if result["channel"] != "#other" {
t.Fatalf("expected #other, got %v", result["channel"])
}
})
t.Run("part channel", func(t *testing.T) {
resp, result := ts.sendCommand(token, map[string]any{"command": "PART", "to": "#test"})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result)
}
if result["channel"] != "#test" {
t.Fatalf("expected #test, got %v", result["channel"])
}
})
t.Run("join missing to", func(t *testing.T) {
resp, _ := ts.sendCommand(token, map[string]any{"command": "JOIN"})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
}
func TestPrivmsg(t *testing.T) {
ts := newTestServer(t)
_, aliceToken := ts.createSession("alice_msg")
_, bobToken := ts.createSession("bob_msg")
// Both join #chat
ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#chat"})
ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#chat"})
// Drain existing messages (JOINs)
_, _ = ts.pollMessages(aliceToken, 0, 0)
_, bobLastID := ts.pollMessages(bobToken, 0, 0)
t.Run("send channel message", func(t *testing.T) {
resp, result := ts.sendCommand(aliceToken, map[string]any{
"command": "PRIVMSG",
"to": "#chat",
"body": []string{"hello world"},
})
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
if result["id"] == nil || result["id"] == "" {
t.Fatal("expected message id")
}
})
t.Run("bob receives message", func(t *testing.T) {
msgs, _ := ts.pollMessages(bobToken, bobLastID, 0)
found := false
for _, m := range msgs {
if m["command"] == "PRIVMSG" && m["from"] == "alice_msg" {
found = true
break
}
}
if !found {
t.Fatalf("bob didn't receive alice's message: %v", msgs)
}
})
t.Run("missing body", func(t *testing.T) {
resp, _ := ts.sendCommand(aliceToken, map[string]any{
"command": "PRIVMSG",
"to": "#chat",
})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
t.Run("missing to", func(t *testing.T) {
resp, _ := ts.sendCommand(aliceToken, map[string]any{
"command": "PRIVMSG",
"body": []string{"hello"},
})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
}
func TestDM(t *testing.T) {
ts := newTestServer(t)
_, aliceToken := ts.createSession("alice_dm")
_, bobToken := ts.createSession("bob_dm")
t.Run("send DM", func(t *testing.T) {
resp, result := ts.sendCommand(aliceToken, map[string]any{
"command": "PRIVMSG",
"to": "bob_dm",
"body": []string{"hey bob"},
})
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result)
}
})
t.Run("bob receives DM", func(t *testing.T) {
msgs, _ := ts.pollMessages(bobToken, 0, 0)
found := false
for _, m := range msgs {
if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" {
found = true
}
}
if !found {
t.Fatal("bob didn't receive DM")
}
})
t.Run("alice gets echo", func(t *testing.T) {
msgs, _ := ts.pollMessages(aliceToken, 0, 0)
found := false
for _, m := range msgs {
if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" && m["to"] == "bob_dm" {
found = true
}
}
if !found {
t.Fatal("alice didn't get DM echo")
}
})
t.Run("DM to nonexistent user", func(t *testing.T) {
resp, _ := ts.sendCommand(aliceToken, map[string]any{
"command": "PRIVMSG",
"to": "nobody",
"body": []string{"hello?"},
})
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("expected 404, got %d", resp.StatusCode)
}
})
}
func TestNick(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("nick_test")
t.Run("change nick", func(t *testing.T) {
resp, result := ts.sendCommand(token, map[string]any{
"command": "NICK",
"body": []string{"newnick"},
})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result)
}
if result["nick"] != "newnick" {
t.Fatalf("expected newnick, got %v", result["nick"])
}
})
t.Run("nick same as current", func(t *testing.T) {
resp, _ := ts.sendCommand(token, map[string]any{
"command": "NICK",
"body": []string{"newnick"},
})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
})
t.Run("nick collision", func(t *testing.T) {
ts.createSession("taken_nick")
resp, _ := ts.sendCommand(token, map[string]any{
"command": "NICK",
"body": []string{"taken_nick"},
})
if resp.StatusCode != http.StatusConflict {
t.Fatalf("expected 409, got %d", resp.StatusCode)
}
})
t.Run("invalid nick", func(t *testing.T) {
resp, _ := ts.sendCommand(token, map[string]any{
"command": "NICK",
"body": []string{"bad nick!"},
})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
t.Run("empty body", func(t *testing.T) {
resp, _ := ts.sendCommand(token, map[string]any{
"command": "NICK",
})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
}
func TestTopic(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("topic_user")
ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#topictest"})
t.Run("set topic", func(t *testing.T) {
resp, result := ts.sendCommand(token, map[string]any{
"command": "TOPIC",
"to": "#topictest",
"body": []string{"Hello World Topic"},
})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result)
}
if result["topic"] != "Hello World Topic" {
t.Fatalf("expected topic, got %v", result["topic"])
}
})
t.Run("missing to", func(t *testing.T) {
resp, _ := ts.sendCommand(token, map[string]any{
"command": "TOPIC",
"body": []string{"topic"},
})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
t.Run("missing body", func(t *testing.T) {
resp, _ := ts.sendCommand(token, map[string]any{
"command": "TOPIC",
"to": "#topictest",
})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
})
}
func TestPing(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("ping_user")
resp, result := ts.sendCommand(token, map[string]any{"command": "PING"})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if result["command"] != "PONG" {
t.Fatalf("expected PONG, got %v", result["command"])
}
}
func TestQuit(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("quitter")
_, observerToken := ts.createSession("observer")
// Both join a channel
ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#quitchan"})
ts.sendCommand(observerToken, map[string]any{"command": "JOIN", "to": "#quitchan"})
// Drain messages
_, lastID := ts.pollMessages(observerToken, 0, 0)
// Quit
resp, result := ts.sendCommand(token, map[string]any{"command": "QUIT"})
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result)
}
// Observer should get QUIT message
msgs, _ := ts.pollMessages(observerToken, lastID, 0)
found := false
for _, m := range msgs {
if m["command"] == "QUIT" && m["from"] == "quitter" {
found = true
}
}
if !found {
t.Fatalf("observer didn't get QUIT: %v", msgs)
}
// Token should be invalid now
resp2, _ := ts.getJSON(token, "/api/v1/state")
if resp2.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected 401 after quit, got %d", resp2.StatusCode)
}
}
func TestUnknownCommand(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("cmdtest")
resp, result := ts.sendCommand(token, map[string]any{"command": "BOGUS"})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result)
}
}
func TestEmptyCommand(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("emptycmd")
resp, _ := ts.sendCommand(token, map[string]any{"command": ""})
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.StatusCode)
}
}
func TestHistory(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("historian")
ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#history"})
// Send some messages
for i := 0; i < 5; i++ {
ts.sendCommand(token, map[string]any{
"command": "PRIVMSG",
"to": "#history",
"body": []string{"msg " + string(rune('A'+i))},
})
}
req, _ := http.NewRequest("GET", ts.url("/api/v1/history?target=%23history&limit=3"), nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
var msgs []map[string]any
json.NewDecoder(resp.Body).Decode(&msgs)
if len(msgs) != 3 {
t.Fatalf("expected 3 messages, got %d", len(msgs))
}
}
func TestChannelList(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("lister")
ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#listchan"})
req, _ := http.NewRequest("GET", ts.url("/api/v1/channels"), nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
var channels []map[string]any
json.NewDecoder(resp.Body).Decode(&channels)
found := false
for _, ch := range channels {
if ch["name"] == "#listchan" {
found = true
}
}
if !found {
t.Fatal("channel not in list")
}
}
func TestChannelMembers(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("membertest")
ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#members"})
req, _ := http.NewRequest("GET", ts.url("/api/v1/channels/members/members"), nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
}
func TestLongPoll(t *testing.T) {
ts := newTestServer(t)
_, aliceToken := ts.createSession("lp_alice")
_, bobToken := ts.createSession("lp_bob")
ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#longpoll"})
ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#longpoll"})
// Drain existing messages
_, lastID := ts.pollMessages(bobToken, 0, 0)
// Start long-poll in goroutine
var wg sync.WaitGroup
var pollMsgs []map[string]any
wg.Add(1)
go func() {
defer wg.Done()
url := fmt.Sprintf("%s/api/v1/messages?timeout=5&after=%d", ts.srv.URL, lastID)
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+bobToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
var result struct {
Messages []map[string]any `json:"messages"`
}
json.NewDecoder(resp.Body).Decode(&result)
pollMsgs = result.Messages
}()
// Give the long-poll a moment to start
time.Sleep(200 * time.Millisecond)
// Send a message
ts.sendCommand(aliceToken, map[string]any{
"command": "PRIVMSG",
"to": "#longpoll",
"body": []string{"wake up!"},
})
wg.Wait()
found := false
for _, m := range pollMsgs {
if m["command"] == "PRIVMSG" && m["from"] == "lp_alice" {
found = true
}
}
if !found {
t.Fatalf("long-poll didn't receive message: %v", pollMsgs)
}
}
func TestLongPollTimeout(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("lp_timeout")
start := time.Now()
req, _ := http.NewRequest("GET", ts.url("/api/v1/messages?timeout=1"), nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
elapsed := time.Since(start)
if elapsed < 900*time.Millisecond {
t.Fatalf("long-poll returned too fast: %v", elapsed)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
}
func TestEphemeralChannelCleanup(t *testing.T) {
ts := newTestServer(t)
_, token := ts.createSession("ephemeral")
ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#ephemeral"})
ts.sendCommand(token, map[string]any{"command": "PART", "to": "#ephemeral"})
// Channel should be gone
req, _ := http.NewRequest("GET", ts.url("/api/v1/channels"), nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
var channels []map[string]any
json.NewDecoder(resp.Body).Decode(&channels)
for _, ch := range channels {
if ch["name"] == "#ephemeral" {
t.Fatal("ephemeral channel should have been cleaned up")
}
}
}
func TestConcurrentSessions(t *testing.T) {
ts := newTestServer(t)
var wg sync.WaitGroup
errors := make(chan error, 20)
for i := 0; i < 20; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
nick := "concurrent_" + string(rune('a'+i))
body, _ := json.Marshal(map[string]string{"nick": nick})
resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body))
if err != nil {
errors <- err
return
}
resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
errors <- err
}
}(i)
}
wg.Wait()
close(errors)
for err := range errors {
if err != nil {
t.Fatalf("concurrent session creation error: %v", err)
}
}
}
func TestServerInfo(t *testing.T) {
ts := newTestServer(t)
resp, err := http.Get(ts.url("/api/v1/server"))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
}
func TestHealthcheck(t *testing.T) {
ts := newTestServer(t)
resp, err := http.Get(ts.url("/.well-known/healthcheck.json"))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
var result map[string]any
json.NewDecoder(resp.Body).Decode(&result)
if result["status"] != "ok" {
t.Fatalf("expected ok status, got %v", result["status"])
}
}
func TestNickBroadcastToChannels(t *testing.T) {
ts := newTestServer(t)
_, aliceToken := ts.createSession("nick_a")
_, bobToken := ts.createSession("nick_b")
ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#nicktest"})
ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#nicktest"})
// Drain
_, lastID := ts.pollMessages(bobToken, 0, 0)
// Alice changes nick
ts.sendCommand(aliceToken, map[string]any{"command": "NICK", "body": []string{"nick_a_new"}})
// Bob should see it
msgs, _ := ts.pollMessages(bobToken, lastID, 0)
found := false
for _, m := range msgs {
if m["command"] == "NICK" && m["from"] == "nick_a" {
found = true
}
}
if !found {
t.Fatalf("bob didn't get nick change: %v", msgs)
}
}
// Broker unit tests
func TestBrokerNotifyWithoutWaiters(t *testing.T) {
b := broker.New()
// Should not panic
b.Notify(999)
}
func TestBrokerWaitAndNotify(t *testing.T) {
b := broker.New()
ch := b.Wait(1)
go func() {
time.Sleep(50 * time.Millisecond)
b.Notify(1)
}()
select {
case <-ch:
// ok
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for notification")
}
}
func TestBrokerRemove(t *testing.T) {
b := broker.New()
ch := b.Wait(1)
b.Remove(1, ch)
// Notify should not send to removed channel
b.Notify(1)
select {
case <-ch:
t.Fatal("should not receive after remove")
case <-time.After(100 * time.Millisecond):
// ok
}
}