package service import ( "crypto/rand" "encoding/base64" "fmt" "log" "time" "github.com/golang-jwt/jwt/v4" "github.com/spf13/viper" ) type UserTokenInfo struct { ClientID string `json:"client_id"` UserID string `json:"user_id"` RoleID string `json:"role_id"` WarehouseID string `json:"warehouse_id"` } type JWTService interface { GenerateAccessToken(clientId string, userId string, roleId string, warehouseId string) string GenerateRefreshToken() (string, time.Time) ValidateToken(token string) (*jwt.Token, error) GetUserIDByToken(token string) (*UserTokenInfo, error) } type jwtCustomClaim struct { ClientID string `json:"client_id"` UserID string `json:"user_id"` RoleID string `json:"role_id"` WarehouseID string `json:"warehouse_id"` jwt.RegisteredClaims } type jwtService struct { secretKey string issuer string accessExpiry time.Duration refreshExpiry time.Duration } func NewJWTService() JWTService { return &jwtService{ secretKey: getSecretKey(), issuer: "WMS-Wareify", accessExpiry: time.Hour * 8, refreshExpiry: time.Hour * 24 * 7, } } func getSecretKey() string { secretKey := viper.GetString("JWT_SECRET") if secretKey == "" { secretKey = "WMS-WareifySecretKey" } return secretKey } func (j *jwtService) GenerateAccessToken(clientId string, userId string, roleId string, warehouseId string) string { claims := jwtCustomClaim{ ClientID: clientId, UserID: userId, RoleID: roleId, WarehouseID: warehouseId, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.accessExpiry)), Issuer: j.issuer, IssuedAt: jwt.NewNumericDate(time.Now()), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tx, err := token.SignedString([]byte(j.secretKey)) if err != nil { log.Println(err) } return tx } func (j *jwtService) GenerateRefreshToken() (string, time.Time) { b := make([]byte, 32) _, err := rand.Read(b) if err != nil { log.Println(err) return "", time.Time{} } refreshToken := base64.StdEncoding.EncodeToString(b) expiresAt := time.Now().Add(j.refreshExpiry) return refreshToken, expiresAt } func (j *jwtService) parseToken(t_ *jwt.Token) (any, error) { if _, ok := t_.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method %v", t_.Header["alg"]) } return []byte(j.secretKey), nil } func (j *jwtService) ValidateToken(token string) (*jwt.Token, error) { return jwt.Parse(token, j.parseToken) } func (j *jwtService) GetUserIDByToken(token string) (*UserTokenInfo, error) { tToken, err := j.ValidateToken(token) if err != nil { return nil, err } claims := tToken.Claims.(jwt.MapClaims) userId := fmt.Sprintf("%v", claims["user_id"]) clientId := fmt.Sprintf("%v", claims["client_id"]) roleId := fmt.Sprintf("%v", claims["role_id"]) warehouseId := fmt.Sprintf("%v", claims["warehouse_id"]) return &UserTokenInfo{ UserID: userId, ClientID: clientId, RoleID: roleId, WarehouseID: warehouseId, }, nil }