package db_test import ( "context" "fmt" "path/filepath" "testing" "time" "git.eeqj.de/sneak/chat/internal/db" ) const ( nickAlice = "alice" nickBob = "bob" nickCharlie = "charlie" ) // setupTestDB creates a fresh database in a temp directory with // all migrations applied. func setupTestDB(t *testing.T) *db.Database { t.Helper() dir := t.TempDir() dsn := fmt.Sprintf( "file:%s?_journal_mode=WAL", filepath.Join(dir, "test.db"), ) d, err := db.NewTest(dsn) if err != nil { t.Fatalf("failed to create test database: %v", err) } t.Cleanup(func() { _ = d.GetDB().Close() }) return d } func TestCreateUser(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() id, token, err := d.CreateUser(ctx, nickAlice) if err != nil { t.Fatalf("CreateUser: %v", err) } if id <= 0 { t.Errorf("expected positive id, got %d", id) } if token == "" { t.Error("expected non-empty token") } } func TestGetUserByToken(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() _, token, _ := d.CreateUser(ctx, nickAlice) id, nick, err := d.GetUserByToken(ctx, token) if err != nil { t.Fatalf("GetUserByToken: %v", err) } if id <= 0 || nick != nickAlice { t.Errorf( "got id=%d nick=%s, want nick=%s", id, nick, nickAlice, ) } } func TestGetUserByNick(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() origID, _, _ := d.CreateUser(ctx, nickAlice) id, err := d.GetUserByNick(ctx, nickAlice) if err != nil { t.Fatalf("GetUserByNick: %v", err) } if id != origID { t.Errorf("got id %d, want %d", id, origID) } } func TestGetOrCreateChannel(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() id1, err := d.GetOrCreateChannel(ctx, "#general") if err != nil { t.Fatalf("GetOrCreateChannel: %v", err) } if id1 <= 0 { t.Errorf("expected positive id, got %d", id1) } // Same channel returns same ID. id2, err := d.GetOrCreateChannel(ctx, "#general") if err != nil { t.Fatalf("GetOrCreateChannel(2): %v", err) } if id1 != id2 { t.Errorf("got different ids: %d vs %d", id1, id2) } } func TestJoinAndListChannels(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid, _, _ := d.CreateUser(ctx, nickAlice) ch1, _ := d.GetOrCreateChannel(ctx, "#alpha") ch2, _ := d.GetOrCreateChannel(ctx, "#beta") _ = d.JoinChannel(ctx, ch1, uid) _ = d.JoinChannel(ctx, ch2, uid) channels, err := d.ListChannels(ctx, uid) if err != nil { t.Fatalf("ListChannels: %v", err) } if len(channels) != 2 { t.Fatalf("expected 2 channels, got %d", len(channels)) } } func TestListChannelsEmpty(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid, _, _ := d.CreateUser(ctx, nickAlice) channels, err := d.ListChannels(ctx, uid) if err != nil { t.Fatalf("ListChannels: %v", err) } if len(channels) != 0 { t.Errorf("expected 0 channels, got %d", len(channels)) } } func TestPartChannel(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid, _, _ := d.CreateUser(ctx, nickAlice) chID, _ := d.GetOrCreateChannel(ctx, "#general") _ = d.JoinChannel(ctx, chID, uid) _ = d.PartChannel(ctx, chID, uid) channels, err := d.ListChannels(ctx, uid) if err != nil { t.Fatalf("ListChannels: %v", err) } if len(channels) != 0 { t.Errorf("expected 0 after part, got %d", len(channels)) } } func TestSendAndGetMessages(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid, _, _ := d.CreateUser(ctx, nickAlice) chID, _ := d.GetOrCreateChannel(ctx, "#general") _ = d.JoinChannel(ctx, chID, uid) _, err := d.SendMessage(ctx, chID, uid, "hello world") if err != nil { t.Fatalf("SendMessage: %v", err) } msgs, err := d.GetMessages(ctx, chID, 0, 0) if err != nil { t.Fatalf("GetMessages: %v", err) } if len(msgs) != 1 { t.Fatalf("expected 1 message, got %d", len(msgs)) } if msgs[0].Content != "hello world" { t.Errorf( "got content %q, want %q", msgs[0].Content, "hello world", ) } } func TestChannelMembers(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid1, _, _ := d.CreateUser(ctx, nickAlice) uid2, _, _ := d.CreateUser(ctx, nickBob) uid3, _, _ := d.CreateUser(ctx, nickCharlie) chID, _ := d.GetOrCreateChannel(ctx, "#general") _ = d.JoinChannel(ctx, chID, uid1) _ = d.JoinChannel(ctx, chID, uid2) _ = d.JoinChannel(ctx, chID, uid3) members, err := d.ChannelMembers(ctx, chID) if err != nil { t.Fatalf("ChannelMembers: %v", err) } if len(members) != 3 { t.Fatalf("expected 3 members, got %d", len(members)) } } func TestChannelMembersEmpty(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() chID, _ := d.GetOrCreateChannel(ctx, "#empty") members, err := d.ChannelMembers(ctx, chID) if err != nil { t.Fatalf("ChannelMembers: %v", err) } if len(members) != 0 { t.Errorf("expected 0, got %d", len(members)) } } func TestSendDM(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid1, _, _ := d.CreateUser(ctx, nickAlice) uid2, _, _ := d.CreateUser(ctx, nickBob) msgID, err := d.SendDM(ctx, uid1, uid2, "hey bob") if err != nil { t.Fatalf("SendDM: %v", err) } if msgID <= 0 { t.Errorf("expected positive msgID, got %d", msgID) } msgs, err := d.GetDMs(ctx, uid1, uid2, 0, 0) if err != nil { t.Fatalf("GetDMs: %v", err) } if len(msgs) != 1 { t.Fatalf("expected 1 DM, got %d", len(msgs)) } if msgs[0].Content != "hey bob" { t.Errorf("got %q, want %q", msgs[0].Content, "hey bob") } } func TestPollMessages(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid1, _, _ := d.CreateUser(ctx, nickAlice) uid2, _, _ := d.CreateUser(ctx, nickBob) chID, _ := d.GetOrCreateChannel(ctx, "#general") _ = d.JoinChannel(ctx, chID, uid1) _ = d.JoinChannel(ctx, chID, uid2) _, _ = d.SendMessage(ctx, chID, uid2, "hello") _, _ = d.SendDM(ctx, uid2, uid1, "private") time.Sleep(10 * time.Millisecond) msgs, err := d.PollMessages(ctx, uid1, 0, 0) if err != nil { t.Fatalf("PollMessages: %v", err) } if len(msgs) < 2 { t.Fatalf("expected >=2 messages, got %d", len(msgs)) } } func TestChangeNick(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() _, token, _ := d.CreateUser(ctx, nickAlice) err := d.ChangeNick(ctx, 1, "alice2") if err != nil { t.Fatalf("ChangeNick: %v", err) } _, nick, err := d.GetUserByToken(ctx, token) if err != nil { t.Fatalf("GetUserByToken: %v", err) } if nick != "alice2" { t.Errorf("got nick %q, want alice2", nick) } } func TestSetTopic(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid, _, _ := d.CreateUser(ctx, nickAlice) _, _ = d.GetOrCreateChannel(ctx, "#general") err := d.SetTopic(ctx, "#general", uid, "new topic") if err != nil { t.Fatalf("SetTopic: %v", err) } channels, err := d.ListAllChannels(ctx) if err != nil { t.Fatalf("ListAllChannels: %v", err) } found := false for _, ch := range channels { if ch.Name == "#general" && ch.Topic == "new topic" { found = true } } if !found { t.Error("topic was not updated") } } func TestGetMessagesBefore(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() uid, _, _ := d.CreateUser(ctx, nickAlice) chID, _ := d.GetOrCreateChannel(ctx, "#general") _ = d.JoinChannel(ctx, chID, uid) for i := range 5 { _, _ = d.SendMessage( ctx, chID, uid, fmt.Sprintf("msg%d", i), ) time.Sleep(10 * time.Millisecond) } msgs, err := d.GetMessagesBefore(ctx, chID, 0, 3) if err != nil { t.Fatalf("GetMessagesBefore: %v", err) } if len(msgs) != 3 { t.Fatalf("expected 3, got %d", len(msgs)) } } func TestListAllChannels(t *testing.T) { t.Parallel() d := setupTestDB(t) ctx := context.Background() _, _ = d.GetOrCreateChannel(ctx, "#alpha") _, _ = d.GetOrCreateChannel(ctx, "#beta") channels, err := d.ListAllChannels(ctx) if err != nil { t.Fatalf("ListAllChannels: %v", err) } if len(channels) != 2 { t.Errorf("expected 2, got %d", len(channels)) } }