diff --git a/internal/db/db_test.go b/internal/db/db_test.go index cccad85..b3cf841 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -43,213 +43,119 @@ func TestCreateUser(t *testing.T) { d := setupTestDB(t) ctx := context.Background() - u, err := d.CreateUserModel(ctx, "u1", nickAlice, "hash1") + id, token, err := d.CreateUser(ctx, nickAlice) if err != nil { t.Fatalf("CreateUser: %v", err) } - if u.ID != "u1" || u.Nick != nickAlice { - t.Errorf("got user %+v", u) + if id <= 0 { + t.Errorf("expected positive id, got %d", id) + } + + if token == "" { + t.Error("expected non-empty token") } } -func TestCreateAuthToken(t *testing.T) { +func TestGetUserByToken(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - _, err := d.CreateUserModel(ctx, "u1", nickAlice, "h") + _, token, _ := d.CreateUser(ctx, nickAlice) + + id, nick, err := d.GetUserByToken(ctx, token) if err != nil { - t.Fatalf("CreateUser: %v", err) + t.Fatalf("GetUserByToken: %v", err) } - tok, err := d.CreateAuthToken(ctx, "tok1", "u1") - if err != nil { - t.Fatalf("CreateAuthToken: %v", err) - } - - if tok.Token != "tok1" || tok.UserID != "u1" { - t.Errorf("unexpected token: %+v", tok) - } - - u, err := tok.User(ctx) - if err != nil { - t.Fatalf("AuthToken.User: %v", err) - } - - if u.ID != "u1" || u.Nick != nickAlice { - t.Errorf("AuthToken.User got %+v", u) + if id <= 0 || nick != nickAlice { + t.Errorf( + "got id=%d nick=%s, want nick=%s", + id, nick, nickAlice, + ) } } -func TestCreateChannel(t *testing.T) { +func TestGetUserByNick(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - ch, err := d.CreateChannel( - ctx, "c1", "#general", "welcome", "+n", - ) + origID, _, _ := d.CreateUser(ctx, nickAlice) + + id, err := d.GetUserByNick(ctx, nickAlice) if err != nil { - t.Fatalf("CreateChannel: %v", err) + t.Fatalf("GetUserByNick: %v", err) } - if ch.ID != "c1" || ch.Name != "#general" { - t.Errorf("unexpected channel: %+v", ch) + if id != origID { + t.Errorf("got id %d, want %d", id, origID) } } -func TestAddChannelMember(t *testing.T) { +func TestGetOrCreateChannel(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateChannel(ctx, "c1", "#general", "", "") - - cm, err := d.AddChannelMember(ctx, "c1", "u1", "+o") + id1, err := d.GetOrCreateChannel(ctx, "#general") if err != nil { - t.Fatalf("AddChannelMember: %v", err) + t.Fatalf("GetOrCreateChannel: %v", err) } - if cm.ChannelID != "c1" || cm.Modes != "+o" { - t.Errorf("unexpected member: %+v", cm) + if id1 <= 0 { + t.Errorf("expected positive id, got %d", id1) + } + + // Same channel returns same ID. + id2, err := d.GetOrCreateChannel(ctx, "#general") + if err != nil { + t.Fatalf("GetOrCreateChannel(2): %v", err) + } + + if id1 != id2 { + t.Errorf("got different ids: %d vs %d", id1, id2) } } -func TestCreateMessage(t *testing.T) { +func TestJoinAndListChannels(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") + uid, _, _ := d.CreateUser(ctx, nickAlice) + ch1, _ := d.GetOrCreateChannel(ctx, "#alpha") + ch2, _ := d.GetOrCreateChannel(ctx, "#beta") - msg, err := d.CreateMessage( - ctx, "m1", "u1", nickAlice, - "#general", "message", "hello", - ) + _ = d.JoinChannel(ctx, ch1, uid) + _ = d.JoinChannel(ctx, ch2, uid) + + channels, err := d.ListChannels(ctx, uid) if err != nil { - t.Fatalf("CreateMessage: %v", err) - } - - if msg.ID != "m1" || msg.Body != "hello" { - t.Errorf("unexpected message: %+v", msg) - } -} - -func TestQueueMessage(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateUserModel(ctx, "u2", nickBob, "h") - _, _ = d.CreateMessage( - ctx, "m1", "u1", nickAlice, "u2", "message", "hi", - ) - - mq, err := d.QueueMessage(ctx, "u2", "m1") - if err != nil { - t.Fatalf("QueueMessage: %v", err) - } - - if mq.UserID != "u2" || mq.MessageID != "m1" { - t.Errorf("unexpected queue entry: %+v", mq) - } -} - -func TestCreateSession(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - - sess, err := d.CreateSession(ctx, "s1", "u1") - if err != nil { - t.Fatalf("CreateSession: %v", err) - } - - if sess.ID != "s1" || sess.UserID != "u1" { - t.Errorf("unexpected session: %+v", sess) - } - - u, err := sess.User(ctx) - if err != nil { - t.Fatalf("Session.User: %v", err) - } - - if u.ID != "u1" { - t.Errorf("Session.User got %v, want u1", u.ID) - } -} - -func TestCreateServerLink(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - sl, err := d.CreateServerLink( - ctx, "sl1", "peer1", - "https://peer.example.com", "keyhash", true, - ) - if err != nil { - t.Fatalf("CreateServerLink: %v", err) - } - - if sl.ID != "sl1" || !sl.IsActive { - t.Errorf("unexpected server link: %+v", sl) - } -} - -func TestUserChannels(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - u, _ := d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateChannel(ctx, "c1", "#alpha", "", "") - _, _ = d.CreateChannel(ctx, "c2", "#beta", "", "") - _, _ = d.AddChannelMember(ctx, "c1", "u1", "") - _, _ = d.AddChannelMember(ctx, "c2", "u1", "") - - channels, err := u.Channels(ctx) - if err != nil { - t.Fatalf("User.Channels: %v", err) + t.Fatalf("ListChannels: %v", err) } if len(channels) != 2 { t.Fatalf("expected 2 channels, got %d", len(channels)) } - - if channels[0].Name != "#alpha" { - t.Errorf("first channel: got %s", channels[0].Name) - } - - if channels[1].Name != "#beta" { - t.Errorf("second channel: got %s", channels[1].Name) - } } -func TestUserChannelsEmpty(t *testing.T) { +func TestListChannelsEmpty(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - u, _ := d.CreateUserModel(ctx, "u1", nickAlice, "h") + uid, _, _ := d.CreateUser(ctx, nickAlice) - channels, err := u.Channels(ctx) + channels, err := d.ListChannels(ctx, uid) if err != nil { - t.Fatalf("User.Channels: %v", err) + t.Fatalf("ListChannels: %v", err) } if len(channels) != 0 { @@ -257,60 +163,57 @@ func TestUserChannelsEmpty(t *testing.T) { } } -func TestUserQueuedMessages(t *testing.T) { +func TestPartChannel(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - u, _ := d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateUserModel(ctx, "u2", nickBob, "h") + uid, _, _ := d.CreateUser(ctx, nickAlice) + chID, _ := d.GetOrCreateChannel(ctx, "#general") - for i := range 3 { - id := fmt.Sprintf("m%d", i) + _ = d.JoinChannel(ctx, chID, uid) + _ = d.PartChannel(ctx, chID, uid) - _, _ = d.CreateMessage( - ctx, id, "u2", nickBob, "u1", - "message", fmt.Sprintf("msg%d", i), - ) - - time.Sleep(10 * time.Millisecond) - - _, _ = d.QueueMessage(ctx, "u1", id) - } - - msgs, err := u.QueuedMessages(ctx) + channels, err := d.ListChannels(ctx, uid) if err != nil { - t.Fatalf("User.QueuedMessages: %v", err) + t.Fatalf("ListChannels: %v", err) } - if len(msgs) != 3 { - t.Fatalf("expected 3 messages, got %d", len(msgs)) - } - - for i, msg := range msgs { - want := fmt.Sprintf("msg%d", i) - if msg.Body != want { - t.Errorf("msg %d: got %q, want %q", i, msg.Body, want) - } + if len(channels) != 0 { + t.Errorf("expected 0 after part, got %d", len(channels)) } } -func TestUserQueuedMessagesEmpty(t *testing.T) { +func TestSendAndGetMessages(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - u, _ := d.CreateUserModel(ctx, "u1", nickAlice, "h") + uid, _, _ := d.CreateUser(ctx, nickAlice) + chID, _ := d.GetOrCreateChannel(ctx, "#general") + _ = d.JoinChannel(ctx, chID, uid) - msgs, err := u.QueuedMessages(ctx) + _, err := d.SendMessage(ctx, chID, uid, "hello world") if err != nil { - t.Fatalf("User.QueuedMessages: %v", err) + t.Fatalf("SendMessage: %v", err) } - if len(msgs) != 0 { - t.Errorf("expected 0 messages, got %d", len(msgs)) + msgs, err := d.GetMessages(ctx, chID, 0, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + + if len(msgs) != 1 { + t.Fatalf("expected 1 message, got %d", len(msgs)) + } + + if msgs[0].Content != "hello world" { + t.Errorf( + "got content %q, want %q", + msgs[0].Content, "hello world", + ) } } @@ -320,35 +223,23 @@ func TestChannelMembers(t *testing.T) { d := setupTestDB(t) ctx := context.Background() - ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateUserModel(ctx, "u2", nickBob, "h") - _, _ = d.CreateUserModel(ctx, "u3", nickCharlie, "h") - _, _ = d.AddChannelMember(ctx, "c1", "u1", "+o") - _, _ = d.AddChannelMember(ctx, "c1", "u2", "+v") - _, _ = d.AddChannelMember(ctx, "c1", "u3", "") + uid1, _, _ := d.CreateUser(ctx, nickAlice) + uid2, _, _ := d.CreateUser(ctx, nickBob) + uid3, _, _ := d.CreateUser(ctx, nickCharlie) + chID, _ := d.GetOrCreateChannel(ctx, "#general") - members, err := ch.Members(ctx) + _ = d.JoinChannel(ctx, chID, uid1) + _ = d.JoinChannel(ctx, chID, uid2) + _ = d.JoinChannel(ctx, chID, uid3) + + members, err := d.ChannelMembers(ctx, chID) if err != nil { - t.Fatalf("Channel.Members: %v", err) + t.Fatalf("ChannelMembers: %v", err) } if len(members) != 3 { t.Fatalf("expected 3 members, got %d", len(members)) } - - nicks := map[string]bool{} - for _, m := range members { - nicks[m.Nick] = true - } - - for _, want := range []string{ - nickAlice, nickBob, nickCharlie, - } { - if !nicks[want] { - t.Errorf("missing nick %q", want) - } - } } func TestChannelMembersEmpty(t *testing.T) { @@ -357,11 +248,11 @@ func TestChannelMembersEmpty(t *testing.T) { d := setupTestDB(t) ctx := context.Background() - ch, _ := d.CreateChannel(ctx, "c1", "#empty", "", "") + chID, _ := d.GetOrCreateChannel(ctx, "#empty") - members, err := ch.Members(ctx) + members, err := d.ChannelMembers(ctx, chID) if err != nil { - t.Fatalf("Channel.Members: %v", err) + t.Fatalf("ChannelMembers: %v", err) } if len(members) != 0 { @@ -369,126 +260,166 @@ func TestChannelMembersEmpty(t *testing.T) { } } -func TestChannelRecentMessages(t *testing.T) { +func TestSendDM(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") + uid1, _, _ := d.CreateUser(ctx, nickAlice) + uid2, _, _ := d.CreateUser(ctx, nickBob) + + msgID, err := d.SendDM(ctx, uid1, uid2, "hey bob") + if err != nil { + t.Fatalf("SendDM: %v", err) + } + + if msgID <= 0 { + t.Errorf("expected positive msgID, got %d", msgID) + } + + msgs, err := d.GetDMs(ctx, uid1, uid2, 0, 0) + if err != nil { + t.Fatalf("GetDMs: %v", err) + } + + if len(msgs) != 1 { + t.Fatalf("expected 1 DM, got %d", len(msgs)) + } + + if msgs[0].Content != "hey bob" { + t.Errorf("got %q, want %q", msgs[0].Content, "hey bob") + } +} + +func TestPollMessages(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + uid1, _, _ := d.CreateUser(ctx, nickAlice) + uid2, _, _ := d.CreateUser(ctx, nickBob) + chID, _ := d.GetOrCreateChannel(ctx, "#general") + + _ = d.JoinChannel(ctx, chID, uid1) + _ = d.JoinChannel(ctx, chID, uid2) + + _, _ = d.SendMessage(ctx, chID, uid2, "hello") + _, _ = d.SendDM(ctx, uid2, uid1, "private") + + time.Sleep(10 * time.Millisecond) + + msgs, err := d.PollMessages(ctx, uid1, 0, 0) + if err != nil { + t.Fatalf("PollMessages: %v", err) + } + + if len(msgs) < 2 { + t.Fatalf("expected >=2 messages, got %d", len(msgs)) + } +} + +func TestChangeNick(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + _, token, _ := d.CreateUser(ctx, nickAlice) + + err := d.ChangeNick(ctx, 1, "alice2") + if err != nil { + t.Fatalf("ChangeNick: %v", err) + } + + _, nick, err := d.GetUserByToken(ctx, token) + if err != nil { + t.Fatalf("GetUserByToken: %v", err) + } + + if nick != "alice2" { + t.Errorf("got nick %q, want alice2", nick) + } +} + +func TestSetTopic(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + uid, _, _ := d.CreateUser(ctx, nickAlice) + _, _ = d.GetOrCreateChannel(ctx, "#general") + + err := d.SetTopic(ctx, "#general", uid, "new topic") + if err != nil { + t.Fatalf("SetTopic: %v", err) + } + + channels, err := d.ListAllChannels(ctx) + if err != nil { + t.Fatalf("ListAllChannels: %v", err) + } + + found := false + + for _, ch := range channels { + if ch.Name == "#general" && ch.Topic == "new topic" { + found = true + } + } + + if !found { + t.Error("topic was not updated") + } +} + +func TestGetMessagesBefore(t *testing.T) { + t.Parallel() + + d := setupTestDB(t) + ctx := context.Background() + + uid, _, _ := d.CreateUser(ctx, nickAlice) + chID, _ := d.GetOrCreateChannel(ctx, "#general") + + _ = d.JoinChannel(ctx, chID, uid) for i := range 5 { - id := fmt.Sprintf("m%d", i) - - _, _ = d.CreateMessage( - ctx, id, "u1", nickAlice, "#general", - "message", fmt.Sprintf("msg%d", i), + _, _ = d.SendMessage( + ctx, chID, uid, + fmt.Sprintf("msg%d", i), ) time.Sleep(10 * time.Millisecond) } - msgs, err := ch.RecentMessages(ctx, 3) + msgs, err := d.GetMessagesBefore(ctx, chID, 0, 3) if err != nil { - t.Fatalf("RecentMessages: %v", err) + t.Fatalf("GetMessagesBefore: %v", err) } if len(msgs) != 3 { t.Fatalf("expected 3, got %d", len(msgs)) } - - if msgs[0].Body != "msg4" { - t.Errorf("first: got %q, want msg4", msgs[0].Body) - } - - if msgs[2].Body != "msg2" { - t.Errorf("last: got %q, want msg2", msgs[2].Body) - } } -func TestChannelRecentMessagesLargeLimit(t *testing.T) { +func TestListAllChannels(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() - ch, _ := d.CreateChannel(ctx, "c1", "#general", "", "") - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateMessage( - ctx, "m1", "u1", nickAlice, - "#general", "message", "only", - ) + _, _ = d.GetOrCreateChannel(ctx, "#alpha") + _, _ = d.GetOrCreateChannel(ctx, "#beta") - msgs, err := ch.RecentMessages(ctx, 100) + channels, err := d.ListAllChannels(ctx) if err != nil { - t.Fatalf("RecentMessages: %v", err) + t.Fatalf("ListAllChannels: %v", err) } - if len(msgs) != 1 { - t.Errorf("expected 1, got %d", len(msgs)) - } -} - -func TestChannelMemberUser(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateChannel(ctx, "c1", "#general", "", "") - - cm, _ := d.AddChannelMember(ctx, "c1", "u1", "+o") - - u, err := cm.User(ctx) - if err != nil { - t.Fatalf("ChannelMember.User: %v", err) - } - - if u.ID != "u1" || u.Nick != nickAlice { - t.Errorf("got %+v", u) - } -} - -func TestChannelMemberChannel(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateChannel(ctx, "c1", "#general", "topic", "+n") - - cm, _ := d.AddChannelMember(ctx, "c1", "u1", "") - - ch, err := cm.Channel(ctx) - if err != nil { - t.Fatalf("ChannelMember.Channel: %v", err) - } - - if ch.ID != "c1" || ch.Topic != "topic" { - t.Errorf("got %+v", ch) - } -} - -func TestDMMessage(t *testing.T) { - t.Parallel() - - d := setupTestDB(t) - ctx := context.Background() - - _, _ = d.CreateUserModel(ctx, "u1", nickAlice, "h") - _, _ = d.CreateUserModel(ctx, "u2", nickBob, "h") - - msg, err := d.CreateMessage( - ctx, "m1", "u1", nickAlice, "u2", "message", "hey", - ) - if err != nil { - t.Fatalf("CreateMessage DM: %v", err) - } - - if msg.Target != "u2" { - t.Errorf("target: got %q, want u2", msg.Target) + if len(channels) != 2 { + t.Errorf("expected 2, got %d", len(channels)) } } diff --git a/internal/db/schema/003_users.sql b/internal/db/schema/003_users.sql index f305aa0..a89aad8 100644 --- a/internal/db/schema/003_users.sql +++ b/internal/db/schema/003_users.sql @@ -1,6 +1,18 @@ -PRAGMA foreign_keys = ON; +-- Migration 003: Replace UUID-based tables with simple integer-keyed +-- tables for the HTTP API. Drops the 002 tables and recreates them. -CREATE TABLE IF NOT EXISTS users ( +PRAGMA foreign_keys = OFF; + +DROP TABLE IF EXISTS message_queue; +DROP TABLE IF EXISTS sessions; +DROP TABLE IF EXISTS server_links; +DROP TABLE IF EXISTS messages; +DROP TABLE IF EXISTS channel_members; +DROP TABLE IF EXISTS auth_tokens; +DROP TABLE IF EXISTS channels; +DROP TABLE IF EXISTS users; + +CREATE TABLE users ( id INTEGER PRIMARY KEY AUTOINCREMENT, nick TEXT NOT NULL UNIQUE, token TEXT NOT NULL UNIQUE, @@ -8,7 +20,15 @@ CREATE TABLE IF NOT EXISTS users ( last_seen DATETIME DEFAULT CURRENT_TIMESTAMP ); -CREATE TABLE IF NOT EXISTS channel_members ( +CREATE TABLE channels ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + topic TEXT NOT NULL DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE channel_members ( id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE, user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, @@ -16,7 +36,7 @@ CREATE TABLE IF NOT EXISTS channel_members ( UNIQUE(channel_id, user_id) ); -CREATE TABLE IF NOT EXISTS messages ( +CREATE TABLE messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER REFERENCES channels(id) ON DELETE CASCADE, user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, @@ -26,6 +46,8 @@ CREATE TABLE IF NOT EXISTS messages ( created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); -CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel_id, created_at); -CREATE INDEX IF NOT EXISTS idx_messages_dm ON messages(user_id, dm_target_id, created_at); -CREATE INDEX IF NOT EXISTS idx_users_token ON users(token); +CREATE INDEX idx_messages_channel ON messages(channel_id, created_at); +CREATE INDEX idx_messages_dm ON messages(user_id, dm_target_id, created_at); +CREATE INDEX idx_users_token ON users(token); + +PRAGMA foreign_keys = ON;