package handlers import ( "database/sql" "encoding/json" "net/http" "strconv" "strings" "git.eeqj.de/sneak/chat/internal/db" "github.com/go-chi/chi" ) const ( maxNickLen = 32 defaultHistory = 50 ) // authUser extracts the user from the Authorization header // (Bearer token). func (s *Handlers) authUser( r *http.Request, ) (int64, string, error) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { return 0, "", sql.ErrNoRows } token := strings.TrimPrefix(auth, "Bearer ") return s.params.Database.GetUserByToken(r.Context(), token) } func (s *Handlers) requireAuth( w http.ResponseWriter, r *http.Request, ) (int64, string, bool) { uid, nick, err := s.authUser(r) if err != nil { s.respondJSON( w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized, ) return 0, "", false } return uid, nick, true } func (s *Handlers) respondError( w http.ResponseWriter, r *http.Request, msg string, code int, ) { s.respondJSON(w, r, map[string]string{"error": msg}, code) } func (s *Handlers) internalError( w http.ResponseWriter, r *http.Request, msg string, err error, ) { s.log.Error(msg, "error", err) s.respondError(w, r, "internal error", http.StatusInternalServerError) } // bodyLines extracts body as string lines from a request body // field. func bodyLines(body any) []string { switch v := body.(type) { case []any: lines := make([]string, 0, len(v)) for _, item := range v { if s, ok := item.(string); ok { lines = append(lines, s) } } return lines case []string: return v default: return nil } } // HandleCreateSession creates a new user session and returns // the auth token. func (s *Handlers) HandleCreateSession() http.HandlerFunc { type request struct { Nick string `json:"nick"` } type response struct { ID int64 `json:"id"` Nick string `json:"nick"` Token string `json:"token"` } return func(w http.ResponseWriter, r *http.Request) { var req request err := json.NewDecoder(r.Body).Decode(&req) if err != nil { s.respondError( w, r, "invalid request", http.StatusBadRequest, ) return } req.Nick = strings.TrimSpace(req.Nick) if req.Nick == "" || len(req.Nick) > maxNickLen { s.respondError( w, r, "nick must be 1-32 characters", http.StatusBadRequest, ) return } id, token, err := s.params.Database.CreateUser( r.Context(), req.Nick, ) if err != nil { if strings.Contains(err.Error(), "UNIQUE") { s.respondError( w, r, "nick already taken", http.StatusConflict, ) return } s.internalError(w, r, "create user failed", err) return } s.respondJSON( w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated, ) } } // HandleState returns the current user's info and joined // channels. func (s *Handlers) HandleState() http.HandlerFunc { type response struct { ID int64 `json:"id"` Nick string `json:"nick"` Channels []db.ChannelInfo `json:"channels"` } return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } channels, err := s.params.Database.ListChannels( r.Context(), uid, ) if err != nil { s.internalError( w, r, "list channels failed", err, ) return } s.respondJSON( w, r, &response{ ID: uid, Nick: nick, Channels: channels, }, http.StatusOK, ) } } // HandleListAllChannels returns all channels on the server. func (s *Handlers) HandleListAllChannels() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, _, ok := s.requireAuth(w, r) if !ok { return } channels, err := s.params.Database.ListAllChannels( r.Context(), ) if err != nil { s.internalError( w, r, "list all channels failed", err, ) return } s.respondJSON(w, r, channels, http.StatusOK) } } // HandleChannelMembers returns members of a channel. func (s *Handlers) HandleChannelMembers() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, _, ok := s.requireAuth(w, r) if !ok { return } name := "#" + chi.URLParam(r, "channel") var chID int64 err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query r.Context(), "SELECT id FROM channels WHERE name = ?", name, ).Scan(&chID) if err != nil { s.respondError( w, r, "channel not found", http.StatusNotFound, ) return } members, err := s.params.Database.ChannelMembers( r.Context(), chID, ) if err != nil { s.internalError( w, r, "channel members failed", err, ) return } s.respondJSON(w, r, members, http.StatusOK) } } // HandleGetMessages returns all new messages (channel + DM) // for the user via long-polling. func (s *Handlers) HandleGetMessages() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, _, ok := s.requireAuth(w, r) if !ok { return } afterID, _ := strconv.ParseInt( r.URL.Query().Get("after"), 10, 64, ) limit, _ := strconv.Atoi( r.URL.Query().Get("limit"), ) msgs, err := s.params.Database.PollMessages( r.Context(), uid, afterID, limit, ) if err != nil { s.internalError( w, r, "get messages failed", err, ) return } s.respondJSON(w, r, msgs, http.StatusOK) } } type sendRequest struct { Command string `json:"command"` To string `json:"to"` Params []string `json:"params,omitempty"` Body any `json:"body,omitempty"` } // HandleSendCommand handles all C2S commands via POST // /messages. func (s *Handlers) HandleSendCommand() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } var req sendRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { s.respondError( w, r, "invalid request", http.StatusBadRequest, ) return } req.Command = strings.ToUpper( strings.TrimSpace(req.Command), ) req.To = strings.TrimSpace(req.To) s.dispatchCommand(w, r, uid, nick, &req) } } func (s *Handlers) dispatchCommand( w http.ResponseWriter, r *http.Request, uid int64, nick string, req *sendRequest, ) { switch req.Command { case "PRIVMSG", "NOTICE": s.handlePrivmsg(w, r, uid, req) case "JOIN": s.handleJoin(w, r, uid, req) case "PART": s.handlePart(w, r, uid, req) case "NICK": s.handleNick(w, r, uid, req) case "TOPIC": s.handleTopic(w, r, uid, req) case "PING": s.respondJSON( w, r, map[string]string{ "command": "PONG", "from": s.params.Config.ServerName, }, http.StatusOK, ) default: _ = nick s.respondError( w, r, "unknown command: "+req.Command, http.StatusBadRequest, ) } } func (s *Handlers) handlePrivmsg( w http.ResponseWriter, r *http.Request, uid int64, req *sendRequest, ) { if req.To == "" { s.respondError( w, r, "to field required", http.StatusBadRequest, ) return } lines := bodyLines(req.Body) if len(lines) == 0 { s.respondError( w, r, "body required", http.StatusBadRequest, ) return } content := strings.Join(lines, "\n") if strings.HasPrefix(req.To, "#") { s.sendChannelMsg(w, r, uid, req.To, content) } else { s.sendDM(w, r, uid, req.To, content) } } func (s *Handlers) sendChannelMsg( w http.ResponseWriter, r *http.Request, uid int64, channel, content string, ) { var chID int64 err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query r.Context(), "SELECT id FROM channels WHERE name = ?", channel, ).Scan(&chID) if err != nil { s.respondError( w, r, "channel not found", http.StatusNotFound, ) return } msgID, err := s.params.Database.SendMessage( r.Context(), chID, uid, content, ) if err != nil { s.internalError(w, r, "send message failed", err) return } s.respondJSON( w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated, ) } func (s *Handlers) sendDM( w http.ResponseWriter, r *http.Request, uid int64, toNick, content string, ) { targetID, err := s.params.Database.GetUserByNick( r.Context(), toNick, ) if err != nil { s.respondError( w, r, "user not found", http.StatusNotFound, ) return } msgID, err := s.params.Database.SendDM( r.Context(), uid, targetID, content, ) if err != nil { s.internalError(w, r, "send dm failed", err) return } s.respondJSON( w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated, ) } func (s *Handlers) handleJoin( w http.ResponseWriter, r *http.Request, uid int64, req *sendRequest, ) { if req.To == "" { s.respondError( w, r, "to field required", http.StatusBadRequest, ) return } channel := req.To if !strings.HasPrefix(channel, "#") { channel = "#" + channel } chID, err := s.params.Database.GetOrCreateChannel( r.Context(), channel, ) if err != nil { s.internalError( w, r, "get/create channel failed", err, ) return } err = s.params.Database.JoinChannel( r.Context(), chID, uid, ) if err != nil { s.internalError(w, r, "join channel failed", err) return } s.respondJSON( w, r, map[string]string{ "status": "joined", "channel": channel, }, http.StatusOK, ) } func (s *Handlers) handlePart( w http.ResponseWriter, r *http.Request, uid int64, req *sendRequest, ) { if req.To == "" { s.respondError( w, r, "to field required", http.StatusBadRequest, ) return } channel := req.To if !strings.HasPrefix(channel, "#") { channel = "#" + channel } var chID int64 err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query r.Context(), "SELECT id FROM channels WHERE name = ?", channel, ).Scan(&chID) if err != nil { s.respondError( w, r, "channel not found", http.StatusNotFound, ) return } err = s.params.Database.PartChannel( r.Context(), chID, uid, ) if err != nil { s.internalError(w, r, "part channel failed", err) return } s.respondJSON( w, r, map[string]string{ "status": "parted", "channel": channel, }, http.StatusOK, ) } func (s *Handlers) handleNick( w http.ResponseWriter, r *http.Request, uid int64, req *sendRequest, ) { lines := bodyLines(req.Body) if len(lines) == 0 { s.respondError( w, r, "body required (new nick)", http.StatusBadRequest, ) return } newNick := strings.TrimSpace(lines[0]) if newNick == "" || len(newNick) > maxNickLen { s.respondError( w, r, "nick must be 1-32 characters", http.StatusBadRequest, ) return } err := s.params.Database.ChangeNick( r.Context(), uid, newNick, ) if err != nil { if strings.Contains(err.Error(), "UNIQUE") { s.respondError( w, r, "nick already in use", http.StatusConflict, ) return } s.internalError(w, r, "change nick failed", err) return } s.respondJSON( w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK, ) } func (s *Handlers) handleTopic( w http.ResponseWriter, r *http.Request, uid int64, req *sendRequest, ) { if req.To == "" { s.respondError( w, r, "to field required", http.StatusBadRequest, ) return } lines := bodyLines(req.Body) if len(lines) == 0 { s.respondError( w, r, "body required (topic text)", http.StatusBadRequest, ) return } topic := strings.Join(lines, " ") channel := req.To if !strings.HasPrefix(channel, "#") { channel = "#" + channel } err := s.params.Database.SetTopic( r.Context(), channel, uid, topic, ) if err != nil { s.internalError(w, r, "set topic failed", err) return } s.respondJSON( w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK, ) } // HandleGetHistory returns message history for a specific // target (channel or DM). func (s *Handlers) HandleGetHistory() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { uid, _, ok := s.requireAuth(w, r) if !ok { return } target := r.URL.Query().Get("target") if target == "" { s.respondError( w, r, "target required", http.StatusBadRequest, ) return } beforeID, _ := strconv.ParseInt( r.URL.Query().Get("before"), 10, 64, ) limit, _ := strconv.Atoi( r.URL.Query().Get("limit"), ) if limit <= 0 { limit = defaultHistory } if strings.HasPrefix(target, "#") { s.getChannelHistory( w, r, target, beforeID, limit, ) } else { s.getDMHistory( w, r, uid, target, beforeID, limit, ) } } } func (s *Handlers) getChannelHistory( w http.ResponseWriter, r *http.Request, target string, beforeID int64, limit int, ) { var chID int64 err := s.params.Database.GetDB().QueryRowContext( //nolint:gosec // parameterized query r.Context(), "SELECT id FROM channels WHERE name = ?", target, ).Scan(&chID) if err != nil { s.respondError( w, r, "channel not found", http.StatusNotFound, ) return } msgs, err := s.params.Database.GetMessagesBefore( r.Context(), chID, beforeID, limit, ) if err != nil { s.internalError(w, r, "get history failed", err) return } s.respondJSON(w, r, msgs, http.StatusOK) } func (s *Handlers) getDMHistory( w http.ResponseWriter, r *http.Request, uid int64, target string, beforeID int64, limit int, ) { targetID, err := s.params.Database.GetUserByNick( r.Context(), target, ) if err != nil { s.respondError( w, r, "user not found", http.StatusNotFound, ) return } msgs, err := s.params.Database.GetDMsBefore( r.Context(), uid, targetID, beforeID, limit, ) if err != nil { s.internalError( w, r, "get dm history failed", err, ) return } s.respondJSON(w, r, msgs, http.StatusOK) } // HandleServerInfo returns server metadata (MOTD, name). func (s *Handlers) HandleServerInfo() http.HandlerFunc { type response struct { Name string `json:"name"` MOTD string `json:"motd"` } return func(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, r, &response{ Name: s.params.Config.ServerName, MOTD: s.params.Config.MOTD, }, http.StatusOK) } }