aboutsummaryrefslogtreecommitdiffstats
path: root/internal/user/repo
diff options
context:
space:
mode:
Diffstat (limited to 'internal/user/repo')
-rw-r--r--internal/user/repo/repository.go160
1 files changed, 160 insertions, 0 deletions
diff --git a/internal/user/repo/repository.go b/internal/user/repo/repository.go
new file mode 100644
index 0000000..fa53753
--- /dev/null
+++ b/internal/user/repo/repository.go
@@ -0,0 +1,160 @@
+package user
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "donetick.com/core/config"
+ uModel "donetick.com/core/internal/user/model"
+ "donetick.com/core/logging"
+ "gorm.io/gorm"
+)
+
+type IUserRepository interface {
+ GetUserByUsername(username string) (*uModel.User, error)
+ GetUser(id int) (*uModel.User, error)
+ GetAllUsers() ([]*uModel.User, error)
+ CreateUser(user *uModel.User) error
+ UpdateUser(user *uModel.User) error
+ UpdateUserCircle(userID, circleID int) error
+ FindByEmail(email string) (*uModel.User, error)
+}
+
+type UserRepository struct {
+ db *gorm.DB
+ isDonetickDotCom bool
+}
+
+func NewUserRepository(db *gorm.DB, cfg *config.Config) *UserRepository {
+ return &UserRepository{db, cfg.IsDoneTickDotCom}
+}
+
+func (r *UserRepository) GetAllUsers(c context.Context, circleID int) ([]*uModel.User, error) {
+ var users []*uModel.User
+ if err := r.db.WithContext(c).Where("circle_id = ?", circleID).Find(&users).Error; err != nil {
+ return nil, err
+ }
+ return users, nil
+}
+
+func (r *UserRepository) GetAllUsersForSystemOnly(c context.Context) ([]*uModel.User, error) {
+ var users []*uModel.User
+ if err := r.db.WithContext(c).Find(&users).Error; err != nil {
+ return nil, err
+ }
+ return users, nil
+}
+func (r *UserRepository) CreateUser(c context.Context, user *uModel.User) (*uModel.User, error) {
+ if err := r.db.WithContext(c).Save(user).Error; err != nil {
+ return nil, err
+ }
+ return user, nil
+}
+func (r *UserRepository) GetUserByUsername(c context.Context, username string) (*uModel.User, error) {
+ var user *uModel.User
+ if r.isDonetickDotCom {
+ if err := r.db.WithContext(c).Table("users u").Select("u.*, ss.status as subscription, ss.expired_at as expiration").Joins("left join stripe_customers sc on sc.user_id = u.id ").Joins("left join stripe_subscriptions ss on sc.customer_id = ss.customer_id").Where("username = ?", username).First(&user).Error; err != nil {
+ return nil, err
+ }
+ } else {
+ if err := r.db.WithContext(c).Where("username = ?", username).First(&user).Error; err != nil {
+ return nil, err
+ }
+ }
+
+ return user, nil
+}
+
+func (r *UserRepository) UpdateUser(c context.Context, user *uModel.User) error {
+ return r.db.WithContext(c).Save(user).Error
+}
+
+func (r *UserRepository) UpdateUserCircle(c context.Context, userID, circleID int) error {
+ return r.db.WithContext(c).Model(&uModel.User{}).Where("id = ?", userID).Update("circle_id", circleID).Error
+}
+
+func (r *UserRepository) FindByEmail(c context.Context, email string) (*uModel.User, error) {
+ var user *uModel.User
+ if err := r.db.WithContext(c).Where("email = ?", email).First(&user).Error; err != nil {
+ return nil, err
+ }
+ return user, nil
+}
+
+func (r *UserRepository) SetPasswordResetToken(c context.Context, email, token string) error {
+ // confirm user exists with email:
+ user, err := r.FindByEmail(c, email)
+ if err != nil {
+ return err
+ }
+ // save new token:
+ if err := r.db.WithContext(c).Model(&uModel.UserPasswordReset{}).Save(&uModel.UserPasswordReset{
+ UserID: user.ID,
+ Token: token,
+ Email: email,
+ ExpirationDate: time.Now().UTC().Add(time.Hour * 24),
+ }).Error; err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r *UserRepository) UpdatePasswordByToken(ctx context.Context, email string, token string, password string) error {
+ logger := logging.FromContext(ctx)
+
+ logger.Debugw("account.db.UpdatePasswordByToken", "email", email)
+ upr := &uModel.UserPasswordReset{
+ Email: email,
+ Token: token,
+ }
+ result := r.db.WithContext(ctx).Where("email = ?", email).Where("token = ?", token).Delete(upr)
+ if result.RowsAffected <= 0 {
+ return fmt.Errorf("invalid token")
+ }
+ // find account by email and update password:
+ chain := r.db.WithContext(ctx).Model(&uModel.User{}).Where("email = ?", email).UpdateColumns(map[string]interface{}{"password": password})
+ if chain.Error != nil {
+ return chain.Error
+ }
+ if chain.RowsAffected == 0 {
+ return fmt.Errorf("account not found")
+ }
+
+ return nil
+}
+
+func (r *UserRepository) StoreAPIToken(c context.Context, userID int, name string, tokenCode string) (*uModel.APIToken, error) {
+ token := &uModel.APIToken{
+ UserID: userID,
+ Name: name,
+ Token: tokenCode,
+ CreatedAt: time.Now().UTC(),
+ }
+ if err := r.db.WithContext(c).Model(&uModel.APIToken{}).Save(
+ token).Error; err != nil {
+ return nil, err
+
+ }
+ return token, nil
+}
+
+func (r *UserRepository) GetUserByToken(c context.Context, token string) (*uModel.User, error) {
+ var user *uModel.User
+ if err := r.db.WithContext(c).Table("users u").Select("u.*").Joins("left join api_tokens at on at.user_id = u.id").Where("at.token = ?", token).First(&user).Error; err != nil {
+ return nil, err
+ }
+ return user, nil
+}
+
+func (r *UserRepository) GetAllUserTokens(c context.Context, userID int) ([]*uModel.APIToken, error) {
+ var tokens []*uModel.APIToken
+ if err := r.db.WithContext(c).Where("user_id = ?", userID).Find(&tokens).Error; err != nil {
+ return nil, err
+ }
+ return tokens, nil
+}
+
+func (r *UserRepository) DeleteAPIToken(c context.Context, userID int, tokenID string) error {
+ return r.db.WithContext(c).Where("id = ? AND user_id = ?", tokenID, userID).Delete(&uModel.APIToken{}).Error
+}