This commit is contained in:
Jeffrey Paul 2024-06-01 15:51:29 -07:00
parent 578da90b55
commit 8b2d298488
3 changed files with 359 additions and 0 deletions

65
internal/store/store.go Normal file
View File

@ -0,0 +1,65 @@
package store
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"go.uber.org/fx"
"gorm.io/gorm"
"sneak.berlin/go/directory/internal/config"
"sneak.berlin/go/directory/internal/database"
"sneak.berlin/go/directory/internal/logger"
)
const (
UsersTableName string = "users"
UserAttributesTableName string = "userattributes"
)
type StoreParams struct {
fx.In
Logger *logger.Logger
Config *config.Config
Database *database.Database
}
type Store struct {
dbdir string
dbfn string
db *gorm.DB
logger *logger.Logger
log *zerolog.Logger
params *StoreParams
}
func New(lc fx.Lifecycle, params *StoreParams) (*Store, error) {
s := new(Store)
s.logger = params.Logger
s.log = params.Logger.Get()
var err error
s.db, err = params.Database.Get()
if err != nil {
log.Error().Err(err).Msg("failed to get database object")
return nil, err
}
// Auto migrate the schema
err = s.db.AutoMigrate(&User{}, &UserAttribute{})
if err != nil {
log.Error().Err(err).Msg("failed to migrate database schema")
return nil, err
}
return s, nil
}
func (s *Store) Close() {
log.Info().Msg("closing store db")
db, err := s.db.DB()
if err != nil {
log.Error().Err(err).Msg("failed to get generic database object from gorm")
return
}
db.Close()
}

207
internal/store/user.go Normal file
View File

@ -0,0 +1,207 @@
package store
import (
"crypto/rand"
"errors"
"net/http"
"time"
"github.com/oklog/ulid/v2"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID string `gorm:"primaryKey"`
Username string `gorm:"unique"`
Email *string
PasswordHash string
CreatedAt *time.Time
LastLogin *time.Time
LastSeen *time.Time
LastFailedAuth *time.Time
RecentFailedAuthCount int
LastSeenIP *string
Disabled bool
Attributes []UserAttribute `gorm:"foreignKey:UserID"`
db *gorm.DB `gorm:"-"`
}
func (User) TableName() string {
return UsersTableName
}
func (u *User) IsDisabled() bool {
return u.Disabled
}
func (s *Store) AddUser(username, password string) (*User, error) {
var hashedPassword string
if password != "" {
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
log.Error().Err(err).Msg("unable to hash password")
return nil, err
}
hashedPassword = string(hashedPasswordBytes)
}
userID := ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String()
createdAt := time.Now()
user := &User{
ID: userID,
Username: username,
PasswordHash: hashedPassword,
CreatedAt: &createdAt,
db: s.db,
}
result := s.db.Create(user)
if result.Error != nil {
log.Error().
Str("username", username).
Err(result.Error).
Msg("unable to insert user into db")
return nil, result.Error
}
log.Debug().
Str("username", username).
Msg("user added to db")
return user, nil
}
func (s *Store) FindUserByUsername(username string) (*User, error) {
if len(username) < 3 {
return nil, errors.New("usernames are at least 3 characters")
}
var user User
result := s.db.Where("username = ? AND disabled = ?", username, false).First(&user)
if result.Error != nil {
log.Error().Err(result.Error).Msg("unable to find user")
return nil, result.Error
}
user.db = s.db
return &user, nil
}
func (s *Store) FindUserByID(userID string) (*User, error) {
var user User
result := s.db.Where("id = ? and disabled = ?", userID, false).First(&user)
if result.Error != nil {
log.Error().Err(result.Error).Msg("unable to find user")
return nil, result.Error
}
user.db = s.db
return &user, nil
}
func (s *Store) LogFailedAuth(UserID string, r *http.Request) {
// FIXME implement
}
func (s *Store) LogSuccessfulAuth(UserID string, r *http.Request) {
now := time.Now()
u, err := s.FindUserByID(UserID)
if err != nil {
panic("unable to find user")
}
s.db.Model(u).Update("LastSeen", now)
//s.db.Model(u).Update("LastSeenIP", FIXME)
}
func (s *Store) AuthenticateUser(username, password, rootPasswordHash string, r *http.Request) (bool, *User, error) {
if username == "root" {
user, err := s.FindUserByUsername(username)
if err != nil {
// we probably need to create the root user in the db
// and set 'user' because it's returned
user, err = s.AddUser("root", "*")
if err != nil {
log.Error().Err(err).Msg("unable to create root user")
return false, nil, err
}
}
// we also ignore if the root user in the database is disabled,
// because root can always log in using the env var password
// even if the root user exists in the db, we still use the root
// password from the environment var, and ignore the root password
// hash in the database
err = bcrypt.CompareHashAndPassword([]byte(rootPasswordHash), []byte(password))
if err != nil {
s.LogFailedAuth(user.ID, r)
return false, nil, nil
}
s.LogSuccessfulAuth(user.ID, r)
return true, user, nil
}
// FIXME update user record when auth fails
user, err := s.FindUserByUsername(username)
if err != nil {
return false, nil, err
}
if user == nil {
return false, nil, nil
}
if len(password) < 8 {
s.LogFailedAuth(user.ID, r)
return false, nil, nil
}
// how long are bcrypt hashes anyway?
if len(user.PasswordHash) < 10 {
s.LogFailedAuth(user.ID, r)
return false, nil, nil
}
if user.Disabled {
s.LogFailedAuth(user.ID, r)
return false, nil, nil
}
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
if err != nil {
if err == bcrypt.ErrMismatchedHashAndPassword {
s.LogFailedAuth(user.ID, r)
return false, nil, nil
}
log.Error().Err(err).Msg("password comparison error")
return false, nil, err
}
s.LogSuccessfulAuth(user.ID, r)
return true, user, nil
}
func (u *User) Disable() error {
u.Disabled = true
result := u.db.Save(u)
if result.Error != nil {
log.Error().
Str("email", *u.Email).
Err(result.Error).
Msg("unable to disable user")
return result.Error
}
log.Debug().
Str("email", *u.Email).
Msg("user disabled")
return nil
}
func (u *User) IsAdmin() bool {
return u.HasAttribute("admin")
}
func (u *User) IsSuperAdmin() bool {
return u.HasAttribute("superadmin")
}

