From 0ee3fd78d2cebe940968157cbee04f2ab0ee20a4 Mon Sep 17 00:00:00 2001 From: clawbot Date: Tue, 10 Feb 2026 17:53:08 -0800 Subject: [PATCH] refactor: unify all C2S commands through POST /messages All client-to-server commands now go through POST /api/v1/messages with a 'command' field. The server dispatches by command type: - PRIVMSG/NOTICE: send message to channel or user - JOIN: join channel (creates if needed) - PART: leave channel - NICK: change nickname - TOPIC: set channel topic - PING: keepalive (returns PONG) Removed separate routes: - POST /channels/join - DELETE /channels/{channel} - POST /register (renamed to POST /session) - GET /channels/all (moved to GET /channels) Added DB methods: ChangeNick, SetTopic --- internal/db/queries.go | 16 ++- internal/handlers/api.go | 247 ++++++++++++++++++++++++-------------- internal/server/routes.go | 8 +- 3 files changed, 172 insertions(+), 99 deletions(-) diff --git a/internal/db/queries.go b/internal/db/queries.go index b051617..974f4db 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -372,7 +372,21 @@ func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeI return msgs, nil } -// GetMOTD returns the server MOTD from config. +// ChangeNick updates a user's nickname. +func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error { + _, err := s.db.ExecContext(ctx, + "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) + return err +} + +// SetTopic sets the topic for a channel. +func (s *Database) SetTopic(ctx context.Context, channelName string, _ int64, topic string) error { + _, err := s.db.ExecContext(ctx, + "UPDATE channels SET topic = ? WHERE name = ?", topic, channelName) + return err +} + +// GetServerName returns the server name (unused, config provides this). func (s *Database) GetServerName() string { return "" } diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 5624f40..975b7a1 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -30,8 +30,8 @@ func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, s return uid, nick, true } -// HandleRegister creates a new user and returns the auth token. -func (s *Handlers) HandleRegister() http.HandlerFunc { +// HandleCreateSession creates a new user session and returns the auth token. +func (s *Handlers) HandleCreateSession() http.HandlerFunc { type request struct { Nick string `json:"nick"` } @@ -104,68 +104,6 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc { } } -// HandleJoinChannel joins a channel (creates it if needed). -func (s *Handlers) HandleJoinChannel() http.HandlerFunc { - type request struct { - Channel string `json:"channel"` - } - return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) - if !ok { - return - } - var req request - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) - return - } - req.Channel = strings.TrimSpace(req.Channel) - if req.Channel == "" { - s.respondJSON(w, r, map[string]string{"error": "channel name required"}, http.StatusBadRequest) - return - } - if !strings.HasPrefix(req.Channel, "#") { - req.Channel = "#" + req.Channel - } - chID, err := s.params.Database.GetOrCreateChannel(r.Context(), req.Channel) - if err != nil { - s.log.Error("get/create channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { - s.log.Error("join channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "joined", "channel": req.Channel}, http.StatusOK) - } -} - -// HandlePartChannel leaves a channel. -func (s *Handlers) HandlePartChannel() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) - if !ok { - return - } - name := "#" + chi.URLParam(r, "channel") - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", name).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { - s.log.Error("part channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "parted", "channel": name}, http.StatusOK) - } -} - // HandleChannelMembers returns members of a channel. func (s *Handlers) HandleChannelMembers() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -211,15 +149,17 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc { } } -// HandleSendMessage sends a message to a channel or user. -// The "to" field determines the target: "#channel" for channels, "nick" for DMs. -func (s *Handlers) HandleSendMessage() http.HandlerFunc { +// HandleSendCommand handles all C2S commands via POST /messages. +// The "command" field dispatches to the appropriate logic. +func (s *Handlers) HandleSendCommand() http.HandlerFunc { type request struct { - To string `json:"to"` - Content string `json:"content"` + Command string `json:"command"` + To string `json:"to"` + Params []string `json:"params,omitempty"` + Body interface{} `json:"body,omitempty"` } return func(w http.ResponseWriter, r *http.Request) { - uid, _, ok := s.requireAuth(w, r) + uid, nick, ok := s.requireAuth(w, r) if !ok { return } @@ -228,46 +168,167 @@ func (s *Handlers) HandleSendMessage() http.HandlerFunc { s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) return } - if strings.TrimSpace(req.Content) == "" { - s.respondJSON(w, r, map[string]string{"error": "content required"}, http.StatusBadRequest) - return - } + req.Command = strings.ToUpper(strings.TrimSpace(req.Command)) req.To = strings.TrimSpace(req.To) - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return + + // Helper to extract body as string lines. + bodyLines := func() []string { + switch v := req.Body.(type) { + case []interface{}: + lines := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + lines = append(lines, s) + } + } + return lines + case []string: + return v + default: + return nil + } } - if strings.HasPrefix(req.To, "#") { - // Channel message + switch req.Command { + case "PRIVMSG", "NOTICE": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) + return + } + content := strings.Join(lines, "\n") + + if strings.HasPrefix(req.To, "#") { + // Channel message + var chID int64 + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, content) + if err != nil { + s.log.Error("send message failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) + } else { + // DM + targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) + return + } + msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, content) + if err != nil { + s.log.Error("send dm failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) + } + + case "JOIN": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel) + if err != nil { + s.log.Error("get/create channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { + s.log.Error("join channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK) + + case "PART": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } var chID int64 err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID) + "SELECT id FROM channels WHERE name = ?", channel).Scan(&chID) if err != nil { s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) return } - msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, req.Content) - if err != nil { - s.log.Error("send message failed", "error", err) + if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { + s.log.Error("part channel failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) - } else { - // DM - targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) + s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK) + + case "NICK": + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest) return } - msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, req.Content) - if err != nil { - s.log.Error("send dm failed", "error", err) + newNick := strings.TrimSpace(lines[0]) + if newNick == "" || len(newNick) > 32 { + s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) + return + } + if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict) + return + } + s.log.Error("change nick failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) + s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) + + case "TOPIC": + if req.To == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + lines := bodyLines() + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest) + return + } + topic := strings.Join(lines, " ") + channel := req.To + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + if err := s.params.Database.SetTopic(r.Context(), channel, uid, topic); err != nil { + s.log.Error("set topic failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return + } + s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK) + + case "PING": + s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK) + + default: + _ = nick // suppress unused warning + s.respondJSON(w, r, map[string]string{"error": "unknown command: " + req.Command}, http.StatusBadRequest) } } } diff --git a/internal/server/routes.go b/internal/server/routes.go index e5ce9e9..e211492 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -52,18 +52,16 @@ func (s *Server) SetupRoutes() { // API v1 s.router.Route("/api/v1", func(r chi.Router) { r.Get("/server", s.h.HandleServerInfo()) - r.Post("/register", s.h.HandleRegister()) + r.Post("/session", s.h.HandleCreateSession()) // Unified state and message endpoints r.Get("/state", s.h.HandleState()) r.Get("/messages", s.h.HandleGetMessages()) - r.Post("/messages", s.h.HandleSendMessage()) + r.Post("/messages", s.h.HandleSendCommand()) r.Get("/history", s.h.HandleGetHistory()) // Channels - r.Get("/channels/all", s.h.HandleListAllChannels()) - r.Post("/channels/join", s.h.HandleJoinChannel()) - r.Delete("/channels/{channel}", s.h.HandlePartChannel()) + r.Get("/channels", s.h.HandleListAllChannels()) r.Get("/channels/{channel}/members", s.h.HandleChannelMembers()) })