package db import ( "context" "database/sql" "encoding/json" "log/slog" "testing" _ "modernc.org/sqlite" ) func setupTestDB(t *testing.T) *Database { t.Helper() d, err := sql.Open("sqlite", "file::memory:?cache=shared&_pragma=foreign_keys(1)") if err != nil { t.Fatal(err) } t.Cleanup(func() { d.Close() }) db := &Database{db: d, log: slog.Default()} if err := db.runMigrations(context.Background()); err != nil { t.Fatal(err) } return db } func TestCreateUser(t *testing.T) { db := setupTestDB(t) ctx := context.Background() id, token, err := db.CreateUser(ctx, "alice") if err != nil { t.Fatal(err) } if id == 0 || token == "" { t.Fatal("expected valid id and token") } // Duplicate nick _, _, err = db.CreateUser(ctx, "alice") if err == nil { t.Fatal("expected error for duplicate nick") } } func TestGetUserByToken(t *testing.T) { db := setupTestDB(t) ctx := context.Background() _, token, _ := db.CreateUser(ctx, "bob") id, nick, err := db.GetUserByToken(ctx, token) if err != nil { t.Fatal(err) } if nick != "bob" || id == 0 { t.Fatalf("expected bob, got %s", nick) } // Invalid token _, _, err = db.GetUserByToken(ctx, "badtoken") if err == nil { t.Fatal("expected error for bad token") } } func TestGetUserByNick(t *testing.T) { db := setupTestDB(t) ctx := context.Background() db.CreateUser(ctx, "charlie") id, err := db.GetUserByNick(ctx, "charlie") if err != nil || id == 0 { t.Fatal("expected to find charlie") } _, err = db.GetUserByNick(ctx, "nobody") if err == nil { t.Fatal("expected error for unknown nick") } } func TestChannelOperations(t *testing.T) { db := setupTestDB(t) ctx := context.Background() // Create channel chID, err := db.GetOrCreateChannel(ctx, "#test") if err != nil || chID == 0 { t.Fatal("expected channel id") } // Get same channel chID2, err := db.GetOrCreateChannel(ctx, "#test") if err != nil || chID2 != chID { t.Fatal("expected same channel id") } // GetChannelByName chID3, err := db.GetChannelByName(ctx, "#test") if err != nil || chID3 != chID { t.Fatal("expected same channel id from GetChannelByName") } // Nonexistent channel _, err = db.GetChannelByName(ctx, "#nope") if err == nil { t.Fatal("expected error for nonexistent channel") } } func TestJoinAndPart(t *testing.T) { db := setupTestDB(t) ctx := context.Background() uid, _, _ := db.CreateUser(ctx, "user1") chID, _ := db.GetOrCreateChannel(ctx, "#chan") // Join if err := db.JoinChannel(ctx, chID, uid); err != nil { t.Fatal(err) } // Verify membership ids, err := db.GetChannelMemberIDs(ctx, chID) if err != nil || len(ids) != 1 || ids[0] != uid { t.Fatal("expected user in channel") } // Double join (should be ignored) if err := db.JoinChannel(ctx, chID, uid); err != nil { t.Fatal(err) } // Part if err := db.PartChannel(ctx, chID, uid); err != nil { t.Fatal(err) } ids, _ = db.GetChannelMemberIDs(ctx, chID) if len(ids) != 0 { t.Fatal("expected empty channel") } } func TestDeleteChannelIfEmpty(t *testing.T) { db := setupTestDB(t) ctx := context.Background() chID, _ := db.GetOrCreateChannel(ctx, "#empty") uid, _, _ := db.CreateUser(ctx, "temp") db.JoinChannel(ctx, chID, uid) db.PartChannel(ctx, chID, uid) if err := db.DeleteChannelIfEmpty(ctx, chID); err != nil { t.Fatal(err) } _, err := db.GetChannelByName(ctx, "#empty") if err == nil { t.Fatal("expected channel to be deleted") } } func TestListChannels(t *testing.T) { db := setupTestDB(t) ctx := context.Background() uid, _, _ := db.CreateUser(ctx, "lister") ch1, _ := db.GetOrCreateChannel(ctx, "#a") ch2, _ := db.GetOrCreateChannel(ctx, "#b") db.JoinChannel(ctx, ch1, uid) db.JoinChannel(ctx, ch2, uid) channels, err := db.ListChannels(ctx, uid) if err != nil || len(channels) != 2 { t.Fatalf("expected 2 channels, got %d", len(channels)) } } func TestListAllChannels(t *testing.T) { db := setupTestDB(t) ctx := context.Background() db.GetOrCreateChannel(ctx, "#x") db.GetOrCreateChannel(ctx, "#y") channels, err := db.ListAllChannels(ctx) if err != nil || len(channels) < 2 { t.Fatalf("expected >= 2 channels, got %d", len(channels)) } } func TestChangeNick(t *testing.T) { db := setupTestDB(t) ctx := context.Background() uid, token, _ := db.CreateUser(ctx, "old") if err := db.ChangeNick(ctx, uid, "new"); err != nil { t.Fatal(err) } _, nick, _ := db.GetUserByToken(ctx, token) if nick != "new" { t.Fatalf("expected new, got %s", nick) } } func TestSetTopic(t *testing.T) { db := setupTestDB(t) ctx := context.Background() db.GetOrCreateChannel(ctx, "#topictest") if err := db.SetTopic(ctx, "#topictest", "Hello"); err != nil { t.Fatal(err) } channels, _ := db.ListAllChannels(ctx) for _, ch := range channels { if ch.Name == "#topictest" && ch.Topic != "Hello" { t.Fatalf("expected topic Hello, got %s", ch.Topic) } } } func TestInsertAndPollMessages(t *testing.T) { db := setupTestDB(t) ctx := context.Background() uid, _, _ := db.CreateUser(ctx, "poller") body := json.RawMessage(`["hello"]`) dbID, uuid, err := db.InsertMessage(ctx, "PRIVMSG", "poller", "#test", body, nil) if err != nil || dbID == 0 || uuid == "" { t.Fatal("insert failed") } if err := db.EnqueueMessage(ctx, uid, dbID); err != nil { t.Fatal(err) } msgs, lastQID, err := db.PollMessages(ctx, uid, 0, 10) if err != nil { t.Fatal(err) } if len(msgs) != 1 { t.Fatalf("expected 1 message, got %d", len(msgs)) } if msgs[0].Command != "PRIVMSG" { t.Fatalf("expected PRIVMSG, got %s", msgs[0].Command) } if lastQID == 0 { t.Fatal("expected nonzero lastQID") } // Poll again with lastQID - should be empty msgs, _, _ = db.PollMessages(ctx, uid, lastQID, 10) if len(msgs) != 0 { t.Fatalf("expected 0 messages, got %d", len(msgs)) } } func TestGetHistory(t *testing.T) { db := setupTestDB(t) ctx := context.Background() for i := 0; i < 10; i++ { db.InsertMessage(ctx, "PRIVMSG", "user", "#hist", json.RawMessage(`["msg"]`), nil) } msgs, err := db.GetHistory(ctx, "#hist", 0, 5) if err != nil { t.Fatal(err) } if len(msgs) != 5 { t.Fatalf("expected 5, got %d", len(msgs)) } // Should be ascending order if msgs[0].DBID > msgs[4].DBID { t.Fatal("expected ascending order") } } func TestDeleteUser(t *testing.T) { db := setupTestDB(t) ctx := context.Background() uid, _, _ := db.CreateUser(ctx, "deleteme") chID, _ := db.GetOrCreateChannel(ctx, "#delchan") db.JoinChannel(ctx, chID, uid) if err := db.DeleteUser(ctx, uid); err != nil { t.Fatal(err) } _, err := db.GetUserByNick(ctx, "deleteme") if err == nil { t.Fatal("user should be deleted") } // Channel membership should be cleaned up via CASCADE ids, _ := db.GetChannelMemberIDs(ctx, chID) if len(ids) != 0 { t.Fatal("expected no members after user deletion") } } func TestChannelMembers(t *testing.T) { db := setupTestDB(t) ctx := context.Background() uid1, _, _ := db.CreateUser(ctx, "m1") uid2, _, _ := db.CreateUser(ctx, "m2") chID, _ := db.GetOrCreateChannel(ctx, "#members") db.JoinChannel(ctx, chID, uid1) db.JoinChannel(ctx, chID, uid2) members, err := db.ChannelMembers(ctx, chID) if err != nil || len(members) != 2 { t.Fatalf("expected 2 members, got %d", len(members)) } } func TestGetAllChannelMembershipsForUser(t *testing.T) { db := setupTestDB(t) ctx := context.Background() uid, _, _ := db.CreateUser(ctx, "multi") ch1, _ := db.GetOrCreateChannel(ctx, "#m1") ch2, _ := db.GetOrCreateChannel(ctx, "#m2") db.JoinChannel(ctx, ch1, uid) db.JoinChannel(ctx, ch2, uid) channels, err := db.GetAllChannelMembershipsForUser(ctx, uid) if err != nil || len(channels) != 2 { t.Fatalf("expected 2 channels, got %d", len(channels)) } }