diff --git a/go.mod b/go.mod index b1b1166..53d7d54 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,18 @@ go 1.24.0 require ( github.com/99designs/basicauth-go v0.0.0-20230316000542-bf6f9cbbf0f8 + github.com/gdamore/tcell/v2 v2.13.8 github.com/getsentry/sentry-go v0.42.0 github.com/go-chi/chi v1.5.5 github.com/go-chi/cors v1.2.2 + github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/prometheus/client_golang v1.23.2 + github.com/rivo/tview v0.42.0 github.com/slok/go-http-metrics v0.13.0 github.com/spf13/viper v1.21.0 go.uber.org/fx v1.24.0 + golang.org/x/crypto v0.48.0 modernc.org/sqlite v1.45.0 ) @@ -21,9 +25,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect - github.com/gdamore/tcell/v2 v2.13.8 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect @@ -33,7 +35,6 @@ require ( github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/tview v0.42.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect @@ -47,9 +48,9 @@ require ( go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/sys v0.38.0 // indirect - golang.org/x/term v0.37.0 // indirect - golang.org/x/text v0.31.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/term v0.40.0 // indirect + golang.org/x/text v0.34.0 // indirect google.golang.org/protobuf v1.36.8 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index 7d3b111..764fe68 100644 --- a/go.sum +++ b/go.sum @@ -113,12 +113,14 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= -golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -126,8 +128,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -135,30 +137,26 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= -golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= -golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= diff --git a/internal/db/auth.go b/internal/db/auth.go new file mode 100644 index 0000000..b27eed9 --- /dev/null +++ b/internal/db/auth.go @@ -0,0 +1,161 @@ +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" +) + +const bcryptCost = bcrypt.DefaultCost + +var errNoPassword = errors.New( + "account has no password set", +) + +// RegisterUser creates a session with a hashed password +// and returns session ID, client ID, and token. +func (database *Database) RegisterUser( + ctx context.Context, + nick, password string, +) (int64, int64, string, error) { + hash, err := bcrypt.GenerateFromPassword( + []byte(password), bcryptCost, + ) + if err != nil { + return 0, 0, "", fmt.Errorf( + "hash password: %w", err, + ) + } + + sessionUUID := uuid.New().String() + clientUUID := uuid.New().String() + + token, err := generateToken() + if err != nil { + return 0, 0, "", err + } + + now := time.Now() + + transaction, err := database.conn.BeginTx(ctx, nil) + if err != nil { + return 0, 0, "", fmt.Errorf( + "begin tx: %w", err, + ) + } + + res, err := transaction.ExecContext(ctx, + `INSERT INTO sessions + (uuid, nick, password_hash, + created_at, last_seen) + VALUES (?, ?, ?, ?, ?)`, + sessionUUID, nick, string(hash), now, now) + if err != nil { + _ = transaction.Rollback() + + return 0, 0, "", fmt.Errorf( + "create session: %w", err, + ) + } + + sessionID, _ := res.LastInsertId() + + clientRes, err := transaction.ExecContext(ctx, + `INSERT INTO clients + (uuid, session_id, token, + created_at, last_seen) + VALUES (?, ?, ?, ?, ?)`, + clientUUID, sessionID, token, now, now) + if err != nil { + _ = transaction.Rollback() + + return 0, 0, "", fmt.Errorf( + "create client: %w", err, + ) + } + + clientID, _ := clientRes.LastInsertId() + + err = transaction.Commit() + if err != nil { + return 0, 0, "", fmt.Errorf( + "commit registration: %w", err, + ) + } + + return sessionID, clientID, token, nil +} + +// LoginUser verifies a nick/password and creates a new +// client token. +func (database *Database) LoginUser( + ctx context.Context, + nick, password string, +) (int64, int64, string, error) { + var ( + sessionID int64 + passwordHash string + ) + + err := database.conn.QueryRowContext( + ctx, + `SELECT id, password_hash + FROM sessions WHERE nick = ?`, + nick, + ).Scan(&sessionID, &passwordHash) + if err != nil { + return 0, 0, "", fmt.Errorf( + "get session for login: %w", err, + ) + } + + if passwordHash == "" { + return 0, 0, "", fmt.Errorf( + "login: %w", errNoPassword, + ) + } + + err = bcrypt.CompareHashAndPassword( + []byte(passwordHash), []byte(password), + ) + if err != nil { + return 0, 0, "", fmt.Errorf( + "verify password: %w", err, + ) + } + + clientUUID := uuid.New().String() + + token, err := generateToken() + if err != nil { + return 0, 0, "", err + } + + now := time.Now() + + res, err := database.conn.ExecContext(ctx, + `INSERT INTO clients + (uuid, session_id, token, + created_at, last_seen) + VALUES (?, ?, ?, ?, ?)`, + clientUUID, sessionID, token, now, now) + if err != nil { + return 0, 0, "", fmt.Errorf( + "create login client: %w", err, + ) + } + + clientID, _ := res.LastInsertId() + + _, _ = database.conn.ExecContext( + ctx, + "UPDATE sessions SET last_seen = ? WHERE id = ?", + now, sessionID, + ) + + return sessionID, clientID, token, nil +} diff --git a/internal/db/auth_test.go b/internal/db/auth_test.go new file mode 100644 index 0000000..5188925 --- /dev/null +++ b/internal/db/auth_test.go @@ -0,0 +1,178 @@ +package db_test + +import ( + "testing" + + _ "modernc.org/sqlite" +) + +func TestRegisterUser(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + sessionID, clientID, token, err := + database.RegisterUser(ctx, "reguser", "password123") + if err != nil { + t.Fatal(err) + } + + if sessionID == 0 || clientID == 0 || token == "" { + t.Fatal("expected valid ids and token") + } + + // Verify session works via token lookup. + sid, cid, nick, err := + database.GetSessionByToken(ctx, token) + if err != nil { + t.Fatal(err) + } + + if sid != sessionID || cid != clientID { + t.Fatal("session/client id mismatch") + } + + if nick != "reguser" { + t.Fatalf("expected reguser, got %s", nick) + } +} + +func TestRegisterUserDuplicateNick(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + regSID, regCID, regToken, err := + database.RegisterUser(ctx, "dupnick", "password123") + if err != nil { + t.Fatal(err) + } + + _ = regSID + _ = regCID + _ = regToken + + dupSID, dupCID, dupToken, dupErr := + database.RegisterUser(ctx, "dupnick", "other12345") + if dupErr == nil { + t.Fatal("expected error for duplicate nick") + } + + _ = dupSID + _ = dupCID + _ = dupToken +} + +func TestLoginUser(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + regSID, regCID, regToken, err := + database.RegisterUser(ctx, "loginuser", "mypassword") + if err != nil { + t.Fatal(err) + } + + _ = regSID + _ = regCID + _ = regToken + + sessionID, clientID, token, err := + database.LoginUser(ctx, "loginuser", "mypassword") + if err != nil { + t.Fatal(err) + } + + if sessionID == 0 || clientID == 0 || token == "" { + t.Fatal("expected valid ids and token") + } + + // Verify the new token works. + _, _, nick, err := + database.GetSessionByToken(ctx, token) + if err != nil { + t.Fatal(err) + } + + if nick != "loginuser" { + t.Fatalf("expected loginuser, got %s", nick) + } +} + +func TestLoginUserWrongPassword(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + regSID, regCID, regToken, err := + database.RegisterUser(ctx, "wrongpw", "correctpass") + if err != nil { + t.Fatal(err) + } + + _ = regSID + _ = regCID + _ = regToken + + loginSID, loginCID, loginToken, loginErr := + database.LoginUser(ctx, "wrongpw", "wrongpass12") + if loginErr == nil { + t.Fatal("expected error for wrong password") + } + + _ = loginSID + _ = loginCID + _ = loginToken +} + +func TestLoginUserNoPassword(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + // Create anonymous session (no password). + anonSID, anonCID, anonToken, err := + database.CreateSession(ctx, "anon") + if err != nil { + t.Fatal(err) + } + + _ = anonSID + _ = anonCID + _ = anonToken + + loginSID, loginCID, loginToken, loginErr := + database.LoginUser(ctx, "anon", "anything1") + if loginErr == nil { + t.Fatal( + "expected error for login on passwordless account", + ) + } + + _ = loginSID + _ = loginCID + _ = loginToken +} + +func TestLoginUserNonexistent(t *testing.T) { + t.Parallel() + + database := setupTestDB(t) + ctx := t.Context() + + loginSID, loginCID, loginToken, err := + database.LoginUser(ctx, "ghost", "password123") + if err == nil { + t.Fatal("expected error for nonexistent user") + } + + _ = loginSID + _ = loginCID + _ = loginToken +} diff --git a/internal/db/schema/001_initial.sql b/internal/db/schema/001_initial.sql index 14ebcf5..67ccfa6 100644 --- a/internal/db/schema/001_initial.sql +++ b/internal/db/schema/001_initial.sql @@ -1,11 +1,12 @@ -- Chat server schema (pre-1.0 consolidated) PRAGMA foreign_keys = ON; --- Sessions: IRC-style sessions (no passwords, nick + optional signing key) +-- Sessions: each session is a user identity (nick + optional password + signing key) CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, uuid TEXT NOT NULL UNIQUE, nick TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL DEFAULT '', signing_key TEXT NOT NULL DEFAULT '', created_at DATETIME DEFAULT CURRENT_TIMESTAMP, last_seen DATETIME DEFAULT CURRENT_TIMESTAMP diff --git a/internal/handlers/api_test.go b/internal/handlers/api_test.go index fdd4e75..1a5ce9e 100644 --- a/internal/handlers/api_test.go +++ b/internal/handlers/api_test.go @@ -1469,6 +1469,310 @@ func TestHealthcheck(t *testing.T) { } } +func TestRegisterValid(t *testing.T) { + tserver := newTestServer(t) + + body, err := json.Marshal(map[string]string{ + "nick": "reguser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf( + "expected 201, got %d: %s", + resp.StatusCode, respBody, + ) + } + + var result map[string]any + + _ = json.NewDecoder(resp.Body).Decode(&result) + + if result["token"] == nil || result["token"] == "" { + t.Fatal("expected token in response") + } + + if result["nick"] != "reguser" { + t.Fatalf( + "expected reguser, got %v", result["nick"], + ) + } +} + +func TestRegisterDuplicate(t *testing.T) { + tserver := newTestServer(t) + + body, err := json.Marshal(map[string]string{ + "nick": "dupuser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + _ = resp.Body.Close() + + resp2, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp2.Body.Close() }() + + if resp2.StatusCode != http.StatusConflict { + t.Fatalf("expected 409, got %d", resp2.StatusCode) + } +} + +func postJSONExpectStatus( + t *testing.T, + tserver *testServer, + path string, + payload map[string]string, + expectedStatus int, +) { + t.Helper() + + body, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url(path), + bytes.NewReader(body), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != expectedStatus { + t.Fatalf( + "expected %d, got %d", + expectedStatus, resp.StatusCode, + ) + } +} + +func TestRegisterShortPassword(t *testing.T) { + tserver := newTestServer(t) + + postJSONExpectStatus( + t, tserver, "/api/v1/register", + map[string]string{ + "nick": "shortpw", "password": "short", + }, + http.StatusBadRequest, + ) +} + +func TestRegisterInvalidNick(t *testing.T) { + tserver := newTestServer(t) + + postJSONExpectStatus( + t, tserver, "/api/v1/register", + map[string]string{ + "nick": "bad nick!", + "password": "password123", + }, + http.StatusBadRequest, + ) +} + +func TestLoginValid(t *testing.T) { + tserver := newTestServer(t) + + // Register first. + regBody, err := json.Marshal(map[string]string{ + "nick": "loginuser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(regBody), + ) + if err != nil { + t.Fatal(err) + } + + _ = resp.Body.Close() + + // Login. + loginBody, err := json.Marshal(map[string]string{ + "nick": "loginuser", "password": "password123", + }) + if err != nil { + t.Fatal(err) + } + + resp2, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/login"), + bytes.NewReader(loginBody), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp2.Body.Close() }() + + if resp2.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp2.Body) + t.Fatalf( + "expected 200, got %d: %s", + resp2.StatusCode, respBody, + ) + } + + var result map[string]any + + _ = json.NewDecoder(resp2.Body).Decode(&result) + + if result["token"] == nil || result["token"] == "" { + t.Fatal("expected token in response") + } + + // Verify token works. + token, ok := result["token"].(string) + if !ok { + t.Fatal("token not a string") + } + + status, state := tserver.getState(token) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) + } + + if state["nick"] != "loginuser" { + t.Fatalf( + "expected loginuser, got %v", + state["nick"], + ) + } +} + +func TestLoginWrongPassword(t *testing.T) { + tserver := newTestServer(t) + + regBody, err := json.Marshal(map[string]string{ + "nick": "wrongpwuser", "password": "correctpass1", + }) + if err != nil { + t.Fatal(err) + } + + resp, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/register"), + bytes.NewReader(regBody), + ) + if err != nil { + t.Fatal(err) + } + + _ = resp.Body.Close() + + loginBody, err := json.Marshal(map[string]string{ + "nick": "wrongpwuser", "password": "wrongpass12", + }) + if err != nil { + t.Fatal(err) + } + + resp2, err := doRequest( + t, + http.MethodPost, + tserver.url("/api/v1/login"), + bytes.NewReader(loginBody), + ) + if err != nil { + t.Fatal(err) + } + + defer func() { _ = resp2.Body.Close() }() + + if resp2.StatusCode != http.StatusUnauthorized { + t.Fatalf( + "expected 401, got %d", resp2.StatusCode, + ) + } +} + +func TestLoginNonexistentUser(t *testing.T) { + tserver := newTestServer(t) + + postJSONExpectStatus( + t, tserver, "/api/v1/login", + map[string]string{ + "nick": "ghostuser", + "password": "password123", + }, + http.StatusUnauthorized, + ) +} + +func TestSessionStillWorks(t *testing.T) { + tserver := newTestServer(t) + + // Verify anonymous session creation still works. + token := tserver.createSession("anon_user") + if token == "" { + t.Fatal("expected token for anonymous session") + } + + status, state := tserver.getState(token) + if status != http.StatusOK { + t.Fatalf("expected 200, got %d", status) + } + + if state["nick"] != "anon_user" { + t.Fatalf( + "expected anon_user, got %v", + state["nick"], + ) + } +} + func TestNickBroadcastToChannels(t *testing.T) { tserver := newTestServer(t) aliceToken := tserver.createSession("nick_a") diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go new file mode 100644 index 0000000..bc44866 --- /dev/null +++ b/internal/handlers/auth.go @@ -0,0 +1,186 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "strings" +) + +const minPasswordLength = 8 + +// HandleRegister creates a new user with a password. +func (hdlr *Handlers) HandleRegister() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + request.Body = http.MaxBytesReader( + writer, request.Body, hdlr.maxBodySize(), + ) + + hdlr.handleRegister(writer, request) + } +} + +func (hdlr *Handlers) handleRegister( + writer http.ResponseWriter, + request *http.Request, +) { + type registerRequest struct { + Nick string `json:"nick"` + Password string `json:"password"` + } + + var payload registerRequest + + err := json.NewDecoder(request.Body).Decode(&payload) + if err != nil { + hdlr.respondError( + writer, request, + "invalid request body", + http.StatusBadRequest, + ) + + return + } + + payload.Nick = strings.TrimSpace(payload.Nick) + + if !validNickRe.MatchString(payload.Nick) { + hdlr.respondError( + writer, request, + "invalid nick format", + http.StatusBadRequest, + ) + + return + } + + if len(payload.Password) < minPasswordLength { + hdlr.respondError( + writer, request, + "password must be at least 8 characters", + http.StatusBadRequest, + ) + + return + } + + sessionID, clientID, token, err := + hdlr.params.Database.RegisterUser( + request.Context(), + payload.Nick, + payload.Password, + ) + if err != nil { + hdlr.handleRegisterError( + writer, request, err, + ) + + return + } + + hdlr.deliverMOTD(request, clientID, sessionID) + + hdlr.respondJSON(writer, request, map[string]any{ + "id": sessionID, + "nick": payload.Nick, + "token": token, + }, http.StatusCreated) +} + +func (hdlr *Handlers) handleRegisterError( + writer http.ResponseWriter, + request *http.Request, + err error, +) { + if strings.Contains(err.Error(), "UNIQUE") { + hdlr.respondError( + writer, request, + "nick already taken", + http.StatusConflict, + ) + + return + } + + hdlr.log.Error( + "register user failed", "error", err, + ) + hdlr.respondError( + writer, request, + "internal error", + http.StatusInternalServerError, + ) +} + +// HandleLogin authenticates a user with nick and password. +func (hdlr *Handlers) HandleLogin() http.HandlerFunc { + return func( + writer http.ResponseWriter, + request *http.Request, + ) { + request.Body = http.MaxBytesReader( + writer, request.Body, hdlr.maxBodySize(), + ) + + hdlr.handleLogin(writer, request) + } +} + +func (hdlr *Handlers) handleLogin( + writer http.ResponseWriter, + request *http.Request, +) { + type loginRequest struct { + Nick string `json:"nick"` + Password string `json:"password"` + } + + var payload loginRequest + + err := json.NewDecoder(request.Body).Decode(&payload) + if err != nil { + hdlr.respondError( + writer, request, + "invalid request body", + http.StatusBadRequest, + ) + + return + } + + payload.Nick = strings.TrimSpace(payload.Nick) + + if payload.Nick == "" || payload.Password == "" { + hdlr.respondError( + writer, request, + "nick and password required", + http.StatusBadRequest, + ) + + return + } + + sessionID, _, token, err := + hdlr.params.Database.LoginUser( + request.Context(), + payload.Nick, + payload.Password, + ) + if err != nil { + hdlr.respondError( + writer, request, + "invalid credentials", + http.StatusUnauthorized, + ) + + return + } + + hdlr.respondJSON(writer, request, map[string]any{ + "id": sessionID, + "nick": payload.Nick, + "token": token, + }, http.StatusOK) +} diff --git a/internal/server/routes.go b/internal/server/routes.go index ba49ad9..e7b632e 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -59,48 +59,55 @@ func (srv *Server) SetupRoutes() { } // API v1. - srv.router.Route( - "/api/v1", - func(router chi.Router) { - router.Get( - "/server", - srv.handlers.HandleServerInfo(), - ) - router.Post( - "/session", - srv.handlers.HandleCreateSession(), - ) - router.Get( - "/state", - srv.handlers.HandleState(), - ) - router.Get( - "/messages", - srv.handlers.HandleGetMessages(), - ) - router.Post( - "/messages", - srv.handlers.HandleSendCommand(), - ) - router.Get( - "/history", - srv.handlers.HandleGetHistory(), - ) - router.Get( - "/channels", - srv.handlers.HandleListAllChannels(), - ) - router.Get( - "/channels/{channel}/members", - srv.handlers.HandleChannelMembers(), - ) - }, - ) + srv.router.Route("/api/v1", srv.setupAPIv1) // Serve embedded SPA. srv.setupSPA() } +func (srv *Server) setupAPIv1(router chi.Router) { + router.Get( + "/server", + srv.handlers.HandleServerInfo(), + ) + router.Post( + "/session", + srv.handlers.HandleCreateSession(), + ) + router.Post( + "/register", + srv.handlers.HandleRegister(), + ) + router.Post( + "/login", + srv.handlers.HandleLogin(), + ) + router.Get( + "/state", + srv.handlers.HandleState(), + ) + router.Get( + "/messages", + srv.handlers.HandleGetMessages(), + ) + router.Post( + "/messages", + srv.handlers.HandleSendCommand(), + ) + router.Get( + "/history", + srv.handlers.HandleGetHistory(), + ) + router.Get( + "/channels", + srv.handlers.HandleListAllChannels(), + ) + router.Get( + "/channels/{channel}/members", + srv.handlers.HandleChannelMembers(), + ) +} + func (srv *Server) setupSPA() { distFS, err := fs.Sub(web.Dist, "dist") if err != nil {