diff --git a/internal/msggateway/callback.go b/internal/msggateway/callback.go new file mode 100644 index 000000000..2d30b52b4 --- /dev/null +++ b/internal/msggateway/callback.go @@ -0,0 +1,155 @@ +package msggateway + +import ( + "context" + cbapi "github.com/OpenIMSDK/Open-IM-Server/pkg/callbackstruct" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/http" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" + "time" +) + +func url() string { + return config.Config.Callback.CallbackUrl +} + +func CallbackUserOnline(ctx context.Context, userID string, platformID int, isAppBackground bool, connID string) error { + if !config.Config.Callback.CallbackUserOnline.Enable { + return nil + } + req := cbapi.CallbackUserOnlineReq{ + UserStatusCallbackReq: cbapi.UserStatusCallbackReq{ + UserStatusBaseCallback: cbapi.UserStatusBaseCallback{ + CallbackCommand: constant.CallbackUserOnlineCommand, + OperationID: mcontext.GetOperationID(ctx), + PlatformID: platformID, + Platform: constant.PlatformIDToName(platformID), + }, + UserID: userID, + }, + Seq: time.Now().UnixMilli(), + IsAppBackground: isAppBackground, + ConnID: connID, + } + resp := cbapi.CommonCallbackResp{} + return http.CallBackPostReturn(ctx, url(), &req, &resp, config.Config.Callback.CallbackUserOnline) +} + +func CallbackUserOffline(ctx context.Context, userID string, platformID int, connID string) error { + if !config.Config.Callback.CallbackUserOffline.Enable { + return nil + } + req := &cbapi.CallbackUserOfflineReq{ + UserStatusCallbackReq: cbapi.UserStatusCallbackReq{ + UserStatusBaseCallback: cbapi.UserStatusBaseCallback{ + CallbackCommand: constant.CallbackUserOfflineCommand, + OperationID: mcontext.GetOperationID(ctx), + PlatformID: platformID, + Platform: constant.PlatformIDToName(platformID), + }, + UserID: userID, + }, + Seq: time.Now().UnixMilli(), + ConnID: connID, + } + resp := &cbapi.CallbackUserOfflineResp{} + return http.CallBackPostReturn(ctx, url(), req, resp, config.Config.Callback.CallbackUserOffline) +} + +func CallbackUserKickOff(ctx context.Context, userID string, platformID int) error { + if !config.Config.Callback.CallbackUserKickOff.Enable { + return nil + } + req := &cbapi.CallbackUserKickOffReq{ + UserStatusCallbackReq: cbapi.UserStatusCallbackReq{ + UserStatusBaseCallback: cbapi.UserStatusBaseCallback{ + CallbackCommand: constant.CallbackUserKickOffCommand, + OperationID: mcontext.GetOperationID(ctx), + PlatformID: platformID, + Platform: constant.PlatformIDToName(platformID), + }, + UserID: userID, + }, + Seq: time.Now().UnixMilli(), + } + resp := &cbapi.CommonCallbackResp{} + return http.CallBackPostReturn(ctx, url(), req, resp, config.Config.Callback.CallbackUserOffline) +} + +//func callbackUserOnline(operationID, userID string, platformID int, token string, isAppBackground bool, connID string) cbApi.CommonCallbackResp { +// callbackResp := cbApi.CommonCallbackResp{OperationID: operationID} +// if !config.Config.Callback.CallbackUserOnline.Enable { +// return callbackResp +// } +// callbackUserOnlineReq := cbApi.CallbackUserOnlineReq{ +// Token: token, +// UserStatusCallbackReq: cbApi.UserStatusCallbackReq{ +// UserStatusBaseCallback: cbApi.UserStatusBaseCallback{ +// CallbackCommand: constant.CallbackUserOnlineCommand, +// OperationID: operationID, +// PlatformID: int32(platformID), +// Platform: constant.PlatformIDToName(platformID), +// }, +// UserID: userID, +// }, +// Seq: int(time.Now().UnixNano() / 1e6), +// IsAppBackground: isAppBackground, +// ConnID: connID, +// } +// callbackUserOnlineResp := &cbApi.CallbackUserOnlineResp{CommonCallbackResp: &callbackResp} +// if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, constant.CallbackUserOnlineCommand, callbackUserOnlineReq, callbackUserOnlineResp, config.Config.Callback.CallbackUserOnline.CallbackTimeOut); err != nil { +// callbackResp.ErrCode = http2.StatusInternalServerError +// callbackResp.ErrMsg = err.Error() +// } +// return callbackResp +//} +//func callbackUserOffline(operationID, userID string, platformID int, connID string) cbApi.CommonCallbackResp { +// callbackResp := cbApi.CommonCallbackResp{OperationID: operationID} +// if !config.Config.Callback.CallbackUserOffline.Enable { +// return callbackResp +// } +// callbackOfflineReq := cbApi.CallbackUserOfflineReq{ +// UserStatusCallbackReq: cbApi.UserStatusCallbackReq{ +// UserStatusBaseCallback: cbApi.UserStatusBaseCallback{ +// CallbackCommand: constant.CallbackUserOfflineCommand, +// OperationID: operationID, +// PlatformID: int32(platformID), +// Platform: constant.PlatformIDToName(platformID), +// }, +// UserID: userID, +// }, +// Seq: int(time.Now().UnixNano() / 1e6), +// ConnID: connID, +// } +// callbackUserOfflineResp := &cbApi.CallbackUserOfflineResp{CommonCallbackResp: &callbackResp} +// if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, constant.CallbackUserOfflineCommand, callbackOfflineReq, callbackUserOfflineResp, config.Config.Callback.CallbackUserOffline.CallbackTimeOut); err != nil { +// callbackResp.ErrCode = http2.StatusInternalServerError +// callbackResp.ErrMsg = err.Error() +// } +// return callbackResp +//} +//func callbackUserKickOff(operationID string, userID string, platformID int) cbApi.CommonCallbackResp { +// callbackResp := cbApi.CommonCallbackResp{OperationID: operationID} +// if !config.Config.Callback.CallbackUserKickOff.Enable { +// return callbackResp +// } +// callbackUserKickOffReq := cbApi.CallbackUserKickOffReq{ +// UserStatusCallbackReq: cbApi.UserStatusCallbackReq{ +// UserStatusBaseCallback: cbApi.UserStatusBaseCallback{ +// CallbackCommand: constant.CallbackUserKickOffCommand, +// OperationID: operationID, +// PlatformID: int32(platformID), +// Platform: constant.PlatformIDToName(platformID), +// }, +// UserID: userID, +// }, +// Seq: int(time.Now().UnixNano() / 1e6), +// } +// callbackUserKickOffResp := &cbApi.CallbackUserKickOffResp{CommonCallbackResp: &callbackResp} +// if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, constant.CallbackUserKickOffCommand, callbackUserKickOffReq, callbackUserKickOffResp, config.Config.Callback.CallbackUserOffline.CallbackTimeOut); err != nil { +// callbackResp.ErrCode = http2.StatusInternalServerError +// callbackResp.ErrMsg = err.Error() +// } +// return callbackResp +//} diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go new file mode 100644 index 000000000..b38761b88 --- /dev/null +++ b/internal/msggateway/client.go @@ -0,0 +1,269 @@ +package msggateway + +import ( + "context" + "errors" + "fmt" + "runtime/debug" + "sync" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/apiresp" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "google.golang.org/protobuf/proto" +) + +var ErrConnClosed = errors.New("conn has closed") +var ErrNotSupportMessageProtocol = errors.New("not support message protocol") +var ErrClientClosed = errors.New("client actively close the connection") +var ErrPanic = errors.New("panic error") + +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a pong control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +type PongHandler func(string) error + +type Client struct { + w *sync.Mutex + conn LongConn + PlatformID int `json:"platformID"` + IsCompress bool `json:"isCompress"` + UserID string `json:"userID"` + IsBackground bool `json:"isBackground"` + ctx *UserConnContext + longConnServer LongConnServer + closed bool + closedErr error +} + +func newClient(ctx *UserConnContext, conn LongConn, isCompress bool) *Client { + return &Client{ + w: new(sync.Mutex), + conn: conn, + PlatformID: utils.StringToInt(ctx.GetPlatformID()), + IsCompress: isCompress, + UserID: ctx.GetUserID(), + ctx: ctx, + } +} +func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isBackground, isCompress bool, longConnServer LongConnServer) { + c.w = new(sync.Mutex) + c.conn = conn + c.PlatformID = utils.StringToInt(ctx.GetPlatformID()) + c.IsCompress = isCompress + c.IsBackground = isBackground + c.UserID = ctx.GetUserID() + c.ctx = ctx + c.longConnServer = longConnServer + c.IsBackground = false + c.closed = false + c.closedErr = nil +} +func (c *Client) pongHandler(_ string) error { + c.conn.SetReadDeadline(pongWait) + return nil +} +func (c *Client) readMessage() { + defer func() { + if r := recover(); r != nil { + c.closedErr = ErrPanic + fmt.Println("socket have panic err:", r, string(debug.Stack())) + } + c.close() + }() + c.conn.SetReadLimit(maxMessageSize) + _ = c.conn.SetReadDeadline(pongWait) + c.conn.SetPongHandler(c.pongHandler) + for { + messageType, message, returnErr := c.conn.ReadMessage() + if returnErr != nil { + c.closedErr = returnErr + return + } + log.ZDebug(c.ctx, "readMessage", "messageType", messageType) + if c.closed == true { //连接刚置位已经关闭,但是协程还没退出的场景 + c.closedErr = ErrConnClosed + return + } + switch messageType { + case MessageBinary: + _ = c.conn.SetReadDeadline(pongWait) + parseDataErr := c.handleMessage(message) + if parseDataErr != nil { + c.closedErr = parseDataErr + return + } + case MessageText: + c.closedErr = ErrNotSupportMessageProtocol + return + case PingMessage: + err := c.writePongMsg() + log.ZError(c.ctx, "writePongMsg", err) + case CloseMessage: + c.closedErr = ErrClientClosed + return + default: + } + } + +} +func (c *Client) handleMessage(message []byte) error { + if c.IsCompress { + var decompressErr error + message, decompressErr = c.longConnServer.DeCompress(message) + if decompressErr != nil { + return utils.Wrap(decompressErr, "") + } + } + var binaryReq Req + err := c.longConnServer.Decode(message, &binaryReq) + if err != nil { + return utils.Wrap(err, "") + } + if err := c.longConnServer.Validate(binaryReq); err != nil { + return utils.Wrap(err, "") + } + if binaryReq.SendID != c.UserID { + return utils.Wrap(errors.New("exception conn userID not same to req userID"), binaryReq.String()) + } + ctx := mcontext.WithMustInfoCtx([]string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.PlatformID), c.ctx.GetConnID()}) + log.ZDebug(ctx, "gateway req message", "req", binaryReq.String()) + var messageErr error + var resp []byte + switch binaryReq.ReqIdentifier { + case WSGetNewestSeq: + resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq) + case WSSendMsg: + resp, messageErr = c.longConnServer.SendMessage(ctx, binaryReq) + case WSSendSignalMsg: + resp, messageErr = c.longConnServer.SendSignalMessage(ctx, binaryReq) + case WSPullMsgBySeqList: + resp, messageErr = c.longConnServer.PullMessageBySeqList(ctx, binaryReq) + case WsLogoutMsg: + resp, messageErr = c.longConnServer.UserLogout(ctx, binaryReq) + case WsSetBackgroundStatus: + resp, messageErr = c.setAppBackgroundStatus(ctx, binaryReq) + default: + return fmt.Errorf("ReqIdentifier failed,sendID:%s,msgIncr:%s,reqIdentifier:%d", binaryReq.SendID, binaryReq.MsgIncr, binaryReq.ReqIdentifier) + } + c.replyMessage(ctx, &binaryReq, messageErr, resp) + return nil + +} +func (c *Client) setAppBackgroundStatus(ctx context.Context, req Req) ([]byte, error) { + resp, isBackground, messageErr := c.longConnServer.SetUserDeviceBackground(ctx, req) + if messageErr != nil { + return nil, messageErr + } + c.IsBackground = isBackground + //todo callback + return resp, nil + +} +func (c *Client) close() { + c.w.Lock() + defer c.w.Unlock() + c.closed = true + c.conn.Close() + c.longConnServer.UnRegister(c) + +} +func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, resp []byte) { + errResp := apiresp.ParseError(err) + mReply := Resp{ + ReqIdentifier: binaryReq.ReqIdentifier, + MsgIncr: binaryReq.MsgIncr, + OperationID: binaryReq.OperationID, + ErrCode: errResp.ErrCode, + ErrMsg: errResp.ErrMsg, + Data: resp, + } + log.ZDebug(ctx, "gateway reply message", "resp", mReply.String()) + err = c.writeBinaryMsg(mReply) + if err != nil { + log.ZWarn(ctx, "wireBinaryMsg replyMessage", err, "resp", mReply.String()) + } +} +func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error { + var msg sdkws.PushMessages + conversationID := utils.GetConversationIDByMsg(msgData) + m := map[string]*sdkws.PullMsgs{conversationID: {Msgs: []*sdkws.MsgData{msgData}}} + if utils.IsNotification(conversationID) { + msg.NotificationMsgs = m + } else { + msg.Msgs = m + } + log.ZDebug(ctx, "PushMessage", "msg", msg) + data, err := proto.Marshal(&msg) + if err != nil { + return err + } + resp := Resp{ + ReqIdentifier: WSPushMsg, + OperationID: mcontext.GetOperationID(ctx), + Data: data, + } + return c.writeBinaryMsg(resp) +} + +func (c *Client) KickOnlineMessage() error { + resp := Resp{ + ReqIdentifier: WSKickOnlineMsg, + } + return c.writeBinaryMsg(resp) +} + +func (c *Client) writeBinaryMsg(resp Resp) error { + c.w.Lock() + defer c.w.Unlock() + if c.closed == true { + return nil + } + encodedBuf := bufferPool.Get().([]byte) + resultBuf := bufferPool.Get().([]byte) + encodedBuf, err := c.longConnServer.Encode(resp) + if err != nil { + return utils.Wrap(err, "") + } + _ = c.conn.SetWriteDeadline(writeWait) + if c.IsCompress { + var compressErr error + resultBuf, compressErr = c.longConnServer.Compress(encodedBuf) + if compressErr != nil { + return utils.Wrap(compressErr, "") + } + return c.conn.WriteMessage(MessageBinary, resultBuf) + } else { + return c.conn.WriteMessage(MessageBinary, encodedBuf) + } +} + +func (c *Client) writePongMsg() error { + c.w.Lock() + defer c.w.Unlock() + if c.closed == true { + return nil + } + _ = c.conn.SetWriteDeadline(writeWait) + return c.conn.WriteMessage(PongMessage, nil) + +} diff --git a/internal/msggateway/compressor.go b/internal/msggateway/compressor.go new file mode 100644 index 000000000..a37c74ccd --- /dev/null +++ b/internal/msggateway/compressor.go @@ -0,0 +1,46 @@ +package msggateway + +import ( + "bytes" + "compress/gzip" + "io" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" +) + +type Compressor interface { + Compress(rawData []byte) ([]byte, error) + DeCompress(compressedData []byte) ([]byte, error) +} +type GzipCompressor struct { + compressProtocol string +} + +func NewGzipCompressor() *GzipCompressor { + return &GzipCompressor{compressProtocol: "gzip"} +} +func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) { + gzipBuffer := bytes.Buffer{} + gz := gzip.NewWriter(&gzipBuffer) + if _, err := gz.Write(rawData); err != nil { + return nil, utils.Wrap(err, "") + } + if err := gz.Close(); err != nil { + return nil, utils.Wrap(err, "") + } + return gzipBuffer.Bytes(), nil +} + +func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) { + buff := bytes.NewBuffer(compressedData) + reader, err := gzip.NewReader(buff) + if err != nil { + return nil, utils.Wrap(err, "NewReader failed") + } + compressedData, err = io.ReadAll(reader) + if err != nil { + return nil, utils.Wrap(err, "ReadAll failed") + } + _ = reader.Close() + return compressedData, nil +} diff --git a/internal/msggateway/constant.go b/internal/msggateway/constant.go new file mode 100644 index 000000000..58ee6e940 --- /dev/null +++ b/internal/msggateway/constant.go @@ -0,0 +1,41 @@ +package msggateway + +import "time" + +const ( + WsUserID = "sendID" + CommonUserID = "userID" + PlatformID = "platformID" + ConnID = "connID" + Token = "token" + OperationID = "operationID" + Compression = "compression" + GzipCompressionProtocol = "gzip" + BackgroundStatus = "isBackground" +) +const ( + WebSocket = iota + 1 +) +const ( + //Websocket Protocol + WSGetNewestSeq = 1001 + WSPullMsgBySeqList = 1002 + WSSendMsg = 1003 + WSSendSignalMsg = 1004 + WSPushMsg = 2001 + WSKickOnlineMsg = 2002 + WsLogoutMsg = 2003 + WsSetBackgroundStatus = 2004 + WSDataError = 3001 +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 30 * time.Second + + // Maximum message size allowed from peer. + maxMessageSize = 51200 +) diff --git a/internal/msggateway/context.go b/internal/msggateway/context.go new file mode 100644 index 000000000..cd395e7e0 --- /dev/null +++ b/internal/msggateway/context.go @@ -0,0 +1,104 @@ +package msggateway + +import ( + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "net/http" + "strconv" + "time" +) + +type UserConnContext struct { + RespWriter http.ResponseWriter + Req *http.Request + Path string + Method string + RemoteAddr string + ConnID string +} + +func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) { + return +} + +func (c *UserConnContext) Done() <-chan struct{} { + return nil +} + +func (c *UserConnContext) Err() error { + return nil +} + +func (c *UserConnContext) Value(key any) any { + switch key { + case constant.OpUserID: + return c.GetUserID() + case constant.OperationID: + return c.GetOperationID() + case constant.ConnID: + return c.GetConnID() + case constant.OpUserPlatform: + return constant.PlatformIDToName(utils.StringToInt(c.GetPlatformID())) + case constant.RemoteAddr: + return c.RemoteAddr + default: + return "" + } +} + +func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnContext { + return &UserConnContext{ + RespWriter: respWriter, + Req: req, + Path: req.URL.Path, + Method: req.Method, + RemoteAddr: req.RemoteAddr, + ConnID: utils.Md5(req.RemoteAddr + "_" + strconv.Itoa(int(utils.GetCurrentTimestampByMill()))), + } +} +func (c *UserConnContext) GetRemoteAddr() string { + return c.RemoteAddr +} +func (c *UserConnContext) Query(key string) (string, bool) { + var value string + if value = c.Req.URL.Query().Get(key); value == "" { + return value, false + } + return value, true +} +func (c *UserConnContext) GetHeader(key string) (string, bool) { + var value string + if value = c.Req.Header.Get(key); value == "" { + return value, false + } + return value, true +} +func (c *UserConnContext) SetHeader(key, value string) { + c.RespWriter.Header().Set(key, value) +} +func (c *UserConnContext) ErrReturn(error string, code int) { + http.Error(c.RespWriter, error, code) +} +func (c *UserConnContext) GetConnID() string { + return c.ConnID +} +func (c *UserConnContext) GetUserID() string { + return c.Req.URL.Query().Get(WsUserID) +} +func (c *UserConnContext) GetPlatformID() string { + return c.Req.URL.Query().Get(PlatformID) +} +func (c *UserConnContext) GetOperationID() string { + return c.Req.URL.Query().Get(OperationID) +} +func (c *UserConnContext) GetToken() string { + return c.Req.URL.Query().Get(Token) +} +func (c *UserConnContext) GetBackground() bool { + b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus)) + if err != nil { + return false + } else { + return b + } +} diff --git a/internal/msggateway/encoder.go b/internal/msggateway/encoder.go new file mode 100644 index 000000000..6a4104ff4 --- /dev/null +++ b/internal/msggateway/encoder.go @@ -0,0 +1,37 @@ +package msggateway + +import ( + "bytes" + "encoding/gob" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" +) + +type Encoder interface { + Encode(data interface{}) ([]byte, error) + Decode(encodeData []byte, decodeData interface{}) error +} + +type GobEncoder struct { +} + +func NewGobEncoder() *GobEncoder { + return &GobEncoder{} +} +func (g *GobEncoder) Encode(data interface{}) ([]byte, error) { + buff := bytes.Buffer{} + enc := gob.NewEncoder(&buff) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buff.Bytes(), nil +} +func (g *GobEncoder) Decode(encodeData []byte, decodeData interface{}) error { + buff := bytes.NewBuffer(encodeData) + dec := gob.NewDecoder(buff) + err := dec.Decode(decodeData) + if err != nil { + return utils.Wrap(err, "") + } + return nil +} diff --git a/internal/msggateway/http_error.go b/internal/msggateway/http_error.go new file mode 100644 index 000000000..fd00276fb --- /dev/null +++ b/internal/msggateway/http_error.go @@ -0,0 +1,7 @@ +package msggateway + +import "github.com/OpenIMSDK/Open-IM-Server/pkg/apiresp" + +func httpError(ctx *UserConnContext, err error) { + apiresp.HttpError(ctx.RespWriter, err) +} diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go new file mode 100644 index 000000000..633b7c479 --- /dev/null +++ b/internal/msggateway/hub_server.go @@ -0,0 +1,154 @@ +package msggateway + +import ( + "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/prome" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify" + "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" + "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msggateway" + "github.com/OpenIMSDK/Open-IM-Server/pkg/startrpc" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "google.golang.org/grpc" +) + +func (s *Server) InitServer(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis() + if err != nil { + return err + } + msgModel := cache.NewMsgCacheModel(rdb) + s.LongConnServer.SetDiscoveryRegistry(client) + s.LongConnServer.SetCacheHandler(msgModel) + msggateway.RegisterMsgGatewayServer(server, s) + return nil +} + +func (s *Server) Start() error { + return startrpc.Start(s.rpcPort, config.Config.RpcRegisterName.OpenImMessageGatewayName, s.prometheusPort, s.InitServer) +} + +type Server struct { + rpcPort int + prometheusPort int + LongConnServer LongConnServer + pushTerminal []int +} + +func (s *Server) SetLongConnServer(LongConnServer LongConnServer) { + s.LongConnServer = LongConnServer +} + +func NewServer(rpcPort int, longConnServer LongConnServer) *Server { + return &Server{rpcPort: rpcPort, LongConnServer: longConnServer, pushTerminal: []int{constant.IOSPlatformID, constant.AndroidPlatformID}} +} + +func (s *Server) OnlinePushMsg(context context.Context, req *msggateway.OnlinePushMsgReq) (*msggateway.OnlinePushMsgResp, error) { + panic("implement me") +} + +func (s *Server) GetUsersOnlineStatus(ctx context.Context, req *msggateway.GetUsersOnlineStatusReq) (*msggateway.GetUsersOnlineStatusResp, error) { + if !tokenverify.IsAppManagerUid(ctx) { + return nil, errs.ErrNoPermission.Wrap("only app manager") + } + var resp msggateway.GetUsersOnlineStatusResp + for _, userID := range req.UserIDs { + clients, ok := s.LongConnServer.GetUserAllCons(userID) + if !ok { + continue + } + temp := new(msggateway.GetUsersOnlineStatusResp_SuccessResult) + temp.UserID = userID + for _, client := range clients { + if client != nil { + ps := new(msggateway.GetUsersOnlineStatusResp_SuccessDetail) + ps.Platform = constant.PlatformIDToName(client.PlatformID) + ps.Status = constant.OnlineStatus + ps.ConnID = client.ctx.GetConnID() + ps.IsBackground = client.IsBackground + temp.Status = constant.OnlineStatus + temp.DetailPlatformStatus = append(temp.DetailPlatformStatus, ps) + } + } + if temp.Status == constant.OnlineStatus { + resp.SuccessResult = append(resp.SuccessResult, temp) + } + } + return &resp, nil +} + +func (s *Server) OnlineBatchPushOneMsg(ctx context.Context, req *msggateway.OnlineBatchPushOneMsgReq) (*msggateway.OnlineBatchPushOneMsgResp, error) { + panic("implement me") +} + +func (s *Server) SuperGroupOnlineBatchPushOneMsg(ctx context.Context, req *msggateway.OnlineBatchPushOneMsgReq) (*msggateway.OnlineBatchPushOneMsgResp, error) { + var singleUserResult []*msggateway.SingleMsgToUserResults + for _, v := range req.PushToUserIDs { + var resp []*msggateway.SingleMsgToUserPlatform + tempT := &msggateway.SingleMsgToUserResults{ + UserID: v, + } + clients, ok := s.LongConnServer.GetUserAllCons(v) + if !ok { + log.ZDebug(ctx, "push user not online", "userID", v) + tempT.Resp = resp + singleUserResult = append(singleUserResult, tempT) + continue + } + log.ZDebug(ctx, "push user online", "clients", clients, "userID", v) + for _, client := range clients { + if client != nil { + temp := &msggateway.SingleMsgToUserPlatform{ + RecvID: v, + RecvPlatFormID: int32(client.PlatformID), + } + if !client.IsBackground || (client.IsBackground == true && client.PlatformID != constant.IOSPlatformID) { + err := client.PushMessage(ctx, req.MsgData) + if err != nil { + temp.ResultCode = -2 + resp = append(resp, temp) + } else { + if utils.IsContainInt(client.PlatformID, s.pushTerminal) { + tempT.OnlinePush = true + prome.Inc(prome.MsgOnlinePushSuccessCounter) + resp = append(resp, temp) + } + } + } else { + temp.ResultCode = -3 + resp = append(resp, temp) + } + } + } + tempT.Resp = resp + singleUserResult = append(singleUserResult, tempT) + } + + return &msggateway.OnlineBatchPushOneMsgResp{ + SinglePushResult: singleUserResult, + }, nil +} + +func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOfflineReq) (*msggateway.KickUserOfflineResp, error) { + for _, v := range req.KickUserIDList { + if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok { + for _, client := range clients { + err := client.KickOnlineMessage() + if err != nil { + return nil, err + } + } + } + } + return &msggateway.KickUserOfflineResp{}, nil +} + +func (s *Server) MultiTerminalLoginCheck(ctx context.Context, req *msggateway.MultiTerminalLoginCheckReq) (*msggateway.MultiTerminalLoginCheckResp, error) { + //TODO implement me + panic("implement me") +} diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go new file mode 100644 index 000000000..2e001667a --- /dev/null +++ b/internal/msggateway/init.go @@ -0,0 +1,24 @@ +package msggateway + +import ( + "fmt" + "time" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" +) + +func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { + fmt.Println("start rpc/msg_gateway server, port: ", rpcPort, wsPort, prometheusPort, ", OpenIM version: ", config.Version) + longServer, err := NewWsServer( + WithPort(wsPort), + WithMaxConnNum(int64(config.Config.LongConnSvr.WebsocketMaxConnNum)), + WithHandshakeTimeout(time.Duration(config.Config.LongConnSvr.WebsocketTimeout)*time.Second), + WithMessageMaxMsgLength(config.Config.LongConnSvr.WebsocketMaxMsgLen)) + if err != nil { + return err + } + hubServer := NewServer(rpcPort, longServer) + go hubServer.Start() + hubServer.LongConnServer.Run() + return nil +} diff --git a/internal/msggateway/long_conn.go b/internal/msggateway/long_conn.go new file mode 100644 index 000000000..ed31dbda7 --- /dev/null +++ b/internal/msggateway/long_conn.go @@ -0,0 +1,107 @@ +package msggateway + +import ( + "github.com/gorilla/websocket" + "net/http" + "time" +) + +type LongConn interface { + //Close this connection + Close() error + // WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1). + WriteMessage(messageType int, message []byte) error + // ReadMessage Read message from connection. + ReadMessage() (int, []byte, error) + // SetReadDeadline sets the read deadline on the underlying network connection, + //after a read has timed out, will return an error. + SetReadDeadline(timeout time.Duration) error + // SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error. + SetWriteDeadline(timeout time.Duration) error + // Dial Try to dial a connection,url must set auth args,header can control compress data + Dial(urlStr string, requestHeader http.Header) (*http.Response, error) + // IsNil Whether the connection of the current long connection is nil + IsNil() bool + // SetConnNil Set the connection of the current long connection to nil + SetConnNil() + // SetReadLimit sets the maximum size for a message read from the peer.bytes + SetReadLimit(limit int64) + SetPongHandler(handler PongHandler) + // GenerateLongConn Check the connection of the current and when it was sent are the same + GenerateLongConn(w http.ResponseWriter, r *http.Request) error +} +type GWebSocket struct { + protocolType int + conn *websocket.Conn + handshakeTimeout time.Duration +} + +func newGWebSocket(protocolType int, handshakeTimeout time.Duration) *GWebSocket { + return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout} +} + +func (d *GWebSocket) Close() error { + return d.conn.Close() +} +func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error { + upgrader := &websocket.Upgrader{ + HandshakeTimeout: d.handshakeTimeout, + CheckOrigin: func(r *http.Request) bool { return true }, + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return err + } + d.conn = conn + return nil + +} +func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { + //d.setSendConn(d.conn) + return d.conn.WriteMessage(messageType, message) +} + +//func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) { +// d.sendConn = sendConn +//} + +func (d *GWebSocket) ReadMessage() (int, []byte, error) { + return d.conn.ReadMessage() +} +func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error { + return d.conn.SetReadDeadline(time.Now().Add(timeout)) +} + +func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error { + return d.conn.SetWriteDeadline(time.Now().Add(timeout)) +} + +func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { + conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader) + if err == nil { + d.conn = conn + } + return httpResp, err + +} + +func (d *GWebSocket) IsNil() bool { + if d.conn != nil { + return false + } + return true +} + +func (d *GWebSocket) SetConnNil() { + d.conn = nil +} +func (d *GWebSocket) SetReadLimit(limit int64) { + d.conn.SetReadLimit(limit) +} +func (d *GWebSocket) SetPongHandler(handler PongHandler) { + d.conn.SetPongHandler(handler) +} + +//func (d *GWebSocket) CheckSendConnDiffNow() bool { +// return d.conn == d.sendConn +//} diff --git a/internal/msggateway/message_handler.go b/internal/msggateway/message_handler.go new file mode 100644 index 000000000..5430c8175 --- /dev/null +++ b/internal/msggateway/message_handler.go @@ -0,0 +1,181 @@ +package msggateway + +import ( + "context" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/push" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" + "github.com/OpenIMSDK/Open-IM-Server/pkg/rpcclient" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "github.com/go-playground/validator/v10" + "google.golang.org/protobuf/proto" +) + +type Req struct { + ReqIdentifier int32 `json:"reqIdentifier" validate:"required"` + Token string `json:"token" ` + SendID string `json:"sendID" validate:"required"` + OperationID string `json:"operationID" validate:"required"` + MsgIncr string `json:"msgIncr" validate:"required"` + Data []byte `json:"data"` +} + +func (r *Req) String() string { + return utils.StructToJsonString(r) +} + +type Resp struct { + ReqIdentifier int32 `json:"reqIdentifier"` + MsgIncr string `json:"msgIncr"` + OperationID string `json:"operationID"` + ErrCode int `json:"errCode"` + ErrMsg string `json:"errMsg"` + Data []byte `json:"data"` +} + +func (r *Resp) String() string { + return utils.StructToJsonString(r) +} + +type MessageHandler interface { + GetSeq(context context.Context, data Req) ([]byte, error) + SendMessage(context context.Context, data Req) ([]byte, error) + SendSignalMessage(context context.Context, data Req) ([]byte, error) + PullMessageBySeqList(context context.Context, data Req) ([]byte, error) + UserLogout(context context.Context, data Req) ([]byte, error) + SetUserDeviceBackground(context context.Context, data Req) ([]byte, bool, error) +} + +var _ MessageHandler = (*GrpcHandler)(nil) + +type GrpcHandler struct { + msgRpcClient *rpcclient.MessageRpcClient + pushClient *rpcclient.PushRpcClient + validate *validator.Validate +} + +func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDiscoveryRegistry) *GrpcHandler { + msgRpcClient := rpcclient.NewMessageRpcClient(client) + pushRpcClient := rpcclient.NewPushRpcClient(client) + return &GrpcHandler{msgRpcClient: &msgRpcClient, + pushClient: &pushRpcClient, validate: validate} +} + +func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { + req := sdkws.GetMaxSeqReq{} + if err := proto.Unmarshal(data.Data, &req); err != nil { + return nil, err + } + if err := g.validate.Struct(&req); err != nil { + return nil, err + } + resp, err := g.msgRpcClient.GetMaxSeq(context, &req) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil +} + +func (g GrpcHandler) SendMessage(context context.Context, data Req) ([]byte, error) { + msgData := sdkws.MsgData{} + if err := proto.Unmarshal(data.Data, &msgData); err != nil { + return nil, err + } + if err := g.validate.Struct(&msgData); err != nil { + return nil, err + } + req := msg.SendMsgReq{MsgData: &msgData} + resp, err := g.msgRpcClient.SendMsg(context, &req) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil +} + +func (g GrpcHandler) SendSignalMessage(context context.Context, data Req) ([]byte, error) { + resp, err := g.msgRpcClient.SendMsg(context, nil) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil +} + +func (g GrpcHandler) PullMessageBySeqList(context context.Context, data Req) ([]byte, error) { + req := sdkws.PullMessageBySeqsReq{} + if err := proto.Unmarshal(data.Data, &req); err != nil { + return nil, err + } + if err := g.validate.Struct(data); err != nil { + return nil, err + } + resp, err := g.msgRpcClient.PullMessageBySeqList(context, &req) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil +} + +func (g GrpcHandler) UserLogout(context context.Context, data Req) ([]byte, error) { + req := push.DelUserPushTokenReq{} + if err := proto.Unmarshal(data.Data, &req); err != nil { + return nil, err + } + resp, err := g.pushClient.DelUserPushToken(context, &req) + if err != nil { + return nil, err + } + c, err := proto.Marshal(resp) + if err != nil { + return nil, err + } + return c, nil +} +func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data Req) ([]byte, bool, error) { + req := sdkws.SetAppBackgroundStatusReq{} + if err := proto.Unmarshal(data.Data, &req); err != nil { + return nil, false, err + } + if err := g.validate.Struct(data); err != nil { + return nil, false, err + } + return nil, req.IsBackground, nil +} + +//func (g GrpcHandler) call[T any](ctx context.Context, data Req, m proto.Message, rpc func(ctx context.Context, req proto.Message)) ([]byte, error) { +// if err := proto.Unmarshal(data.Data, m); err != nil { +// return nil, err +// } +// if err := g.validate.Struct(m); err != nil { +// return nil, err +// } +// rpc(ctx, m) +// req := msg.SendMsgReq{MsgData: &msgData} +// resp, err := g.notification.Msg.SendMsg(context, &req) +// if err != nil { +// return nil, err +// } +// c, err := proto.Marshal(resp) +// if err != nil { +// return nil, err +// } +// return c, nil +//} diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go new file mode 100644 index 000000000..dcf183015 --- /dev/null +++ b/internal/msggateway/n_ws_server.go @@ -0,0 +1,318 @@ +package msggateway + +import ( + "context" + "errors" + "net/http" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" + "github.com/redis/go-redis/v9" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify" + "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "github.com/go-playground/validator/v10" +) + +type LongConnServer interface { + Run() error + wsHandler(w http.ResponseWriter, r *http.Request) + GetUserAllCons(userID string) ([]*Client, bool) + GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) + Validate(s interface{}) error + SetCacheHandler(cache cache.MsgModel) + SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) + UnRegister(c *Client) + Compressor + Encoder + MessageHandler +} + +var bufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 1024) + }, +} + +type WsServer struct { + port int + wsMaxConnNum int64 + registerChan chan *Client + unregisterChan chan *Client + kickHandlerChan chan *kickHandler + clients *UserMap + clientPool sync.Pool + onlineUserNum int64 + onlineUserConnNum int64 + handshakeTimeout time.Duration + hubServer *Server + validate *validator.Validate + cache cache.MsgModel + Compressor + Encoder + MessageHandler +} +type kickHandler struct { + clientOK bool + oldClients []*Client + newClient *Client +} + +func (ws *WsServer) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) { + ws.MessageHandler = NewGrpcHandler(ws.validate, client) +} +func (ws *WsServer) SetCacheHandler(cache cache.MsgModel) { + ws.cache = cache +} + +func (ws *WsServer) UnRegister(c *Client) { + ws.unregisterChan <- c +} + +func (ws *WsServer) Validate(s interface{}) error { + return nil +} + +func (ws *WsServer) GetUserAllCons(userID string) ([]*Client, bool) { + return ws.clients.GetAll(userID) +} + +func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) { + return ws.clients.Get(userID, platform) +} + +func NewWsServer(opts ...Option) (*WsServer, error) { + var config configs + for _, o := range opts { + o(&config) + } + if config.port < 1024 { + return nil, errors.New("port not allow to listen") + + } + v := validator.New() + return &WsServer{ + port: config.port, + wsMaxConnNum: config.maxConnNum, + handshakeTimeout: config.handshakeTimeout, + clientPool: sync.Pool{ + New: func() interface{} { + return new(Client) + }, + }, + registerChan: make(chan *Client, 1000), + unregisterChan: make(chan *Client, 1000), + kickHandlerChan: make(chan *kickHandler, 1000), + validate: v, + clients: newUserMap(), + Compressor: NewGzipCompressor(), + Encoder: NewGobEncoder(), + }, nil +} +func (ws *WsServer) Run() error { + var client *Client + go func() { + for { + select { + case client = <-ws.registerChan: + ws.registerClient(client) + case client = <-ws.unregisterChan: + ws.unregisterClient(client) + case onlineInfo := <-ws.kickHandlerChan: + ws.multiTerminalLoginChecker(onlineInfo) + } + } + }() + http.HandleFunc("/", ws.wsHandler) + // http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {}) + return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening +} + +func (ws *WsServer) registerClient(client *Client) { + var ( + userOK bool + clientOK bool + oldClients []*Client + ) + oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) + if !userOK { + ws.clients.Set(client.UserID, client) + log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID) + atomic.AddInt64(&ws.onlineUserNum, 1) + atomic.AddInt64(&ws.onlineUserConnNum, 1) + + } else { + i := &kickHandler{ + clientOK: clientOK, + oldClients: oldClients, + newClient: client, + } + ws.kickHandlerChan <- i + log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID) + if clientOK { + ws.clients.Set(client.UserID, client) + //已经有同平台的连接存在 + log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients)) + atomic.AddInt64(&ws.onlineUserConnNum, 1) + } else { + ws.clients.Set(client.UserID, client) + + atomic.AddInt64(&ws.onlineUserConnNum, 1) + } + } + log.ZInfo(client.ctx, "user online", "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum) +} +func getRemoteAdders(client []*Client) string { + var ret string + for i, c := range client { + if i == 0 { + ret = c.ctx.GetRemoteAddr() + } else { + ret += "@" + c.ctx.GetRemoteAddr() + } + } + return ret +} + +func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { + switch config.Config.MultiLoginPolicy { + case constant.DefalutNotKick: + case constant.PCAndOther: + if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC { + return + } + fallthrough + case constant.AllLoginButSameTermKick: + if info.clientOK { + ws.clients.deleteClients(info.newClient.UserID, info.oldClients) + for _, c := range info.oldClients { + err := c.KickOnlineMessage() + if err != nil { + log.ZWarn(c.ctx, "KickOnlineMessage", err) + } + } + m, err := ws.cache.GetTokensWithoutError(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID) + if err != nil && err != redis.Nil { + log.ZWarn(info.newClient.ctx, "get token from redis err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID) + return + } + if m == nil { + log.ZWarn(info.newClient.ctx, "m is nil", errors.New("m is nil"), "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID) + return + } + log.ZDebug(info.newClient.ctx, "get token from redis", "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID, "tokenMap", m) + + for k, _ := range m { + if k != info.newClient.ctx.GetToken() { + m[k] = constant.KickedToken + } + } + log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID) + err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m) + if err != nil { + log.ZWarn(info.newClient.ctx, "SetTokenMapByUidPid err", err, "userID", info.newClient.UserID, "platformID", info.newClient.PlatformID) + return + } + } + } + +} +func (ws *WsServer) unregisterClient(client *Client) { + defer ws.clientPool.Put(client) + isDeleteUser := ws.clients.delete(client.UserID, client.ctx.GetRemoteAddr()) + if isDeleteUser { + atomic.AddInt64(&ws.onlineUserNum, -1) + } + atomic.AddInt64(&ws.onlineUserConnNum, -1) + log.ZInfo(client.ctx, "user offline", "close reason", client.closedErr, "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum) +} + +func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { + connContext := newContext(w, r) + if ws.onlineUserConnNum >= ws.wsMaxConnNum { + httpError(connContext, errs.ErrConnOverMaxNumLimit) + return + } + var ( + token string + userID string + platformIDStr string + exists bool + compression bool + ) + + token, exists = connContext.Query(Token) + if !exists { + httpError(connContext, errs.ErrConnArgsErr) + return + } + userID, exists = connContext.Query(WsUserID) + if !exists { + httpError(connContext, errs.ErrConnArgsErr) + return + } + platformIDStr, exists = connContext.Query(PlatformID) + if !exists { + httpError(connContext, errs.ErrConnArgsErr) + return + } + platformID, err := strconv.Atoi(platformIDStr) + if err != nil { + httpError(connContext, errs.ErrConnArgsErr) + return + } + if err := tokenverify.WsVerifyToken(token, userID, platformID); err != nil { + httpError(connContext, err) + return + } + m, err := ws.cache.GetTokensWithoutError(context.Background(), userID, platformID) + if err != nil { + httpError(connContext, err) + return + } + if v, ok := m[token]; ok { + switch v { + case constant.NormalToken: + case constant.KickedToken: + httpError(connContext, errs.ErrTokenKicked.Wrap()) + return + default: + httpError(connContext, errs.ErrTokenUnknown.Wrap()) + return + } + } else { + httpError(connContext, errs.ErrTokenNotExist.Wrap()) + return + } + wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout) + err = wsLongConn.GenerateLongConn(w, r) + if err != nil { + httpError(connContext, err) + return + } + compressProtoc, exists := connContext.Query(Compression) + if exists { + if compressProtoc == GzipCompressionProtocol { + compression = true + } + } + compressProtoc, exists = connContext.GetHeader(Compression) + if exists { + if compressProtoc == GzipCompressionProtocol { + compression = true + } + } + client := ws.clientPool.Get().(*Client) + client.ResetClient(connContext, wsLongConn, connContext.GetBackground(), compression, ws) + ws.registerChan <- client + go client.readMessage() +} diff --git a/internal/msggateway/options.go b/internal/msggateway/options.go new file mode 100644 index 000000000..a54ffe880 --- /dev/null +++ b/internal/msggateway/options.go @@ -0,0 +1,36 @@ +package msggateway + +import "time" + +type Option func(opt *configs) +type configs struct { + //长连接监听端口 + port int + //长连接允许最大链接数 + maxConnNum int64 + //连接握手超时时间 + handshakeTimeout time.Duration + //允许消息最大长度 + messageMaxMsgLength int +} + +func WithPort(port int) Option { + return func(opt *configs) { + opt.port = port + } +} +func WithMaxConnNum(num int64) Option { + return func(opt *configs) { + opt.maxConnNum = num + } +} +func WithHandshakeTimeout(t time.Duration) Option { + return func(opt *configs) { + opt.handshakeTimeout = t + } +} +func WithMessageMaxMsgLength(length int) Option { + return func(opt *configs) { + opt.messageMaxMsgLength = length + } +} diff --git a/internal/msggateway/user_map.go b/internal/msggateway/user_map.go new file mode 100644 index 000000000..63881bc1a --- /dev/null +++ b/internal/msggateway/user_map.go @@ -0,0 +1,100 @@ +package msggateway + +import ( + "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "sync" +) + +type UserMap struct { + m sync.Map +} + +func newUserMap() *UserMap { + return &UserMap{} +} +func (u *UserMap) GetAll(key string) ([]*Client, bool) { + allClients, ok := u.m.Load(key) + if ok { + return allClients.([]*Client), ok + } + return nil, ok +} +func (u *UserMap) Get(key string, platformID int) ([]*Client, bool, bool) { + allClients, userExisted := u.m.Load(key) + if userExisted { + var clients []*Client + for _, client := range allClients.([]*Client) { + if client.PlatformID == platformID { + clients = append(clients, client) + } + } + if len(clients) > 0 { + return clients, userExisted, true + + } + return clients, userExisted, false + } + return nil, userExisted, false +} +func (u *UserMap) Set(key string, v *Client) { + allClients, existed := u.m.Load(key) + if existed { + log.ZDebug(context.Background(), "Set existed", "user_id", key, "client", *v) + oldClients := allClients.([]*Client) + oldClients = append(oldClients, v) + u.m.Store(key, oldClients) + } else { + log.ZDebug(context.Background(), "Set not existed", "user_id", key, "client", *v) + var clients []*Client + clients = append(clients, v) + u.m.Store(key, clients) + } +} +func (u *UserMap) delete(key string, connRemoteAddr string) (isDeleteUser bool) { + allClients, existed := u.m.Load(key) + if existed { + oldClients := allClients.([]*Client) + var a []*Client + for _, client := range oldClients { + if client.ctx.GetRemoteAddr() != connRemoteAddr { + a = append(a, client) + } + } + if len(a) == 0 { + u.m.Delete(key) + return true + } else { + u.m.Store(key, a) + return false + } + } + return existed +} +func (u *UserMap) deleteClients(key string, clients []*Client) (isDeleteUser bool) { + m := utils.SliceToMapAny(clients, func(c *Client) (string, struct{}) { + return c.ctx.GetRemoteAddr(), struct{}{} + }) + allClients, existed := u.m.Load(key) + if existed { + oldClients := allClients.([]*Client) + var a []*Client + for _, client := range oldClients { + if _, ok := m[client.ctx.GetRemoteAddr()]; !ok { + a = append(a, client) + } + } + if len(a) == 0 { + u.m.Delete(key) + return true + } else { + u.m.Store(key, a) + return false + } + } + return existed +} +func (u *UserMap) DeleteAll(key string) { + u.m.Delete(key) +} diff --git a/internal/msgtransfer/init.go b/internal/msgtransfer/init.go new file mode 100644 index 000000000..589e81726 --- /dev/null +++ b/internal/msgtransfer/init.go @@ -0,0 +1,110 @@ +package msgtransfer + +import ( + "fmt" + "sync" + "time" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/controller" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/relation" + relationTb "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/relation" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/tx" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/unrelation" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mw" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/prome" + openKeeper "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry/zookeeper" + "github.com/OpenIMSDK/Open-IM-Server/pkg/rpcclient" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type MsgTransfer struct { + persistentCH *PersistentConsumerHandler // 聊天记录持久化到mysql的消费者 订阅的topic: ws2ms_chat + historyCH *OnlineHistoryRedisConsumerHandler // 这个消费者聚合消息, 订阅的topic:ws2ms_chat, 修改通知发往msg_to_modify topic, 消息存入redis后Incr Redis, 再发消息到ms2pschat topic推送, 发消息到msg_to_mongo topic持久化 + historyMongoCH *OnlineHistoryMongoConsumerHandler // mongoDB批量插入, 成功后删除redis中消息,以及处理删除通知消息删除的 订阅的topic: msg_to_mongo + modifyCH *ModifyMsgConsumerHandler // 负责消费修改消息通知的consumer, 订阅的topic: msg_to_modify +} + +func StartTransfer(prometheusPort int) error { + db, err := relation.NewGormDB() + if err != nil { + return err + } + if err := db.AutoMigrate(&relationTb.ChatLogModel{}); err != nil { + return err + } + rdb, err := cache.NewRedis() + if err != nil { + return err + } + mongo, err := unrelation.NewMongo() + if err != nil { + return err + } + if err := mongo.CreateMsgIndex(); err != nil { + return err + } + client, err := openKeeper.NewClient(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, + openKeeper.WithFreq(time.Hour), openKeeper.WithRoundRobin(), openKeeper.WithUserNameAndPassword(config.Config.Zookeeper.Username, + config.Config.Zookeeper.Password), openKeeper.WithTimeout(10), openKeeper.WithLogger(log.NewZkLogger())) + if err != nil { + return err + } + if client.CreateRpcRootNodes(config.GetServiceNames()); err != nil { + return err + } + client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials())) + msgModel := cache.NewMsgCacheModel(rdb) + msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase()) + extendMsgModel := unrelation.NewExtendMsgSetMongoDriver(mongo.GetDatabase()) + extendMsgCache := cache.NewExtendMsgSetCacheRedis(rdb, extendMsgModel, cache.GetDefaultOpt()) + chatLogDatabase := controller.NewChatLogDatabase(relation.NewChatLogGorm(db)) + extendMsgDatabase := controller.NewExtendMsgDatabase(extendMsgModel, extendMsgCache, tx.NewMongo(mongo.GetClient())) + msgDatabase := controller.NewCommonMsgDatabase(msgDocModel, msgModel) + conversationRpcClient := rpcclient.NewConversationRpcClient(client) + groupRpcClient := rpcclient.NewGroupRpcClient(client) + msgTransfer := NewMsgTransfer(chatLogDatabase, extendMsgDatabase, msgDatabase, &conversationRpcClient, &groupRpcClient) + msgTransfer.initPrometheus() + return msgTransfer.Start(prometheusPort) +} + +func NewMsgTransfer(chatLogDatabase controller.ChatLogDatabase, + extendMsgDatabase controller.ExtendMsgDatabase, msgDatabase controller.CommonMsgDatabase, + conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient) *MsgTransfer { + return &MsgTransfer{persistentCH: NewPersistentConsumerHandler(chatLogDatabase), historyCH: NewOnlineHistoryRedisConsumerHandler(msgDatabase, conversationRpcClient, groupRpcClient), + historyMongoCH: NewOnlineHistoryMongoConsumerHandler(msgDatabase), modifyCH: NewModifyMsgConsumerHandler(extendMsgDatabase)} +} + +func (m *MsgTransfer) initPrometheus() { + prome.NewSeqGetSuccessCounter() + prome.NewSeqGetFailedCounter() + prome.NewSeqSetSuccessCounter() + prome.NewSeqSetFailedCounter() + prome.NewMsgInsertRedisSuccessCounter() + prome.NewMsgInsertRedisFailedCounter() + prome.NewMsgInsertMongoSuccessCounter() + prome.NewMsgInsertMongoFailedCounter() +} + +func (m *MsgTransfer) Start(prometheusPort int) error { + var wg sync.WaitGroup + wg.Add(1) + fmt.Println("start msg transfer", "prometheusPort:", prometheusPort) + if config.Config.ChatPersistenceMysql { + // go m.persistentCH.persistentConsumerGroup.RegisterHandleAndConsumer(m.persistentCH) + } else { + fmt.Println("msg transfer not start mysql consumer") + } + go m.historyCH.historyConsumerGroup.RegisterHandleAndConsumer(m.historyCH) + go m.historyMongoCH.historyConsumerGroup.RegisterHandleAndConsumer(m.historyMongoCH) + go m.modifyCH.modifyMsgConsumerGroup.RegisterHandleAndConsumer(m.modifyCH) + err := prome.StartPrometheusSrv(prometheusPort) + if err != nil { + return err + } + wg.Wait() + return nil +} diff --git a/internal/msgtransfer/modify_msg_handler.go b/internal/msgtransfer/modify_msg_handler.go new file mode 100644 index 000000000..0b9ad2aff --- /dev/null +++ b/internal/msgtransfer/modify_msg_handler.go @@ -0,0 +1,113 @@ +package msgtransfer + +import ( + "context" + "encoding/json" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/controller" + unRelationTb "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/unrelation" + kfk "github.com/OpenIMSDK/Open-IM-Server/pkg/common/kafka" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" + pbMsg "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "github.com/Shopify/sarama" + + "google.golang.org/protobuf/proto" +) + +type ModifyMsgConsumerHandler struct { + modifyMsgConsumerGroup *kfk.MConsumerGroup + + extendMsgDatabase controller.ExtendMsgDatabase + extendSetMsgModel unRelationTb.ExtendMsgSetModel +} + +func NewModifyMsgConsumerHandler(database controller.ExtendMsgDatabase) *ModifyMsgConsumerHandler { + return &ModifyMsgConsumerHandler{ + modifyMsgConsumerGroup: kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{KafkaVersion: sarama.V2_0_0_0, + OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false}, []string{config.Config.Kafka.MsgToModify.Topic}, + config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToModify), + extendMsgDatabase: database, + } +} + +func (ModifyMsgConsumerHandler) Setup(_ sarama.ConsumerGroupSession) error { return nil } +func (ModifyMsgConsumerHandler) Cleanup(_ sarama.ConsumerGroupSession) error { return nil } +func (mmc *ModifyMsgConsumerHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, + claim sarama.ConsumerGroupClaim) error { + for msg := range claim.Messages() { + ctx := mmc.modifyMsgConsumerGroup.GetContextFromMsg(msg) + log.ZDebug(ctx, "kafka get info to mysql", "ModifyMsgConsumerHandler", msg.Topic, "msgPartition", msg.Partition, "msg", string(msg.Value), "key", string(msg.Key)) + if len(msg.Value) != 0 { + mmc.ModifyMsg(ctx, msg, string(msg.Key), sess) + } else { + log.ZError(ctx, "msg get from kafka but is nil", nil, "key", msg.Key) + } + sess.MarkMessage(msg, "") + } + return nil +} + +func (mmc *ModifyMsgConsumerHandler) ModifyMsg(ctx context.Context, cMsg *sarama.ConsumerMessage, msgKey string, _ sarama.ConsumerGroupSession) { + msgFromMQ := pbMsg.MsgDataToModifyByMQ{} + operationID := mcontext.GetOperationID(ctx) + err := proto.Unmarshal(cMsg.Value, &msgFromMQ) + if err != nil { + log.ZError(ctx, "msg_transfer Unmarshal msg err", err, "msg", string(cMsg.Value)) + return + } + log.ZDebug(ctx, "proto.Unmarshal MsgDataToMQ", "msgs", msgFromMQ.String()) + for _, msg := range msgFromMQ.Messages { + isReactionFromCache := utils.GetSwitchFromOptions(msg.Options, constant.IsReactionFromCache) + if !isReactionFromCache { + continue + } + ctx = mcontext.SetOperationID(ctx, operationID) + if msg.ContentType == constant.ReactionMessageModifier { + notification := &sdkws.ReactionMessageModifierNotification{} + if err := json.Unmarshal(msg.Content, notification); err != nil { + continue + } + if notification.IsExternalExtensions { + continue + } + if !notification.IsReact { + // first time to modify + var reactionExtensionList = make(map[string]unRelationTb.KeyValueModel) + extendMsg := unRelationTb.ExtendMsgModel{ + ReactionExtensionList: reactionExtensionList, + ClientMsgID: notification.ClientMsgID, + MsgFirstModifyTime: notification.MsgFirstModifyTime, + } + for _, v := range notification.SuccessReactionExtensions { + reactionExtensionList[v.TypeKey] = unRelationTb.KeyValueModel{ + TypeKey: v.TypeKey, + Value: v.Value, + LatestUpdateTime: v.LatestUpdateTime, + } + } + + if err := mmc.extendMsgDatabase.InsertExtendMsg(ctx, notification.ConversationID, notification.SessionType, &extendMsg); err != nil { + // log.ZError(ctx, "MsgFirstModify InsertExtendMsg failed", notification.ConversationID, notification.SessionType, extendMsg, err.Error()) + continue + } + } else { + if err := mmc.extendMsgDatabase.InsertOrUpdateReactionExtendMsgSet(ctx, notification.ConversationID, notification.SessionType, notification.ClientMsgID, notification.MsgFirstModifyTime, mmc.extendSetMsgModel.Pb2Model(notification.SuccessReactionExtensions)); err != nil { + // log.NewError(operationID, "InsertOrUpdateReactionExtendMsgSet failed") + } + } + } else if msg.ContentType == constant.ReactionMessageDeleter { + notification := &sdkws.ReactionMessageDeleteNotification{} + if err := json.Unmarshal(msg.Content, notification); err != nil { + continue + } + if err := mmc.extendMsgDatabase.DeleteReactionExtendMsgSet(ctx, notification.ConversationID, notification.SessionType, notification.ClientMsgID, notification.MsgFirstModifyTime, mmc.extendSetMsgModel.Pb2Model(notification.SuccessReactionExtensions)); err != nil { + // log.NewError(operationID, "InsertOrUpdateReactionExtendMsgSet failed") + } + } + } +} diff --git a/internal/msgtransfer/online_history_msg_handler.go b/internal/msgtransfer/online_history_msg_handler.go new file mode 100644 index 000000000..7f776d5ee --- /dev/null +++ b/internal/msgtransfer/online_history_msg_handler.go @@ -0,0 +1,343 @@ +package msgtransfer + +import ( + "context" + "strconv" + "strings" + "sync" + "time" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/controller" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/kafka" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" + "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/sdkws" + "github.com/OpenIMSDK/Open-IM-Server/pkg/rpcclient" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + "github.com/Shopify/sarama" + "github.com/go-redis/redis" + "google.golang.org/protobuf/proto" +) + +const ConsumerMsgs = 3 +const SourceMessages = 4 +const MongoMessages = 5 +const ChannelNum = 100 + +type MsgChannelValue struct { + uniqueKey string + ctx context.Context + ctxMsgList []*ContextMsg +} + +type TriggerChannelValue struct { + ctx context.Context + cMsgList []*sarama.ConsumerMessage +} + +type Cmd2Value struct { + Cmd int + Value interface{} +} +type ContextMsg struct { + message *sdkws.MsgData + ctx context.Context +} + +type OnlineHistoryRedisConsumerHandler struct { + historyConsumerGroup *kafka.MConsumerGroup + chArrays [ChannelNum]chan Cmd2Value + msgDistributionCh chan Cmd2Value + + singleMsgSuccessCount uint64 + singleMsgFailedCount uint64 + singleMsgSuccessCountMutex sync.Mutex + singleMsgFailedCountMutex sync.Mutex + + msgDatabase controller.CommonMsgDatabase + conversationRpcClient *rpcclient.ConversationRpcClient + groupRpcClient *rpcclient.GroupRpcClient +} + +func NewOnlineHistoryRedisConsumerHandler(database controller.CommonMsgDatabase, conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient) *OnlineHistoryRedisConsumerHandler { + var och OnlineHistoryRedisConsumerHandler + och.msgDatabase = database + och.msgDistributionCh = make(chan Cmd2Value) //no buffer channel + go och.MessagesDistributionHandle() + for i := 0; i < ChannelNum; i++ { + och.chArrays[i] = make(chan Cmd2Value, 50) + go och.Run(i) + } + och.conversationRpcClient = conversationRpcClient + och.groupRpcClient = groupRpcClient + och.historyConsumerGroup = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{KafkaVersion: sarama.V2_0_0_0, + OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false}, []string{config.Config.Kafka.LatestMsgToRedis.Topic}, + config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToRedis) + //statistics.NewStatistics(&och.singleMsgSuccessCount, config.Config.ModuleName.MsgTransferName, fmt.Sprintf("%d second singleMsgCount insert to mongo", constant.StatisticsTimeInterval), constant.StatisticsTimeInterval) + return &och +} + +func (och *OnlineHistoryRedisConsumerHandler) Run(channelID int) { + for { + select { + case cmd := <-och.chArrays[channelID]: + switch cmd.Cmd { + case SourceMessages: + msgChannelValue := cmd.Value.(MsgChannelValue) + ctxMsgList := msgChannelValue.ctxMsgList + ctx := msgChannelValue.ctx + log.ZDebug(ctx, "msg arrived channel", "channel id", channelID, "msgList length", len(ctxMsgList), "uniqueKey", msgChannelValue.uniqueKey) + storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList, modifyMsgList := och.getPushStorageMsgList(ctxMsgList) + log.ZDebug(ctx, "msg lens", "storageMsgList", len(storageMsgList), "notStorageMsgList", len(notStorageMsgList), + "storageNotificationList", len(storageNotificationList), "notStorageNotificationList", len(notStorageNotificationList), "modifyMsgList", len(modifyMsgList)) + conversationIDMsg := utils.GetChatConversationIDByMsg(ctxMsgList[0].message) + conversationIDNotification := utils.GetNotificationConversationID(ctxMsgList[0].message) + och.handleMsg(ctx, msgChannelValue.uniqueKey, conversationIDMsg, storageMsgList, notStorageMsgList) + och.handleNotification(ctx, msgChannelValue.uniqueKey, conversationIDNotification, storageNotificationList, notStorageNotificationList) + if err := och.msgDatabase.MsgToModifyMQ(ctx, msgChannelValue.uniqueKey, conversationIDNotification, modifyMsgList); err != nil { + log.ZError(ctx, "msg to modify mq error", err, "uniqueKey", msgChannelValue.uniqueKey, "modifyMsgList", modifyMsgList) + } + } + } + } +} + +// 获取消息/通知 存储的消息列表, 不存储并且推送的消息列表, +func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList(totalMsgs []*ContextMsg) (storageMsgList, notStorageMsgList, storageNotificatoinList, notStorageNotificationList, modifyMsgList []*sdkws.MsgData) { + isStorage := func(msg *sdkws.MsgData) bool { + options2 := utils.Options(msg.Options) + if options2.IsHistory() { + return true + } else { + // if !(!options2.IsSenderSync() && conversationID == msg.MsgData.SendID) { + // return false + // } + return false + } + } + for _, v := range totalMsgs { + options := utils.Options(v.message.Options) + if !options.IsNotNotification() { + // clone msg from notificationMsg + if options.IsSendMsg() { + msg := proto.Clone(v.message).(*sdkws.MsgData) + // 消息 + if v.message.Options != nil { + msg.Options = utils.NewMsgOptions() + } + if options.IsOfflinePush() { + v.message.Options = utils.WithOptions(utils.Options(v.message.Options), utils.WithOfflinePush(false)) + msg.Options = utils.WithOptions(utils.Options(msg.Options), utils.WithOfflinePush(true)) + } + if options.IsUnreadCount() { + v.message.Options = utils.WithOptions(utils.Options(v.message.Options), utils.WithUnreadCount(false)) + msg.Options = utils.WithOptions(utils.Options(msg.Options), utils.WithUnreadCount(true)) + } + storageMsgList = append(storageMsgList, msg) + } + if isStorage(v.message) { + storageNotificatoinList = append(storageNotificatoinList, v.message) + } else { + notStorageNotificationList = append(notStorageNotificationList, v.message) + } + } else { + if isStorage(v.message) { + storageMsgList = append(storageMsgList, v.message) + } else { + notStorageMsgList = append(notStorageMsgList, v.message) + } + } + if v.message.ContentType == constant.ReactionMessageModifier || v.message.ContentType == constant.ReactionMessageDeleter { + modifyMsgList = append(modifyMsgList, v.message) + } + } + return +} + +func (och *OnlineHistoryRedisConsumerHandler) handleNotification(ctx context.Context, key, conversationID string, storageList, notStorageList []*sdkws.MsgData) { + och.toPushTopic(ctx, key, conversationID, notStorageList) + if len(storageList) > 0 { + lastSeq, _, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList) + if err != nil { + log.ZError(ctx, "notification batch insert to redis error", err, "conversationID", conversationID, "storageList", storageList) + return + } + log.ZDebug(ctx, "success to next topic", "conversationID", conversationID) + och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq) + och.toPushTopic(ctx, key, conversationID, storageList) + } +} + +func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(ctx context.Context, key, conversationID string, msgs []*sdkws.MsgData) { + for _, v := range msgs { + och.msgDatabase.MsgToPushMQ(ctx, key, conversationID, v) + } +} + +func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*sdkws.MsgData) { + och.toPushTopic(ctx, key, conversationID, notStorageList) + if len(storageList) > 0 { + lastSeq, isNewConversation, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList) + if err != nil && errs.Unwrap(err) != redis.Nil { + log.ZError(ctx, "batch data insert to redis err", err, "storageMsgList", storageList) + och.singleMsgFailedCountMutex.Lock() + och.singleMsgFailedCount += uint64(len(storageList)) + och.singleMsgFailedCountMutex.Unlock() + return + } + if isNewConversation { + if storageList[0].SessionType == constant.SuperGroupChatType { + log.ZInfo(ctx, "group chat first create conversation", "conversationID", conversationID) + userIDs, err := och.groupRpcClient.GetGroupMemberIDs(ctx, storageList[0].GroupID) + if err != nil { + log.ZWarn(ctx, "get group member ids error", err, "conversationID", conversationID) + } else { + if err := och.conversationRpcClient.GroupChatFirstCreateConversation(ctx, storageList[0].GroupID, userIDs); err != nil { + log.ZWarn(ctx, "single chat first create conversation error", err, "conversationID", conversationID) + } + } + } else { + if err := och.conversationRpcClient.SingleChatFirstCreateConversation(ctx, storageList[0].RecvID, storageList[0].SendID); err != nil { + log.ZWarn(ctx, "single chat first create conversation error", err, "conversationID", conversationID) + } + } + } + + log.ZDebug(ctx, "success incr to next topic") + och.singleMsgSuccessCountMutex.Lock() + och.singleMsgSuccessCount += uint64(len(storageList)) + och.singleMsgSuccessCountMutex.Unlock() + och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq) + och.toPushTopic(ctx, key, conversationID, storageList) + } +} + +func (och *OnlineHistoryRedisConsumerHandler) MessagesDistributionHandle() { + for { + aggregationMsgs := make(map[string][]*ContextMsg, ChannelNum) + select { + case cmd := <-och.msgDistributionCh: + switch cmd.Cmd { + case ConsumerMsgs: + triggerChannelValue := cmd.Value.(TriggerChannelValue) + ctx := triggerChannelValue.ctx + consumerMessages := triggerChannelValue.cMsgList + //Aggregation map[userid]message list + log.ZDebug(ctx, "batch messages come to distribution center", "length", len(consumerMessages)) + for i := 0; i < len(consumerMessages); i++ { + ctxMsg := &ContextMsg{} + msgFromMQ := &sdkws.MsgData{} + err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ) + if err != nil { + log.ZError(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value)) + continue + } + var arr []string + for i, header := range consumerMessages[i].Headers { + arr = append(arr, strconv.Itoa(i), string(header.Key), string(header.Value)) + } + log.ZInfo(ctx, "consumer.kafka.GetContextWithMQHeader", "len", len(consumerMessages[i].Headers), "header", strings.Join(arr, ", ")) + ctxMsg.ctx = kafka.GetContextWithMQHeader(consumerMessages[i].Headers) + ctxMsg.message = msgFromMQ + log.ZDebug(ctx, "single msg come to distribution center", "message", msgFromMQ, "key", string(consumerMessages[i].Key)) + //aggregationMsgs[string(consumerMessages[i].Key)] = append(aggregationMsgs[string(consumerMessages[i].Key)], ctxMsg) + if oldM, ok := aggregationMsgs[string(consumerMessages[i].Key)]; ok { + oldM = append(oldM, ctxMsg) + aggregationMsgs[string(consumerMessages[i].Key)] = oldM + } else { + m := make([]*ContextMsg, 0, 100) + m = append(m, ctxMsg) + aggregationMsgs[string(consumerMessages[i].Key)] = m + } + } + log.ZDebug(ctx, "generate map list users len", "length", len(aggregationMsgs)) + for uniqueKey, v := range aggregationMsgs { + if len(v) >= 0 { + hashCode := utils.GetHashCode(uniqueKey) + channelID := hashCode % ChannelNum + newCtx := withAggregationCtx(ctx, v) + log.ZDebug(newCtx, "generate channelID", "hashCode", hashCode, "channelID", channelID, "uniqueKey", uniqueKey) + och.chArrays[channelID] <- Cmd2Value{Cmd: SourceMessages, Value: MsgChannelValue{uniqueKey: uniqueKey, ctxMsgList: v, ctx: newCtx}} + } + } + } + } + } +} +func withAggregationCtx(ctx context.Context, values []*ContextMsg) context.Context { + var allMessageOperationID string + for i, v := range values { + if opid := mcontext.GetOperationID(v.ctx); opid != "" { + if i == 0 { + allMessageOperationID += opid + + } else { + allMessageOperationID += "$" + opid + } + } + } + return mcontext.SetOperationID(ctx, allMessageOperationID) +} + +func (och *OnlineHistoryRedisConsumerHandler) Setup(_ sarama.ConsumerGroupSession) error { return nil } +func (och *OnlineHistoryRedisConsumerHandler) Cleanup(_ sarama.ConsumerGroupSession) error { + return nil +} + +func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { // a instance in the consumer group + for { + if sess == nil { + log.ZWarn(context.Background(), "sess == nil, waiting", nil) + time.Sleep(100 * time.Millisecond) + } else { + break + } + } + rwLock := new(sync.RWMutex) + log.ZDebug(context.Background(), "online new session msg come", "highWaterMarkOffset", + claim.HighWaterMarkOffset(), "topic", claim.Topic(), "partition", claim.Partition()) + cMsg := make([]*sarama.ConsumerMessage, 0, 1000) + t := time.NewTicker(time.Millisecond * 100) + go func() { + for { + select { + case <-t.C: + if len(cMsg) > 0 { + rwLock.Lock() + ccMsg := make([]*sarama.ConsumerMessage, 0, 1000) + for _, v := range cMsg { + ccMsg = append(ccMsg, v) + } + cMsg = make([]*sarama.ConsumerMessage, 0, 1000) + rwLock.Unlock() + split := 1000 + ctx := mcontext.WithTriggerIDContext(context.Background(), utils.OperationIDGenerator()) + log.ZDebug(ctx, "timer trigger msg consumer start", "length", len(ccMsg)) + for i := 0; i < len(ccMsg)/split; i++ { + //log.Debug() + och.msgDistributionCh <- Cmd2Value{Cmd: ConsumerMsgs, Value: TriggerChannelValue{ + ctx: ctx, cMsgList: ccMsg[i*split : (i+1)*split]}} + } + if (len(ccMsg) % split) > 0 { + och.msgDistributionCh <- Cmd2Value{Cmd: ConsumerMsgs, Value: TriggerChannelValue{ + ctx: ctx, cMsgList: ccMsg[split*(len(ccMsg)/split):]}} + } + log.ZDebug(ctx, "timer trigger msg consumer end", "length", len(ccMsg)) + } + } + } + }() + for msg := range claim.Messages() { + rwLock.Lock() + if len(msg.Value) != 0 { + cMsg = append(cMsg, msg) + } + rwLock.Unlock() + sess.MarkMessage(msg, "") + } + return nil +} diff --git a/internal/msgtransfer/online_msg_to_mongo_handler.go b/internal/msgtransfer/online_msg_to_mongo_handler.go new file mode 100644 index 000000000..77c5e9a6b --- /dev/null +++ b/internal/msgtransfer/online_msg_to_mongo_handler.go @@ -0,0 +1,73 @@ +package msgtransfer + +import ( + "context" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/controller" + kfk "github.com/OpenIMSDK/Open-IM-Server/pkg/common/kafka" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + pbMsg "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" + "github.com/Shopify/sarama" + "google.golang.org/protobuf/proto" +) + +type OnlineHistoryMongoConsumerHandler struct { + historyConsumerGroup *kfk.MConsumerGroup + msgDatabase controller.CommonMsgDatabase +} + +func NewOnlineHistoryMongoConsumerHandler(database controller.CommonMsgDatabase) *OnlineHistoryMongoConsumerHandler { + mc := &OnlineHistoryMongoConsumerHandler{ + historyConsumerGroup: kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{KafkaVersion: sarama.V2_0_0_0, + OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false}, []string{config.Config.Kafka.MsgToMongo.Topic}, + config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToMongo), + msgDatabase: database, + } + return mc +} + +func (mc *OnlineHistoryMongoConsumerHandler) handleChatWs2Mongo(ctx context.Context, cMsg *sarama.ConsumerMessage, key string, session sarama.ConsumerGroupSession) { + msg := cMsg.Value + msgFromMQ := pbMsg.MsgDataToMongoByMQ{} + err := proto.Unmarshal(msg, &msgFromMQ) + if err != nil { + log.ZError(ctx, "unmarshall failed", err, "key", key, "len", len(msg)) + return + } + if len(msgFromMQ.MsgData) == 0 { + log.ZError(ctx, "msgFromMQ.MsgData is empty", nil, "cMsg", cMsg) + return + } + log.ZInfo(ctx, "mongo consumer recv msg", "msgs", msgFromMQ.MsgData) + err = mc.msgDatabase.BatchInsertChat2DB(ctx, msgFromMQ.ConversationID, msgFromMQ.MsgData, msgFromMQ.LastSeq) + if err != nil { + log.ZError(ctx, "single data insert to mongo err", err, "msg", msgFromMQ.MsgData, "conversationID", msgFromMQ.ConversationID) + } + var seqs []int64 + for _, msg := range msgFromMQ.MsgData { + seqs = append(seqs, msg.Seq) + } + err = mc.msgDatabase.DeleteMessagesFromCache(ctx, msgFromMQ.ConversationID, seqs) + if err != nil { + log.ZError(ctx, "remove cache msg from redis err", err, "msg", msgFromMQ.MsgData, "conversationID", msgFromMQ.ConversationID) + } + mc.msgDatabase.DelUserDeleteMsgsList(ctx, msgFromMQ.ConversationID, seqs) +} + +func (OnlineHistoryMongoConsumerHandler) Setup(_ sarama.ConsumerGroupSession) error { return nil } +func (OnlineHistoryMongoConsumerHandler) Cleanup(_ sarama.ConsumerGroupSession) error { return nil } +func (mc *OnlineHistoryMongoConsumerHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { // a instance in the consumer group + log.ZDebug(context.Background(), "online new session msg come", "highWaterMarkOffset", + claim.HighWaterMarkOffset(), "topic", claim.Topic(), "partition", claim.Partition()) + for msg := range claim.Messages() { + ctx := mc.historyConsumerGroup.GetContextFromMsg(msg) + if len(msg.Value) != 0 { + mc.handleChatWs2Mongo(ctx, msg, string(msg.Key), sess) + } else { + log.ZError(ctx, "mongo msg get from kafka but is nil", nil, "conversationID", msg.Key) + } + sess.MarkMessage(msg, "") + } + return nil +} diff --git a/internal/msgtransfer/persistent_msg_handler.go b/internal/msgtransfer/persistent_msg_handler.go new file mode 100644 index 000000000..026f5f5a0 --- /dev/null +++ b/internal/msgtransfer/persistent_msg_handler.go @@ -0,0 +1,88 @@ +/* +** description(""). +** copyright('tuoyun,www.tuoyun.net'). +** author("fg,Gordon@tuoyun.net"). +** time(2021/5/11 15:37). + */ +package msgtransfer + +import ( + "context" + + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/controller" + kfk "github.com/OpenIMSDK/Open-IM-Server/pkg/common/kafka" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + pbMsg "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msg" + "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" + + "github.com/Shopify/sarama" + "google.golang.org/protobuf/proto" +) + +type PersistentConsumerHandler struct { + persistentConsumerGroup *kfk.MConsumerGroup + chatLogDatabase controller.ChatLogDatabase +} + +func NewPersistentConsumerHandler(database controller.ChatLogDatabase) *PersistentConsumerHandler { + return &PersistentConsumerHandler{ + persistentConsumerGroup: kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{KafkaVersion: sarama.V2_0_0_0, + OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false}, []string{config.Config.Kafka.LatestMsgToRedis.Topic}, + config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToMySql), + chatLogDatabase: database, + } +} + +func (pc *PersistentConsumerHandler) handleChatWs2Mysql(ctx context.Context, cMsg *sarama.ConsumerMessage, msgKey string, _ sarama.ConsumerGroupSession) { + msg := cMsg.Value + var tag bool + msgFromMQ := pbMsg.MsgDataToMQ{} + err := proto.Unmarshal(msg, &msgFromMQ) + if err != nil { + log.ZError(ctx, "msg_transfer Unmarshal msg err", err) + return + } + return + log.ZDebug(ctx, "handleChatWs2Mysql", "msg", msgFromMQ.MsgData) + //Control whether to store history messages (mysql) + isPersist := utils.GetSwitchFromOptions(msgFromMQ.MsgData.Options, constant.IsPersistent) + //Only process receiver data + if isPersist { + switch msgFromMQ.MsgData.SessionType { + case constant.SingleChatType, constant.NotificationChatType: + if msgKey == msgFromMQ.MsgData.RecvID { + tag = true + } + case constant.GroupChatType: + if msgKey == msgFromMQ.MsgData.SendID { + tag = true + } + case constant.SuperGroupChatType: + tag = true + } + if tag { + log.ZInfo(ctx, "msg_transfer msg persisting", "msg", string(msg)) + if err = pc.chatLogDatabase.CreateChatLog(&msgFromMQ); err != nil { + log.ZError(ctx, "Message insert failed", err, "msg", msgFromMQ.String()) + return + } + } + } +} +func (PersistentConsumerHandler) Setup(_ sarama.ConsumerGroupSession) error { return nil } +func (PersistentConsumerHandler) Cleanup(_ sarama.ConsumerGroupSession) error { return nil } +func (pc *PersistentConsumerHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { + for msg := range claim.Messages() { + ctx := pc.persistentConsumerGroup.GetContextFromMsg(msg) + log.ZDebug(ctx, "kafka get info to mysql", "msgTopic", msg.Topic, "msgPartition", msg.Partition, "msg", string(msg.Value), "key", string(msg.Key)) + if len(msg.Value) != 0 { + pc.handleChatWs2Mysql(ctx, msg, string(msg.Key), sess) + } else { + log.ZError(ctx, "msg get from kafka but is nil", nil, "key", msg.Key) + } + sess.MarkMessage(msg, "") + } + return nil +}