diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index 253bd9c..8cdbab3 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -27,9 +27,9 @@ import ( // testServer wraps a test HTTP server with helper methods. type testServer struct { - srv *httptest.Server - t *testing.T - fxApp *fxtest.App + srv *httptest.Server + t *testing.T + fxApp *fxtest.App } func newTestServer(t *testing.T) *testServer { @@ -39,32 +39,94 @@ func newTestServer(t *testing.T) *testServer { app := fxtest.New(t, fx.Provide( - func() *globals.Globals { return &globals.Globals{Appname: "chat-test", Version: "test"} }, + func() *globals.Globals { + return &globals.Globals{ + Appname: "chat-test", + Version: "test", + } + }, logger.New, - func(lc fx.Lifecycle, g *globals.Globals, l *logger.Logger) (*config.Config, error) { - return config.New(lc, config.Params{Globals: g, Logger: l}) + func( + lc fx.Lifecycle, + g *globals.Globals, + l *logger.Logger, + ) (*config.Config, error) { + return config.New(lc, config.Params{ + Globals: g, Logger: l, + }) }, - func(lc fx.Lifecycle, l *logger.Logger, c *config.Config) (*db.Database, error) { - return db.New(lc, db.Params{Logger: l, Config: c}) + func( + lc fx.Lifecycle, + l *logger.Logger, + c *config.Config, + ) (*db.Database, error) { + return db.New(lc, db.Params{ + Logger: l, Config: c, + }) }, - func(lc fx.Lifecycle, g *globals.Globals, c *config.Config, l *logger.Logger, d *db.Database) (*healthcheck.Healthcheck, error) { - return healthcheck.New(lc, healthcheck.Params{Globals: g, Config: c, Logger: l, Database: d}) + func( + lc fx.Lifecycle, + g *globals.Globals, + c *config.Config, + l *logger.Logger, + d *db.Database, + ) (*healthcheck.Healthcheck, error) { + return healthcheck.New(lc, healthcheck.Params{ + Globals: g, + Config: c, + Logger: l, + Database: d, + }) }, - func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config) (*middleware.Middleware, error) { - return middleware.New(lc, middleware.Params{Logger: l, Globals: g, Config: c}) + func( + lc fx.Lifecycle, + l *logger.Logger, + g *globals.Globals, + c *config.Config, + ) (*middleware.Middleware, error) { + return middleware.New(lc, middleware.Params{ + Logger: l, + Globals: g, + Config: c, + }) }, - func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config, d *db.Database, hc *healthcheck.Healthcheck) (*handlers.Handlers, error) { - return handlers.New(lc, handlers.Params{Logger: l, Globals: g, Config: c, Database: d, Healthcheck: hc}) + func( + lc fx.Lifecycle, + l *logger.Logger, + g *globals.Globals, + c *config.Config, + d *db.Database, + hc *healthcheck.Healthcheck, + ) (*handlers.Handlers, error) { + return handlers.New(lc, handlers.Params{ + Logger: l, + Globals: g, + Config: c, + Database: d, + Healthcheck: hc, + }) }, - func(lc fx.Lifecycle, l *logger.Logger, g *globals.Globals, c *config.Config, mw *middleware.Middleware, h *handlers.Handlers) (*server.Server, error) { - return server.New(lc, server.Params{Logger: l, Globals: g, Config: c, Middleware: mw, Handlers: h}) + func( + lc fx.Lifecycle, + l *logger.Logger, + g *globals.Globals, + c *config.Config, + mw *middleware.Middleware, + h *handlers.Handlers, + ) (*server.Server, error) { + return server.New(lc, server.Params{ + Logger: l, + Globals: g, + Config: c, + Middleware: mw, + Handlers: h, + }) }, ), fx.Populate(&s), ) app.RequireStart() - // Give the server a moment to set up routes. time.Sleep(100 * time.Millisecond) ts := httptest.NewServer(s) @@ -80,74 +142,150 @@ func (ts *testServer) url(path string) string { return ts.srv.URL + path } -func (ts *testServer) createSession(nick string) (int64, string) { +func (ts *testServer) doReq( + method, url string, body io.Reader, +) (*http.Response, error) { ts.t.Helper() - body, _ := json.Marshal(map[string]string{"nick": nick}) - resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + + req, err := http.NewRequestWithContext( + context.Background(), method, url, body, + ) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + return http.DefaultClient.Do(req) +} + +func (ts *testServer) doReqAuth( + method, url, token string, body io.Reader, +) (*http.Response, error) { + ts.t.Helper() + + req, err := http.NewRequestWithContext( + context.Background(), method, url, body, + ) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + return http.DefaultClient.Do(req) +} + +func (ts *testServer) createSession(nick string) string { + ts.t.Helper() + + body, err := json.Marshal(map[string]string{"nick": nick}) + if err != nil { + ts.t.Fatalf("marshal session: %v", err) + } + + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + ) if err != nil { ts.t.Fatalf("create session: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusCreated { b, _ := io.ReadAll(resp.Body) ts.t.Fatalf("create session: status %d: %s", resp.StatusCode, b) } + var result struct { ID int64 `json:"id"` Token string `json:"token"` } - json.NewDecoder(resp.Body).Decode(&result) - return result.ID, result.Token + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + ts.t.Fatalf("decode session: %v", err) + } + + return result.Token } -func (ts *testServer) sendCommand(token string, cmd map[string]any) (*http.Response, map[string]any) { +func (ts *testServer) sendCommand( + token string, cmd map[string]any, +) (int, map[string]any) { ts.t.Helper() - body, _ := json.Marshal(cmd) - req, _ := http.NewRequest("POST", ts.url("/api/v1/messages"), bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+token) - resp, err := http.DefaultClient.Do(req) + + body, err := json.Marshal(cmd) + if err != nil { + ts.t.Fatalf("marshal command: %v", err) + } + + resp, err := ts.doReqAuth( + http.MethodPost, ts.url("/api/v1/messages"), token, bytes.NewReader(body), + ) if err != nil { ts.t.Fatalf("send command: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + var result map[string]any - json.NewDecoder(resp.Body).Decode(&result) - return resp, result + + _ = json.NewDecoder(resp.Body).Decode(&result) + + return resp.StatusCode, result } -func (ts *testServer) getJSON(token, path string) (*http.Response, map[string]any) { +func (ts *testServer) getJSON( + token, path string, //nolint:unparam +) (int, map[string]any) { ts.t.Helper() - req, _ := http.NewRequest("GET", ts.url(path), nil) - if token != "" { - req.Header.Set("Authorization", "Bearer "+token) - } - resp, err := http.DefaultClient.Do(req) + + resp, err := ts.doReqAuth(http.MethodGet, ts.url(path), token, nil) if err != nil { ts.t.Fatalf("get: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + var result map[string]any - json.NewDecoder(resp.Body).Decode(&result) - return resp, result + + _ = json.NewDecoder(resp.Body).Decode(&result) + + return resp.StatusCode, result } -func (ts *testServer) pollMessages(token string, afterID int64, timeout int) ([]map[string]any, int64) { +func (ts *testServer) pollMessages( + token string, afterID int64, +) ([]map[string]any, int64) { ts.t.Helper() - url := fmt.Sprintf("%s/api/v1/messages?timeout=%d&after=%d", ts.srv.URL, timeout, afterID) - req, _ := http.NewRequestWithContext(context.Background(), "GET", url, nil) - req.Header.Set("Authorization", "Bearer "+token) - resp, err := http.DefaultClient.Do(req) + + url := fmt.Sprintf( + "%s/api/v1/messages?timeout=0&after=%d", ts.srv.URL, afterID, + ) + + resp, err := ts.doReqAuth(http.MethodGet, url, token, nil) if err != nil { ts.t.Fatalf("poll: %v", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + var result struct { Messages []map[string]any `json:"messages"` - LastID json.Number `json:"last_id"` + LastID json.Number `json:"last_id"` //nolint:tagliatelle } - json.NewDecoder(resp.Body).Decode(&result) + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + ts.t.Fatalf("decode poll: %v", err) + } + lastID, _ := result.LastID.Int64() + return result.Messages, lastID } @@ -157,66 +295,97 @@ func TestCreateSession(t *testing.T) { ts := newTestServer(t) t.Run("valid nick", func(t *testing.T) { - _, token := ts.createSession("alice") + token := ts.createSession("alice") if token == "" { t.Fatal("expected token") } }) t.Run("duplicate nick", func(t *testing.T) { - body, _ := json.Marshal(map[string]string{"nick": "alice"}) - resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + body, err := json.Marshal(map[string]string{"nick": "alice"}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusConflict { t.Fatalf("expected 409, got %d", resp.StatusCode) } }) t.Run("empty nick", func(t *testing.T) { - body, _ := json.Marshal(map[string]string{"nick": ""}) - resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + body, err := json.Marshal(map[string]string{"nick": ""}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusBadRequest { t.Fatalf("expected 400, got %d", resp.StatusCode) } }) t.Run("invalid nick chars", func(t *testing.T) { - body, _ := json.Marshal(map[string]string{"nick": "hello world"}) - resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + body, err := json.Marshal(map[string]string{"nick": "hello world"}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusBadRequest { t.Fatalf("expected 400, got %d", resp.StatusCode) } }) t.Run("nick starting with number", func(t *testing.T) { - body, _ := json.Marshal(map[string]string{"nick": "123abc"}) - resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + body, err := json.Marshal(map[string]string{"nick": "123abc"}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusBadRequest { t.Fatalf("expected 400, got %d", resp.StatusCode) } }) t.Run("malformed json", func(t *testing.T) { - resp, err := http.Post(ts.url("/api/v1/session"), "application/json", strings.NewReader("{bad")) + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), strings.NewReader("{bad"), + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusBadRequest { t.Fatalf("expected 400, got %d", resp.StatusCode) } @@ -227,25 +396,27 @@ func TestAuth(t *testing.T) { ts := newTestServer(t) t.Run("no auth header", func(t *testing.T) { - resp, _ := ts.getJSON("", "/api/v1/state") - if resp.StatusCode != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", resp.StatusCode) + status, _ := ts.getJSON("", "/api/v1/state") + if status != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", status) } }) t.Run("bad token", func(t *testing.T) { - resp, _ := ts.getJSON("invalid-token-12345", "/api/v1/state") - if resp.StatusCode != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", resp.StatusCode) + status, _ := ts.getJSON("invalid-token-12345", "/api/v1/state") + if status != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", status) } }) t.Run("valid token", func(t *testing.T) { - _, token := ts.createSession("authtest") - resp, result := ts.getJSON(token, "/api/v1/state") - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + token := ts.createSession("authtest") + + status, result := ts.getJSON(token, "/api/v1/state") + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) } + if result["nick"] != "authtest" { t.Fatalf("expected nick authtest, got %v", result["nick"]) } @@ -254,268 +425,285 @@ func TestAuth(t *testing.T) { func TestJoinAndPart(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("bob") + token := ts.createSession("bob") t.Run("join channel", func(t *testing.T) { - resp, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#test"}) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + status, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#test"}) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", status, result) } + if result["channel"] != "#test" { t.Fatalf("expected #test, got %v", result["channel"]) } }) t.Run("join without hash", func(t *testing.T) { - resp, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "other"}) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + status, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "other"}) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", status, result) } + if result["channel"] != "#other" { t.Fatalf("expected #other, got %v", result["channel"]) } }) t.Run("part channel", func(t *testing.T) { - resp, result := ts.sendCommand(token, map[string]any{"command": "PART", "to": "#test"}) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + status, result := ts.sendCommand(token, map[string]any{"command": "PART", "to": "#test"}) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", status, result) } + if result["channel"] != "#test" { t.Fatalf("expected #test, got %v", result["channel"]) } }) t.Run("join missing to", func(t *testing.T) { - resp, _ := ts.sendCommand(token, map[string]any{"command": "JOIN"}) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + status, _ := ts.sendCommand(token, map[string]any{"command": "JOIN"}) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } }) } func TestPrivmsg(t *testing.T) { ts := newTestServer(t) - _, aliceToken := ts.createSession("alice_msg") - _, bobToken := ts.createSession("bob_msg") + aliceToken := ts.createSession("alice_msg") + bobToken := ts.createSession("bob_msg") - // Both join #chat ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#chat"}) ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#chat"}) - // Drain existing messages (JOINs) - _, _ = ts.pollMessages(aliceToken, 0, 0) - _, bobLastID := ts.pollMessages(bobToken, 0, 0) + _, _ = ts.pollMessages(aliceToken, 0) + _, bobLastID := ts.pollMessages(bobToken, 0) t.Run("send channel message", func(t *testing.T) { - resp, result := ts.sendCommand(aliceToken, map[string]any{ + status, result := ts.sendCommand(aliceToken, map[string]any{ "command": "PRIVMSG", "to": "#chat", "body": []string{"hello world"}, }) - if resp.StatusCode != http.StatusCreated { - t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) + if status != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", status, result) } + if result["id"] == nil || result["id"] == "" { t.Fatal("expected message id") } }) t.Run("bob receives message", func(t *testing.T) { - msgs, _ := ts.pollMessages(bobToken, bobLastID, 0) + msgs, _ := ts.pollMessages(bobToken, bobLastID) + found := false + for _, m := range msgs { if m["command"] == "PRIVMSG" && m["from"] == "alice_msg" { found = true + break } } + if !found { t.Fatalf("bob didn't receive alice's message: %v", msgs) } }) t.Run("missing body", func(t *testing.T) { - resp, _ := ts.sendCommand(aliceToken, map[string]any{ + status, _ := ts.sendCommand(aliceToken, map[string]any{ "command": "PRIVMSG", "to": "#chat", }) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } }) t.Run("missing to", func(t *testing.T) { - resp, _ := ts.sendCommand(aliceToken, map[string]any{ + status, _ := ts.sendCommand(aliceToken, map[string]any{ "command": "PRIVMSG", "body": []string{"hello"}, }) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } }) } func TestDM(t *testing.T) { ts := newTestServer(t) - _, aliceToken := ts.createSession("alice_dm") - _, bobToken := ts.createSession("bob_dm") + aliceToken := ts.createSession("alice_dm") + bobToken := ts.createSession("bob_dm") t.Run("send DM", func(t *testing.T) { - resp, result := ts.sendCommand(aliceToken, map[string]any{ + status, result := ts.sendCommand(aliceToken, map[string]any{ "command": "PRIVMSG", "to": "bob_dm", "body": []string{"hey bob"}, }) - if resp.StatusCode != http.StatusCreated { - t.Fatalf("expected 201, got %d: %v", resp.StatusCode, result) + if status != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", status, result) } }) t.Run("bob receives DM", func(t *testing.T) { - msgs, _ := ts.pollMessages(bobToken, 0, 0) + msgs, _ := ts.pollMessages(bobToken, 0) + found := false + for _, m := range msgs { if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" { found = true } } + if !found { t.Fatal("bob didn't receive DM") } }) t.Run("alice gets echo", func(t *testing.T) { - msgs, _ := ts.pollMessages(aliceToken, 0, 0) + msgs, _ := ts.pollMessages(aliceToken, 0) + found := false + for _, m := range msgs { if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" && m["to"] == "bob_dm" { found = true } } + if !found { t.Fatal("alice didn't get DM echo") } }) t.Run("DM to nonexistent user", func(t *testing.T) { - resp, _ := ts.sendCommand(aliceToken, map[string]any{ + status, _ := ts.sendCommand(aliceToken, map[string]any{ "command": "PRIVMSG", "to": "nobody", "body": []string{"hello?"}, }) - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("expected 404, got %d", resp.StatusCode) + if status != http.StatusNotFound { + t.Fatalf("expected 404, got %d", status) } }) } func TestNick(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("nick_test") + token := ts.createSession("nick_test") t.Run("change nick", func(t *testing.T) { - resp, result := ts.sendCommand(token, map[string]any{ + status, result := ts.sendCommand(token, map[string]any{ "command": "NICK", "body": []string{"newnick"}, }) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", status, result) } + if result["nick"] != "newnick" { t.Fatalf("expected newnick, got %v", result["nick"]) } }) t.Run("nick same as current", func(t *testing.T) { - resp, _ := ts.sendCommand(token, map[string]any{ + status, _ := ts.sendCommand(token, map[string]any{ "command": "NICK", "body": []string{"newnick"}, }) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) } }) t.Run("nick collision", func(t *testing.T) { ts.createSession("taken_nick") - resp, _ := ts.sendCommand(token, map[string]any{ + + status, _ := ts.sendCommand(token, map[string]any{ "command": "NICK", "body": []string{"taken_nick"}, }) - if resp.StatusCode != http.StatusConflict { - t.Fatalf("expected 409, got %d", resp.StatusCode) + if status != http.StatusConflict { + t.Fatalf("expected 409, got %d", status) } }) t.Run("invalid nick", func(t *testing.T) { - resp, _ := ts.sendCommand(token, map[string]any{ + status, _ := ts.sendCommand(token, map[string]any{ "command": "NICK", "body": []string{"bad nick!"}, }) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } }) t.Run("empty body", func(t *testing.T) { - resp, _ := ts.sendCommand(token, map[string]any{ + status, _ := ts.sendCommand(token, map[string]any{ "command": "NICK", }) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } }) } func TestTopic(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("topic_user") + token := ts.createSession("topic_user") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#topictest"}) t.Run("set topic", func(t *testing.T) { - resp, result := ts.sendCommand(token, map[string]any{ + status, result := ts.sendCommand(token, map[string]any{ "command": "TOPIC", "to": "#topictest", "body": []string{"Hello World Topic"}, }) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", status, result) } + if result["topic"] != "Hello World Topic" { t.Fatalf("expected topic, got %v", result["topic"]) } }) t.Run("missing to", func(t *testing.T) { - resp, _ := ts.sendCommand(token, map[string]any{ + status, _ := ts.sendCommand(token, map[string]any{ "command": "TOPIC", "body": []string{"topic"}, }) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } }) t.Run("missing body", func(t *testing.T) { - resp, _ := ts.sendCommand(token, map[string]any{ + status, _ := ts.sendCommand(token, map[string]any{ "command": "TOPIC", "to": "#topictest", }) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } }) } func TestPing(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("ping_user") + token := ts.createSession("ping_user") - resp, result := ts.sendCommand(token, map[string]any{"command": "PING"}) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.StatusCode) + status, result := ts.sendCommand(token, map[string]any{"command": "PING"}) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) } + if result["command"] != "PONG" { t.Fatalf("expected PONG, got %v", result["command"]) } @@ -523,89 +711,91 @@ func TestPing(t *testing.T) { func TestQuit(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("quitter") - _, observerToken := ts.createSession("observer") + token := ts.createSession("quitter") + observerToken := ts.createSession("observer") - // Both join a channel ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#quitchan"}) ts.sendCommand(observerToken, map[string]any{"command": "JOIN", "to": "#quitchan"}) - // Drain messages - _, lastID := ts.pollMessages(observerToken, 0, 0) + _, lastID := ts.pollMessages(observerToken, 0) - // Quit - resp, result := ts.sendCommand(token, map[string]any{"command": "QUIT"}) - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", resp.StatusCode, result) + status, result := ts.sendCommand(token, map[string]any{"command": "QUIT"}) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", status, result) } - // Observer should get QUIT message - msgs, _ := ts.pollMessages(observerToken, lastID, 0) + msgs, _ := ts.pollMessages(observerToken, lastID) + found := false + for _, m := range msgs { if m["command"] == "QUIT" && m["from"] == "quitter" { found = true } } + if !found { t.Fatalf("observer didn't get QUIT: %v", msgs) } - // Token should be invalid now - resp2, _ := ts.getJSON(token, "/api/v1/state") - if resp2.StatusCode != http.StatusUnauthorized { - t.Fatalf("expected 401 after quit, got %d", resp2.StatusCode) + status2, _ := ts.getJSON(token, "/api/v1/state") + if status2 != http.StatusUnauthorized { + t.Fatalf("expected 401 after quit, got %d", status2) } } func TestUnknownCommand(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("cmdtest") + token := ts.createSession("cmdtest") - resp, result := ts.sendCommand(token, map[string]any{"command": "BOGUS"}) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d: %v", resp.StatusCode, result) + status, result := ts.sendCommand(token, map[string]any{"command": "BOGUS"}) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %v", status, result) } } func TestEmptyCommand(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("emptycmd") + token := ts.createSession("emptycmd") - resp, _ := ts.sendCommand(token, map[string]any{"command": ""}) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) + status, _ := ts.sendCommand(token, map[string]any{"command": ""}) + if status != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", status) } } func TestHistory(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("historian") + token := ts.createSession("historian") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#history"}) - // Send some messages - for i := 0; i < 5; i++ { + for range 5 { ts.sendCommand(token, map[string]any{ "command": "PRIVMSG", "to": "#history", - "body": []string{"msg " + string(rune('A'+i))}, + "body": []string{"test message"}, }) } - req, _ := http.NewRequest("GET", ts.url("/api/v1/history?target=%23history&limit=3"), nil) - req.Header.Set("Authorization", "Bearer "+token) - resp, err := http.DefaultClient.Do(req) + resp, err := ts.doReqAuth( + http.MethodGet, ts.url("/api/v1/history?target=%23history&limit=3"), token, nil, + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } var msgs []map[string]any - json.NewDecoder(resp.Body).Decode(&msgs) + + if err := json.NewDecoder(resp.Body).Decode(&msgs); err != nil { + t.Fatalf("decode history: %v", err) + } + if len(msgs) != 3 { t.Fatalf("expected 3 messages, got %d", len(msgs)) } @@ -613,29 +803,36 @@ func TestHistory(t *testing.T) { func TestChannelList(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("lister") + token := ts.createSession("lister") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#listchan"}) - req, _ := http.NewRequest("GET", ts.url("/api/v1/channels"), nil) - req.Header.Set("Authorization", "Bearer "+token) - resp, err := http.DefaultClient.Do(req) + resp, err := ts.doReqAuth( + http.MethodGet, ts.url("/api/v1/channels"), token, nil, + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } var channels []map[string]any - json.NewDecoder(resp.Body).Decode(&channels) + + if err := json.NewDecoder(resp.Body).Decode(&channels); err != nil { + t.Fatalf("decode channels: %v", err) + } + found := false + for _, ch := range channels { if ch["name"] == "#listchan" { found = true } } + if !found { t.Fatal("channel not in list") } @@ -643,16 +840,17 @@ func TestChannelList(t *testing.T) { func TestChannelMembers(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("membertest") + token := ts.createSession("membertest") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#members"}) - req, _ := http.NewRequest("GET", ts.url("/api/v1/channels/members/members"), nil) - req.Header.Set("Authorization", "Bearer "+token) - resp, err := http.DefaultClient.Do(req) + resp, err := ts.doReqAuth( + http.MethodGet, ts.url("/api/v1/channels/members/members"), token, nil, + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) @@ -661,41 +859,44 @@ func TestChannelMembers(t *testing.T) { func TestLongPoll(t *testing.T) { ts := newTestServer(t) - _, aliceToken := ts.createSession("lp_alice") - _, bobToken := ts.createSession("lp_bob") + aliceToken := ts.createSession("lp_alice") + bobToken := ts.createSession("lp_bob") ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#longpoll"}) ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#longpoll"}) - // Drain existing messages - _, lastID := ts.pollMessages(bobToken, 0, 0) + _, lastID := ts.pollMessages(bobToken, 0) - // Start long-poll in goroutine var wg sync.WaitGroup + var pollMsgs []map[string]any wg.Add(1) + go func() { defer wg.Done() - url := fmt.Sprintf("%s/api/v1/messages?timeout=5&after=%d", ts.srv.URL, lastID) - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+bobToken) - resp, err := http.DefaultClient.Do(req) + + url := fmt.Sprintf( + "%s/api/v1/messages?timeout=5&after=%d", ts.srv.URL, lastID, + ) + + resp, err := ts.doReqAuth(http.MethodGet, url, bobToken, nil) if err != nil { return } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + var result struct { Messages []map[string]any `json:"messages"` } - json.NewDecoder(resp.Body).Decode(&result) + + _ = json.NewDecoder(resp.Body).Decode(&result) + pollMsgs = result.Messages }() - // Give the long-poll a moment to start time.Sleep(200 * time.Millisecond) - // Send a message ts.sendCommand(aliceToken, map[string]any{ "command": "PRIVMSG", "to": "#longpoll", @@ -705,11 +906,13 @@ func TestLongPoll(t *testing.T) { wg.Wait() found := false + for _, m := range pollMsgs { if m["command"] == "PRIVMSG" && m["from"] == "lp_alice" { found = true } } + if !found { t.Fatalf("long-poll didn't receive message: %v", pollMsgs) } @@ -717,21 +920,24 @@ func TestLongPoll(t *testing.T) { func TestLongPollTimeout(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("lp_timeout") + token := ts.createSession("lp_timeout") start := time.Now() - req, _ := http.NewRequest("GET", ts.url("/api/v1/messages?timeout=1"), nil) - req.Header.Set("Authorization", "Bearer "+token) - resp, err := http.DefaultClient.Do(req) + + resp, err := ts.doReqAuth( + http.MethodGet, ts.url("/api/v1/messages?timeout=1"), token, nil, + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + elapsed := time.Since(start) if elapsed < 900*time.Millisecond { t.Fatalf("long-poll returned too fast: %v", elapsed) } + if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } @@ -739,21 +945,25 @@ func TestLongPollTimeout(t *testing.T) { func TestEphemeralChannelCleanup(t *testing.T) { ts := newTestServer(t) - _, token := ts.createSession("ephemeral") + token := ts.createSession("ephemeral") + ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#ephemeral"}) ts.sendCommand(token, map[string]any{"command": "PART", "to": "#ephemeral"}) - // Channel should be gone - req, _ := http.NewRequest("GET", ts.url("/api/v1/channels"), nil) - req.Header.Set("Authorization", "Bearer "+token) - resp, err := http.DefaultClient.Do(req) + resp, err := ts.doReqAuth( + http.MethodGet, ts.url("/api/v1/channels"), token, nil, + ) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() var channels []map[string]any - json.NewDecoder(resp.Body).Decode(&channels) + + if err := json.NewDecoder(resp.Body).Decode(&channels); err != nil { + t.Fatalf("decode channels: %v", err) + } + for _, ch := range channels { if ch["name"] == "#ephemeral" { t.Fatal("ephemeral channel should have been cleaned up") @@ -765,30 +975,47 @@ func TestConcurrentSessions(t *testing.T) { ts := newTestServer(t) var wg sync.WaitGroup - errors := make(chan error, 20) - for i := 0; i < 20; i++ { + errs := make(chan error, 20) + + for i := range 20 { wg.Add(1) + go func(i int) { defer wg.Done() + nick := "concurrent_" + string(rune('a'+i)) - body, _ := json.Marshal(map[string]string{"nick": nick}) - resp, err := http.Post(ts.url("/api/v1/session"), "application/json", bytes.NewReader(body)) + + body, err := json.Marshal(map[string]string{"nick": nick}) if err != nil { - errors <- err + errs <- fmt.Errorf("marshal: %w", err) + return } - resp.Body.Close() + + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + ) + if err != nil { + errs <- err + + return + } + + _ = resp.Body.Close() + if resp.StatusCode != http.StatusCreated { - errors <- err + errs <- fmt.Errorf( //nolint:err113 + "status %d for %s", resp.StatusCode, nick, + ) } }(i) } wg.Wait() - close(errors) + close(errs) - for err := range errors { + for err := range errs { if err != nil { t.Fatalf("concurrent session creation error: %v", err) } @@ -798,11 +1025,12 @@ func TestConcurrentSessions(t *testing.T) { func TestServerInfo(t *testing.T) { ts := newTestServer(t) - resp, err := http.Get(ts.url("/api/v1/server")) + resp, err := ts.doReq(http.MethodGet, ts.url("/api/v1/server"), nil) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } @@ -811,17 +1039,22 @@ func TestServerInfo(t *testing.T) { func TestHealthcheck(t *testing.T) { ts := newTestServer(t) - resp, err := http.Get(ts.url("/.well-known/healthcheck.json")) + resp, err := ts.doReq(http.MethodGet, ts.url("/.well-known/healthcheck.json"), nil) if err != nil { t.Fatal(err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { t.Fatalf("expected 200, got %d", resp.StatusCode) } var result map[string]any - json.NewDecoder(resp.Body).Decode(&result) + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("decode healthcheck: %v", err) + } + if result["status"] != "ok" { t.Fatalf("expected ok status, got %v", result["status"]) } @@ -829,29 +1062,29 @@ func TestHealthcheck(t *testing.T) { func TestNickBroadcastToChannels(t *testing.T) { ts := newTestServer(t) - _, aliceToken := ts.createSession("nick_a") - _, bobToken := ts.createSession("nick_b") + aliceToken := ts.createSession("nick_a") + bobToken := ts.createSession("nick_b") ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#nicktest"}) ts.sendCommand(bobToken, map[string]any{"command": "JOIN", "to": "#nicktest"}) - // Drain - _, lastID := ts.pollMessages(bobToken, 0, 0) + _, lastID := ts.pollMessages(bobToken, 0) - // Alice changes nick - ts.sendCommand(aliceToken, map[string]any{"command": "NICK", "body": []string{"nick_a_new"}}) + ts.sendCommand(aliceToken, map[string]any{ + "command": "NICK", "body": []string{"nick_a_new"}, + }) + + msgs, _ := ts.pollMessages(bobToken, lastID) - // Bob should see it - msgs, _ := ts.pollMessages(bobToken, lastID, 0) found := false + for _, m := range msgs { if m["command"] == "NICK" && m["from"] == "nick_a" { found = true } } + if !found { t.Fatalf("bob didn't get nick change: %v", msgs) } } - -// Broker tests are in internal/broker/broker_test.go