diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 2563ab3ce..d7175ed36 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -1,6 +1,7 @@ package msggateway import ( + "context" "errors" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" @@ -188,7 +189,7 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { for _, c := range info.oldClients { err := c.KickOnlineMessage() if err != nil { - log.ZWarn(c.ctx, "kick online message error", err) + log.ZError(c.ctx, "KickOnlineMessage", err) } } ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID) @@ -207,9 +208,9 @@ func (ws *WsServer) unregisterClient(client *Client) { } func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { - context := newContext(w, r) + connContext := newContext(w, r) if ws.onlineUserConnNum >= ws.wsMaxConnNum { - httpError(context, errs.ErrConnOverMaxNumLimit) + httpError(connContext, errs.ErrConnOverMaxNumLimit) return } var ( @@ -220,46 +221,65 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { compression bool ) - token, exists = context.Query(Token) + token, exists = connContext.Query(Token) if !exists { - httpError(context, errs.ErrConnArgsErr) + httpError(connContext, errs.ErrConnArgsErr) return } - userID, exists = context.Query(WsUserID) + userID, exists = connContext.Query(WsUserID) if !exists { - httpError(context, errs.ErrConnArgsErr) + httpError(connContext, errs.ErrConnArgsErr) return } - platformID, exists = context.Query(PlatformID) + platformID, exists = connContext.Query(PlatformID) if !exists || utils.StringToInt(platformID) == 0 { - httpError(context, errs.ErrConnArgsErr) + httpError(connContext, errs.ErrConnArgsErr) return } err := tokenverify.WsVerifyToken(token, userID, platformID) if err != nil { - httpError(context, err) + httpError(connContext, err) + return + } + m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID) + if err != nil { + httpError(connContext, err) + return + } + if v, ok := m[token]; ok { + switch v { + case constant.NormalToken: + case constant.KickedToken: + httpError(connContext, errs.ErrTokenKicked.Wrap()) + return + default: + httpError(connContext, errs.ErrTokenUnknown.Wrap()) + return + } + } else { + httpError(connContext, errs.ErrTokenNotExist.Wrap()) return } wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout) err = wsLongConn.GenerateLongConn(w, r) if err != nil { - httpError(context, err) + httpError(connContext, err) return } - compressProtoc, exists := context.Query(Compression) + compressProtoc, exists := connContext.Query(Compression) if exists { if compressProtoc == GzipCompressionProtocol { compression = true } } - compressProtoc, exists = context.GetHeader(Compression) + compressProtoc, exists = connContext.GetHeader(Compression) if exists { if compressProtoc == GzipCompressionProtocol { compression = true } } client := ws.clientPool.Get().(*Client) - client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws) + client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws) ws.registerChan <- client go client.readMessage() } diff --git a/pkg/common/mw/gin.go b/pkg/common/mw/gin.go index 6c87fbb5b..449afee71 100644 --- a/pkg/common/mw/gin.go +++ b/pkg/common/mw/gin.go @@ -155,6 +155,9 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { c.Abort() return } + } else { + apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) + return } c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID)) c.Set(constant.OpUserID, claims.UserID)