Files
upaas/internal/service/auth/auth.go
clawbot 96eea71c54 fix: set authenticated user on request context in bearer token auth
tryBearerAuth validated the bearer token but never looked up the
associated user or set it on the request context. This meant
downstream handlers calling GetCurrentUser would get nil even
with a valid token.

Changes:
- Add ContextWithUser/UserFromContext helpers in auth package
- tryBearerAuth now looks up the user by token's UserID and
  sets it on the request context via auth.ContextWithUser
- GetCurrentUser checks context first before falling back to
  session cookie
- Add integration tests for bearer auth user context
2026-02-19 23:43:22 -08:00

300 lines
7.2 KiB
Go

// Package auth provides authentication services.
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"github.com/gorilla/sessions"
"go.uber.org/fx"
"golang.org/x/crypto/argon2"
"git.eeqj.de/sneak/upaas/internal/config"
"git.eeqj.de/sneak/upaas/internal/database"
"git.eeqj.de/sneak/upaas/internal/logger"
"git.eeqj.de/sneak/upaas/internal/models"
)
const (
sessionName = "upaas_session"
sessionUserID = "user_id"
)
// contextKeyUser is the context key for storing the authenticated user.
type contextKeyUser struct{}
// ContextWithUser returns a new context with the given user attached.
func ContextWithUser(ctx context.Context, user *models.User) context.Context {
return context.WithValue(ctx, contextKeyUser{}, user)
}
// UserFromContext retrieves the user from the context, if set.
func UserFromContext(ctx context.Context) *models.User {
user, _ := ctx.Value(contextKeyUser{}).(*models.User)
return user
}
// Argon2 parameters.
const (
argonTime = 1
argonMemory = 64 * 1024
argonThreads = 4
argonKeyLen = 32
saltLen = 16
)
// Session duration constants.
const (
sessionMaxAgeDays = 7
sessionMaxAgeSeconds = 86400 * sessionMaxAgeDays
)
var (
// ErrInvalidCredentials is returned when username/password is incorrect.
ErrInvalidCredentials = errors.New("invalid credentials")
// ErrUserExists is returned when trying to create a user that already exists.
ErrUserExists = errors.New("user already exists")
)
// ServiceParams contains dependencies for Service.
type ServiceParams struct {
fx.In
Logger *logger.Logger
Config *config.Config
Database *database.Database
}
// Service provides authentication functionality.
type Service struct {
log *slog.Logger
db *database.Database
store *sessions.CookieStore
params *ServiceParams
}
// New creates a new auth Service.
func New(_ fx.Lifecycle, params ServiceParams) (*Service, error) {
store := sessions.NewCookieStore([]byte(params.Config.SessionSecret))
store.Options = &sessions.Options{
Path: "/",
MaxAge: sessionMaxAgeSeconds,
HttpOnly: true,
Secure: !params.Config.Debug,
SameSite: http.SameSiteLaxMode,
}
return &Service{
log: params.Logger.Get(),
db: params.Database,
store: store,
params: &params,
}, nil
}
// HashPassword hashes a password using Argon2id.
func (svc *Service) HashPassword(password string) (string, error) {
salt := make([]byte, saltLen)
_, err := rand.Read(salt)
if err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
hash := argon2.IDKey(
[]byte(password),
salt,
argonTime,
argonMemory,
argonThreads,
argonKeyLen,
)
// Encode as base64: salt$hash
saltB64 := base64.StdEncoding.EncodeToString(salt)
hashB64 := base64.StdEncoding.EncodeToString(hash)
return saltB64 + "$" + hashB64, nil
}
// VerifyPassword verifies a password against a hash.
func (svc *Service) VerifyPassword(hashedPassword, password string) bool {
// Parse salt$hash format using strings.Cut (more reliable than fmt.Sscanf)
saltB64, hashB64, found := strings.Cut(hashedPassword, "$")
if !found || saltB64 == "" || hashB64 == "" {
return false
}
salt, err := base64.StdEncoding.DecodeString(saltB64)
if err != nil {
return false
}
expectedHash, err := base64.StdEncoding.DecodeString(hashB64)
if err != nil {
return false
}
// Compute hash with same parameters
computedHash := argon2.IDKey(
[]byte(password),
salt,
argonTime,
argonMemory,
argonThreads,
argonKeyLen,
)
// Constant-time comparison
if len(computedHash) != len(expectedHash) {
return false
}
var result byte
for idx := range computedHash {
result |= computedHash[idx] ^ expectedHash[idx]
}
return result == 0
}
// IsSetupRequired checks if initial setup is needed (no users exist).
func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
exists, err := models.UserExists(ctx, svc.db)
if err != nil {
return false, fmt.Errorf("failed to check if user exists: %w", err)
}
return !exists, nil
}
// CreateUser creates the initial admin user.
// It uses a DB transaction to atomically check that no users exist and insert
// the new admin user, preventing race conditions from concurrent setup requests.
func (svc *Service) CreateUser(
ctx context.Context,
username, password string,
) (*models.User, error) {
// Hash password before starting transaction.
hash, err := svc.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
// Use a transaction so the "no users exist" check and the insert are atomic.
// SQLite serializes write transactions, so concurrent requests will block here.
user, err := models.CreateFirstUser(ctx, svc.db, username, hash)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
if user == nil {
return nil, ErrUserExists
}
svc.log.Info("user created", "username", username)
return user, nil
}
// Authenticate validates credentials and returns the user.
func (svc *Service) Authenticate(
ctx context.Context,
username, password string,
) (*models.User, error) {
user, err := models.FindUserByUsername(ctx, svc.db, username)
if err != nil {
return nil, fmt.Errorf("failed to find user: %w", err)
}
if user == nil {
return nil, ErrInvalidCredentials
}
if !svc.VerifyPassword(user.PasswordHash, password) {
return nil, ErrInvalidCredentials
}
return user, nil
}
// CreateSession creates a session for the user.
func (svc *Service) CreateSession(
respWriter http.ResponseWriter,
request *http.Request,
user *models.User,
) error {
session, err := svc.store.Get(request, sessionName)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
session.Values[sessionUserID] = user.ID
saveErr := session.Save(request, respWriter)
if saveErr != nil {
return fmt.Errorf("failed to save session: %w", saveErr)
}
return nil
}
// GetCurrentUser returns the currently logged-in user, or nil if not logged in.
//
//nolint:nilerr // Session errors are not propagated - they indicate no user
func (svc *Service) GetCurrentUser(
ctx context.Context,
request *http.Request,
) (*models.User, error) {
// Check context first (set by bearer token auth).
if user := UserFromContext(ctx); user != nil {
return user, nil
}
session, sessionErr := svc.store.Get(request, sessionName)
if sessionErr != nil {
// Session error means no user - this is not an error condition
return nil, nil //nolint:nilnil // Expected behavior for no session
}
userID, ok := session.Values[sessionUserID].(int64)
if !ok {
return nil, nil //nolint:nilnil // No user ID in session is valid
}
user, err := models.FindUser(ctx, svc.db, userID)
if err != nil {
return nil, fmt.Errorf("failed to find user: %w", err)
}
return user, nil
}
// DestroySession destroys the current session.
func (svc *Service) DestroySession(
respWriter http.ResponseWriter,
request *http.Request,
) error {
session, err := svc.store.Get(request, sessionName)
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
session.Options.MaxAge = -1
saveErr := session.Save(request, respWriter)
if saveErr != nil {
return fmt.Errorf("failed to save session: %w", saveErr)
}
return nil
}