fix: resolve typecheck errors by removing duplicate db methods and updating handlers to use models-based API
This commit is contained in:
@@ -2,88 +2,52 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateToken() string {
|
|
||||||
b := make([]byte, 32)
|
|
||||||
_, _ = rand.Read(b)
|
|
||||||
return hex.EncodeToString(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateUser registers a new user with the given nick and returns the user with token.
|
|
||||||
func (s *Database) CreateUser(ctx context.Context, nick string) (int64, string, error) {
|
|
||||||
token := generateToken()
|
|
||||||
now := time.Now()
|
|
||||||
res, err := s.db.ExecContext(ctx,
|
|
||||||
"INSERT INTO users (nick, token, created_at, last_seen) VALUES (?, ?, ?, ?)",
|
|
||||||
nick, token, now, now)
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", fmt.Errorf("create user: %w", err)
|
|
||||||
}
|
|
||||||
id, _ := res.LastInsertId()
|
|
||||||
return id, token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByToken returns user id and nick for a given auth token.
|
|
||||||
func (s *Database) GetUserByToken(ctx context.Context, token string) (int64, string, error) {
|
|
||||||
var id int64
|
|
||||||
var nick string
|
|
||||||
err := s.db.QueryRowContext(ctx, "SELECT id, nick FROM users WHERE token = ?", token).Scan(&id, &nick)
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", err
|
|
||||||
}
|
|
||||||
// Update last_seen
|
|
||||||
_, _ = s.db.ExecContext(ctx, "UPDATE users SET last_seen = ? WHERE id = ?", time.Now(), id)
|
|
||||||
return id, nick, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByNick returns user id for a given nick.
|
|
||||||
func (s *Database) GetUserByNick(ctx context.Context, nick string) (int64, error) {
|
|
||||||
var id int64
|
|
||||||
err := s.db.QueryRowContext(ctx, "SELECT id FROM users WHERE nick = ?", nick).Scan(&id)
|
|
||||||
return id, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOrCreateChannel returns the channel id, creating it if needed.
|
// GetOrCreateChannel returns the channel id, creating it if needed.
|
||||||
func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (int64, error) {
|
func (s *Database) GetOrCreateChannel(ctx context.Context, name string) (string, error) {
|
||||||
var id int64
|
var id string
|
||||||
|
|
||||||
err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id)
|
err := s.db.QueryRowContext(ctx, "SELECT id FROM channels WHERE name = ?", name).Scan(&id)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
res, err := s.db.ExecContext(ctx,
|
id = fmt.Sprintf("ch-%d", now.UnixNano())
|
||||||
"INSERT INTO channels (name, created_at, updated_at) VALUES (?, ?, ?)",
|
|
||||||
name, now, now)
|
_, err = s.db.ExecContext(ctx,
|
||||||
|
"INSERT INTO channels (id, name, topic, modes, created_at, updated_at) VALUES (?, ?, '', '', ?, ?)",
|
||||||
|
id, name, now, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("create channel: %w", err)
|
return "", fmt.Errorf("create channel: %w", err)
|
||||||
}
|
}
|
||||||
id, _ = res.LastInsertId()
|
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinChannel adds a user to a channel.
|
// JoinChannel adds a user to a channel.
|
||||||
func (s *Database) JoinChannel(ctx context.Context, channelID, userID int64) error {
|
func (s *Database) JoinChannel(ctx context.Context, channelID, userID string) error {
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
"INSERT OR IGNORE INTO channel_members (channel_id, user_id, joined_at) VALUES (?, ?, ?)",
|
"INSERT OR IGNORE INTO channel_members (channel_id, user_id, modes, joined_at) VALUES (?, ?, '', ?)",
|
||||||
channelID, userID, time.Now())
|
channelID, userID, time.Now())
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// PartChannel removes a user from a channel.
|
// PartChannel removes a user from a channel.
|
||||||
func (s *Database) PartChannel(ctx context.Context, channelID, userID int64) error {
|
func (s *Database) PartChannel(ctx context.Context, channelID, userID string) error {
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
"DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?",
|
"DELETE FROM channel_members WHERE channel_id = ? AND user_id = ?",
|
||||||
channelID, userID)
|
channelID, userID)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListChannels returns all channels the user has joined.
|
// ListChannels returns all channels the user has joined.
|
||||||
func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInfo, error) {
|
func (s *Database) ListChannels(ctx context.Context, userID string) ([]ChannelInfo, error) {
|
||||||
rows, err := s.db.QueryContext(ctx,
|
rows, err := s.db.QueryContext(ctx,
|
||||||
`SELECT c.id, c.name, c.topic FROM channels c
|
`SELECT c.id, c.name, c.topic FROM channels c
|
||||||
INNER JOIN channel_members cm ON cm.channel_id = c.id
|
INNER JOIN channel_members cm ON cm.channel_id = c.id
|
||||||
@@ -91,62 +55,66 @@ func (s *Database) ListChannels(ctx context.Context, userID int64) ([]ChannelInf
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
var channels []ChannelInfo
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
channels := []ChannelInfo{}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch ChannelInfo
|
var ch ChannelInfo
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
}
|
}
|
||||||
if channels == nil {
|
|
||||||
channels = []ChannelInfo{}
|
return channels, rows.Err()
|
||||||
}
|
|
||||||
return channels, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChannelInfo is a lightweight channel representation.
|
// ChannelInfo is a lightweight channel representation.
|
||||||
type ChannelInfo struct {
|
type ChannelInfo struct {
|
||||||
ID int64 `json:"id"`
|
ID string `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Topic string `json:"topic"`
|
Topic string `json:"topic"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChannelMembers returns all members of a channel.
|
// ChannelMembers returns all members of a channel.
|
||||||
func (s *Database) ChannelMembers(ctx context.Context, channelID int64) ([]MemberInfo, error) {
|
func (s *Database) ChannelMembers(ctx context.Context, channelID string) ([]MemberInfo, error) {
|
||||||
rows, err := s.db.QueryContext(ctx,
|
rows, err := s.db.QueryContext(ctx,
|
||||||
`SELECT u.id, u.nick, u.last_seen FROM users u
|
`SELECT u.id, u.nick, u.last_seen_at FROM users u
|
||||||
INNER JOIN channel_members cm ON cm.user_id = u.id
|
INNER JOIN channel_members cm ON cm.user_id = u.id
|
||||||
WHERE cm.channel_id = ? ORDER BY u.nick`, channelID)
|
WHERE cm.channel_id = ? ORDER BY u.nick`, channelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
var members []MemberInfo
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
members := []MemberInfo{}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var m MemberInfo
|
var m MemberInfo
|
||||||
if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil {
|
if err := rows.Scan(&m.ID, &m.Nick, &m.LastSeen); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
members = append(members, m)
|
members = append(members, m)
|
||||||
}
|
}
|
||||||
if members == nil {
|
|
||||||
members = []MemberInfo{}
|
return members, rows.Err()
|
||||||
}
|
|
||||||
return members, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MemberInfo represents a channel member.
|
// MemberInfo represents a channel member.
|
||||||
type MemberInfo struct {
|
type MemberInfo struct {
|
||||||
ID int64 `json:"id"`
|
ID string `json:"id"`
|
||||||
Nick string `json:"nick"`
|
Nick string `json:"nick"`
|
||||||
LastSeen time.Time `json:"lastSeen"`
|
LastSeen *time.Time `json:"lastSeen"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageInfo represents a chat message.
|
// MessageInfo represents a chat message.
|
||||||
type MessageInfo struct {
|
type MessageInfo struct {
|
||||||
ID int64 `json:"id"`
|
ID string `json:"id"`
|
||||||
Channel string `json:"channel,omitempty"`
|
Channel string `json:"channel,omitempty"`
|
||||||
Nick string `json:"nick"`
|
Nick string `json:"nick"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
@@ -155,234 +123,202 @@ type MessageInfo struct {
|
|||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessages returns messages for a channel, optionally after a given ID.
|
|
||||||
func (s *Database) GetMessages(ctx context.Context, channelID int64, afterID int64, limit int) ([]MessageInfo, error) {
|
|
||||||
if limit <= 0 {
|
|
||||||
limit = 50
|
|
||||||
}
|
|
||||||
rows, err := s.db.QueryContext(ctx,
|
|
||||||
`SELECT m.id, c.name, u.nick, m.content, m.created_at
|
|
||||||
FROM messages m
|
|
||||||
INNER JOIN users u ON u.id = m.user_id
|
|
||||||
INNER JOIN channels c ON c.id = m.channel_id
|
|
||||||
WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id > ?
|
|
||||||
ORDER BY m.id ASC LIMIT ?`, channelID, afterID, limit)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var msgs []MessageInfo
|
|
||||||
for rows.Next() {
|
|
||||||
var m MessageInfo
|
|
||||||
if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
msgs = append(msgs, m)
|
|
||||||
}
|
|
||||||
if msgs == nil {
|
|
||||||
msgs = []MessageInfo{}
|
|
||||||
}
|
|
||||||
return msgs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendMessage inserts a channel message.
|
// SendMessage inserts a channel message.
|
||||||
func (s *Database) SendMessage(ctx context.Context, channelID, userID int64, content string) (int64, error) {
|
func (s *Database) SendMessage(ctx context.Context, channelID, userID, nick, content string) (string, error) {
|
||||||
res, err := s.db.ExecContext(ctx,
|
now := time.Now()
|
||||||
"INSERT INTO messages (channel_id, user_id, content, is_dm, created_at) VALUES (?, ?, ?, 0, ?)",
|
id := fmt.Sprintf("msg-%d", now.UnixNano())
|
||||||
channelID, userID, content, time.Now())
|
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`,
|
||||||
|
id, now, userID, nick, channelID, content, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return "", err
|
||||||
}
|
}
|
||||||
return res.LastInsertId()
|
|
||||||
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendDM inserts a direct message.
|
// SendDM inserts a direct message.
|
||||||
func (s *Database) SendDM(ctx context.Context, fromID, toID int64, content string) (int64, error) {
|
func (s *Database) SendDM(ctx context.Context, fromID, fromNick, toID, content string) (string, error) {
|
||||||
res, err := s.db.ExecContext(ctx,
|
now := time.Now()
|
||||||
"INSERT INTO messages (user_id, content, is_dm, dm_target_id, created_at) VALUES (?, ?, 1, ?, ?)",
|
id := fmt.Sprintf("msg-%d", now.UnixNano())
|
||||||
fromID, content, toID, time.Now())
|
|
||||||
|
_, err := s.db.ExecContext(ctx,
|
||||||
|
`INSERT INTO messages (id, ts, from_user_id, from_nick, target, type, body, meta, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, 'message', ?, '{}', ?)`,
|
||||||
|
id, now, fromID, fromNick, toID, content, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return "", err
|
||||||
}
|
}
|
||||||
return res.LastInsertId()
|
|
||||||
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDMs returns direct messages between two users after a given ID.
|
// PollMessages returns all new messages for a user's joined channels, ordered by timestamp.
|
||||||
func (s *Database) GetDMs(ctx context.Context, userA, userB int64, afterID int64, limit int) ([]MessageInfo, error) {
|
func (s *Database) PollMessages(ctx context.Context, userID string, afterTS string, limit int) ([]MessageInfo, error) {
|
||||||
if limit <= 0 {
|
|
||||||
limit = 50
|
|
||||||
}
|
|
||||||
rows, err := s.db.QueryContext(ctx,
|
|
||||||
`SELECT m.id, u.nick, m.content, t.nick, m.created_at
|
|
||||||
FROM messages m
|
|
||||||
INNER JOIN users u ON u.id = m.user_id
|
|
||||||
INNER JOIN users t ON t.id = m.dm_target_id
|
|
||||||
WHERE m.is_dm = 1 AND m.id > ?
|
|
||||||
AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
|
|
||||||
ORDER BY m.id ASC LIMIT ?`, afterID, userA, userB, userB, userA, limit)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var msgs []MessageInfo
|
|
||||||
for rows.Next() {
|
|
||||||
var m MessageInfo
|
|
||||||
if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.IsDM = true
|
|
||||||
msgs = append(msgs, m)
|
|
||||||
}
|
|
||||||
if msgs == nil {
|
|
||||||
msgs = []MessageInfo{}
|
|
||||||
}
|
|
||||||
return msgs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PollMessages returns all new messages (channel + DM) for a user after a given ID.
|
|
||||||
func (s *Database) PollMessages(ctx context.Context, userID int64, afterID int64, limit int) ([]MessageInfo, error) {
|
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 100
|
limit = 100
|
||||||
}
|
}
|
||||||
rows, err := s.db.QueryContext(ctx,
|
|
||||||
`SELECT m.id, COALESCE(c.name, ''), u.nick, m.content, m.is_dm, COALESCE(t.nick, ''), m.created_at
|
|
||||||
FROM messages m
|
|
||||||
INNER JOIN users u ON u.id = m.user_id
|
|
||||||
LEFT JOIN channels c ON c.id = m.channel_id
|
|
||||||
LEFT JOIN users t ON t.id = m.dm_target_id
|
|
||||||
WHERE m.id > ? AND (
|
|
||||||
(m.is_dm = 0 AND m.channel_id IN (SELECT channel_id FROM channel_members WHERE user_id = ?))
|
|
||||||
OR (m.is_dm = 1 AND (m.user_id = ? OR m.dm_target_id = ?))
|
|
||||||
)
|
|
||||||
ORDER BY m.id ASC LIMIT ?`, afterID, userID, userID, userID, limit)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var msgs []MessageInfo
|
|
||||||
for rows.Next() {
|
|
||||||
var m MessageInfo
|
|
||||||
var isDM int
|
|
||||||
if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &isDM, &m.DMTarget, &m.CreatedAt); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m.IsDM = isDM == 1
|
|
||||||
msgs = append(msgs, m)
|
|
||||||
}
|
|
||||||
if msgs == nil {
|
|
||||||
msgs = []MessageInfo{}
|
|
||||||
}
|
|
||||||
return msgs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMessagesBefore returns channel messages before a given ID (for history scrollback).
|
rows, err := s.db.QueryContext(ctx,
|
||||||
func (s *Database) GetMessagesBefore(ctx context.Context, channelID int64, beforeID int64, limit int) ([]MessageInfo, error) {
|
`SELECT m.id, m.target, m.from_nick, m.body, m.created_at
|
||||||
if limit <= 0 {
|
FROM messages m
|
||||||
limit = 50
|
WHERE m.created_at > COALESCE(NULLIF(?, ''), '1970-01-01')
|
||||||
}
|
AND (
|
||||||
var query string
|
m.target IN (SELECT cm.channel_id FROM channel_members cm WHERE cm.user_id = ?)
|
||||||
var args []any
|
OR m.target = ?
|
||||||
if beforeID > 0 {
|
OR m.from_user_id = ?
|
||||||
query = `SELECT m.id, c.name, u.nick, m.content, m.created_at
|
)
|
||||||
FROM messages m
|
ORDER BY m.created_at ASC LIMIT ?`,
|
||||||
INNER JOIN users u ON u.id = m.user_id
|
afterTS, userID, userID, userID, limit)
|
||||||
INNER JOIN channels c ON c.id = m.channel_id
|
|
||||||
WHERE m.channel_id = ? AND m.is_dm = 0 AND m.id < ?
|
|
||||||
ORDER BY m.id DESC LIMIT ?`
|
|
||||||
args = []any{channelID, beforeID, limit}
|
|
||||||
} else {
|
|
||||||
query = `SELECT m.id, c.name, u.nick, m.content, m.created_at
|
|
||||||
FROM messages m
|
|
||||||
INNER JOIN users u ON u.id = m.user_id
|
|
||||||
INNER JOIN channels c ON c.id = m.channel_id
|
|
||||||
WHERE m.channel_id = ? AND m.is_dm = 0
|
|
||||||
ORDER BY m.id DESC LIMIT ?`
|
|
||||||
args = []any{channelID, limit}
|
|
||||||
}
|
|
||||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
var msgs []MessageInfo
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
msgs := []MessageInfo{}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var m MessageInfo
|
var m MessageInfo
|
||||||
if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil {
|
if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs = append(msgs, m)
|
msgs = append(msgs, m)
|
||||||
}
|
}
|
||||||
if msgs == nil {
|
|
||||||
msgs = []MessageInfo{}
|
return msgs, rows.Err()
|
||||||
}
|
|
||||||
// Reverse to ascending order
|
|
||||||
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
|
||||||
msgs[i], msgs[j] = msgs[j], msgs[i]
|
|
||||||
}
|
|
||||||
return msgs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDMsBefore returns DMs between two users before a given ID (for history scrollback).
|
// GetMessagesBefore returns channel messages before a given timestamp (for history scrollback).
|
||||||
func (s *Database) GetDMsBefore(ctx context.Context, userA, userB int64, beforeID int64, limit int) ([]MessageInfo, error) {
|
func (s *Database) GetMessagesBefore(ctx context.Context, target string, beforeTS string, limit int) ([]MessageInfo, error) {
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 50
|
limit = 50
|
||||||
}
|
}
|
||||||
var query string
|
|
||||||
var args []any
|
var rows interface {
|
||||||
if beforeID > 0 {
|
Next() bool
|
||||||
query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at
|
Scan(dest ...interface{}) error
|
||||||
FROM messages m
|
Close() error
|
||||||
INNER JOIN users u ON u.id = m.user_id
|
Err() error
|
||||||
INNER JOIN users t ON t.id = m.dm_target_id
|
|
||||||
WHERE m.is_dm = 1 AND m.id < ?
|
|
||||||
AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
|
|
||||||
ORDER BY m.id DESC LIMIT ?`
|
|
||||||
args = []any{beforeID, userA, userB, userB, userA, limit}
|
|
||||||
} else {
|
|
||||||
query = `SELECT m.id, u.nick, m.content, t.nick, m.created_at
|
|
||||||
FROM messages m
|
|
||||||
INNER JOIN users u ON u.id = m.user_id
|
|
||||||
INNER JOIN users t ON t.id = m.dm_target_id
|
|
||||||
WHERE m.is_dm = 1
|
|
||||||
AND ((m.user_id = ? AND m.dm_target_id = ?) OR (m.user_id = ? AND m.dm_target_id = ?))
|
|
||||||
ORDER BY m.id DESC LIMIT ?`
|
|
||||||
args = []any{userA, userB, userB, userA, limit}
|
|
||||||
}
|
}
|
||||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if beforeTS != "" {
|
||||||
|
rows, err = s.db.QueryContext(ctx,
|
||||||
|
`SELECT m.id, m.target, m.from_nick, m.body, m.created_at
|
||||||
|
FROM messages m
|
||||||
|
WHERE m.target = ? AND m.created_at < ?
|
||||||
|
ORDER BY m.created_at DESC LIMIT ?`,
|
||||||
|
target, beforeTS, limit)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx,
|
||||||
|
`SELECT m.id, m.target, m.from_nick, m.body, m.created_at
|
||||||
|
FROM messages m
|
||||||
|
WHERE m.target = ?
|
||||||
|
ORDER BY m.created_at DESC LIMIT ?`,
|
||||||
|
target, limit)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
var msgs []MessageInfo
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
msgs := []MessageInfo{}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var m MessageInfo
|
||||||
|
if err := rows.Scan(&m.ID, &m.Channel, &m.Nick, &m.Content, &m.CreatedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs = append(msgs, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse to ascending order.
|
||||||
|
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
msgs[i], msgs[j] = msgs[j], msgs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgs, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDMsBefore returns DMs between two users before a given timestamp.
|
||||||
|
func (s *Database) GetDMsBefore(ctx context.Context, userA, userB string, beforeTS string, limit int) ([]MessageInfo, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
var rows interface {
|
||||||
|
Next() bool
|
||||||
|
Scan(dest ...interface{}) error
|
||||||
|
Close() error
|
||||||
|
Err() error
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if beforeTS != "" {
|
||||||
|
rows, err = s.db.QueryContext(ctx,
|
||||||
|
`SELECT m.id, m.from_nick, m.body, m.target, m.created_at
|
||||||
|
FROM messages m
|
||||||
|
WHERE m.created_at < ?
|
||||||
|
AND ((m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?))
|
||||||
|
ORDER BY m.created_at DESC LIMIT ?`,
|
||||||
|
beforeTS, userA, userB, userB, userA, limit)
|
||||||
|
} else {
|
||||||
|
rows, err = s.db.QueryContext(ctx,
|
||||||
|
`SELECT m.id, m.from_nick, m.body, m.target, m.created_at
|
||||||
|
FROM messages m
|
||||||
|
WHERE (m.from_user_id = ? AND m.target = ?) OR (m.from_user_id = ? AND m.target = ?)
|
||||||
|
ORDER BY m.created_at DESC LIMIT ?`,
|
||||||
|
userA, userB, userB, userA, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
msgs := []MessageInfo{}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var m MessageInfo
|
var m MessageInfo
|
||||||
if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil {
|
if err := rows.Scan(&m.ID, &m.Nick, &m.Content, &m.DMTarget, &m.CreatedAt); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.IsDM = true
|
m.IsDM = true
|
||||||
msgs = append(msgs, m)
|
msgs = append(msgs, m)
|
||||||
}
|
}
|
||||||
if msgs == nil {
|
|
||||||
msgs = []MessageInfo{}
|
// Reverse to ascending order.
|
||||||
}
|
|
||||||
// Reverse to ascending order
|
|
||||||
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
||||||
msgs[i], msgs[j] = msgs[j], msgs[i]
|
msgs[i], msgs[j] = msgs[j], msgs[i]
|
||||||
}
|
}
|
||||||
return msgs, nil
|
|
||||||
|
return msgs, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChangeNick updates a user's nickname.
|
// ChangeNick updates a user's nickname.
|
||||||
func (s *Database) ChangeNick(ctx context.Context, userID int64, newNick string) error {
|
func (s *Database) ChangeNick(ctx context.Context, userID string, newNick string) error {
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
"UPDATE users SET nick = ? WHERE id = ?", newNick, userID)
|
"UPDATE users SET nick = ? WHERE id = ?", newNick, userID)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTopic sets the topic for a channel.
|
// SetTopic sets the topic for a channel.
|
||||||
func (s *Database) SetTopic(ctx context.Context, channelName string, _ int64, topic string) error {
|
func (s *Database) SetTopic(ctx context.Context, channelName string, _ string, topic string) error {
|
||||||
_, err := s.db.ExecContext(ctx,
|
_, err := s.db.ExecContext(ctx,
|
||||||
"UPDATE channels SET topic = ? WHERE name = ?", topic, channelName)
|
"UPDATE channels SET topic = ? WHERE name = ?", topic, channelName)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -398,17 +334,19 @@ func (s *Database) ListAllChannels(ctx context.Context) ([]ChannelInfo, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
var channels []ChannelInfo
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
channels := []ChannelInfo{}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch ChannelInfo
|
var ch ChannelInfo
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Topic); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
}
|
}
|
||||||
if channels == nil {
|
|
||||||
channels = []ChannelInfo{}
|
return channels, rows.Err()
|
||||||
}
|
|
||||||
return channels, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -12,77 +14,114 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// authUser extracts the user from the Authorization header (Bearer token).
|
// authUser extracts the user from the Authorization header (Bearer token).
|
||||||
func (s *Handlers) authUser(r *http.Request) (int64, string, error) {
|
func (s *Handlers) authUser(r *http.Request) (string, string, error) {
|
||||||
auth := r.Header.Get("Authorization")
|
auth := r.Header.Get("Authorization")
|
||||||
if !strings.HasPrefix(auth, "Bearer ") {
|
if !strings.HasPrefix(auth, "Bearer ") {
|
||||||
return 0, "", sql.ErrNoRows
|
return "", "", sql.ErrNoRows
|
||||||
}
|
}
|
||||||
|
|
||||||
token := strings.TrimPrefix(auth, "Bearer ")
|
token := strings.TrimPrefix(auth, "Bearer ")
|
||||||
return s.params.Database.GetUserByToken(r.Context(), token)
|
|
||||||
|
u, err := s.params.Database.GetUserByToken(r.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.ID, u.Nick, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (int64, string, bool) {
|
func (s *Handlers) requireAuth(w http.ResponseWriter, r *http.Request) (string, string, bool) {
|
||||||
uid, nick, err := s.authUser(r)
|
uid, nick, err := s.authUser(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized)
|
s.respondJSON(w, r, map[string]string{"error": "unauthorized"}, http.StatusUnauthorized)
|
||||||
return 0, "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
return uid, nick, true
|
return uid, nick, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateID() string {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
// HandleCreateSession creates a new user session and returns the auth token.
|
// HandleCreateSession creates a new user session and returns the auth token.
|
||||||
func (s *Handlers) HandleCreateSession() http.HandlerFunc {
|
func (s *Handlers) HandleCreateSession() http.HandlerFunc {
|
||||||
type request struct {
|
type request struct {
|
||||||
Nick string `json:"nick"`
|
Nick string `json:"nick"`
|
||||||
}
|
}
|
||||||
type response struct {
|
type response struct {
|
||||||
ID int64 `json:"id"`
|
ID string `json:"id"`
|
||||||
Nick string `json:"nick"`
|
Nick string `json:"nick"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
var req request
|
var req request
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
|
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Nick = strings.TrimSpace(req.Nick)
|
req.Nick = strings.TrimSpace(req.Nick)
|
||||||
if req.Nick == "" || len(req.Nick) > 32 {
|
if req.Nick == "" || len(req.Nick) > 32 {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
|
s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id, token, err := s.params.Database.CreateUser(r.Context(), req.Nick)
|
|
||||||
|
id := generateID()
|
||||||
|
|
||||||
|
u, err := s.params.Database.CreateUser(r.Context(), id, req.Nick, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "UNIQUE") {
|
if strings.Contains(err.Error(), "UNIQUE") {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict)
|
s.respondJSON(w, r, map[string]string{"error": "nick already taken"}, http.StatusConflict)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.log.Error("create user failed", "error", err)
|
s.log.Error("create user failed", "error", err)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.respondJSON(w, r, &response{ID: id, Nick: req.Nick, Token: token}, http.StatusCreated)
|
|
||||||
|
tokenStr := generateID()
|
||||||
|
|
||||||
|
_, err = s.params.Database.CreateAuthToken(r.Context(), tokenStr, u.ID)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error("create auth token failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.respondJSON(w, r, &response{ID: u.ID, Nick: req.Nick, Token: tokenStr}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleState returns the current user's info and joined channels.
|
// HandleState returns the current user's info and joined channels.
|
||||||
func (s *Handlers) HandleState() http.HandlerFunc {
|
func (s *Handlers) HandleState() http.HandlerFunc {
|
||||||
type response struct {
|
type response struct {
|
||||||
ID int64 `json:"id"`
|
ID string `json:"id"`
|
||||||
Nick string `json:"nick"`
|
Nick string `json:"nick"`
|
||||||
Channels []db.ChannelInfo `json:"channels"`
|
Channels []db.ChannelInfo `json:"channels"`
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
uid, nick, ok := s.requireAuth(w, r)
|
uid, nick, ok := s.requireAuth(w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
channels, err := s.params.Database.ListChannels(r.Context(), uid)
|
channels, err := s.params.Database.ListChannels(r.Context(), uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("list channels failed", "error", err)
|
s.log.Error("list channels failed", "error", err)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK)
|
s.respondJSON(w, r, &response{ID: uid, Nick: nick, Channels: channels}, http.StatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -94,12 +133,15 @@ func (s *Handlers) HandleListAllChannels() http.HandlerFunc {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
channels, err := s.params.Database.ListAllChannels(r.Context())
|
channels, err := s.params.Database.ListAllChannels(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("list all channels failed", "error", err)
|
s.log.Error("list all channels failed", "error", err)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.respondJSON(w, r, channels, http.StatusOK)
|
s.respondJSON(w, r, channels, http.StatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -111,20 +153,26 @@ func (s *Handlers) HandleChannelMembers() http.HandlerFunc {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := "#" + chi.URLParam(r, "channel")
|
name := "#" + chi.URLParam(r, "channel")
|
||||||
var chID int64
|
|
||||||
|
var chID string
|
||||||
|
|
||||||
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
||||||
"SELECT id FROM channels WHERE name = ?", name).Scan(&chID)
|
"SELECT id FROM channels WHERE name = ?", name).Scan(&chID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
members, err := s.params.Database.ChannelMembers(r.Context(), chID)
|
members, err := s.params.Database.ChannelMembers(r.Context(), chID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("channel members failed", "error", err)
|
s.log.Error("channel members failed", "error", err)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.respondJSON(w, r, members, http.StatusOK)
|
s.respondJSON(w, r, members, http.StatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,14 +185,18 @@ func (s *Handlers) HandleGetMessages() http.HandlerFunc {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
afterID, _ := strconv.ParseInt(r.URL.Query().Get("after"), 10, 64)
|
|
||||||
|
afterTS := r.URL.Query().Get("after")
|
||||||
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||||
msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterID, limit)
|
|
||||||
|
msgs, err := s.params.Database.PollMessages(r.Context(), uid, afterTS, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("get messages failed", "error", err)
|
s.log.Error("get messages failed", "error", err)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.respondJSON(w, r, msgs, http.StatusOK)
|
s.respondJSON(w, r, msgs, http.StatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,16 +210,19 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc {
|
|||||||
Params []string `json:"params,omitempty"`
|
Params []string `json:"params,omitempty"`
|
||||||
Body interface{} `json:"body,omitempty"`
|
Body interface{} `json:"body,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
uid, nick, ok := s.requireAuth(w, r)
|
uid, nick, ok := s.requireAuth(w, r)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req request
|
var req request
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
|
s.respondJSON(w, r, map[string]string{"error": "invalid request"}, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Command = strings.ToUpper(strings.TrimSpace(req.Command))
|
req.Command = strings.ToUpper(strings.TrimSpace(req.Command))
|
||||||
req.To = strings.TrimSpace(req.To)
|
req.To = strings.TrimSpace(req.To)
|
||||||
|
|
||||||
@@ -176,11 +231,13 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc {
|
|||||||
switch v := req.Body.(type) {
|
switch v := req.Body.(type) {
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
lines := make([]string, 0, len(v))
|
lines := make([]string, 0, len(v))
|
||||||
|
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if s, ok := item.(string); ok {
|
if str, ok := item.(string); ok {
|
||||||
lines = append(lines, s)
|
lines = append(lines, str)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
case []string:
|
case []string:
|
||||||
return v
|
return v
|
||||||
@@ -191,137 +248,19 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc {
|
|||||||
|
|
||||||
switch req.Command {
|
switch req.Command {
|
||||||
case "PRIVMSG", "NOTICE":
|
case "PRIVMSG", "NOTICE":
|
||||||
if req.To == "" {
|
s.handlePrivmsg(w, r, uid, nick, req.To, bodyLines())
|
||||||
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lines := bodyLines()
|
|
||||||
if len(lines) == 0 {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
content := strings.Join(lines, "\n")
|
|
||||||
|
|
||||||
if strings.HasPrefix(req.To, "#") {
|
|
||||||
// Channel message
|
|
||||||
var chID int64
|
|
||||||
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
|
||||||
"SELECT id FROM channels WHERE name = ?", req.To).Scan(&chID)
|
|
||||||
if err != nil {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, content)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error("send message failed", "error", err)
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
|
|
||||||
} else {
|
|
||||||
// DM
|
|
||||||
targetID, err := s.params.Database.GetUserByNick(r.Context(), req.To)
|
|
||||||
if err != nil {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
msgID, err := s.params.Database.SendDM(r.Context(), uid, targetID, content)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error("send dm failed", "error", err)
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "JOIN":
|
case "JOIN":
|
||||||
if req.To == "" {
|
s.handleJoin(w, r, uid, req.To)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
channel := req.To
|
|
||||||
if !strings.HasPrefix(channel, "#") {
|
|
||||||
channel = "#" + channel
|
|
||||||
}
|
|
||||||
chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel)
|
|
||||||
if err != nil {
|
|
||||||
s.log.Error("get/create channel failed", "error", err)
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil {
|
|
||||||
s.log.Error("join channel failed", "error", err)
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK)
|
|
||||||
|
|
||||||
case "PART":
|
case "PART":
|
||||||
if req.To == "" {
|
s.handlePart(w, r, uid, req.To)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
channel := req.To
|
|
||||||
if !strings.HasPrefix(channel, "#") {
|
|
||||||
channel = "#" + channel
|
|
||||||
}
|
|
||||||
var chID int64
|
|
||||||
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
|
||||||
"SELECT id FROM channels WHERE name = ?", channel).Scan(&chID)
|
|
||||||
if err != nil {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil {
|
|
||||||
s.log.Error("part channel failed", "error", err)
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK)
|
|
||||||
|
|
||||||
case "NICK":
|
case "NICK":
|
||||||
lines := bodyLines()
|
s.handleNick(w, r, uid, bodyLines())
|
||||||
if len(lines) == 0 {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
newNick := strings.TrimSpace(lines[0])
|
|
||||||
if newNick == "" || len(newNick) > 32 {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil {
|
|
||||||
if strings.Contains(err.Error(), "UNIQUE") {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.log.Error("change nick failed", "error", err)
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK)
|
|
||||||
|
|
||||||
case "TOPIC":
|
case "TOPIC":
|
||||||
if req.To == "" {
|
s.handleTopic(w, r, uid, req.To, bodyLines())
|
||||||
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lines := bodyLines()
|
|
||||||
if len(lines) == 0 {
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
topic := strings.Join(lines, " ")
|
|
||||||
channel := req.To
|
|
||||||
if !strings.HasPrefix(channel, "#") {
|
|
||||||
channel = "#" + channel
|
|
||||||
}
|
|
||||||
if err := s.params.Database.SetTopic(r.Context(), channel, uid, topic); err != nil {
|
|
||||||
s.log.Error("set topic failed", "error", err)
|
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK)
|
|
||||||
|
|
||||||
case "PING":
|
case "PING":
|
||||||
s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK)
|
s.respondJSON(w, r, map[string]string{"command": "PONG", "from": s.params.Config.ServerName}, http.StatusOK)
|
||||||
@@ -333,6 +272,173 @@ func (s *Handlers) HandleSendCommand() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Handlers) handlePrivmsg(w http.ResponseWriter, r *http.Request, uid, nick, to string, lines []string) {
|
||||||
|
if to == "" {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(lines) == 0 {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "body required"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content := strings.Join(lines, "\n")
|
||||||
|
|
||||||
|
if strings.HasPrefix(to, "#") {
|
||||||
|
// Channel message.
|
||||||
|
var chID string
|
||||||
|
|
||||||
|
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
||||||
|
"SELECT id FROM channels WHERE name = ?", to).Scan(&chID)
|
||||||
|
if err != nil {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msgID, err := s.params.Database.SendMessage(r.Context(), chID, uid, nick, content)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error("send message failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
|
||||||
|
} else {
|
||||||
|
// DM.
|
||||||
|
targetUser, err := s.params.Database.GetUserByNick(r.Context(), to)
|
||||||
|
if err != nil {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msgID, err := s.params.Database.SendDM(r.Context(), uid, nick, targetUser.ID, content)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error("send dm failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.respondJSON(w, r, map[string]any{"id": msgID, "status": "sent"}, http.StatusCreated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Handlers) handleJoin(w http.ResponseWriter, r *http.Request, uid, to string) {
|
||||||
|
if to == "" {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channel := to
|
||||||
|
if !strings.HasPrefix(channel, "#") {
|
||||||
|
channel = "#" + channel
|
||||||
|
}
|
||||||
|
|
||||||
|
chID, err := s.params.Database.GetOrCreateChannel(r.Context(), channel)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Error("get/create channel failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.params.Database.JoinChannel(r.Context(), chID, uid); err != nil {
|
||||||
|
s.log.Error("join channel failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.respondJSON(w, r, map[string]string{"status": "joined", "channel": channel}, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Handlers) handlePart(w http.ResponseWriter, r *http.Request, uid, to string) {
|
||||||
|
if to == "" {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channel := to
|
||||||
|
if !strings.HasPrefix(channel, "#") {
|
||||||
|
channel = "#" + channel
|
||||||
|
}
|
||||||
|
|
||||||
|
var chID string
|
||||||
|
|
||||||
|
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
||||||
|
"SELECT id FROM channels WHERE name = ?", channel).Scan(&chID)
|
||||||
|
if err != nil {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.params.Database.PartChannel(r.Context(), chID, uid); err != nil {
|
||||||
|
s.log.Error("part channel failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.respondJSON(w, r, map[string]string{"status": "parted", "channel": channel}, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Handlers) handleNick(w http.ResponseWriter, r *http.Request, uid string, lines []string) {
|
||||||
|
if len(lines) == 0 {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "body required (new nick)"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newNick := strings.TrimSpace(lines[0])
|
||||||
|
if newNick == "" || len(newNick) > 32 {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "nick must be 1-32 characters"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.params.Database.ChangeNick(r.Context(), uid, newNick); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "UNIQUE") {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "nick already in use"}, http.StatusConflict)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.Error("change nick failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.respondJSON(w, r, map[string]string{"status": "ok", "nick": newNick}, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Handlers) handleTopic(w http.ResponseWriter, r *http.Request, uid, to string, lines []string) {
|
||||||
|
if to == "" {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "to field required"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(lines) == 0 {
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "body required (topic text)"}, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topic := strings.Join(lines, " ")
|
||||||
|
|
||||||
|
channel := to
|
||||||
|
if !strings.HasPrefix(channel, "#") {
|
||||||
|
channel = "#" + channel
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.params.Database.SetTopic(r.Context(), channel, uid, topic); err != nil {
|
||||||
|
s.log.Error("set topic failed", "error", err)
|
||||||
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.respondJSON(w, r, map[string]string{"status": "ok", "topic": topic}, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
// HandleGetHistory returns message history for a specific target (channel or DM).
|
// HandleGetHistory returns message history for a specific target (channel or DM).
|
||||||
func (s *Handlers) HandleGetHistory() http.HandlerFunc {
|
func (s *Handlers) HandleGetHistory() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -340,46 +446,56 @@ func (s *Handlers) HandleGetHistory() http.HandlerFunc {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
target := r.URL.Query().Get("target")
|
target := r.URL.Query().Get("target")
|
||||||
if target == "" {
|
if target == "" {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest)
|
s.respondJSON(w, r, map[string]string{"error": "target required"}, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
beforeID, _ := strconv.ParseInt(r.URL.Query().Get("before"), 10, 64)
|
|
||||||
|
beforeTS := r.URL.Query().Get("before")
|
||||||
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||||
|
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 50
|
limit = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(target, "#") {
|
if strings.HasPrefix(target, "#") {
|
||||||
// Channel history
|
// Channel history — look up channel by name to get its ID for target matching.
|
||||||
var chID int64
|
var chID string
|
||||||
|
|
||||||
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
err := s.params.Database.GetDB().QueryRowContext(r.Context(),
|
||||||
"SELECT id FROM channels WHERE name = ?", target).Scan(&chID)
|
"SELECT id FROM channels WHERE name = ?", target).Scan(&chID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
s.respondJSON(w, r, map[string]string{"error": "channel not found"}, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeID, limit)
|
|
||||||
|
msgs, err := s.params.Database.GetMessagesBefore(r.Context(), chID, beforeTS, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("get history failed", "error", err)
|
s.log.Error("get history failed", "error", err)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.respondJSON(w, r, msgs, http.StatusOK)
|
s.respondJSON(w, r, msgs, http.StatusOK)
|
||||||
} else {
|
} else {
|
||||||
// DM history
|
// DM history.
|
||||||
targetID, err := s.params.Database.GetUserByNick(r.Context(), target)
|
targetUser, err := s.params.Database.GetUserByNick(r.Context(), target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
|
s.respondJSON(w, r, map[string]string{"error": "user not found"}, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetID, beforeID, limit)
|
|
||||||
|
msgs, err := s.params.Database.GetDMsBefore(r.Context(), uid, targetUser.ID, beforeTS, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("get dm history failed", "error", err)
|
s.log.Error("get dm history failed", "error", err)
|
||||||
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
s.respondJSON(w, r, map[string]string{"error": "internal error"}, http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.respondJSON(w, r, msgs, http.StatusOK)
|
s.respondJSON(w, r, msgs, http.StatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -391,6 +507,7 @@ func (s *Handlers) HandleServerInfo() http.HandlerFunc {
|
|||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
MOTD string `json:"motd"`
|
MOTD string `json:"motd"`
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
s.respondJSON(w, r, &response{
|
s.respondJSON(w, r, &response{
|
||||||
Name: s.params.Config.ServerName,
|
Name: s.params.Config.ServerName,
|
||||||
|
|||||||
Reference in New Issue
Block a user