fix:create auth token can add expire time

pull/2352/head
icey-yu 1 year ago
parent e32d30f287
commit b2824a92c3

@ -21,6 +21,7 @@ import (
"github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/stringutil" "github.com/openimsdk/tools/utils/stringutil"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"time"
) )
type tokenCache struct { type tokenCache struct {
@ -33,10 +34,22 @@ func NewTokenCacheModel(rdb redis.UniversalClient) cache.TokenModel {
} }
} }
func (c *tokenCache) AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error { func (c *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
return errs.Wrap(c.rdb.HSet(ctx, cachekey.GetTokenKey(userID, platformID), token, flag).Err()) return errs.Wrap(c.rdb.HSet(ctx, cachekey.GetTokenKey(userID, platformID), token, flag).Err())
} }
// SetTokenFlagEx set token and flag with expire time
func (c *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int, expire time.Duration) error {
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HSet(ctx, key, token, flag).Err(); err != nil {
return errs.Wrap(err)
}
if err := c.rdb.Expire(ctx, key, expire).Err(); err != nil {
return errs.Wrap(err)
}
return nil
}
func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
m, err := c.rdb.HGetAll(ctx, cachekey.GetTokenKey(userID, platformID)).Result() m, err := c.rdb.HGetAll(ctx, cachekey.GetTokenKey(userID, platformID)).Result()
if err != nil { if err != nil {

@ -2,10 +2,13 @@ package cache
import ( import (
"context" "context"
"time"
) )
type TokenModel interface { type TokenModel interface {
AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
// SetTokenFlagEx set token and flag with expire time
SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int, expire time.Duration) error
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error

@ -16,6 +16,7 @@ package controller
import ( import (
"context" "context"
"time"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
@ -55,6 +56,7 @@ func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, p
// Create Token. // Create Token.
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
isCreate := true // flag is create or update
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID) tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
if err != nil { if err != nil {
return "", err return "", err
@ -65,6 +67,9 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
if err != nil || v != constant.NormalToken { if err != nil || v != constant.NormalToken {
deleteTokenKey = append(deleteTokenKey, k) deleteTokenKey = append(deleteTokenKey, k)
} }
if v == constant.NormalToken {
isCreate = false
}
} }
if len(deleteTokenKey) != 0 { if len(deleteTokenKey) != 0 {
err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
@ -79,5 +84,21 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
if err != nil { if err != nil {
return "", errs.WrapMsg(err, "token.SignedString") return "", errs.WrapMsg(err, "token.SignedString")
} }
return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
if isCreate {
// should create,should specify expiration time
if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken, a.getExpireTime()); err != nil {
return "", err
}
} else {
// should update
if err = a.cache.SetTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil {
return "", err
}
}
return tokenString, nil
}
func (a *authDatabase) getExpireTime() time.Duration {
return time.Hour * 24 * time.Duration(a.accessExpire)
} }

Loading…
Cancel
Save