115 lines
2.5 KiB
Go
115 lines
2.5 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
)
|
|
|
|
type UserTokenInfo struct {
|
|
ClientID string `json:"client_id"`
|
|
UserID string `json:"user_id"`
|
|
}
|
|
|
|
type JWTService interface {
|
|
GenerateAccessToken(clientId string, userId string) string
|
|
GenerateRefreshToken() (string, time.Time)
|
|
ValidateToken(token string) (*jwt.Token, error)
|
|
GetUserIDByToken(token string) (*UserTokenInfo, error)
|
|
}
|
|
|
|
type jwtCustomClaim struct {
|
|
ClientID string `json:"client_id"`
|
|
UserID string `json:"user_id"`
|
|
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
type jwtService struct {
|
|
secretKey string
|
|
issuer string
|
|
accessExpiry time.Duration
|
|
refreshExpiry time.Duration
|
|
}
|
|
|
|
func NewJWTService() JWTService {
|
|
return &jwtService{
|
|
secretKey: getSecretKey(),
|
|
issuer: "Template",
|
|
accessExpiry: time.Minute * 15,
|
|
refreshExpiry: time.Hour * 24 * 7,
|
|
}
|
|
}
|
|
|
|
func getSecretKey() string {
|
|
secretKey := os.Getenv("JWT_SECRET")
|
|
if secretKey == "" {
|
|
secretKey = "Template"
|
|
}
|
|
return secretKey
|
|
}
|
|
|
|
func (j *jwtService) GenerateAccessToken(clientId string, userId string) string {
|
|
claims := jwtCustomClaim{
|
|
ClientID: clientId,
|
|
UserID: userId,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(j.accessExpiry)),
|
|
Issuer: j.issuer,
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
tx, err := token.SignedString([]byte(j.secretKey))
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
return tx
|
|
}
|
|
|
|
func (j *jwtService) GenerateRefreshToken() (string, time.Time) {
|
|
b := make([]byte, 32)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return "", time.Time{}
|
|
}
|
|
|
|
refreshToken := base64.StdEncoding.EncodeToString(b)
|
|
expiresAt := time.Now().Add(j.refreshExpiry)
|
|
|
|
return refreshToken, expiresAt
|
|
}
|
|
|
|
func (j *jwtService) parseToken(t_ *jwt.Token) (any, error) {
|
|
if _, ok := t_.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method %v", t_.Header["alg"])
|
|
}
|
|
return []byte(j.secretKey), nil
|
|
}
|
|
|
|
func (j *jwtService) ValidateToken(token string) (*jwt.Token, error) {
|
|
return jwt.Parse(token, j.parseToken)
|
|
}
|
|
|
|
func (j *jwtService) GetUserIDByToken(token string) (*UserTokenInfo, error) {
|
|
tToken, err := j.ValidateToken(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
claims := tToken.Claims.(jwt.MapClaims)
|
|
userId := fmt.Sprintf("%v", claims["user_id"])
|
|
clientId := fmt.Sprintf("%v", claims["client_id"])
|
|
return &UserTokenInfo{
|
|
UserID: userId,
|
|
ClientID: clientId,
|
|
}, nil
|
|
}
|