package db import ( "context" "errors" "fmt" "time" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" ) const bcryptCost = bcrypt.DefaultCost var errNoPassword = errors.New( "account has no password set", ) // RegisterUser creates a session with a hashed password // and returns session ID, client ID, and token. func (database *Database) RegisterUser( ctx context.Context, nick, password string, ) (int64, int64, string, error) { hash, err := bcrypt.GenerateFromPassword( []byte(password), bcryptCost, ) if err != nil { return 0, 0, "", fmt.Errorf( "hash password: %w", err, ) } sessionUUID := uuid.New().String() clientUUID := uuid.New().String() token, err := generateToken() if err != nil { return 0, 0, "", err } now := time.Now() transaction, err := database.conn.BeginTx(ctx, nil) if err != nil { return 0, 0, "", fmt.Errorf( "begin tx: %w", err, ) } res, err := transaction.ExecContext(ctx, `INSERT INTO sessions (uuid, nick, password_hash, created_at, last_seen) VALUES (?, ?, ?, ?, ?)`, sessionUUID, nick, string(hash), now, now) if err != nil { _ = transaction.Rollback() return 0, 0, "", fmt.Errorf( "create session: %w", err, ) } sessionID, _ := res.LastInsertId() clientRes, err := transaction.ExecContext(ctx, `INSERT INTO clients (uuid, session_id, token, created_at, last_seen) VALUES (?, ?, ?, ?, ?)`, clientUUID, sessionID, token, now, now) if err != nil { _ = transaction.Rollback() return 0, 0, "", fmt.Errorf( "create client: %w", err, ) } clientID, _ := clientRes.LastInsertId() err = transaction.Commit() if err != nil { return 0, 0, "", fmt.Errorf( "commit registration: %w", err, ) } return sessionID, clientID, token, nil } // LoginUser verifies a nick/password and creates a new // client token. func (database *Database) LoginUser( ctx context.Context, nick, password string, ) (int64, int64, string, error) { var ( sessionID int64 passwordHash string ) err := database.conn.QueryRowContext( ctx, `SELECT id, password_hash FROM sessions WHERE nick = ?`, nick, ).Scan(&sessionID, &passwordHash) if err != nil { return 0, 0, "", fmt.Errorf( "get session for login: %w", err, ) } if passwordHash == "" { return 0, 0, "", fmt.Errorf( "login: %w", errNoPassword, ) } err = bcrypt.CompareHashAndPassword( []byte(passwordHash), []byte(password), ) if err != nil { return 0, 0, "", fmt.Errorf( "verify password: %w", err, ) } clientUUID := uuid.New().String() token, err := generateToken() if err != nil { return 0, 0, "", err } now := time.Now() res, err := database.conn.ExecContext(ctx, `INSERT INTO clients (uuid, session_id, token, created_at, last_seen) VALUES (?, ?, ?, ?, ?)`, clientUUID, sessionID, token, now, now) if err != nil { return 0, 0, "", fmt.Errorf( "create login client: %w", err, ) } clientID, _ := res.LastInsertId() _, _ = database.conn.ExecContext( ctx, "UPDATE sessions SET last_seen = ? WHERE id = ?", now, sessionID, ) return sessionID, clientID, token, nil }