store
This commit is contained in:
parent
578da90b55
commit
8b2d298488
|
@ -0,0 +1,65 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go.uber.org/fx"
|
||||
"gorm.io/gorm"
|
||||
"sneak.berlin/go/directory/internal/config"
|
||||
"sneak.berlin/go/directory/internal/database"
|
||||
"sneak.berlin/go/directory/internal/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
UsersTableName string = "users"
|
||||
UserAttributesTableName string = "userattributes"
|
||||
)
|
||||
|
||||
type StoreParams struct {
|
||||
fx.In
|
||||
Logger *logger.Logger
|
||||
Config *config.Config
|
||||
Database *database.Database
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
dbdir string
|
||||
dbfn string
|
||||
db *gorm.DB
|
||||
logger *logger.Logger
|
||||
log *zerolog.Logger
|
||||
params *StoreParams
|
||||
}
|
||||
|
||||
func New(lc fx.Lifecycle, params *StoreParams) (*Store, error) {
|
||||
s := new(Store)
|
||||
|
||||
s.logger = params.Logger
|
||||
s.log = params.Logger.Get()
|
||||
|
||||
var err error
|
||||
s.db, err = params.Database.Get()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to get database object")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Auto migrate the schema
|
||||
err = s.db.AutoMigrate(&User{}, &UserAttribute{})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to migrate database schema")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() {
|
||||
log.Info().Msg("closing store db")
|
||||
db, err := s.db.DB()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to get generic database object from gorm")
|
||||
return
|
||||
}
|
||||
db.Close()
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
Username string `gorm:"unique"`
|
||||
Email *string
|
||||
PasswordHash string
|
||||
CreatedAt *time.Time
|
||||
LastLogin *time.Time
|
||||
LastSeen *time.Time
|
||||
LastFailedAuth *time.Time
|
||||
RecentFailedAuthCount int
|
||||
LastSeenIP *string
|
||||
Disabled bool
|
||||
Attributes []UserAttribute `gorm:"foreignKey:UserID"`
|
||||
db *gorm.DB `gorm:"-"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return UsersTableName
|
||||
}
|
||||
|
||||
func (u *User) IsDisabled() bool {
|
||||
return u.Disabled
|
||||
}
|
||||
|
||||
func (s *Store) AddUser(username, password string) (*User, error) {
|
||||
var hashedPassword string
|
||||
if password != "" {
|
||||
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to hash password")
|
||||
return nil, err
|
||||
}
|
||||
hashedPassword = string(hashedPasswordBytes)
|
||||
}
|
||||
|
||||
userID := ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String()
|
||||
createdAt := time.Now()
|
||||
user := &User{
|
||||
ID: userID,
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
CreatedAt: &createdAt,
|
||||
db: s.db,
|
||||
}
|
||||
|
||||
result := s.db.Create(user)
|
||||
if result.Error != nil {
|
||||
log.Error().
|
||||
Str("username", username).
|
||||
Err(result.Error).
|
||||
Msg("unable to insert user into db")
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("username", username).
|
||||
Msg("user added to db")
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Store) FindUserByUsername(username string) (*User, error) {
|
||||
if len(username) < 3 {
|
||||
return nil, errors.New("usernames are at least 3 characters")
|
||||
}
|
||||
|
||||
var user User
|
||||
result := s.db.Where("username = ? AND disabled = ?", username, false).First(&user)
|
||||
if result.Error != nil {
|
||||
log.Error().Err(result.Error).Msg("unable to find user")
|
||||
return nil, result.Error
|
||||
}
|
||||
user.db = s.db
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *Store) FindUserByID(userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.Where("id = ? and disabled = ?", userID, false).First(&user)
|
||||
if result.Error != nil {
|
||||
log.Error().Err(result.Error).Msg("unable to find user")
|
||||
return nil, result.Error
|
||||
}
|
||||
user.db = s.db
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *Store) LogFailedAuth(UserID string, r *http.Request) {
|
||||
// FIXME implement
|
||||
}
|
||||
|
||||
func (s *Store) LogSuccessfulAuth(UserID string, r *http.Request) {
|
||||
now := time.Now()
|
||||
u, err := s.FindUserByID(UserID)
|
||||
if err != nil {
|
||||
panic("unable to find user")
|
||||
}
|
||||
s.db.Model(u).Update("LastSeen", now)
|
||||
//s.db.Model(u).Update("LastSeenIP", FIXME)
|
||||
|
||||
}
|
||||
|
||||
func (s *Store) AuthenticateUser(username, password, rootPasswordHash string, r *http.Request) (bool, *User, error) {
|
||||
|
||||
if username == "root" {
|
||||
user, err := s.FindUserByUsername(username)
|
||||
if err != nil {
|
||||
// we probably need to create the root user in the db
|
||||
// and set 'user' because it's returned
|
||||
user, err = s.AddUser("root", "*")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to create root user")
|
||||
return false, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// we also ignore if the root user in the database is disabled,
|
||||
// because root can always log in using the env var password
|
||||
|
||||
// even if the root user exists in the db, we still use the root
|
||||
// password from the environment var, and ignore the root password
|
||||
// hash in the database
|
||||
err = bcrypt.CompareHashAndPassword([]byte(rootPasswordHash), []byte(password))
|
||||
if err != nil {
|
||||
s.LogFailedAuth(user.ID, r)
|
||||
return false, nil, nil
|
||||
}
|
||||
s.LogSuccessfulAuth(user.ID, r)
|
||||
return true, user, nil
|
||||
}
|
||||
|
||||
// FIXME update user record when auth fails
|
||||
user, err := s.FindUserByUsername(username)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
if len(password) < 8 {
|
||||
s.LogFailedAuth(user.ID, r)
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// how long are bcrypt hashes anyway?
|
||||
if len(user.PasswordHash) < 10 {
|
||||
s.LogFailedAuth(user.ID, r)
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
if user.Disabled {
|
||||
s.LogFailedAuth(user.ID, r)
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||||
if err != nil {
|
||||
if err == bcrypt.ErrMismatchedHashAndPassword {
|
||||
s.LogFailedAuth(user.ID, r)
|
||||
return false, nil, nil
|
||||
}
|
||||
log.Error().Err(err).Msg("password comparison error")
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
s.LogSuccessfulAuth(user.ID, r)
|
||||
return true, user, nil
|
||||
}
|
||||
|
||||
func (u *User) Disable() error {
|
||||
u.Disabled = true
|
||||
result := u.db.Save(u)
|
||||
if result.Error != nil {
|
||||
log.Error().
|
||||
Str("email", *u.Email).
|
||||
Err(result.Error).
|
||||
Msg("unable to disable user")
|
||||
return result.Error
|
||||
}
|
||||
log.Debug().
|
||||
Str("email", *u.Email).
|
||||
Msg("user disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) IsAdmin() bool {
|
||||
return u.HasAttribute("admin")
|
||||
}
|
||||
|
||||
func (u *User) IsSuperAdmin() bool {
|
||||
return u.HasAttribute("superadmin")
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
package store
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserAttribute struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
UserID string `gorm:"not null"`
|
||||
Value string `gorm:"not null"`
|
||||
db *gorm.DB `gorm:"-"`
|
||||
}
|
||||
|
||||
func (UserAttribute) TableName() string {
|
||||
return UserAttributesTableName
|
||||
}
|
||||
|
||||
func (u *User) AddAttribute(value string) error {
|
||||
value = strings.ToLower(value)
|
||||
if u.HasAttribute(value) {
|
||||
log.Debug().Str("attribute", value).Msg("attribute already exists")
|
||||
return nil
|
||||
}
|
||||
|
||||
attribute := &UserAttribute{
|
||||
UserID: u.ID,
|
||||
Value: value,
|
||||
db: u.db,
|
||||
}
|
||||
|
||||
result := u.db.Create(attribute)
|
||||
if result.Error != nil {
|
||||
log.Error().Err(result.Error).Msg("unable to add user attribute")
|
||||
return result.Error
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("email", *u.Email).
|
||||
Str("attribute", value).
|
||||
Msg("user attribute added")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) DeleteAttribute(value string) error {
|
||||
value = strings.ToLower(value)
|
||||
result := u.db.Where("user_id = ? AND value = ?", u.ID, value).Delete(&UserAttribute{})
|
||||
if result.Error != nil {
|
||||
log.Error().Err(result.Error).Msg("unable to delete user attribute")
|
||||
return result.Error
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("email", *u.Email).
|
||||
Str("attribute", value).
|
||||
Msg("user attribute deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) HasAttribute(value string) bool {
|
||||
value = strings.ToLower(value)
|
||||
var count int64
|
||||
result := u.db.Model(&UserAttribute{}).Where("user_id = ? AND value = ?", u.ID, value).Count(&count)
|
||||
if result.Error != nil {
|
||||
log.Error().Err(result.Error).Msg("unable to check user attribute")
|
||||
return false
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (u *User) GetAttributes() ([]string, error) {
|
||||
var attributes []UserAttribute
|
||||
result := u.db.Where("user_id = ?", u.ID).Find(&attributes)
|
||||
if result.Error != nil {
|
||||
log.Error().Err(result.Error).Msg("unable to get user attributes")
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
var attributeValues []string
|
||||
for _, attr := range attributes {
|
||||
attributeValues = append(attributeValues, attr.Value)
|
||||
}
|
||||
|
||||
return attributeValues, nil
|
||||
}
|
Loading…
Reference in New Issue