perf: add concurrency and pipeline for redis cache (#1338)

* perf: add concurrency and pipeline mode for redis cache

Signed-off-by: rfyiamcool <rfyiamcool@163.com>

* perf: add concurrency and pipeline mode for redis cache

Signed-off-by: rfyiamcool <rfyiamcool@163.com>

* perf: unit test for redis cache

Signed-off-by: rfyiamcool <rfyiamcool@163.com>

---------

Signed-off-by: rfyiamcool <rfyiamcool@163.com>
pull/1371/head
fengyun.rui 1 year ago committed by GitHub
parent d1af343b13
commit 2496a16a88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -78,10 +78,11 @@ type configStruct struct {
} `yaml:"mongo"`
Redis struct {
ClusterMode bool `yaml:"clusterMode"`
Address []string `yaml:"address"`
Username string `yaml:"username"`
Password string `yaml:"password"`
ClusterMode bool `yaml:"clusterMode"`
Address []string `yaml:"address"`
Username string `yaml:"username"`
Password string `yaml:"password"`
EnablePipeline bool `yaml:"enablePipeline"`
} `yaml:"redis"`
Kafka struct {

@ -21,6 +21,7 @@ import (
"time"
"github.com/dtm-labs/rockscache"
"golang.org/x/sync/errgroup"
unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
@ -62,6 +63,8 @@ const (
uidPidToken = "UID_PID_TOKEN_STATUS:"
)
var concurrentLimit = 3
type SeqCache interface {
SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error
GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error)
@ -345,85 +348,165 @@ func (c *msgCache) allMessageCacheKey(conversationID string) string {
}
func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) {
if config.Config.Redis.EnablePipeline {
return c.PipeGetMessagesBySeq(ctx, conversationID, seqs)
}
return c.ParallelGetMessagesBySeq(ctx, conversationID, seqs)
}
func (c *msgCache) PipeGetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) {
pipe := c.rdb.Pipeline()
results := []*redis.StringCmd{}
for _, seq := range seqs {
res, err := c.rdb.Get(ctx, c.getMessageCacheKey(conversationID, seq)).Result()
if err != nil {
log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq)
results = append(results, pipe.Get(ctx, c.getMessageCacheKey(conversationID, seq)))
}
_, err = pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return seqMsgs, failedSeqs, errs.Wrap(err, "pipe.get")
}
for idx, res := range results {
seq := seqs[idx]
if res.Err() != nil {
log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq, "err", res.Err())
failedSeqs = append(failedSeqs, seq)
continue
}
msg := sdkws.MsgData{}
if err = msgprocessor.String2Pb(res, &msg); err != nil {
if err = msgprocessor.String2Pb(res.Val(), &msg); err != nil {
log.ZError(ctx, "GetMessagesBySeq Unmarshal failed", err, "res", res, "conversationID", conversationID, "seq", seq)
failedSeqs = append(failedSeqs, seq)
continue
}
if msg.Status == constant.MsgDeleted {
failedSeqs = append(failedSeqs, seq)
continue
}
seqMsgs = append(seqMsgs, &msg)
}
return
//pipe := c.rdb.Pipeline()
//for _, v := range seqs {
// // MESSAGE_CACHE:169.254.225.224_reliability1653387820_0_1
// key := c.getMessageCacheKey(conversationID, v)
// if err := pipe.Get(ctx, key).Err(); err != nil && err != redis.Nil {
// return nil, nil, err
// }
//}
//result, err := pipe.Exec(ctx)
//for i, v := range result {
// cmd := v.(*redis.StringCmd)
// if cmd.Err() != nil {
// failedSeqs = append(failedSeqs, seqs[i])
// } else {
// msg := sdkws.MsgData{}
// err = msgprocessor.String2Pb(cmd.Val(), &msg)
// if err == nil {
// if msg.Status != constant.MsgDeleted {
// seqMsgs = append(seqMsgs, &msg)
// continue
// }
// } else {
// log.ZWarn(ctx, "UnmarshalString failed", err, "conversationID", conversationID, "seq", seqs[i], "msg", cmd.Val())
// }
// failedSeqs = append(failedSeqs, seqs[i])
// }
//}
//return seqMsgs, failedSeqs, err
}
func (c *msgCache) ParallelGetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) {
type entry struct {
err error
msg *sdkws.MsgData
}
wg := errgroup.Group{}
wg.SetLimit(concurrentLimit)
results := make([]entry, len(seqs)) // set slice len/cap to length of seqs.
for idx, seq := range seqs {
// closure safe var
idx := idx
seq := seq
wg.Go(func() error {
res, err := c.rdb.Get(ctx, c.getMessageCacheKey(conversationID, seq)).Result()
if err != nil {
log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq)
results[idx] = entry{err: err}
return nil
}
msg := sdkws.MsgData{}
if err = msgprocessor.String2Pb(res, &msg); err != nil {
log.ZError(ctx, "GetMessagesBySeq Unmarshal failed", err, "res", res, "conversationID", conversationID, "seq", seq)
results[idx] = entry{err: err}
return nil
}
if msg.Status == constant.MsgDeleted {
results[idx] = entry{err: err}
return nil
}
results[idx] = entry{msg: &msg}
return nil
})
}
_ = wg.Wait()
for idx, res := range results {
if res.err != nil {
failedSeqs = append(failedSeqs, seqs[idx])
continue
}
seqMsgs = append(seqMsgs, res.msg)
}
return
}
func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) {
if config.Config.Redis.EnablePipeline {
return c.PipeSetMessageToCache(ctx, conversationID, msgs)
}
return c.ParallelSetMessageToCache(ctx, conversationID, msgs)
}
func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) {
pipe := c.rdb.Pipeline()
for _, msg := range msgs {
s, err := msgprocessor.Pb2String(msg)
if err != nil {
return 0, errs.Wrap(err)
return 0, errs.Wrap(err, "pb.marshal")
}
key := c.getMessageCacheKey(conversationID, msg.Seq)
if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
_ = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second)
}
results, err := pipe.Exec(ctx)
if err != nil {
return 0, errs.Wrap(err, "pipe.set")
}
for _, res := range results {
if res.Err() != nil {
return 0, errs.Wrap(err)
}
}
return len(msgs), nil
}
func (c *msgCache) ParallelSetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) {
wg := errgroup.Group{}
wg.SetLimit(concurrentLimit)
for _, msg := range msgs {
msg := msg // closure safe var
wg.Go(func() error {
s, err := msgprocessor.Pb2String(msg)
if err != nil {
return errs.Wrap(err)
}
key := c.getMessageCacheKey(conversationID, msg.Seq)
if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
return errs.Wrap(err)
}
return nil
})
}
err := wg.Wait()
if err != nil {
return 0, err
}
return len(msgs), nil
//pipe := c.rdb.Pipeline()
//var failedMsgs []*sdkws.MsgData
//for _, msg := range msgs {
// key := c.getMessageCacheKey(conversationID, msg.Seq)
// s, err := msgprocessor.Pb2String(msg)
// if err != nil {
// return 0, errs.Wrap(err)
// }
// err = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err()
// if err != nil {
// failedMsgs = append(failedMsgs, msg)
// log.ZWarn(ctx, "set msg 2 cache failed", err, "msg", failedMsgs)
// }
//}
//_, err := pipe.Exec(ctx)
//return len(failedMsgs), err
}
func (c *msgCache) getMessageDelUserListKey(conversationID string, seq int64) string {

@ -0,0 +1,251 @@
package cache
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
)
func TestParallelSetMessageToCache(t *testing.T) {
var (
cid = fmt.Sprintf("cid-%v", rand.Int63())
seqFirst = rand.Int63()
msgs = []*sdkws.MsgData{}
)
for i := 0; i < 100; i++ {
msgs = append(msgs, &sdkws.MsgData{
Seq: seqFirst + int64(i),
})
}
testParallelSetMessageToCache(t, cid, msgs)
}
func testParallelSetMessageToCache(t *testing.T, cid string, msgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
ret, err := cacher.ParallelSetMessageToCache(context.Background(), cid, msgs)
assert.Nil(t, err)
assert.Equal(t, len(msgs), ret)
// validate
for _, msg := range msgs {
key := cacher.getMessageCacheKey(cid, msg.Seq)
val, err := rdb.Exists(context.Background(), key).Result()
assert.Nil(t, err)
assert.EqualValues(t, 1, val)
}
}
func TestPipeSetMessageToCache(t *testing.T) {
var (
cid = fmt.Sprintf("cid-%v", rand.Int63())
seqFirst = rand.Int63()
msgs = []*sdkws.MsgData{}
)
for i := 0; i < 100; i++ {
msgs = append(msgs, &sdkws.MsgData{
Seq: seqFirst + int64(i),
})
}
testPipeSetMessageToCache(t, cid, msgs)
}
func testPipeSetMessageToCache(t *testing.T, cid string, msgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
ret, err := cacher.PipeSetMessageToCache(context.Background(), cid, msgs)
assert.Nil(t, err)
assert.Equal(t, len(msgs), ret)
// validate
for _, msg := range msgs {
key := cacher.getMessageCacheKey(cid, msg.Seq)
val, err := rdb.Exists(context.Background(), key).Result()
assert.Nil(t, err)
assert.EqualValues(t, 1, val)
}
}
func TestGetMessagesBySeq(t *testing.T) {
var (
cid = fmt.Sprintf("cid-%v", rand.Int63())
seqFirst = rand.Int63()
msgs = []*sdkws.MsgData{}
)
seqs := []int64{}
for i := 0; i < 100; i++ {
msgs = append(msgs, &sdkws.MsgData{
Seq: seqFirst + int64(i),
SendID: fmt.Sprintf("fake-sendid-%v", i),
})
seqs = append(seqs, seqFirst+int64(i))
}
// set data to cache
testPipeSetMessageToCache(t, cid, msgs)
// get data from cache with parallet mode
testParallelGetMessagesBySeq(t, cid, seqs, msgs)
// get data from cache with pipeline mode
testPipeGetMessagesBySeq(t, cid, seqs, msgs)
}
func testParallelGetMessagesBySeq(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs)
assert.Nil(t, err)
assert.Equal(t, 0, len(failedSeqs))
assert.Equal(t, len(respMsgs), len(seqs))
// validate
for idx, msg := range respMsgs {
assert.Equal(t, msg.Seq, inputMsgs[idx].Seq)
assert.Equal(t, msg.SendID, inputMsgs[idx].SendID)
}
}
func testPipeGetMessagesBySeq(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs)
assert.Nil(t, err)
assert.Equal(t, 0, len(failedSeqs))
assert.Equal(t, len(respMsgs), len(seqs))
// validate
for idx, msg := range respMsgs {
assert.Equal(t, msg.Seq, inputMsgs[idx].Seq)
assert.Equal(t, msg.SendID, inputMsgs[idx].SendID)
}
}
func TestGetMessagesBySeqWithEmptySeqs(t *testing.T) {
var (
cid = fmt.Sprintf("cid-%v", rand.Int63())
seqFirst int64 = 0
msgs = []*sdkws.MsgData{}
)
seqs := []int64{}
for i := 0; i < 100; i++ {
msgs = append(msgs, &sdkws.MsgData{
Seq: seqFirst + int64(i),
SendID: fmt.Sprintf("fake-sendid-%v", i),
})
seqs = append(seqs, seqFirst+int64(i))
}
// don't set cache, only get data from cache.
// get data from cache with parallet mode
testParallelGetMessagesBySeqWithEmptry(t, cid, seqs, msgs)
// get data from cache with pipeline mode
testPipeGetMessagesBySeqWithEmptry(t, cid, seqs, msgs)
}
func testParallelGetMessagesBySeqWithEmptry(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs)
assert.Nil(t, err)
assert.Equal(t, len(seqs), len(failedSeqs))
assert.Equal(t, 0, len(respMsgs))
}
func testPipeGetMessagesBySeqWithEmptry(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs)
assert.Equal(t, err, redis.Nil)
assert.Equal(t, len(seqs), len(failedSeqs))
assert.Equal(t, 0, len(respMsgs))
}
func TestGetMessagesBySeqWithLostHalfSeqs(t *testing.T) {
var (
cid = fmt.Sprintf("cid-%v", rand.Int63())
seqFirst int64 = 0
msgs = []*sdkws.MsgData{}
)
seqs := []int64{}
for i := 0; i < 100; i++ {
msgs = append(msgs, &sdkws.MsgData{
Seq: seqFirst + int64(i),
SendID: fmt.Sprintf("fake-sendid-%v", i),
})
seqs = append(seqs, seqFirst+int64(i))
}
// Only set half the number of messages.
testParallelSetMessageToCache(t, cid, msgs[:50])
// get data from cache with parallet mode
testParallelGetMessagesBySeqWithLostHalfSeqs(t, cid, seqs, msgs)
// get data from cache with pipeline mode
testPipeGetMessagesBySeqWithLostHalfSeqs(t, cid, seqs, msgs)
}
func testParallelGetMessagesBySeqWithLostHalfSeqs(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs)
assert.Nil(t, err)
assert.Equal(t, len(seqs)/2, len(failedSeqs))
assert.Equal(t, len(seqs)/2, len(respMsgs))
for idx, msg := range respMsgs {
assert.Equal(t, msg.Seq, seqs[idx])
}
}
func testPipeGetMessagesBySeqWithLostHalfSeqs(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) {
rdb := redis.NewClient(&redis.Options{})
defer rdb.Close()
cacher := msgCache{rdb: rdb}
respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs)
assert.Nil(t, err)
assert.Equal(t, len(seqs)/2, len(failedSeqs))
assert.Equal(t, len(seqs)/2, len(respMsgs))
for idx, msg := range respMsgs {
assert.Equal(t, msg.Seq, seqs[idx])
}
}
Loading…
Cancel
Save