Merge remote-tracking branch 'origin/errcode' into errcode

# Conflicts:
#	internal/msggateway/n_ws_server.go
test-errcode
Gordon 1 year ago
commit b3c60aaf10

@ -1,6 +1,7 @@
package msggateway package msggateway
import ( import (
"context"
"errors" "errors"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
@ -188,7 +189,7 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
for _, c := range info.oldClients { for _, c := range info.oldClients {
err := c.KickOnlineMessage() err := c.KickOnlineMessage()
if err != nil { if err != nil {
log.ZWarn(c.ctx, "kick online message error", err) log.ZError(c.ctx, "KickOnlineMessage", err)
} }
} }
ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID) ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
@ -207,9 +208,9 @@ func (ws *WsServer) unregisterClient(client *Client) {
} }
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r) connContext := newContext(w, r)
if ws.onlineUserConnNum >= ws.wsMaxConnNum { if ws.onlineUserConnNum >= ws.wsMaxConnNum {
httpError(context, errs.ErrConnOverMaxNumLimit) httpError(connContext, errs.ErrConnOverMaxNumLimit)
return return
} }
var ( var (
@ -220,46 +221,65 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
compression bool compression bool
) )
token, exists = context.Query(Token) token, exists = connContext.Query(Token)
if !exists { if !exists {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
userID, exists = context.Query(WsUserID) userID, exists = connContext.Query(WsUserID)
if !exists { if !exists {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
platformID, exists = context.Query(PlatformID) platformID, exists = connContext.Query(PlatformID)
if !exists || utils.StringToInt(platformID) == 0 { if !exists || utils.StringToInt(platformID) == 0 {
httpError(context, errs.ErrConnArgsErr) httpError(connContext, errs.ErrConnArgsErr)
return return
} }
err := tokenverify.WsVerifyToken(token, userID, platformID) err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil { if err != nil {
httpError(context, err) httpError(connContext, err)
return
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
if err != nil {
httpError(connContext, err)
return
}
if v, ok := m[token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
httpError(connContext, errs.ErrTokenKicked.Wrap())
return
default:
httpError(connContext, errs.ErrTokenUnknown.Wrap())
return
}
} else {
httpError(connContext, errs.ErrTokenNotExist.Wrap())
return return
} }
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout) wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
err = wsLongConn.GenerateLongConn(w, r) err = wsLongConn.GenerateLongConn(w, r)
if err != nil { if err != nil {
httpError(context, err) httpError(connContext, err)
return return
} }
compressProtoc, exists := context.Query(Compression) compressProtoc, exists := connContext.Query(Compression)
if exists { if exists {
if compressProtoc == GzipCompressionProtocol { if compressProtoc == GzipCompressionProtocol {
compression = true compression = true
} }
} }
compressProtoc, exists = context.GetHeader(Compression) compressProtoc, exists = connContext.GetHeader(Compression)
if exists { if exists {
if compressProtoc == GzipCompressionProtocol { if compressProtoc == GzipCompressionProtocol {
compression = true compression = true
} }
} }
client := ws.clientPool.Get().(*Client) client := ws.clientPool.Get().(*Client)
client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws) client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
ws.registerChan <- client ws.registerChan <- client
go client.readMessage() go client.readMessage()
} }

@ -155,6 +155,9 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
c.Abort() c.Abort()
return return
} }
} else {
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
return
} }
c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID)) c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID))
c.Set(constant.OpUserID, claims.UserID) c.Set(constant.OpUserID, claims.UserID)

Loading…
Cancel
Save