From 7047167dc88ecb612d4bb07a193396573146ed6e Mon Sep 17 00:00:00 2001 From: clawbot Date: Fri, 27 Feb 2026 04:56:51 -0800 Subject: [PATCH] Add tests for register and login endpoints --- internal/db/auth_test.go | 178 ++++++++++++++++++++ internal/handlers/api_test.go | 304 ++++++++++++++++++++++++++++++++++ internal/handlers/auth.go | 4 +- internal/server/routes.go | 89 +++++----- 4 files changed, 528 insertions(+), 47 deletions(-) create mode 100644 internal/db/auth_test.go diff --git a/internal/db/auth_test.go b/internal/db/auth_test.go new file mode 100644 index 0000000..5188925 --- /dev/null +++ b/internal/db/auth_test.go @@ -0,0 +1,178 @@ +package db_test + +import ( + "testing" + + _ "modernc.org/sqlite" +) + +func TestRegisterUser(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sessionID, clientID, token, err := + database.RegisterUser(ctx, "reguser", "password123") + if err != nil { + t.Fatal(err) + } + + if sessionID == 0 || clientID == 0 || token == "" { + t.Fatal("expected valid ids and token") + } + + // Verify session works via token lookup. + sid, cid, nick, err := + database.GetSessionByToken(ctx, token) + if err != nil { + t.Fatal(err) + } + + if sid != sessionID || cid != clientID { + t.Fatal("session/client id mismatch") + } + + if nick != "reguser" { + t.Fatalf("expected reguser, got %s", nick) + } +} + +func TestRegisterUserDuplicateNick(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + regSID, regCID, regToken, err := + database.RegisterUser(ctx, "dupnick", "password123") + if err != nil { + t.Fatal(err) + } + + _ = regSID + _ = regCID + _ = regToken + + dupSID, dupCID, dupToken, dupErr := + database.RegisterUser(ctx, "dupnick", "other12345") + if dupErr == nil { + t.Fatal("expected error for duplicate nick") + } + + _ = dupSID + _ = dupCID + _ = dupToken +} + +func TestLoginUser(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + regSID, regCID, regToken, err := + database.RegisterUser(ctx, "loginuser", "mypassword") + if err != nil { + t.Fatal(err) + } + + _ = regSID + _ = regCID + _ = regToken + + sessionID, clientID, token, err := + database.LoginUser(ctx, "loginuser", "mypassword") + if err != nil { + t.Fatal(err) + } + + if sessionID == 0 || clientID == 0 || token == "" { + t.Fatal("expected valid ids and token") + } + + // Verify the new token works. + _, _, nick, err := + database.GetSessionByToken(ctx, token) + if err != nil { + t.Fatal(err) + } + + if nick != "loginuser" { + t.Fatalf("expected loginuser, got %s", nick) + } +} + +func TestLoginUserWrongPassword(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + regSID, regCID, regToken, err := + database.RegisterUser(ctx, "wrongpw", "correctpass") + if err != nil { + t.Fatal(err) + } + + _ = regSID + _ = regCID + _ = regToken + + loginSID, loginCID, loginToken, loginErr := + database.LoginUser(ctx, "wrongpw", "wrongpass12") + if loginErr == nil { + t.Fatal("expected error for wrong password") + } + + _ = loginSID + _ = loginCID + _ = loginToken +} + +func TestLoginUserNoPassword(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + // Create anonymous session (no password). + anonSID, anonCID, anonToken, err := + database.CreateSession(ctx, "anon") + if err != nil { + t.Fatal(err) + } + + _ = anonSID + _ = anonCID + _ = anonToken + + loginSID, loginCID, loginToken, loginErr := + database.LoginUser(ctx, "anon", "anything1") + if loginErr == nil { + t.Fatal( + "expected error for login on passwordless account", + ) + } + + _ = loginSID + _ = loginCID + _ = loginToken +} + +func TestLoginUserNonexistent(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + loginSID, loginCID, loginToken, err := + database.LoginUser(ctx, "ghost", "password123") + if err == nil { + t.Fatal("expected error for nonexistent user") + } + + _ = loginSID + _ = loginCID + _ = loginToken +} diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index fdd4e75..1a5ce9e 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -1469,6 +1469,310 @@ func TestHealthcheck(t *testing.T) { } } +func TestRegisterValid(t *testing.T) { + tserver := newTestServer(t) + + body, err := json.Marshal(map[string]string{ + "nick": "reguser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf( + "expected 201, got %d: %s", + resp.StatusCode, respBody, + ) + } + + var result map[string]any + + _ = json.NewDecoder(resp.Body).Decode(&result) + + if result["token"] == nil || result["token"] == "" { + t.Fatal("expected token in response") + } + + if result["nick"] != "reguser" { + t.Fatalf( + "expected reguser, got %v", result["nick"], + ) + } +} + +func TestRegisterDuplicate(t *testing.T) { + tserver := newTestServer(t) + + body, err := json.Marshal(map[string]string{ + "nick": "dupuser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + _ = resp.Body.Close() + + resp2, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp2.Body.Close() }() + + if resp2.StatusCode != http.StatusConflict { + t.Fatalf("expected 409, got %d", resp2.StatusCode) + } +} + +func postJSONExpectStatus( + t *testing.T, + tserver *testServer, + path string, + payload map[string]string, + expectedStatus int, +) { + t.Helper() + + body, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url(path), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != expectedStatus { + t.Fatalf( + "expected %d, got %d", + expectedStatus, resp.StatusCode, + ) + } +} + +func TestRegisterShortPassword(t *testing.T) { + tserver := newTestServer(t) + + postJSONExpectStatus( + t, tserver, "/api/v1/register", + map[string]string{ + "nick": "shortpw", "password": "short", + }, + http.StatusBadRequest, + ) +} + +func TestRegisterInvalidNick(t *testing.T) { + tserver := newTestServer(t) + + postJSONExpectStatus( + t, tserver, "/api/v1/register", + map[string]string{ + "nick": "bad nick!", + "password": "password123", + }, + http.StatusBadRequest, + ) +} + +func TestLoginValid(t *testing.T) { + tserver := newTestServer(t) + + // Register first. + regBody, err := json.Marshal(map[string]string{ + "nick": "loginuser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(regBody), + ) + if err != nil { + t.Fatal(err) + } + + _ = resp.Body.Close() + + // Login. + loginBody, err := json.Marshal(map[string]string{ + "nick": "loginuser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp2, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/login"), + bytes.NewReader(loginBody), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp2.Body.Close() }() + + if resp2.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp2.Body) + t.Fatalf( + "expected 200, got %d: %s", + resp2.StatusCode, respBody, + ) + } + + var result map[string]any + + _ = json.NewDecoder(resp2.Body).Decode(&result) + + if result["token"] == nil || result["token"] == "" { + t.Fatal("expected token in response") + } + + // Verify token works. + token, ok := result["token"].(string) + if !ok { + t.Fatal("token not a string") + } + + status, state := tserver.getState(token) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) + } + + if state["nick"] != "loginuser" { + t.Fatalf( + "expected loginuser, got %v", + state["nick"], + ) + } +} + +func TestLoginWrongPassword(t *testing.T) { + tserver := newTestServer(t) + + regBody, err := json.Marshal(map[string]string{ + "nick": "wrongpwuser", "password": "correctpass1", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(regBody), + ) + if err != nil { + t.Fatal(err) + } + + _ = resp.Body.Close() + + loginBody, err := json.Marshal(map[string]string{ + "nick": "wrongpwuser", "password": "wrongpass12", + }) + if err != nil { + t.Fatal(err) + } + + resp2, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/login"), + bytes.NewReader(loginBody), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp2.Body.Close() }() + + if resp2.StatusCode != http.StatusUnauthorized { + t.Fatalf( + "expected 401, got %d", resp2.StatusCode, + ) + } +} + +func TestLoginNonexistentUser(t *testing.T) { + tserver := newTestServer(t) + + postJSONExpectStatus( + t, tserver, "/api/v1/login", + map[string]string{ + "nick": "ghostuser", + "password": "password123", + }, + http.StatusUnauthorized, + ) +} + +func TestSessionStillWorks(t *testing.T) { + tserver := newTestServer(t) + + // Verify anonymous session creation still works. + token := tserver.createSession("anon_user") + if token == "" { + t.Fatal("expected token for anonymous session") + } + + status, state := tserver.getState(token) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) + } + + if state["nick"] != "anon_user" { + t.Fatalf( + "expected anon_user, got %v", + state["nick"], + ) + } +} + func TestNickBroadcastToChannels(t *testing.T) { tserver := newTestServer(t) aliceToken := tserver.createSession("nick_a") diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index bc44866..e13fc3b 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -28,7 +28,7 @@ func (hdlr *Handlers) handleRegister( ) { type registerRequest struct { Nick string `json:"nick"` - Password string `json:"password"` + Password string `json:"password"` //nolint:gosec // not a hardcoded secret } var payload registerRequest @@ -134,7 +134,7 @@ func (hdlr *Handlers) handleLogin( ) { type loginRequest struct { Nick string `json:"nick"` - Password string `json:"password"` + Password string `json:"password"` //nolint:gosec // not a hardcoded secret } var payload loginRequest diff --git a/internal/server/routes.go b/internal/server/routes.go index ebc8fc4..e7b632e 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -59,56 +59,55 @@ func (srv *Server) SetupRoutes() { } // API v1. - srv.router.Route( - "/api/v1", - func(router chi.Router) { - router.Get( - "/server", - srv.handlers.HandleServerInfo(), - ) - router.Post( - "/session", - srv.handlers.HandleCreateSession(), - ) - router.Post( - "/register", - srv.handlers.HandleRegister(), - ) - router.Post( - "/login", - srv.handlers.HandleLogin(), - ) - router.Get( - "/state", - srv.handlers.HandleState(), - ) - router.Get( - "/messages", - srv.handlers.HandleGetMessages(), - ) - router.Post( - "/messages", - srv.handlers.HandleSendCommand(), - ) - router.Get( - "/history", - srv.handlers.HandleGetHistory(), - ) - router.Get( - "/channels", - srv.handlers.HandleListAllChannels(), - ) - router.Get( - "/channels/{channel}/members", - srv.handlers.HandleChannelMembers(), - ) - }, - ) + srv.router.Route("/api/v1", srv.setupAPIv1) // Serve embedded SPA. srv.setupSPA() } +func (srv *Server) setupAPIv1(router chi.Router) { + router.Get( + "/server", + srv.handlers.HandleServerInfo(), + ) + router.Post( + "/session", + srv.handlers.HandleCreateSession(), + ) + router.Post( + "/register", + srv.handlers.HandleRegister(), + ) + router.Post( + "/login", + srv.handlers.HandleLogin(), + ) + router.Get( + "/state", + srv.handlers.HandleState(), + ) + router.Get( + "/messages", + srv.handlers.HandleGetMessages(), + ) + router.Post( + "/messages", + srv.handlers.HandleSendCommand(), + ) + router.Get( + "/history", + srv.handlers.HandleGetHistory(), + ) + router.Get( + "/channels", + srv.handlers.HandleListAllChannels(), + ) + router.Get( + "/channels/{channel}/members", + srv.handlers.HandleChannelMembers(), + ) +} + func (srv *Server) setupSPA() { distFS, err := fs.Sub(web.Dist, "dist") if err != nil {