initial
This commit is contained in:
25
internal/database/base_model.go
Normal file
25
internal/database/base_model.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseModel contains common fields for all models
|
||||
// This replaces gorm.Model but uses UUID instead of uint for ID
|
||||
type BaseModel struct {
|
||||
ID string `gorm:"type:uuid;primary_key" json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||
}
|
||||
|
||||
// BeforeCreate hook to set UUID before creating a record
|
||||
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
|
||||
if b.ID == "" {
|
||||
b.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
144
internal/database/database.go
Normal file
144
internal/database/database.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"go.uber.org/fx"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
_ "modernc.org/sqlite" // Pure Go SQLite driver
|
||||
)
|
||||
|
||||
// nolint:revive // DatabaseParams is a standard fx naming convention
|
||||
type DatabaseParams struct {
|
||||
fx.In
|
||||
Config *config.Config
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
db *gorm.DB
|
||||
log *slog.Logger
|
||||
params *DatabaseParams
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params DatabaseParams) (*Database, error) {
|
||||
d := &Database{
|
||||
params: ¶ms,
|
||||
log: params.Logger.Get(),
|
||||
}
|
||||
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
return d.connect()
|
||||
},
|
||||
OnStop: func(_ context.Context) error { // nolint:revive // ctx unused but required by fx
|
||||
return d.close()
|
||||
},
|
||||
})
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Database) connect() error {
|
||||
dbURL := d.params.Config.DBURL
|
||||
if dbURL == "" {
|
||||
// Default to SQLite for development
|
||||
dbURL = "file:webhooker.db?cache=shared&mode=rwc"
|
||||
}
|
||||
|
||||
// First, open the database with the pure Go driver
|
||||
sqlDB, err := sql.Open("sqlite", dbURL)
|
||||
if err != nil {
|
||||
d.log.Error("failed to open database", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Then use it with GORM
|
||||
db, err := gorm.Open(sqlite.Dialector{
|
||||
Conn: sqlDB,
|
||||
}, &gorm.Config{})
|
||||
if err != nil {
|
||||
d.log.Error("failed to connect to database", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
d.db = db
|
||||
d.log.Info("connected to database", "database", dbURL)
|
||||
|
||||
// Run migrations
|
||||
return d.migrate()
|
||||
}
|
||||
|
||||
func (d *Database) migrate() error {
|
||||
// Run GORM auto-migrations
|
||||
if err := d.Migrate(); err != nil {
|
||||
d.log.Error("failed to run database migrations", "error", err)
|
||||
return err
|
||||
}
|
||||
d.log.Info("database migrations completed")
|
||||
|
||||
// Check if admin user exists
|
||||
var userCount int64
|
||||
if err := d.db.Model(&User{}).Count(&userCount).Error; err != nil {
|
||||
d.log.Error("failed to count users", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if userCount == 0 {
|
||||
// Create admin user
|
||||
d.log.Info("no users found, creating admin user")
|
||||
|
||||
// Generate random password
|
||||
password, err := GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
d.log.Error("failed to generate random password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
hashedPassword, err := HashPassword(password)
|
||||
if err != nil {
|
||||
d.log.Error("failed to hash password", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create admin user
|
||||
adminUser := &User{
|
||||
Username: "admin",
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
if err := d.db.Create(adminUser).Error; err != nil {
|
||||
d.log.Error("failed to create admin user", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Log the password - this will only happen once on first startup
|
||||
d.log.Info("admin user created",
|
||||
"username", "admin",
|
||||
"password", password,
|
||||
"message", "SAVE THIS PASSWORD - it will not be shown again!")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) close() error {
|
||||
if d.db != nil {
|
||||
sqlDB, err := d.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) DB() *gorm.DB {
|
||||
return d.db
|
||||
}
|
||||
81
internal/database/database_test.go
Normal file
81
internal/database/database_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.eeqj.de/sneak/webhooker/internal/config"
|
||||
"git.eeqj.de/sneak/webhooker/internal/globals"
|
||||
"git.eeqj.de/sneak/webhooker/internal/logger"
|
||||
"go.uber.org/fx/fxtest"
|
||||
)
|
||||
|
||||
func TestDatabaseConnection(t *testing.T) {
|
||||
// Set up test dependencies
|
||||
lc := fxtest.NewLifecycle(t)
|
||||
|
||||
// Create globals
|
||||
globals.Appname = "webhooker-test"
|
||||
globals.Version = "test"
|
||||
globals.Buildarch = "test"
|
||||
|
||||
g, err := globals.New(lc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create globals: %v", err)
|
||||
}
|
||||
|
||||
// Create logger
|
||||
l, err := logger.New(lc, logger.LoggerParams{Globals: g})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create logger: %v", err)
|
||||
}
|
||||
|
||||
// Create config
|
||||
c, err := config.New(lc, config.ConfigParams{
|
||||
Globals: g,
|
||||
Logger: l,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create config: %v", err)
|
||||
}
|
||||
|
||||
// Set test database URL
|
||||
c.DBURL = "file:test.db?cache=shared&mode=rwc"
|
||||
|
||||
// Create database
|
||||
db, err := New(lc, DatabaseParams{
|
||||
Config: c,
|
||||
Logger: l,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
|
||||
// Start lifecycle (this will trigger the connection)
|
||||
ctx := context.Background()
|
||||
err = lc.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if stopErr := lc.Stop(ctx); stopErr != nil {
|
||||
t.Errorf("Failed to stop lifecycle: %v", stopErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Verify we can get the DB instance
|
||||
if db.DB() == nil {
|
||||
t.Error("Expected non-nil database connection")
|
||||
}
|
||||
|
||||
// Test that we can perform a simple query
|
||||
var result int
|
||||
err = db.DB().Raw("SELECT 1").Scan(&result).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute test query: %v", err)
|
||||
}
|
||||
|
||||
if result != 1 {
|
||||
t.Errorf("Expected query result to be 1, got %d", result)
|
||||
}
|
||||
}
|
||||
16
internal/database/model_apikey.go
Normal file
16
internal/database/model_apikey.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package database
|
||||
|
||||
import "time"
|
||||
|
||||
// APIKey represents an API key for a user
|
||||
type APIKey struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
Key string `gorm:"uniqueIndex;not null" json:"key"`
|
||||
Description string `json:"description"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitempty"`
|
||||
}
|
||||
25
internal/database/model_delivery.go
Normal file
25
internal/database/model_delivery.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package database
|
||||
|
||||
// DeliveryStatus represents the status of a delivery
|
||||
type DeliveryStatus string
|
||||
|
||||
const (
|
||||
DeliveryStatusPending DeliveryStatus = "pending"
|
||||
DeliveryStatusDelivered DeliveryStatus = "delivered"
|
||||
DeliveryStatusFailed DeliveryStatus = "failed"
|
||||
DeliveryStatusRetrying DeliveryStatus = "retrying"
|
||||
)
|
||||
|
||||
// Delivery represents a delivery attempt for an event to a target
|
||||
type Delivery struct {
|
||||
BaseModel
|
||||
|
||||
EventID string `gorm:"type:uuid;not null" json:"event_id"`
|
||||
TargetID string `gorm:"type:uuid;not null" json:"target_id"`
|
||||
Status DeliveryStatus `gorm:"not null;default:'pending'" json:"status"`
|
||||
|
||||
// Relations
|
||||
Event Event `json:"event,omitempty"`
|
||||
Target Target `json:"target,omitempty"`
|
||||
DeliveryResults []DeliveryResult `json:"delivery_results,omitempty"`
|
||||
}
|
||||
17
internal/database/model_delivery_result.go
Normal file
17
internal/database/model_delivery_result.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package database
|
||||
|
||||
// DeliveryResult represents the result of a delivery attempt
|
||||
type DeliveryResult struct {
|
||||
BaseModel
|
||||
|
||||
DeliveryID string `gorm:"type:uuid;not null" json:"delivery_id"`
|
||||
AttemptNum int `gorm:"not null" json:"attempt_num"`
|
||||
Success bool `json:"success"`
|
||||
StatusCode int `json:"status_code,omitempty"`
|
||||
ResponseBody string `gorm:"type:text" json:"response_body,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Duration int64 `json:"duration_ms"` // Duration in milliseconds
|
||||
|
||||
// Relations
|
||||
Delivery Delivery `json:"delivery,omitempty"`
|
||||
}
|
||||
20
internal/database/model_event.go
Normal file
20
internal/database/model_event.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package database
|
||||
|
||||
// Event represents a webhook event
|
||||
type Event struct {
|
||||
BaseModel
|
||||
|
||||
ProcessorID string `gorm:"type:uuid;not null" json:"processor_id"`
|
||||
WebhookID string `gorm:"type:uuid;not null" json:"webhook_id"`
|
||||
|
||||
// Request data
|
||||
Method string `gorm:"not null" json:"method"`
|
||||
Headers string `gorm:"type:text" json:"headers"` // JSON
|
||||
Body string `gorm:"type:text" json:"body"`
|
||||
ContentType string `json:"content_type"`
|
||||
|
||||
// Relations
|
||||
Processor Processor `json:"processor,omitempty"`
|
||||
Webhook Webhook `json:"webhook,omitempty"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
16
internal/database/model_processor.go
Normal file
16
internal/database/model_processor.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package database
|
||||
|
||||
// Processor represents an event processor
|
||||
type Processor struct {
|
||||
BaseModel
|
||||
|
||||
UserID string `gorm:"type:uuid;not null" json:"user_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Description string `json:"description"`
|
||||
RetentionDays int `gorm:"default:30" json:"retention_days"` // Days to retain events
|
||||
|
||||
// Relations
|
||||
User User `json:"user,omitempty"`
|
||||
Webhooks []Webhook `json:"webhooks,omitempty"`
|
||||
Targets []Target `json:"targets,omitempty"`
|
||||
}
|
||||
32
internal/database/model_target.go
Normal file
32
internal/database/model_target.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package database
|
||||
|
||||
// TargetType represents the type of delivery target
|
||||
type TargetType string
|
||||
|
||||
const (
|
||||
TargetTypeHTTP TargetType = "http"
|
||||
TargetTypeRetry TargetType = "retry"
|
||||
TargetTypeDatabase TargetType = "database"
|
||||
TargetTypeLog TargetType = "log"
|
||||
)
|
||||
|
||||
// Target represents a delivery target for a processor
|
||||
type Target struct {
|
||||
BaseModel
|
||||
|
||||
ProcessorID string `gorm:"type:uuid;not null" json:"processor_id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Type TargetType `gorm:"not null" json:"type"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
|
||||
// Configuration fields (JSON stored based on type)
|
||||
Config string `gorm:"type:text" json:"config"` // JSON configuration
|
||||
|
||||
// For retry targets
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
MaxQueueSize int `json:"max_queue_size,omitempty"`
|
||||
|
||||
// Relations
|
||||
Processor Processor `json:"processor,omitempty"`
|
||||
Deliveries []Delivery `json:"deliveries,omitempty"`
|
||||
}
|
||||
13
internal/database/model_user.go
Normal file
13
internal/database/model_user.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package database
|
||||
|
||||
// User represents a user of the webhooker service
|
||||
type User struct {
|
||||
BaseModel
|
||||
|
||||
Username string `gorm:"uniqueIndex;not null" json:"username"`
|
||||
Password string `gorm:"not null" json:"-"` // Argon2 hashed
|
||||
|
||||
// Relations
|
||||
Processors []Processor `json:"processors,omitempty"`
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
}
|
||||
14
internal/database/model_webhook.go
Normal file
14
internal/database/model_webhook.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package database
|
||||
|
||||
// Webhook represents a webhook endpoint that feeds into a processor
|
||||
type Webhook struct {
|
||||
BaseModel
|
||||
|
||||
ProcessorID string `gorm:"type:uuid;not null" json:"processor_id"`
|
||||
Path string `gorm:"uniqueIndex;not null" json:"path"` // URL path for this webhook
|
||||
Description string `json:"description"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
|
||||
// Relations
|
||||
Processor Processor `json:"processor,omitempty"`
|
||||
}
|
||||
15
internal/database/models.go
Normal file
15
internal/database/models.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package database
|
||||
|
||||
// Migrate runs database migrations for all models
|
||||
func (d *Database) Migrate() error {
|
||||
return d.db.AutoMigrate(
|
||||
&User{},
|
||||
&APIKey{},
|
||||
&Processor{},
|
||||
&Webhook{},
|
||||
&Target{},
|
||||
&Event{},
|
||||
&Delivery{},
|
||||
&DeliveryResult{},
|
||||
)
|
||||
}
|
||||
187
internal/database/password.go
Normal file
187
internal/database/password.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
// Argon2 parameters - these are up-to-date secure defaults
|
||||
const (
|
||||
argon2Time = 1
|
||||
argon2Memory = 64 * 1024 // 64 MB
|
||||
argon2Threads = 4
|
||||
argon2KeyLen = 32
|
||||
argon2SaltLen = 16
|
||||
)
|
||||
|
||||
// PasswordConfig holds Argon2 configuration
|
||||
type PasswordConfig struct {
|
||||
Time uint32
|
||||
Memory uint32
|
||||
Threads uint8
|
||||
KeyLen uint32
|
||||
SaltLen uint32
|
||||
}
|
||||
|
||||
// DefaultPasswordConfig returns secure default Argon2 parameters
|
||||
func DefaultPasswordConfig() *PasswordConfig {
|
||||
return &PasswordConfig{
|
||||
Time: argon2Time,
|
||||
Memory: argon2Memory,
|
||||
Threads: argon2Threads,
|
||||
KeyLen: argon2KeyLen,
|
||||
SaltLen: argon2SaltLen,
|
||||
}
|
||||
}
|
||||
|
||||
// HashPassword generates an Argon2id hash of the password
|
||||
func HashPassword(password string) (string, error) {
|
||||
config := DefaultPasswordConfig()
|
||||
|
||||
// Generate a salt
|
||||
salt := make([]byte, config.SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate the hash
|
||||
hash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
|
||||
// Encode the hash and parameters
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Format: $argon2id$v=19$m=65536,t=1,p=4$salt$hash
|
||||
encoded := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, config.Memory, config.Time, config.Threads, b64Salt, b64Hash)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// VerifyPassword checks if the provided password matches the hash
|
||||
func VerifyPassword(password, encodedHash string) (bool, error) {
|
||||
// Extract parameters and hash from encoded string
|
||||
config, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Generate hash of the provided password
|
||||
otherHash := argon2.IDKey([]byte(password), salt, config.Time, config.Memory, config.Threads, config.KeyLen)
|
||||
|
||||
// Compare hashes using constant time comparison
|
||||
return subtle.ConstantTimeCompare(hash, otherHash) == 1, nil
|
||||
}
|
||||
|
||||
// decodeHash extracts parameters, salt, and hash from an encoded hash string
|
||||
func decodeHash(encodedHash string) (*PasswordConfig, []byte, []byte, error) {
|
||||
parts := strings.Split(encodedHash, "$")
|
||||
if len(parts) != 6 {
|
||||
return nil, nil, nil, fmt.Errorf("invalid hash format")
|
||||
}
|
||||
|
||||
if parts[1] != "argon2id" {
|
||||
return nil, nil, nil, fmt.Errorf("invalid algorithm")
|
||||
}
|
||||
|
||||
var version int
|
||||
if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, fmt.Errorf("incompatible argon2 version")
|
||||
}
|
||||
|
||||
config := &PasswordConfig{}
|
||||
if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &config.Memory, &config.Time, &config.Threads); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
saltLen := len(salt)
|
||||
if saltLen < 0 || saltLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("salt length out of range")
|
||||
}
|
||||
config.SaltLen = uint32(saltLen) // nolint:gosec // checked above
|
||||
|
||||
hash, err := base64.RawStdEncoding.DecodeString(parts[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
hashLen := len(hash)
|
||||
if hashLen < 0 || hashLen > int(^uint32(0)) {
|
||||
return nil, nil, nil, fmt.Errorf("hash length out of range")
|
||||
}
|
||||
config.KeyLen = uint32(hashLen) // nolint:gosec // checked above
|
||||
|
||||
return config, salt, hash, nil
|
||||
}
|
||||
|
||||
// GenerateRandomPassword generates a cryptographically secure random password
|
||||
func GenerateRandomPassword(length int) (string, error) {
|
||||
const (
|
||||
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
lowercase = "abcdefghijklmnopqrstuvwxyz"
|
||||
digits = "0123456789"
|
||||
special = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
)
|
||||
|
||||
// Combine all character sets
|
||||
allChars := uppercase + lowercase + digits + special
|
||||
|
||||
// Create password slice
|
||||
password := make([]byte, length)
|
||||
|
||||
// Ensure at least one character from each set for password complexity
|
||||
if length >= 4 {
|
||||
// Get one character from each set
|
||||
password[0] = uppercase[cryptoRandInt(len(uppercase))]
|
||||
password[1] = lowercase[cryptoRandInt(len(lowercase))]
|
||||
password[2] = digits[cryptoRandInt(len(digits))]
|
||||
password[3] = special[cryptoRandInt(len(special))]
|
||||
|
||||
// Fill the rest randomly from all characters
|
||||
for i := 4; i < length; i++ {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
|
||||
// Shuffle the password to avoid predictable pattern
|
||||
for i := len(password) - 1; i > 0; i-- {
|
||||
j := cryptoRandInt(i + 1)
|
||||
password[i], password[j] = password[j], password[i]
|
||||
}
|
||||
} else {
|
||||
// For very short passwords, just use all characters
|
||||
for i := 0; i < length; i++ {
|
||||
password[i] = allChars[cryptoRandInt(len(allChars))]
|
||||
}
|
||||
}
|
||||
|
||||
return string(password), nil
|
||||
}
|
||||
|
||||
// cryptoRandInt generates a cryptographically secure random integer in [0, max)
|
||||
func cryptoRandInt(max int) int {
|
||||
if max <= 0 {
|
||||
panic("max must be positive")
|
||||
}
|
||||
|
||||
// Calculate the maximum valid value to avoid modulo bias
|
||||
// For example, if max=200 and we have 256 possible values,
|
||||
// we only accept values 0-199 (reject 200-255)
|
||||
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand error: %v", err))
|
||||
}
|
||||
|
||||
return int(nBig.Int64())
|
||||
}
|
||||
126
internal/database/password_test.go
Normal file
126
internal/database/password_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateRandomPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
}{
|
||||
{"Short password", 8},
|
||||
{"Medium password", 16},
|
||||
{"Long password", 32},
|
||||
{"Very short password", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
password, err := GenerateRandomPassword(tt.length)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if len(password) != tt.length {
|
||||
t.Errorf("Password length = %v, want %v", len(password), tt.length)
|
||||
}
|
||||
|
||||
// For passwords >= 4 chars, check complexity
|
||||
if tt.length >= 4 {
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
hasSpecial := false
|
||||
|
||||
for _, char := range password {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
case strings.ContainsRune("!@#$%^&*()_+-=[]{}|;:,.<>?", char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasDigit || !hasSpecial {
|
||||
t.Errorf("Password lacks required complexity: upper=%v, lower=%v, digit=%v, special=%v",
|
||||
hasUpper, hasLower, hasDigit, hasSpecial)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomPasswordUniqueness(t *testing.T) {
|
||||
// Generate multiple passwords and ensure they're different
|
||||
passwords := make(map[string]bool)
|
||||
const numPasswords = 100
|
||||
|
||||
for i := 0; i < numPasswords; i++ {
|
||||
password, err := GenerateRandomPassword(16)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if passwords[password] {
|
||||
t.Errorf("Duplicate password generated: %s", password)
|
||||
}
|
||||
passwords[password] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
password := "testPassword123!"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
// Check that hash has correct format
|
||||
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||
t.Errorf("Hash doesn't have correct prefix: %s", hash)
|
||||
}
|
||||
|
||||
// Verify password
|
||||
valid, err := VerifyPassword(password, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
if !valid {
|
||||
t.Error("VerifyPassword() returned false for correct password")
|
||||
}
|
||||
|
||||
// Verify wrong password fails
|
||||
valid, err = VerifyPassword("wrongPassword", hash)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyPassword() error = %v", err)
|
||||
}
|
||||
if valid {
|
||||
t.Error("VerifyPassword() returned true for wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPasswordUniqueness(t *testing.T) {
|
||||
password := "testPassword123!"
|
||||
|
||||
// Same password should produce different hashes due to salt
|
||||
hash1, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
hash2, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() error = %v", err)
|
||||
}
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Error("Same password produced identical hashes (salt not working)")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user