refactor: msg transfer refactor.

pull/2325/head
Gordon 1 year ago
parent 31aba9b9ff
commit 9c19fd43fa

@ -20,6 +20,7 @@ import (
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/redisutil"
"github.com/openimsdk/tools/utils/datautil"
"net"
"net/http"
"os"
"os/signal"
@ -167,4 +168,40 @@ func (m *MsgTransfer) Start(index int, config *Config) error {
close(netDone)
return netErr
}
if config.MsgTransfer.Prometheus.Enable {
go func() {
proreg := prometheus.NewRegistry()
proreg.MustRegister(
collectors.NewGoCollector(),
)
proreg.MustRegister(prommetrics.GetGrpcCusMetrics("Transfer", &config.Share)...)
http.Handle("/metrics", promhttp.HandlerFor(proreg, promhttp.HandlerOpts{Registry: proreg}))
lc := net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
if err != nil {
return fmt.Errorf("setsockopt failed: %w", err)
}
})
},
}
listener, err := lc.Listen(context.Background(), "tcp", fmt.Sprintf(":%d", prometheusPort))
if err != nil {
netErr = errs.WrapMsg(err, "prometheus start error", "prometheusPort", prometheusPort)
netDone <- struct{}{}
return
}
err = http.Serve(listener, nil)
if err != nil && err != http.ErrServerClosed {
netErr = errs.WrapMsg(err, "HTTP server start error", "prometheusPort", prometheusPort)
netDone <- struct{}{}
}
}()
}
}

