Files
chat/internal/db/queries_test.go
clawbot f9c145ad09
All checks were successful
check / check (push) Successful in 58s
refactor: replace HTTP error codes with IRC numeric replies for all IRC commands
IRC commands (PRIVMSG, JOIN, PART, NICK, TOPIC, etc.) now respond with
proper IRC numeric replies delivered through the message queue instead of
HTTP status codes. HTTP error codes are now reserved exclusively for
transport-level concerns: auth failures (401), malformed requests (400),
and server errors (500).

Changes:
- Add params column to messages table for IRC-style parameters
- Add Params field to IRCMessage struct and update all queries
- Add respondIRCError helper for consistent IRC error delivery
- Add RPL_WELCOME (001) on session creation and login
- Add RPL_TOPIC/RPL_NOTOPIC (332/331), RPL_NAMREPLY (353),
  RPL_ENDOFNAMES (366) on JOIN
- Add RPL_TOPIC (332) on TOPIC set
- Replace HTTP 404 with ERR_NOSUCHCHANNEL (403) and ERR_NOSUCHNICK (401)
- Replace HTTP 409 with ERR_NICKNAMEINUSE (433)
- Replace HTTP 403 with ERR_NOTONCHANNEL (442)
- Replace HTTP 400 with ERR_NEEDMOREPARAMS (461), ERR_ERRONEUSNICKNAME (432),
  and ERR_UNKNOWNCOMMAND (421) where appropriate
- Change PRIVMSG/NOTICE success from HTTP 201 to HTTP 200
- Update all tests to verify IRC numerics in message queue
- Add new tests for RPL_WELCOME and JOIN numerics
- Update README to document new numeric reply behavior

closes #54
2026-03-08 01:32:02 -08:00

654 lines
11 KiB
Go

