diff --git a/pkg/common/db/cache/msg.go b/pkg/common/db/cache/msg.go index c8346a1d4..6d0ee8c67 100644 --- a/pkg/common/db/cache/msg.go +++ b/pkg/common/db/cache/msg.go @@ -20,11 +20,8 @@ import ( "strconv" "time" - "github.com/dtm-labs/rockscache" "golang.org/x/sync/errgroup" - unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation" - "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "github.com/OpenIMSDK/tools/errs" @@ -136,10 +133,7 @@ func NewMsgCacheModel(client redis.UniversalClient) MsgModel { type msgCache struct { metaCache - rdb redis.UniversalClient - expireTime time.Duration - rcClient *rockscache.Client - msgDocDatabase unrelationtb.MsgDocModelInterface + rdb redis.UniversalClient } func (c *msgCache) getMaxSeqKey(conversationID string) string { @@ -176,29 +170,6 @@ func (c *msgCache) getSeqs(ctx context.Context, items []string, getkey func(s st } return m, nil - - //pipe := c.rdb.Pipeline() - //for _, v := range items { - // if err := pipe.Get(ctx, getkey(v)).Err(); err != nil && err != redis.Nil { - // return nil, errs.Wrap(err) - // } - //} - //result, err := pipe.Exec(ctx) - //if err != nil && err != redis.Nil { - // return nil, errs.Wrap(err) - //} - //m = make(map[string]int64, len(items)) - //for i, v := range result { - // seq := v.(*redis.StringCmd) - // if seq.Err() != nil && seq.Err() != redis.Nil { - // return nil, errs.Wrap(v.Err()) - // } - // val := utils.StringToInt64(seq.Val()) - // if val != 0 { - // m[items[i]] = val - // } - //} - //return m, nil } func (c *msgCache) SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error { @@ -224,15 +195,6 @@ func (c *msgCache) setSeqs(ctx context.Context, seqs map[string]int64, getkey fu } } return nil - //pipe := c.rdb.Pipeline() - //for k, seq := range seqs { - // err := pipe.Set(ctx, getkey(k), seq, 0).Err() - // if err != nil { - // return errs.Wrap(err) - // } - //} - //_, err := pipe.Exec(ctx) - //return err } func (c *msgCache) SetMinSeqs(ctx context.Context, seqs map[string]int64) error { @@ -637,20 +599,49 @@ func (c *msgCache) DelUserDeleteMsgsList(ctx context.Context, conversationID str } func (c *msgCache) DeleteMessages(ctx context.Context, conversationID string, seqs []int64) error { + if config.Config.Redis.EnablePipeline { + return c.PipeDeleteMessages(ctx, conversationID, seqs) + } + + return c.ParallelDeleteMessages(ctx, conversationID, seqs) +} + +func (c *msgCache) ParallelDeleteMessages(ctx context.Context, conversationID string, seqs []int64) error { + wg := errgroup.Group{} + wg.SetLimit(concurrentLimit) + for _, seq := range seqs { - if err := c.rdb.Del(ctx, c.getMessageCacheKey(conversationID, seq)).Err(); err != nil { + seq := seq + wg.Go(func() error { + err := c.rdb.Del(ctx, c.getMessageCacheKey(conversationID, seq)).Err() + if err != nil { + return errs.Wrap(err) + } + return nil + }) + } + + return wg.Wait() +} + +func (c *msgCache) PipeDeleteMessages(ctx context.Context, conversationID string, seqs []int64) error { + pipe := c.rdb.Pipeline() + for _, seq := range seqs { + _ = pipe.Del(ctx, c.getMessageCacheKey(conversationID, seq)) + } + + results, err := pipe.Exec(ctx) + if err != nil { + return errs.Wrap(err, "pipe.del") + } + + for _, res := range results { + if res.Err() != nil { return errs.Wrap(err) } } + return nil - //pipe := c.rdb.Pipeline() - //for _, seq := range seqs { - // if err := pipe.Del(ctx, c.getMessageCacheKey(conversationID, seq)).Err(); err != nil { - // return errs.Wrap(err) - // } - //} - //_, err := pipe.Exec(ctx) - //return errs.Wrap(err) } func (c *msgCache) CleanUpOneConversationAllMsg(ctx context.Context, conversationID string) error { @@ -667,14 +658,6 @@ func (c *msgCache) CleanUpOneConversationAllMsg(ctx context.Context, conversatio } } return nil - //pipe := c.rdb.Pipeline() - //for _, v := range vals { - // if err := pipe.Del(ctx, v).Err(); err != nil { - // return errs.Wrap(err) - // } - //} - //_, err = pipe.Exec(ctx) - //return errs.Wrap(err) } func (c *msgCache) DelMsgFromCache(ctx context.Context, userID string, seqs []int64) error { diff --git a/pkg/common/db/cache/msg_test.go b/pkg/common/db/cache/msg_test.go index c5a4fb870..3fddf5965 100644 --- a/pkg/common/db/cache/msg_test.go +++ b/pkg/common/db/cache/msg_test.go @@ -249,3 +249,139 @@ func testPipeGetMessagesBySeqWithLostHalfSeqs(t *testing.T, cid string, seqs []i assert.Equal(t, msg.Seq, seqs[idx]) } } + +func TestPipeDeleteMessages(t *testing.T) { + var ( + cid = fmt.Sprintf("cid-%v", rand.Int63()) + seqFirst = rand.Int63() + msgs = []*sdkws.MsgData{} + ) + + var seqs []int64 + for i := 0; i < 100; i++ { + msgs = append(msgs, &sdkws.MsgData{ + Seq: seqFirst + int64(i), + }) + seqs = append(seqs, msgs[i].Seq) + } + + testPipeSetMessageToCache(t, cid, msgs) + testPipeDeleteMessagesOK(t, cid, seqs, msgs) + + // set again + testPipeSetMessageToCache(t, cid, msgs) + testPipeDeleteMessagesMix(t, cid, seqs[:90], msgs) +} + +func testPipeDeleteMessagesOK(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) + assert.Nil(t, err) + + // validate + for _, msg := range inputMsgs { + key := cacher.getMessageCacheKey(cid, msg.Seq) + val := rdb.Exists(context.Background(), key).Val() + assert.EqualValues(t, 0, val) + } +} + +func testPipeDeleteMessagesMix(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) + assert.Nil(t, err) + + // validate + for idx, msg := range inputMsgs { + key := cacher.getMessageCacheKey(cid, msg.Seq) + val, err := rdb.Exists(context.Background(), key).Result() + assert.Nil(t, err) + if idx < 90 { + assert.EqualValues(t, 0, val) // not exists + continue + } + + assert.EqualValues(t, 1, val) // exists + } +} + +func TestParallelDeleteMessages(t *testing.T) { + var ( + cid = fmt.Sprintf("cid-%v", rand.Int63()) + seqFirst = rand.Int63() + msgs = []*sdkws.MsgData{} + ) + + var seqs []int64 + for i := 0; i < 100; i++ { + msgs = append(msgs, &sdkws.MsgData{ + Seq: seqFirst + int64(i), + }) + seqs = append(seqs, msgs[i].Seq) + } + + randSeqs := []int64{} + for i := seqFirst + 100; i < seqFirst+200; i++ { + randSeqs = append(randSeqs, i) + } + + testParallelSetMessageToCache(t, cid, msgs) + testParallelDeleteMessagesOK(t, cid, seqs, msgs) + + // set again + testParallelSetMessageToCache(t, cid, msgs) + testParallelDeleteMessagesMix(t, cid, seqs[:90], msgs, 90) + testParallelDeleteMessagesOK(t, cid, seqs[90:], msgs[:90]) + + // set again + testParallelSetMessageToCache(t, cid, msgs) + testParallelDeleteMessagesMix(t, cid, randSeqs, msgs, 0) +} + +func testParallelDeleteMessagesOK(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) + assert.Nil(t, err) + + // validate + for _, msg := range inputMsgs { + key := cacher.getMessageCacheKey(cid, msg.Seq) + val := rdb.Exists(context.Background(), key).Val() + assert.EqualValues(t, 0, val) + } +} + +func testParallelDeleteMessagesMix(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData, lessValNonExists int) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) + assert.Nil(t, err) + + // validate + for idx, msg := range inputMsgs { + key := cacher.getMessageCacheKey(cid, msg.Seq) + val, err := rdb.Exists(context.Background(), key).Result() + assert.Nil(t, err) + if idx < lessValNonExists { + assert.EqualValues(t, 0, val) // not exists + continue + } + + assert.EqualValues(t, 1, val) // exists + } +}