@ -16,6 +16,7 @@ package msgtransfer
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/util/batcher"
"strconv"
"strings"
"sync"
@ -71,10 +72,7 @@ type OnlineHistoryRedisConsumerHandler struct {
chArrays [ChannelNum]chan Cmd2Value
msgDistributionCh chan Cmd2Value
// singleMsgSuccessCount uint64
// singleMsgFailedCount uint64
// singleMsgSuccessCountMutex sync.Mutex
// singleMsgFailedCountMutex sync.Mutex
redisMessageBatches *batcher.Batcher[sarama.ConsumerMessage]
msgDatabase controller.CommonMsgDatabase
conversationRpcClient *rpcclient.ConversationRpcClient
@ -89,83 +87,79 @@ func NewOnlineHistoryRedisConsumerHandler(kafkaConf *config.Kafka, database cont
}
var och OnlineHistoryRedisConsumerHandler
och.msgDatabase = database
och.msgDistributionCh = make(chan Cmd2Value) // no buffer channel
go och.MessagesDistributionHandle()
for i := 0; i < ChannelNum; i++ {
och.chArrays[i] = make(chan Cmd2Value, 50)
go och.Run(i)
b := batcher.New[sarama.ConsumerMessage]()
b.Sharding = func(key string) int {
hashCode := stringutil.GetHashCode(key)
return int(hashCode) % och.redisMessageBatches.Worker()
}
b.Key = func(consumerMessage *sarama.ConsumerMessage) string {
return string(consumerMessage.Key)
}
och.redisMessageBatches = b
err = b.Start()
if err != nil {
return nil, err
}
//och.msgDistributionCh = make(chan Cmd2Value) // no buffer channel
//go och.MessagesDistributionHandle()
//for i := 0; i < ChannelNum; i++ {
// och.chArrays[i] = make(chan Cmd2Value, 50)
// go och.Run(i)
//}
och.conversationRpcClient = conversationRpcClient
och.groupRpcClient = groupRpcClient
och.historyConsumerGroup = historyConsumerGroup
return &och, err
}
func (och *OnlineHistoryRedisConsumerHandler) do(ctx context.Context, channelID int, val *batcher.Msg[sarama.ConsumerMessage]) {
ctx = mcontext.WithTriggerIDContext(ctx, val.TriggerID())
ctxMessages := och.parseConsumerMessages(ctx, val.Val())
ctx = withAggregationCtx(ctx, ctxMessages)
log.ZInfo(ctx, "msg arrived channel", "channel id", channelID, "msgList length", len(ctxMessages),
"key", val.Key())
storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList :=
och.categorizeMessageLists(ctxMessages)
log.ZDebug(ctx, "number of categorized messages", "storageMsgList", len(storageMsgList), "notStorageMsgList",
len(notStorageMsgList), "storageNotificationList", len(storageNotificationList), "notStorageNotificationList",
len(notStorageNotificationList))
conversationIDMsg := msgprocessor.GetChatConversationIDByMsg(ctxMessages[0].message)
conversationIDNotification := msgprocessor.GetNotificationConversationIDByMsg(ctxMessages[0].message)
och.handleMsg(ctx, val.Key(), conversationIDMsg, storageMsgList, notStorageMsgList)
och.handleNotification(ctx, val.Key(), conversationIDNotification, storageNotificationList, notStorageNotificationList)
}
func (och *OnlineHistoryRedisConsumerHandler) Run(channelID int) {
for cmd := range och.chArrays[channelID] {
switch cmd.Cmd {
case SourceMessages:
msgChannelValue := cmd.Value.(MsgChannelValue)
ctxMsgList := msgChannelValue.ctxMsgList
ctx := msgChannelValue.ctx
log.ZDebug(
ctx,
"msg arrived channel",
"channel id",
channelID,
"msgList length",
len(ctxMsgList),
"uniqueKey",
msgChannelValue.uniqueKey,
)
storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList, modifyMsgList := och.getPushStorageMsgList(
ctxMsgList,
)
log.ZDebug(
ctx,
"msg lens",
"storageMsgList",
len(storageMsgList),
"notStorageMsgList",
len(notStorageMsgList),
"storageNotificationList",
len(storageNotificationList),
"notStorageNotificationList",
len(notStorageNotificationList),
"modifyMsgList",
len(modifyMsgList),
)
conversationIDMsg := msgprocessor.GetChatConversationIDByMsg(ctxMsgList[0].message)
conversationIDNotification := msgprocessor.GetNotificationConversationIDByMsg(ctxMsgList[0].message)
och.handleMsg(ctx, msgChannelValue.uniqueKey, conversationIDMsg, storageMsgList, notStorageMsgList)
och.handleNotification(
ctx,
msgChannelValue.uniqueKey,
conversationIDNotification,
storageNotificationList,
notStorageNotificationList,
)
if err := och.msgDatabase.MsgToModifyMQ(ctx, msgChannelValue.uniqueKey, conversationIDNotification, modifyMsgList); err != nil {
log.ZError(ctx, "msg to modify mq error", err, "uniqueKey", msgChannelValue.uniqueKey, "modifyMsgList", modifyMsgList)
}
func (och *OnlineHistoryRedisConsumerHandler) parseConsumerMessages(ctx context.Context, consumerMessages []*sarama.ConsumerMessage) []*ContextMsg {
var ctxMessages []*ContextMsg
for i := 0; i < len(consumerMessages); i++ {
ctxMsg := &ContextMsg{}
msgFromMQ := &sdkws.MsgData{}
err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ)
if err != nil {
log.ZWarn(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value))
continue
}
var arr []string
for i, header := range consumerMessages[i].Headers {
arr = append(arr, strconv.Itoa(i), string(header.Key), string(header.Value))
}
log.ZDebug(ctx, "consumer.kafka.GetContextWithMQHeader", "len", len(consumerMessages[i].Headers),
"header", strings.Join(arr, ", "))
ctxMsg.ctx = kafka.GetContextWithMQHeader(consumerMessages[i].Headers)
ctxMsg.message = msgFromMQ
log.ZDebug(ctx, "message parse finish", "message", msgFromMQ, "key",
string(consumerMessages[i].Key))
ctxMessages = append(ctxMessages, ctxMsg)
}
return ctxMessages
}
// Get messages/notifications stored message list, not stored and pushed message list.
func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList(
totalMsgs []*ContextMsg,
) (storageMsgList, notStorageMsgList, storageNotificatoinList, notStorageNotificationList, modifyMsgList []*sdkws.MsgData) {
isStorage := func(msg *sdkws.MsgData) bool {
options2 := msgprocessor.Options(msg.Options)
if options2.IsHistory() {
return true
}
// if !(!options2.IsSenderSync() && conversationID == msg.MsgData.SendID) {
// return false
// }
return false
}
func (och *OnlineHistoryRedisConsumerHandler) categorizeMessageLists(totalMsgs []*ContextMsg) (storageMsgList,
notStorageMsgList, storageNotificationList, notStorageNotificationList []*ContextMsg) {
for _, v := range totalMsgs {
options := msgprocessor.Options(v.message.Options)
if !options.IsNotNotification() {
@ -190,101 +184,71 @@ func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList(
)
msg.Options = msgprocessor.WithOptions(msg.Options, msgprocessor.WithUnreadCount(true))
}
storageMsgList = append(storageMsgList, msg)
ctxMsg := &ContextMsg{
message: msg,
ctx: v.ctx,
}
storageMsgList = append(storageMsgList, ctxMsg)
}
if isStorage(v.message) {
storageNotificatoinList = append(storageNotificatoinList, v.message)
if options.IsHistory() {
storageNotificationList = append(storageNotificationList, v)
} else {
notStorageNotificationList = append(notStorageNotificationList, v.message)
notStorageNotificationList = append(notStorageNotificationList, v)
}
} else {
if isStorage(v.message) {
storageMsgList = append(storageMsgList, v.message)
if options.IsHistory() {
storageMsgList = append(storageMsgList, v)
} else {
notStorageMsgList = append(notStorageMsgList, v.message)
notStorageMsgList = append(notStorageMsgList, v)
}
}
if v.message.ContentType == constant.ReactionMessageModifier ||
v.message.ContentType == constant.ReactionMessageDeleter {
modifyMsgList = append(modifyMsgList, v.message)
}
}
return
}
func (och *OnlineHistoryRedisConsumerHandler) handleNotification(
ctx context.Context,
key, conversationID string,
storageList, notStorageList []*sdkws.MsgData,
) {
func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*ContextMsg) {
och.toPushTopic(ctx, key, conversationID, notStorageList)
if len(storageList) > 0 {
lastSeq, _, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList)
if err != nil {
log.ZError(
ctx,
"notification batch insert to redis error",
err,
"conversationID",
conversationID,
"storageList",
storageList,
)
return
}
log.ZDebug(ctx, "success to next topic", "conversationID", conversationID)
err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq)
if err != nil {
log.ZError(ctx, "MsgToMongoMQ error", err)
}
och.toPushTopic(ctx, key, conversationID, storageList)
}
}
func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(ctx context.Context, key, conversationID string, msgs []*sdkws.MsgData) {
for _, v := range msgs {
och.msgDatabase.MsgToPushMQ(ctx, key, conversationID, v) // nolint: errcheck
var storageMessageList []*sdkws.MsgData
for _, msg := range storageList {
storageMessageList = append(storageMessageList, msg.message)
}
}
func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*sdkws.MsgData) {
och.toPushTopic(ctx, key, conversationID, notStorageList)
if len(storageList) > 0 {
lastSeq, isNewConversation, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList)
if len(storageMessageList) > 0 {
msg := storageMessageList[0]
lastSeq, isNewConversation, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList)
if err != nil && errs.Unwrap(err) != redis.Nil {
log.ZError(ctx, "batch data insert to redis err", err, "storageMsgList", storageList)
log.ZError(ctx, "batch data insert to redis err", err, "storageMsgList", storageMessageList)
return
}
if isNewConversation {
switch storageList[0].SessionType {
switch msg.SessionType {
case constant.ReadGroupChatType:
log.ZInfo(ctx, "group chat first create conversation", "conversationID",
conversationID)
userIDs, err := och.groupRpcClient.GetGroupMemberIDs(ctx, storageList[0].GroupID)
userIDs, err := och.groupRpcClient.GetGroupMemberIDs(ctx, msg.GroupID)
if err != nil {
log.ZWarn(ctx, "get group member ids error", err, "conversationID",
conversationID)
} else {
if err := och.conversationRpcClient.GroupChatFirstCreateConversation(ctx,
storageList[0].GroupID, userIDs); err != nil {
msg.GroupID, userIDs); err != nil {
log.ZWarn(ctx, "single chat first create conversation error", err,
"conversationID", conversationID)
}
}
case constant.SingleChatType, constant.NotificationChatType:
if err := och.conversationRpcClient.SingleChatFirstCreateConversation(ctx, storageList[0].RecvID,
storageList[0].SendID, conversationID, storageList[0].SessionType); err != nil {
if err := och.conversationRpcClient.SingleChatFirstCreateConversation(ctx, msg.RecvID,
msg.SendID, conversationID, msg.SessionType); err != nil {
log.ZWarn(ctx, "single chat or notification first create conversation error", err,
"conversationID", conversationID, "sessionType", storageList[0].SessionType)
"conversationID", conversationID, "sessionType", msg.SessionType)
}
default:
log.ZWarn(ctx, "unknown session type", nil, "sessionType",
storageList[0].SessionType)
msg.SessionType)
}
}
log.ZDebug(ctx, "success incr to next topic")
err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq)
err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq)
if err != nil {
log.ZError(ctx, "MsgToMongoMQ error", err)
}
@ -292,74 +256,32 @@ func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key
}
}
func (och *OnlineHistoryRedisConsumerHandler) MessagesDistributionHandle() {
for {
aggregationMsgs := make(map[string][]*ContextMsg, ChannelNum)
select {
case cmd := <-och.msgDistributionCh:
switch cmd.Cmd {
case ConsumerMsgs:
triggerChannelValue := cmd.Value.(TriggerChannelValue)
ctx := triggerChannelValue.ctx
consumerMessages := triggerChannelValue.cMsgList
// Aggregation map[userid]message list
log.ZDebug(ctx, "batch messages come to distribution center", "length", len(consumerMessages))
for i := 0; i < len(consumerMessages); i++ {
ctxMsg := &ContextMsg{}
msgFromMQ := &sdkws.MsgData{}
err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ)
if err != nil {
log.ZError(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value))
continue
}
var arr []string
for i, header := range consumerMessages[i].Headers {
arr = append(arr, strconv.Itoa(i), string(header.Key), string(header.Value))
}
log.ZInfo(ctx, "consumer.kafka.GetContextWithMQHeader", "len", len(consumerMessages[i].Headers),
"header", strings.Join(arr, ", "))
ctxMsg.ctx = kafka.GetContextWithMQHeader(consumerMessages[i].Headers)
ctxMsg.message = msgFromMQ
log.ZDebug(
ctx,
"single msg come to distribution center",
"message",
msgFromMQ,
"key",
string(consumerMessages[i].Key),
)
// aggregationMsgs[string(consumerMessages[i].Key)] =
// append(aggregationMsgs[string(consumerMessages[i].Key)], ctxMsg)
if oldM, ok := aggregationMsgs[string(consumerMessages[i].Key)]; ok {
oldM = append(oldM, ctxMsg)
aggregationMsgs[string(consumerMessages[i].Key)] = oldM
} else {
m := make([]*ContextMsg, 0, 100)
m = append(m, ctxMsg)
aggregationMsgs[string(consumerMessages[i].Key)] = m
}
}
log.ZDebug(ctx, "generate map list users len", "length", len(aggregationMsgs))
for uniqueKey, v := range aggregationMsgs {
if len(v) >= 0 {
hashCode := stringutil.GetHashCode(uniqueKey)
channelID := hashCode % ChannelNum
newCtx := withAggregationCtx(ctx, v)
log.ZDebug(
newCtx,
"generate channelID",
"hashCode",
hashCode,
"channelID",
channelID,
"uniqueKey",
uniqueKey,
)
och.chArrays[channelID] <- Cmd2Value{Cmd: SourceMessages, Value: MsgChannelValue{uniqueKey: uniqueKey, ctxMsgList: v, ctx: newCtx}}
}
}
}
func (och *OnlineHistoryRedisConsumerHandler) handleNotification(ctx context.Context, key, conversationID string,
storageList, notStorageList []*ContextMsg) {
och.toPushTopic(ctx, key, conversationID, notStorageList)
var storageMessageList []*sdkws.MsgData
for _, msg := range storageList {
storageMessageList = append(storageMessageList, msg.message)
}
if len(storageMessageList) > 0 {
lastSeq, _, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList)
if err != nil {
log.ZError(ctx, "notification batch insert to redis error", err, "conversationID", conversationID,
"storageList", storageMessageList)
return
}
log.ZDebug(ctx, "success to next topic", "conversationID", conversationID)
err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq)
if err != nil {
log.ZError(ctx, "MsgToMongoMQ error", err)
}
och.toPushTopic(ctx, key, conversationID, storageList)
}
}
func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(_ context.Context, key, conversationID string, msgs []*ContextMsg) {
for _, v := range msgs {
och.msgDatabase.MsgToPushMQ(v.ctx, key, conversationID, v.message)
}
}
@ -382,20 +304,34 @@ func (och *OnlineHistoryRedisConsumerHandler) Cleanup(_ sarama.ConsumerGroupSess
return nil
}
func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(
sess sarama.ConsumerGroupSession,
claim sarama.ConsumerGroupClaim,
) error { // a instance in the consumer group
func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(session sarama.ConsumerGroupSession,
claim sarama.ConsumerGroupClaim) error { // a instance in the consumer group
log.ZInfo(context.Background(), "online new session msg come", "highWaterMarkOffset",
claim.HighWaterMarkOffset(), "topic", claim.Topic(), "partition", claim.Partition())
och.redisMessageBatches.OnComplete = func(lastMessage *sarama.ConsumerMessage, totalCount int) {
session.MarkMessage(lastMessage, "")
}
for {
if sess == nil {
log.ZWarn(context.Background(), "sess == nil, waiting", nil)
time.Sleep(100 * time.Millisecond)
} else {
break
select {
case msg, ok := <-claim.Messages():
if !ok {
return nil
}
if len(msg.Value) == 0 {
continue
}
err := och.redisMessageBatches.Put(context.Background(), msg)
if err != nil {
log.ZWarn(context.Background(), "put msg to error", err, "msg", msg)
}
session.MarkMessage(msg, "")
case <-session.Context().Done():
return nil
}
}
log.ZInfo(context.Background(), "online new session msg come", "highWaterMarkOffset",
claim.HighWaterMarkOffset(), "topic", claim.Topic(), "partition", claim.Partition())
var (
split = 1000
@ -473,9 +409,9 @@ func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(
messages = append(messages, msg)
rwLock.Unlock()
sess.MarkMessage(msg, "")
session.MarkMessage(msg, "")
case <-sess.Context().Done():
case <-session.Context().Done():
running.Store(false)
return
}

@ -17,6 +17,7 @@ package push
import (
"context"
"encoding/json"
"github.com/IBM/sarama"
"github.com/openimsdk/open-im-server/v3/internal/push/offlinepush"
"github.com/openimsdk/open-im-server/v3/internal/push/offlinepush/options"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
@ -25,19 +26,16 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/rpccache"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
"github.com/openimsdk/open-im-server/v3/pkg/util/conversationutil"
"github.com/openimsdk/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/redis/go-redis/v9"
"github.com/IBM/sarama"
"github.com/openimsdk/protocol/constant"
pbchat "github.com/openimsdk/protocol/msg"
pbpush "github.com/openimsdk/protocol/push"
"github.com/openimsdk/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/mq/kafka"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/openimsdk/tools/utils/timeutil"
"google.golang.org/protobuf/proto"
)
@ -162,7 +160,8 @@ func (c *ConsumerHandler) Push2User(ctx context.Context, userIDs []string, msg *
err = c.offlinePushMsg(ctx, msg, offlinePUshUserID)
if err != nil {
return err
log.ZWarn(ctx, "offlinePushMsg failed", err, "offlinePUshUserID", offlinePUshUserID, "msg", msg)
return nil
}
return nil
@ -223,8 +222,8 @@ func (c *ConsumerHandler) Push2Group(ctx context.Context, groupID string, msg *s
err = c.offlinePushMsg(ctx, msg, needOfflinePushUserIDs)
if err != nil {
log.ZError(ctx, "offlinePushMsg failed", err, "groupID", groupID, "msg", msg)
return err
log.ZWarn(ctx, "offlinePushMsg failed", err, "groupID", groupID, "msg", msg)
return nil
}
}

@ -291,28 +291,21 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR
break
}
}
if req.GroupInfo.GroupType == constant.SuperGroup {
go func() {
for _, userID := range userIDs {
s.notification.SuperGroupNotification(ctx, userID, userID)
}
}()
} else {
tips := &sdkws.GroupCreatedTips{
Group: resp.GroupInfo,
OperationTime: group.CreateTime.UnixMilli(),
GroupOwnerUser: s.groupMemberDB2PB(groupMembers[0], userMap[groupMembers[0].UserID].AppMangerLevel),
}
for _, member := range groupMembers {
member.Nickname = userMap[member.UserID].Nickname
tips.MemberList = append(tips.MemberList, s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel))
if member.UserID == opUserID {
tips.OpUser = s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel)
break
}
tips := &sdkws.GroupCreatedTips{
Group: resp.GroupInfo,
OperationTime: group.CreateTime.UnixMilli(),
GroupOwnerUser: s.groupMemberDB2PB(groupMembers[0], userMap[groupMembers[0].UserID].AppMangerLevel),
}
for _, member := range groupMembers {
member.Nickname = userMap[member.UserID].Nickname
tips.MemberList = append(tips.MemberList, s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel))
if member.UserID == opUserID {
tips.OpUser = s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel)
break
}
s.notification.GroupCreatedNotification(ctx, tips)
}
s.notification.GroupCreatedNotification(ctx, tips)
reqCallBackAfter := &pbgroup.CreateGroupReq{
MemberUserIDs: userIDs,

@ -715,7 +715,3 @@ func (g *GroupNotificationSender) GroupMemberSetToOrdinaryUserNotification(ctx c
}
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberSetToOrdinaryUserNotification, tips)
}
func (g *GroupNotificationSender) SuperGroupNotification(ctx context.Context, sendID, recvID string) {
g.Notification(ctx, sendID, recvID, constant.SuperGroupUpdateNotification, nil)
}

