package database import ( "crypto/rand" "crypto/subtle" "encoding/base64" "errors" "fmt" "math/big" "strings" "golang.org/x/crypto/argon2" ) // Argon2 parameters - these are up-to-date secure defaults const ( argon2Time = 1 argon2Memory = 64 * 1024 // 64 MB argon2Threads = 4 argon2KeyLen = 32 argon2SaltLen = 16 ) // hashParts is the expected number of $-separated segments // in an encoded Argon2id hash string. const hashParts = 6 // minPasswordComplexityLen is the minimum password length that // triggers per-character-class complexity enforcement. const minPasswordComplexityLen = 4 // Sentinel errors returned by decodeHash. var ( errInvalidHashFormat = errors.New("invalid hash format") errInvalidAlgorithm = errors.New("invalid algorithm") errIncompatibleVersion = errors.New("incompatible argon2 version") errSaltLengthOutOfRange = errors.New("salt length out of range") errHashLengthOutOfRange = errors.New("hash length out of range") ) // PasswordConfig holds Argon2 configuration type PasswordConfig struct { Time uint32 Memory uint32 Threads uint8 KeyLen uint32 SaltLen uint32 } // DefaultPasswordConfig returns secure default Argon2 parameters func DefaultPasswordConfig() *PasswordConfig { return &PasswordConfig{ Time: argon2Time, Memory: argon2Memory, Threads: argon2Threads, KeyLen: argon2KeyLen, SaltLen: argon2SaltLen, } } // HashPassword generates an Argon2id hash of the password func HashPassword(password string) (string, error) { config := DefaultPasswordConfig() // Generate a salt salt := make([]byte, config.SaltLen) _, err := rand.Read(salt) if err != nil { return "", err } // Generate the hash hash := argon2.IDKey( []byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen, ) // Encode the hash and parameters b64Salt := base64.RawStdEncoding.EncodeToString(salt) b64Hash := base64.RawStdEncoding.EncodeToString(hash) // Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash encoded := fmt.Sprintf( "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, config.Memory, config.Time, config.Threads, b64Salt, b64Hash, ) return encoded, nil } // VerifyPassword checks if the provided password matches the hash func VerifyPassword( password, encodedHash string, ) (bool, error) { // Extract parameters and hash from encoded string config, salt, hash, err := decodeHash(encodedHash) if err != nil { return false, err } // Generate hash of the provided password otherHash := argon2.IDKey( []byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen, ) // Compare hashes using constant time comparison return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil } // decodeHash extracts parameters, salt, and hash from an // encoded hash string. func decodeHash( encodedHash string, ) (*PasswordConfig, []byte, []byte, error) { parts := strings.Split(encodedHash, "$") if len(parts) != hashParts { return nil, nil, nil, errInvalidHashFormat } if parts[1] != "argon2id" { return nil, nil, nil, errInvalidAlgorithm } version, err := parseVersion(parts[2]) if err != nil { return nil, nil, nil, err } if version != argon2.Version { return nil, nil, nil, errIncompatibleVersion } config, err := parseParams(parts[3]) if err != nil { return nil, nil, nil, err } salt, err := decodeSalt(parts[4]) if err != nil { return nil, nil, nil, err } config.SaltLen = uint32(len(salt)) //nolint:gosec // validated in decodeSalt hash, err := decodeHashBytes(parts[5]) if err != nil { return nil, nil, nil, err } config.KeyLen = uint32(len(hash)) //nolint:gosec // validated in decodeHashBytes return config, salt, hash, nil } func parseVersion(s string) (int, error) { var version int _, err := fmt.Sscanf(s, "v=%d", &version) if err != nil { return 0, fmt.Errorf("parsing version: %w", err) } return version, nil } func parseParams(s string) (*PasswordConfig, error) { config := &PasswordConfig{} _, err := fmt.Sscanf( s, "m=%d,t=%d,p=%d", &config.Memory, &config.Time, &config.Threads, ) if err != nil { return nil, fmt.Errorf("parsing params: %w", err) } return config, nil } func decodeSalt(s string) ([]byte, error) { salt, err := base64.RawStdEncoding.DecodeString(s) if err != nil { return nil, fmt.Errorf("decoding salt: %w", err) } saltLen := len(salt) if saltLen < 0 || saltLen > int(^uint32(0)) { return nil, errSaltLengthOutOfRange } return salt, nil } func decodeHashBytes(s string) ([]byte, error) { hash, err := base64.RawStdEncoding.DecodeString(s) if err != nil { return nil, fmt.Errorf("decoding hash: %w", err) } hashLen := len(hash) if hashLen < 0 || hashLen > int(^uint32(0)) { return nil, errHashLengthOutOfRange } return hash, nil } // GenerateRandomPassword generates a cryptographically secure // random password. func GenerateRandomPassword(length int) (string, error) { const ( uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" lowercase = "abcdefghijklmnopqrstuvwxyz" digits = "0123456789" special = "!@#$%^&*()_+-=[]{}|;:,.<>?" ) // Combine all character sets allChars := uppercase + lowercase + digits + special // Create password slice password := make([]byte, length) // Ensure at least one character from each set if length >= minPasswordComplexityLen { password[0] = uppercase[cryptoRandInt(len(uppercase))] password[1] = lowercase[cryptoRandInt(len(lowercase))] password[2] = digits[cryptoRandInt(len(digits))] password[3] = special[cryptoRandInt(len(special))] // Fill the rest randomly from all characters for i := minPasswordComplexityLen; i < length; i++ { password[i] = allChars[cryptoRandInt(len(allChars))] } // Shuffle the password to avoid predictable pattern for i := range len(password) - 1 { j := cryptoRandInt(len(password) - i) idx := len(password) - 1 - i password[idx], password[j] = password[j], password[idx] } } else { // For very short passwords, just use all characters for i := range length { password[i] = allChars[cryptoRandInt(len(allChars))] } } return string(password), nil } // cryptoRandInt generates a cryptographically secure random // integer in [0, upperBound). func cryptoRandInt(upperBound int) int { if upperBound <= 0 { panic("upperBound must be positive") } nBig, err := rand.Int( rand.Reader, big.NewInt(int64(upperBound)), ) if err != nil { panic(fmt.Sprintf("crypto/rand error: %v", err)) } return int(nBig.Int64()) }