diff --git a/pkg/common/storage/cache/redis/batch.go b/pkg/common/storage/cache/redis/batch.go index 8434474cb..5f9a8c82d 100644 --- a/pkg/common/storage/cache/redis/batch.go +++ b/pkg/common/storage/cache/redis/batch.go @@ -19,7 +19,7 @@ func getRocksCacheRedisClient(cli *rockscache.Client) redis.UniversalClient { return (*Client)(unsafe.Pointer(cli)).rdb } -func batchGetCache2[K comparable, V any](ctx context.Context, rcClient *rockscache.Client, expire time.Duration, ids []K, idKey func(id K) string, vId func(v V) K, fn func(ctx context.Context, ids []K) ([]V, error)) ([]V, error) { +func batchGetCache2[K comparable, V any](ctx context.Context, rcClient *rockscache.Client, expire time.Duration, ids []K, idKey func(id K) string, vId func(v *V) K, fn func(ctx context.Context, ids []K) ([]*V, error)) ([]*V, error) { if len(ids) == 0 { return nil, nil } @@ -37,7 +37,7 @@ func batchGetCache2[K comparable, V any](ctx context.Context, rcClient *rockscac if err != nil { return nil, err } - result := make([]V, 0, len(findKeys)) + result := make([]*V, 0, len(findKeys)) for _, keys := range slotKeys { indexCache, err := rcClient.FetchBatch2(ctx, keys, expire, func(idx []int) (map[int]string, error) { queryIds := make([]K, 0, len(idx)) @@ -72,7 +72,7 @@ func batchGetCache2[K comparable, V any](ctx context.Context, rcClient *rockscac if err != nil { return nil, err } - for _, data := range indexCache { + for index, data := range indexCache { if data == "" { continue } @@ -80,8 +80,15 @@ func batchGetCache2[K comparable, V any](ctx context.Context, rcClient *rockscac if err := json.Unmarshal([]byte(data), &value); err != nil { return nil, err } - result = append(result, value) + if cb, ok := any(&value).(BatchCacheCallback[K]); ok { + cb.BatchCache(keyId[keys[index]]) + } + result = append(result, &value) } } return result, nil } + +type BatchCacheCallback[K comparable] interface { + BatchCache(id K) +} diff --git a/pkg/common/storage/cache/redis/batch_test.go b/pkg/common/storage/cache/redis/batch_test.go index b9174abbf..e4caa2a21 100644 --- a/pkg/common/storage/cache/redis/batch_test.go +++ b/pkg/common/storage/cache/redis/batch_test.go @@ -2,14 +2,11 @@ package redis import ( "context" - "github.com/dtm-labs/rockscache" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo" - "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/tools/db/mongoutil" "github.com/openimsdk/tools/db/redisutil" "testing" - "time" ) func TestName(t *testing.T) { @@ -37,43 +34,22 @@ func TestName(t *testing.T) { if err != nil { panic(err) } - userMgo, err := mgo.NewUserMongo(mgocli.GetDB()) + //userMgo, err := mgo.NewUserMongo(mgocli.GetDB()) + //if err != nil { + // panic(err) + //} + //rock := rockscache.NewClient(rdb, rockscache.NewDefaultOptions()) + mgoSeqUser, err := mgo.NewSeqUserMongo(mgocli.GetDB()) if err != nil { panic(err) } - rock := rockscache.NewClient(rdb, rockscache.NewDefaultOptions()) - //var keys []string - //for i := 1; i <= 10; i++ { - // keys = append(keys, fmt.Sprintf("test%d", i)) - //} - //res, err := cli.FetchBatch2(ctx, keys, time.Hour, func(idx []int) (map[int]string, error) { - // t.Log("FetchBatch2=>", idx) - // time.Sleep(time.Second * 1) - // res := make(map[int]string) - // for _, i := range idx { - // res[i] = fmt.Sprintf("hello_%d", i) - // } - // t.Log("FetchBatch2=>", res) - // return res, nil - //}) - //if err != nil { - // t.Log(err) - // return - //} - //t.Log(res) - - userIDs := []string{"1814217053", "2110910952", "1234567890"} + seqUser := NewSeqUserCacheRedis(rdb, mgoSeqUser) - res, err := batchGetCache2(ctx, rock, time.Hour, userIDs, func(id string) string { - return "TEST_USER:" + id - }, func(v *model.User) string { - return v.UserID - }, func(ctx context.Context, ids []string) ([]*model.User, error) { - t.Log("find mongo", ids) - return userMgo.Find(ctx, ids) - }) + res, err := seqUser.GetReadSeqs(ctx, "2110910952", []string{"sg_2920732023", "sg_345762580"}) if err != nil { panic(err) } - t.Log("==>", res) + + t.Log(res) + } diff --git a/pkg/common/storage/cache/redis/redis_shard_manager.go b/pkg/common/storage/cache/redis/redis_shard_manager.go index 98d70dabf..17e5fecf6 100644 --- a/pkg/common/storage/cache/redis/redis_shard_manager.go +++ b/pkg/common/storage/cache/redis/redis_shard_manager.go @@ -2,6 +2,7 @@ package redis import ( "context" + "github.com/dtm-labs/rockscache" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/redis/go-redis/v9" @@ -109,7 +110,7 @@ func (rsm *RedisShardManager) ProcessKeysBySlot( func groupKeysBySlot(ctx context.Context, redisClient redis.UniversalClient, keys []string) (map[int64][]string, error) { slots := make(map[int64][]string) clusterClient, isCluster := redisClient.(*redis.ClusterClient) - if isCluster { + if isCluster && len(keys) > 1 { pipe := clusterClient.Pipeline() cmds := make([]*redis.IntCmd, len(keys)) for i, key := range keys { @@ -195,3 +196,16 @@ func ProcessKeysBySlot( } return nil } + +func DeleteCacheBySlot(ctx context.Context, rcClient *rockscache.Client, keys []string) error { + switch len(keys) { + case 0: + return nil + case 1: + return rcClient.TagAsDeletedBatch2(ctx, keys) + default: + return ProcessKeysBySlot(ctx, getRocksCacheRedisClient(rcClient), keys, func(ctx context.Context, slot int64, keys []string) error { + return rcClient.TagAsDeletedBatch2(ctx, keys) + }) + } +} diff --git a/pkg/common/storage/cache/redis/seq_conversation.go b/pkg/common/storage/cache/redis/seq_conversation.go index 034462fd1..76cac2b02 100644 --- a/pkg/common/storage/cache/redis/seq_conversation.go +++ b/pkg/common/storage/cache/redis/seq_conversation.go @@ -2,6 +2,7 @@ package redis import ( "context" + "errors" "fmt" "github.com/dtm-labs/rockscache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" @@ -39,13 +40,7 @@ func (s *seqConversationCacheRedis) getMinSeqKey(conversationID string) string { } func (s *seqConversationCacheRedis) SetMinSeq(ctx context.Context, conversationID string, seq int64) error { - if err := s.mgo.SetMinSeq(ctx, conversationID, seq); err != nil { - return err - } - if err := s.rocks.TagAsDeleted2(ctx, s.getMinSeqKey(conversationID)); err != nil { - return errs.Wrap(err) - } - return nil + return s.SetMinSeqs(ctx, map[string]int64{conversationID: seq}) } func (s *seqConversationCacheRedis) GetMinSeq(ctx context.Context, conversationID string) (int64, error) { @@ -54,6 +49,78 @@ func (s *seqConversationCacheRedis) GetMinSeq(ctx context.Context, conversationI }) } +func (s *seqConversationCacheRedis) getSingleMaxSeq(ctx context.Context, conversationID string) (map[string]int64, error) { + seq, err := s.GetMaxSeq(ctx, conversationID) + if err != nil { + return nil, err + } + return map[string]int64{conversationID: seq}, nil +} + +func (s *seqConversationCacheRedis) batchGetMaxSeq(ctx context.Context, keys []string, keyConversationID map[string]string, seqs map[string]int64) error { + result := make([]*redis.StringCmd, len(keys)) + pipe := s.rdb.Pipeline() + for i, key := range keys { + result[i] = pipe.HGet(ctx, key, "CURR") + } + if _, err := pipe.Exec(ctx); err != nil { + return errs.Wrap(err) + } + var notFoundKey []string + for i, r := range result { + req, err := r.Int64() + if err == nil { + seqs[keyConversationID[keys[i]]] = req + } else if errors.Is(err, redis.Nil) { + notFoundKey = append(notFoundKey, keys[i]) + } else { + return errs.Wrap(err) + } + } + if len(notFoundKey) > 0 { + conversationID := keyConversationID[notFoundKey[0]] + seq, err := s.GetMaxSeq(ctx, conversationID) + if err != nil { + return err + } + seqs[conversationID] = seq + } + return nil +} + +func (s *seqConversationCacheRedis) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) { + switch len(conversationIDs) { + case 0: + return map[string]int64{}, nil + case 1: + return s.getSingleMaxSeq(ctx, conversationIDs[0]) + } + keys := make([]string, 0, len(conversationIDs)) + keyConversationID := make(map[string]string, len(conversationIDs)) + for _, conversationID := range conversationIDs { + key := s.getSeqMallocKey(conversationID) + if _, ok := keyConversationID[key]; ok { + continue + } + keys = append(keys, key) + keyConversationID[key] = conversationID + } + if len(keys) == 1 { + return s.getSingleMaxSeq(ctx, conversationIDs[0]) + } + slotKeys, err := groupKeysBySlot(ctx, s.rdb, keys) + if err != nil { + return nil, err + } + seqs := make(map[string]int64, len(conversationIDs)) + for _, keys := range slotKeys { + if err := s.batchGetMaxSeq(ctx, keys, keyConversationID, seqs); err != nil { + return nil, err + } + } + return seqs, nil +} + func (s *seqConversationCacheRedis) getSeqMallocKey(conversationID string) string { return cachekey.GetMallocSeqKey(conversationID) } @@ -253,3 +320,14 @@ func (s *seqConversationCacheRedis) Malloc(ctx context.Context, conversationID s func (s *seqConversationCacheRedis) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) { return s.Malloc(ctx, conversationID, 0) } + +func (s *seqConversationCacheRedis) SetMinSeqs(ctx context.Context, seqs map[string]int64) error { + keys := make([]string, 0, len(seqs)) + for conversationID, seq := range seqs { + keys = append(keys, s.getMinSeqKey(conversationID)) + if err := s.mgo.SetMinSeq(ctx, conversationID, seq); err != nil { + return err + } + } + return DeleteCacheBySlot(ctx, s.rocks, keys) +} diff --git a/pkg/common/storage/cache/redis/seq_user.go b/pkg/common/storage/cache/redis/seq_user.go index 3c533cdd8..2ad43eebd 100644 --- a/pkg/common/storage/cache/redis/seq_user.go +++ b/pkg/common/storage/cache/redis/seq_user.go @@ -64,10 +64,7 @@ func (s *seqUserCacheRedis) GetMinSeq(ctx context.Context, conversationID string } func (s *seqUserCacheRedis) SetMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error { - if err := s.mgo.SetMinSeq(ctx, conversationID, userID, seq); err != nil { - return err - } - return s.rocks.TagAsDeleted2(ctx, s.getSeqUserMinSeqKey(conversationID, userID)) + return s.SetMinSeqs(ctx, userID, map[string]int64{conversationID: seq}) } func (s *seqUserCacheRedis) GetReadSeq(ctx context.Context, conversationID string, userID string) (int64, error) { @@ -87,3 +84,102 @@ func (s *seqUserCacheRedis) SetReadSeq(ctx context.Context, conversationID strin } return nil } + +func (s *seqUserCacheRedis) SetMinSeqs(ctx context.Context, userID string, seqs map[string]int64) error { + keys := make([]string, 0, len(seqs)) + for conversationID, seq := range seqs { + if err := s.mgo.SetMinSeq(ctx, conversationID, userID, seq); err != nil { + return err + } + keys = append(keys, s.getSeqUserMinSeqKey(conversationID, userID)) + } + return DeleteCacheBySlot(ctx, s.rocks, keys) +} + +func (s *seqUserCacheRedis) setRedisReadSeqs(ctx context.Context, userID string, seqs map[string]int64) error { + keys := make([]string, 0, len(seqs)) + keySeq := make(map[string]int64) + for conversationID, seq := range seqs { + key := s.getSeqUserReadSeqKey(conversationID, userID) + keys = append(keys, key) + keySeq[key] = seq + } + slotKeys, err := groupKeysBySlot(ctx, s.rdb, keys) + if err != nil { + return err + } + for _, keys := range slotKeys { + pipe := s.rdb.Pipeline() + for _, key := range keys { + pipe.HSet(ctx, key, "value", strconv.FormatInt(keySeq[key], 10)) + pipe.Expire(ctx, key, s.readExpireTime) + } + if _, err := pipe.Exec(ctx); err != nil { + return err + } + } + return nil +} + +func (s *seqUserCacheRedis) SetReadSeqs(ctx context.Context, userID string, seqs map[string]int64) error { + if len(seqs) == 0 { + return nil + } + if err := s.setRedisReadSeqs(ctx, userID, seqs); err != nil { + return err + } + for conversationID, seq := range seqs { + if seq%s.readSeqWriteRatio == 0 { + if err := s.mgo.SetReadSeq(ctx, conversationID, userID, seq); err != nil { + return err + } + } + } + return nil +} + +func (s *seqUserCacheRedis) GetReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) { + res, err := batchGetCache2(ctx, s.rocks, s.readExpireTime, conversationIDs, func(conversationID string) string { + return s.getSeqUserReadSeqKey(conversationID, userID) + }, func(v *readSeqModel) string { + return v.ConversationID + }, func(ctx context.Context, conversationIDs []string) ([]*readSeqModel, error) { + seqs, err := s.mgo.GetReadSeqs(ctx, userID, conversationIDs) + if err != nil { + return nil, err + } + res := make([]*readSeqModel, 0, len(seqs)) + for conversationID, seq := range seqs { + res = append(res, &readSeqModel{ConversationID: conversationID, Seq: seq}) + } + return res, nil + }) + if err != nil { + return nil, err + } + data := make(map[string]int64) + for _, v := range res { + data[v.ConversationID] = v.Seq + } + return data, nil +} + +var _ BatchCacheCallback[string] = (*readSeqModel)(nil) + +type readSeqModel struct { + ConversationID string + Seq int64 +} + +func (r *readSeqModel) BatchCache(conversationID string) { + r.ConversationID = conversationID +} + +func (r *readSeqModel) UnmarshalJSON(bytes []byte) (err error) { + r.Seq, err = strconv.ParseInt(string(bytes), 10, 64) + return +} + +func (r *readSeqModel) MarshalJSON() ([]byte, error) { + return []byte(strconv.FormatInt(r.Seq, 10)), nil +} diff --git a/pkg/common/storage/cache/seq_conversation.go b/pkg/common/storage/cache/seq_conversation.go index 5d38537a9..2c893a5e8 100644 --- a/pkg/common/storage/cache/seq_conversation.go +++ b/pkg/common/storage/cache/seq_conversation.go @@ -7,4 +7,6 @@ type SeqConversationCache interface { GetMaxSeq(ctx context.Context, conversationID string) (int64, error) SetMinSeq(ctx context.Context, conversationID string, seq int64) error GetMinSeq(ctx context.Context, conversationID string) (int64, error) + GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) + SetMinSeqs(ctx context.Context, seqs map[string]int64) error } diff --git a/pkg/common/storage/cache/seq_user.go b/pkg/common/storage/cache/seq_user.go index 9e68399a7..4d0bb4ffa 100644 --- a/pkg/common/storage/cache/seq_user.go +++ b/pkg/common/storage/cache/seq_user.go @@ -9,4 +9,7 @@ type SeqUser interface { SetMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error GetReadSeq(ctx context.Context, conversationID string, userID string) (int64, error) SetReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error + SetMinSeqs(ctx context.Context, userID string, seqs map[string]int64) error + SetReadSeqs(ctx context.Context, userID string, seqs map[string]int64) error + GetReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) } diff --git a/pkg/common/storage/controller/msg.go b/pkg/common/storage/controller/msg.go index 1e7b5c229..32202ac9e 100644 --- a/pkg/common/storage/controller/msg.go +++ b/pkg/common/storage/controller/msg.go @@ -72,15 +72,8 @@ type CommonMsgDatabase interface { //SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) - //SetMinSeq(ctx context.Context, conversationID string, minSeq int64) error SetMinSeqs(ctx context.Context, seqs map[string]int64) error - //GetMinSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) - //GetMinSeq(ctx context.Context, conversationID string) (int64, error) - //GetConversationUserMinSeq(ctx context.Context, conversationID string, userID string) (int64, error) - //GetConversationUserMinSeqs(ctx context.Context, conversationID string, userIDs []string) (map[string]int64, error) - //SetConversationUserMinSeq(ctx context.Context, conversationID string, userID string, minSeq int64) error - //SetConversationUserMinSeqs(ctx context.Context, conversationID string, seqs map[string]int64) (err error) SetUserConversationsMinSeqs(ctx context.Context, userID string, seqs map[string]int64) (err error) SetHasReadSeq(ctx context.Context, userID string, conversationID string, hasReadSeq int64) error GetHasReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) @@ -784,44 +777,8 @@ func (db *commonMsgDatabase) DeleteUserMsgsBySeqs(ctx context.Context, userID st return nil } -func (db *commonMsgDatabase) DeleteMsgsBySeqs(ctx context.Context, conversationID string, seqs []int64) error { - return nil -} - -func (db *commonMsgDatabase) CleanUpUserConversationsMsgs(ctx context.Context, user string, conversationIDs []string) { - for _, conversationID := range conversationIDs { - maxSeq, err := db.seqConversation.GetMaxSeq(ctx, conversationID) - if err != nil { - if err == redis.Nil { - log.ZDebug(ctx, "max seq is nil", "conversationID", conversationID) - } else { - log.ZError(ctx, "get max seq failed", err, "conversationID", conversationID) - } - continue - } - if err := db.seqConversation.SetMinSeq(ctx, conversationID, maxSeq+1); err != nil { - log.ZError(ctx, "set min seq failed", err, "conversationID", conversationID, "minSeq", maxSeq+1) - } - } -} - -//func (db *commonMsgDatabase) SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error { -// return db.seq.SetMaxSeq(ctx, conversationID, maxSeq) -//} - func (db *commonMsgDatabase) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) { - result := make(map[string]int64) - for _, conversationID := range conversationIDs { - if result[conversationID] != 0 { - continue - } - seq, err := db.seqConversation.GetMaxSeq(ctx, conversationID) - if err != nil { - return nil, err - } - result[conversationID] = seq - } - return result, nil + return db.seqConversation.GetMaxSeqs(ctx, conversationIDs) } func (db *commonMsgDatabase) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) { @@ -833,30 +790,15 @@ func (db *commonMsgDatabase) SetMinSeq(ctx context.Context, conversationID strin } func (db *commonMsgDatabase) SetMinSeqs(ctx context.Context, seqs map[string]int64) error { - for conversationID, seq := range seqs { - if err := db.seqConversation.SetMinSeq(ctx, conversationID, seq); err != nil { - return err - } - } - return nil + return db.seqConversation.SetMinSeqs(ctx, seqs) } func (db *commonMsgDatabase) SetUserConversationsMinSeqs(ctx context.Context, userID string, seqs map[string]int64) error { - for conversationID, seq := range seqs { - if err := db.seqUser.SetMinSeq(ctx, conversationID, userID, seq); err != nil { - return err - } - } - return nil + return db.seqUser.SetMinSeqs(ctx, userID, seqs) } func (db *commonMsgDatabase) UserSetHasReadSeqs(ctx context.Context, userID string, hasReadSeqs map[string]int64) error { - for conversationID, seq := range hasReadSeqs { - if err := db.seqUser.SetReadSeq(ctx, conversationID, userID, seq); err != nil { - return err - } - } - return nil + return db.seqUser.SetReadSeqs(ctx, userID, hasReadSeqs) } func (db *commonMsgDatabase) SetHasReadSeq(ctx context.Context, userID string, conversationID string, hasReadSeq int64) error { @@ -864,18 +806,7 @@ func (db *commonMsgDatabase) SetHasReadSeq(ctx context.Context, userID string, c } func (db *commonMsgDatabase) GetHasReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) { - cSeq := make(map[string]int64) - for _, conversationID := range conversationIDs { - if _, ok := cSeq[conversationID]; ok { - continue - } - seq, err := db.seqUser.GetReadSeq(ctx, conversationID, userID) - if err != nil { - return nil, err - } - cSeq[conversationID] = seq - } - return cSeq, nil + return db.seqUser.GetReadSeqs(ctx, userID, conversationIDs) } func (db *commonMsgDatabase) GetHasReadSeq(ctx context.Context, userID string, conversationID string) (int64, error) { diff --git a/pkg/common/storage/database/mgo/seq_user.go b/pkg/common/storage/database/mgo/seq_user.go index 5e0eb2022..e0cbb08d9 100644 --- a/pkg/common/storage/database/mgo/seq_user.go +++ b/pkg/common/storage/database/mgo/seq_user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/tools/db/mongoutil" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -87,6 +88,23 @@ func (s *seqUserMongo) GetReadSeq(ctx context.Context, conversationID string, us return s.getSeq(ctx, conversationID, userID, "read_seq") } +func (s *seqUserMongo) GetReadSeqs(ctx context.Context, userID string, conversationID []string) (map[string]int64, error) { + if len(conversationID) == 0 { + return map[string]int64{}, nil + } + filter := bson.M{"user_id": userID, "conversation_id": bson.M{"$in": conversationID}} + opt := options.Find().SetProjection(bson.M{"_id": 0, "conversation_id": 1, "read_seq": 1}) + seqs, err := mongoutil.Find[*model.SeqUser](ctx, s.coll, filter, opt) + if err != nil { + return nil, err + } + res := make(map[string]int64) + for _, seq := range seqs { + res[seq.ConversationID] = seq.ReadSeq + } + return res, nil +} + func (s *seqUserMongo) SetReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error { return s.setSeq(ctx, conversationID, userID, seq, "read_seq") } diff --git a/pkg/common/storage/database/seq_user.go b/pkg/common/storage/database/seq_user.go index d32d614dd..edd3910d0 100644 --- a/pkg/common/storage/database/seq_user.go +++ b/pkg/common/storage/database/seq_user.go @@ -9,4 +9,5 @@ type SeqUser interface { SetMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error GetReadSeq(ctx context.Context, conversationID string, userID string) (int64, error) SetReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error + GetReadSeqs(ctx context.Context, userID string, conversationID []string) (map[string]int64, error) }