From bc47a6a8b6d9e3f03b1da9e37cf0cfec718e290c Mon Sep 17 00:00:00 2001 From: Monet Lee Date: Tue, 13 Aug 2024 17:01:32 +0800 Subject: [PATCH] update get docIDs logic. --- pkg/common/storage/controller/msg.go | 22 +++++++++++----------- pkg/common/storage/database/mgo/msg.go | 9 ++++++--- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pkg/common/storage/controller/msg.go b/pkg/common/storage/controller/msg.go index 19b082f2a..ba98f8f24 100644 --- a/pkg/common/storage/controller/msg.go +++ b/pkg/common/storage/controller/msg.go @@ -918,11 +918,20 @@ func (db *commonMsgDatabase) ConvertMsgsDocLen(ctx context.Context, conversation func (db *commonMsgDatabase) GetBeforeMsg(ctx context.Context, ts int64, docIDs []string, limit int) ([]*model.MsgDocModel, error) { var msgs []*model.MsgDocModel for i := 0; i < len(docIDs); i += 1000 { - res, err := db.msgDocDatabase.GetBeforeMsg(ctx, ts, docIDs[i:i+1000], limit) + end := i + 1000 + if end > len(docIDs) { + end = len(docIDs) + } + + res, err := db.msgDocDatabase.GetBeforeMsg(ctx, ts, docIDs[i:end], limit) if err != nil { return nil, err } msgs = append(msgs, res...) + + if len(msgs) >= limit { + return msgs[:limit], nil + } } return msgs, nil } @@ -968,14 +977,5 @@ func (db *commonMsgDatabase) setMinSeq(ctx context.Context, conversationID strin } func (db *commonMsgDatabase) GetDocIDs(ctx context.Context) ([]string, error) { - var docIDsList []string - - docIDs, err := db.msgDocDatabase.GetDocIDs(ctx) - if err != nil { - return nil, errs.Wrap(err) - } - - docIDsList = append(docIDsList, docIDs...) - - return docIDsList, nil + return db.msgDocDatabase.GetDocIDs(ctx) } diff --git a/pkg/common/storage/database/mgo/msg.go b/pkg/common/storage/database/mgo/msg.go index ad77c5b7d..3d3ab46e8 100644 --- a/pkg/common/storage/database/mgo/msg.go +++ b/pkg/common/storage/database/mgo/msg.go @@ -1238,10 +1238,10 @@ func (m *MsgMgo) GetDocIDs(ctx context.Context) ([]string, error) { } if count < int64(limit) { - skip = int(count) + skip = 0 } else { rand.Seed(uint64(time.Now().UnixMilli())) - skip = rand.Intn(int(count - int64(limit))) + skip = rand.Intn(int(count)) } res, err := mongoutil.Aggregate[*model.MsgDocModel](ctx, m.coll, []bson.M{ @@ -1251,7 +1251,10 @@ func (m *MsgMgo) GetDocIDs(ctx context.Context) ([]string, error) { }, }, { - "$limit": skip, + "$skip": skip, + }, + { + "$limit": limit, }, })