Merge remote-tracking branch 'origin/errcode' into errcode

# Conflicts:
#	internal/msggateway/n_ws_server.go
test-errcode
Gordon 1 year ago
commit b3c60aaf10

@ -1,6 +1,7 @@
package msggateway
import (
"context"
"errors"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
@ -188,7 +189,7 @@ func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
for _, c := range info.oldClients {
err := c.KickOnlineMessage()
if err != nil {
log.ZWarn(c.ctx, "kick online message error", err)
log.ZError(c.ctx, "KickOnlineMessage", err)
}
}
ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
@ -207,9 +208,9 @@ func (ws *WsServer) unregisterClient(client *Client) {
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r)
connContext := newContext(w, r)
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
httpError(context, errs.ErrConnOverMaxNumLimit)
httpError(connContext, errs.ErrConnOverMaxNumLimit)
return
}
var (
@ -220,46 +221,65 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
compression bool
)
token, exists = context.Query(Token)
token, exists = connContext.Query(Token)
if !exists {
httpError(context, errs.ErrConnArgsErr)
httpError(connContext, errs.ErrConnArgsErr)
return
}
userID, exists = context.Query(WsUserID)
userID, exists = connContext.Query(WsUserID)
if !exists {
httpError(context, errs.ErrConnArgsErr)
httpError(connContext, errs.ErrConnArgsErr)
return
}
platformID, exists = context.Query(PlatformID)
platformID, exists = connContext.Query(PlatformID)
if !exists || utils.StringToInt(platformID) == 0 {
httpError(context, errs.ErrConnArgsErr)
httpError(connContext, errs.ErrConnArgsErr)
return
}
err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil {
httpError(context, err)
httpError(connContext, err)
return
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
if err != nil {
httpError(connContext, err)
return
}
if v, ok := m[token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
httpError(connContext, errs.ErrTokenKicked.Wrap())
return
default:
httpError(connContext, errs.ErrTokenUnknown.Wrap())
return
}
} else {
httpError(connContext, errs.ErrTokenNotExist.Wrap())
return
}
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
err = wsLongConn.GenerateLongConn(w, r)
if err != nil {
httpError(context, err)
httpError(connContext, err)
return
}
compressProtoc, exists := context.Query(Compression)
compressProtoc, exists := connContext.Query(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
compressProtoc, exists = context.GetHeader(Compression)
compressProtoc, exists = connContext.GetHeader(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
client := ws.clientPool.Get().(*Client)
client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws)
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
ws.registerChan <- client
go client.readMessage()
}

@ -155,6 +155,9 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
c.Abort()
return
}
} else {
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
return
}
c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID))
c.Set(constant.OpUserID, claims.UserID)

Loading…
Cancel
Save