package handlers import ( "encoding/json" "net/http" "regexp" "strconv" "strings" "time" "github.com/go-chi/chi" ) var validNickRe = regexp.MustCompile( `^[a-zA-Z_][a-zA-Z0-9_\-\[\]\\^{}|` + "`" + `]{0,31}$`, ) var validChannelRe = regexp.MustCompile( `^#[a-zA-Z0-9_\-]{1,63}$`, ) const ( maxLongPollTimeout = 30 pollMessageLimit = 100 ) // authUser extracts the user from the Authorization header. func (s *Handlers) authUser( r *http.Request, ) (int64, string, error) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { return 0, "", errUnauthorized } token := strings.TrimPrefix(auth, "Bearer ") if token == "" { return 0, "", errUnauthorized } return s.params.Database.GetUserByToken( r.Context(), token, ) } func (s *Handlers) requireAuth( w http.ResponseWriter, r *http.Request, ) (int64, string, bool) { uid, nick, err := s.authUser(r) if err != nil { s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized) return 0, "", false } return uid, nick, true } // fanOut stores a message and enqueues it to all specified // user IDs, then notifies them. func (s *Handlers) fanOut( r *http.Request, command, from, to string, body json.RawMessage, userIDs []int64, ) (string, error) { dbID, msgUUID, err := s.params.Database.InsertMessage( r.Context(), command, from, to, body, nil, ) if err != nil { return "", err } for _, uid := range userIDs { err = s.params.Database.EnqueueMessage( r.Context(), uid, dbID, ) if err != nil { s.log.Error("enqueue failed", "error", err, "user_id", uid) } s.broker.Notify(uid) } return msgUUID, nil } // fanOutSilent calls fanOut and discards the return values. func (s *Handlers) fanOutSilent( r *http.Request, command, from, to string, body json.RawMessage, userIDs []int64, ) error { _, err := s.fanOut( r, command, from, to, body, userIDs, ) return err } // HandleCreateSession creates a new user session. func (s *Handlers) HandleCreateSession() http.HandlerFunc { type request struct { Nick string `json:"nick"` } type response struct { ID int64 `json:"id"` Nick string `json:"nick"` Token string `json:"token"` } return func(w http.ResponseWriter, r *http.Request) { var req request err := json.NewDecoder(r.Body).Decode(&req) if err != nil { s.respondJSON(w, r, map[string]string{ "error": "invalid request body", }, http.StatusBadRequest) return } req.Nick = strings.TrimSpace(req.Nick) if !validNickRe.MatchString(req.Nick) { s.respondJSON(w, r, map[string]string{ "error": "invalid nick format", }, http.StatusBadRequest) return } id, token, err := s.params.Database.CreateUser( r.Context(), req.Nick, ) if err != nil { if strings.Contains(err.Error(), "UNIQUE") { s.respondJSON(w, r, map[string]string{ "error": "nick already taken", }, http.StatusConflict) return } s.log.Error("create user failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated) } } // HandleState returns the current user's info and channels. func (s *Handlers) HandleState() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } channels, err := s.params.Database.ListChannels( r.Context(), uid, ) if err != nil { s.log.Error("list channels failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, map[string]any{ "id": uid, "nick": nick, "channels": channels, }, http.StatusOK) } } // HandleListAllChannels returns all channels on the server. func (s *Handlers) HandleListAllChannels() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, _, ok := s.requireAuth(w, r) if !ok { return } channels, err := s.params.Database.ListAllChannels( r.Context(), ) if err != nil { s.log.Error( "list all channels failed", "error", err, ) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, channels, http.StatusOK) } } // HandleChannelMembers returns members of a channel. func (s *Handlers) HandleChannelMembers() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, _, ok := s.requireAuth(w, r) if !ok { return } name := "#" + chi.URLParam(r, "channel") chID, err := s.params.Database.GetChannelByName( r.Context(), name, ) if err != nil { s.respondJSON(w, r, map[string]string{ "error": "channel not found", }, http.StatusNotFound) return } members, err := s.params.Database.ChannelMembers( r.Context(), chID, ) if err != nil { s.log.Error( "channel members failed", "error", err, ) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, members, http.StatusOK) } } // HandleGetMessages returns messages via long-polling. func (s *Handlers) HandleGetMessages() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, _, ok := s.requireAuth(w, r) if !ok { return } afterID, _ := strconv.ParseInt( r.URL.Query().Get("after"), 10, 64, ) timeout, _ := strconv.Atoi( r.URL.Query().Get("timeout"), ) if timeout < 0 { timeout = 0 } if timeout > maxLongPollTimeout { timeout = maxLongPollTimeout } msgs, lastQID, err := s.params.Database.PollMessages( r.Context(), uid, afterID, pollMessageLimit, ) if err != nil { s.log.Error( "poll messages failed", "error", err, ) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } if len(msgs) > 0 || timeout == 0 { s.respondJSON(w, r, map[string]any{ "messages": msgs, "last_id": lastQID, }, http.StatusOK) return } s.longPoll(w, r, uid, afterID, timeout) } } func (s *Handlers) longPoll( w http.ResponseWriter, r *http.Request, uid, afterID int64, timeout int, ) { waitCh := s.broker.Wait(uid) timer := time.NewTimer( time.Duration(timeout) * time.Second, ) defer timer.Stop() select { case <-waitCh: case <-timer.C: case <-r.Context().Done(): s.broker.Remove(uid, waitCh) return } s.broker.Remove(uid, waitCh) msgs, lastQID, err := s.params.Database.PollMessages( r.Context(), uid, afterID, pollMessageLimit, ) if err != nil { s.log.Error("poll messages failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, map[string]any{ "messages": msgs, "last_id": lastQID, }, http.StatusOK) } // HandleSendCommand handles all C2S commands. func (s *Handlers) HandleSendCommand() http.HandlerFunc { type request struct { Command string `json:"command"` To string `json:"to"` Body json.RawMessage `json:"body,omitempty"` Meta json.RawMessage `json:"meta,omitempty"` } return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } var req request err := json.NewDecoder(r.Body).Decode(&req) if err != nil { s.respondJSON(w, r, map[string]string{ "error": "invalid request body", }, http.StatusBadRequest) return } req.Command = strings.ToUpper( strings.TrimSpace(req.Command), ) req.To = strings.TrimSpace(req.To) if req.Command == "" { s.respondJSON(w, r, map[string]string{ "error": "command required", }, http.StatusBadRequest) return } bodyLines := func() []string { if req.Body == nil { return nil } var lines []string err := json.Unmarshal(req.Body, &lines) if err != nil { return nil } return lines } s.dispatchCommand( w, r, uid, nick, req.Command, req.To, req.Body, bodyLines, ) } } func (s *Handlers) dispatchCommand( w http.ResponseWriter, r *http.Request, uid int64, nick, command, to string, body json.RawMessage, bodyLines func() []string, ) { switch command { case "PRIVMSG", "NOTICE": s.handlePrivmsg( w, r, uid, nick, command, to, body, bodyLines, ) case "JOIN": s.handleJoin(w, r, uid, nick, to) case "PART": s.handlePart(w, r, uid, nick, to, body) case "NICK": s.handleNick(w, r, uid, nick, bodyLines) case "TOPIC": s.handleTopic(w, r, nick, to, body, bodyLines) case "QUIT": s.handleQuit(w, r, uid, nick, body) case "PING": s.respondJSON(w, r, map[string]string{ "command": "PONG", "from": s.params.Config.ServerName, }, http.StatusOK) default: s.respondJSON(w, r, map[string]string{ "error": "unknown command: " + command, }, http.StatusBadRequest) } } func (s *Handlers) handlePrivmsg( w http.ResponseWriter, r *http.Request, uid int64, nick, command, to string, body json.RawMessage, bodyLines func() []string, ) { if to == "" { s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) return } lines := bodyLines() if len(lines) == 0 { s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) return } if strings.HasPrefix(to, "#") { s.handleChannelMsg( w, r, uid, nick, command, to, body, ) return } s.handleDirectMsg(w, r, uid, nick, command, to, body) } func (s *Handlers) handleChannelMsg( w http.ResponseWriter, r *http.Request, _ int64, nick, command, to string, body json.RawMessage, ) { chID, err := s.params.Database.GetChannelByName( r.Context(), to, ) if err != nil { s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) return } memberIDs, err := s.params.Database.GetChannelMemberIDs( r.Context(), chID, ) if err != nil { s.log.Error( "get channel members failed", "error", err, ) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } msgUUID, err := s.fanOut( r, command, nick, to, body, memberIDs, ) if err != nil { s.log.Error("send message failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) } func (s *Handlers) handleDirectMsg( w http.ResponseWriter, r *http.Request, uid int64, nick, command, to string, body json.RawMessage, ) { targetUID, err := s.params.Database.GetUserByNick( r.Context(), to, ) if err != nil { s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) return } recipients := []int64{targetUID} if targetUID != uid { recipients = append(recipients, uid) } msgUUID, err := s.fanOut( r, command, nick, to, body, recipients, ) if err != nil { s.log.Error("send dm failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, map[string]string{"id": msgUUID, "status": "sent"}, http.StatusCreated) } func (s *Handlers) handleJoin( w http.ResponseWriter, r *http.Request, uid int64, nick, to string, ) { if to == "" { s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) return } channel := to if !strings.HasPrefix(channel, "#") { channel = "#" + channel } if !validChannelRe.MatchString(channel) { s.respondJSON(w, r, map[string]string{ "error": "invalid channel name", }, http.StatusBadRequest) return } chID, err := s.params.Database.GetOrCreateChannel( r.Context(), channel, ) if err != nil { s.log.Error( "get/create channel failed", "error", err, ) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } err = s.params.Database.JoinChannel( r.Context(), chID, uid, ) if err != nil { s.log.Error("join channel failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } memberIDs, _ := s.params.Database.GetChannelMemberIDs( r.Context(), chID, ) _ = s.fanOutSilent( r, "JOIN", nick, channel, nil, memberIDs, ) s.respondJSON(w, r, map[string]string{ "status": "joined", "channel": channel, }, http.StatusOK) } func (s *Handlers) handlePart( w http.ResponseWriter, r *http.Request, uid int64, nick, to string, body json.RawMessage, ) { if to == "" { s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) return } channel := to if !strings.HasPrefix(channel, "#") { channel = "#" + channel } chID, err := s.params.Database.GetChannelByName( r.Context(), channel, ) if err != nil { s.respondJSON(w, r, map[string]string{ "error": "channel not found", }, http.StatusNotFound) return } memberIDs, _ := s.params.Database.GetChannelMemberIDs( r.Context(), chID, ) _ = s.fanOutSilent( r, "PART", nick, channel, body, memberIDs, ) err = s.params.Database.PartChannel( r.Context(), chID, uid, ) if err != nil { s.log.Error("part channel failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } _ = s.params.Database.DeleteChannelIfEmpty( r.Context(), chID, ) s.respondJSON(w, r, map[string]string{ "status": "parted", "channel": channel, }, http.StatusOK) } func (s *Handlers) handleNick( w http.ResponseWriter, r *http.Request, uid int64, nick string, bodyLines func() []string, ) { lines := bodyLines() if len(lines) == 0 { s.respondJSON(w, r, map[string]string{ "error": "body required (new nick)", }, http.StatusBadRequest) return } newNick := strings.TrimSpace(lines[0]) if !validNickRe.MatchString(newNick) { s.respondJSON(w, r, map[string]string{"error": "invalid nick"}, http.StatusBadRequest) return } if newNick == nick { s.respondJSON(w, r, map[string]string{ "status": "ok", "nick": newNick, }, http.StatusOK) return } err := s.params.Database.ChangeNick( r.Context(), uid, newNick, ) if err != nil { if strings.Contains(err.Error(), "UNIQUE") { s.respondJSON(w, r, map[string]string{ "error": "nick already in use", }, http.StatusConflict) return } s.log.Error("change nick failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.broadcastNick(r, uid, nick, newNick) s.respondJSON(w, r, map[string]string{ "status": "ok", "nick": newNick, }, http.StatusOK) } func (s *Handlers) broadcastNick( r *http.Request, uid int64, oldNick, newNick string, ) { channels, _ := s.params.Database. GetAllChannelMembershipsForUser(r.Context(), uid) notified := map[int64]bool{uid: true} nickBody, err := json.Marshal([]string{newNick}) if err != nil { s.log.Error("marshal nick body", "error", err) return } dbID, _, _ := s.params.Database.InsertMessage( r.Context(), "NICK", oldNick, "", json.RawMessage(nickBody), nil, ) _ = s.params.Database.EnqueueMessage( r.Context(), uid, dbID, ) s.broker.Notify(uid) for _, ch := range channels { memberIDs, _ := s.params.Database. GetChannelMemberIDs(r.Context(), ch.ID) for _, mid := range memberIDs { if !notified[mid] { notified[mid] = true _ = s.params.Database.EnqueueMessage( r.Context(), mid, dbID, ) s.broker.Notify(mid) } } } } func (s *Handlers) handleTopic( w http.ResponseWriter, r *http.Request, nick, to string, body json.RawMessage, bodyLines func() []string, ) { if to == "" { s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) return } lines := bodyLines() if len(lines) == 0 { s.respondJSON(w, r, map[string]string{ "error": "body required (topic text)", }, http.StatusBadRequest) return } topic := strings.Join(lines, " ") channel := to if !strings.HasPrefix(channel, "#") { channel = "#" + channel } err := s.params.Database.SetTopic( r.Context(), channel, topic, ) if err != nil { s.log.Error("set topic failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } chID, err := s.params.Database.GetChannelByName( r.Context(), channel, ) if err != nil { s.respondJSON(w, r, map[string]string{ "error": "channel not found", }, http.StatusNotFound) return } memberIDs, _ := s.params.Database.GetChannelMemberIDs( r.Context(), chID, ) _ = s.fanOutSilent( r, "TOPIC", nick, channel, body, memberIDs, ) s.respondJSON(w, r, map[string]string{ "status": "ok", "topic": topic, }, http.StatusOK) } func (s *Handlers) handleQuit( w http.ResponseWriter, r *http.Request, uid int64, nick string, body json.RawMessage, ) { channels, _ := s.params.Database. GetAllChannelMembershipsForUser(r.Context(), uid) notified := map[int64]bool{} var dbID int64 if len(channels) > 0 { dbID, _, _ = s.params.Database.InsertMessage( r.Context(), "QUIT", nick, "", body, nil, ) } for _, ch := range channels { memberIDs, _ := s.params.Database. GetChannelMemberIDs(r.Context(), ch.ID) for _, mid := range memberIDs { if mid != uid && !notified[mid] { notified[mid] = true _ = s.params.Database.EnqueueMessage( r.Context(), mid, dbID, ) s.broker.Notify(mid) } } _ = s.params.Database.PartChannel( r.Context(), ch.ID, uid, ) _ = s.params.Database.DeleteChannelIfEmpty( r.Context(), ch.ID, ) } _ = s.params.Database.DeleteUser(r.Context(), uid) s.respondJSON(w, r, map[string]string{"status": "quit"}, http.StatusOK) } const ( defaultHistLimit = 50 maxHistLimit = 500 ) // HandleGetHistory returns message history for a target. func (s *Handlers) HandleGetHistory() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, _, ok := s.requireAuth(w, r) if !ok { return } target := r.URL.Query().Get("target") if target == "" { s.respondJSON(w, r, map[string]string{ "error": "target required", }, http.StatusBadRequest) return } beforeID, _ := strconv.ParseInt( r.URL.Query().Get("before"), 10, 64, ) limit, _ := strconv.Atoi( r.URL.Query().Get("limit"), ) if limit <= 0 || limit > maxHistLimit { limit = defaultHistLimit } msgs, err := s.params.Database.GetHistory( r.Context(), target, beforeID, limit, ) if err != nil { s.log.Error( "get history failed", "error", err, ) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) return } s.respondJSON(w, r, msgs, http.StatusOK) } } // HandleServerInfo returns server metadata. func (s *Handlers) HandleServerInfo() http.HandlerFunc { type response struct { Name string `json:"name"` MOTD string `json:"motd"` } return func(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, r, &response{ Name: s.params.Config.ServerName, MOTD: s.params.Config.MOTD, }, http.StatusOK) } }