store
This commit is contained in:
parent
578da90b55
commit
8b2d298488
65
internal/store/store.go
Normal file
65
internal/store/store.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"go.uber.org/fx"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"sneak.berlin/go/directory/internal/config"
|
||||||
|
"sneak.berlin/go/directory/internal/database"
|
||||||
|
"sneak.berlin/go/directory/internal/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
UsersTableName string = "users"
|
||||||
|
UserAttributesTableName string = "userattributes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StoreParams struct {
|
||||||
|
fx.In
|
||||||
|
Logger *logger.Logger
|
||||||
|
Config *config.Config
|
||||||
|
Database *database.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
dbdir string
|
||||||
|
dbfn string
|
||||||
|
db *gorm.DB
|
||||||
|
logger *logger.Logger
|
||||||
|
log *zerolog.Logger
|
||||||
|
params *StoreParams
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(lc fx.Lifecycle, params *StoreParams) (*Store, error) {
|
||||||
|
s := new(Store)
|
||||||
|
|
||||||
|
s.logger = params.Logger
|
||||||
|
s.log = params.Logger.Get()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.db, err = params.Database.Get()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to get database object")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto migrate the schema
|
||||||
|
err = s.db.AutoMigrate(&User{}, &UserAttribute{})
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to migrate database schema")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Close() {
|
||||||
|
log.Info().Msg("closing store db")
|
||||||
|
db, err := s.db.DB()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to get generic database object from gorm")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
db.Close()
|
||||||
|
}
|
207
internal/store/user.go
Normal file
207
internal/store/user.go
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oklog/ulid/v2"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
Username string `gorm:"unique"`
|
||||||
|
Email *string
|
||||||
|
PasswordHash string
|
||||||
|
CreatedAt *time.Time
|
||||||
|
LastLogin *time.Time
|
||||||
|
LastSeen *time.Time
|
||||||
|
LastFailedAuth *time.Time
|
||||||
|
RecentFailedAuthCount int
|
||||||
|
LastSeenIP *string
|
||||||
|
Disabled bool
|
||||||
|
Attributes []UserAttribute `gorm:"foreignKey:UserID"`
|
||||||
|
db *gorm.DB `gorm:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string {
|
||||||
|
return UsersTableName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) IsDisabled() bool {
|
||||||
|
return u.Disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AddUser(username, password string) (*User, error) {
|
||||||
|
var hashedPassword string
|
||||||
|
if password != "" {
|
||||||
|
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("unable to hash password")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hashedPassword = string(hashedPasswordBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String()
|
||||||
|
createdAt := time.Now()
|
||||||
|
user := &User{
|
||||||
|
ID: userID,
|
||||||
|
Username: username,
|
||||||
|
PasswordHash: hashedPassword,
|
||||||
|
CreatedAt: &createdAt,
|
||||||
|
db: s.db,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.db.Create(user)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("username", username).
|
||||||
|
Err(result.Error).
|
||||||
|
Msg("unable to insert user into db")
|
||||||
|
return nil, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("username", username).
|
||||||
|
Msg("user added to db")
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) FindUserByUsername(username string) (*User, error) {
|
||||||
|
if len(username) < 3 {
|
||||||
|
return nil, errors.New("usernames are at least 3 characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
result := s.db.Where("username = ? AND disabled = ?", username, false).First(&user)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().Err(result.Error).Msg("unable to find user")
|
||||||
|
return nil, result.Error
|
||||||
|
}
|
||||||
|
user.db = s.db
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) FindUserByID(userID string) (*User, error) {
|
||||||
|
var user User
|
||||||
|
result := s.db.Where("id = ? and disabled = ?", userID, false).First(&user)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().Err(result.Error).Msg("unable to find user")
|
||||||
|
return nil, result.Error
|
||||||
|
}
|
||||||
|
user.db = s.db
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) LogFailedAuth(UserID string, r *http.Request) {
|
||||||
|
// FIXME implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) LogSuccessfulAuth(UserID string, r *http.Request) {
|
||||||
|
now := time.Now()
|
||||||
|
u, err := s.FindUserByID(UserID)
|
||||||
|
if err != nil {
|
||||||
|
panic("unable to find user")
|
||||||
|
}
|
||||||
|
s.db.Model(u).Update("LastSeen", now)
|
||||||
|
//s.db.Model(u).Update("LastSeenIP", FIXME)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) AuthenticateUser(username, password, rootPasswordHash string, r *http.Request) (bool, *User, error) {
|
||||||
|
|
||||||
|
if username == "root" {
|
||||||
|
user, err := s.FindUserByUsername(username)
|
||||||
|
if err != nil {
|
||||||
|
// we probably need to create the root user in the db
|
||||||
|
// and set 'user' because it's returned
|
||||||
|
user, err = s.AddUser("root", "*")
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("unable to create root user")
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we also ignore if the root user in the database is disabled,
|
||||||
|
// because root can always log in using the env var password
|
||||||
|
|
||||||
|
// even if the root user exists in the db, we still use the root
|
||||||
|
// password from the environment var, and ignore the root password
|
||||||
|
// hash in the database
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(rootPasswordHash), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
s.LogFailedAuth(user.ID, r)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
s.LogSuccessfulAuth(user.ID, r)
|
||||||
|
return true, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME update user record when auth fails
|
||||||
|
user, err := s.FindUserByUsername(username)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(password) < 8 {
|
||||||
|
s.LogFailedAuth(user.ID, r)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// how long are bcrypt hashes anyway?
|
||||||
|
if len(user.PasswordHash) < 10 {
|
||||||
|
s.LogFailedAuth(user.ID, r)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Disabled {
|
||||||
|
s.LogFailedAuth(user.ID, r)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
if err == bcrypt.ErrMismatchedHashAndPassword {
|
||||||
|
s.LogFailedAuth(user.ID, r)
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
log.Error().Err(err).Msg("password comparison error")
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.LogSuccessfulAuth(user.ID, r)
|
||||||
|
return true, user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Disable() error {
|
||||||
|
u.Disabled = true
|
||||||
|
result := u.db.Save(u)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("email", *u.Email).
|
||||||
|
Err(result.Error).
|
||||||
|
Msg("unable to disable user")
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
log.Debug().
|
||||||
|
Str("email", *u.Email).
|
||||||
|
Msg("user disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) IsAdmin() bool {
|
||||||
|
return u.HasAttribute("admin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) IsSuperAdmin() bool {
|
||||||
|
return u.HasAttribute("superadmin")
|
||||||
|
}
|
87
internal/store/user_attribute.go
Normal file
87
internal/store/user_attribute.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserAttribute struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
UserID string `gorm:"not null"`
|
||||||
|
Value string `gorm:"not null"`
|
||||||
|
db *gorm.DB `gorm:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserAttribute) TableName() string {
|
||||||
|
return UserAttributesTableName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) AddAttribute(value string) error {
|
||||||
|
value = strings.ToLower(value)
|
||||||
|
if u.HasAttribute(value) {
|
||||||
|
log.Debug().Str("attribute", value).Msg("attribute already exists")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
attribute := &UserAttribute{
|
||||||
|
UserID: u.ID,
|
||||||
|
Value: value,
|
||||||
|
db: u.db,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := u.db.Create(attribute)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().Err(result.Error).Msg("unable to add user attribute")
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("email", *u.Email).
|
||||||
|
Str("attribute", value).
|
||||||
|
Msg("user attribute added")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) DeleteAttribute(value string) error {
|
||||||
|
value = strings.ToLower(value)
|
||||||
|
result := u.db.Where("user_id = ? AND value = ?", u.ID, value).Delete(&UserAttribute{})
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().Err(result.Error).Msg("unable to delete user attribute")
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("email", *u.Email).
|
||||||
|
Str("attribute", value).
|
||||||
|
Msg("user attribute deleted")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) HasAttribute(value string) bool {
|
||||||
|
value = strings.ToLower(value)
|
||||||
|
var count int64
|
||||||
|
result := u.db.Model(&UserAttribute{}).Where("user_id = ? AND value = ?", u.ID, value).Count(&count)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().Err(result.Error).Msg("unable to check user attribute")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) GetAttributes() ([]string, error) {
|
||||||
|
var attributes []UserAttribute
|
||||||
|
result := u.db.Where("user_id = ?", u.ID).Find(&attributes)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Error().Err(result.Error).Msg("unable to get user attributes")
|
||||||
|
return nil, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
var attributeValues []string
|
||||||
|
for _, attr := range attributes {
|
||||||
|
attributeValues = append(attributeValues, attr.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return attributeValues, nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user