@ -0,0 +1,269 @@
package batcher
import (
"context"
"fmt"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/idutil"
"strings"
"sync"
"time"
)
var (
DefaultDataChanSize = 1000
DefaultSize = 100
DefaultBuffer = 100
DefaultWorker = 5
DefaultInterval = time.Second
)
type Config struct {
size int // Number of message aggregations
buffer int // The number of caches running in a single coroutine
worker int // Number of coroutines processed in parallel
interval time.Duration // Time of message aggregations
syncWait bool // Whether to wait synchronously after distributing messages
}
type Option func(c *Config)
func WithSize(s int) Option {
return func(c *Config) {
c.size = s
}
}
func WithBuffer(b int) Option {
return func(c *Config) {
c.buffer = b
}
}
func WithWorker(w int) Option {
return func(c *Config) {
c.worker = w
}
}
func WithInterval(i time.Duration) Option {
return func(c *Config) {
c.interval = i
}
}
func WithSyncWait(wait bool) Option {
return func(c *Config) {
c.syncWait = wait
}
}
type Batcher[T any] struct {
config *Config
globalCtx context.Context
cancel context.CancelFunc
Do func(ctx context.Context, channelID int, val *Msg[T])
OnComplete func(lastMessage *T, totalCount int)
Sharding func(key string) int
Key func(data *T) string
HookFunc func(triggerID string, messages map[string][]*T, totalCount int, lastMessage *T)
data chan *T
chArrays []chan *Msg[T]
wait sync.WaitGroup
counter sync.WaitGroup
}
func emptyOnComplete[T any](*T, int) {}
func emptyHookFunc[T any](string, map[string][]*T, int, *T) {
}
func New[T any](opts ...Option) *Batcher[T] {
b := &Batcher[T]{
OnComplete: emptyOnComplete[T],
HookFunc: emptyHookFunc[T],
}
config := &Config{
size: DefaultSize,
buffer: DefaultBuffer,
worker: DefaultWorker,
interval: DefaultInterval,
}
for _, opt := range opts {
opt(config)
}
b.config = config
b.data = make(chan *T, DefaultDataChanSize)
b.globalCtx, b.cancel = context.WithCancel(context.Background())
b.chArrays = make([]chan *Msg[T], b.config.worker)
for i := 0; i < b.config.worker; i++ {
b.chArrays[i] = make(chan *Msg[T], b.config.buffer)
}
return b
}
func (b *Batcher[T]) Worker() int {
return b.config.worker
}
func (b *Batcher[T]) Start() error {
if b.Sharding == nil {
return errs.New("Sharding function is required").Wrap()
}
if b.Do == nil {
return errs.New("Do function is required").Wrap()
}
if b.Key == nil {
return errs.New("Key function is required").Wrap()
}
b.wait.Add(b.config.worker)
for i := 0; i < b.config.worker; i++ {
go b.run(i, b.chArrays[i])
}
b.wait.Add(1)
go b.scheduler()
return nil
}
func (b *Batcher[T]) Put(ctx context.Context, data *T) error {
if data == nil {
return errs.New("data can not be nil").Wrap()
}
select {
case <-b.globalCtx.Done():
return errs.New("data channel is closed").Wrap()
case <-ctx.Done():
return ctx.Err()
case b.data <- data:
return nil
}
}
func (b *Batcher[T]) scheduler() {
ticker := time.NewTicker(b.config.interval)
defer func() {
ticker.Stop()
for _, ch := range b.chArrays {
close(ch) // 发送关闭信号到每个worker
}
close(b.data)
b.wait.Done()
}()
vals := make(map[string][]*T)
count := 0
var lastAny *T
for {
select {
case data, ok := <-b.data:
if !ok {
// 如果data channel意外关闭
return
}
if data == nil {
// 接收到nil作为结束信号
fmt.Println("Batcher Closing1", count)
if count > 0 {
fmt.Println("Batcher Closing2", count)
b.distributeMessage(vals, count, lastAny)
}
return
}
// 正常数据处理
key := b.Key(data)
vals[key] = append(vals[key], data)
lastAny = data
count++
if count >= b.config.size {
fmt.Printf("counter to %d, %v\n", count, lastAny)
b.distributeMessage(vals, count, lastAny)
vals = make(map[string][]*T)
count = 0
}
case <-ticker.C:
if count > 0 {
fmt.Printf("ticker to %v , %d, %v\n", b.config.interval, count, lastAny)
b.distributeMessage(vals, count, lastAny)
vals = make(map[string][]*T)
count = 0
}
}
}
}
type Msg[T any] struct {
key string
triggerID string
val []*T
}
func (m Msg[T]) Key() string {
return m.key
}
func (m Msg[T]) TriggerID() string {
return m.triggerID
}
func (m Msg[T]) Val() []*T {
return m.val
}
func (m Msg[T]) String() string {
var sb strings.Builder
sb.WriteString("Key: ")
sb.WriteString(m.key)
sb.WriteString(", Values: [")
for i, v := range m.val {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", *v))
}
sb.WriteString("]")
return sb.String()
}
func (b *Batcher[T]) distributeMessage(messages map[string][]*T, totalCount int, lastMessage *T) {
triggerID := idutil.OperationIDGenerator()
b.HookFunc(triggerID, messages, totalCount, lastMessage)
for key, data := range messages {
if b.config.syncWait {
b.counter.Add(1)
}
channelID := b.Sharding(key)
b.chArrays[channelID] <- &Msg[T]{key: key, triggerID: triggerID, val: data}
}
if b.config.syncWait {
b.counter.Wait()
}
b.OnComplete(lastMessage, totalCount)
}
func (b *Batcher[T]) run(channelID int, ch <-chan *Msg[T]) {
defer b.wait.Done()
for {
select {
case messages, ok := <-ch:
if !ok {
return
}
b.Do(context.Background(), channelID, messages)
if b.config.syncWait {
b.counter.Done()
}
}
}
}
func (b *Batcher[T]) Close() {
b.cancel() // Signal to stop put data
b.data <- nil
//wait all goroutines exit
b.wait.Wait()
}

