refactor: merge /me + /channels into /state, unify message endpoints

- HandleState returns user info (id, nick) + joined channels in one response
- HandleGetMessages now serves unified message stream (was HandlePoll)
- HandleSendMessage accepts 'to' field for routing to #channel or nick
- HandleGetHistory supports scrollback for channels and DMs
- Remove separate HandleMe, HandleListChannels, HandleSendDM, HandleGetDMs, HandlePoll
This commit is contained in:
user
2026-02-10 10:20:00 -08:00
parent ac933d07d2
commit 7361e8bd9b
2 changed files with 220 additions and 142 deletions

View File

@@ -275,6 +275,103 @@ func (s *Database) PollMessages(ctx context.Context, userID int64, afterID int64
return msgs, nil
}
// GetMessagesBefore returns channel messages before a given ID (for history scrollback).
func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, beforeID int64, limit int) ([]MessageInfo, error) {
if limit <= 0 {
limit = 50
}
var query string
var args []any
if beforeID > 0 {
query = `SELECT m.id, c.name, u.nick, m.content, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ?
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, beforeID, limit}
} else {
query = `SELECT m.id, c.name, u.nick, m.content, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN channels c ON c.id = m.channel_id
WHERE m.channel_id = ? AND m.is_dm = 0
ORDER BY m.id DESC LIMIT ?`
args = []any{channelID, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil {
return nil, err
}
msgs = append(msgs, m)
}
if msgs == nil {
msgs = []MessageInfo{}
}
// Reverse to ascending order
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
return msgs, nil
}
// GetDMsBefore returns DMs between two users before a given ID (for history scrollback).
func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeID int64, limit int) ([]MessageInfo, error) {
if limit <= 0 {
limit = 50
}
var query string
var args []any
if beforeID > 0 {
query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1 AND m.id < ?
AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{beforeID, userA, userB, userB, userA, limit}
} else {
query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at
FROM messages m
INNER JOIN users u ON u.id = m.user_id
INNER JOIN users t ON t.id = m.dm_target_id
WHERE m.is_dm = 1
AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
ORDER BY m.id DESC LIMIT ?`
args = []any{userA, userB, userB, userA, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var msgs []MessageInfo
for rows.Next() {
var m MessageInfo
if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil {
return nil, err
}
m.IsDM = true
msgs = append(msgs, m)
}
if msgs == nil {
msgs = []MessageInfo{}
}
// Reverse to ascending order
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
return msgs, nil
}
// GetMOTD returns the server MOTD from config.
func (s *Database) GetServerName() string {
return ""

View File

@@ -7,6 +7,7 @@ import (
"strconv"
"strings"
"git.eeqj.de/sneak/chat/internal/db"
"github.com/go-chi/chi"
)
@@ -64,35 +65,25 @@ func (s *Handlers) HandleRegister() http.HandlerFunc {
}
}
// HandleMe returns the current user's info.
func (s *Handlers) HandleMe() http.HandlerFunc {
// HandleState returns the current user's info and joined channels.
func (s *Handlers) HandleState() http.HandlerFunc {
type response struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
ID int64 `json:"id"`
Nick string `json:"nick"`
Channels []db.ChannelInfo `json:"channels"`
}
return func(w http.ResponseWriter, r *http.Request) {
uid, nick, ok := s.requireAuth(w, r)
if !ok {
return
}
s.respondJSON(w, r, &response{ID: uid, Nick: nick}, http.StatusOK)
}
}
// HandleListChannels returns channels the user has joined.
func (s *Handlers) HandleListChannels() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
channels, err := s.params.Database.ListChannels(r.Context(), uid)
if err != nil {
s.log.Error("list channels failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, channels, http.StatusOK)
s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK)
}
}
@@ -200,132 +191,9 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
}
}
// HandleGetMessages returns messages for a channel.
// HandleGetMessages returns all new messages (channel + DM) for the user via long-polling.
// This is the single unified message stream — replaces separate channel/DM/poll endpoints.
func (s *Handlers) HandleGetMessages() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
_, _, ok := s.requireAuth(w, r)
if !ok {
return
}
name := "#" + chi.URLParam(r, "channel")
var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", name).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
return
}
afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64)
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
msgs, err := s.params.Database.GetMessages(r.Context(), chID, afterID, limit)
if err != nil {
s.log.Error("get messages failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
}
// HandleSendMessage sends a message to a channel.
func (s *Handlers) HandleSendMessage() http.HandlerFunc {
type request struct {
Content string `json:"content"`
}
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
name := "#" + chi.URLParam(r, "channel")
var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", name).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
return
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Content) == "" {
s.respondJSON(w, r, map[string]string{"error": "content required"}, http.StatusBadRequest)
return
}
msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, req.Content)
if err != nil {
s.log.Error("send message failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
}
}
// HandleSendDM sends a direct message to a user.
func (s *Handlers) HandleSendDM() http.HandlerFunc {
type request struct {
Content string `json:"content"`
}
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
targetNick := chi.URLParam(r, "nick")
targetID, err := s.params.Database.GetUserByNick(r.Context(), targetNick)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
return
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Content) == "" {
s.respondJSON(w, r, map[string]string{"error": "content required"}, http.StatusBadRequest)
return
}
msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, req.Content)
if err != nil {
s.log.Error("send dm failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
}
}
// HandleGetDMs returns direct messages with a user.
func (s *Handlers) HandleGetDMs() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
targetNick := chi.URLParam(r, "nick")
targetID, err := s.params.Database.GetUserByNick(r.Context(), targetNick)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
return
}
afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64)
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
msgs, err := s.params.Database.GetDMs(r.Context(), uid, targetID, afterID, limit)
if err != nil {
s.log.Error("get dms failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
}
// HandlePoll returns all new messages (channels + DMs) for the user.
func (s *Handlers) HandlePoll() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
@@ -335,7 +203,7 @@ func (s *Handlers) HandlePoll() http.HandlerFunc {
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterID, limit)
if err != nil {
s.log.Error("poll messages failed", "error", err)
s.log.Error("get messages failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
@@ -343,6 +211,119 @@ func (s *Handlers) HandlePoll() http.HandlerFunc {
}
}
// HandleSendMessage sends a message to a channel or user.
// The "to" field determines the target: "#channel" for channels, "nick" for DMs.
func (s *Handlers) HandleSendMessage() http.HandlerFunc {
type request struct {
To string `json:"to"`
Content string `json:"content"`
}
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
var req request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Content) == "" {
s.respondJSON(w, r, map[string]string{"error": "content required"}, http.StatusBadRequest)
return
}
req.To = strings.TrimSpace(req.To)
if req.To == "" {
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
return
}
if strings.HasPrefix(req.To, "#") {
// Channel message
var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
return
}
msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, req.Content)
if err != nil {
s.log.Error("send message failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
} else {
// DM
targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
return
}
msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, req.Content)
if err != nil {
s.log.Error("send dm failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
}
}
}
// HandleGetHistory returns message history for a specific target (channel or DM).
func (s *Handlers) HandleGetHistory() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
target := r.URL.Query().Get("target")
if target == "" {
s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest)
return
}
beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64)
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
if limit <= 0 {
limit = 50
}
if strings.HasPrefix(target, "#") {
// Channel history
var chID int64
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
"SELECT id FROM channels WHERE name = ?", target).Scan(&chID)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
return
}
msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeID, limit)
if err != nil {
s.log.Error("get history failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
} else {
// DM history
targetID, err := s.params.Database.GetUserByNick(r.Context(), target)
if err != nil {
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
return
}
msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetID, beforeID, limit)
if err != nil {
s.log.Error("get dm history failed", "error", err)
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
}
}
// HandleServerInfo returns server metadata (MOTD, name).
func (s *Handlers) HandleServerInfo() http.HandlerFunc {
type response struct {