288 lines
6.5 KiB
Go
288 lines
6.5 KiB
Go
// Package auth provides authentication services.
|
|
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"go.uber.org/fx"
|
|
"golang.org/x/crypto/argon2"
|
|
|
|
"git.eeqj.de/sneak/upaas/internal/config"
|
|
"git.eeqj.de/sneak/upaas/internal/database"
|
|
"git.eeqj.de/sneak/upaas/internal/logger"
|
|
"git.eeqj.de/sneak/upaas/internal/models"
|
|
)
|
|
|
|
const (
|
|
sessionName = "upaas_session"
|
|
sessionUserID = "user_id"
|
|
)
|
|
|
|
// Argon2 parameters.
|
|
const (
|
|
argonTime = 1
|
|
argonMemory = 64 * 1024
|
|
argonThreads = 4
|
|
argonKeyLen = 32
|
|
saltLen = 16
|
|
)
|
|
|
|
// Session duration constants.
|
|
const (
|
|
sessionMaxAgeDays = 7
|
|
sessionMaxAgeSeconds = 86400 * sessionMaxAgeDays
|
|
)
|
|
|
|
var (
|
|
// ErrInvalidCredentials is returned when username/password is incorrect.
|
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
// ErrUserExists is returned when trying to create a user that already exists.
|
|
ErrUserExists = errors.New("user already exists")
|
|
)
|
|
|
|
// ServiceParams contains dependencies for Service.
|
|
type ServiceParams struct {
|
|
fx.In
|
|
|
|
Logger *logger.Logger
|
|
Config *config.Config
|
|
Database *database.Database
|
|
}
|
|
|
|
// Service provides authentication functionality.
|
|
type Service struct {
|
|
log *slog.Logger
|
|
db *database.Database
|
|
store *sessions.CookieStore
|
|
params *ServiceParams
|
|
}
|
|
|
|
// New creates a new auth Service.
|
|
func New(_ fx.Lifecycle, params ServiceParams) (*Service, error) {
|
|
store := sessions.NewCookieStore([]byte(params.Config.SessionSecret))
|
|
store.Options = &sessions.Options{
|
|
Path: "/",
|
|
MaxAge: sessionMaxAgeSeconds,
|
|
HttpOnly: true,
|
|
Secure: !params.Config.Debug,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
|
|
return &Service{
|
|
log: params.Logger.Get(),
|
|
db: params.Database,
|
|
store: store,
|
|
params: ¶ms,
|
|
}, nil
|
|
}
|
|
|
|
// HashPassword hashes a password using Argon2id.
|
|
func (svc *Service) HashPassword(password string) (string, error) {
|
|
salt := make([]byte, saltLen)
|
|
|
|
_, err := rand.Read(salt)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to generate salt: %w", err)
|
|
}
|
|
|
|
hash := argon2.IDKey(
|
|
[]byte(password),
|
|
salt,
|
|
argonTime,
|
|
argonMemory,
|
|
argonThreads,
|
|
argonKeyLen,
|
|
)
|
|
|
|
// Encode as base64: salt$hash
|
|
saltB64 := base64.StdEncoding.EncodeToString(salt)
|
|
hashB64 := base64.StdEncoding.EncodeToString(hash)
|
|
|
|
return saltB64 + "$" + hashB64, nil
|
|
}
|
|
|
|
// VerifyPassword verifies a password against a hash.
|
|
func (svc *Service) VerifyPassword(hashedPassword, password string) bool {
|
|
// Parse salt$hash format using strings.Cut (more reliable than fmt.Sscanf)
|
|
saltB64, hashB64, found := strings.Cut(hashedPassword, "$")
|
|
if !found || saltB64 == "" || hashB64 == "" {
|
|
return false
|
|
}
|
|
|
|
salt, err := base64.StdEncoding.DecodeString(saltB64)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
expectedHash, err := base64.StdEncoding.DecodeString(hashB64)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Compute hash with same parameters
|
|
computedHash := argon2.IDKey(
|
|
[]byte(password),
|
|
salt,
|
|
argonTime,
|
|
argonMemory,
|
|
argonThreads,
|
|
argonKeyLen,
|
|
)
|
|
|
|
// Constant-time comparison
|
|
if len(computedHash) != len(expectedHash) {
|
|
return false
|
|
}
|
|
|
|
var result byte
|
|
|
|
for idx := range computedHash {
|
|
result |= computedHash[idx] ^ expectedHash[idx]
|
|
}
|
|
|
|
return result == 0
|
|
}
|
|
|
|
// IsSetupRequired checks if initial setup is needed (no users exist).
|
|
func (svc *Service) IsSetupRequired(ctx context.Context) (bool, error) {
|
|
exists, err := models.UserExists(ctx, svc.db)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check if user exists: %w", err)
|
|
}
|
|
|
|
return !exists, nil
|
|
}
|
|
|
|
// CreateUser creates the initial admin user.
|
|
func (svc *Service) CreateUser(
|
|
ctx context.Context,
|
|
username, password string,
|
|
) (*models.User, error) {
|
|
// Check if user already exists
|
|
exists, err := models.UserExists(ctx, svc.db)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check if user exists: %w", err)
|
|
}
|
|
|
|
if exists {
|
|
return nil, ErrUserExists
|
|
}
|
|
|
|
// Hash password
|
|
hash, err := svc.HashPassword(password)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
|
|
// Create user
|
|
user := models.NewUser(svc.db)
|
|
user.Username = username
|
|
user.PasswordHash = hash
|
|
|
|
err = user.Save(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to save user: %w", err)
|
|
}
|
|
|
|
svc.log.Info("user created", "username", username)
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// Authenticate validates credentials and returns the user.
|
|
func (svc *Service) Authenticate(
|
|
ctx context.Context,
|
|
username, password string,
|
|
) (*models.User, error) {
|
|
user, err := models.FindUserByUsername(ctx, svc.db, username)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
if user == nil {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
if !svc.VerifyPassword(user.PasswordHash, password) {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// CreateSession creates a session for the user.
|
|
func (svc *Service) CreateSession(
|
|
respWriter http.ResponseWriter,
|
|
request *http.Request,
|
|
user *models.User,
|
|
) error {
|
|
session, err := svc.store.Get(request, sessionName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
session.Values[sessionUserID] = user.ID
|
|
|
|
saveErr := session.Save(request, respWriter)
|
|
if saveErr != nil {
|
|
return fmt.Errorf("failed to save session: %w", saveErr)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetCurrentUser returns the currently logged-in user, or nil if not logged in.
|
|
//
|
|
//nolint:nilerr // Session errors are not propagated - they indicate no user
|
|
func (svc *Service) GetCurrentUser(
|
|
ctx context.Context,
|
|
request *http.Request,
|
|
) (*models.User, error) {
|
|
session, sessionErr := svc.store.Get(request, sessionName)
|
|
if sessionErr != nil {
|
|
// Session error means no user - this is not an error condition
|
|
return nil, nil //nolint:nilnil // Expected behavior for no session
|
|
}
|
|
|
|
userID, ok := session.Values[sessionUserID].(int64)
|
|
if !ok {
|
|
return nil, nil //nolint:nilnil // No user ID in session is valid
|
|
}
|
|
|
|
user, err := models.FindUser(ctx, svc.db, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// DestroySession destroys the current session.
|
|
func (svc *Service) DestroySession(
|
|
respWriter http.ResponseWriter,
|
|
request *http.Request,
|
|
) error {
|
|
session, err := svc.store.Get(request, sessionName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
session.Options.MaxAge = -1 * int(time.Second)
|
|
|
|
saveErr := session.Save(request, respWriter)
|
|
if saveErr != nil {
|
|
return fmt.Errorf("failed to save session: %w", saveErr)
|
|
}
|
|
|
|
return nil
|
|
}
|