forked from sneak/upaas
tryBearerAuth validated the bearer token but never looked up the associated user or set it on the request context. This meant downstream handlers calling GetCurrentUser would get nil even with a valid token. Changes: - Add ContextWithUser/UserFromContext helpers in auth package - tryBearerAuth now looks up the user by token's UserID and sets it on the request context via auth.ContextWithUser - GetCurrentUser checks context first before falling back to session cookie - Add integration tests for bearer auth user context
300 lines
7.2 KiB
Go
300 lines
7.2 KiB
Go
// Package auth provides authentication services.
|
|
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"go.uber.org/fx"
|
|
"golang.org/x/crypto/argon2"
|
|
|
|
"git.eeqj.de/sneak/upaas/internal/config"
|
|
"git.eeqj.de/sneak/upaas/internal/database"
|
|
"git.eeqj.de/sneak/upaas/internal/logger"
|
|
"git.eeqj.de/sneak/upaas/internal/models"
|
|
)
|
|
|
|
const (
|
|
sessionName = "upaas_session"
|
|
sessionUserID = "user_id"
|
|
)
|
|
|
|
// contextKeyUser is the context key for storing the authenticated user.
|
|
type contextKeyUser struct{}
|
|
|
|
// ContextWithUser returns a new context with the given user attached.
|
|
func ContextWithUser(ctx context.Context, user *models.User) context.Context {
|
|
return context.WithValue(ctx, contextKeyUser{}, user)
|
|
}
|
|
|
|
// UserFromContext retrieves the user from the context, if set.
|
|
func UserFromContext(ctx context.Context) *models.User {
|
|
user, _ := ctx.Value(contextKeyUser{}).(*models.User)
|
|
|
|
return user
|
|
}
|
|
|
|
// Argon2 parameters.
|
|
const (
|
|
argonTime = 1
|
|
argonMemory = 64 * 1024
|
|
argonThreads = 4
|
|
argonKeyLen = 32
|
|
saltLen = 16
|
|
)
|
|
|
|
// Session duration constants.
|
|
const (
|
|
sessionMaxAgeDays = 7
|
|
sessionMaxAgeSeconds = 86400 * sessionMaxAgeDays
|
|
)
|
|
|
|
var (
|
|
// ErrInvalidCredentials is returned when username/password is incorrect.
|
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
// ErrUserExists is returned when trying to create a user that already exists.
|
|
ErrUserExists = errors.New("user already exists")
|
|
)
|
|
|
|
// ServiceParams contains dependencies for Service.
|
|
type ServiceParams struct {
|
|
fx.In
|
|
|
|
Logger *logger.Logger
|
|
Config *config.Config
|
|
Database *database.Database
|
|
}
|
|
|
|
// Service provides authentication functionality.
|
|
type Service struct {
|
|
log *slog.Logger
|
|
db *database.Database
|
|
store *sessions.CookieStore
|
|
params *ServiceParams
|
|
}
|
|
|
|
// New creates a new auth Service.
|
|
func New(_ fx.Lifecycle, params ServiceParams) (*Service, error) {
|
|
store := sessions.NewCookieStore([]byte(params.Config.SessionSecret))
|
|
store.Options = &sessions.Options{
|
|
Path: "/",
|
|
MaxAge: sessionMaxAgeSeconds,
|
|
HttpOnly: true,
|
|
Secure: !params.Config.Debug,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
|
|
return &Service{
|
|
log: params.Logger.Get(),
|
|
db: params.Database,
|
|
store: store,
|
|
params: ¶ms,
|
|
}, nil
|
|
}
|
|
|
|
// HashPassword hashes a password using Argon2id.
|
|
func (svc *Service) HashPassword(password string) (string, error) {
|
|
salt := make([]byte, saltLen)
|
|
|
|
_, err := rand.Read(salt)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to generate salt: %w", err)
|
|
}
|
|
|
|
hash := argon2.IDKey(
|
|
[]byte(password),
|
|
salt,
|
|
argonTime,
|
|
argonMemory,
|
|
argonThreads,
|
|
argonKeyLen,
|
|
)
|
|
|
|
// Encode as base64: salt$hash
|
|
saltB64 := base64.StdEncoding.EncodeToString(salt)
|
|
hashB64 := base64.StdEncoding.EncodeToString(hash)
|
|
|
|
return saltB64 + "$" + hashB64, nil
|
|
}
|
|
|
|
// VerifyPassword verifies a password against a hash.
|
|
func (svc *Service) VerifyPassword(hashedPassword, password string) bool {
|
|
// Parse salt$hash format using strings.Cut (more reliable than fmt.Sscanf)
|
|
saltB64, hashB64, found := strings.Cut(hashedPassword, "$")
|
|
if !found || saltB64 == "" || hashB64 == "" {
|
|
return false
|
|
}
|
|
|
|
salt, err := base64.StdEncoding.DecodeString(saltB64)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
expectedHash, err := base64.StdEncoding.DecodeString(hashB64)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Compute hash with same parameters
|
|
computedHash := argon2.IDKey(
|
|
[]byte(password),
|
|
salt,
|
|
argonTime,
|
|
argonMemory,
|
|
argonThreads,
|
|
argonKeyLen,
|
|
)
|
|
|
|
// Constant-time comparison
|
|
if len(computedHash) != len(expectedHash) {
|
|
return false
|
|
}
|
|
|
|
var result byte
|
|
|
|
for idx := range computedHash {
|
|
result |= computedHash[idx] ^ expectedHash[idx]
|
|
}
|
|
|
|
return result == 0
|
|
}
|
|
|
|
// IsSetupRequired checks if initial setup is needed (no users exist).
|
|
func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
|
|
exists, err := models.UserExists(ctx, svc.db)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check if user exists: %w", err)
|
|
}
|
|
|
|
return !exists, nil
|
|
}
|
|
|
|
// CreateUser creates the initial admin user.
|
|
// It uses a DB transaction to atomically check that no users exist and insert
|
|
// the new admin user, preventing race conditions from concurrent setup requests.
|
|
func (svc *Service) CreateUser(
|
|
ctx context.Context,
|
|
username, password string,
|
|
) (*models.User, error) {
|
|
// Hash password before starting transaction.
|
|
hash, err := svc.HashPassword(password)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
|
|
// Use a transaction so the "no users exist" check and the insert are atomic.
|
|
// SQLite serializes write transactions, so concurrent requests will block here.
|
|
user, err := models.CreateFirstUser(ctx, svc.db, username, hash)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
if user == nil {
|
|
return nil, ErrUserExists
|
|
}
|
|
|
|
svc.log.Info("user created", "username", username)
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// Authenticate validates credentials and returns the user.
|
|
func (svc *Service) Authenticate(
|
|
ctx context.Context,
|
|
username, password string,
|
|
) (*models.User, error) {
|
|
user, err := models.FindUserByUsername(ctx, svc.db, username)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
if user == nil {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
if !svc.VerifyPassword(user.PasswordHash, password) {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// CreateSession creates a session for the user.
|
|
func (svc *Service) CreateSession(
|
|
respWriter http.ResponseWriter,
|
|
request *http.Request,
|
|
user *models.User,
|
|
) error {
|
|
session, err := svc.store.Get(request, sessionName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
session.Values[sessionUserID] = user.ID
|
|
|
|
saveErr := session.Save(request, respWriter)
|
|
if saveErr != nil {
|
|
return fmt.Errorf("failed to save session: %w", saveErr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetCurrentUser returns the currently logged-in user, or nil if not logged in.
|
|
//
|
|
//nolint:nilerr // Session errors are not propagated - they indicate no user
|
|
func (svc *Service) GetCurrentUser(
|
|
ctx context.Context,
|
|
request *http.Request,
|
|
) (*models.User, error) {
|
|
// Check context first (set by bearer token auth).
|
|
if user := UserFromContext(ctx); user != nil {
|
|
return user, nil
|
|
}
|
|
|
|
session, sessionErr := svc.store.Get(request, sessionName)
|
|
if sessionErr != nil {
|
|
// Session error means no user - this is not an error condition
|
|
return nil, nil //nolint:nilnil // Expected behavior for no session
|
|
}
|
|
|
|
userID, ok := session.Values[sessionUserID].(int64)
|
|
if !ok {
|
|
return nil, nil //nolint:nilnil // No user ID in session is valid
|
|
}
|
|
|
|
user, err := models.FindUser(ctx, svc.db, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// DestroySession destroys the current session.
|
|
func (svc *Service) DestroySession(
|
|
respWriter http.ResponseWriter,
|
|
request *http.Request,
|
|
) error {
|
|
session, err := svc.store.Get(request, sessionName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
session.Options.MaxAge = -1
|
|
|
|
saveErr := session.Save(request, respWriter)
|
|
if saveErr != nil {
|
|
return fmt.Errorf("failed to save session: %w", saveErr)
|
|
}
|
|
|
|
return nil
|
|
}
|