From 9c19fd43fa0bff33ee39c5b6bdba2c7833260940 Mon Sep 17 00:00:00 2001 From: Gordon <46924906+FGadvancer@users.noreply.github.com> Date: Sun, 28 Apr 2024 18:09:39 +0800 Subject: [PATCH] refactor: msg transfer refactor. --- internal/msgtransfer/init.go | 37 ++ .../msgtransfer/online_history_msg_handler.go | 348 +++++++----------- internal/push/push_handler.go | 19 +- internal/rpc/group/group.go | 33 +- internal/rpc/group/notification.go | 4 - pkg/util/batcher/batcher.go | 269 ++++++++++++++ pkg/util/batcher/batcher_test.go | 66 ++++ 7 files changed, 536 insertions(+), 240 deletions(-) create mode 100644 pkg/util/batcher/batcher.go create mode 100644 pkg/util/batcher/batcher_test.go diff --git a/internal/msgtransfer/init.go b/internal/msgtransfer/init.go index 68d953e90..0cd0239eb 100644 --- a/internal/msgtransfer/init.go +++ b/internal/msgtransfer/init.go @@ -20,6 +20,7 @@ import ( "github.com/openimsdk/tools/db/mongoutil" "github.com/openimsdk/tools/db/redisutil" "github.com/openimsdk/tools/utils/datautil" + "net" "net/http" "os" "os/signal" @@ -167,4 +168,40 @@ func (m *MsgTransfer) Start(index int, config *Config) error { close(netDone) return netErr } + + if config.MsgTransfer.Prometheus.Enable { + go func() { + proreg := prometheus.NewRegistry() + proreg.MustRegister( + collectors.NewGoCollector(), + ) + proreg.MustRegister(prommetrics.GetGrpcCusMetrics("Transfer", &config.Share)...) + + http.Handle("/metrics", promhttp.HandlerFor(proreg, promhttp.HandlerOpts{Registry: proreg})) + + lc := net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + if err != nil { + return fmt.Errorf("setsockopt failed: %w", err) + } + }) + }, + } + + listener, err := lc.Listen(context.Background(), "tcp", fmt.Sprintf(":%d", prometheusPort)) + if err != nil { + netErr = errs.WrapMsg(err, "prometheus start error", "prometheusPort", prometheusPort) + netDone <- struct{}{} + return + } + + err = http.Serve(listener, nil) + if err != nil && err != http.ErrServerClosed { + netErr = errs.WrapMsg(err, "HTTP server start error", "prometheusPort", prometheusPort) + netDone <- struct{}{} + } + }() + } } diff --git a/internal/msgtransfer/online_history_msg_handler.go b/internal/msgtransfer/online_history_msg_handler.go index 8691e92ab..0b24a6607 100644 --- a/internal/msgtransfer/online_history_msg_handler.go +++ b/internal/msgtransfer/online_history_msg_handler.go @@ -16,6 +16,7 @@ package msgtransfer import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/util/batcher" "strconv" "strings" "sync" @@ -71,10 +72,7 @@ type OnlineHistoryRedisConsumerHandler struct { chArrays [ChannelNum]chan Cmd2Value msgDistributionCh chan Cmd2Value - // singleMsgSuccessCount uint64 - // singleMsgFailedCount uint64 - // singleMsgSuccessCountMutex sync.Mutex - // singleMsgFailedCountMutex sync.Mutex + redisMessageBatches *batcher.Batcher[sarama.ConsumerMessage] msgDatabase controller.CommonMsgDatabase conversationRpcClient *rpcclient.ConversationRpcClient @@ -89,83 +87,79 @@ func NewOnlineHistoryRedisConsumerHandler(kafkaConf *config.Kafka, database cont } var och OnlineHistoryRedisConsumerHandler och.msgDatabase = database - och.msgDistributionCh = make(chan Cmd2Value) // no buffer channel - go och.MessagesDistributionHandle() - for i := 0; i < ChannelNum; i++ { - och.chArrays[i] = make(chan Cmd2Value, 50) - go och.Run(i) + + b := batcher.New[sarama.ConsumerMessage]() + b.Sharding = func(key string) int { + hashCode := stringutil.GetHashCode(key) + return int(hashCode) % och.redisMessageBatches.Worker() + } + b.Key = func(consumerMessage *sarama.ConsumerMessage) string { + return string(consumerMessage.Key) } + och.redisMessageBatches = b + + err = b.Start() + if err != nil { + return nil, err + } + //och.msgDistributionCh = make(chan Cmd2Value) // no buffer channel + //go och.MessagesDistributionHandle() + //for i := 0; i < ChannelNum; i++ { + // och.chArrays[i] = make(chan Cmd2Value, 50) + // go och.Run(i) + //} och.conversationRpcClient = conversationRpcClient och.groupRpcClient = groupRpcClient och.historyConsumerGroup = historyConsumerGroup return &och, err } +func (och *OnlineHistoryRedisConsumerHandler) do(ctx context.Context, channelID int, val *batcher.Msg[sarama.ConsumerMessage]) { + ctx = mcontext.WithTriggerIDContext(ctx, val.TriggerID()) + ctxMessages := och.parseConsumerMessages(ctx, val.Val()) + ctx = withAggregationCtx(ctx, ctxMessages) + log.ZInfo(ctx, "msg arrived channel", "channel id", channelID, "msgList length", len(ctxMessages), + "key", val.Key()) + + storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList := + och.categorizeMessageLists(ctxMessages) + log.ZDebug(ctx, "number of categorized messages", "storageMsgList", len(storageMsgList), "notStorageMsgList", + len(notStorageMsgList), "storageNotificationList", len(storageNotificationList), "notStorageNotificationList", + len(notStorageNotificationList)) + + conversationIDMsg := msgprocessor.GetChatConversationIDByMsg(ctxMessages[0].message) + conversationIDNotification := msgprocessor.GetNotificationConversationIDByMsg(ctxMessages[0].message) + och.handleMsg(ctx, val.Key(), conversationIDMsg, storageMsgList, notStorageMsgList) + och.handleNotification(ctx, val.Key(), conversationIDNotification, storageNotificationList, notStorageNotificationList) +} -func (och *OnlineHistoryRedisConsumerHandler) Run(channelID int) { - for cmd := range och.chArrays[channelID] { - switch cmd.Cmd { - case SourceMessages: - msgChannelValue := cmd.Value.(MsgChannelValue) - ctxMsgList := msgChannelValue.ctxMsgList - ctx := msgChannelValue.ctx - log.ZDebug( - ctx, - "msg arrived channel", - "channel id", - channelID, - "msgList length", - len(ctxMsgList), - "uniqueKey", - msgChannelValue.uniqueKey, - ) - storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList, modifyMsgList := och.getPushStorageMsgList( - ctxMsgList, - ) - log.ZDebug( - ctx, - "msg lens", - "storageMsgList", - len(storageMsgList), - "notStorageMsgList", - len(notStorageMsgList), - "storageNotificationList", - len(storageNotificationList), - "notStorageNotificationList", - len(notStorageNotificationList), - "modifyMsgList", - len(modifyMsgList), - ) - conversationIDMsg := msgprocessor.GetChatConversationIDByMsg(ctxMsgList[0].message) - conversationIDNotification := msgprocessor.GetNotificationConversationIDByMsg(ctxMsgList[0].message) - och.handleMsg(ctx, msgChannelValue.uniqueKey, conversationIDMsg, storageMsgList, notStorageMsgList) - och.handleNotification( - ctx, - msgChannelValue.uniqueKey, - conversationIDNotification, - storageNotificationList, - notStorageNotificationList, - ) - if err := och.msgDatabase.MsgToModifyMQ(ctx, msgChannelValue.uniqueKey, conversationIDNotification, modifyMsgList); err != nil { - log.ZError(ctx, "msg to modify mq error", err, "uniqueKey", msgChannelValue.uniqueKey, "modifyMsgList", modifyMsgList) - } +func (och *OnlineHistoryRedisConsumerHandler) parseConsumerMessages(ctx context.Context, consumerMessages []*sarama.ConsumerMessage) []*ContextMsg { + var ctxMessages []*ContextMsg + for i := 0; i < len(consumerMessages); i++ { + ctxMsg := &ContextMsg{} + msgFromMQ := &sdkws.MsgData{} + err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ) + if err != nil { + log.ZWarn(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value)) + continue } + var arr []string + for i, header := range consumerMessages[i].Headers { + arr = append(arr, strconv.Itoa(i), string(header.Key), string(header.Value)) + } + log.ZDebug(ctx, "consumer.kafka.GetContextWithMQHeader", "len", len(consumerMessages[i].Headers), + "header", strings.Join(arr, ", ")) + ctxMsg.ctx = kafka.GetContextWithMQHeader(consumerMessages[i].Headers) + ctxMsg.message = msgFromMQ + log.ZDebug(ctx, "message parse finish", "message", msgFromMQ, "key", + string(consumerMessages[i].Key)) + ctxMessages = append(ctxMessages, ctxMsg) } + return ctxMessages } // Get messages/notifications stored message list, not stored and pushed message list. -func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList( - totalMsgs []*ContextMsg, -) (storageMsgList, notStorageMsgList, storageNotificatoinList, notStorageNotificationList, modifyMsgList []*sdkws.MsgData) { - isStorage := func(msg *sdkws.MsgData) bool { - options2 := msgprocessor.Options(msg.Options) - if options2.IsHistory() { - return true - } - // if !(!options2.IsSenderSync() && conversationID == msg.MsgData.SendID) { - // return false - // } - return false - } +func (och *OnlineHistoryRedisConsumerHandler) categorizeMessageLists(totalMsgs []*ContextMsg) (storageMsgList, + notStorageMsgList, storageNotificationList, notStorageNotificationList []*ContextMsg) { for _, v := range totalMsgs { options := msgprocessor.Options(v.message.Options) if !options.IsNotNotification() { @@ -190,101 +184,71 @@ func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList( ) msg.Options = msgprocessor.WithOptions(msg.Options, msgprocessor.WithUnreadCount(true)) } - storageMsgList = append(storageMsgList, msg) + ctxMsg := &ContextMsg{ + message: msg, + ctx: v.ctx, + } + storageMsgList = append(storageMsgList, ctxMsg) } - if isStorage(v.message) { - storageNotificatoinList = append(storageNotificatoinList, v.message) + if options.IsHistory() { + storageNotificationList = append(storageNotificationList, v) } else { - notStorageNotificationList = append(notStorageNotificationList, v.message) + notStorageNotificationList = append(notStorageNotificationList, v) } } else { - if isStorage(v.message) { - storageMsgList = append(storageMsgList, v.message) + if options.IsHistory() { + storageMsgList = append(storageMsgList, v) } else { - notStorageMsgList = append(notStorageMsgList, v.message) + notStorageMsgList = append(notStorageMsgList, v) } } - if v.message.ContentType == constant.ReactionMessageModifier || - v.message.ContentType == constant.ReactionMessageDeleter { - modifyMsgList = append(modifyMsgList, v.message) - } } return } -func (och *OnlineHistoryRedisConsumerHandler) handleNotification( - ctx context.Context, - key, conversationID string, - storageList, notStorageList []*sdkws.MsgData, -) { +func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*ContextMsg) { och.toPushTopic(ctx, key, conversationID, notStorageList) - if len(storageList) > 0 { - lastSeq, _, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList) - if err != nil { - log.ZError( - ctx, - "notification batch insert to redis error", - err, - "conversationID", - conversationID, - "storageList", - storageList, - ) - return - } - log.ZDebug(ctx, "success to next topic", "conversationID", conversationID) - err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq) - if err != nil { - log.ZError(ctx, "MsgToMongoMQ error", err) - } - och.toPushTopic(ctx, key, conversationID, storageList) - } -} - -func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(ctx context.Context, key, conversationID string, msgs []*sdkws.MsgData) { - for _, v := range msgs { - och.msgDatabase.MsgToPushMQ(ctx, key, conversationID, v) // nolint: errcheck + var storageMessageList []*sdkws.MsgData + for _, msg := range storageList { + storageMessageList = append(storageMessageList, msg.message) } -} - -func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*sdkws.MsgData) { - och.toPushTopic(ctx, key, conversationID, notStorageList) - if len(storageList) > 0 { - lastSeq, isNewConversation, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList) + if len(storageMessageList) > 0 { + msg := storageMessageList[0] + lastSeq, isNewConversation, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList) if err != nil && errs.Unwrap(err) != redis.Nil { - log.ZError(ctx, "batch data insert to redis err", err, "storageMsgList", storageList) + log.ZError(ctx, "batch data insert to redis err", err, "storageMsgList", storageMessageList) return } if isNewConversation { - switch storageList[0].SessionType { + switch msg.SessionType { case constant.ReadGroupChatType: log.ZInfo(ctx, "group chat first create conversation", "conversationID", conversationID) - userIDs, err := och.groupRpcClient.GetGroupMemberIDs(ctx, storageList[0].GroupID) + userIDs, err := och.groupRpcClient.GetGroupMemberIDs(ctx, msg.GroupID) if err != nil { log.ZWarn(ctx, "get group member ids error", err, "conversationID", conversationID) } else { if err := och.conversationRpcClient.GroupChatFirstCreateConversation(ctx, - storageList[0].GroupID, userIDs); err != nil { + msg.GroupID, userIDs); err != nil { log.ZWarn(ctx, "single chat first create conversation error", err, "conversationID", conversationID) } } case constant.SingleChatType, constant.NotificationChatType: - if err := och.conversationRpcClient.SingleChatFirstCreateConversation(ctx, storageList[0].RecvID, - storageList[0].SendID, conversationID, storageList[0].SessionType); err != nil { + if err := och.conversationRpcClient.SingleChatFirstCreateConversation(ctx, msg.RecvID, + msg.SendID, conversationID, msg.SessionType); err != nil { log.ZWarn(ctx, "single chat or notification first create conversation error", err, - "conversationID", conversationID, "sessionType", storageList[0].SessionType) + "conversationID", conversationID, "sessionType", msg.SessionType) } default: log.ZWarn(ctx, "unknown session type", nil, "sessionType", - storageList[0].SessionType) + msg.SessionType) } } log.ZDebug(ctx, "success incr to next topic") - err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq) + err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq) if err != nil { log.ZError(ctx, "MsgToMongoMQ error", err) } @@ -292,74 +256,32 @@ func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key } } -func (och *OnlineHistoryRedisConsumerHandler) MessagesDistributionHandle() { - for { - aggregationMsgs := make(map[string][]*ContextMsg, ChannelNum) - select { - case cmd := <-och.msgDistributionCh: - switch cmd.Cmd { - case ConsumerMsgs: - triggerChannelValue := cmd.Value.(TriggerChannelValue) - ctx := triggerChannelValue.ctx - consumerMessages := triggerChannelValue.cMsgList - // Aggregation map[userid]message list - log.ZDebug(ctx, "batch messages come to distribution center", "length", len(consumerMessages)) - for i := 0; i < len(consumerMessages); i++ { - ctxMsg := &ContextMsg{} - msgFromMQ := &sdkws.MsgData{} - err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ) - if err != nil { - log.ZError(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value)) - continue - } - var arr []string - for i, header := range consumerMessages[i].Headers { - arr = append(arr, strconv.Itoa(i), string(header.Key), string(header.Value)) - } - log.ZInfo(ctx, "consumer.kafka.GetContextWithMQHeader", "len", len(consumerMessages[i].Headers), - "header", strings.Join(arr, ", ")) - ctxMsg.ctx = kafka.GetContextWithMQHeader(consumerMessages[i].Headers) - ctxMsg.message = msgFromMQ - log.ZDebug( - ctx, - "single msg come to distribution center", - "message", - msgFromMQ, - "key", - string(consumerMessages[i].Key), - ) - // aggregationMsgs[string(consumerMessages[i].Key)] = - // append(aggregationMsgs[string(consumerMessages[i].Key)], ctxMsg) - if oldM, ok := aggregationMsgs[string(consumerMessages[i].Key)]; ok { - oldM = append(oldM, ctxMsg) - aggregationMsgs[string(consumerMessages[i].Key)] = oldM - } else { - m := make([]*ContextMsg, 0, 100) - m = append(m, ctxMsg) - aggregationMsgs[string(consumerMessages[i].Key)] = m - } - } - log.ZDebug(ctx, "generate map list users len", "length", len(aggregationMsgs)) - for uniqueKey, v := range aggregationMsgs { - if len(v) >= 0 { - hashCode := stringutil.GetHashCode(uniqueKey) - channelID := hashCode % ChannelNum - newCtx := withAggregationCtx(ctx, v) - log.ZDebug( - newCtx, - "generate channelID", - "hashCode", - hashCode, - "channelID", - channelID, - "uniqueKey", - uniqueKey, - ) - och.chArrays[channelID] <- Cmd2Value{Cmd: SourceMessages, Value: MsgChannelValue{uniqueKey: uniqueKey, ctxMsgList: v, ctx: newCtx}} - } - } - } +func (och *OnlineHistoryRedisConsumerHandler) handleNotification(ctx context.Context, key, conversationID string, + storageList, notStorageList []*ContextMsg) { + och.toPushTopic(ctx, key, conversationID, notStorageList) + var storageMessageList []*sdkws.MsgData + for _, msg := range storageList { + storageMessageList = append(storageMessageList, msg.message) + } + if len(storageMessageList) > 0 { + lastSeq, _, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList) + if err != nil { + log.ZError(ctx, "notification batch insert to redis error", err, "conversationID", conversationID, + "storageList", storageMessageList) + return + } + log.ZDebug(ctx, "success to next topic", "conversationID", conversationID) + err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq) + if err != nil { + log.ZError(ctx, "MsgToMongoMQ error", err) } + och.toPushTopic(ctx, key, conversationID, storageList) + } +} + +func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(_ context.Context, key, conversationID string, msgs []*ContextMsg) { + for _, v := range msgs { + och.msgDatabase.MsgToPushMQ(v.ctx, key, conversationID, v.message) } } @@ -382,20 +304,34 @@ func (och *OnlineHistoryRedisConsumerHandler) Cleanup(_ sarama.ConsumerGroupSess return nil } -func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim( - sess sarama.ConsumerGroupSession, - claim sarama.ConsumerGroupClaim, -) error { // a instance in the consumer group +func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(session sarama.ConsumerGroupSession, + claim sarama.ConsumerGroupClaim) error { // a instance in the consumer group + log.ZInfo(context.Background(), "online new session msg come", "highWaterMarkOffset", + claim.HighWaterMarkOffset(), "topic", claim.Topic(), "partition", claim.Partition()) + och.redisMessageBatches.OnComplete = func(lastMessage *sarama.ConsumerMessage, totalCount int) { + session.MarkMessage(lastMessage, "") + } for { - if sess == nil { - log.ZWarn(context.Background(), "sess == nil, waiting", nil) - time.Sleep(100 * time.Millisecond) - } else { - break + select { + case msg, ok := <-claim.Messages(): + if !ok { + return nil + } + + if len(msg.Value) == 0 { + continue + } + err := och.redisMessageBatches.Put(context.Background(), msg) + if err != nil { + log.ZWarn(context.Background(), "put msg to error", err, "msg", msg) + } + + session.MarkMessage(msg, "") + + case <-session.Context().Done(): + return nil } } - log.ZInfo(context.Background(), "online new session msg come", "highWaterMarkOffset", - claim.HighWaterMarkOffset(), "topic", claim.Topic(), "partition", claim.Partition()) var ( split = 1000 @@ -473,9 +409,9 @@ func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim( messages = append(messages, msg) rwLock.Unlock() - sess.MarkMessage(msg, "") + session.MarkMessage(msg, "") - case <-sess.Context().Done(): + case <-session.Context().Done(): running.Store(false) return } diff --git a/internal/push/push_handler.go b/internal/push/push_handler.go index 3a9a696f6..2e236195b 100644 --- a/internal/push/push_handler.go +++ b/internal/push/push_handler.go @@ -17,6 +17,7 @@ package push import ( "context" "encoding/json" + "github.com/IBM/sarama" "github.com/openimsdk/open-im-server/v3/internal/push/offlinepush" "github.com/openimsdk/open-im-server/v3/internal/push/offlinepush/options" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" @@ -25,19 +26,16 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/rpccache" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/util/conversationutil" - "github.com/openimsdk/protocol/sdkws" - "github.com/openimsdk/tools/discovery" - "github.com/openimsdk/tools/mcontext" - "github.com/openimsdk/tools/utils/jsonutil" - "github.com/redis/go-redis/v9" - - "github.com/IBM/sarama" "github.com/openimsdk/protocol/constant" pbchat "github.com/openimsdk/protocol/msg" pbpush "github.com/openimsdk/protocol/push" + "github.com/openimsdk/protocol/sdkws" + "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mq/kafka" "github.com/openimsdk/tools/utils/datautil" + "github.com/openimsdk/tools/utils/jsonutil" "github.com/openimsdk/tools/utils/timeutil" "google.golang.org/protobuf/proto" ) @@ -162,7 +160,8 @@ func (c *ConsumerHandler) Push2User(ctx context.Context, userIDs []string, msg * err = c.offlinePushMsg(ctx, msg, offlinePUshUserID) if err != nil { - return err + log.ZWarn(ctx, "offlinePushMsg failed", err, "offlinePUshUserID", offlinePUshUserID, "msg", msg) + return nil } return nil @@ -223,8 +222,8 @@ func (c *ConsumerHandler) Push2Group(ctx context.Context, groupID string, msg *s err = c.offlinePushMsg(ctx, msg, needOfflinePushUserIDs) if err != nil { - log.ZError(ctx, "offlinePushMsg failed", err, "groupID", groupID, "msg", msg) - return err + log.ZWarn(ctx, "offlinePushMsg failed", err, "groupID", groupID, "msg", msg) + return nil } } diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 13bd7f9be..2f1df656f 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -291,28 +291,21 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR break } } - if req.GroupInfo.GroupType == constant.SuperGroup { - go func() { - for _, userID := range userIDs { - s.notification.SuperGroupNotification(ctx, userID, userID) - } - }() - } else { - tips := &sdkws.GroupCreatedTips{ - Group: resp.GroupInfo, - OperationTime: group.CreateTime.UnixMilli(), - GroupOwnerUser: s.groupMemberDB2PB(groupMembers[0], userMap[groupMembers[0].UserID].AppMangerLevel), - } - for _, member := range groupMembers { - member.Nickname = userMap[member.UserID].Nickname - tips.MemberList = append(tips.MemberList, s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel)) - if member.UserID == opUserID { - tips.OpUser = s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel) - break - } + + tips := &sdkws.GroupCreatedTips{ + Group: resp.GroupInfo, + OperationTime: group.CreateTime.UnixMilli(), + GroupOwnerUser: s.groupMemberDB2PB(groupMembers[0], userMap[groupMembers[0].UserID].AppMangerLevel), + } + for _, member := range groupMembers { + member.Nickname = userMap[member.UserID].Nickname + tips.MemberList = append(tips.MemberList, s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel)) + if member.UserID == opUserID { + tips.OpUser = s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel) + break } - s.notification.GroupCreatedNotification(ctx, tips) } + s.notification.GroupCreatedNotification(ctx, tips) reqCallBackAfter := &pbgroup.CreateGroupReq{ MemberUserIDs: userIDs, diff --git a/internal/rpc/group/notification.go b/internal/rpc/group/notification.go index 6d7cebcbc..0690ef991 100644 --- a/internal/rpc/group/notification.go +++ b/internal/rpc/group/notification.go @@ -715,7 +715,3 @@ func (g *GroupNotificationSender) GroupMemberSetToOrdinaryUserNotification(ctx c } g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberSetToOrdinaryUserNotification, tips) } - -func (g *GroupNotificationSender) SuperGroupNotification(ctx context.Context, sendID, recvID string) { - g.Notification(ctx, sendID, recvID, constant.SuperGroupUpdateNotification, nil) -} diff --git a/pkg/util/batcher/batcher.go b/pkg/util/batcher/batcher.go new file mode 100644 index 000000000..87d3128d5 --- /dev/null +++ b/pkg/util/batcher/batcher.go @@ -0,0 +1,269 @@ +package batcher + +import ( + "context" + "fmt" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/utils/idutil" + "strings" + "sync" + "time" +) + +var ( + DefaultDataChanSize = 1000 + DefaultSize = 100 + DefaultBuffer = 100 + DefaultWorker = 5 + DefaultInterval = time.Second +) + +type Config struct { + size int // Number of message aggregations + buffer int // The number of caches running in a single coroutine + worker int // Number of coroutines processed in parallel + interval time.Duration // Time of message aggregations + syncWait bool // Whether to wait synchronously after distributing messages +} + +type Option func(c *Config) + +func WithSize(s int) Option { + return func(c *Config) { + c.size = s + } +} + +func WithBuffer(b int) Option { + return func(c *Config) { + c.buffer = b + } +} + +func WithWorker(w int) Option { + return func(c *Config) { + c.worker = w + } +} + +func WithInterval(i time.Duration) Option { + return func(c *Config) { + c.interval = i + } +} + +func WithSyncWait(wait bool) Option { + return func(c *Config) { + c.syncWait = wait + } +} + +type Batcher[T any] struct { + config *Config + + globalCtx context.Context + cancel context.CancelFunc + Do func(ctx context.Context, channelID int, val *Msg[T]) + OnComplete func(lastMessage *T, totalCount int) + Sharding func(key string) int + Key func(data *T) string + HookFunc func(triggerID string, messages map[string][]*T, totalCount int, lastMessage *T) + data chan *T + chArrays []chan *Msg[T] + wait sync.WaitGroup + counter sync.WaitGroup +} + +func emptyOnComplete[T any](*T, int) {} +func emptyHookFunc[T any](string, map[string][]*T, int, *T) { +} + +func New[T any](opts ...Option) *Batcher[T] { + b := &Batcher[T]{ + OnComplete: emptyOnComplete[T], + HookFunc: emptyHookFunc[T], + } + config := &Config{ + size: DefaultSize, + buffer: DefaultBuffer, + worker: DefaultWorker, + interval: DefaultInterval, + } + for _, opt := range opts { + opt(config) + } + b.config = config + b.data = make(chan *T, DefaultDataChanSize) + b.globalCtx, b.cancel = context.WithCancel(context.Background()) + + b.chArrays = make([]chan *Msg[T], b.config.worker) + for i := 0; i < b.config.worker; i++ { + b.chArrays[i] = make(chan *Msg[T], b.config.buffer) + } + return b +} + +func (b *Batcher[T]) Worker() int { + return b.config.worker +} + +func (b *Batcher[T]) Start() error { + if b.Sharding == nil { + return errs.New("Sharding function is required").Wrap() + } + if b.Do == nil { + return errs.New("Do function is required").Wrap() + } + if b.Key == nil { + return errs.New("Key function is required").Wrap() + } + b.wait.Add(b.config.worker) + for i := 0; i < b.config.worker; i++ { + go b.run(i, b.chArrays[i]) + } + b.wait.Add(1) + go b.scheduler() + return nil +} + +func (b *Batcher[T]) Put(ctx context.Context, data *T) error { + if data == nil { + return errs.New("data can not be nil").Wrap() + } + select { + case <-b.globalCtx.Done(): + return errs.New("data channel is closed").Wrap() + case <-ctx.Done(): + return ctx.Err() + case b.data <- data: + return nil + } +} + +func (b *Batcher[T]) scheduler() { + ticker := time.NewTicker(b.config.interval) + defer func() { + ticker.Stop() + for _, ch := range b.chArrays { + close(ch) // 发送关闭信号到每个worker + } + close(b.data) + b.wait.Done() + }() + + vals := make(map[string][]*T) + count := 0 + var lastAny *T + + for { + select { + case data, ok := <-b.data: + if !ok { + // 如果data channel意外关闭 + return + } + if data == nil { + // 接收到nil作为结束信号 + fmt.Println("Batcher Closing1", count) + if count > 0 { + fmt.Println("Batcher Closing2", count) + b.distributeMessage(vals, count, lastAny) + } + return + } + // 正常数据处理 + key := b.Key(data) + vals[key] = append(vals[key], data) + lastAny = data + + count++ + if count >= b.config.size { + + fmt.Printf("counter to %d, %v\n", count, lastAny) + b.distributeMessage(vals, count, lastAny) + vals = make(map[string][]*T) + count = 0 + } + + case <-ticker.C: + if count > 0 { + fmt.Printf("ticker to %v , %d, %v\n", b.config.interval, count, lastAny) + b.distributeMessage(vals, count, lastAny) + vals = make(map[string][]*T) + count = 0 + } + } + } +} + +type Msg[T any] struct { + key string + triggerID string + val []*T +} + +func (m Msg[T]) Key() string { + return m.key +} + +func (m Msg[T]) TriggerID() string { + return m.triggerID +} + +func (m Msg[T]) Val() []*T { + return m.val +} + +func (m Msg[T]) String() string { + var sb strings.Builder + sb.WriteString("Key: ") + sb.WriteString(m.key) + sb.WriteString(", Values: [") + for i, v := range m.val { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(fmt.Sprintf("%v", *v)) + } + sb.WriteString("]") + return sb.String() +} + +func (b *Batcher[T]) distributeMessage(messages map[string][]*T, totalCount int, lastMessage *T) { + triggerID := idutil.OperationIDGenerator() + b.HookFunc(triggerID, messages, totalCount, lastMessage) + for key, data := range messages { + if b.config.syncWait { + b.counter.Add(1) + } + channelID := b.Sharding(key) + b.chArrays[channelID] <- &Msg[T]{key: key, triggerID: triggerID, val: data} + } + if b.config.syncWait { + b.counter.Wait() + } + b.OnComplete(lastMessage, totalCount) +} + +func (b *Batcher[T]) run(channelID int, ch <-chan *Msg[T]) { + defer b.wait.Done() + for { + select { + case messages, ok := <-ch: + if !ok { + return + } + b.Do(context.Background(), channelID, messages) + if b.config.syncWait { + b.counter.Done() + } + } + } +} + +func (b *Batcher[T]) Close() { + b.cancel() // Signal to stop put data + b.data <- nil + //wait all goroutines exit + b.wait.Wait() +} diff --git a/pkg/util/batcher/batcher_test.go b/pkg/util/batcher/batcher_test.go new file mode 100644 index 000000000..90e028449 --- /dev/null +++ b/pkg/util/batcher/batcher_test.go @@ -0,0 +1,66 @@ +package batcher + +import ( + "context" + "fmt" + "github.com/openimsdk/tools/utils/stringutil" + "testing" + "time" +) + +func TestBatcher(t *testing.T) { + config := Config{ + size: 1000, + buffer: 10, + worker: 10, + interval: 5 * time.Millisecond, + } + + b := New[string]( + WithSize(config.size), + WithBuffer(config.buffer), + WithWorker(config.worker), + WithInterval(config.interval), + WithSyncWait(true), + ) + + // Mock Do function to simply print values for demonstration + b.Do = func(ctx context.Context, channelID int, vals *Msg[string]) { + t.Logf("Channel %d Processed batch: %v", channelID, vals) + } + b.OnComplete = func(lastMessage *string, totalCount int) { + t.Logf("Completed processing with last message: %v, total count: %d", *lastMessage, totalCount) + } + b.Sharding = func(key string) int { + hashCode := stringutil.GetHashCode(key) + return int(hashCode) % config.worker + } + b.Key = func(data *string) string { + return *data + } + + err := b.Start() + if err != nil { + t.Fatal(err) + } + + // Test normal data processing + for i := 0; i < 10000; i++ { + data := "data" + fmt.Sprintf("%d", i) + if err := b.Put(context.Background(), &data); err != nil { + t.Fatal(err) + } + } + + time.Sleep(time.Duration(1) * time.Second) + start := time.Now() + // Wait for all processing to finish + b.Close() + + elapsed := time.Since(start) + t.Logf("Close took %s", elapsed) + + if len(b.data) != 0 { + t.Error("Data channel should be empty after closing") + } +}