Files
chat/internal/handlers/api.go
clawbot cb30b1b054
Some checks failed
check / check (push) Has been cancelled
fix: pin golangci-lint v2 by commit SHA, remove unused nolint directives
- Fix golangci-lint import path to v2/cmd/golangci-lint
- Pin to v2.8.0 commit SHA (e2e40021c9007020676c93680a36e3ab06c6cd33)
- Use CGO_ENABLED=0 for golangci-lint install
- Remove unused //nolint:gosec directives flagged by nolintlint
2026-02-26 20:22:12 -08:00

793 lines
14 KiB
Go

package handlers
import (
"database/sql"
"encoding/json"
"net/http"
"strconv"
"strings"
"git.eeqj.de/sneak/chat/internal/db"
"github.com/go-chi/chi"
)
const (
maxNickLen = 32
defaultHistory = 50
)
// authUser extracts the user from the Authorization header
// (Bearer token).
func (s *Handlers) authUser(
r *http.Request,
) (int64, string, error) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
return 0, "", sql.ErrNoRows
}
token := strings.TrimPrefix(auth, "Bearer ")
return s.params.Database.GetUserByToken(r.Context(), token)
}
func (s *Handlers) requireAuth(
w http.ResponseWriter,
r *http.Request,
) (int64, string, bool) {
uid, nick, err := s.authUser(r)
if err != nil {
s.respondJSON(
w, r,
map[string]string{"error": "unauthorized"},
http.StatusUnauthorized,
)
return 0, "", false
}
return uid, nick, true
}
func (s *Handlers) respondError(
w http.ResponseWriter,
r *http.Request,
msg string,
code int,
) {
s.respondJSON(w, r, map[string]string{"error": msg}, code)
}
func (s *Handlers) internalError(
w http.ResponseWriter,
r *http.Request,
msg string,
err error,
) {
s.log.Error(msg, "error", err)
s.respondError(w, r, "internal error", http.StatusInternalServerError)
}
// bodyLines extracts body as string lines from a request body
// field.
func bodyLines(body any) []string {
switch v := body.(type) {
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
lines = append(lines, s)
}
}
return lines
case []string:
return v
default:
return nil
}
}
// HandleCreateSession creates a new user session and returns
// the auth token.
func (s *Handlers) HandleCreateSession() http.HandlerFunc {
type request struct {
Nick string `json:"nick"`
}
type response struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
Token string `json:"token"`
}
return func(w http.ResponseWriter, r *http.Request) {
var req request
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
s.respondError(
w, r, "invalid request",
http.StatusBadRequest,
)
return
}
req.Nick = strings.TrimSpace(req.Nick)
if req.Nick == "" || len(req.Nick) > maxNickLen {
s.respondError(
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return
}
id, token, err := s.params.Database.CreateUser(
r.Context(), req.Nick,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") {
s.respondError(
w, r, "nick already taken",
http.StatusConflict,
)
return
}
s.internalError(w, r, "create user failed", err)
return
}
s.respondJSON(
w, r,
&response{ID: id, Nick: req.Nick, Token: token},
http.StatusCreated,
)
}
}
// HandleState returns the current user's info and joined
// channels.
func (s *Handlers) HandleState() http.HandlerFunc {
type response struct {
ID int64 `json:"id"`
Nick string `json:"nick"`
Channels []db.ChannelInfo `json:"channels"`
}
return func(w http.ResponseWriter, r *http.Request) {
uid, nick, ok := s.requireAuth(w, r)
if !ok {
return
}
channels, err := s.params.Database.ListChannels(
r.Context(), uid,
)
if err != nil {
s.internalError(
w, r, "list channels failed", err,
)
return
}
s.respondJSON(
w, r,
&response{
ID: uid, Nick: nick,
Channels: channels,
},
http.StatusOK,
)
}
}
// HandleListAllChannels returns all channels on the server.
func (s *Handlers) HandleListAllChannels() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
_, _, ok := s.requireAuth(w, r)
if !ok {
return
}
channels, err := s.params.Database.ListAllChannels(
r.Context(),
)
if err != nil {
s.internalError(
w, r, "list all channels failed", err,
)
return
}
s.respondJSON(w, r, channels, http.StatusOK)
}
}
// HandleChannelMembers returns members of a channel.
func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
_, _, ok := s.requireAuth(w, r)
if !ok {
return
}
name := "#" + chi.URLParam(r, "channel")
var chID int64
err := s.params.Database.GetDB().QueryRowContext(
r.Context(),
"SELECT id FROM channels WHERE name = ?",
name,
).Scan(&chID)
if err != nil {
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
members, err := s.params.Database.ChannelMembers(
r.Context(), chID,
)
if err != nil {
s.internalError(
w, r, "channel members failed", err,
)
return
}
s.respondJSON(w, r, members, http.StatusOK)
}
}
// HandleGetMessages returns all new messages (channel + DM)
// for the user via long-polling.
func (s *Handlers) HandleGetMessages() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
afterID, _ := strconv.ParseInt(
r.URL.Query().Get("after"), 10, 64,
)
limit, _ := strconv.Atoi(
r.URL.Query().Get("limit"),
)
msgs, err := s.params.Database.PollMessages(
r.Context(), uid, afterID, limit,
)
if err != nil {
s.internalError(
w, r, "get messages failed", err,
)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
}
type sendRequest struct {
Command string `json:"command"`
To string `json:"to"`
Params []string `json:"params,omitempty"`
Body any `json:"body,omitempty"`
}
// HandleSendCommand handles all C2S commands via POST
// /messages.
func (s *Handlers) HandleSendCommand() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, nick, ok := s.requireAuth(w, r)
if !ok {
return
}
var req sendRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
s.respondError(
w, r, "invalid request",
http.StatusBadRequest,
)
return
}
req.Command = strings.ToUpper(
strings.TrimSpace(req.Command),
)
req.To = strings.TrimSpace(req.To)
s.dispatchCommand(w, r, uid, nick, &req)
}
}
func (s *Handlers) dispatchCommand(
w http.ResponseWriter,
r *http.Request,
uid int64,
nick string,
req *sendRequest,
) {
switch req.Command {
case "PRIVMSG", "NOTICE":
s.handlePrivmsg(w, r, uid, req)
case "JOIN":
s.handleJoin(w, r, uid, req)
case "PART":
s.handlePart(w, r, uid, req)
case "NICK":
s.handleNick(w, r, uid, req)
case "TOPIC":
s.handleTopic(w, r, uid, req)
case "PING":
s.respondJSON(
w, r,
map[string]string{
"command": "PONG",
"from": s.params.Config.ServerName,
},
http.StatusOK,
)
default:
_ = nick
s.respondError(
w, r,
"unknown command: "+req.Command,
http.StatusBadRequest,
)
}
}
func (s *Handlers) handlePrivmsg(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
lines := bodyLines(req.Body)
if len(lines) == 0 {
s.respondError(
w, r, "body required", http.StatusBadRequest,
)
return
}
content := strings.Join(lines, "\n")
if strings.HasPrefix(req.To, "#") {
s.sendChannelMsg(w, r, uid, req.To, content)
} else {
s.sendDM(w, r, uid, req.To, content)
}
}
func (s *Handlers) sendChannelMsg(
w http.ResponseWriter,
r *http.Request,
uid int64,
channel, content string,
) {
var chID int64
err := s.params.Database.GetDB().QueryRowContext(
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil {
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
msgID, err := s.params.Database.SendMessage(
r.Context(), chID, uid, content,
)
if err != nil {
s.internalError(w, r, "send message failed", err)
return
}
s.respondJSON(
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
}
func (s *Handlers) sendDM(
w http.ResponseWriter,
r *http.Request,
uid int64,
toNick, content string,
) {
targetID, err := s.params.Database.GetUserByNick(
r.Context(), toNick,
)
if err != nil {
s.respondError(
w, r, "user not found", http.StatusNotFound,
)
return
}
msgID, err := s.params.Database.SendDM(
r.Context(), uid, targetID, content,
)
if err != nil {
s.internalError(w, r, "send dm failed", err)
return
}
s.respondJSON(
w, r,
map[string]any{"id": msgID, "status": "sent"},
http.StatusCreated,
)
}
func (s *Handlers) handleJoin(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
channel := req.To
if !strings.HasPrefix(channel, "#") {
channel = "#" + channel
}
chID, err := s.params.Database.GetOrCreateChannel(
r.Context(), channel,
)
if err != nil {
s.internalError(
w, r, "get/create channel failed", err,
)
return
}
err = s.params.Database.JoinChannel(
r.Context(), chID, uid,
)
if err != nil {
s.internalError(w, r, "join channel failed", err)
return
}
s.respondJSON(
w, r,
map[string]string{
"status": "joined", "channel": channel,
},
http.StatusOK,
)
}
func (s *Handlers) handlePart(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
channel := req.To
if !strings.HasPrefix(channel, "#") {
channel = "#" + channel
}
var chID int64
err := s.params.Database.GetDB().QueryRowContext(
r.Context(),
"SELECT id FROM channels WHERE name = ?",
channel,
).Scan(&chID)
if err != nil {
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
err = s.params.Database.PartChannel(
r.Context(), chID, uid,
)
if err != nil {
s.internalError(w, r, "part channel failed", err)
return
}
s.respondJSON(
w, r,
map[string]string{
"status": "parted", "channel": channel,
},
http.StatusOK,
)
}
func (s *Handlers) handleNick(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
lines := bodyLines(req.Body)
if len(lines) == 0 {
s.respondError(
w, r, "body required (new nick)",
http.StatusBadRequest,
)
return
}
newNick := strings.TrimSpace(lines[0])
if newNick == "" || len(newNick) > maxNickLen {
s.respondError(
w, r, "nick must be 1-32 characters",
http.StatusBadRequest,
)
return
}
err := s.params.Database.ChangeNick(
r.Context(), uid, newNick,
)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE") {
s.respondError(
w, r, "nick already in use",
http.StatusConflict,
)
return
}
s.internalError(w, r, "change nick failed", err)
return
}
s.respondJSON(
w, r,
map[string]string{"status": "ok", "nick": newNick},
http.StatusOK,
)
}
func (s *Handlers) handleTopic(
w http.ResponseWriter,
r *http.Request,
uid int64,
req *sendRequest,
) {
if req.To == "" {
s.respondError(
w, r, "to field required",
http.StatusBadRequest,
)
return
}
lines := bodyLines(req.Body)
if len(lines) == 0 {
s.respondError(
w, r, "body required (topic text)",
http.StatusBadRequest,
)
return
}
topic := strings.Join(lines, " ")
channel := req.To
if !strings.HasPrefix(channel, "#") {
channel = "#" + channel
}
err := s.params.Database.SetTopic(
r.Context(), channel, uid, topic,
)
if err != nil {
s.internalError(w, r, "set topic failed", err)
return
}
s.respondJSON(
w, r,
map[string]string{"status": "ok", "topic": topic},
http.StatusOK,
)
}
// HandleGetHistory returns message history for a specific
// target (channel or DM).
func (s *Handlers) HandleGetHistory() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
uid, _, ok := s.requireAuth(w, r)
if !ok {
return
}
target := r.URL.Query().Get("target")
if target == "" {
s.respondError(
w, r, "target required",
http.StatusBadRequest,
)
return
}
beforeID, _ := strconv.ParseInt(
r.URL.Query().Get("before"), 10, 64,
)
limit, _ := strconv.Atoi(
r.URL.Query().Get("limit"),
)
if limit <= 0 {
limit = defaultHistory
}
if strings.HasPrefix(target, "#") {
s.getChannelHistory(
w, r, target, beforeID, limit,
)
} else {
s.getDMHistory(
w, r, uid, target, beforeID, limit,
)
}
}
}
func (s *Handlers) getChannelHistory(
w http.ResponseWriter,
r *http.Request,
target string,
beforeID int64,
limit int,
) {
var chID int64
err := s.params.Database.GetDB().QueryRowContext(
r.Context(),
"SELECT id FROM channels WHERE name = ?",
target,
).Scan(&chID)
if err != nil {
s.respondError(
w, r, "channel not found",
http.StatusNotFound,
)
return
}
msgs, err := s.params.Database.GetMessagesBefore(
r.Context(), chID, beforeID, limit,
)
if err != nil {
s.internalError(w, r, "get history failed", err)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
func (s *Handlers) getDMHistory(
w http.ResponseWriter,
r *http.Request,
uid int64,
target string,
beforeID int64,
limit int,
) {
targetID, err := s.params.Database.GetUserByNick(
r.Context(), target,
)
if err != nil {
s.respondError(
w, r, "user not found", http.StatusNotFound,
)
return
}
msgs, err := s.params.Database.GetDMsBefore(
r.Context(), uid, targetID, beforeID, limit,
)
if err != nil {
s.internalError(
w, r, "get dm history failed", err,
)
return
}
s.respondJSON(w, r, msgs, http.StatusOK)
}
// HandleServerInfo returns server metadata (MOTD, name).
func (s *Handlers) HandleServerInfo() http.HandlerFunc {
type response struct {
Name string `json:"name"`
MOTD string `json:"motd"`
}
return func(w http.ResponseWriter, r *http.Request) {
s.respondJSON(w, r, &response{
Name: s.params.Config.ServerName,
MOTD: s.params.Config.MOTD,
}, http.StatusOK)
}
}