|
|
|
@ -1,13 +1,19 @@
|
|
|
|
|
package msggateway
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"errors"
|
|
|
|
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/config"
|
|
|
|
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
|
|
|
|
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
|
|
|
|
|
"net/http"
|
|
|
|
|
"strconv"
|
|
|
|
|
"sync"
|
|
|
|
|
"sync/atomic"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
|
|
|
|
|
redis "github.com/go-redis/redis/v8"
|
|
|
|
|
|
|
|
|
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
|
|
|
|
|
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
|
|
|
|
@ -22,7 +28,7 @@ type LongConnServer interface {
|
|
|
|
|
GetUserAllCons(userID string) ([]*Client, bool)
|
|
|
|
|
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
|
|
|
|
|
Validate(s interface{}) error
|
|
|
|
|
//SetMessageHandler(msgRpcClient *rpcclient.MsgClient)
|
|
|
|
|
SetCacheHandler(cache cache.MsgModel)
|
|
|
|
|
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
|
|
|
|
|
UnRegister(c *Client)
|
|
|
|
|
Compressor
|
|
|
|
@ -41,6 +47,7 @@ type WsServer struct {
|
|
|
|
|
wsMaxConnNum int64
|
|
|
|
|
registerChan chan *Client
|
|
|
|
|
unregisterChan chan *Client
|
|
|
|
|
kickHandlerChan chan *kickHandler
|
|
|
|
|
clients *UserMap
|
|
|
|
|
clientPool sync.Pool
|
|
|
|
|
onlineUserNum int64
|
|
|
|
@ -48,14 +55,23 @@ type WsServer struct {
|
|
|
|
|
handshakeTimeout time.Duration
|
|
|
|
|
hubServer *Server
|
|
|
|
|
validate *validator.Validate
|
|
|
|
|
cache cache.MsgModel
|
|
|
|
|
Compressor
|
|
|
|
|
Encoder
|
|
|
|
|
MessageHandler
|
|
|
|
|
}
|
|
|
|
|
type kickHandler struct {
|
|
|
|
|
clientOK bool
|
|
|
|
|
oldClients []*Client
|
|
|
|
|
newClient *Client
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) {
|
|
|
|
|
ws.MessageHandler = NewGrpcHandler(ws.validate, client)
|
|
|
|
|
}
|
|
|
|
|
func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) {
|
|
|
|
|
ws.cache = cache
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ws *WsServer) UnRegister(c *Client) {
|
|
|
|
|
ws.unregisterChan <- c
|
|
|
|
@ -92,12 +108,13 @@ func NewWsServer(opts ...Option) (*WsServer, error) {
|
|
|
|
|
return new(Client)
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
registerChan: make(chan *Client, 1000),
|
|
|
|
|
unregisterChan: make(chan *Client, 1000),
|
|
|
|
|
validate: v,
|
|
|
|
|
clients: newUserMap(),
|
|
|
|
|
Compressor: NewGzipCompressor(),
|
|
|
|
|
Encoder: NewGobEncoder(),
|
|
|
|
|
registerChan: make(chan *Client, 1000),
|
|
|
|
|
unregisterChan: make(chan *Client, 1000),
|
|
|
|
|
kickHandlerChan: make(chan *kickHandler, 1000),
|
|
|
|
|
validate: v,
|
|
|
|
|
clients: newUserMap(),
|
|
|
|
|
Compressor: NewGzipCompressor(),
|
|
|
|
|
Encoder: NewGobEncoder(),
|
|
|
|
|
}, nil
|
|
|
|
|
}
|
|
|
|
|
func (ws *WsServer) Run() error {
|
|
|
|
@ -109,6 +126,8 @@ func (ws *WsServer) Run() error {
|
|
|
|
|
ws.registerClient(client)
|
|
|
|
|
case client = <-ws.unregisterChan:
|
|
|
|
|
ws.unregisterClient(client)
|
|
|
|
|
case onlineInfo := <-ws.kickHandlerChan:
|
|
|
|
|
ws.multiTerminalLoginChecker(onlineInfo)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
@ -119,26 +138,29 @@ func (ws *WsServer) Run() error {
|
|
|
|
|
|
|
|
|
|
func (ws *WsServer) registerClient(client *Client) {
|
|
|
|
|
var (
|
|
|
|
|
userOK bool
|
|
|
|
|
clientOK bool
|
|
|
|
|
cli []*Client
|
|
|
|
|
userOK bool
|
|
|
|
|
clientOK bool
|
|
|
|
|
oldClients []*Client
|
|
|
|
|
)
|
|
|
|
|
cli, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
|
|
|
|
ws.clients.Set(client.UserID, client)
|
|
|
|
|
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
|
|
|
|
if !userOK {
|
|
|
|
|
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
|
|
|
|
|
ws.clients.Set(client.UserID, client)
|
|
|
|
|
atomic.AddInt64(&ws.onlineUserNum, 1)
|
|
|
|
|
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
i := &kickHandler{
|
|
|
|
|
clientOK: clientOK,
|
|
|
|
|
oldClients: oldClients,
|
|
|
|
|
newClient: client,
|
|
|
|
|
}
|
|
|
|
|
ws.kickHandlerChan <- i
|
|
|
|
|
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
|
|
|
|
|
if clientOK { //已经有同平台的连接存在
|
|
|
|
|
ws.clients.Set(client.UserID, client)
|
|
|
|
|
ws.multiTerminalLoginChecker(cli)
|
|
|
|
|
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(cli))
|
|
|
|
|
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
|
|
|
|
|
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
|
|
|
|
} else {
|
|
|
|
|
ws.clients.Set(client.UserID, client)
|
|
|
|
|
atomic.AddInt64(&ws.onlineUserConnNum, 1)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -156,7 +178,47 @@ func getRemoteAdders(client []*Client) string {
|
|
|
|
|
return ret
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ws *WsServer) multiTerminalLoginChecker(client []*Client) {
|
|
|
|
|
func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) {
|
|
|
|
|
switch config.Config.MultiLoginPolicy {
|
|
|
|
|
case constant.DefalutNotKick:
|
|
|
|
|
case constant.PCAndOther:
|
|
|
|
|
if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
fallthrough
|
|
|
|
|
case constant.AllLoginButSameTermKick:
|
|
|
|
|
if info.clientOK {
|
|
|
|
|
ws.clients.deleteClients(info.newClient.UserID, info.oldClients)
|
|
|
|
|
for _, c := range info.oldClients {
|
|
|
|
|
err := c.KickOnlineMessage()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.ZWarn(c.ctx, "KickOnlineMessage", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
m, err := ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID)
|
|
|
|
|
if err != nil && err != redis.Nil {
|
|
|
|
|
log.ZWarn(info.newClient.ctx, "get token from redis err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if m == nil {
|
|
|
|
|
log.ZWarn(info.newClient.ctx, "m is nil", errors.New("m is nil"), "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
log.ZDebug(info.newClient.ctx, "get token from redis", "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID, "tokenMap", m)
|
|
|
|
|
|
|
|
|
|
for k, _ := range m {
|
|
|
|
|
if k != info.newClient.ctx.GetToken() {
|
|
|
|
|
m[k] = constant.KickedToken
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID)
|
|
|
|
|
err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.ZWarn(info.newClient.ctx, "SetTokenMapByUidPid err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
func (ws *WsServer) unregisterClient(client *Client) {
|
|
|
|
@ -170,60 +232,83 @@ func (ws *WsServer) unregisterClient(client *Client) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
context := newContext(w, r)
|
|
|
|
|
defer log.ZInfo(context.Background(), "wsHandler", "remote addr", "url", r.URL.String())
|
|
|
|
|
connContext := newContext(w, r)
|
|
|
|
|
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
|
|
|
|
|
httpError(context, errs.ErrConnOverMaxNumLimit)
|
|
|
|
|
httpError(connContext, errs.ErrConnOverMaxNumLimit)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
var (
|
|
|
|
|
token string
|
|
|
|
|
userID string
|
|
|
|
|
platformID string
|
|
|
|
|
exists bool
|
|
|
|
|
compression bool
|
|
|
|
|
token string
|
|
|
|
|
userID string
|
|
|
|
|
platformIDStr string
|
|
|
|
|
exists bool
|
|
|
|
|
compression bool
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
token, exists = context.Query(Token)
|
|
|
|
|
token, exists = connContext.Query(Token)
|
|
|
|
|
if !exists {
|
|
|
|
|
httpError(context, errs.ErrConnArgsErr)
|
|
|
|
|
httpError(connContext, errs.ErrConnArgsErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
userID, exists = context.Query(WsUserID)
|
|
|
|
|
userID, exists = connContext.Query(WsUserID)
|
|
|
|
|
if !exists {
|
|
|
|
|
httpError(context, errs.ErrConnArgsErr)
|
|
|
|
|
httpError(connContext, errs.ErrConnArgsErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
platformID, exists = context.Query(PlatformID)
|
|
|
|
|
if !exists || utils.StringToInt(platformID) == 0 {
|
|
|
|
|
httpError(context, errs.ErrConnArgsErr)
|
|
|
|
|
platformIDStr, exists = connContext.Query(PlatformID)
|
|
|
|
|
if !exists {
|
|
|
|
|
httpError(connContext, errs.ErrConnArgsErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
// log.ZDebug(context2.Background(), "conn", "platformID", platformID)
|
|
|
|
|
err := tokenverify.WsVerifyToken(token, userID, platformID)
|
|
|
|
|
platformID, err := strconv.Atoi(platformIDStr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
httpError(context, err)
|
|
|
|
|
httpError(connContext, errs.ErrConnArgsErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil {
|
|
|
|
|
httpError(connContext, err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
httpError(connContext, err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if v, ok := m[token]; ok {
|
|
|
|
|
switch v {
|
|
|
|
|
case constant.NormalToken:
|
|
|
|
|
case constant.KickedToken:
|
|
|
|
|
httpError(connContext, errs.ErrTokenKicked.Wrap())
|
|
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
httpError(connContext, errs.ErrTokenUnknown.Wrap())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
httpError(connContext, errs.ErrTokenNotExist.Wrap())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
|
|
|
|
|
err = wsLongConn.GenerateLongConn(w, r)
|
|
|
|
|
if err != nil {
|
|
|
|
|
httpError(context, err)
|
|
|
|
|
httpError(connContext, err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
compressProtoc, exists := context.Query(Compression)
|
|
|
|
|
compressProtoc, exists := connContext.Query(Compression)
|
|
|
|
|
if exists {
|
|
|
|
|
if compressProtoc == GzipCompressionProtocol {
|
|
|
|
|
compression = true
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
compressProtoc, exists = context.GetHeader(Compression)
|
|
|
|
|
compressProtoc, exists = connContext.GetHeader(Compression)
|
|
|
|
|
if exists {
|
|
|
|
|
if compressProtoc == GzipCompressionProtocol {
|
|
|
|
|
compression = true
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
client := ws.clientPool.Get().(*Client)
|
|
|
|
|
client.ResetClient(context, wsLongConn, context.GetBackground(), compression, ws)
|
|
|
|
|
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws)
|
|
|
|
|
ws.registerChan <- client
|
|
|
|
|
go client.readMessage()
|
|
|
|
|
}
|
|
|
|
|