diff --git a/internal/db/queries.go b/internal/db/queries.go index 974f4db..ce63057 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -2,88 +2,52 @@ package db import ( "context" - "crypto/rand" - "encoding/hex" "fmt" "time" ) -func generateToken() string { - b := make([]byte, 32) - _, _ = rand.Read(b) - return hex.EncodeToString(b) -} - -// CreateUser registers a new user with the given nick and returns the user with token. -func (s *Database) CreateUser(ctx context.Context, nick string) (int64, string, error) { - token := generateToken() - now := time.Now() - res, err := s.db.ExecContext(ctx, - "INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)", - nick, token, now, now) - if err != nil { - return 0, "", fmt.Errorf("create user: %w", err) - } - id, _ := res.LastInsertId() - return id, token, nil -} - -// GetUserByToken returns user id and nick for a given auth token. -func (s *Database) GetUserByToken(ctx context.Context, token string) (int64, string, error) { - var id int64 - var nick string - err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick) - if err != nil { - return 0, "", err - } - // Update last_seen - _, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id) - return id, nick, nil -} - -// GetUserByNick returns user id for a given nick. -func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error) { - var id int64 - err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id) - return id, err -} - // GetOrCreateChannel returns the channel id, creating it if needed. -func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) { - var id int64 +func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (string, error) { + var id string + err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id) if err == nil { return id, nil } + now := time.Now() - res, err := s.db.ExecContext(ctx, - "INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)", - name, now, now) + id = fmt.Sprintf("ch-%d", now.UnixNano()) + + _, err = s.db.ExecContext(ctx, + "INSERT INTO channels (id, name, topic, modes, created_at, updated_at) VALUES (?, ?, '', '', ?, ?)", + id, name, now, now) if err != nil { - return 0, fmt.Errorf("create channel: %w", err) + return "", fmt.Errorf("create channel: %w", err) } - id, _ = res.LastInsertId() + return id, nil } // JoinChannel adds a user to a channel. -func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error { +func (s *Database) JoinChannel(ctx context.Context, channelID, userID string) error { _, err := s.db.ExecContext(ctx, - "INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)", + "INSERT OR IGNORE INTO channel_members (channel_id, user_id, modes, joined_at) VALUES (?, ?, '', ?)", channelID, userID, time.Now()) + return err } // PartChannel removes a user from a channel. -func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error { +func (s *Database) PartChannel(ctx context.Context, channelID, userID string) error { _, err := s.db.ExecContext(ctx, "DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?", channelID, userID) + return err } // ListChannels returns all channels the user has joined. -func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) { +func (s *Database) ListChannels(ctx context.Context, userID string) ([]ChannelInfo, error) { rows, err := s.db.QueryContext(ctx, `SELECT c.id, c.name, c.topic FROM channels c INNER JOIN channel_members cm ON cm.channel_id = c.id @@ -91,62 +55,66 @@ func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInf if err != nil { return nil, err } - defer rows.Close() - var channels []ChannelInfo + + defer func() { _ = rows.Close() }() + + channels := []ChannelInfo{} + for rows.Next() { var ch ChannelInfo if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } + channels = append(channels, ch) } - if channels == nil { - channels = []ChannelInfo{} - } - return channels, nil + + return channels, rows.Err() } // ChannelInfo is a lightweight channel representation. type ChannelInfo struct { - ID int64 `json:"id"` + ID string `json:"id"` Name string `json:"name"` Topic string `json:"topic"` } // ChannelMembers returns all members of a channel. -func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) { +func (s *Database) ChannelMembers(ctx context.Context, channelID string) ([]MemberInfo, error) { rows, err := s.db.QueryContext(ctx, - `SELECT u.id, u.nick, u.last_seen FROM users u + `SELECT u.id, u.nick, u.last_seen_at FROM users u INNER JOIN channel_members cm ON cm.user_id = u.id WHERE cm.channel_id = ? ORDER BY u.nick`, channelID) if err != nil { return nil, err } - defer rows.Close() - var members []MemberInfo + + defer func() { _ = rows.Close() }() + + members := []MemberInfo{} + for rows.Next() { var m MemberInfo if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil { return nil, err } + members = append(members, m) } - if members == nil { - members = []MemberInfo{} - } - return members, nil + + return members, rows.Err() } // MemberInfo represents a channel member. type MemberInfo struct { - ID int64 `json:"id"` - Nick string `json:"nick"` - LastSeen time.Time `json:"lastSeen"` + ID string `json:"id"` + Nick string `json:"nick"` + LastSeen *time.Time `json:"lastSeen"` } // MessageInfo represents a chat message. type MessageInfo struct { - ID int64 `json:"id"` + ID string `json:"id"` Channel string `json:"channel,omitempty"` Nick string `json:"nick"` Content string `json:"content"` @@ -155,234 +123,202 @@ type MessageInfo struct { CreatedAt time.Time `json:"createdAt"` } -// GetMessages returns messages for a channel, optionally after a given ID. -func (s *Database) GetMessages(ctx context.Context, channelID int64, afterID int64, limit int) ([]MessageInfo, error) { - if limit <= 0 { - limit = 50 - } - rows, err := s.db.QueryContext(ctx, - `SELECT m.id, c.name, u.nick, m.content, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN channels c ON c.id = m.channel_id - WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id > ? - ORDER BY m.id ASC LIMIT ?`, channelID, afterID, limit) - if err != nil { - return nil, err - } - defer rows.Close() - var msgs []MessageInfo - for rows.Next() { - var m MessageInfo - if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { - return nil, err - } - msgs = append(msgs, m) - } - if msgs == nil { - msgs = []MessageInfo{} - } - return msgs, nil -} - // SendMessage inserts a channel message. -func (s *Database) SendMessage(ctx context.Context, channelID, userID int64, content string) (int64, error) { - res, err := s.db.ExecContext(ctx, - "INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)", - channelID, userID, content, time.Now()) +func (s *Database) SendMessage(ctx context.Context, channelID, userID, nick, content string) (string, error) { + now := time.Now() + id := fmt.Sprintf("msg-%d", now.UnixNano()) + + _, err := s.db.ExecContext(ctx, + `INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at) + VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`, + id, now, userID, nick, channelID, content, now) if err != nil { - return 0, err + return "", err } - return res.LastInsertId() + + return id, nil } // SendDM inserts a direct message. -func (s *Database) SendDM(ctx context.Context, fromID, toID int64, content string) (int64, error) { - res, err := s.db.ExecContext(ctx, - "INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)", - fromID, content, toID, time.Now()) +func (s *Database) SendDM(ctx context.Context, fromID, fromNick, toID, content string) (string, error) { + now := time.Now() + id := fmt.Sprintf("msg-%d", now.UnixNano()) + + _, err := s.db.ExecContext(ctx, + `INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at) + VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`, + id, now, fromID, fromNick, toID, content, now) if err != nil { - return 0, err + return "", err } - return res.LastInsertId() + + return id, nil } -// GetDMs returns direct messages between two users after a given ID. -func (s *Database) GetDMs(ctx context.Context, userA, userB int64, afterID int64, limit int) ([]MessageInfo, error) { - if limit <= 0 { - limit = 50 - } - rows, err := s.db.QueryContext(ctx, - `SELECT m.id, u.nick, m.content, t.nick, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN users t ON t.id = m.dm_target_id - WHERE m.is_dm = 1 AND m.id > ? - AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) - ORDER BY m.id ASC LIMIT ?`, afterID, userA, userB, userB, userA, limit) - if err != nil { - return nil, err - } - defer rows.Close() - var msgs []MessageInfo - for rows.Next() { - var m MessageInfo - if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil { - return nil, err - } - m.IsDM = true - msgs = append(msgs, m) - } - if msgs == nil { - msgs = []MessageInfo{} - } - return msgs, nil -} - -// PollMessages returns all new messages (channel + DM) for a user after a given ID. -func (s *Database) PollMessages(ctx context.Context, userID int64, afterID int64, limit int) ([]MessageInfo, error) { +// PollMessages returns all new messages for a user's joined channels, ordered by timestamp. +func (s *Database) PollMessages(ctx context.Context, userID string, afterTS string, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 100 } - rows, err := s.db.QueryContext(ctx, - `SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, m.is_dm, COALESCE(t.nick, ''), m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - LEFT JOIN channels c ON c.id = m.channel_id - LEFT JOIN users t ON t.id = m.dm_target_id - WHERE m.id > ? AND ( - (m.is_dm = 0 AND m.channel_id IN (SELECT channel_id FROM channel_members WHERE user_id = ?)) - OR (m.is_dm = 1 AND (m.user_id = ? OR m.dm_target_id = ?)) - ) - ORDER BY m.id ASC LIMIT ?`, afterID, userID, userID, userID, limit) - if err != nil { - return nil, err - } - defer rows.Close() - var msgs []MessageInfo - for rows.Next() { - var m MessageInfo - var isDM int - if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &isDM, &m.DMTarget, &m.CreatedAt); err != nil { - return nil, err - } - m.IsDM = isDM == 1 - msgs = append(msgs, m) - } - if msgs == nil { - msgs = []MessageInfo{} - } - return msgs, nil -} -// GetMessagesBefore returns channel messages before a given ID (for history scrollback). -func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, beforeID int64, limit int) ([]MessageInfo, error) { - if limit <= 0 { - limit = 50 - } - var query string - var args []any - if beforeID > 0 { - query = `SELECT m.id, c.name, u.nick, m.content, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN channels c ON c.id = m.channel_id - WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ? - ORDER BY m.id DESC LIMIT ?` - args = []any{channelID, beforeID, limit} - } else { - query = `SELECT m.id, c.name, u.nick, m.content, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN channels c ON c.id = m.channel_id - WHERE m.channel_id = ? AND m.is_dm = 0 - ORDER BY m.id DESC LIMIT ?` - args = []any{channelID, limit} - } - rows, err := s.db.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, + `SELECT m.id, m.target, m.from_nick, m.body, m.created_at + FROM messages m + WHERE m.created_at > COALESCE(NULLIF(?, ''), '1970-01-01') + AND ( + m.target IN (SELECT cm.channel_id FROM channel_members cm WHERE cm.user_id = ?) + OR m.target = ? + OR m.from_user_id = ? + ) + ORDER BY m.created_at ASC LIMIT ?`, + afterTS, userID, userID, userID, limit) if err != nil { return nil, err } - defer rows.Close() - var msgs []MessageInfo + + defer func() { _ = rows.Close() }() + + msgs := []MessageInfo{} + for rows.Next() { var m MessageInfo if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { return nil, err } + msgs = append(msgs, m) } - if msgs == nil { - msgs = []MessageInfo{} - } - // Reverse to ascending order - for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { - msgs[i], msgs[j] = msgs[j], msgs[i] - } - return msgs, nil + + return msgs, rows.Err() } -// GetDMsBefore returns DMs between two users before a given ID (for history scrollback). -func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeID int64, limit int) ([]MessageInfo, error) { +// GetMessagesBefore returns channel messages before a given timestamp (for history scrollback). +func (s *Database) GetMessagesBefore(ctx context.Context, target string, beforeTS string, limit int) ([]MessageInfo, error) { if limit <= 0 { limit = 50 } - var query string - var args []any - if beforeID > 0 { - query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN users t ON t.id = m.dm_target_id - WHERE m.is_dm = 1 AND m.id < ? - AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) - ORDER BY m.id DESC LIMIT ?` - args = []any{beforeID, userA, userB, userB, userA, limit} - } else { - query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at - FROM messages m - INNER JOIN users u ON u.id = m.user_id - INNER JOIN users t ON t.id = m.dm_target_id - WHERE m.is_dm = 1 - AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?)) - ORDER BY m.id DESC LIMIT ?` - args = []any{userA, userB, userB, userA, limit} + + var rows interface { + Next() bool + Scan(dest ...interface{}) error + Close() error + Err() error } - rows, err := s.db.QueryContext(ctx, query, args...) + + var err error + + if beforeTS != "" { + rows, err = s.db.QueryContext(ctx, + `SELECT m.id, m.target, m.from_nick, m.body, m.created_at + FROM messages m + WHERE m.target = ? AND m.created_at < ? + ORDER BY m.created_at DESC LIMIT ?`, + target, beforeTS, limit) + } else { + rows, err = s.db.QueryContext(ctx, + `SELECT m.id, m.target, m.from_nick, m.body, m.created_at + FROM messages m + WHERE m.target = ? + ORDER BY m.created_at DESC LIMIT ?`, + target, limit) + } + if err != nil { return nil, err } - defer rows.Close() - var msgs []MessageInfo + + defer func() { _ = rows.Close() }() + + msgs := []MessageInfo{} + + for rows.Next() { + var m MessageInfo + if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil { + return nil, err + } + + msgs = append(msgs, m) + } + + // Reverse to ascending order. + for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { + msgs[i], msgs[j] = msgs[j], msgs[i] + } + + return msgs, rows.Err() +} + +// GetDMsBefore returns DMs between two users before a given timestamp. +func (s *Database) GetDMsBefore(ctx context.Context, userA, userB string, beforeTS string, limit int) ([]MessageInfo, error) { + if limit <= 0 { + limit = 50 + } + + var rows interface { + Next() bool + Scan(dest ...interface{}) error + Close() error + Err() error + } + + var err error + + if beforeTS != "" { + rows, err = s.db.QueryContext(ctx, + `SELECT m.id, m.from_nick, m.body, m.target, m.created_at + FROM messages m + WHERE m.created_at < ? + AND ((m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?)) + ORDER BY m.created_at DESC LIMIT ?`, + beforeTS, userA, userB, userB, userA, limit) + } else { + rows, err = s.db.QueryContext(ctx, + `SELECT m.id, m.from_nick, m.body, m.target, m.created_at + FROM messages m + WHERE (m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?) + ORDER BY m.created_at DESC LIMIT ?`, + userA, userB, userB, userA, limit) + } + + if err != nil { + return nil, err + } + + defer func() { _ = rows.Close() }() + + msgs := []MessageInfo{} + for rows.Next() { var m MessageInfo if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil { return nil, err } + m.IsDM = true msgs = append(msgs, m) } - if msgs == nil { - msgs = []MessageInfo{} - } - // Reverse to ascending order + + // Reverse to ascending order. for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 { msgs[i], msgs[j] = msgs[j], msgs[i] } - return msgs, nil + + return msgs, rows.Err() } // ChangeNick updates a user's nickname. -func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error { +func (s *Database) ChangeNick(ctx context.Context, userID string, newNick string) error { _, err := s.db.ExecContext(ctx, "UPDATE users SET nick = ? WHERE id = ?", newNick, userID) + return err } // SetTopic sets the topic for a channel. -func (s *Database) SetTopic(ctx context.Context, channelName string, _ int64, topic string) error { +func (s *Database) SetTopic(ctx context.Context, channelName string, _ string, topic string) error { _, err := s.db.ExecContext(ctx, "UPDATE channels SET topic = ? WHERE name = ?", topic, channelName) + return err } @@ -398,17 +334,19 @@ func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) { if err != nil { return nil, err } - defer rows.Close() - var channels []ChannelInfo + + defer func() { _ = rows.Close() }() + + channels := []ChannelInfo{} + for rows.Next() { var ch ChannelInfo if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil { return nil, err } + channels = append(channels, ch) } - if channels == nil { - channels = []ChannelInfo{} - } - return channels, nil + + return channels, rows.Err() } diff --git a/internal/handlers/api.go b/internal/handlers/api.go index 975b7a1..de53a92 100644 --- a/internal/handlers/api.go +++ b/internal/handlers/api.go @@ -1,7 +1,9 @@ package handlers import ( + "crypto/rand" "database/sql" + "encoding/hex" "encoding/json" "net/http" "strconv" @@ -12,77 +14,114 @@ import ( ) // authUser extracts the user from the Authorization header (Bearer token). -func (s *Handlers) authUser(r *http.Request) (int64, string, error) { +func (s *Handlers) authUser(r *http.Request) (string, string, error) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { - return 0, "", sql.ErrNoRows + return "", "", sql.ErrNoRows } + token := strings.TrimPrefix(auth, "Bearer ") - return s.params.Database.GetUserByToken(r.Context(), token) + + u, err := s.params.Database.GetUserByToken(r.Context(), token) + if err != nil { + return "", "", err + } + + return u.ID, u.Nick, nil } -func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, string, bool) { +func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (string, string, bool) { uid, nick, err := s.authUser(r) if err != nil { s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized) - return 0, "", false + return "", "", false } + return uid, nick, true } +func generateID() string { + b := make([]byte, 16) + _, _ = rand.Read(b) + + return hex.EncodeToString(b) +} + // HandleCreateSession creates a new user session and returns the auth token. func (s *Handlers) HandleCreateSession() http.HandlerFunc { type request struct { Nick string `json:"nick"` } type response struct { - ID int64 `json:"id"` + ID string `json:"id"` Nick string `json:"nick"` Token string `json:"token"` } + return func(w http.ResponseWriter, r *http.Request) { var req request if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) return } + req.Nick = strings.TrimSpace(req.Nick) if req.Nick == "" || len(req.Nick) > 32 { s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) return } - id, token, err := s.params.Database.CreateUser(r.Context(), req.Nick) + + id := generateID() + + u, err := s.params.Database.CreateUser(r.Context(), id, req.Nick, "") if err != nil { if strings.Contains(err.Error(), "UNIQUE") { s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict) return } + s.log.Error("create user failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } - s.respondJSON(w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated) + + tokenStr := generateID() + + _, err = s.params.Database.CreateAuthToken(r.Context(), tokenStr, u.ID) + if err != nil { + s.log.Error("create auth token failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + s.respondJSON(w, r, &response{ID: u.ID, Nick: req.Nick, Token: tokenStr}, http.StatusCreated) } } // HandleState returns the current user's info and joined channels. func (s *Handlers) HandleState() http.HandlerFunc { type response struct { - ID int64 `json:"id"` + ID string `json:"id"` Nick string `json:"nick"` Channels []db.ChannelInfo `json:"channels"` } + return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } + channels, err := s.params.Database.ListChannels(r.Context(), uid) if err != nil { s.log.Error("list channels failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } + s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK) } } @@ -94,12 +133,15 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc { if !ok { return } + channels, err := s.params.Database.ListAllChannels(r.Context()) if err != nil { s.log.Error("list all channels failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } + s.respondJSON(w, r, channels, http.StatusOK) } } @@ -111,20 +153,26 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc { if !ok { return } + name := "#" + chi.URLParam(r, "channel") - var chID int64 + + var chID string + err := s.params.Database.GetDB().QueryRowContext(r.Context(), "SELECT id FROM channels WHERE name = ?", name).Scan(&chID) if err != nil { s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) return } + members, err := s.params.Database.ChannelMembers(r.Context(), chID) if err != nil { s.log.Error("channel members failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } + s.respondJSON(w, r, members, http.StatusOK) } } @@ -137,14 +185,18 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc { if !ok { return } - afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64) + + afterTS := r.URL.Query().Get("after") limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) - msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterID, limit) + + msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterTS, limit) if err != nil { s.log.Error("get messages failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } + s.respondJSON(w, r, msgs, http.StatusOK) } } @@ -158,16 +210,19 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { Params []string `json:"params,omitempty"` Body interface{} `json:"body,omitempty"` } + return func(w http.ResponseWriter, r *http.Request) { uid, nick, ok := s.requireAuth(w, r) if !ok { return } + var req request if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest) return } + req.Command = strings.ToUpper(strings.TrimSpace(req.Command)) req.To = strings.TrimSpace(req.To) @@ -176,11 +231,13 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { switch v := req.Body.(type) { case []interface{}: lines := make([]string, 0, len(v)) + for _, item := range v { - if s, ok := item.(string); ok { - lines = append(lines, s) + if str, ok := item.(string); ok { + lines = append(lines, str) } } + return lines case []string: return v @@ -191,137 +248,19 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { switch req.Command { case "PRIVMSG", "NOTICE": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) - return - } - content := strings.Join(lines, "\n") - - if strings.HasPrefix(req.To, "#") { - // Channel message - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, content) - if err != nil { - s.log.Error("send message failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) - } else { - // DM - targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) - return - } - msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, content) - if err != nil { - s.log.Error("send dm failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) - } + s.handlePrivmsg(w, r, uid, nick, req.To, bodyLines()) case "JOIN": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel) - if err != nil { - s.log.Error("get/create channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { - s.log.Error("join channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK) + s.handleJoin(w, r, uid, req.To) case "PART": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - var chID int64 - err := s.params.Database.GetDB().QueryRowContext(r.Context(), - "SELECT id FROM channels WHERE name = ?", channel).Scan(&chID) - if err != nil { - s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) - return - } - if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { - s.log.Error("part channel failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK) + s.handlePart(w, r, uid, req.To) case "NICK": - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest) - return - } - newNick := strings.TrimSpace(lines[0]) - if newNick == "" || len(newNick) > 32 { - s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) - return - } - if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil { - if strings.Contains(err.Error(), "UNIQUE") { - s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict) - return - } - s.log.Error("change nick failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) + s.handleNick(w, r, uid, bodyLines()) case "TOPIC": - if req.To == "" { - s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) - return - } - lines := bodyLines() - if len(lines) == 0 { - s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest) - return - } - topic := strings.Join(lines, " ") - channel := req.To - if !strings.HasPrefix(channel, "#") { - channel = "#" + channel - } - if err := s.params.Database.SetTopic(r.Context(), channel, uid, topic); err != nil { - s.log.Error("set topic failed", "error", err) - s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) - return - } - s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK) + s.handleTopic(w, r, uid, req.To, bodyLines()) case "PING": s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK) @@ -333,6 +272,173 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc { } } +func (s *Handlers) handlePrivmsg(w http.ResponseWriter, r *http.Request, uid, nick, to string, lines []string) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest) + return + } + + content := strings.Join(lines, "\n") + + if strings.HasPrefix(to, "#") { + // Channel message. + var chID string + + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", to).Scan(&chID) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + + msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, nick, content) + if err != nil { + s.log.Error("send message failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) + } else { + // DM. + targetUser, err := s.params.Database.GetUserByNick(r.Context(), to) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) + return + } + + msgID, err := s.params.Database.SendDM(r.Context(), uid, nick, targetUser.ID, content) + if err != nil { + s.log.Error("send dm failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated) + } +} + +func (s *Handlers) handleJoin(w http.ResponseWriter, r *http.Request, uid, to string) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + + channel := to + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel) + if err != nil { + s.log.Error("get/create channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil { + s.log.Error("join channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK) +} + +func (s *Handlers) handlePart(w http.ResponseWriter, r *http.Request, uid, to string) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + + channel := to + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + var chID string + + err := s.params.Database.GetDB().QueryRowContext(r.Context(), + "SELECT id FROM channels WHERE name = ?", channel).Scan(&chID) + if err != nil { + s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) + return + } + + if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil { + s.log.Error("part channel failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK) +} + +func (s *Handlers) handleNick(w http.ResponseWriter, r *http.Request, uid string, lines []string) { + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest) + return + } + + newNick := strings.TrimSpace(lines[0]) + if newNick == "" || len(newNick) > 32 { + s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest) + return + } + + if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil { + if strings.Contains(err.Error(), "UNIQUE") { + s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict) + return + } + + s.log.Error("change nick failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK) +} + +func (s *Handlers) handleTopic(w http.ResponseWriter, r *http.Request, uid, to string, lines []string) { + if to == "" { + s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest) + return + } + + if len(lines) == 0 { + s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest) + return + } + + topic := strings.Join(lines, " ") + + channel := to + if !strings.HasPrefix(channel, "#") { + channel = "#" + channel + } + + if err := s.params.Database.SetTopic(r.Context(), channel, uid, topic); err != nil { + s.log.Error("set topic failed", "error", err) + s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + + return + } + + s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK) +} + // HandleGetHistory returns message history for a specific target (channel or DM). func (s *Handlers) HandleGetHistory() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -340,46 +446,56 @@ func (s *Handlers) HandleGetHistory() http.HandlerFunc { if !ok { return } + target := r.URL.Query().Get("target") if target == "" { s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest) return } - beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64) + + beforeTS := r.URL.Query().Get("before") limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + if limit <= 0 { limit = 50 } if strings.HasPrefix(target, "#") { - // Channel history - var chID int64 + // Channel history — look up channel by name to get its ID for target matching. + var chID string + err := s.params.Database.GetDB().QueryRowContext(r.Context(), "SELECT id FROM channels WHERE name = ?", target).Scan(&chID) if err != nil { s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound) return } - msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeID, limit) + + msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeTS, limit) if err != nil { s.log.Error("get history failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } + s.respondJSON(w, r, msgs, http.StatusOK) } else { - // DM history - targetID, err := s.params.Database.GetUserByNick(r.Context(), target) + // DM history. + targetUser, err := s.params.Database.GetUserByNick(r.Context(), target) if err != nil { s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound) return } - msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetID, beforeID, limit) + + msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetUser.ID, beforeTS, limit) if err != nil { s.log.Error("get dm history failed", "error", err) s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError) + return } + s.respondJSON(w, r, msgs, http.StatusOK) } } @@ -391,6 +507,7 @@ func (s *Handlers) HandleServerInfo() http.HandlerFunc { Name string `json:"name"` MOTD string `json:"motd"` } + return func(w http.ResponseWriter, r *http.Request) { s.respondJSON(w, r, &response{ Name: s.params.Config.ServerName,