Add tests for register and login endpoints
Some checks failed
check / check (push) Failing after 1m39s
Some checks failed
check / check (push) Failing after 1m39s
This commit is contained in:
178
internal/db/auth_test.go
Normal file
178
internal/db/auth_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package db_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRegisterUser(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
sessionID, clientID, token, err :=
|
||||||
|
database.RegisterUser(ctx, "reguser", "password123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionID == 0 || clientID == 0 || token == "" {
|
||||||
|
t.Fatal("expected valid ids and token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session works via token lookup.
|
||||||
|
sid, cid, nick, err :=
|
||||||
|
database.GetSessionByToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sid != sessionID || cid != clientID {
|
||||||
|
t.Fatal("session/client id mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if nick != "reguser" {
|
||||||
|
t.Fatalf("expected reguser, got %s", nick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterUserDuplicateNick(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
regSID, regCID, regToken, err :=
|
||||||
|
database.RegisterUser(ctx, "dupnick", "password123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = regSID
|
||||||
|
_ = regCID
|
||||||
|
_ = regToken
|
||||||
|
|
||||||
|
dupSID, dupCID, dupToken, dupErr :=
|
||||||
|
database.RegisterUser(ctx, "dupnick", "other12345")
|
||||||
|
if dupErr == nil {
|
||||||
|
t.Fatal("expected error for duplicate nick")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = dupSID
|
||||||
|
_ = dupCID
|
||||||
|
_ = dupToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginUser(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
regSID, regCID, regToken, err :=
|
||||||
|
database.RegisterUser(ctx, "loginuser", "mypassword")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = regSID
|
||||||
|
_ = regCID
|
||||||
|
_ = regToken
|
||||||
|
|
||||||
|
sessionID, clientID, token, err :=
|
||||||
|
database.LoginUser(ctx, "loginuser", "mypassword")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionID == 0 || clientID == 0 || token == "" {
|
||||||
|
t.Fatal("expected valid ids and token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the new token works.
|
||||||
|
_, _, nick, err :=
|
||||||
|
database.GetSessionByToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nick != "loginuser" {
|
||||||
|
t.Fatalf("expected loginuser, got %s", nick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginUserWrongPassword(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
regSID, regCID, regToken, err :=
|
||||||
|
database.RegisterUser(ctx, "wrongpw", "correctpass")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = regSID
|
||||||
|
_ = regCID
|
||||||
|
_ = regToken
|
||||||
|
|
||||||
|
loginSID, loginCID, loginToken, loginErr :=
|
||||||
|
database.LoginUser(ctx, "wrongpw", "wrongpass12")
|
||||||
|
if loginErr == nil {
|
||||||
|
t.Fatal("expected error for wrong password")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = loginSID
|
||||||
|
_ = loginCID
|
||||||
|
_ = loginToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginUserNoPassword(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
// Create anonymous session (no password).
|
||||||
|
anonSID, anonCID, anonToken, err :=
|
||||||
|
database.CreateSession(ctx, "anon")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = anonSID
|
||||||
|
_ = anonCID
|
||||||
|
_ = anonToken
|
||||||
|
|
||||||
|
loginSID, loginCID, loginToken, loginErr :=
|
||||||
|
database.LoginUser(ctx, "anon", "anything1")
|
||||||
|
if loginErr == nil {
|
||||||
|
t.Fatal(
|
||||||
|
"expected error for login on passwordless account",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = loginSID
|
||||||
|
_ = loginCID
|
||||||
|
_ = loginToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginUserNonexistent(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
database := setupTestDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
|
||||||
|
loginSID, loginCID, loginToken, err :=
|
||||||
|
database.LoginUser(ctx, "ghost", "password123")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nonexistent user")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = loginSID
|
||||||
|
_ = loginCID
|
||||||
|
_ = loginToken
|
||||||
|
}
|
||||||
@@ -1469,6 +1469,310 @@ func TestHealthcheck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRegisterValid(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]string{
|
||||||
|
"nick": "reguser", "password": "password123",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url("/api/v1/register"),
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusCreated {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf(
|
||||||
|
"expected 201, got %d: %s",
|
||||||
|
resp.StatusCode, respBody,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
|
||||||
|
_ = json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
if result["token"] == nil || result["token"] == "" {
|
||||||
|
t.Fatal("expected token in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["nick"] != "reguser" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected reguser, got %v", result["nick"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterDuplicate(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]string{
|
||||||
|
"nick": "dupuser", "password": "password123",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url("/api/v1/register"),
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
resp2, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url("/api/v1/register"),
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = resp2.Body.Close() }()
|
||||||
|
|
||||||
|
if resp2.StatusCode != http.StatusConflict {
|
||||||
|
t.Fatalf("expected 409, got %d", resp2.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func postJSONExpectStatus(
|
||||||
|
t *testing.T,
|
||||||
|
tserver *testServer,
|
||||||
|
path string,
|
||||||
|
payload map[string]string,
|
||||||
|
expectedStatus int,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url(path),
|
||||||
|
bytes.NewReader(body),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != expectedStatus {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected %d, got %d",
|
||||||
|
expectedStatus, resp.StatusCode,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterShortPassword(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
postJSONExpectStatus(
|
||||||
|
t, tserver, "/api/v1/register",
|
||||||
|
map[string]string{
|
||||||
|
"nick": "shortpw", "password": "short",
|
||||||
|
},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInvalidNick(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
postJSONExpectStatus(
|
||||||
|
t, tserver, "/api/v1/register",
|
||||||
|
map[string]string{
|
||||||
|
"nick": "bad nick!",
|
||||||
|
"password": "password123",
|
||||||
|
},
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginValid(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
// Register first.
|
||||||
|
regBody, err := json.Marshal(map[string]string{
|
||||||
|
"nick": "loginuser", "password": "password123",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url("/api/v1/register"),
|
||||||
|
bytes.NewReader(regBody),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
// Login.
|
||||||
|
loginBody, err := json.Marshal(map[string]string{
|
||||||
|
"nick": "loginuser", "password": "password123",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp2, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url("/api/v1/login"),
|
||||||
|
bytes.NewReader(loginBody),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = resp2.Body.Close() }()
|
||||||
|
|
||||||
|
if resp2.StatusCode != http.StatusOK {
|
||||||
|
respBody, _ := io.ReadAll(resp2.Body)
|
||||||
|
t.Fatalf(
|
||||||
|
"expected 200, got %d: %s",
|
||||||
|
resp2.StatusCode, respBody,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
|
||||||
|
_ = json.NewDecoder(resp2.Body).Decode(&result)
|
||||||
|
|
||||||
|
if result["token"] == nil || result["token"] == "" {
|
||||||
|
t.Fatal("expected token in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify token works.
|
||||||
|
token, ok := result["token"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("token not a string")
|
||||||
|
}
|
||||||
|
|
||||||
|
status, state := tserver.getState(token)
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state["nick"] != "loginuser" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected loginuser, got %v",
|
||||||
|
state["nick"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginWrongPassword(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
regBody, err := json.Marshal(map[string]string{
|
||||||
|
"nick": "wrongpwuser", "password": "correctpass1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url("/api/v1/register"),
|
||||||
|
bytes.NewReader(regBody),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
loginBody, err := json.Marshal(map[string]string{
|
||||||
|
"nick": "wrongpwuser", "password": "wrongpass12",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp2, err := doRequest(
|
||||||
|
t,
|
||||||
|
http.MethodPost,
|
||||||
|
tserver.url("/api/v1/login"),
|
||||||
|
bytes.NewReader(loginBody),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = resp2.Body.Close() }()
|
||||||
|
|
||||||
|
if resp2.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected 401, got %d", resp2.StatusCode,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginNonexistentUser(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
postJSONExpectStatus(
|
||||||
|
t, tserver, "/api/v1/login",
|
||||||
|
map[string]string{
|
||||||
|
"nick": "ghostuser",
|
||||||
|
"password": "password123",
|
||||||
|
},
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStillWorks(t *testing.T) {
|
||||||
|
tserver := newTestServer(t)
|
||||||
|
|
||||||
|
// Verify anonymous session creation still works.
|
||||||
|
token := tserver.createSession("anon_user")
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("expected token for anonymous session")
|
||||||
|
}
|
||||||
|
|
||||||
|
status, state := tserver.getState(token)
|
||||||
|
if status != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state["nick"] != "anon_user" {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected anon_user, got %v",
|
||||||
|
state["nick"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNickBroadcastToChannels(t *testing.T) {
|
func TestNickBroadcastToChannels(t *testing.T) {
|
||||||
tserver := newTestServer(t)
|
tserver := newTestServer(t)
|
||||||
aliceToken := tserver.createSession("nick_a")
|
aliceToken := tserver.createSession("nick_a")
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func (hdlr *Handlers) handleRegister(
|
|||||||
) {
|
) {
|
||||||
type registerRequest struct {
|
type registerRequest struct {
|
||||||
Nick string `json:"nick"`
|
Nick string `json:"nick"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"` //nolint:gosec // not a hardcoded secret
|
||||||
}
|
}
|
||||||
|
|
||||||
var payload registerRequest
|
var payload registerRequest
|
||||||
@@ -134,7 +134,7 @@ func (hdlr *Handlers) handleLogin(
|
|||||||
) {
|
) {
|
||||||
type loginRequest struct {
|
type loginRequest struct {
|
||||||
Nick string `json:"nick"`
|
Nick string `json:"nick"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"` //nolint:gosec // not a hardcoded secret
|
||||||
}
|
}
|
||||||
|
|
||||||
var payload loginRequest
|
var payload loginRequest
|
||||||
|
|||||||
@@ -59,9 +59,13 @@ func (srv *Server) SetupRoutes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// API v1.
|
// API v1.
|
||||||
srv.router.Route(
|
srv.router.Route("/api/v1", srv.setupAPIv1)
|
||||||
"/api/v1",
|
|
||||||
func(router chi.Router) {
|
// Serve embedded SPA.
|
||||||
|
srv.setupSPA()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) setupAPIv1(router chi.Router) {
|
||||||
router.Get(
|
router.Get(
|
||||||
"/server",
|
"/server",
|
||||||
srv.handlers.HandleServerInfo(),
|
srv.handlers.HandleServerInfo(),
|
||||||
@@ -102,11 +106,6 @@ func (srv *Server) SetupRoutes() {
|
|||||||
"/channels/{channel}/members",
|
"/channels/{channel}/members",
|
||||||
srv.handlers.HandleChannelMembers(),
|
srv.handlers.HandleChannelMembers(),
|
||||||
)
|
)
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
// Serve embedded SPA.
|
|
||||||
srv.setupSPA()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) setupSPA() {
|
func (srv *Server) setupSPA() {
|
||||||
|
|||||||
Reference in New Issue
Block a user