diff options
Diffstat (limited to '')
-rw-r--r-- | internal/user/handler.go | 511 | ||||
-rw-r--r-- | internal/user/model/model.go | 38 | ||||
-rw-r--r-- | internal/user/repo/repository.go | 160 |
3 files changed, 709 insertions, 0 deletions
diff --git a/internal/user/handler.go b/internal/user/handler.go new file mode 100644 index 0000000..0eee6f2 --- /dev/null +++ b/internal/user/handler.go @@ -0,0 +1,511 @@ +package user + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "html" + "net/http" + "time" + + auth "donetick.com/core/internal/authorization" + cModel "donetick.com/core/internal/circle/model" + cRepo "donetick.com/core/internal/circle/repo" + "donetick.com/core/internal/email" + uModel "donetick.com/core/internal/user/model" + uRepo "donetick.com/core/internal/user/repo" + "donetick.com/core/internal/utils" + "donetick.com/core/logging" + jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + limiter "github.com/ulule/limiter/v3" + "google.golang.org/api/googleapi" + "google.golang.org/api/oauth2/v1" +) + +type Handler struct { + userRepo *uRepo.UserRepository + circleRepo *cRepo.CircleRepository + jwtAuth *jwt.GinJWTMiddleware + email *email.EmailSender +} + +func NewHandler(ur *uRepo.UserRepository, cr *cRepo.CircleRepository, jwtAuth *jwt.GinJWTMiddleware, email *email.EmailSender) *Handler { + return &Handler{ + userRepo: ur, + circleRepo: cr, + jwtAuth: jwtAuth, + email: email, + } +} + +func (h *Handler) GetAllUsers() gin.HandlerFunc { + return func(c *gin.Context) { + currentUser, ok := auth.CurrentUser(c) + if !ok { + c.JSON(500, gin.H{ + "error": "Error getting current user", + }) + return + } + + users, err := h.userRepo.GetAllUsers(c, currentUser.CircleID) + if err != nil { + c.JSON(500, gin.H{ + "error": "Error getting users", + }) + return + } + + c.JSON(200, gin.H{ + "res": users, + }) + } +} + +func (h *Handler) signUp(c *gin.Context) { + + type SignUpReq struct { + Username string `json:"username" binding:"required,min=4,max=20"` + Password string `json:"password" binding:"required,min=8,max=45"` + DisplayName string `json:"displayName"` + } + var signupReq SignUpReq + if err := c.BindJSON(&signupReq); err != nil { + c.JSON(400, gin.H{ + "error": "Invalid request", + }) + return + } + if signupReq.DisplayName == "" { + signupReq.DisplayName = signupReq.Username + } + password, err := auth.EncodePassword(signupReq.Password) + signupReq.Username = html.EscapeString(signupReq.Username) + signupReq.DisplayName = html.EscapeString(signupReq.DisplayName) + + if err != nil { + c.JSON(500, gin.H{ + "error": "Error encoding password", + }) + return + } + var insertedUser *uModel.User + if insertedUser, err = h.userRepo.CreateUser(c, &uModel.User{ + Username: signupReq.Username, + Password: password, + DisplayName: signupReq.DisplayName, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }); err != nil { + c.JSON(500, gin.H{ + "error": "Error creating user", + }) + return + } + // var userCircle *circle.Circle + // var userRole string + userCircle, err := h.circleRepo.CreateCircle(c, &cModel.Circle{ + Name: signupReq.DisplayName + "'s circle", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + InviteCode: utils.GenerateInviteCode(c), + }) + + if err != nil { + c.JSON(500, gin.H{ + "error": "Error creating circle", + }) + return + } + + if err := h.circleRepo.AddUserToCircle(c, &cModel.UserCircle{ + UserID: insertedUser.ID, + CircleID: userCircle.ID, + Role: "admin", + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }); err != nil { + c.JSON(500, gin.H{ + "error": "Error adding user to circle", + }) + return + } + insertedUser.CircleID = userCircle.ID + if err := h.userRepo.UpdateUser(c, insertedUser); err != nil { + c.JSON(500, gin.H{ + "error": "Error updating user", + }) + return + } + + c.JSON(201, gin.H{}) +} + +func (h *Handler) GetUserProfile(c *gin.Context) { + user, ok := auth.CurrentUser(c) + if !ok { + c.JSON(500, gin.H{ + "error": "Error getting user", + }) + return + } + c.JSON(200, gin.H{ + "res": user, + }) +} + +func (h *Handler) thirdPartyAuthCallback(c *gin.Context) { + + // read :provider from path param, if param is google check the token with google if it's valid and fetch the user details: + logger := logging.FromContext(c) + provider := c.Param("provider") + logger.Infow("account.handler.thirdPartyAuthCallback", "provider", provider) + + if provider == "google" { + c.Set("auth_provider", "3rdPartyAuth") + type OAuthRequest struct { + Token string `json:"token" binding:"required"` + Provider string `json:"provider" binding:"required"` + } + var body OAuthRequest + if err := c.ShouldBindJSON(&body); err != nil { + logger.Errorw("account.handler.thirdPartyAuthCallback failed to bind", "err", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid request", + }) + return + } + + // logger.Infow("account.handler.thirdPartyAuthCallback", "token", token) + service, err := oauth2.New(http.DefaultClient) + + // tokenInfo, err := service.Tokeninfo().AccessToken(token).Do() + userinfo, err := service.Userinfo.Get().Do(googleapi.QueryParameter("access_token", body.Token)) + logger.Infow("account.handler.thirdPartyAuthCallback", "tokenInfo", userinfo) + if err != nil { + logger.Errorw("account.handler.thirdPartyAuthCallback failed to get token info", "err", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid token", + }) + return + } + + acc, err := h.userRepo.FindByEmail(c, userinfo.Email) + + if err != nil { + // create a random password for the user using crypto/rand: + password := auth.GenerateRandomPassword(12) + encodedPassword, err := auth.EncodePassword(password) + acc = &uModel.User{ + Username: userinfo.Id, + Email: userinfo.Email, + Image: userinfo.Picture, + Password: encodedPassword, + DisplayName: userinfo.GivenName, + Provider: 2, + } + createdUser, err := h.userRepo.CreateUser(c, acc) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to create user", + }) + return + + } + // Create Circle for the user: + userCircle, err := h.circleRepo.CreateCircle(c, &cModel.Circle{ + Name: userinfo.GivenName + "'s circle", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + InviteCode: utils.GenerateInviteCode(c), + }) + + if err != nil { + c.JSON(500, gin.H{ + "error": "Error creating circle", + }) + return + } + + if err := h.circleRepo.AddUserToCircle(c, &cModel.UserCircle{ + UserID: createdUser.ID, + CircleID: userCircle.ID, + Role: "admin", + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }); err != nil { + c.JSON(500, gin.H{ + "error": "Error adding user to circle", + }) + return + } + createdUser.CircleID = userCircle.ID + if err := h.userRepo.UpdateUser(c, createdUser); err != nil { + c.JSON(500, gin.H{ + "error": "Error updating user", + }) + return + } + } + // use auth to generate a token for the user: + c.Set("user_account", acc) + h.jwtAuth.Authenticator(c) + tokenString, expire, err := h.jwtAuth.TokenGenerator(acc) + if err != nil { + logger.Errorw("Unable to Generate a Token") + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to Generate a Token", + }) + return + } + c.JSON(http.StatusOK, gin.H{"token": tokenString, "expire": expire}) + return + } +} + +func (h *Handler) resetPassword(c *gin.Context) { + log := logging.FromContext(c) + type ResetPasswordReq struct { + Email string `json:"email" binding:"required,email"` + } + var req ResetPasswordReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid request", + }) + return + } + user, err := h.userRepo.FindByEmail(c, req.Email) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{ + "error": "User not found", + }) + return + } + if user.Provider != 0 { + // user create account thought login with Gmail. they can reset the password they just need to login with google again + c.JSON( + http.StatusForbidden, + gin.H{ + "error": "User account created with google login. Please login with google", + }, + ) + return + } + // generate a random password: + token, err := auth.GenerateEmailResetToken(c) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to generate token", + }) + return + } + + err = h.userRepo.SetPasswordResetToken(c, req.Email, token) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to generate password", + }) + return + } + // send an email to the user with the new password: + err = h.email.SendResetPasswordEmail(c, req.Email, token) + if err != nil { + log.Errorw("account.handler.resetPassword failed to send email", "err", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to send email", + }) + return + } + + // send an email to the user with the new password: + c.JSON(http.StatusOK, gin.H{}) +} + +func (h *Handler) updateUserPassword(c *gin.Context) { + logger := logging.FromContext(c) + // read the code from query param: + code := c.Query("c") + email, code, err := email.DecodeEmailAndCode(code) + if err != nil { + logger.Errorw("account.handler.verify failed to decode email and code", "err", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid code", + }) + return + + } + // read password from body: + type RequestBody struct { + Password string `json:"password" binding:"required,min=8,max=32"` + } + var body RequestBody + if err := c.ShouldBindJSON(&body); err != nil { + logger.Errorw("user.handler.resetAccountPassword failed to bind", "err", err) + c.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid request", + }) + return + + } + password, err := auth.EncodePassword(body.Password) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to process password", + }) + return + } + + err = h.userRepo.UpdatePasswordByToken(c.Request.Context(), email, code, password) + if err != nil { + logger.Errorw("account.handler.resetAccountPassword failed to reset password", "err", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Unable to reset password", + }) + return + } + c.JSON(http.StatusOK, gin.H{}) + +} + +func (h *Handler) UpdateUserDetails(c *gin.Context) { + type UpdateUserReq struct { + DisplayName *string `json:"displayName" binding:"omitempty"` + ChatID *int64 `json:"chatID" binding:"omitempty"` + Image *string `json:"image" binding:"omitempty"` + } + user, ok := auth.CurrentUser(c) + if !ok { + c.JSON(500, gin.H{ + "error": "Error getting user", + }) + return + } + var req UpdateUserReq + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(400, gin.H{ + "error": "Invalid request", + }) + return + } + // update non-nil fields: + if req.DisplayName != nil { + user.DisplayName = *req.DisplayName + } + if req.ChatID != nil { + user.ChatID = *req.ChatID + } + if req.Image != nil { + user.Image = *req.Image + } + + if err := h.userRepo.UpdateUser(c, user); err != nil { + c.JSON(500, gin.H{ + "error": "Error updating user", + }) + return + } + c.JSON(200, user) +} + +func (h *Handler) CreateLongLivedToken(c *gin.Context) { + currentUser, ok := auth.CurrentUser(c) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get current user"}) + return + } + type TokenRequest struct { + Name string `json:"name" binding:"required"` + } + var req TokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"}) + return + } + + // Step 1: Generate a secure random number + randomBytes := make([]byte, 16) // 128 bits are enough for strong randomness + _, err := rand.Read(randomBytes) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate random part of the token"}) + return + } + + timestamp := time.Now().Unix() + hashInput := fmt.Sprintf("%s:%d:%x", currentUser.Username, timestamp, randomBytes) + hash := sha256.Sum256([]byte(hashInput)) + + token := hex.EncodeToString(hash[:]) + + tokenModel, err := h.userRepo.StoreAPIToken(c, currentUser.ID, req.Name, token) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to store the token"}) + return + } + + c.JSON(http.StatusOK, gin.H{"res": tokenModel}) +} + +func (h *Handler) GetAllUserToken(c *gin.Context) { + currentUser, ok := auth.CurrentUser(c) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get current user"}) + return + } + + tokens, err := h.userRepo.GetAllUserTokens(c, currentUser.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user tokens"}) + return + } + + c.JSON(http.StatusOK, gin.H{"res": tokens}) + +} + +func (h *Handler) DeleteUserToken(c *gin.Context) { + currentUser, ok := auth.CurrentUser(c) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get current user"}) + return + } + + tokenID := c.Param("id") + + err := h.userRepo.DeleteAPIToken(c, currentUser.ID, tokenID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete the token"}) + return + } + + c.JSON(http.StatusOK, gin.H{}) +} + +func Routes(router *gin.Engine, h *Handler, auth *jwt.GinJWTMiddleware, limiter *limiter.Limiter) { + + userRoutes := router.Group("users") + userRoutes.Use(auth.MiddlewareFunc(), utils.RateLimitMiddleware(limiter)) + { + userRoutes.GET("/", h.GetAllUsers()) + userRoutes.GET("/profile", h.GetUserProfile) + userRoutes.PUT("", h.UpdateUserDetails) + userRoutes.POST("/tokens", h.CreateLongLivedToken) + userRoutes.GET("/tokens", h.GetAllUserToken) + userRoutes.DELETE("/tokens/:id", h.DeleteUserToken) + } + + authRoutes := router.Group("auth") + authRoutes.Use(utils.RateLimitMiddleware(limiter)) + { + authRoutes.POST("/", h.signUp) + authRoutes.POST("login", auth.LoginHandler) + authRoutes.GET("refresh", auth.RefreshHandler) + authRoutes.POST("/:provider/callback", h.thirdPartyAuthCallback) + authRoutes.POST("reset", h.resetPassword) + authRoutes.POST("password", h.updateUserPassword) + } +} diff --git a/internal/user/model/model.go b/internal/user/model/model.go new file mode 100644 index 0000000..4874ac1 --- /dev/null +++ b/internal/user/model/model.go @@ -0,0 +1,38 @@ +package user + +import "time" + +type User struct { + ID int `json:"id" gorm:"primary_key"` // Unique identifier + DisplayName string `json:"displayName" gorm:"column:display_name"` // Display name + Username string `json:"username" gorm:"column:username;unique"` // Username (unique) + Email string `json:"email" gorm:"column:email;unique"` // Email (unique) + Provider int `json:"provider" gorm:"column:provider"` // Provider + Password string `json:"-" gorm:"column:password"` // Password + CircleID int `json:"circleID" gorm:"column:circle_id"` // Circle ID + ChatID int64 `json:"chatID" gorm:"column:chat_id"` // Telegram chat ID + Image string `json:"image" gorm:"column:image"` // Image + CreatedAt time.Time `json:"created_at" gorm:"column:created_at"` // Created at + UpdatedAt time.Time `json:"updated_at" gorm:"column:updated_at"` // Updated at + Disabled bool `json:"disabled" gorm:"column:disabled"` // Disabled + // Email string `json:"email" gorm:"column:email"` // Email + CustomerID *string `gorm:"column:customer_id;<-:false"` // read one column + Subscription *string `json:"subscription" gorm:"column:subscription;<-:false"` // read one column + Expiration *string `json:"expiration" gorm:"column:expiration;<-:false"` // read one column +} + +type UserPasswordReset struct { + ID int `gorm:"column:id"` + UserID int `gorm:"column:user_id"` + Email string `gorm:"column:email"` + Token string `gorm:"column:token"` + ExpirationDate time.Time `gorm:"column:expiration_date"` +} + +type APIToken struct { + ID int `json:"id" gorm:"primary_key"` // Unique identifier + Name string `json:"name" gorm:"column:name;unique"` // Name (unique) + UserID int `json:"userId" gorm:"column:user_id;index"` // Index on userID + Token string `json:"token" gorm:"column:token;index"` // Index on token + CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"` +} diff --git a/internal/user/repo/repository.go b/internal/user/repo/repository.go new file mode 100644 index 0000000..fa53753 --- /dev/null +++ b/internal/user/repo/repository.go @@ -0,0 +1,160 @@ +package user + +import ( + "context" + "fmt" + "time" + + "donetick.com/core/config" + uModel "donetick.com/core/internal/user/model" + "donetick.com/core/logging" + "gorm.io/gorm" +) + +type IUserRepository interface { + GetUserByUsername(username string) (*uModel.User, error) + GetUser(id int) (*uModel.User, error) + GetAllUsers() ([]*uModel.User, error) + CreateUser(user *uModel.User) error + UpdateUser(user *uModel.User) error + UpdateUserCircle(userID, circleID int) error + FindByEmail(email string) (*uModel.User, error) +} + +type UserRepository struct { + db *gorm.DB + isDonetickDotCom bool +} + +func NewUserRepository(db *gorm.DB, cfg *config.Config) *UserRepository { + return &UserRepository{db, cfg.IsDoneTickDotCom} +} + +func (r *UserRepository) GetAllUsers(c context.Context, circleID int) ([]*uModel.User, error) { + var users []*uModel.User + if err := r.db.WithContext(c).Where("circle_id = ?", circleID).Find(&users).Error; err != nil { + return nil, err + } + return users, nil +} + +func (r *UserRepository) GetAllUsersForSystemOnly(c context.Context) ([]*uModel.User, error) { + var users []*uModel.User + if err := r.db.WithContext(c).Find(&users).Error; err != nil { + return nil, err + } + return users, nil +} +func (r *UserRepository) CreateUser(c context.Context, user *uModel.User) (*uModel.User, error) { + if err := r.db.WithContext(c).Save(user).Error; err != nil { + return nil, err + } + return user, nil +} +func (r *UserRepository) GetUserByUsername(c context.Context, username string) (*uModel.User, error) { + var user *uModel.User + if r.isDonetickDotCom { + if err := r.db.WithContext(c).Table("users u").Select("u.*, ss.status as subscription, ss.expired_at as expiration").Joins("left join stripe_customers sc on sc.user_id = u.id ").Joins("left join stripe_subscriptions ss on sc.customer_id = ss.customer_id").Where("username = ?", username).First(&user).Error; err != nil { + return nil, err + } + } else { + if err := r.db.WithContext(c).Where("username = ?", username).First(&user).Error; err != nil { + return nil, err + } + } + + return user, nil +} + +func (r *UserRepository) UpdateUser(c context.Context, user *uModel.User) error { + return r.db.WithContext(c).Save(user).Error +} + +func (r *UserRepository) UpdateUserCircle(c context.Context, userID, circleID int) error { + return r.db.WithContext(c).Model(&uModel.User{}).Where("id = ?", userID).Update("circle_id", circleID).Error +} + +func (r *UserRepository) FindByEmail(c context.Context, email string) (*uModel.User, error) { + var user *uModel.User + if err := r.db.WithContext(c).Where("email = ?", email).First(&user).Error; err != nil { + return nil, err + } + return user, nil +} + +func (r *UserRepository) SetPasswordResetToken(c context.Context, email, token string) error { + // confirm user exists with email: + user, err := r.FindByEmail(c, email) + if err != nil { + return err + } + // save new token: + if err := r.db.WithContext(c).Model(&uModel.UserPasswordReset{}).Save(&uModel.UserPasswordReset{ + UserID: user.ID, + Token: token, + Email: email, + ExpirationDate: time.Now().UTC().Add(time.Hour * 24), + }).Error; err != nil { + return err + } + return nil +} + +func (r *UserRepository) UpdatePasswordByToken(ctx context.Context, email string, token string, password string) error { + logger := logging.FromContext(ctx) + + logger.Debugw("account.db.UpdatePasswordByToken", "email", email) + upr := &uModel.UserPasswordReset{ + Email: email, + Token: token, + } + result := r.db.WithContext(ctx).Where("email = ?", email).Where("token = ?", token).Delete(upr) + if result.RowsAffected <= 0 { + return fmt.Errorf("invalid token") + } + // find account by email and update password: + chain := r.db.WithContext(ctx).Model(&uModel.User{}).Where("email = ?", email).UpdateColumns(map[string]interface{}{"password": password}) + if chain.Error != nil { + return chain.Error + } + if chain.RowsAffected == 0 { + return fmt.Errorf("account not found") + } + + return nil +} + +func (r *UserRepository) StoreAPIToken(c context.Context, userID int, name string, tokenCode string) (*uModel.APIToken, error) { + token := &uModel.APIToken{ + UserID: userID, + Name: name, + Token: tokenCode, + CreatedAt: time.Now().UTC(), + } + if err := r.db.WithContext(c).Model(&uModel.APIToken{}).Save( + token).Error; err != nil { + return nil, err + + } + return token, nil +} + +func (r *UserRepository) GetUserByToken(c context.Context, token string) (*uModel.User, error) { + var user *uModel.User + if err := r.db.WithContext(c).Table("users u").Select("u.*").Joins("left join api_tokens at on at.user_id = u.id").Where("at.token = ?", token).First(&user).Error; err != nil { + return nil, err + } + return user, nil +} + +func (r *UserRepository) GetAllUserTokens(c context.Context, userID int) ([]*uModel.APIToken, error) { + var tokens []*uModel.APIToken + if err := r.db.WithContext(c).Where("user_id = ?", userID).Find(&tokens).Error; err != nil { + return nil, err + } + return tokens, nil +} + +func (r *UserRepository) DeleteAPIToken(c context.Context, userID int, tokenID string) error { + return r.db.WithContext(c).Where("id = ? AND user_id = ?", tokenID, userID).Delete(&uModel.APIToken{}).Error +} |