From 6c1d652308174f3c0641ee1f7e745c51960ac013 Mon Sep 17 00:00:00 2001 From: clawbot Date: Tue, 10 Feb 2026 18:16:23 -0800 Subject: [PATCH] refactor: clean up handlers, add input validation, remove raw SQL from handlers - Merge fanOut/fanOutDirect into single fanOut method - Move channel lookup to db.GetChannelByName - Add regex validation for nicks and channel names - Split HandleSendCommand into per-command helper methods - Add charset to Content-Type header - Add sentinel error for unauthorized - Cap history limit to 500 - Skip NICK change if new == old - Add empty command check --- internal/db/queries.go | 7 + internal/handlers/api.go | 464 ++++++++++++++++++---------------- internal/handlers/handlers.go | 8 +- 3 files changed, 252 insertions(+), 227 deletions(-) diff --git a/internal/db/queries.go b/internal/db/queries.go index cbe9c16..83ec801 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -77,6 +77,13 @@ func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error return id, err } +// GetChannelByName returns the channel ID for a given name. +func (s *Database) GetChannelByName(ctx context.Context, name string) (int64, error) { + var id int64 + err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id) + return id, err +} + // GetOrCreateChannel returns the channel id, creating it if needed. func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) { var id int64 diff --git a/internal/handlers/api.go b/internal/handlers/api.go index d02f734..e67e366 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -1,9 +1,9 @@ package handlers import ( - "database/sql" "encoding/json" "net/http" + "regexp" "strconv" "strings" "time" @@ -11,13 +11,19 @@ import ( "github.com/go-chi/chi" ) +var validNickRe = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_\-\[\]\\^{}|` + "`" + `]{0,31}$`) +var validChannelRe = regexp.MustCompile(`^#[a-zA-Z0-9_\-]{1,63}$`) + // authUser extracts the user from the Authorization header (Bearer token). func (s *Handlers) authUser(r *http.Request) (int64, string, error) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { - return 0, "", sql.ErrNoRows + return 0, "", errUnauthorized } token := strings.TrimPrefix(auth, "Bearer ") + if token == "" { + return 0, "", errUnauthorized + } return s.params.Database.GetUserByToken(r.Context(), token) } @@ -31,28 +37,13 @@ func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, s } // fanOut stores a message and enqueues it to all specified user IDs, then notifies them. -func (s *Handlers) fanOut(ctx *http.Request, command, from, to string, body json.RawMessage, userIDs []int64) error { - dbID, _, err := s.params.Database.InsertMessage(ctx.Context(), command, from, to, body, nil) - if err != nil { - return err - } - for _, uid := range userIDs { - if err := s.params.Database.EnqueueMessage(ctx.Context(), uid, dbID); err != nil { - s.log.Error("enqueue failed", "error", err, "user_id", uid) - } - s.broker.Notify(uid) - } - return nil -} - -// fanOutRaw stores and fans out, returning the message DB ID. -func (s *Handlers) fanOutDirect(ctx *http.Request, command, from, to string, body json.RawMessage, userIDs []int64) (int64, string, error) { - dbID, msgUUID, err := s.params.Database.InsertMessage(ctx.Context(), command, from, to, body, nil) +func (s *Handlers) fanOut(r *http.Request, command, from, to string, body json.RawMessage, userIDs []int64) (int64, string, error) { + dbID, msgUUID, err := s.params.Database.InsertMessage(r.Context(), command, from, to, body, nil) if err != nil { return 0, "", err } for _, uid := range userIDs { - if err := s.params.Database.EnqueueMessage(ctx.Context(), uid, dbID); err != nil { + if err := s.params.Database.EnqueueMessage(r.Context(), uid, dbID); err != nil { s.log.Error("enqueue failed", "error", err, "user_id", uid) } s.broker.Notify(uid) @@ -60,18 +51,6 @@ func (s *Handlers) fanOutDirect(ctx *http.Request, command, from, to string, bod return dbID, msgUUID, nil } -// getChannelMembers gets all member IDs for a channel by name. -func (s *Handlers) getChannelMemberIDs(r *http.Request, channelName string) (int64, []int64, error) { - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", channelName).Scan(&chID) - if err != nil { - return 0, nil, err - } - ids, err := s.params.Database.GetChannelMemberIDs(r.Context(), chID) - return chID, ids, err -} - // HandleCreateSession creates a new user session and returns the auth token. func (s *Handlers) HandleCreateSession() http.HandlerFunc { type request struct { @@ -85,12 +64,12 @@ func (s *Handlers) HandleCreateSession() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req request if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) + s.respondJSON(w, r, map[string]string{"error": "invalid request body"}, http.StatusBadRequest) return } req.Nick = strings.TrimSpace(req.Nick) - if req.Nick == "" || len(req.Nick) > 32 { - s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) + if !validNickRe.MatchString(req.Nick) { + s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 chars, start with letter/underscore, contain only [a-zA-Z0-9_\\-[]\\^{}|`]"}, http.StatusBadRequest) return } id, token, err := s.params.Database.CreateUser(r.Context(), req.Nick) @@ -153,9 +132,7 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc { return } name := "#" + chi.URLParam(r, "channel") - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", name).Scan(&chID) + chID, err := s.params.Database.GetChannelByName(r.Context(), name) if err != nil { s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) return @@ -179,7 +156,7 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc { } afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64) timeout, _ := strconv.Atoi(r.URL.Query().Get("timeout")) - if timeout <= 0 { + if timeout < 0 { timeout = 0 } if timeout > 30 { @@ -245,12 +222,17 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { } var req request if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) + s.respondJSON(w, r, map[string]string{"error": "invalid request body"}, http.StatusBadRequest) return } req.Command = strings.ToUpper(strings.TrimSpace(req.Command)) req.To = strings.TrimSpace(req.To) + if req.Command == "" { + s.respondJSON(w, r, map[string]string{"error": "command required"}, http.StatusBadRequest) + return + } + bodyLines := func() []string { if req.Body == nil { return nil @@ -264,202 +246,236 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { switch req.Command { case "PRIVMSG", "NOTICE": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) - return - } - - if strings.HasPrefix(req.To, "#") { - // Channel message — fan out to all channel members. - _, memberIDs, err := s.getChannelMemberIDs(r, req.To) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - _, msgUUID, err := s.fanOutDirect(r, req.Command, nick, req.To, req.Body, memberIDs) - if err != nil { - s.log.Error("send message failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) - } else { - // DM — fan out to recipient + sender. - targetUID, err := s.params.Database.GetUserByNick(r.Context(), req.To) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) - return - } - recipients := []int64{targetUID} - if targetUID != uid { - recipients = append(recipients, uid) // echo to sender - } - _, msgUUID, err := s.fanOutDirect(r, req.Command, nick, req.To, req.Body, recipients) - if err != nil { - s.log.Error("send dm failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) - } - + s.handlePrivmsgOrNotice(w, r, uid, nick, req.Command, req.To, req.Body, bodyLines) case "JOIN": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel) - if err != nil { - s.log.Error("get/create channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { - s.log.Error("join channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - // Broadcast JOIN to all channel members (including the joiner). - memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), chID) - _ = s.fanOut(r, "JOIN", nick, channel, nil, memberIDs) - s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK) - + s.handleJoin(w, r, uid, nick, req.To) case "PART": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", channel).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - // Broadcast PART before removing the member. - memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), chID) - _ = s.fanOut(r, "PART", nick, channel, req.Body, memberIDs) - - if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { - s.log.Error("part channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - // Delete channel if empty (ephemeral). - _ = s.params.Database.DeleteChannelIfEmpty(r.Context(), chID) - s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK) - + s.handlePart(w, r, uid, nick, req.To, req.Body) case "NICK": - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest) - return - } - newNick := strings.TrimSpace(lines[0]) - if newNick == "" || len(newNick) > 32 { - s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) - return - } - if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict) - return - } - s.log.Error("change nick failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - // Broadcast NICK to all channels the user is in. - channels, _ := s.params.Database.GetAllChannelMembershipsForUser(r.Context(), uid) - notified := map[int64]bool{uid: true} - body, _ := json.Marshal([]string{newNick}) - // Notify self. - dbID, _, _ := s.params.Database.InsertMessage(r.Context(), "NICK", nick, "", json.RawMessage(body), nil) - _ = s.params.Database.EnqueueMessage(r.Context(), uid, dbID) - s.broker.Notify(uid) - - for _, ch := range channels { - memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), ch.ID) - for _, mid := range memberIDs { - if !notified[mid] { - notified[mid] = true - _ = s.params.Database.EnqueueMessage(r.Context(), mid, dbID) - s.broker.Notify(mid) - } - } - } - s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) - + s.handleNick(w, r, uid, nick, bodyLines) case "TOPIC": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest) - return - } - topic := strings.Join(lines, " ") - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - if err := s.params.Database.SetTopic(r.Context(), channel, topic); err != nil { - s.log.Error("set topic failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - // Broadcast TOPIC to channel members. - _, memberIDs, _ := s.getChannelMemberIDs(r, channel) - _ = s.fanOut(r, "TOPIC", nick, channel, req.Body, memberIDs) - s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK) - + s.handleTopic(w, r, nick, req.To, req.Body, bodyLines) case "QUIT": - // Broadcast QUIT to all channels, then remove user. - channels, _ := s.params.Database.GetAllChannelMembershipsForUser(r.Context(), uid) - notified := map[int64]bool{} - var dbID int64 - if len(channels) > 0 { - dbID, _, _ = s.params.Database.InsertMessage(r.Context(), "QUIT", nick, "", req.Body, nil) - } - for _, ch := range channels { - memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), ch.ID) - for _, mid := range memberIDs { - if mid != uid && !notified[mid] { - notified[mid] = true - _ = s.params.Database.EnqueueMessage(r.Context(), mid, dbID) - s.broker.Notify(mid) - } - } - _ = s.params.Database.PartChannel(r.Context(), ch.ID, uid) - _ = s.params.Database.DeleteChannelIfEmpty(r.Context(), ch.ID) - } - _ = s.params.Database.DeleteUser(r.Context(), uid) - s.respondJSON(w, r, map[string]string{"status": "quit"}, http.StatusOK) - + s.handleQuit(w, r, uid, nick, req.Body) case "PING": s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK) - default: s.respondJSON(w, r, map[string]string{"error": "unknown command: " + req.Command}, http.StatusBadRequest) } } } +func (s *Handlers) handlePrivmsgOrNotice(w http.ResponseWriter, r *http.Request, uid int64, nick, command, to string, body json.RawMessage, bodyLines func() []string) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) + return + } + + if strings.HasPrefix(to, "#") { + // Channel message. + chID, err := s.params.Database.GetChannelByName(r.Context(), to) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + memberIDs, err := s.params.Database.GetChannelMemberIDs(r.Context(), chID) + if err != nil { + s.log.Error("get channel members failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + _, msgUUID, err := s.fanOut(r, command, nick, to, body, memberIDs) + if err != nil { + s.log.Error("send message failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) + } else { + // DM. + targetUID, err := s.params.Database.GetUserByNick(r.Context(), to) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) + return + } + recipients := []int64{targetUID} + if targetUID != uid { + recipients = append(recipients, uid) // echo to sender + } + _, msgUUID, err := s.fanOut(r, command, nick, to, body, recipients) + if err != nil { + s.log.Error("send dm failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) + } +} + +func (s *Handlers) handleJoin(w http.ResponseWriter, r *http.Request, uid int64, nick, to string) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + channel := to + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + if !validChannelRe.MatchString(channel) { + s.respondJSON(w, r, map[string]string{"error": "invalid channel name"}, http.StatusBadRequest) + return + } + + chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel) + if err != nil { + s.log.Error("get/create channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { + s.log.Error("join channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + // Broadcast JOIN to all channel members (including the joiner). + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), chID) + _, _, _ = s.fanOut(r, "JOIN", nick, channel, nil, memberIDs) + s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK) +} + +func (s *Handlers) handlePart(w http.ResponseWriter, r *http.Request, uid int64, nick, to string, body json.RawMessage) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + channel := to + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + chID, err := s.params.Database.GetChannelByName(r.Context(), channel) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + // Broadcast PART before removing the member. + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), chID) + _, _, _ = s.fanOut(r, "PART", nick, channel, body, memberIDs) + + if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { + s.log.Error("part channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + // Delete channel if empty (ephemeral). + _ = s.params.Database.DeleteChannelIfEmpty(r.Context(), chID) + s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK) +} + +func (s *Handlers) handleNick(w http.ResponseWriter, r *http.Request, uid int64, nick string, bodyLines func() []string) { + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest) + return + } + newNick := strings.TrimSpace(lines[0]) + if !validNickRe.MatchString(newNick) { + s.respondJSON(w, r, map[string]string{"error": "invalid nick"}, http.StatusBadRequest) + return + } + if newNick == nick { + s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) + return + } + + if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict) + return + } + s.log.Error("change nick failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + // Broadcast NICK to all channels the user is in. + channels, _ := s.params.Database.GetAllChannelMembershipsForUser(r.Context(), uid) + notified := map[int64]bool{uid: true} + body, _ := json.Marshal([]string{newNick}) + dbID, _, _ := s.params.Database.InsertMessage(r.Context(), "NICK", nick, "", json.RawMessage(body), nil) + _ = s.params.Database.EnqueueMessage(r.Context(), uid, dbID) + s.broker.Notify(uid) + + for _, ch := range channels { + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), ch.ID) + for _, mid := range memberIDs { + if !notified[mid] { + notified[mid] = true + _ = s.params.Database.EnqueueMessage(r.Context(), mid, dbID) + s.broker.Notify(mid) + } + } + } + s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) +} + +func (s *Handlers) handleTopic(w http.ResponseWriter, r *http.Request, nick, to string, body json.RawMessage, bodyLines func() []string) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest) + return + } + topic := strings.Join(lines, " ") + channel := to + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + if err := s.params.Database.SetTopic(r.Context(), channel, topic); err != nil { + s.log.Error("set topic failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + chID, err := s.params.Database.GetChannelByName(r.Context(), channel) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), chID) + _, _, _ = s.fanOut(r, "TOPIC", nick, channel, body, memberIDs) + s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK) +} + +func (s *Handlers) handleQuit(w http.ResponseWriter, r *http.Request, uid int64, nick string, body json.RawMessage) { + channels, _ := s.params.Database.GetAllChannelMembershipsForUser(r.Context(), uid) + notified := map[int64]bool{} + var dbID int64 + if len(channels) > 0 { + dbID, _, _ = s.params.Database.InsertMessage(r.Context(), "QUIT", nick, "", body, nil) + } + for _, ch := range channels { + memberIDs, _ := s.params.Database.GetChannelMemberIDs(r.Context(), ch.ID) + for _, mid := range memberIDs { + if mid != uid && !notified[mid] { + notified[mid] = true + _ = s.params.Database.EnqueueMessage(r.Context(), mid, dbID) + s.broker.Notify(mid) + } + } + _ = s.params.Database.PartChannel(r.Context(), ch.ID, uid) + _ = s.params.Database.DeleteChannelIfEmpty(r.Context(), ch.ID) + } + _ = s.params.Database.DeleteUser(r.Context(), uid) + s.respondJSON(w, r, map[string]string{"status": "quit"}, http.StatusOK) +} + // HandleGetHistory returns message history for a specific target. func (s *Handlers) HandleGetHistory() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -474,7 +490,7 @@ func (s *Handlers) HandleGetHistory() http.HandlerFunc { } beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64) limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) - if limit <= 0 { + if limit <= 0 || limit > 500 { limit = 50 } msgs, err := s.params.Database.GetHistory(r.Context(), target, beforeID, limit) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 92e5234..9ac9162 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -4,6 +4,7 @@ package handlers import ( "context" "encoding/json" + "errors" "log/slog" "net/http" @@ -16,6 +17,8 @@ import ( "go.uber.org/fx" ) +var errUnauthorized = errors.New("unauthorized") + // Params defines the dependencies for creating Handlers. type Params struct { fx.In @@ -53,12 +56,11 @@ func New(lc fx.Lifecycle, params Params) (*Handlers, error) { } func (s *Handlers) respondJSON(w http.ResponseWriter, _ *http.Request, data any, status int) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) if data != nil { - err := json.NewEncoder(w).Encode(data) - if err != nil { + if err := json.NewEncoder(w).Encode(data); err != nil { s.log.Error("json encode error", "error", err) } }