@ -0,0 +1,66 @@
package batcher
import (
"context"
"fmt"
"github.com/openimsdk/tools/utils/stringutil"
"testing"
"time"
)
func TestBatcher(t *testing.T) {
config := Config{
size: 1000,
buffer: 10,
worker: 10,
interval: 5 * time.Millisecond,
}
b := New[string](
WithSize(config.size),
WithBuffer(config.buffer),
WithWorker(config.worker),
WithInterval(config.interval),
WithSyncWait(true),
)
// Mock Do function to simply print values for demonstration
b.Do = func(ctx context.Context, channelID int, vals *Msg[string]) {
t.Logf("Channel %d Processed batch: %v", channelID, vals)
}
b.OnComplete = func(lastMessage *string, totalCount int) {
t.Logf("Completed processing with last message: %v, total count: %d", *lastMessage, totalCount)
}
b.Sharding = func(key string) int {
hashCode := stringutil.GetHashCode(key)
return int(hashCode) % config.worker
}
b.Key = func(data *string) string {
return *data
}
err := b.Start()
if err != nil {
t.Fatal(err)
}
// Test normal data processing
for i := 0; i < 10000; i++ {
data := "data" + fmt.Sprintf("%d", i)
if err := b.Put(context.Background(), &data); err != nil {
t.Fatal(err)
}
}
time.Sleep(time.Duration(1) * time.Second)
start := time.Now()
// Wait for all processing to finish
b.Close()
elapsed := time.Since(start)
t.Logf("Close took %s", elapsed)
if len(b.data) != 0 {
t.Error("Data channel should be empty after closing")
}
}
Loading…
Cancel
Save