diff --git a/internal/db/queries.go b/internal/db/queries.go index af7b83b..b051617 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -275,6 +275,103 @@ func (s *Database) PollMessages(ctx context.Context, userID int64, afterID int64 return msgs, nil } +// GetMessagesBefore returns channel messages before a given ID (for history scrollback). +func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, beforeID int64, limit int) ([]MessageInfo, error) { + if limit <= 0 { + limit = 50 + } + var query string + var args []any + if beforeID > 0 { + query = `SELECT m.id, c.name, u.nick, m.content, m.created_at + FROM messages m + INNER JOIN users u ON u.id = m.user_id + INNER JOIN channels c ON c.id = m.channel_id + WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ? + ORDER BY m.id DESC LIMIT ?` + args = []any{channelID, beforeID, limit} + } else { + query = `SELECT m.id, c.name, u.nick, m.content, m.created_at + FROM messages m + INNER JOIN users u ON u.id = m.user_id + INNER JOIN channels c ON c.id = m.channel_id + WHERE m.channel_id = ? AND m.is_dm = 0 + ORDER BY m.id DESC LIMIT ?` + args = []any{channelID, limit} + } + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + var msgs []MessageInfo + for rows.Next() { + var m MessageInfo + if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { + return nil, err + } + msgs = append(msgs, m) + } + if msgs == nil { + msgs = []MessageInfo{} + } + // Reverse to ascending order + for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { + msgs[i], msgs[j] = msgs[j], msgs[i] + } + return msgs, nil +} + +// GetDMsBefore returns DMs between two users before a given ID (for history scrollback). +func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeID int64, limit int) ([]MessageInfo, error) { + if limit <= 0 { + limit = 50 + } + var query string + var args []any + if beforeID > 0 { + query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at + FROM messages m + INNER JOIN users u ON u.id = m.user_id + INNER JOIN users t ON t.id = m.dm_target_id + WHERE m.is_dm = 1 AND m.id < ? + AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) + ORDER BY m.id DESC LIMIT ?` + args = []any{beforeID, userA, userB, userB, userA, limit} + } else { + query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at + FROM messages m + INNER JOIN users u ON u.id = m.user_id + INNER JOIN users t ON t.id = m.dm_target_id + WHERE m.is_dm = 1 + AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) + ORDER BY m.id DESC LIMIT ?` + args = []any{userA, userB, userB, userA, limit} + } + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + var msgs []MessageInfo + for rows.Next() { + var m MessageInfo + if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil { + return nil, err + } + m.IsDM = true + msgs = append(msgs, m) + } + if msgs == nil { + msgs = []MessageInfo{} + } + // Reverse to ascending order + for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { + msgs[i], msgs[j] = msgs[j], msgs[i] + } + return msgs, nil +} + // GetMOTD returns the server MOTD from config. func (s *Database) GetServerName() string { return "" diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 22487d4..5624f40 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" + "git.eeqj.de/sneak/chat/internal/db" "github.com/go-chi/chi" ) @@ -64,35 +65,25 @@ func (s *Handlers) HandleRegister() http.HandlerFunc { } } -// HandleMe returns the current user's info. -func (s *Handlers) HandleMe() http.HandlerFunc { +// HandleState returns the current user's info and joined channels. +func (s *Handlers) HandleState() http.HandlerFunc { type response struct { - ID int64 `json:"id"` - Nick string `json:"nick"` + ID int64 `json:"id"` + Nick string `json:"nick"` + Channels []db.ChannelInfo `json:"channels"` } return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } - s.respondJSON(w, r, &response{ID: uid, Nick: nick}, http.StatusOK) - } -} - -// HandleListChannels returns channels the user has joined. -func (s *Handlers) HandleListChannels() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) - if !ok { - return - } channels, err := s.params.Database.ListChannels(r.Context(), uid) if err != nil { s.log.Error("list channels failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - s.respondJSON(w, r, channels, http.StatusOK) + s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK) } } @@ -200,132 +191,9 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc { } } -// HandleGetMessages returns messages for a channel. +// HandleGetMessages returns all new messages (channel + DM) for the user via long-polling. +// This is the single unified message stream — replaces separate channel/DM/poll endpoints. func (s *Handlers) HandleGetMessages() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - _, _, ok := s.requireAuth(w, r) - if !ok { - return - } - name := "#" + chi.URLParam(r, "channel") - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", name).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64) - limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) - msgs, err := s.params.Database.GetMessages(r.Context(), chID, afterID, limit) - if err != nil { - s.log.Error("get messages failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, msgs, http.StatusOK) - } -} - -// HandleSendMessage sends a message to a channel. -func (s *Handlers) HandleSendMessage() http.HandlerFunc { - type request struct { - Content string `json:"content"` - } - return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) - if !ok { - return - } - name := "#" + chi.URLParam(r, "channel") - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", name).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) - return - } - if strings.TrimSpace(req.Content) == "" { - s.respondJSON(w, r, map[string]string{"error": "content required"}, http.StatusBadRequest) - return - } - msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, req.Content) - if err != nil { - s.log.Error("send message failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) - } -} - -// HandleSendDM sends a direct message to a user. -func (s *Handlers) HandleSendDM() http.HandlerFunc { - type request struct { - Content string `json:"content"` - } - return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) - if !ok { - return - } - targetNick := chi.URLParam(r, "nick") - targetID, err := s.params.Database.GetUserByNick(r.Context(), targetNick) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) - return - } - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) - return - } - if strings.TrimSpace(req.Content) == "" { - s.respondJSON(w, r, map[string]string{"error": "content required"}, http.StatusBadRequest) - return - } - msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, req.Content) - if err != nil { - s.log.Error("send dm failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) - } -} - -// HandleGetDMs returns direct messages with a user. -func (s *Handlers) HandleGetDMs() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) - if !ok { - return - } - targetNick := chi.URLParam(r, "nick") - targetID, err := s.params.Database.GetUserByNick(r.Context(), targetNick) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) - return - } - afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64) - limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) - msgs, err := s.params.Database.GetDMs(r.Context(), uid, targetID, afterID, limit) - if err != nil { - s.log.Error("get dms failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, msgs, http.StatusOK) - } -} - -// HandlePoll returns all new messages (channels + DMs) for the user. -func (s *Handlers) HandlePoll() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, _, ok := s.requireAuth(w, r) if !ok { @@ -335,7 +203,7 @@ func (s *Handlers) HandlePoll() http.HandlerFunc { limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterID, limit) if err != nil { - s.log.Error("poll messages failed", "error", err) + s.log.Error("get messages failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } @@ -343,6 +211,119 @@ func (s *Handlers) HandlePoll() http.HandlerFunc { } } +// HandleSendMessage sends a message to a channel or user. +// The "to" field determines the target: "#channel" for channels, "nick" for DMs. +func (s *Handlers) HandleSendMessage() http.HandlerFunc { + type request struct { + To string `json:"to"` + Content string `json:"content"` + } + return func(w http.ResponseWriter, r *http.Request) { + uid, _, ok := s.requireAuth(w, r) + if !ok { + return + } + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Content) == "" { + s.respondJSON(w, r, map[string]string{"error": "content required"}, http.StatusBadRequest) + return + } + req.To = strings.TrimSpace(req.To) + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + + if strings.HasPrefix(req.To, "#") { + // Channel message + var chID int64 + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, req.Content) + if err != nil { + s.log.Error("send message failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) + } else { + // DM + targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) + return + } + msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, req.Content) + if err != nil { + s.log.Error("send dm failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) + } + } +} + +// HandleGetHistory returns message history for a specific target (channel or DM). +func (s *Handlers) HandleGetHistory() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + uid, _, ok := s.requireAuth(w, r) + if !ok { + return + } + target := r.URL.Query().Get("target") + if target == "" { + s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest) + return + } + beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64) + limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + if limit <= 0 { + limit = 50 + } + + if strings.HasPrefix(target, "#") { + // Channel history + var chID int64 + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", target).Scan(&chID) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeID, limit) + if err != nil { + s.log.Error("get history failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, msgs, http.StatusOK) + } else { + // DM history + targetID, err := s.params.Database.GetUserByNick(r.Context(), target) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) + return + } + msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetID, beforeID, limit) + if err != nil { + s.log.Error("get dm history failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, msgs, http.StatusOK) + } + } +} + // HandleServerInfo returns server metadata (MOTD, name). func (s *Handlers) HandleServerInfo() http.HandlerFunc { type response struct {