diff --git a/.golangci.yml b/.golangci.yml index 4ff0955..34a8e31 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,33 +7,26 @@ run: linters: default: all disable: - - exhaustruct - - depguard - - godot - - wsl - - wrapcheck - - varnamelen - - dupl - - paralleltest - - nlreturn - - tagliatelle - - goconst - - funlen - - maintidx - - cyclop - - gocognit - - lll - settings: - lll: - line-length: 88 - funlen: - lines: 80 - statements: 50 - cyclop: - max-complexity: 15 - dupl: - threshold: 100 + # Genuinely incompatible with project patterns + - exhaustruct # Requires all struct fields + - depguard # Dependency allow/block lists + - godot # Requires comments to end with periods + - wsl # Deprecated, replaced by wsl_v5 + - wrapcheck # Too verbose for internal packages + - varnamelen # Short names like db, id are idiomatic Go + +linters-settings: + lll: + line-length: 88 + funlen: + lines: 80 + statements: 50 + cyclop: + max-complexity: 15 + dupl: + threshold: 100 issues: + exclude-use-default: false max-issues-per-linter: 0 max-same-issues: 0 diff --git a/cmd/chat-cli/api/client.go b/cmd/chat-cli/api/client.go index dbb63ba..aea4ed6 100644 --- a/cmd/chat-cli/api/client.go +++ b/cmd/chat-cli/api/client.go @@ -101,17 +101,7 @@ func (c *Client) PollMessages( ) * time.Second, } - params := url.Values{} - if afterID > 0 { - params.Set( - "after", - strconv.FormatInt(afterID, 10), - ) - } - - params.Set("timeout", strconv.Itoa(timeout)) - - path := "/api/v1/messages?" + params.Encode() + path := c.buildPollPath(afterID, timeout) req, err := http.NewRequestWithContext( context.Background(), @@ -125,38 +115,14 @@ func (c *Client) PollMessages( req.Header.Set("Authorization", "Bearer "+c.Token) - resp, err := client.Do(req) + resp, err := client.Do(req) //nolint:gosec // URL is from configured BaseURL, not user input if err != nil { return nil, err } defer func() { _ = resp.Body.Close() }() - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode >= httpErrThreshold { - return nil, fmt.Errorf( - "%w %d: %s", - errHTTP, resp.StatusCode, string(data), - ) - } - - var wrapped MessagesResponse - - err = json.Unmarshal(data, &wrapped) - if err != nil { - return nil, fmt.Errorf( - "decode messages: %w", err, - ) - } - - return &PollResult{ - Messages: wrapped.Messages, - LastID: wrapped.LastID, - }, nil + return c.decodePollResponse(resp) } // JoinChannel joins a channel. @@ -239,6 +205,52 @@ func (c *Client) GetServerInfo() (*ServerInfo, error) { return &info, nil } +func (c *Client) buildPollPath( + afterID int64, timeout int, +) string { + params := url.Values{} + if afterID > 0 { + params.Set( + "after", + strconv.FormatInt(afterID, 10), + ) + } + + params.Set("timeout", strconv.Itoa(timeout)) + + return "/api/v1/messages?" + params.Encode() +} + +func (c *Client) decodePollResponse( + resp *http.Response, +) (*PollResult, error) { + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode >= httpErrThreshold { + return nil, fmt.Errorf( + "%w %d: %s", + errHTTP, resp.StatusCode, string(data), + ) + } + + var wrapped MessagesResponse + + err = json.Unmarshal(data, &wrapped) + if err != nil { + return nil, fmt.Errorf( + "decode messages: %w", err, + ) + } + + return &PollResult{ + Messages: wrapped.Messages, + LastID: wrapped.LastID, + }, nil +} + func (c *Client) do( method, path string, body any, @@ -272,7 +284,7 @@ func (c *Client) do( ) } - resp, err := c.HTTPClient.Do(req) + resp, err := c.HTTPClient.Do(req) //nolint:gosec // URL is from configured BaseURL, not user input if err != nil { return nil, fmt.Errorf("http: %w", err) } diff --git a/cmd/chat-cli/main.go b/cmd/chat-cli/main.go index d263539..6da0038 100644 --- a/cmd/chat-cli/main.go +++ b/cmd/chat-cli/main.go @@ -123,36 +123,40 @@ func (a *App) handleCommand(text string) { } func (a *App) dispatchCommand(cmd, args string) { - switch cmd { - case "/connect": - a.cmdConnect(args) - case "/nick": - a.cmdNick(args) - case "/join": - a.cmdJoin(args) - case "/part": - a.cmdPart(args) - case "/msg": - a.cmdMsg(args) - case "/query": - a.cmdQuery(args) - case "/topic": - a.cmdTopic(args) - case "/names": - a.cmdNames() - case "/list": - a.cmdList() - case "/window", "/w": - a.cmdWindow(args) - case "/quit": - a.cmdQuit() - case "/help": - a.cmdHelp() - default: - a.ui.AddStatus( - "[red]Unknown command: " + cmd, - ) + argCmds := map[string]func(string){ + "/connect": a.cmdConnect, + "/nick": a.cmdNick, + "/join": a.cmdJoin, + "/part": a.cmdPart, + "/msg": a.cmdMsg, + "/query": a.cmdQuery, + "/topic": a.cmdTopic, + "/window": a.cmdWindow, + "/w": a.cmdWindow, } + + if fn, ok := argCmds[cmd]; ok { + fn(args) + + return + } + + noArgCmds := map[string]func(){ + "/names": a.cmdNames, + "/list": a.cmdList, + "/quit": a.cmdQuit, + "/help": a.cmdHelp, + } + + if fn, ok := noArgCmds[cmd]; ok { + fn() + + return + } + + a.ui.AddStatus( + "[red]Unknown command: " + cmd, + ) } func (a *App) cmdConnect(serverURL string) { diff --git a/cmd/chat-cli/ui.go b/cmd/chat-cli/ui.go index f27b50e..847114b 100644 --- a/cmd/chat-cli/ui.go +++ b/cmd/chat-cli/ui.go @@ -80,6 +80,7 @@ func (ui *UI) AddLine(bufferName, line string) { cur := ui.buffers[ui.currentBuffer] if cur != buf { buf.Unread++ + ui.refreshStatusBar() } diff --git a/internal/db/queries_test.go b/internal/db/queries_test.go index 99c3f36..0cae346 100644 --- a/internal/db/queries_test.go +++ b/internal/db/queries_test.go @@ -349,10 +349,12 @@ func TestSetTopic(t *testing.T) { } } -func TestInsertAndPollMessages(t *testing.T) { - t.Parallel() +func insertTestMessage( + t *testing.T, + database *db.Database, +) (int64, int64) { + t.Helper() - database := setupTestDB(t) ctx := t.Context() uid, _, err := database.CreateUser(ctx, "poller") @@ -374,10 +376,19 @@ func TestInsertAndPollMessages(t *testing.T) { t.Fatal(err) } + return uid, dbID +} + +func TestInsertAndPollMessages(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + uid, _ := insertTestMessage(t, database) + const batchSize = 10 msgs, lastQID, err := database.PollMessages( - ctx, uid, 0, batchSize, + t.Context(), uid, 0, batchSize, ) if err != nil { t.Fatal(err) @@ -400,7 +411,7 @@ func TestInsertAndPollMessages(t *testing.T) { } msgs, _, _ = database.PollMessages( - ctx, uid, lastQID, batchSize, + t.Context(), uid, lastQID, batchSize, ) if len(msgs) != 0 { diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 21c9dad..065467a 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -49,8 +49,8 @@ func (s *Handlers) requireAuth( ) (int64, string, bool) { uid, nick, err := s.authUser(r) if err != nil { - s.respondJSON(w, r, - map[string]string{"error": "unauthorized"}, + s.respondError(w, r, + "unauthorized", http.StatusUnauthorized) return 0, "", false @@ -120,10 +120,8 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc { err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - s.respondJSON(w, r, - map[string]string{ - "error": "invalid request body", - }, + s.respondError(w, r, + "invalid request body", http.StatusBadRequest) return @@ -132,10 +130,8 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc { req.Nick = strings.TrimSpace(req.Nick) if !validNickRe.MatchString(req.Nick) { - s.respondJSON(w, r, - map[string]string{ - "error": "invalid nick format", - }, + s.respondError(w, r, + "invalid nick format", http.StatusBadRequest) return @@ -145,21 +141,7 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc { r.Context(), req.Nick, ) if err != nil { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondJSON(w, r, - map[string]string{ - "error": "nick already taken", - }, - http.StatusConflict) - - return - } - - s.log.Error("create user failed", "error", err) - - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, - http.StatusInternalServerError) + s.handleCreateUserError(w, r, err) return } @@ -170,6 +152,36 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc { } } +func (s *Handlers) respondError( + w http.ResponseWriter, + r *http.Request, + msg string, + code int, +) { + s.respondJSON(w, r, + map[string]string{"error": msg}, code) +} + +func (s *Handlers) handleCreateUserError( + w http.ResponseWriter, + r *http.Request, + err error, +) { + if strings.Contains(err.Error(), "UNIQUE") { + s.respondError(w, r, + "nick already taken", + http.StatusConflict) + + return + } + + s.log.Error("create user failed", "error", err) + + s.respondError(w, r, + "internal error", + http.StatusInternalServerError) +} + // HandleState returns the current user's info and channels. func (s *Handlers) HandleState() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -184,8 +196,8 @@ func (s *Handlers) HandleState() http.HandlerFunc { if err != nil { s.log.Error("list channels failed", "error", err) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -215,8 +227,8 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc { "list all channels failed", "error", err, ) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -240,10 +252,8 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc { r.Context(), name, ) if err != nil { - s.respondJSON(w, r, - map[string]string{ - "error": "channel not found", - }, + s.respondError(w, r, + "channel not found", http.StatusNotFound) return @@ -257,8 +267,8 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc { "channel members failed", "error", err, ) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -299,8 +309,8 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc { "poll messages failed", "error", err, ) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -350,8 +360,8 @@ func (s *Handlers) longPoll( if err != nil { s.log.Error("poll messages failed", "error", err) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -382,10 +392,8 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - s.respondJSON(w, r, - map[string]string{ - "error": "invalid request body", - }, + s.respondError(w, r, + "invalid request body", http.StatusBadRequest) return @@ -397,10 +405,8 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { req.To = strings.TrimSpace(req.To) if req.Command == "" { - s.respondJSON(w, r, - map[string]string{ - "error": "command required", - }, + s.respondError(w, r, + "command required", http.StatusBadRequest) return @@ -476,8 +482,8 @@ func (s *Handlers) handlePrivmsg( bodyLines func() []string, ) { if to == "" { - s.respondJSON(w, r, - map[string]string{"error": "to field required"}, + s.respondError(w, r, + "to field required", http.StatusBadRequest) return @@ -485,8 +491,8 @@ func (s *Handlers) handlePrivmsg( lines := bodyLines() if len(lines) == 0 { - s.respondJSON(w, r, - map[string]string{"error": "body required"}, + s.respondError(w, r, + "body required", http.StatusBadRequest) return @@ -514,8 +520,8 @@ func (s *Handlers) handleChannelMsg( r.Context(), to, ) if err != nil { - s.respondJSON(w, r, - map[string]string{"error": "channel not found"}, + s.respondError(w, r, + "channel not found", http.StatusNotFound) return @@ -529,8 +535,8 @@ func (s *Handlers) handleChannelMsg( "get channel members failed", "error", err, ) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -542,8 +548,8 @@ func (s *Handlers) handleChannelMsg( if err != nil { s.log.Error("send message failed", "error", err) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -565,8 +571,8 @@ func (s *Handlers) handleDirectMsg( r.Context(), to, ) if err != nil { - s.respondJSON(w, r, - map[string]string{"error": "user not found"}, + s.respondError(w, r, + "user not found", http.StatusNotFound) return @@ -583,8 +589,8 @@ func (s *Handlers) handleDirectMsg( if err != nil { s.log.Error("send dm failed", "error", err) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return @@ -595,6 +601,27 @@ func (s *Handlers) handleDirectMsg( http.StatusCreated) } +func normalizeChannel(name string) string { + if !strings.HasPrefix(name, "#") { + return "#" + name + } + + return name +} + +func (s *Handlers) internalError( + w http.ResponseWriter, + r *http.Request, + msg string, + err error, +) { + s.log.Error(msg, "error", err) + + s.respondError(w, r, + "internal error", + http.StatusInternalServerError) +} + func (s *Handlers) handleJoin( w http.ResponseWriter, r *http.Request, @@ -602,23 +629,18 @@ func (s *Handlers) handleJoin( nick, to string, ) { if to == "" { - s.respondJSON(w, r, - map[string]string{"error": "to field required"}, + s.respondError(w, r, + "to field required", http.StatusBadRequest) return } - channel := to - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } + channel := normalizeChannel(to) if !validChannelRe.MatchString(channel) { - s.respondJSON(w, r, - map[string]string{ - "error": "invalid channel name", - }, + s.respondError(w, r, + "invalid channel name", http.StatusBadRequest) return @@ -628,13 +650,8 @@ func (s *Handlers) handleJoin( r.Context(), channel, ) if err != nil { - s.log.Error( - "get/create channel failed", "error", err, - ) - - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, - http.StatusInternalServerError) + s.internalError(w, r, + "get/create channel failed", err) return } @@ -643,11 +660,8 @@ func (s *Handlers) handleJoin( r.Context(), chID, uid, ) if err != nil { - s.log.Error("join channel failed", "error", err) - - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, - http.StatusInternalServerError) + s.internalError(w, r, + "join channel failed", err) return } @@ -676,31 +690,36 @@ func (s *Handlers) handlePart( body json.RawMessage, ) { if to == "" { - s.respondJSON(w, r, - map[string]string{"error": "to field required"}, + s.respondError(w, r, + "to field required", http.StatusBadRequest) return } - channel := to - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } + channel := normalizeChannel(to) chID, err := s.params.Database.GetChannelByName( r.Context(), channel, ) if err != nil { - s.respondJSON(w, r, - map[string]string{ - "error": "channel not found", - }, + s.respondError(w, r, + "channel not found", http.StatusNotFound) return } + s.partAndCleanup(w, r, chID, uid, nick, channel, body) +} + +func (s *Handlers) partAndCleanup( + w http.ResponseWriter, + r *http.Request, + chID, uid int64, + nick, channel string, + body json.RawMessage, +) { memberIDs, _ := s.params.Database.GetChannelMemberIDs( r.Context(), chID, ) @@ -709,15 +728,12 @@ func (s *Handlers) handlePart( r, "PART", nick, channel, body, memberIDs, ) - err = s.params.Database.PartChannel( + err := s.params.Database.PartChannel( r.Context(), chID, uid, ) if err != nil { - s.log.Error("part channel failed", "error", err) - - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, - http.StatusInternalServerError) + s.internalError(w, r, + "part channel failed", err) return } @@ -743,10 +759,8 @@ func (s *Handlers) handleNick( ) { lines := bodyLines() if len(lines) == 0 { - s.respondJSON(w, r, - map[string]string{ - "error": "body required (new nick)", - }, + s.respondError(w, r, + "body required (new nick)", http.StatusBadRequest) return @@ -755,8 +769,8 @@ func (s *Handlers) handleNick( newNick := strings.TrimSpace(lines[0]) if !validNickRe.MatchString(newNick) { - s.respondJSON(w, r, - map[string]string{"error": "invalid nick"}, + s.respondError(w, r, + "invalid nick", http.StatusBadRequest) return @@ -776,21 +790,7 @@ func (s *Handlers) handleNick( r.Context(), uid, newNick, ) if err != nil { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondJSON(w, r, - map[string]string{ - "error": "nick already in use", - }, - http.StatusConflict) - - return - } - - s.log.Error("change nick failed", "error", err) - - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, - http.StatusInternalServerError) + s.handleChangeNickError(w, r, err) return } @@ -804,6 +804,22 @@ func (s *Handlers) handleNick( http.StatusOK) } +func (s *Handlers) handleChangeNickError( + w http.ResponseWriter, + r *http.Request, + err error, +) { + if strings.Contains(err.Error(), "UNIQUE") { + s.respondError(w, r, + "nick already in use", + http.StatusConflict) + + return + } + + s.internalError(w, r, "change nick failed", err) +} + func (s *Handlers) broadcastNick( r *http.Request, uid int64, @@ -858,8 +874,8 @@ func (s *Handlers) handleTopic( bodyLines func() []string, ) { if to == "" { - s.respondJSON(w, r, - map[string]string{"error": "to field required"}, + s.respondError(w, r, + "to field required", http.StatusBadRequest) return @@ -867,43 +883,40 @@ func (s *Handlers) handleTopic( lines := bodyLines() if len(lines) == 0 { - s.respondJSON(w, r, - map[string]string{ - "error": "body required (topic text)", - }, + s.respondError(w, r, + "body required (topic text)", http.StatusBadRequest) return } topic := strings.Join(lines, " ") - - channel := to - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } + channel := normalizeChannel(to) err := s.params.Database.SetTopic( r.Context(), channel, topic, ) if err != nil { - s.log.Error("set topic failed", "error", err) - - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, - http.StatusInternalServerError) + s.internalError(w, r, "set topic failed", err) return } + s.broadcastTopic(w, r, nick, channel, topic, body) +} + +func (s *Handlers) broadcastTopic( + w http.ResponseWriter, + r *http.Request, + nick, channel, topic string, + body json.RawMessage, +) { chID, err := s.params.Database.GetChannelByName( r.Context(), channel, ) if err != nil { - s.respondJSON(w, r, - map[string]string{ - "error": "channel not found", - }, + s.respondError(w, r, + "channel not found", http.StatusNotFound) return @@ -991,10 +1004,8 @@ func (s *Handlers) HandleGetHistory() http.HandlerFunc { target := r.URL.Query().Get("target") if target == "" { - s.respondJSON(w, r, - map[string]string{ - "error": "target required", - }, + s.respondError(w, r, + "target required", http.StatusBadRequest) return @@ -1019,8 +1030,8 @@ func (s *Handlers) HandleGetHistory() http.HandlerFunc { "get history failed", "error", err, ) - s.respondJSON(w, r, - map[string]string{"error": "internal error"}, + s.respondError(w, r, + "internal error", http.StatusInternalServerError) return diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index ce3b4b1..da40025 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -26,6 +26,10 @@ import ( "go.uber.org/fx/fxtest" ) +const cmdPrivmsg = "PRIVMSG" + +var viperMu sync.Mutex //nolint:gochecknoglobals // serializes viper access in parallel tests + // testServer wraps a test HTTP server with helper methods. type testServer struct { srv *httptest.Server @@ -33,100 +37,125 @@ type testServer struct { fxApp *fxtest.App } +func testGlobals() *globals.Globals { + return &globals.Globals{ + Appname: "chat-test", + Version: "test", + } +} + +func testConfigFactory( + dbURL string, +) func(fx.Lifecycle, *globals.Globals, *logger.Logger) (*config.Config, error) { + return func( + lc fx.Lifecycle, + g *globals.Globals, + l *logger.Logger, + ) (*config.Config, error) { + viperMu.Lock() + + c, err := config.New(lc, config.Params{ + Globals: g, Logger: l, + }) + + viperMu.Unlock() + + if err != nil { + return nil, err + } + + c.DBURL = dbURL + + return c, nil + } +} + +func testDB( + lc fx.Lifecycle, + l *logger.Logger, + c *config.Config, +) (*db.Database, error) { + return db.New(lc, db.Params{ + Logger: l, Config: c, + }) +} + +func testHealthcheck( + lc fx.Lifecycle, + g *globals.Globals, + c *config.Config, + l *logger.Logger, + d *db.Database, +) (*healthcheck.Healthcheck, error) { + return healthcheck.New(lc, healthcheck.Params{ + Globals: g, + Config: c, + Logger: l, + Database: d, + }) +} + +func testMiddleware( + lc fx.Lifecycle, + l *logger.Logger, + g *globals.Globals, + c *config.Config, +) (*middleware.Middleware, error) { + return middleware.New(lc, middleware.Params{ + Logger: l, + Globals: g, + Config: c, + }) +} + +func testHandlers( + lc fx.Lifecycle, + l *logger.Logger, + g *globals.Globals, + c *config.Config, + d *db.Database, + hc *healthcheck.Healthcheck, +) (*handlers.Handlers, error) { + return handlers.New(lc, handlers.Params{ + Logger: l, + Globals: g, + Config: c, + Database: d, + Healthcheck: hc, + }) +} + +func testServer2( + lc fx.Lifecycle, + l *logger.Logger, + g *globals.Globals, + c *config.Config, + mw *middleware.Middleware, + h *handlers.Handlers, +) (*server.Server, error) { + return server.New(lc, server.Params{ + Logger: l, + Globals: g, + Config: c, + Middleware: mw, + Handlers: h, + }) +} + func newTestServer(t *testing.T) *testServer { t.Helper() - // Use a unique DB per test to avoid SQLite BUSY and state leaks. dbPath := filepath.Join(t.TempDir(), "test.db") - t.Setenv("DBURL", "file:"+dbPath+"?_journal_mode=WAL&_busy_timeout=5000") + dbURL := "file:" + dbPath + "?_journal_mode=WAL&_busy_timeout=5000" var s *server.Server app := fxtest.New(t, fx.Provide( - func() *globals.Globals { - return &globals.Globals{ - Appname: "chat-test", - Version: "test", - } - }, - logger.New, - func( - lc fx.Lifecycle, - g *globals.Globals, - l *logger.Logger, - ) (*config.Config, error) { - return config.New(lc, config.Params{ - Globals: g, Logger: l, - }) - }, - func( - lc fx.Lifecycle, - l *logger.Logger, - c *config.Config, - ) (*db.Database, error) { - return db.New(lc, db.Params{ - Logger: l, Config: c, - }) - }, - func( - lc fx.Lifecycle, - g *globals.Globals, - c *config.Config, - l *logger.Logger, - d *db.Database, - ) (*healthcheck.Healthcheck, error) { - return healthcheck.New(lc, healthcheck.Params{ - Globals: g, - Config: c, - Logger: l, - Database: d, - }) - }, - func( - lc fx.Lifecycle, - l *logger.Logger, - g *globals.Globals, - c *config.Config, - ) (*middleware.Middleware, error) { - return middleware.New(lc, middleware.Params{ - Logger: l, - Globals: g, - Config: c, - }) - }, - func( - lc fx.Lifecycle, - l *logger.Logger, - g *globals.Globals, - c *config.Config, - d *db.Database, - hc *healthcheck.Healthcheck, - ) (*handlers.Handlers, error) { - return handlers.New(lc, handlers.Params{ - Logger: l, - Globals: g, - Config: c, - Database: d, - Healthcheck: hc, - }) - }, - func( - lc fx.Lifecycle, - l *logger.Logger, - g *globals.Globals, - c *config.Config, - mw *middleware.Middleware, - h *handlers.Handlers, - ) (*server.Server, error) { - return server.New(lc, server.Params{ - Logger: l, - Globals: g, - Config: c, - Middleware: mw, - Handlers: h, - }) - }, + testGlobals, logger.New, + testConfigFactory(dbURL), testDB, + testHealthcheck, testMiddleware, + testHandlers, testServer2, ), fx.Populate(&s), ) @@ -135,6 +164,7 @@ func newTestServer(t *testing.T) *testServer { time.Sleep(100 * time.Millisecond) ts := httptest.NewServer(s) + t.Cleanup(func() { ts.Close() app.RequireStop() @@ -147,14 +177,20 @@ func (ts *testServer) url(path string) string { return ts.srv.URL + path } +func newReqWithCtx( + method, url string, body io.Reader, +) (*http.Request, error) { + return http.NewRequestWithContext( + context.Background(), method, url, body, + ) +} + func (ts *testServer) doReq( method, url string, body io.Reader, ) (*http.Response, error) { ts.t.Helper() - req, err := http.NewRequestWithContext( - context.Background(), method, url, body, - ) + req, err := newReqWithCtx(method, url, body) if err != nil { return nil, fmt.Errorf("new request: %w", err) } @@ -163,7 +199,7 @@ func (ts *testServer) doReq( req.Header.Set("Content-Type", "application/json") } - return http.DefaultClient.Do(req) + return http.DefaultClient.Do(req) //nolint:gosec // test helper with test server URL } func (ts *testServer) doReqAuth( @@ -171,9 +207,7 @@ func (ts *testServer) doReqAuth( ) (*http.Response, error) { ts.t.Helper() - req, err := http.NewRequestWithContext( - context.Background(), method, url, body, - ) + req, err := newReqWithCtx(method, url, body) if err != nil { return nil, fmt.Errorf("new request: %w", err) } @@ -186,7 +220,7 @@ func (ts *testServer) doReqAuth( req.Header.Set("Authorization", "Bearer "+token) } - return http.DefaultClient.Do(req) + return http.DefaultClient.Do(req) //nolint:gosec // test helper with test server URL } func (ts *testServer) createSession(nick string) string { @@ -203,6 +237,7 @@ func (ts *testServer) createSession(nick string) string { if err != nil { ts.t.Fatalf("create session: %v", err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusCreated { @@ -215,7 +250,8 @@ func (ts *testServer) createSession(nick string) string { Token string `json:"token"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { ts.t.Fatalf("decode session: %v", err) } @@ -238,6 +274,7 @@ func (ts *testServer) sendCommand( if err != nil { ts.t.Fatalf("send command: %v", err) } + defer func() { _ = resp.Body.Close() }() var result map[string]any @@ -256,6 +293,7 @@ func (ts *testServer) getJSON( if err != nil { ts.t.Fatalf("get: %v", err) } + defer func() { _ = resp.Body.Close() }() var result map[string]any @@ -278,6 +316,7 @@ func (ts *testServer) pollMessages( if err != nil { ts.t.Fatalf("poll: %v", err) } + defer func() { _ = resp.Body.Close() }() var result struct { @@ -285,7 +324,8 @@ func (ts *testServer) pollMessages( LastID json.Number `json:"last_id"` //nolint:tagliatelle } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { ts.t.Fatalf("decode poll: %v", err) } @@ -294,12 +334,43 @@ func (ts *testServer) pollMessages( return result.Messages, lastID } +func postSessionExpect( + t *testing.T, + ts *testServer, + nick string, + wantStatus int, +) { + t.Helper() + + body, err := json.Marshal(map[string]string{"nick": nick}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + resp, err := ts.doReq( + http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != wantStatus { + t.Fatalf("expected %d, got %d", wantStatus, resp.StatusCode) + } +} + // --- Tests --- func TestCreateSession(t *testing.T) { + t.Parallel() + ts := newTestServer(t) t.Run("valid nick", func(t *testing.T) { + t.Parallel() + token := ts.createSession("alice") if token == "" { t.Fatal("expected token") @@ -307,88 +378,42 @@ func TestCreateSession(t *testing.T) { }) t.Run("duplicate nick", func(t *testing.T) { - body, err := json.Marshal(map[string]string{"nick": "alice"}) - if err != nil { - t.Fatalf("marshal: %v", err) - } + t.Parallel() - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), - ) - if err != nil { - t.Fatal(err) - } - defer func() { _ = resp.Body.Close() }() + ts2 := newTestServer(t) + ts2.createSession("dupnick") - if resp.StatusCode != http.StatusConflict { - t.Fatalf("expected 409, got %d", resp.StatusCode) - } + postSessionExpect(t, ts2, "dupnick", http.StatusConflict) }) t.Run("empty nick", func(t *testing.T) { - body, err := json.Marshal(map[string]string{"nick": ""}) - if err != nil { - t.Fatalf("marshal: %v", err) - } + t.Parallel() - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), - ) - if err != nil { - t.Fatal(err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) - } + postSessionExpect(t, ts, "", http.StatusBadRequest) }) t.Run("invalid nick chars", func(t *testing.T) { - body, err := json.Marshal(map[string]string{"nick": "hello world"}) - if err != nil { - t.Fatalf("marshal: %v", err) - } + t.Parallel() - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), - ) - if err != nil { - t.Fatal(err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) - } + postSessionExpect(t, ts, "hello world", http.StatusBadRequest) }) t.Run("nick starting with number", func(t *testing.T) { - body, err := json.Marshal(map[string]string{"nick": "123abc"}) - if err != nil { - t.Fatalf("marshal: %v", err) - } + t.Parallel() - resp, err := ts.doReq( - http.MethodPost, ts.url("/api/v1/session"), bytes.NewReader(body), - ) - if err != nil { - t.Fatal(err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.StatusCode) - } + postSessionExpect(t, ts, "123abc", http.StatusBadRequest) }) t.Run("malformed json", func(t *testing.T) { + t.Parallel() + resp, err := ts.doReq( http.MethodPost, ts.url("/api/v1/session"), strings.NewReader("{bad"), ) if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusBadRequest { @@ -398,9 +423,13 @@ func TestCreateSession(t *testing.T) { } func TestAuth(t *testing.T) { + t.Parallel() + ts := newTestServer(t) t.Run("no auth header", func(t *testing.T) { + t.Parallel() + status, _ := ts.getJSON("", "/api/v1/state") if status != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", status) @@ -408,6 +437,8 @@ func TestAuth(t *testing.T) { }) t.Run("bad token", func(t *testing.T) { + t.Parallel() + status, _ := ts.getJSON("invalid-token-12345", "/api/v1/state") if status != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", status) @@ -415,6 +446,8 @@ func TestAuth(t *testing.T) { }) t.Run("valid token", func(t *testing.T) { + t.Parallel() + token := ts.createSession("authtest") status, result := ts.getJSON(token, "/api/v1/state") @@ -429,10 +462,14 @@ func TestAuth(t *testing.T) { } func TestJoinAndPart(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("bob") t.Run("join channel", func(t *testing.T) { + t.Parallel() + status, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#test"}) if status != http.StatusOK { t.Fatalf("expected 200, got %d: %v", status, result) @@ -444,6 +481,8 @@ func TestJoinAndPart(t *testing.T) { }) t.Run("join without hash", func(t *testing.T) { + t.Parallel() + status, result := ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "other"}) if status != http.StatusOK { t.Fatalf("expected 200, got %d: %v", status, result) @@ -455,17 +494,25 @@ func TestJoinAndPart(t *testing.T) { }) t.Run("part channel", func(t *testing.T) { - status, result := ts.sendCommand(token, map[string]any{"command": "PART", "to": "#test"}) + t.Parallel() + + ts2 := newTestServer(t) + tok := ts2.createSession("partuser") + ts2.sendCommand(tok, map[string]any{"command": "JOIN", "to": "#partchan"}) + + status, result := ts2.sendCommand(tok, map[string]any{"command": "PART", "to": "#partchan"}) if status != http.StatusOK { t.Fatalf("expected 200, got %d: %v", status, result) } - if result["channel"] != "#test" { - t.Fatalf("expected #test, got %v", result["channel"]) + if result["channel"] != "#partchan" { + t.Fatalf("expected #partchan, got %v", result["channel"]) } }) t.Run("join missing to", func(t *testing.T) { + t.Parallel() + status, _ := ts.sendCommand(token, map[string]any{"command": "JOIN"}) if status != http.StatusBadRequest { t.Fatalf("expected 400, got %d", status) @@ -473,7 +520,9 @@ func TestJoinAndPart(t *testing.T) { }) } -func TestPrivmsg(t *testing.T) { +func TestPrivmsgChannel(t *testing.T) { + t.Parallel() + ts := newTestServer(t) aliceToken := ts.createSession("alice_msg") bobToken := ts.createSession("bob_msg") @@ -484,43 +533,50 @@ func TestPrivmsg(t *testing.T) { _, _ = ts.pollMessages(aliceToken, 0) _, bobLastID := ts.pollMessages(bobToken, 0) - t.Run("send channel message", func(t *testing.T) { - status, result := ts.sendCommand(aliceToken, map[string]any{ - "command": "PRIVMSG", - "to": "#chat", - "body": []string{"hello world"}, - }) - if status != http.StatusCreated { - t.Fatalf("expected 201, got %d: %v", status, result) - } - - if result["id"] == nil || result["id"] == "" { - t.Fatal("expected message id") - } + status, result := ts.sendCommand(aliceToken, map[string]any{ + "command": cmdPrivmsg, + "to": "#chat", + "body": []string{"hello world"}, }) + if status != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", status, result) + } - t.Run("bob receives message", func(t *testing.T) { - msgs, _ := ts.pollMessages(bobToken, bobLastID) + if result["id"] == nil || result["id"] == "" { + t.Fatal("expected message id") + } - found := false + msgs, _ := ts.pollMessages(bobToken, bobLastID) - for _, m := range msgs { - if m["command"] == "PRIVMSG" && m["from"] == "alice_msg" { - found = true + found := false - break - } + for _, m := range msgs { + if m["command"] == cmdPrivmsg && m["from"] == "alice_msg" { + found = true + + break } + } - if !found { - t.Fatalf("bob didn't receive alice's message: %v", msgs) - } - }) + if !found { + t.Fatalf("bob didn't receive alice's message: %v", msgs) + } +} + +func TestPrivmsgErrors(t *testing.T) { + t.Parallel() + + ts := newTestServer(t) + aliceToken := ts.createSession("alice_msg2") + + ts.sendCommand(aliceToken, map[string]any{"command": "JOIN", "to": "#chat2"}) t.Run("missing body", func(t *testing.T) { + t.Parallel() + status, _ := ts.sendCommand(aliceToken, map[string]any{ - "command": "PRIVMSG", - "to": "#chat", + "command": cmdPrivmsg, + "to": "#chat2", }) if status != http.StatusBadRequest { t.Fatalf("expected 400, got %d", status) @@ -528,8 +584,10 @@ func TestPrivmsg(t *testing.T) { }) t.Run("missing to", func(t *testing.T) { + t.Parallel() + status, _ := ts.sendCommand(aliceToken, map[string]any{ - "command": "PRIVMSG", + "command": cmdPrivmsg, "body": []string{"hello"}, }) if status != http.StatusBadRequest { @@ -538,95 +596,124 @@ func TestPrivmsg(t *testing.T) { }) } -func TestDM(t *testing.T) { +func TestDMSend(t *testing.T) { + t.Parallel() + ts := newTestServer(t) aliceToken := ts.createSession("alice_dm") bobToken := ts.createSession("bob_dm") - t.Run("send DM", func(t *testing.T) { - status, result := ts.sendCommand(aliceToken, map[string]any{ - "command": "PRIVMSG", - "to": "bob_dm", - "body": []string{"hey bob"}, - }) - if status != http.StatusCreated { - t.Fatalf("expected 201, got %d: %v", status, result) - } + status, result := ts.sendCommand(aliceToken, map[string]any{ + "command": cmdPrivmsg, + "to": "bob_dm", + "body": []string{"hey bob"}, }) + if status != http.StatusCreated { + t.Fatalf("expected 201, got %d: %v", status, result) + } - t.Run("bob receives DM", func(t *testing.T) { - msgs, _ := ts.pollMessages(bobToken, 0) + msgs, _ := ts.pollMessages(bobToken, 0) - found := false + found := false - for _, m := range msgs { - if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" { - found = true - } + for _, m := range msgs { + if m["command"] == cmdPrivmsg && m["from"] == "alice_dm" { + found = true } + } - if !found { - t.Fatal("bob didn't receive DM") - } - }) - - t.Run("alice gets echo", func(t *testing.T) { - msgs, _ := ts.pollMessages(aliceToken, 0) - - found := false - - for _, m := range msgs { - if m["command"] == "PRIVMSG" && m["from"] == "alice_dm" && m["to"] == "bob_dm" { - found = true - } - } - - if !found { - t.Fatal("alice didn't get DM echo") - } - }) - - t.Run("DM to nonexistent user", func(t *testing.T) { - status, _ := ts.sendCommand(aliceToken, map[string]any{ - "command": "PRIVMSG", - "to": "nobody", - "body": []string{"hello?"}, - }) - if status != http.StatusNotFound { - t.Fatalf("expected 404, got %d", status) - } - }) + if !found { + t.Fatal("bob didn't receive DM") + } } -func TestNick(t *testing.T) { +func TestDMEcho(t *testing.T) { + t.Parallel() + + ts := newTestServer(t) + aliceToken := ts.createSession("alice_echo") + ts.createSession("bob_echo") + + ts.sendCommand(aliceToken, map[string]any{ + "command": cmdPrivmsg, + "to": "bob_echo", + "body": []string{"hey bob"}, + }) + + msgs, _ := ts.pollMessages(aliceToken, 0) + + found := false + + for _, m := range msgs { + if m["command"] == cmdPrivmsg && m["from"] == "alice_echo" && m["to"] == "bob_echo" { + found = true + } + } + + if !found { + t.Fatal("alice didn't get DM echo") + } +} + +func TestDMNonexistent(t *testing.T) { + t.Parallel() + + ts := newTestServer(t) + aliceToken := ts.createSession("alice_noone") + + status, _ := ts.sendCommand(aliceToken, map[string]any{ + "command": cmdPrivmsg, + "to": "nobody", + "body": []string{"hello?"}, + }) + if status != http.StatusNotFound { + t.Fatalf("expected 404, got %d", status) + } +} + +func TestNickChange(t *testing.T) { + t.Parallel() + + ts := newTestServer(t) + tok := ts.createSession("nick_change") + + status, result := ts.sendCommand(tok, map[string]any{ + "command": "NICK", + "body": []string{"newnick"}, + }) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d: %v", status, result) + } + + if result["nick"] != "newnick" { + t.Fatalf("expected newnick, got %v", result["nick"]) + } +} + +func TestNickSameAsCurrent(t *testing.T) { + t.Parallel() + + ts := newTestServer(t) + tok := ts.createSession("samenick") + + status, _ := ts.sendCommand(tok, map[string]any{ + "command": "NICK", + "body": []string{"samenick"}, + }) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) + } +} + +func TestNickErrors(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("nick_test") - t.Run("change nick", func(t *testing.T) { - status, result := ts.sendCommand(token, map[string]any{ - "command": "NICK", - "body": []string{"newnick"}, - }) - if status != http.StatusOK { - t.Fatalf("expected 200, got %d: %v", status, result) - } + t.Run("collision", func(t *testing.T) { + t.Parallel() - if result["nick"] != "newnick" { - t.Fatalf("expected newnick, got %v", result["nick"]) - } - }) - - t.Run("nick same as current", func(t *testing.T) { - status, _ := ts.sendCommand(token, map[string]any{ - "command": "NICK", - "body": []string{"newnick"}, - }) - if status != http.StatusOK { - t.Fatalf("expected 200, got %d", status) - } - }) - - t.Run("nick collision", func(t *testing.T) { ts.createSession("taken_nick") status, _ := ts.sendCommand(token, map[string]any{ @@ -638,7 +725,9 @@ func TestNick(t *testing.T) { } }) - t.Run("invalid nick", func(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + t.Parallel() + status, _ := ts.sendCommand(token, map[string]any{ "command": "NICK", "body": []string{"bad nick!"}, @@ -649,6 +738,8 @@ func TestNick(t *testing.T) { }) t.Run("empty body", func(t *testing.T) { + t.Parallel() + status, _ := ts.sendCommand(token, map[string]any{ "command": "NICK", }) @@ -659,12 +750,16 @@ func TestNick(t *testing.T) { } func TestTopic(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("topic_user") ts.sendCommand(token, map[string]any{"command": "JOIN", "to": "#topictest"}) t.Run("set topic", func(t *testing.T) { + t.Parallel() + status, result := ts.sendCommand(token, map[string]any{ "command": "TOPIC", "to": "#topictest", @@ -680,6 +775,8 @@ func TestTopic(t *testing.T) { }) t.Run("missing to", func(t *testing.T) { + t.Parallel() + status, _ := ts.sendCommand(token, map[string]any{ "command": "TOPIC", "body": []string{"topic"}, @@ -690,6 +787,8 @@ func TestTopic(t *testing.T) { }) t.Run("missing body", func(t *testing.T) { + t.Parallel() + status, _ := ts.sendCommand(token, map[string]any{ "command": "TOPIC", "to": "#topictest", @@ -701,6 +800,8 @@ func TestTopic(t *testing.T) { } func TestPing(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("ping_user") @@ -715,6 +816,8 @@ func TestPing(t *testing.T) { } func TestQuit(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("quitter") observerToken := ts.createSession("observer") @@ -750,6 +853,8 @@ func TestQuit(t *testing.T) { } func TestUnknownCommand(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("cmdtest") @@ -760,6 +865,8 @@ func TestUnknownCommand(t *testing.T) { } func TestEmptyCommand(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("emptycmd") @@ -770,6 +877,8 @@ func TestEmptyCommand(t *testing.T) { } func TestHistory(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("historian") @@ -777,7 +886,7 @@ func TestHistory(t *testing.T) { for range 5 { ts.sendCommand(token, map[string]any{ - "command": "PRIVMSG", + "command": cmdPrivmsg, "to": "#history", "body": []string{"test message"}, }) @@ -789,6 +898,7 @@ func TestHistory(t *testing.T) { if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { @@ -797,7 +907,8 @@ func TestHistory(t *testing.T) { var msgs []map[string]any - if err := json.NewDecoder(resp.Body).Decode(&msgs); err != nil { + err = json.NewDecoder(resp.Body).Decode(&msgs) + if err != nil { t.Fatalf("decode history: %v", err) } @@ -807,6 +918,8 @@ func TestHistory(t *testing.T) { } func TestChannelList(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("lister") @@ -818,6 +931,7 @@ func TestChannelList(t *testing.T) { if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { @@ -826,7 +940,8 @@ func TestChannelList(t *testing.T) { var channels []map[string]any - if err := json.NewDecoder(resp.Body).Decode(&channels); err != nil { + err = json.NewDecoder(resp.Body).Decode(&channels) + if err != nil { t.Fatalf("decode channels: %v", err) } @@ -844,6 +959,8 @@ func TestChannelList(t *testing.T) { } func TestChannelMembers(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("membertest") @@ -855,6 +972,7 @@ func TestChannelMembers(t *testing.T) { if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { @@ -862,7 +980,45 @@ func TestChannelMembers(t *testing.T) { } } +func asyncPoll( + ts *testServer, + token string, + afterID int64, +) <-chan []map[string]any { + ch := make(chan []map[string]any, 1) + + go func() { + url := fmt.Sprintf( + "%s/api/v1/messages?timeout=5&after=%d", + ts.srv.URL, afterID, + ) + + resp, err := ts.doReqAuth( + http.MethodGet, url, token, nil, + ) + if err != nil { + ch <- nil + + return + } + + defer func() { _ = resp.Body.Close() }() + + var result struct { + Messages []map[string]any `json:"messages"` + } + + _ = json.NewDecoder(resp.Body).Decode(&result) + + ch <- result.Messages + }() + + return ch +} + func TestLongPoll(t *testing.T) { + t.Parallel() + ts := newTestServer(t) aliceToken := ts.createSession("lp_alice") bobToken := ts.createSession("lp_bob") @@ -872,58 +1028,34 @@ func TestLongPoll(t *testing.T) { _, lastID := ts.pollMessages(bobToken, 0) - var wg sync.WaitGroup - - var pollMsgs []map[string]any - - wg.Add(1) - - go func() { - defer wg.Done() - - url := fmt.Sprintf( - "%s/api/v1/messages?timeout=5&after=%d", ts.srv.URL, lastID, - ) - - resp, err := ts.doReqAuth(http.MethodGet, url, bobToken, nil) - if err != nil { - return - } - defer func() { _ = resp.Body.Close() }() - - var result struct { - Messages []map[string]any `json:"messages"` - } - - _ = json.NewDecoder(resp.Body).Decode(&result) - - pollMsgs = result.Messages - }() + pollMsgs := asyncPoll(ts, bobToken, lastID) time.Sleep(200 * time.Millisecond) ts.sendCommand(aliceToken, map[string]any{ - "command": "PRIVMSG", + "command": cmdPrivmsg, "to": "#longpoll", "body": []string{"wake up!"}, }) - wg.Wait() + msgs := <-pollMsgs found := false - for _, m := range pollMsgs { - if m["command"] == "PRIVMSG" && m["from"] == "lp_alice" { + for _, m := range msgs { + if m["command"] == cmdPrivmsg && m["from"] == "lp_alice" { found = true } } if !found { - t.Fatalf("long-poll didn't receive message: %v", pollMsgs) + t.Fatalf("long-poll didn't receive message: %v", msgs) } } func TestLongPollTimeout(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("lp_timeout") @@ -935,6 +1067,7 @@ func TestLongPollTimeout(t *testing.T) { if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() elapsed := time.Since(start) @@ -949,6 +1082,8 @@ func TestLongPollTimeout(t *testing.T) { } func TestEphemeralChannelCleanup(t *testing.T) { + t.Parallel() + ts := newTestServer(t) token := ts.createSession("ephemeral") @@ -961,11 +1096,13 @@ func TestEphemeralChannelCleanup(t *testing.T) { if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() var channels []map[string]any - if err := json.NewDecoder(resp.Body).Decode(&channels); err != nil { + err = json.NewDecoder(resp.Body).Decode(&channels) + if err != nil { t.Fatalf("decode channels: %v", err) } @@ -977,6 +1114,8 @@ func TestEphemeralChannelCleanup(t *testing.T) { } func TestConcurrentSessions(t *testing.T) { + t.Parallel() + ts := newTestServer(t) var wg sync.WaitGroup @@ -986,10 +1125,10 @@ func TestConcurrentSessions(t *testing.T) { for i := range 20 { wg.Add(1) - go func(i int) { + go func(idx int) { defer wg.Done() - nick := "concurrent_" + string(rune('a'+i)) + nick := fmt.Sprintf("concurrent_%d", idx) body, err := json.Marshal(map[string]string{"nick": nick}) if err != nil { @@ -1028,12 +1167,15 @@ func TestConcurrentSessions(t *testing.T) { } func TestServerInfo(t *testing.T) { + t.Parallel() + ts := newTestServer(t) resp, err := ts.doReq(http.MethodGet, ts.url("/api/v1/server"), nil) if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { @@ -1042,12 +1184,15 @@ func TestServerInfo(t *testing.T) { } func TestHealthcheck(t *testing.T) { + t.Parallel() + ts := newTestServer(t) resp, err := ts.doReq(http.MethodGet, ts.url("/.well-known/healthcheck.json"), nil) if err != nil { t.Fatal(err) } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { @@ -1056,7 +1201,8 @@ func TestHealthcheck(t *testing.T) { var result map[string]any - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { t.Fatalf("decode healthcheck: %v", err) } @@ -1066,6 +1212,8 @@ func TestHealthcheck(t *testing.T) { } func TestNickBroadcastToChannels(t *testing.T) { + t.Parallel() + ts := newTestServer(t) aliceToken := ts.createSession("nick_a") bobToken := ts.createSession("nick_b")