package db_test
import (
"encoding/json"
"testing"
"git.eeqj.de/sneak/neoirc/internal/db"
_ "modernc.org/sqlite"
)
func setupTestDB(t *testing.T) *db.Database {
t.Helper()
database, err := db.NewTestDatabase()
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
closeErr := database.Close()
if closeErr != nil {
t.Logf("close db: %v", closeErr)
}
})
return database
}
func TestCreateSession(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sessionID, _, token, err := database.CreateSession(
ctx, "alice",
)
if err != nil {
t.Fatal(err)
}
if sessionID == 0 || token == "" {
t.Fatal("expected valid id and token")
}
_, _, dupToken, dupErr := database.CreateSession(
ctx, "alice",
)
if dupErr == nil {
t.Fatal("expected error for duplicate nick")
}
_ = dupToken
}
func TestGetSessionByToken(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, _, token, err := database.CreateSession(ctx, "bob")
if err != nil {
t.Fatal(err)
}
sessionID, clientID, nick, err :=
database.GetSessionByToken(ctx, token)
if err != nil {
t.Fatal(err)
}
if nick != "bob" || sessionID == 0 || clientID == 0 {
t.Fatalf("expected bob, got %s", nick)
}
badSID, badCID, badNick, badErr :=
database.GetSessionByToken(ctx, "badtoken")
if badErr == nil {
t.Fatal("expected error for bad token")
}
if badSID != 0 || badCID != 0 || badNick != "" {
t.Fatal("expected zero values on error")
}
}
func TestGetSessionByNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
charlieID, charlieClientID, charlieToken, err :=
database.CreateSession(ctx, "charlie")
if err != nil {
t.Fatal(err)
}
if charlieID == 0 || charlieClientID == 0 {
t.Fatal("expected valid session/client IDs")
}
if charlieToken == "" {
t.Fatal("expected non-empty token")
}
id, err := database.GetSessionByNick(ctx, "charlie")
if err != nil || id == 0 {
t.Fatal("expected to find charlie")
}
_, err = database.GetSessionByNick(ctx, "nobody")
if err == nil {
t.Fatal("expected error for unknown nick")
}
}
func TestChannelOperations(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil || chID == 0 {
t.Fatal("expected channel id")
}
chID2, err := database.GetOrCreateChannel(ctx, "#test")
if err != nil || chID2 != chID {
t.Fatal("expected same channel id")
}
chID3, err := database.GetChannelByName(ctx, "#test")
if err != nil || chID3 != chID {
t.Fatal("expected same channel id")
}
_, err = database.GetChannelByName(ctx, "#nope")
if err == nil {
t.Fatal("expected error for nonexistent channel")
}
}
func TestJoinAndPart(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, "user1")
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(ctx, "#chan")
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
ids, err := database.GetChannelMemberIDs(ctx, chID)
if err != nil || len(ids) != 1 || ids[0] != sid {
t.Fatal("expected session in channel")
}
err = database.JoinChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
err = database.PartChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
ids, _ = database.GetChannelMemberIDs(ctx, chID)
if len(ids) != 0 {
t.Fatal("expected empty channel")
}
}
func TestDeleteChannelIfEmpty(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
chID, err := database.GetOrCreateChannel(
ctx, "#empty",
)
if err != nil {
t.Fatal(err)
}
sid, _, _, err := database.CreateSession(ctx, "temp")
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
err = database.PartChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
err = database.DeleteChannelIfEmpty(ctx, chID)
if err != nil {
t.Fatal(err)
}
_, err = database.GetChannelByName(ctx, "#empty")
if err == nil {
t.Fatal("expected channel to be deleted")
}
}
func createSessionWithChannels(
t *testing.T,
database *db.Database,
nick, ch1Name, ch2Name string,
) (int64, int64, int64) {
t.Helper()
ctx := t.Context()
sid, _, _, err := database.CreateSession(ctx, nick)
if err != nil {
t.Fatal(err)
}
ch1, err := database.GetOrCreateChannel(
ctx, ch1Name,
)
if err != nil {
t.Fatal(err)
}
ch2, err := database.GetOrCreateChannel(
ctx, ch2Name,
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, ch1, sid)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, ch2, sid)
if err != nil {
t.Fatal(err)
}
return sid, ch1, ch2
}
func TestListChannels(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
sid, _, _ := createSessionWithChannels(
t, database, "lister", "#a", "#b",
)
channels, err := database.ListChannels(
t.Context(), sid,
)
if err != nil || len(channels) != 2 {
t.Fatalf(
"expected 2 channels, got %d",
len(channels),
)
}
}
func TestListAllChannels(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, err := database.GetOrCreateChannel(ctx, "#x")
if err != nil {
t.Fatal(err)
}
_, err = database.GetOrCreateChannel(ctx, "#y")
if err != nil {
t.Fatal(err)
}
channels, err := database.ListAllChannels(ctx)
if err != nil || len(channels) < 2 {
t.Fatalf(
"expected >= 2 channels, got %d",
len(channels),
)
}
}
func TestChangeNick(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "old",
)
if err != nil {
t.Fatal(err)
}
err = database.ChangeNick(ctx, sid, "new")
if err != nil {
t.Fatal(err)
}
_, _, nick, err := database.GetSessionByToken(
ctx, token,
)
if err != nil {
t.Fatal(err)
}
if nick != "new" {
t.Fatalf("expected new, got %s", nick)
}
}
func TestSetTopic(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, err := database.GetOrCreateChannel(
ctx, "#topictest",
)
if err != nil {
t.Fatal(err)
}
err = database.SetTopic(ctx, "#topictest", "Hello")
if err != nil {
t.Fatal(err)
}
channels, err := database.ListAllChannels(ctx)
if err != nil {
t.Fatal(err)
}
for _, ch := range channels {
if ch.Name == "#topictest" &&
ch.Topic != "Hello" {
t.Fatalf(
"expected topic Hello, got %s",
ch.Topic,
)
}
}
}
func TestInsertMessage(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
body := json.RawMessage(`["hello"]`)
dbID, msgUUID, err := database.InsertMessage(
ctx, "PRIVMSG", "poller", "#test", nil, body, nil,
)
if err != nil {
t.Fatal(err)
}
if dbID == 0 || msgUUID == "" {
t.Fatal("expected valid id and uuid")
}
}
func TestPollMessages(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, token, err := database.CreateSession(
ctx, "poller",
)
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.GetSessionByToken(
ctx, token,
)
if err != nil {
t.Fatal(err)
}
body := json.RawMessage(`["hello"]`)
dbID, _, err := database.InsertMessage(
ctx, "PRIVMSG", "poller", "#test", nil, body, nil,
)
if err != nil {
t.Fatal(err)
}
err = database.EnqueueToSession(ctx, sid, dbID)
if err != nil {
t.Fatal(err)
}
const batchSize = 10
msgs, lastQID, err := database.PollMessages(
ctx, clientID, 0, batchSize,
)
if err != nil {
t.Fatal(err)
}
if len(msgs) != 1 {
t.Fatalf(
"expected 1 message, got %d", len(msgs),
)
}
if msgs[0].Command != "PRIVMSG" {
t.Fatalf(
"expected PRIVMSG, got %s", msgs[0].Command,
)
}
if lastQID == 0 {
t.Fatal("expected nonzero lastQID")
}
msgs, _, _ = database.PollMessages(
ctx, clientID, lastQID, batchSize,
)
if len(msgs) != 0 {
t.Fatalf(
"expected 0 messages, got %d", len(msgs),
)
}
}
func TestGetHistory(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
const msgCount = 10
for range msgCount {
_, _, err := database.InsertMessage(
ctx, "PRIVMSG", "user", "#hist",
nil, json.RawMessage(`["msg"]`), nil,
)
if err != nil {
t.Fatal(err)
}
}
const histLimit = 5
msgs, err := database.GetHistory(
ctx, "#hist", 0, histLimit,
)
if err != nil {
t.Fatal(err)
}
if len(msgs) != histLimit {
t.Fatalf("expected %d, got %d",
histLimit, len(msgs))
}
if msgs[0].DBID > msgs[histLimit-1].DBID {
t.Fatal("expected ascending order")
}
}
func TestDeleteSession(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid, _, _, err := database.CreateSession(
ctx, "deleteme",
)
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(
ctx, "#delchan",
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, sid)
if err != nil {
t.Fatal(err)
}
err = database.DeleteSession(ctx, sid)
if err != nil {
t.Fatal(err)
}
_, err = database.GetSessionByNick(ctx, "deleteme")
if err == nil {
t.Fatal("session should be deleted")
}
ids, _ := database.GetChannelMemberIDs(ctx, chID)
if len(ids) != 0 {
t.Fatal("expected no members after deletion")
}
}
func TestChannelMembers(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
sid1, _, _, err := database.CreateSession(ctx, "m1")
if err != nil {
t.Fatal(err)
}
sid2, _, _, err := database.CreateSession(ctx, "m2")
if err != nil {
t.Fatal(err)
}
chID, err := database.GetOrCreateChannel(
ctx, "#members",
)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, sid1)
if err != nil {
t.Fatal(err)
}
err = database.JoinChannel(ctx, chID, sid2)
if err != nil {
t.Fatal(err)
}
members, err := database.ChannelMembers(ctx, chID)
if err != nil || len(members) != 2 {
t.Fatalf(
"expected 2 members, got %d",
len(members),
)
}
}
func TestGetSessionChannels(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
sid, _, _ := createSessionWithChannels(
t, database, "multi", "#m1", "#m2",
)
channels, err :=
database.GetSessionChannels(
t.Context(), sid,
)
if err != nil || len(channels) != 2 {
t.Fatalf(
"expected 2 channels, got %d",
len(channels),
)
}
}
func TestEnqueueToClient(t *testing.T) {
t.Parallel()
database := setupTestDB(t)
ctx := t.Context()
_, _, token, err := database.CreateSession(
ctx, "enqclient",
)
if err != nil {
t.Fatal(err)
}
_, clientID, _, err := database.GetSessionByToken(
ctx, token,
)
if err != nil {
t.Fatal(err)
}
body := json.RawMessage(`["test"]`)
dbID, _, err := database.InsertMessage(
ctx, "PRIVMSG", "sender", "#ch", nil, body, nil,
)
if err != nil {
t.Fatal(err)
}
err = database.EnqueueToClient(ctx, clientID, dbID)
if err != nil {
t.Fatal(err)
}
const batchSize = 10
msgs, _, err := database.PollMessages(
ctx, clientID, 0, batchSize,
)
if err != nil {
t.Fatal(err)
}
if len(msgs) != 1 {
t.Fatalf("expected 1, got %d", len(msgs))
}
}