From 8b2d29848811733310177a704c54a0925dd32092 Mon Sep 17 00:00:00 2001 From: sneak Date: Sat, 1 Jun 2024 15:51:29 -0700 Subject: [PATCH] store --- internal/store/store.go | 65 ++++++++++ internal/store/user.go | 207 +++++++++++++++++++++++++++++++ internal/store/user_attribute.go | 87 +++++++++++++ 3 files changed, 359 insertions(+) create mode 100644 internal/store/store.go create mode 100644 internal/store/user.go create mode 100644 internal/store/user_attribute.go diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..bb82de4 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,65 @@ +package store + +import ( + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "go.uber.org/fx" + "gorm.io/gorm" + "sneak.berlin/go/directory/internal/config" + "sneak.berlin/go/directory/internal/database" + "sneak.berlin/go/directory/internal/logger" +) + +const ( + UsersTableName string = "users" + UserAttributesTableName string = "userattributes" +) + +type StoreParams struct { + fx.In + Logger *logger.Logger + Config *config.Config + Database *database.Database +} + +type Store struct { + dbdir string + dbfn string + db *gorm.DB + logger *logger.Logger + log *zerolog.Logger + params *StoreParams +} + +func New(lc fx.Lifecycle, params *StoreParams) (*Store, error) { + s := new(Store) + + s.logger = params.Logger + s.log = params.Logger.Get() + + var err error + s.db, err = params.Database.Get() + if err != nil { + log.Error().Err(err).Msg("failed to get database object") + return nil, err + } + + // Auto migrate the schema + err = s.db.AutoMigrate(&User{}, &UserAttribute{}) + if err != nil { + log.Error().Err(err).Msg("failed to migrate database schema") + return nil, err + } + + return s, nil +} + +func (s *Store) Close() { + log.Info().Msg("closing store db") + db, err := s.db.DB() + if err != nil { + log.Error().Err(err).Msg("failed to get generic database object from gorm") + return + } + db.Close() +} diff --git a/internal/store/user.go b/internal/store/user.go new file mode 100644 index 0000000..e6e9905 --- /dev/null +++ b/internal/store/user.go @@ -0,0 +1,207 @@ +package store + +import ( + "crypto/rand" + "errors" + "net/http" + "time" + + "github.com/oklog/ulid/v2" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +type User struct { + ID string `gorm:"primaryKey"` + Username string `gorm:"unique"` + Email *string + PasswordHash string + CreatedAt *time.Time + LastLogin *time.Time + LastSeen *time.Time + LastFailedAuth *time.Time + RecentFailedAuthCount int + LastSeenIP *string + Disabled bool + Attributes []UserAttribute `gorm:"foreignKey:UserID"` + db *gorm.DB `gorm:"-"` +} + +func (User) TableName() string { + return UsersTableName +} + +func (u *User) IsDisabled() bool { + return u.Disabled +} + +func (s *Store) AddUser(username, password string) (*User, error) { + var hashedPassword string + if password != "" { + hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + log.Error().Err(err).Msg("unable to hash password") + return nil, err + } + hashedPassword = string(hashedPasswordBytes) + } + + userID := ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String() + createdAt := time.Now() + user := &User{ + ID: userID, + Username: username, + PasswordHash: hashedPassword, + CreatedAt: &createdAt, + db: s.db, + } + + result := s.db.Create(user) + if result.Error != nil { + log.Error(). + Str("username", username). + Err(result.Error). + Msg("unable to insert user into db") + return nil, result.Error + } + + log.Debug(). + Str("username", username). + Msg("user added to db") + return user, nil +} + +func (s *Store) FindUserByUsername(username string) (*User, error) { + if len(username) < 3 { + return nil, errors.New("usernames are at least 3 characters") + } + + var user User + result := s.db.Where("username = ? AND disabled = ?", username, false).First(&user) + if result.Error != nil { + log.Error().Err(result.Error).Msg("unable to find user") + return nil, result.Error + } + user.db = s.db + return &user, nil +} + +func (s *Store) FindUserByID(userID string) (*User, error) { + var user User + result := s.db.Where("id = ? and disabled = ?", userID, false).First(&user) + if result.Error != nil { + log.Error().Err(result.Error).Msg("unable to find user") + return nil, result.Error + } + user.db = s.db + return &user, nil +} + +func (s *Store) LogFailedAuth(UserID string, r *http.Request) { + // FIXME implement +} + +func (s *Store) LogSuccessfulAuth(UserID string, r *http.Request) { + now := time.Now() + u, err := s.FindUserByID(UserID) + if err != nil { + panic("unable to find user") + } + s.db.Model(u).Update("LastSeen", now) + //s.db.Model(u).Update("LastSeenIP", FIXME) + +} + +func (s *Store) AuthenticateUser(username, password, rootPasswordHash string, r *http.Request) (bool, *User, error) { + + if username == "root" { + user, err := s.FindUserByUsername(username) + if err != nil { + // we probably need to create the root user in the db + // and set 'user' because it's returned + user, err = s.AddUser("root", "*") + if err != nil { + log.Error().Err(err).Msg("unable to create root user") + return false, nil, err + } + } + + // we also ignore if the root user in the database is disabled, + // because root can always log in using the env var password + + // even if the root user exists in the db, we still use the root + // password from the environment var, and ignore the root password + // hash in the database + err = bcrypt.CompareHashAndPassword([]byte(rootPasswordHash), []byte(password)) + if err != nil { + s.LogFailedAuth(user.ID, r) + return false, nil, nil + } + s.LogSuccessfulAuth(user.ID, r) + return true, user, nil + } + + // FIXME update user record when auth fails + user, err := s.FindUserByUsername(username) + if err != nil { + return false, nil, err + } + + if user == nil { + return false, nil, nil + } + + if len(password) < 8 { + s.LogFailedAuth(user.ID, r) + return false, nil, nil + } + + // how long are bcrypt hashes anyway? + if len(user.PasswordHash) < 10 { + s.LogFailedAuth(user.ID, r) + return false, nil, nil + } + + if user.Disabled { + s.LogFailedAuth(user.ID, r) + return false, nil, nil + } + + err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) + if err != nil { + if err == bcrypt.ErrMismatchedHashAndPassword { + s.LogFailedAuth(user.ID, r) + return false, nil, nil + } + log.Error().Err(err).Msg("password comparison error") + return false, nil, err + } + + s.LogSuccessfulAuth(user.ID, r) + return true, user, nil +} + +func (u *User) Disable() error { + u.Disabled = true + result := u.db.Save(u) + if result.Error != nil { + log.Error(). + Str("email", *u.Email). + Err(result.Error). + Msg("unable to disable user") + return result.Error + } + log.Debug(). + Str("email", *u.Email). + Msg("user disabled") + return nil +} + +func (u *User) IsAdmin() bool { + return u.HasAttribute("admin") +} + +func (u *User) IsSuperAdmin() bool { + return u.HasAttribute("superadmin") +} diff --git a/internal/store/user_attribute.go b/internal/store/user_attribute.go new file mode 100644 index 0000000..773d0b2 --- /dev/null +++ b/internal/store/user_attribute.go @@ -0,0 +1,87 @@ +package store + +import ( + "strings" + + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +type UserAttribute struct { + ID uint `gorm:"primaryKey"` + UserID string `gorm:"not null"` + Value string `gorm:"not null"` + db *gorm.DB `gorm:"-"` +} + +func (UserAttribute) TableName() string { + return UserAttributesTableName +} + +func (u *User) AddAttribute(value string) error { + value = strings.ToLower(value) + if u.HasAttribute(value) { + log.Debug().Str("attribute", value).Msg("attribute already exists") + return nil + } + + attribute := &UserAttribute{ + UserID: u.ID, + Value: value, + db: u.db, + } + + result := u.db.Create(attribute) + if result.Error != nil { + log.Error().Err(result.Error).Msg("unable to add user attribute") + return result.Error + } + + log.Debug(). + Str("email", *u.Email). + Str("attribute", value). + Msg("user attribute added") + return nil +} + +func (u *User) DeleteAttribute(value string) error { + value = strings.ToLower(value) + result := u.db.Where("user_id = ? AND value = ?", u.ID, value).Delete(&UserAttribute{}) + if result.Error != nil { + log.Error().Err(result.Error).Msg("unable to delete user attribute") + return result.Error + } + + log.Debug(). + Str("email", *u.Email). + Str("attribute", value). + Msg("user attribute deleted") + return nil +} + +func (u *User) HasAttribute(value string) bool { + value = strings.ToLower(value) + var count int64 + result := u.db.Model(&UserAttribute{}).Where("user_id = ? AND value = ?", u.ID, value).Count(&count) + if result.Error != nil { + log.Error().Err(result.Error).Msg("unable to check user attribute") + return false + } + return count > 0 +} + +func (u *User) GetAttributes() ([]string, error) { + var attributes []UserAttribute + result := u.db.Where("user_id = ?", u.ID).Find(&attributes) + if result.Error != nil { + log.Error().Err(result.Error).Msg("unable to get user attributes") + return nil, result.Error + } + + var attributeValues []string + for _, attr := range attributes { + attributeValues = append(attributeValues, attr.Value) + } + + return attributeValues, nil +}