diff --git a/middlewares/authentication.go b/middlewares/authentication.go index effbf5d..ea55b41 100644 --- a/middlewares/authentication.go +++ b/middlewares/authentication.go @@ -40,7 +40,7 @@ func Authenticate(jwtService service.JWTService) gin.HandlerFunc { return } - userId, err := jwtService.GetUserIDByToken(authHeader) + tokenInfo, err := jwtService.GetUserIDByToken(authHeader) if err != nil { response := utils.BuildResponseFailed(dto.MESSAGE_FAILED_PROSES_REQUEST, err.Error(), nil) ctx.AbortWithStatusJSON(http.StatusUnauthorized, response) @@ -48,7 +48,8 @@ func Authenticate(jwtService service.JWTService) gin.HandlerFunc { } ctx.Set("token", authHeader) - ctx.Set("user_id", userId) + ctx.Set("tenant_id", tokenInfo.TenantID) + ctx.Set("user_id", tokenInfo.UserID) ctx.Next() } } diff --git a/modules/auth/service/jwt_service.go b/modules/auth/service/jwt_service.go index 91669c4..38fe3bf 100644 --- a/modules/auth/service/jwt_service.go +++ b/modules/auth/service/jwt_service.go @@ -11,16 +11,23 @@ import ( "github.com/golang-jwt/jwt/v4" ) +type UserTokenInfo struct { + TenantID string `json:"tenant_id"` + UserID string `json:"user_id"` +} + type JWTService interface { - GenerateAccessToken(userId string, role string) string + GenerateAccessToken(tenantId string, userId string, role string) string GenerateRefreshToken() (string, time.Time) ValidateToken(token string) (*jwt.Token, error) - GetUserIDByToken(token string) (string, error) + GetUserIDByToken(token string) (*UserTokenInfo, error) } type jwtCustomClaim struct { - UserID string `json:"user_id"` - Role string `json:"role"` + TenantID string `json:"tenant_id"` + UserID string `json:"user_id"` + Role string `json:"role"` + jwt.RegisteredClaims } @@ -48,11 +55,12 @@ func getSecretKey() string { return secretKey } -func (j *jwtService) GenerateAccessToken(userId string, role string) string { +func (j *jwtService) GenerateAccessToken(tenantId string, userId string, role string) string { claims := jwtCustomClaim{ - userId, - role, - jwt.RegisteredClaims{ + TenantID: tenantId, + UserID: userId, + Role: role, + RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.accessExpiry)), Issuer: j.issuer, IssuedAt: jwt.NewNumericDate(time.Now()), @@ -92,13 +100,17 @@ func (j *jwtService) ValidateToken(token string) (*jwt.Token, error) { return jwt.Parse(token, j.parseToken) } -func (j *jwtService) GetUserIDByToken(token string) (string, error) { +func (j *jwtService) GetUserIDByToken(token string) (*UserTokenInfo, error) { tToken, err := j.ValidateToken(token) if err != nil { - return "", err + return nil, err } claims := tToken.Claims.(jwt.MapClaims) - id := fmt.Sprintf("%v", claims["user_id"]) - return id, nil + userId := fmt.Sprintf("%v", claims["user_id"]) + tenantId := fmt.Sprintf("%v", claims["tenant_id"]) + return &UserTokenInfo{ + UserID: userId, + TenantID: tenantId, + }, nil } diff --git a/modules/user/service/user_service.go b/modules/user/service/user_service.go index 66f197f..d2be3f6 100644 --- a/modules/user/service/user_service.go +++ b/modules/user/service/user_service.go @@ -111,12 +111,13 @@ func (s *userService) Verify(ctx context.Context, req dto.UserLoginRequest) (aut return authDto.TokenResponse{}, dto.ErrUserNotFound } - accessToken := s.jwtService.GenerateAccessToken(user.ID.String(), user.Role) + accessToken := s.jwtService.GenerateAccessToken(user.TenantID.String(), user.ID.String(), user.Role) refreshTokenString, expiresAt := s.jwtService.GenerateRefreshToken() refreshToken := entities.RefreshToken{ ID: uuid.New(), UserID: user.ID, + TenantID: user.TenantID, Token: refreshTokenString, ExpiresAt: expiresAt, } @@ -143,7 +144,7 @@ func (s *userService) SendVerificationEmail(ctx context.Context, req dto.SendVer return dto.ErrAccountAlreadyVerified } - verificationToken := s.jwtService.GenerateAccessToken(user.ID.String(), "verification") + verificationToken := s.jwtService.GenerateAccessToken(user.TenantID.String(), user.ID.String(), "verification") subject := "Email Verification" body := "Please verify your email using this token: " + verificationToken @@ -157,12 +158,12 @@ func (s *userService) VerifyEmail(ctx context.Context, req dto.VerifyEmailReques return dto.VerifyEmailResponse{}, dto.ErrTokenInvalid } - userId, err := s.jwtService.GetUserIDByToken(req.Token) + userTokenInfo, err := s.jwtService.GetUserIDByToken(req.Token) if err != nil { return dto.VerifyEmailResponse{}, dto.ErrTokenInvalid } - user, err := s.userRepository.GetUserById(ctx, s.db, userId) + user, err := s.userRepository.GetUserById(ctx, s.db, userTokenInfo.UserID) if err != nil { return dto.VerifyEmailResponse{}, dto.ErrUserNotFound } @@ -220,7 +221,7 @@ func (s *userService) RefreshToken(ctx context.Context, req authDto.RefreshToken return authDto.TokenResponse{}, err } - accessToken := s.jwtService.GenerateAccessToken(refreshToken.UserID.String(), refreshToken.User.Role) + accessToken := s.jwtService.GenerateAccessToken(refreshToken.TenantID.String(), refreshToken.UserID.String(), refreshToken.User.Role) newRefreshTokenString, expiresAt := s.jwtService.GenerateRefreshToken() err = s.refreshTokenRepository.DeleteByToken(ctx, s.db, req.RefreshToken)