From ad42eaed11070d6cb81abcf05510dc9c74193754 Mon Sep 17 00:00:00 2001 From: Gordon <1432970085@qq.com> Date: Wed, 14 Jun 2023 09:58:10 +0800 Subject: [PATCH 1/6] feat: kick user when same terminal login --- internal/msggateway/client.go | 7 ++- internal/msggateway/hub_server.go | 9 ++- internal/msggateway/n_ws_server.go | 71 ++++++++++++++++------ pkg/common/constant/constant.go | 1 + pkg/common/constant/platform_id_to_name.go | 16 ++++- 5 files changed, 81 insertions(+), 23 deletions(-) diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index e7d794324..6bce68c85 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -224,8 +224,11 @@ func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error return c.writeBinaryMsg(resp) } -func (c *Client) KickOnlineMessage(ctx context.Context) error { - return nil +func (c *Client) KickOnlineMessage() error { + resp := Resp{ + ReqIdentifier: WSKickOnlineMsg, + } + return c.writeBinaryMsg(resp) } func (c *Client) writeBinaryMsg(resp Resp) error { diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index e93497de4..786d3eeff 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -2,6 +2,7 @@ package msggateway import ( "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" @@ -17,7 +18,13 @@ import ( ) func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis() + if err != nil { + return err + } + msgModel := cache.NewMsgCacheModel(rdb) s.LongConnServer.SetDiscoveryRegistry(client) + s.LongConnServer.SetCacheHandler(msgModel) msggateway.RegisterMsgGatewayServer(server, s) return nil } @@ -131,7 +138,7 @@ func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOf for _, v := range req.KickUserIDList { if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok { for _, client := range clients { - err := client.KickOnlineMessage(ctx) + err := client.KickOnlineMessage() if err != nil { return nil, err } diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 749287e7f..7318b2cf1 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -2,6 +2,9 @@ package msggateway import ( "errors" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "net/http" "sync" "sync/atomic" @@ -22,7 +25,7 @@ type LongConnServer interface { GetUserAllCons(userID string) ([]*Client, bool) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) Validate(s interface{}) error - //SetMessageHandler(msgRpcClient *rpcclient.MsgClient) + SetCacheHandler(cache cache.MsgModel) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) UnRegister(c *Client) Compressor @@ -41,6 +44,7 @@ type WsServer struct { wsMaxConnNum int64 registerChan chan *Client unregisterChan chan *Client + kickHandlerChan chan *kickHandler clients *UserMap clientPool sync.Pool onlineUserNum int64 @@ -48,14 +52,23 @@ type WsServer struct { handshakeTimeout time.Duration hubServer *Server validate *validator.Validate + cache cache.MsgModel Compressor Encoder MessageHandler } +type kickHandler struct { + clientOK bool + oldClients []*Client + newClient *Client +} func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) { ws.MessageHandler = NewGrpcHandler(ws.validate, client) } +func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) { + ws.cache = cache +} func (ws *WsServer) UnRegister(c *Client) { ws.unregisterChan <- c @@ -92,12 +105,13 @@ func NewWsServer(opts ...Option) (*WsServer, error) { return new(Client) }, }, - registerChan: make(chan *Client, 1000), - unregisterChan: make(chan *Client, 1000), - validate: v, - clients: newUserMap(), - Compressor: NewGzipCompressor(), - Encoder: NewGobEncoder(), + registerChan: make(chan *Client, 1000), + unregisterChan: make(chan *Client, 1000), + kickHandlerChan: make(chan *kickHandler, 1000), + validate: v, + clients: newUserMap(), + Compressor: NewGzipCompressor(), + Encoder: NewGobEncoder(), }, nil } func (ws *WsServer) Run() error { @@ -109,6 +123,8 @@ func (ws *WsServer) Run() error { ws.registerClient(client) case client = <-ws.unregisterChan: ws.unregisterClient(client) + case onlineInfo := <-ws.kickHandlerChan: + ws.multiTerminalLoginChecker(onlineInfo) } } }() @@ -119,26 +135,29 @@ func (ws *WsServer) Run() error { func (ws *WsServer) registerClient(client *Client) { var ( - userOK bool - clientOK bool - cli []*Client + userOK bool + clientOK bool + oldClients []*Client ) - cli, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) + ws.clients.Set(client.UserID, client) + oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) if !userOK { log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID) - ws.clients.Set(client.UserID, client) atomic.AddInt64(&ws.onlineUserNum, 1) atomic.AddInt64(&ws.onlineUserConnNum, 1) } else { + i := &kickHandler{ + clientOK: clientOK, + oldClients: oldClients, + newClient: client, + } + ws.kickHandlerChan <- i log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID) if clientOK { //已经有同平台的连接存在 - ws.clients.Set(client.UserID, client) - ws.multiTerminalLoginChecker(cli) - log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(cli)) + log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients)) atomic.AddInt64(&ws.onlineUserConnNum, 1) } else { - ws.clients.Set(client.UserID, client) atomic.AddInt64(&ws.onlineUserConnNum, 1) } } @@ -156,7 +175,24 @@ func getRemoteAdders(client []*Client) string { return ret } -func (ws *WsServer) multiTerminalLoginChecker(client []*Client) { +func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { + switch config.Config.MultiLoginPolicy { + case constant.DefalutNotKick: + case constant.PCAndOther: + if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC { + return + } + fallthrough + case constant.AllLoginButSameTermKick: + if info.clientOK { + for _, c := range info.oldClients { + err := c.KickOnlineMessage() + if err != nil { + log.ZWarn() + } + } + } + } } func (ws *WsServer) unregisterClient(client *Client) { @@ -198,7 +234,6 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(context, errs.ErrConnArgsErr) return } - // log.ZDebug(context2.Background(), "conn", "platformID", platformID) err := tokenverify.WsVerifyToken(token, userID, platformID) if err != nil { httpError(context, err) diff --git a/pkg/common/constant/constant.go b/pkg/common/constant/constant.go index 697694e87..8be23c10f 100644 --- a/pkg/common/constant/constant.go +++ b/pkg/common/constant/constant.go @@ -118,6 +118,7 @@ const ( ExpiredToken = 3 //MultiTerminalLogin + DefalutNotKick = 0 //Full-end login, but the same end is mutually exclusive AllLoginButSameTermKick = 1 //Only one of the endpoints can log in diff --git a/pkg/common/constant/platform_id_to_name.go b/pkg/common/constant/platform_id_to_name.go index 3d5ab059b..e8bb129eb 100644 --- a/pkg/common/constant/platform_id_to_name.go +++ b/pkg/common/constant/platform_id_to_name.go @@ -57,7 +57,7 @@ var PlatformName2ID = map[string]int{ IPadPlatformStr: IPadPlatformID, AdminPlatformStr: AdminPlatformID, } -var Platform2class = map[string]string{ +var PlatformName2class = map[string]string{ IOSPlatformStr: TerminalMobile, AndroidPlatformStr: TerminalMobile, MiniWebPlatformStr: WebPlatformStr, @@ -66,6 +66,15 @@ var Platform2class = map[string]string{ OSXPlatformStr: TerminalPC, LinuxPlatformStr: TerminalPC, } +var PlatformID2class = map[int]string{ + IOSPlatformID: TerminalMobile, + AndroidPlatformID: TerminalMobile, + MiniWebPlatformID: WebPlatformStr, + WebPlatformID: WebPlatformStr, + WindowsPlatformID: TerminalPC, + OSXPlatformID: TerminalPC, + LinuxPlatformID: TerminalPC, +} func PlatformIDToName(num int) string { return PlatformID2Name[num] @@ -74,5 +83,8 @@ func PlatformNameToID(name string) int { return PlatformName2ID[name] } func PlatformNameToClass(name string) string { - return Platform2class[name] + return PlatformName2class[name] +} +func PlatformIDToClass(num int) string { + return PlatformID2class[num] } From 82e85c708375e18f2e2a48bdcbe6effc14645c6a Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Wed, 14 Jun 2023 10:15:58 +0800 Subject: [PATCH 2/6] WsVerifyToken --- internal/msggateway/n_ws_server.go | 50 +++++++++++++++++++++--------- pkg/common/mw/gin.go | 3 ++ 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 7318b2cf1..359934532 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -1,6 +1,7 @@ package msggateway import ( + "context" "errors" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" @@ -188,13 +189,13 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { for _, c := range info.oldClients { err := c.KickOnlineMessage() if err != nil { - log.ZWarn() + log.ZError(c.ctx, "KickOnlineMessage", err) } } } } - } + func (ws *WsServer) unregisterClient(client *Client) { defer ws.clientPool.Put(client) isDeleteUser := ws.clients.delete(client.UserID, client.ctx.GetRemoteAddr()) @@ -206,9 +207,9 @@ func (ws *WsServer) unregisterClient(client *Client) { } func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { - context := newContext(w, r) + connContext := newContext(w, r) if ws.onlineUserConnNum >= ws.wsMaxConnNum { - httpError(context, errs.ErrConnOverMaxNumLimit) + httpError(connContext, errs.ErrConnOverMaxNumLimit) return } var ( @@ -219,46 +220,65 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { compression bool ) - token, exists = context.Query(Token) + token, exists = connContext.Query(Token) if !exists { - httpError(context, errs.ErrConnArgsErr) + httpError(connContext, errs.ErrConnArgsErr) return } - userID, exists = context.Query(WsUserID) + userID, exists = connContext.Query(WsUserID) if !exists { - httpError(context, errs.ErrConnArgsErr) + httpError(connContext, errs.ErrConnArgsErr) return } - platformID, exists = context.Query(PlatformID) + platformID, exists = connContext.Query(PlatformID) if !exists || utils.StringToInt(platformID) == 0 { - httpError(context, errs.ErrConnArgsErr) + httpError(connContext, errs.ErrConnArgsErr) return } err := tokenverify.WsVerifyToken(token, userID, platformID) if err != nil { - httpError(context, err) + httpError(connContext, err) + return + } + m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID) + if err != nil { + httpError(connContext, err) + return + } + if v, ok := m[token]; ok { + switch v { + case constant.NormalToken: + case constant.KickedToken: + httpError(connContext, errs.ErrTokenKicked.Wrap()) + return + default: + httpError(connContext, errs.ErrTokenUnknown.Wrap()) + return + } + } else { + httpError(connContext, errs.ErrTokenNotExist.Wrap()) return } wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout) err = wsLongConn.GenerateLongConn(w, r) if err != nil { - httpError(context, err) + httpError(connContext, err) return } - compressProtoc, exists := context.Query(Compression) + compressProtoc, exists := connContext.Query(Compression) if exists { if compressProtoc == GzipCompressionProtocol { compression = true } } - compressProtoc, exists = context.GetHeader(Compression) + compressProtoc, exists = connContext.GetHeader(Compression) if exists { if compressProtoc == GzipCompressionProtocol { compression = true } } client := ws.clientPool.Get().(*Client) - client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws) + client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws) ws.registerChan <- client go client.readMessage() } diff --git a/pkg/common/mw/gin.go b/pkg/common/mw/gin.go index 4b12a8244..4343048fc 100644 --- a/pkg/common/mw/gin.go +++ b/pkg/common/mw/gin.go @@ -155,6 +155,9 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { c.Abort() return } + } else { + apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) + return } c.Set(constant.OpUserPlatform, claims.Platform) c.Set(constant.OpUserID, claims.UID) From 0124a5c05d027bc97524a501fbc5ed85fe2b1229 Mon Sep 17 00:00:00 2001 From: Gordon <1432970085@qq.com> Date: Wed, 14 Jun 2023 10:47:18 +0800 Subject: [PATCH 3/6] refactor: token platformID update --- internal/msggateway/n_ws_server.go | 3 ++- internal/rpc/auth/auth.go | 8 ++++---- pkg/common/db/cache/msg.go | 18 +++++++++--------- pkg/common/db/controller/auth.go | 18 +++++++++--------- pkg/common/mw/gin.go | 6 +++--- pkg/common/tokenverify/jwt_token.go | 23 +++++++++-------------- 6 files changed, 36 insertions(+), 40 deletions(-) diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 7318b2cf1..2563ab3ce 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -188,9 +188,10 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { for _, c := range info.oldClients { err := c.KickOnlineMessage() if err != nil { - log.ZWarn() + log.ZWarn(c.ctx, "kick online message error", err) } } + ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID) } } diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 8298c9a87..fe5dfd21e 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -42,7 +42,7 @@ func (s *authServer) UserToken(ctx context.Context, req *pbAuth.UserTokenReq) (* if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil { return nil, err } - token, err := s.authDatabase.CreateToken(ctx, req.UserID, constant.PlatformIDToName(int(req.PlatformID))) + token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID)) if err != nil { return nil, err } @@ -56,7 +56,7 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim if err != nil { return nil, utils.Wrap(err, "") } - m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UID, claims.Platform) + m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) if err != nil { return nil, err } @@ -82,8 +82,8 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbAuth.ParseTokenReq) if err != nil { return nil, err } - resp.UserID = claims.UID - resp.Platform = claims.Platform + resp.UserID = claims.UserID + resp.Platform = constant.PlatformIDToName(claims.PlatformID) resp.ExpireTimeSeconds = claims.ExpiresAt.Unix() return resp, nil } diff --git a/pkg/common/db/cache/msg.go b/pkg/common/db/cache/msg.go index d3bb47b91..30d5f1ffc 100644 --- a/pkg/common/db/cache/msg.go +++ b/pkg/common/db/cache/msg.go @@ -88,9 +88,9 @@ type MsgModel interface { SeqCache thirdCache AddTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error - GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) - SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error - DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error + GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) + SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error + DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error @@ -260,8 +260,8 @@ func (c *msgCache) AddTokenFlag(ctx context.Context, userID string, platformID i return errs.Wrap(c.rdb.HSet(ctx, key, token, flag).Err()) } -func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID string) (map[string]int, error) { - key := uidPidToken + userID + ":" + platformID +func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { + key := uidPidToken + userID + ":" + constant.PlatformIDToName(platformID) m, err := c.rdb.HGetAll(ctx, key).Result() if err != nil { return nil, errs.Wrap(err) @@ -273,8 +273,8 @@ func (c *msgCache) GetTokensWithoutError(ctx context.Context, userID, platformID return mm, nil } -func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform string, m map[string]int) error { - key := uidPidToken + userID + ":" + platform +func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platform int, m map[string]int) error { + key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform) mm := make(map[string]interface{}) for k, v := range m { mm[k] = v @@ -282,8 +282,8 @@ func (c *msgCache) SetTokenMapByUidPid(ctx context.Context, userID string, platf return errs.Wrap(c.rdb.HSet(ctx, key, mm).Err()) } -func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform string, fields []string) error { - key := uidPidToken + userID + ":" + platform +func (c *msgCache) DeleteTokenByUidPid(ctx context.Context, userID string, platform int, fields []string) error { + key := uidPidToken + userID + ":" + constant.PlatformIDToName(platform) return errs.Wrap(c.rdb.HDel(ctx, key, fields...).Err()) } diff --git a/pkg/common/db/controller/auth.go b/pkg/common/db/controller/auth.go index 148ef6c96..6d6add902 100644 --- a/pkg/common/db/controller/auth.go +++ b/pkg/common/db/controller/auth.go @@ -12,9 +12,9 @@ import ( type AuthDatabase interface { //结果为空 不返回错误 - GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) + GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) //创建token - CreateToken(ctx context.Context, userID string, platform string) (string, error) + CreateToken(ctx context.Context, userID string, platformID int) (string, error) } type authDatabase struct { @@ -29,13 +29,13 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int } // 结果为空 不返回错误 -func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID, platform string) (map[string]int, error) { - return a.cache.GetTokensWithoutError(ctx, userID, platform) +func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { + return a.cache.GetTokensWithoutError(ctx, userID, platformID) } // 创建token -func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform string) (string, error) { - tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platform) +func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { + tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID) if err != nil { return "", err } @@ -47,16 +47,16 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platform } } if len(deleteTokenKey) != 0 { - err := a.cache.DeleteTokenByUidPid(ctx, userID, platform, deleteTokenKey) + err := a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) if err != nil { return "", err } } - claims := tokenverify.BuildClaims(userID, platform, a.accessExpire) + claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(a.accessSecret)) if err != nil { return "", utils.Wrap(err, "") } - return tokenString, a.cache.AddTokenFlag(ctx, userID, constant.PlatformNameToID(platform), tokenString, constant.NormalToken) + return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken) } diff --git a/pkg/common/mw/gin.go b/pkg/common/mw/gin.go index 4b12a8244..6c87fbb5b 100644 --- a/pkg/common/mw/gin.go +++ b/pkg/common/mw/gin.go @@ -128,7 +128,7 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { c.Abort() return } - m, err := dataBase.GetTokensWithoutError(c, claims.UID, claims.Platform) + m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID) if err != nil { log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap()) apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) @@ -156,8 +156,8 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { return } } - c.Set(constant.OpUserPlatform, claims.Platform) - c.Set(constant.OpUserID, claims.UID) + c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID)) + c.Set(constant.OpUserID, claims.UserID) c.Next() } } diff --git a/pkg/common/tokenverify/jwt_token.go b/pkg/common/tokenverify/jwt_token.go index 1f2f0797c..65a31545e 100644 --- a/pkg/common/tokenverify/jwt_token.go +++ b/pkg/common/tokenverify/jwt_token.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" @@ -14,17 +13,17 @@ import ( ) type Claims struct { - UID string - Platform string //login platform + UserID string + PlatformID int //login platform jwt.RegisteredClaims } -func BuildClaims(uid, platform string, ttl int64) Claims { +func BuildClaims(uid string, platformID int, ttl int64) Claims { now := time.Now() before := now.Add(-time.Minute * 5) return Claims{ - UID: uid, - Platform: platform, + UserID: uid, + PlatformID: platformID, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time IssuedAt: jwt.NewNumericDate(now), //Issuing time @@ -95,19 +94,15 @@ func WsVerifyToken(token, userID, platformID string) error { if err != nil { return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID)) } - platform := constant.PlatformIDToName(platformIDInt) - if platform == "" { - return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not exist", platformID)) - } claim, err := GetClaimFromToken(token) if err != nil { return err } - if claim.UID != userID { - return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UID, userID)) + if claim.UserID != userID { + return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID)) } - if claim.Platform != platform { - return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %s != %s", claim.Platform, platform)) + if claim.PlatformID != platformIDInt { + return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformIDInt)) } return nil } From aa8c250c613991544dcb169deef3aac6eb0d6ea2 Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Wed, 14 Jun 2023 11:00:44 +0800 Subject: [PATCH 4/6] WsVerifyToken --- internal/msggateway/n_ws_server.go | 21 +++++++++++++-------- pkg/common/tokenverify/jwt_token.go | 11 +++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index d7175ed36..c665c5556 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -7,6 +7,7 @@ import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "net/http" + "strconv" "sync" "sync/atomic" "time" @@ -214,11 +215,11 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { return } var ( - token string - userID string - platformID string - exists bool - compression bool + token string + userID string + platformIDStr string + exists bool + compression bool ) token, exists = connContext.Query(Token) @@ -231,13 +232,17 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { httpError(connContext, errs.ErrConnArgsErr) return } - platformID, exists = connContext.Query(PlatformID) - if !exists || utils.StringToInt(platformID) == 0 { + platformIDStr, exists = connContext.Query(PlatformID) + if !exists { httpError(connContext, errs.ErrConnArgsErr) return } - err := tokenverify.WsVerifyToken(token, userID, platformID) + platformID, err := strconv.Atoi(platformIDStr) if err != nil { + httpError(connContext, errs.ErrConnArgsErr) + return + } + if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil { httpError(connContext, err) return } diff --git a/pkg/common/tokenverify/jwt_token.go b/pkg/common/tokenverify/jwt_token.go index 65a31545e..bc7ca62e6 100644 --- a/pkg/common/tokenverify/jwt_token.go +++ b/pkg/common/tokenverify/jwt_token.go @@ -8,7 +8,6 @@ import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "github.com/golang-jwt/jwt/v4" - "strconv" "time" ) @@ -89,11 +88,7 @@ func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) { func IsManagerUserID(opUserID string) bool { return utils.IsContain(opUserID, config.Config.Manager.AppManagerUid) } -func WsVerifyToken(token, userID, platformID string) error { - platformIDInt, err := strconv.Atoi(platformID) - if err != nil { - return errs.ErrArgs.Wrap(fmt.Sprintf("platformID %s is not int", platformID)) - } +func WsVerifyToken(token, userID string, platformID int) error { claim, err := GetClaimFromToken(token) if err != nil { return err @@ -101,8 +96,8 @@ func WsVerifyToken(token, userID, platformID string) error { if claim.UserID != userID { return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token uid %s != userID %s", claim.UserID, userID)) } - if claim.PlatformID != platformIDInt { - return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformIDInt)) + if claim.PlatformID != platformID { + return errs.ErrTokenInvalid.Wrap(fmt.Sprintf("token platform %d != %d", claim.PlatformID, platformID)) } return nil } From 638d2d4282a742c52c0b066080b201c348974d1a Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Wed, 14 Jun 2023 11:26:20 +0800 Subject: [PATCH 5/6] WsVerifyToken --- internal/msggateway/n_ws_server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index c665c5556..b093395e8 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -209,6 +209,7 @@ func (ws *WsServer) unregisterClient(client *Client) { } func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { + defer log.ZInfo(context.Background(), "wsHandler", "remote addr", "url", r.URL.String()) connContext := newContext(w, r) if ws.onlineUserConnNum >= ws.wsMaxConnNum { httpError(connContext, errs.ErrConnOverMaxNumLimit) From d3a92132a1416ef18cb8c0f2ba6df1f9aea65b9b Mon Sep 17 00:00:00 2001 From: Gordon <1432970085@qq.com> Date: Wed, 14 Jun 2023 11:53:12 +0800 Subject: [PATCH 6/6] refactor: kick user --- internal/msggateway/context.go | 3 +++ internal/msggateway/n_ws_server.go | 27 +++++++++++++++++++++++++-- internal/msggateway/user_map.go | 24 ++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/internal/msggateway/context.go b/internal/msggateway/context.go index 5baa11fdd..cd395e7e0 100644 --- a/internal/msggateway/context.go +++ b/internal/msggateway/context.go @@ -91,6 +91,9 @@ func (c *UserConnContext) GetPlatformID() string { func (c *UserConnContext) GetOperationID() string { return c.Req.URL.Query().Get(OperationID) } +func (c *UserConnContext) GetToken() string { + return c.Req.URL.Query().Get(Token) +} func (c *UserConnContext) GetBackground() bool { b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus)) if err != nil { diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index b093395e8..de2788cb0 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -13,6 +13,7 @@ import ( "time" "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" + redis "github.com/go-redis/redis/v8" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify" @@ -187,13 +188,35 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { fallthrough case constant.AllLoginButSameTermKick: if info.clientOK { + ws.clients.deleteClients(info.newClient.UserID, info.oldClients) for _, c := range info.oldClients { err := c.KickOnlineMessage() if err != nil { - log.ZError(c.ctx, "KickOnlineMessage", err) + log.ZWarn(c.ctx, "KickOnlineMessage", err) } } - ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID) + m, err := ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID) + if err != nil && err != redis.Nil { + log.ZWarn(info.newClient.ctx, "get token from redis err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID) + return + } + if m == nil { + log.ZWarn(info.newClient.ctx, "m is nil", errors.New("m is nil"), "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID) + return + } + log.ZDebug(info.newClient.ctx, "get token from redis", "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID, "tokenMap", m) + + for k, _ := range m { + if k != info.newClient.ctx.GetToken() { + m[k] = constant.KickedToken + } + } + log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID) + err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m) + if err != nil { + log.ZWarn(info.newClient.ctx, "SetTokenMapByUidPid err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID) + return + } } } diff --git a/internal/msggateway/user_map.go b/internal/msggateway/user_map.go index e482f7cea..63881bc1a 100644 --- a/internal/msggateway/user_map.go +++ b/internal/msggateway/user_map.go @@ -3,6 +3,7 @@ package msggateway import ( "context" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "sync" ) @@ -71,6 +72,29 @@ func (u *UserMap) delete(key string, connRemoteAddr string) (isDeleteUser bool) } return existed } +func (u *UserMap) deleteClients(key string, clients []*Client) (isDeleteUser bool) { + m := utils.SliceToMapAny(clients, func(c *Client) (string, struct{}) { + return c.ctx.GetRemoteAddr(), struct{}{} + }) + allClients, existed := u.m.Load(key) + if existed { + oldClients := allClients.([]*Client) + var a []*Client + for _, client := range oldClients { + if _, ok := m[client.ctx.GetRemoteAddr()]; !ok { + a = append(a, client) + } + } + if len(a) == 0 { + u.m.Delete(key) + return true + } else { + u.m.Store(key, a) + return false + } + } + return existed +} func (u *UserMap) DeleteAll(key string) { u.m.Delete(key) }