You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Open-IM-Server/internal/msggateway/n_ws_server.go

459 lines
12 KiB

// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"errors"
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/go-playground/validator/v10"
"github.com/redis/go-redis/v9"
"golang.org/x/sync/errgroup"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/msggateway"
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
)
type LongConnServer interface {
Run() error
wsHandler(w http.ResponseWriter, r *http.Request)
GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s interface{}) error
SetCacheHandler(cache cache.MsgModel)
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
KickUserConn(client *Client) error
UnRegister(c *Client)
SetKickHandlerInfo(i *kickHandler)
Compressor
Encoder
MessageHandler
}
type WsServer struct {
port int
wsMaxConnNum int64
registerChan chan *Client
unregisterChan chan *Client
kickHandlerChan chan *kickHandler
clients *UserMap
clientPool sync.Pool
onlineUserNum atomic.Int64
onlineUserConnNum atomic.Int64
handshakeTimeout time.Duration
writeBufferSize int
validate *validator.Validate
cache cache.MsgModel
userClient *rpcclient.UserRpcClient
disCov discoveryregistry.SvcDiscoveryRegistry
Compressor
Encoder
MessageHandler
}
type kickHandler struct {
clientOK bool
oldClients []*Client
newClient *Client
}
func (ws *WsServer) SetDiscoveryRegistry(disCov discoveryregistry.SvcDiscoveryRegistry) {
ws.MessageHandler = NewGrpcHandler(ws.validate, disCov)
u := rpcclient.NewUserRpcClient(disCov)
ws.userClient = &u
ws.disCov = disCov
}
func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, status int32) {
err := ws.userClient.SetUserStatus(ctx, client.UserID, status, client.PlatformID)
if err != nil {
log.ZWarn(ctx, "SetUserStatus err", err)
}
switch status {
case constant.Online:
err := CallbackUserOnline(ctx, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID())
if err != nil {
log.ZWarn(ctx, "CallbackUserOnline err", err)
}
case constant.Offline:
err := CallbackUserOffline(ctx, client.UserID, client.PlatformID, client.ctx.GetConnID())
if err != nil {
log.ZWarn(ctx, "CallbackUserOffline err", err)
}
}
}
func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) {
ws.cache = cache
}
func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c
}
func (ws *WsServer) Validate(s interface{}) error {
return nil
}
func (ws *WsServer) GetUserAllCons(userID string) ([]*Client, bool) {
return ws.clients.GetAll(userID)
}
func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) {
return ws.clients.Get(userID, platform)
}
func NewWsServer(opts ...Option) (*WsServer, error) {
var config configs
for _, o := range opts {
o(&config)
}
v := validator.New()
return &WsServer{
port: config.port,
wsMaxConnNum: config.maxConnNum,
writeBufferSize: config.writeBufferSize,
handshakeTimeout: config.handshakeTimeout,
clientPool: sync.Pool{
New: func() interface{} {
return new(Client)
},
},
registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000),
kickHandlerChan: make(chan *kickHandler, 1000),
validate: v,
clients: newUserMap(),
Compressor: NewGzipCompressor(),
Encoder: NewGobEncoder(),
}, nil
}
func (ws *WsServer) Run() error {
var client *Client
go func() {
for {
select {
case client = <-ws.registerChan:
ws.registerClient(client)
case client = <-ws.unregisterChan:
ws.unregisterClient(client)
case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo.clientOK, onlineInfo.oldClients, onlineInfo.newClient)
}
}
}()
http.HandleFunc("/", ws.wsHandler)
// http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {})
return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) // Start listening
}
var concurrentRequest = 3
func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error {
conns, err := ws.disCov.GetConns(ctx, config.Config.RpcRegisterName.OpenImMessageGatewayName)
if err != nil {
return err
}
wg := errgroup.Group{}
wg.SetLimit(concurrentRequest)
// Online push user online message to other node
for _, v := range conns {
v := v // safe closure var
if v.Target() == ws.disCov.GetSelfConnTarget() {
log.ZDebug(ctx, "Filter out this node", "node", v.Target())
continue
}
wg.Go(func() error {
msgClient := msggateway.NewMsgGatewayClient(v)
_, err := msgClient.MultiTerminalLoginCheck(ctx, &msggateway.MultiTerminalLoginCheckReq{
UserID: client.UserID,
PlatformID: int32(client.PlatformID), Token: client.token,
})
if err != nil {
log.ZWarn(ctx, "MultiTerminalLoginCheck err", err, "node", v.Target())
}
return nil
})
}
_ = wg.Wait()
return nil
}
func (ws *WsServer) SetKickHandlerInfo(i *kickHandler) {
ws.kickHandlerChan <- i
}
func (ws *WsServer) registerClient(client *Client) {
var (
userOK bool
clientOK bool
oldClients []*Client
)
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
if !userOK {
ws.clients.Set(client.UserID, client)
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
prommetrics.OnlineUserGauge.Add(1)
ws.onlineUserNum.Add(1)
ws.onlineUserConnNum.Add(1)
} else {
ws.multiTerminalLoginChecker(clientOK, oldClients, client)
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
if clientOK {
ws.clients.Set(client.UserID, client)
// 已经有同平台的连接存在
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
ws.onlineUserConnNum.Add(1)
} else {
ws.clients.Set(client.UserID, client)
ws.onlineUserConnNum.Add(1)
}
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
_ = ws.sendUserOnlineInfoToOtherNode(client.ctx, client)
}()
wg.Add(1)
go func() {
defer wg.Done()
ws.SetUserOnlineStatus(client.ctx, client, constant.Online)
}()
wg.Wait()
log.ZInfo(
client.ctx,
"user online",
"online user Num",
ws.onlineUserNum.Load(),
"online user conn Num",
ws.onlineUserConnNum.Load(),
)
}
func getRemoteAdders(client []*Client) string {
var ret string
for i, c := range client {
if i == 0 {
ret = c.ctx.GetRemoteAddr()
} else {
ret += "@" + c.ctx.GetRemoteAddr()
}
}
return ret
}
func (ws *WsServer) KickUserConn(client *Client) error {
ws.clients.deleteClients(client.UserID, []*Client{client})
return client.KickOnlineMessage()
}
func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) {
switch config.Config.MultiLoginPolicy {
case constant.DefalutNotKick:
case constant.PCAndOther:
if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC {
return
}
fallthrough
case constant.AllLoginButSameTermKick:
if !clientOK {
return
}
isDeleteUser := ws.clients.deleteClients(newClient.UserID, oldClients)
if isDeleteUser {
ws.onlineUserNum.Add(-1)
}
for _, c := range oldClients {
err := c.KickOnlineMessage()
if err != nil {
log.ZWarn(c.ctx, "KickOnlineMessage", err)
}
}
m, err := ws.cache.GetTokensWithoutError(
newClient.ctx,
newClient.UserID,
newClient.PlatformID,
)
if err != nil && err != redis.Nil {
log.ZWarn(
newClient.ctx,
"get token from redis err",
err,
"userID",
newClient.UserID,
"platformID",
newClient.PlatformID,
)
return
}
if m == nil {
log.ZWarn(
newClient.ctx,
"m is nil",
errors.New("m is nil"),
"userID",
newClient.UserID,
"platformID",
newClient.PlatformID,
)
return
}
log.ZDebug(
newClient.ctx,
"get token from redis",
"userID",
newClient.UserID,
"platformID",
newClient.PlatformID,
"tokenMap",
m,
)
for k := range m {
if k != newClient.ctx.GetToken() {
m[k] = constant.KickedToken
}
}
log.ZDebug(newClient.ctx, "set token map is ", "token map", m, "userID",
newClient.UserID, "token", newClient.ctx.GetToken())
err = ws.cache.SetTokenMapByUidPid(newClient.ctx, newClient.UserID, newClient.PlatformID, m)
if err != nil {
log.ZWarn(newClient.ctx, "SetTokenMapByUidPid err", err, "userID", newClient.UserID, "platformID", newClient.PlatformID)
return
}
}
}
func (ws *WsServer) unregisterClient(client *Client) {
defer ws.clientPool.Put(client)
isDeleteUser := ws.clients.delete(client.UserID, client.ctx.GetRemoteAddr())
if isDeleteUser {
ws.onlineUserNum.Add(-1)
prommetrics.OnlineUserGauge.Dec()
}
ws.onlineUserConnNum.Add(-1)
ws.SetUserOnlineStatus(client.ctx, client, constant.Offline)
log.ZInfo(client.ctx, "user offline", "close reason", client.closedErr, "online user Num", ws.onlineUserNum.Load(), "online user conn Num",
ws.onlineUserConnNum.Load(),
)
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
connContext := newContext(w, r)
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
httpError(connContext, errs.ErrConnOverMaxNumLimit)
return
}
var (
token string
userID string
platformIDStr string
exists bool
compression bool
)
token, exists = connContext.Query(Token)
if !exists {
httpError(connContext, errs.ErrConnArgsErr)
return
}
userID, exists = connContext.Query(WsUserID)
if !exists {
httpError(connContext, errs.ErrConnArgsErr)
return
}
platformIDStr, exists = connContext.Query(PlatformID)
if !exists {
httpError(connContext, errs.ErrConnArgsErr)
return
}
platformID, err := strconv.Atoi(platformIDStr)
if err != nil {
httpError(connContext, errs.ErrConnArgsErr)
return
}
if err = authverify.WsVerifyToken(token, userID, platformID); err != nil {
httpError(connContext, err)
return
}
m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID)
if err != nil {
httpError(connContext, err)
return
}
if v, ok := m[token]; ok {
switch v {
case constant.NormalToken:
case constant.KickedToken:
httpError(connContext, errs.ErrTokenKicked.Wrap())
return
default:
httpError(connContext, errs.ErrTokenUnknown.Wrap())
return
}
} else {
httpError(connContext, errs.ErrTokenNotExist.Wrap())
return
}
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
err = wsLongConn.GenerateLongConn(w, r)
if err != nil {
httpError(connContext, err)
return
}
compressProtoc, exists := connContext.Query(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
compressProtoc, exists = connContext.GetHeader(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
client := ws.clientPool.Get().(*Client)
client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws, token)
ws.registerChan <- client
go client.readMessage()
}