package ircserver_test import ( "bufio" "database/sql" "fmt" "log/slog" "net" "os" "strings" "testing" "time" "git.eeqj.de/sneak/neoirc/internal/broker" "git.eeqj.de/sneak/neoirc/internal/config" "git.eeqj.de/sneak/neoirc/internal/db" "git.eeqj.de/sneak/neoirc/internal/ircserver" _ "modernc.org/sqlite" ) const testTimeout = 5 * time.Second func TestMain(m *testing.M) { db.SetBcryptCost(4) os.Exit(m.Run()) } // testEnv holds the shared test infrastructure. type testEnv struct { database *db.Database brk *broker.Broker cfg *config.Config srv *ircserver.Server } func newTestEnv(t *testing.T) *testEnv { t.Helper() dsn := fmt.Sprintf( "file:%s?mode=memory&cache=shared&_journal_mode=WAL", t.Name(), ) conn, err := sql.Open("sqlite", dsn) if err != nil { t.Fatalf("open db: %v", err) } conn.SetMaxOpenConns(1) _, err = conn.ExecContext( t.Context(), "PRAGMA foreign_keys = ON", ) if err != nil { t.Fatalf("pragma: %v", err) } database := db.NewTestDatabaseFromConn(conn) err = database.RunMigrations(t.Context()) if err != nil { t.Fatalf("migrate: %v", err) } brk := broker.New() cfg := &config.Config{ //nolint:exhaustruct ServerName: "test.irc", MOTD: "Welcome to test IRC", } listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } addr := listener.Addr().String() err = listener.Close() if err != nil { t.Fatalf("close listener: %v", err) } log := slog.New(slog.NewTextHandler( os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}, //nolint:exhaustruct )) srv := ircserver.NewTestServer(log, cfg, database, brk) err = srv.Start(addr) if err != nil { t.Fatalf("start irc server: %v", err) } t.Cleanup(func() { srv.Stop() err := conn.Close() if err != nil { t.Logf("close db: %v", err) } }) return &testEnv{ database: database, brk: brk, cfg: cfg, srv: srv, } } // dial connects to the test server. func (env *testEnv) dial(t *testing.T) *testClient { t.Helper() conn, err := net.DialTimeout( "tcp", env.srv.Listener().Addr().String(), testTimeout, ) if err != nil { t.Fatalf("dial: %v", err) } t.Cleanup(func() { err := conn.Close() if err != nil { t.Logf("close conn: %v", err) } }) return &testClient{ t: t, conn: conn, scanner: bufio.NewScanner(conn), } } // testClient wraps a raw TCP connection with helpers. type testClient struct { t *testing.T conn net.Conn scanner *bufio.Scanner } func (tc *testClient) send(line string) { tc.t.Helper() _ = tc.conn.SetWriteDeadline( time.Now().Add(testTimeout), ) _, err := fmt.Fprintf(tc.conn, "%s\r\n", line) if err != nil { tc.t.Fatalf("send: %v", err) } } func (tc *testClient) readLine() string { tc.t.Helper() _ = tc.conn.SetReadDeadline( time.Now().Add(testTimeout), ) if !tc.scanner.Scan() { err := tc.scanner.Err() if err != nil { tc.t.Fatalf("read: %v", err) } tc.t.Fatal("connection closed unexpectedly") } return tc.scanner.Text() } // readUntil reads lines until one matches the predicate. func (tc *testClient) readUntil( pred func(string) bool, ) []string { tc.t.Helper() var lines []string for { line := tc.readLine() lines = append(lines, line) if pred(line) { return lines } } } // register sends NICK + USER and reads through the welcome // burst. func (tc *testClient) register(nick string) []string { tc.t.Helper() tc.send("NICK " + nick) tc.send("USER " + nick + " 0 * :Test User") return tc.readUntil(func(line string) bool { return strings.Contains(line, " 376 ") || strings.Contains(line, " 422 ") }) } // assertContains checks that at least one line matches the // given substring. func assertContains( t *testing.T, lines []string, substr, description string, ) { t.Helper() for _, line := range lines { if strings.Contains(line, substr) { return } } t.Errorf("did not find %q in output: %s", substr, description) } // joinAndDrain joins a channel and reads until // RPL_ENDOFNAMES. func (tc *testClient) joinAndDrain(channel string) { tc.t.Helper() tc.send("JOIN " + channel) tc.readUntil(func(line string) bool { return strings.Contains(line, " 366 ") }) } // sendAndExpect sends a command and reads until a line // containing the expected substring is found. func (tc *testClient) sendAndExpect( cmd, expect string, ) []string { tc.t.Helper() tc.send(cmd) return tc.readUntil(func(line string) bool { return strings.Contains(line, expect) }) } func TestRegistration(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) lines := client.register("alice") assertContains(t, lines, " 001 ", "RPL_WELCOME") } func TestWelcomeContainsNick(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) lines := client.register("bob") for _, line := range lines { if strings.Contains(line, " 001 ") && !strings.Contains(line, "bob") { t.Errorf("001 should contain nick: %s", line) } } } func TestPingPong(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("pingtest") lines := client.sendAndExpect("PING :hello", "PONG") assertContains(t, lines, "PONG", "PONG response") } func TestJoinChannel(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("joiner") client.send("JOIN #test") lines := client.readUntil(func(line string) bool { return strings.Contains(line, " 366 ") }) assertContains(t, lines, "JOIN", "JOIN echo") assertContains(t, lines, " 366 ", "RPL_ENDOFNAMES") } func TestPrivmsgBetweenClients(t *testing.T) { t.Parallel() env := newTestEnv(t) alice := env.dial(t) alice.register("alice_pm") bob := env.dial(t) bob.register("bob_pm") alice.joinAndDrain("#chat") bob.joinAndDrain("#chat") alice.send("PRIVMSG #chat :hello bob!") lines := bob.sendAndExpect("PING :sync", "hello bob!") assertContains(t, lines, "hello bob!", "channel PRIVMSG") } func TestNickChange(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("oldnick") lines := client.sendAndExpect("NICK newnick", "newnick") assertContains(t, lines, "NICK", "NICK change echo") } func TestDuplicateNick(t *testing.T) { t.Parallel() env := newTestEnv(t) first := env.dial(t) first.register("taken") second := env.dial(t) second.send("NICK taken") second.send("USER taken 0 * :Test") lines := second.readUntil(func(line string) bool { return strings.Contains(line, " 433 ") }) assertContains(t, lines, " 433 ", "ERR_NICKNAMEINUSE") } func TestListChannels(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("lister") client.joinAndDrain("#listtest") lines := client.sendAndExpect("LIST", " 323 ") assertContains(t, lines, " 323 ", "RPL_LISTEND") //nolint:misspell // IRC term } func TestWhois(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("whoistest") lines := client.sendAndExpect( "WHOIS whoistest", " 318 ", ) assertContains(t, lines, " 311 ", "RPL_WHOISUSER") assertContains(t, lines, " 318 ", "RPL_ENDOFWHOIS") } func TestQuit(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("quitter") lines := client.sendAndExpect( "QUIT :goodbye", "ERROR", ) assertContains(t, lines, "goodbye", "QUIT reason") } func TestTopicSetAndGet(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("topicuser") client.joinAndDrain("#topictest") lines := client.sendAndExpect( "TOPIC #topictest :New topic here", "New topic here", ) assertContains( t, lines, "New topic here", "TOPIC echo", ) } func TestUnknownCommand(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("unknowncmd") lines := client.sendAndExpect("FOOBAR", " 421 ") assertContains(t, lines, " 421 ", "ERR_UNKNOWNCOMMAND") } func TestDirectMessage(t *testing.T) { t.Parallel() env := newTestEnv(t) sender := env.dial(t) sender.register("dmsender") receiver := env.dial(t) receiver.register("dmreceiver") // Give relay goroutines time to start. time.Sleep(100 * time.Millisecond) sender.send("PRIVMSG dmreceiver :hello privately") lines := receiver.readUntil(func(line string) bool { return strings.Contains(line, "hello privately") }) assertContains( t, lines, "hello privately", "direct PRIVMSG", ) } func TestCAPNegotiation(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.send("CAP LS 302") line := client.readLine() if !strings.Contains(line, "CAP") { t.Errorf("expected CAP response, got: %s", line) } client.send("CAP END") lines := client.register("capuser") assertContains(t, lines, " 001 ", "registration after CAP") } func TestPartChannel(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("parter") client.joinAndDrain("#parttest") lines := client.sendAndExpect( "PART #parttest :leaving", "PART", ) assertContains(t, lines, "#parttest", "PART echo") } func TestModeQuery(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("modeuser") client.joinAndDrain("#modetest") lines := client.sendAndExpect( "MODE #modetest", " 324 ", ) assertContains( t, lines, " 324 ", "RPL_CHANNELMODEIS", ) } func TestWhoChannel(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("whouser") client.joinAndDrain("#whotest") lines := client.sendAndExpect("WHO #whotest", " 315 ") assertContains(t, lines, " 352 ", "RPL_WHOREPLY") } func TestLusers(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("luseruser") lines := client.sendAndExpect("LUSERS", " 255 ") assertContains(t, lines, " 251 ", "RPL_LUSERCLIENT") } func TestMotd(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("motduser") lines := client.sendAndExpect("MOTD", " 376 ") assertContains(t, lines, " 376 ", "RPL_ENDOFMOTD") } func TestAwaySetAndClear(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("awayuser") setLines := client.sendAndExpect( "AWAY :brb lunch", " 306 ", ) assertContains(t, setLines, " 306 ", "RPL_NOWAWAY") clearLines := client.sendAndExpect("AWAY", " 305 ") assertContains(t, clearLines, " 305 ", "RPL_UNAWAY") } func TestHandlePassPostRegistration(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("passuser") lines := client.sendAndExpect( "PASS :mypassword123", "Password set", ) assertContains( t, lines, "Password set", "password confirmation", ) } func TestPreRegistrationNotRegistered(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.send("PRIVMSG #test :hello") line := client.readLine() if !strings.Contains(line, " 451 ") { t.Errorf( "expected ERR_NOTREGISTERED (451), got: %s", line, ) } } func TestNamesNonExistentChannel(t *testing.T) { t.Parallel() env := newTestEnv(t) client := env.dial(t) client.register("namesuser") lines := client.sendAndExpect( "NAMES #doesnotexist", " 366 ", ) assertContains( t, lines, " 366 ", "RPL_ENDOFNAMES for non-existent channel", ) } func BenchmarkParseMessage(b *testing.B) { line := ":nick!user@host PRIVMSG #channel :Hello, world!" b.ResetTimer() for range b.N { _ = ircserver.ParseMessage(line) } } func BenchmarkFormatMessage(b *testing.B) { b.ResetTimer() for range b.N { _ = ircserver.FormatMessage( "nick!user@host", "PRIVMSG", "#channel", "Hello, world!", ) } }