View File

@ -0,0 +1,87 @@
package store
import (
"strings"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
type UserAttribute struct {
ID uint `gorm:"primaryKey"`
UserID string `gorm:"not null"`
Value string `gorm:"not null"`
db *gorm.DB `gorm:"-"`
}
func (UserAttribute) TableName() string {
return UserAttributesTableName
}
func (u *User) AddAttribute(value string) error {
value = strings.ToLower(value)
if u.HasAttribute(value) {
log.Debug().Str("attribute", value).Msg("attribute already exists")
return nil
}
attribute := &UserAttribute{
UserID: u.ID,
Value: value,
db: u.db,
}
result := u.db.Create(attribute)
if result.Error != nil {
log.Error().Err(result.Error).Msg("unable to add user attribute")
return result.Error
}
log.Debug().
Str("email", *u.Email).
Str("attribute", value).
Msg("user attribute added")
return nil
}
func (u *User) DeleteAttribute(value string) error {
value = strings.ToLower(value)
result := u.db.Where("user_id = ? AND value = ?", u.ID, value).Delete(&UserAttribute{})
if result.Error != nil {
log.Error().Err(result.Error).Msg("unable to delete user attribute")
return result.Error
}
log.Debug().
Str("email", *u.Email).
Str("attribute", value).
Msg("user attribute deleted")
return nil
}
func (u *User) HasAttribute(value string) bool {
value = strings.ToLower(value)
var count int64
result := u.db.Model(&UserAttribute{}).Where("user_id = ? AND value = ?", u.ID, value).Count(&count)
if result.Error != nil {
log.Error().Err(result.Error).Msg("unable to check user attribute")
return false
}
return count > 0
}
func (u *User) GetAttributes() ([]string, error) {
var attributes []UserAttribute
result := u.db.Where("user_id = ?", u.ID).Find(&attributes)
if result.Error != nil {
log.Error().Err(result.Error).Msg("unable to get user attributes")
return nil, result.Error
}
var attributeValues []string
for _, attr := range attributes {
attributeValues = append(attributeValues, attr.Value)
}
return attributeValues, nil
}