wms-be/modules/auth/service/jwt_service.go

121 lines
2.7 KiB
Go

package service
import (
"crypto/rand"
"encoding/base64"
"fmt"
"log"
"os"
"time"
"github.com/golang-jwt/jwt/v4"
)
type UserTokenInfo struct {
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
RoleID string `json:"role_id"`
}
type JWTService interface {
GenerateAccessToken(clientId string, userId string, roleId string) string
GenerateRefreshToken() (string, time.Time)
ValidateToken(token string) (*jwt.Token, error)
GetUserIDByToken(token string) (*UserTokenInfo, error)
}
type jwtCustomClaim struct {
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
RoleID string `json:"role_id"`
jwt.RegisteredClaims
}
type jwtService struct {
secretKey string
issuer string
accessExpiry time.Duration
refreshExpiry time.Duration
}
func NewJWTService() JWTService {
return &jwtService{
secretKey: getSecretKey(),
issuer: "WMS-Wareify",
accessExpiry: time.Hour * 8,
refreshExpiry: time.Hour * 24 * 7,
}
}
func getSecretKey() string {
secretKey := os.Getenv("JWT_SECRET")
if secretKey == "" {
secretKey = "WMS-WareifySecretKey"
}
return secretKey
}
func (j *jwtService) GenerateAccessToken(clientId string, userId string, roleId string) string {
claims := jwtCustomClaim{
ClientID: clientId,
UserID: userId,
RoleID: roleId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.accessExpiry)),
Issuer: j.issuer,
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tx, err := token.SignedString([]byte(j.secretKey))
if err != nil {
log.Println(err)
}
return tx
}
func (j *jwtService) GenerateRefreshToken() (string, time.Time) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
log.Println(err)
return "", time.Time{}
}
refreshToken := base64.StdEncoding.EncodeToString(b)
expiresAt := time.Now().Add(j.refreshExpiry)
return refreshToken, expiresAt
}
func (j *jwtService) parseToken(t_ *jwt.Token) (any, error) {
if _, ok := t_.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method %v", t_.Header["alg"])
}
return []byte(j.secretKey), nil
}
func (j *jwtService) ValidateToken(token string) (*jwt.Token, error) {
return jwt.Parse(token, j.parseToken)
}
func (j *jwtService) GetUserIDByToken(token string) (*UserTokenInfo, error) {
tToken, err := j.ValidateToken(token)
if err != nil {
return nil, err
}
claims := tToken.Claims.(jwt.MapClaims)
userId := fmt.Sprintf("%v", claims["user_id"])
clientId := fmt.Sprintf("%v", claims["client_id"])
roleId := fmt.Sprintf("%v", claims["role_id"])
return &UserTokenInfo{
UserID: userId,
ClientID: clientId,
RoleID: roleId,
}, nil
}