162 lines
3.1 KiB
Go
162 lines
3.1 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
const bcryptCost = bcrypt.DefaultCost
|
|
|
|
var errNoPassword = errors.New(
|
|
"account has no password set",
|
|
)
|
|
|
|
// RegisterUser creates a session with a hashed password
|
|
// and returns session ID, client ID, and token.
|
|
func (database *Database) RegisterUser(
|
|
ctx context.Context,
|
|
nick, password string,
|
|
) (int64, int64, string, error) {
|
|
hash, err := bcrypt.GenerateFromPassword(
|
|
[]byte(password), bcryptCost,
|
|
)
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"hash password: %w", err,
|
|
)
|
|
}
|
|
|
|
sessionUUID := uuid.New().String()
|
|
clientUUID := uuid.New().String()
|
|
|
|
token, err := generateToken()
|
|
if err != nil {
|
|
return 0, 0, "", err
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
transaction, err := database.conn.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"begin tx: %w", err,
|
|
)
|
|
}
|
|
|
|
res, err := transaction.ExecContext(ctx,
|
|
`INSERT INTO sessions
|
|
(uuid, nick, password_hash,
|
|
created_at, last_seen)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
sessionUUID, nick, string(hash), now, now)
|
|
if err != nil {
|
|
_ = transaction.Rollback()
|
|
|
|
return 0, 0, "", fmt.Errorf(
|
|
"create session: %w", err,
|
|
)
|
|
}
|
|
|
|
sessionID, _ := res.LastInsertId()
|
|
|
|
clientRes, err := transaction.ExecContext(ctx,
|
|
`INSERT INTO clients
|
|
(uuid, session_id, token,
|
|
created_at, last_seen)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
clientUUID, sessionID, token, now, now)
|
|
if err != nil {
|
|
_ = transaction.Rollback()
|
|
|
|
return 0, 0, "", fmt.Errorf(
|
|
"create client: %w", err,
|
|
)
|
|
}
|
|
|
|
clientID, _ := clientRes.LastInsertId()
|
|
|
|
err = transaction.Commit()
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"commit registration: %w", err,
|
|
)
|
|
}
|
|
|
|
return sessionID, clientID, token, nil
|
|
}
|
|
|
|
// LoginUser verifies a nick/password and creates a new
|
|
// client token.
|
|
func (database *Database) LoginUser(
|
|
ctx context.Context,
|
|
nick, password string,
|
|
) (int64, int64, string, error) {
|
|
var (
|
|
sessionID int64
|
|
passwordHash string
|
|
)
|
|
|
|
err := database.conn.QueryRowContext(
|
|
ctx,
|
|
`SELECT id, password_hash
|
|
FROM sessions WHERE nick = ?`,
|
|
nick,
|
|
).Scan(&sessionID, &passwordHash)
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"get session for login: %w", err,
|
|
)
|
|
}
|
|
|
|
if passwordHash == "" {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"login: %w", errNoPassword,
|
|
)
|
|
}
|
|
|
|
err = bcrypt.CompareHashAndPassword(
|
|
[]byte(passwordHash), []byte(password),
|
|
)
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"verify password: %w", err,
|
|
)
|
|
}
|
|
|
|
clientUUID := uuid.New().String()
|
|
|
|
token, err := generateToken()
|
|
if err != nil {
|
|
return 0, 0, "", err
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
res, err := database.conn.ExecContext(ctx,
|
|
`INSERT INTO clients
|
|
(uuid, session_id, token,
|
|
created_at, last_seen)
|
|
VALUES (?, ?, ?, ?, ?)`,
|
|
clientUUID, sessionID, token, now, now)
|
|
if err != nil {
|
|
return 0, 0, "", fmt.Errorf(
|
|
"create login client: %w", err,
|
|
)
|
|
}
|
|
|
|
clientID, _ := res.LastInsertId()
|
|
|
|
_, _ = database.conn.ExecContext(
|
|
ctx,
|
|
"UPDATE sessions SET last_seen = ? WHERE id = ?",
|
|
now, sessionID,
|
|
)
|
|
|
|
return sessionID, clientID, token, nil
|
|
}
|