diff --git a/internal/rpc/conversation/conversation.go b/internal/rpc/conversation/conversation.go index a48c06051..016ee45e8 100644 --- a/internal/rpc/conversation/conversation.go +++ b/internal/rpc/conversation/conversation.go @@ -54,10 +54,21 @@ type conversationServer struct { webhookClient *webhook.Client userClient *rpcli.UserClient - msgClient *rpcli.MsgClient + msgClient messageClient groupClient *rpcli.GroupClient } +type messageClient interface { + GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) + GetMsgByConversationIDs(ctx context.Context, conversationIDs []string, maxSeqs map[string]int64) (map[string]*sdkws.MsgData, error) + GetHasReadSeqs(ctx context.Context, conversationIDs []string, userID string) (map[string]int64, error) + GetConversationsFullSyncSeqs(ctx context.Context, req *msg.GetConversationsFullSyncSeqsReq) (*msg.GetConversationsFullSyncSeqsResp, error) + SetUserConversationMaxSeq(ctx context.Context, conversationID string, ownerUserIDs []string, maxSeq int64) error + SetUserConversationMin(ctx context.Context, conversationID string, ownerUserIDs []string, minSeq int64) error + GetLastMessageSeqByTime(ctx context.Context, conversationID string, lastTime int64) (int64, error) + GetLastMessage(ctx context.Context, in *msg.GetLastMessageReq, opts ...grpc.CallOption) (*msg.GetLastMessageResp, error) +} + type Config struct { RpcConfig config.Conversation RedisConfig config.Redis diff --git a/internal/rpc/conversation/sync.go b/internal/rpc/conversation/sync.go index 85128f719..b363fd38a 100644 --- a/internal/rpc/conversation/sync.go +++ b/internal/rpc/conversation/sync.go @@ -2,12 +2,22 @@ package conversation import ( "context" + "time" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" "github.com/openimsdk/protocol/conversation" + pbmsg "github.com/openimsdk/protocol/msg" + "github.com/openimsdk/protocol/sdkws" + "github.com/openimsdk/tools/utils/datautil" +) + +var ( + readInactiveConversationFilterEnabled = true + readInactiveConversationCountThreshold = 5000 + readInactiveConversationDuration = int64((30 * 24 * time.Hour) / time.Millisecond) ) func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, req *conversation.GetFullOwnerConversationIDsReq) (*conversation.GetFullOwnerConversationIDsResp, error) { @@ -22,18 +32,82 @@ func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, re if err != nil { return nil, err } + if shouldExcludeReadInactiveConversations(len(conversationIDs)) { + conversationIDs, err = c.excludeReadInactiveConversations(ctx, req.UserID, conversationIDs, readInactiveConversationDuration) + if err != nil { + return nil, err + } + } + total := int64(len(conversationIDs)) idHash := hashutil.IdHash(conversationIDs) if req.IdHash == idHash { conversationIDs = nil + } else if validPagination(req.GetPagination()) { + conversationIDs = datautil.Paginate( + conversationIDs, + int(req.GetPagination().GetPageNumber()), + int(req.GetPagination().GetShowNumber()), + ) } return &conversation.GetFullOwnerConversationIDsResp{ Version: uint64(vl.Version), VersionID: vl.ID.Hex(), Equal: req.IdHash == idHash, ConversationIDs: conversationIDs, + Total: total, }, nil } +func validPagination(pagination *sdkws.RequestPagination) bool { + return pagination != nil && pagination.GetPageNumber() > 0 && pagination.GetShowNumber() > 0 +} + +func shouldExcludeReadInactiveConversations(conversationCount int) bool { + return readInactiveConversationFilterEnabled && + readInactiveConversationDuration > 0 && + conversationCount > readInactiveConversationCountThreshold +} + +func (c *conversationServer) excludeReadInactiveConversations(ctx context.Context, userID string, conversationIDs []string, inactiveDuration int64) ([]string, error) { + if len(conversationIDs) == 0 { + return nil, nil + } + pinnedConversationIDs, err := c.conversationDatabase.GetPinnedConversationIDs(ctx, userID) + if err != nil { + return nil, err + } + pinned := datautil.SliceSet(pinnedConversationIDs) + seqs, err := c.msgClient.GetConversationsFullSyncSeqs(ctx, &pbmsg.GetConversationsFullSyncSeqsReq{ + UserID: userID, + ConversationIDs: conversationIDs, + }) + if err != nil { + return nil, err + } + expireBefore := time.Now().UnixMilli() - inactiveDuration + filteredConversationIDs := make([]string, 0, len(conversationIDs)) + for _, conversationID := range conversationIDs { + if _, ok := pinned[conversationID]; ok { + filteredConversationIDs = append(filteredConversationIDs, conversationID) + continue + } + seq := seqs.GetSeqs()[conversationID] + if seq == nil || !isReadInactiveConversation(seq, expireBefore) { + filteredConversationIDs = append(filteredConversationIDs, conversationID) + } + } + return filteredConversationIDs, nil +} + +func isReadInactiveConversation(seq *pbmsg.FullSyncSeqs, expireBefore int64) bool { + if seq.GetMaxSeq() == 0 || seq.GetUserMinSeq() > seq.GetMaxSeq() { + return true + } + return seq.GetHasReadSeq() >= seq.GetMaxSeq() && + seq.GetMaxSeqTime() > 0 && + seq.GetMaxSeqTime() < expireBefore +} + func (c *conversationServer) GetIncrementalConversation(ctx context.Context, req *conversation.GetIncrementalConversationReq) (*conversation.GetIncrementalConversationResp, error) { if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err diff --git a/internal/rpc/conversation/sync_test.go b/internal/rpc/conversation/sync_test.go new file mode 100644 index 000000000..7baee65da --- /dev/null +++ b/internal/rpc/conversation/sync_test.go @@ -0,0 +1,693 @@ +package conversation + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" + pbconversation "github.com/openimsdk/protocol/conversation" + pbmsg "github.com/openimsdk/protocol/msg" + "github.com/openimsdk/protocol/sdkws" + "github.com/openimsdk/tools/db/pagination" + "github.com/openimsdk/tools/mcontext" + "go.mongodb.org/mongo-driver/bson/primitive" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +func testCtx(userID string) context.Context { + return mcontext.WithOpUserIDContext(context.Background(), userID) +} + +func TestGetFullOwnerConversationIDsFreshDeviceWithoutPaginationKeepsLegacyFullIDs(t *testing.T) { + withReadInactiveConversationFilterEnabled(t, false) + + const conversationCount = 50000 + ids := newConversationIDs(conversationCount) + + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + }) + if err != nil { + t.Fatal(err) + } + if resp.Equal { + t.Fatal("fresh device idHash=0 should not be equal to the server-side conversation ID hash") + } + if got := len(resp.ConversationIDs); got != conversationCount { + t.Fatalf("expected fresh device sync to return all %d conversation IDs, got %d", conversationCount, got) + } + if resp.Total != conversationCount { + t.Fatalf("expected total %d, got %d", conversationCount, resp.Total) + } +} + +func TestGetFullOwnerConversationIDsFreshDeviceWithPaginationReturnsPage(t *testing.T) { + withReadInactiveConversationFilterEnabled(t, false) + + const conversationCount = 50000 + ids := newConversationIDs(conversationCount) + + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + Pagination: &sdkws.RequestPagination{ + PageNumber: 2, + ShowNumber: 100, + }, + }) + if err != nil { + t.Fatal(err) + } + if resp.Equal { + t.Fatal("fresh device idHash=0 should not be equal to the server-side conversation ID hash") + } + if got := len(resp.ConversationIDs); got != 100 { + t.Fatalf("expected paged sync to return 100 conversation IDs, got %d", got) + } + if resp.Total != conversationCount { + t.Fatalf("expected total %d, got %d", conversationCount, resp.Total) + } + if resp.ConversationIDs[0] != ids[100] { + t.Fatalf("expected first ID on page 2 to be %q, got %q", ids[100], resp.ConversationIDs[0]) + } +} + +func TestGetFullOwnerConversationIDsInvalidPaginationKeepsLegacyFullIDs(t *testing.T) { + withReadInactiveConversationFilterEnabled(t, false) + + const conversationCount = 50000 + ids := newConversationIDs(conversationCount) + + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + Pagination: &sdkws.RequestPagination{ + PageNumber: 0, + ShowNumber: 100, + }, + }) + if err != nil { + t.Fatal(err) + } + if got := len(resp.ConversationIDs); got != conversationCount { + t.Fatalf("expected invalid pagination to keep legacy full IDs, got %d", got) + } +} + +func TestGetFullOwnerConversationIDsMatchingHashReturnsNoIDs(t *testing.T) { + ids := []string{"si_user_1", "si_user_2"} + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: hashutil.IdHash(ids), + }) + if err != nil { + t.Fatal(err) + } + if !resp.Equal { + t.Fatal("matching idHash should be reported as equal") + } + if len(resp.ConversationIDs) != 0 { + t.Fatalf("expected matching idHash to omit conversation IDs, got %d", len(resp.ConversationIDs)) + } + if resp.Total != int64(len(ids)) { + t.Fatalf("expected total %d, got %d", len(ids), resp.Total) + } +} + +func TestGetFullOwnerConversationIDsMatchingHashWithPaginationReturnsNoIDs(t *testing.T) { + ids := []string{"si_user_1", "si_user_2"} + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: hashutil.IdHash(ids), + Pagination: &sdkws.RequestPagination{ + PageNumber: 1, + ShowNumber: 1, + }, + }) + if err != nil { + t.Fatal(err) + } + if !resp.Equal { + t.Fatal("matching idHash should be reported as equal") + } + if len(resp.ConversationIDs) != 0 { + t.Fatalf("expected matching idHash to omit conversation IDs, got %d", len(resp.ConversationIDs)) + } + if resp.Total != int64(len(ids)) { + t.Fatalf("expected total %d, got %d", len(ids), resp.Total) + } +} + +func TestGetFullOwnerConversationIDsReadInactiveKeepsLegacyWhenFlagOff(t *testing.T) { + withReadInactiveConversationFilterEnabled(t, false) + + ids := []string{"si_user_1", "si_user_2", "si_user_3"} + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_2": readInactiveSeq(time.Now().Add(-2 * time.Hour)), + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + }) + if err != nil { + t.Fatal(err) + } + if got := len(resp.ConversationIDs); got != len(ids) { + t.Fatalf("expected legacy sync to return all %d conversation IDs, got %d", len(ids), got) + } + if resp.Total != int64(len(ids)) { + t.Fatalf("expected total %d, got %d", len(ids), resp.Total) + } +} + +func TestGetFullOwnerConversationIDsReadInactiveFiltersReadInactiveConversations(t *testing.T) { + withReadInactiveConversationCountThreshold(t, 0) + withReadInactiveConversationDuration(t, int64(time.Hour/time.Millisecond)) + + ids := []string{"si_user_1", "si_user_2", "si_user_3", "si_user_4", "si_user_5"} + filteredIDs := []string{"si_user_1", "si_user_3"} + now := time.Now() + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_1": unreadSeq(now.Add(-2 * time.Hour)), + "si_user_2": readInactiveSeq(now.Add(-2 * time.Hour)), + "si_user_3": readActiveSeq(now.Add(-30 * time.Minute)), + "si_user_4": readInactiveSeq(now.Add(-3 * time.Hour)), + "si_user_5": clearedSeq(now.Add(-2 * time.Hour)), + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + }) + if err != nil { + t.Fatal(err) + } + if !sameStrings(resp.ConversationIDs, filteredIDs) { + t.Fatalf("expected filtered conversation IDs %v, got %v", filteredIDs, resp.ConversationIDs) + } + if resp.Total != int64(len(filteredIDs)) { + t.Fatalf("expected filtered total %d, got %d", len(filteredIDs), resp.Total) + } +} + +func TestGetFullOwnerConversationIDsReadInactiveFiltersEmptyConversations(t *testing.T) { + withReadInactiveConversationCountThreshold(t, 0) + withReadInactiveConversationDuration(t, int64(time.Hour/time.Millisecond)) + + ids := []string{"si_user_1", "si_user_2", "si_user_3"} + now := time.Now() + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_1": {HasReadSeq: 0, MaxSeq: 0, MaxSeqTime: 0}, + "si_user_2": readInactiveSeq(now.Add(-2 * time.Hour)), + "si_user_3": nil, + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + }) + if err != nil { + t.Fatal(err) + } + if !sameStrings(resp.ConversationIDs, []string{"si_user_3"}) { + t.Fatalf("expected empty conversations to be filtered and missing seqs to be kept, got %v", resp.ConversationIDs) + } + if resp.Total != 1 { + t.Fatalf("expected filtered total 1, got %d", resp.Total) + } +} + +func TestGetFullOwnerConversationIDsReadInactiveMatchingFilteredHashReturnsNoIDs(t *testing.T) { + withReadInactiveConversationCountThreshold(t, 0) + withReadInactiveConversationDuration(t, int64(time.Hour/time.Millisecond)) + + ids := []string{"si_user_1", "si_user_2", "si_user_3"} + filteredIDs := []string{"si_user_1", "si_user_3"} + now := time.Now() + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_1": unreadSeq(now.Add(-2 * time.Hour)), + "si_user_2": readInactiveSeq(now.Add(-2 * time.Hour)), + "si_user_3": readActiveSeq(now.Add(-30 * time.Minute)), + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: hashutil.IdHash(filteredIDs), + }) + if err != nil { + t.Fatal(err) + } + if !resp.Equal { + t.Fatal("matching filtered idHash should be reported as equal") + } + if len(resp.ConversationIDs) != 0 { + t.Fatalf("expected matching filtered idHash to omit conversation IDs, got %d", len(resp.ConversationIDs)) + } + if resp.Total != int64(len(filteredIDs)) { + t.Fatalf("expected filtered total %d, got %d", len(filteredIDs), resp.Total) + } +} + +func TestGetFullOwnerConversationIDsReadInactiveKeepsConversationsWithUnknownMaxSeqTime(t *testing.T) { + withReadInactiveConversationCountThreshold(t, 0) + withReadInactiveConversationDuration(t, int64(time.Hour/time.Millisecond)) + + ids := []string{"si_user_1", "si_user_2"} + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_1": {HasReadSeq: 10, MaxSeq: 10, MaxSeqTime: 0}, + "si_user_2": readInactiveSeq(time.Now().Add(-2 * time.Hour)), + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + }) + if err != nil { + t.Fatal(err) + } + if !sameStrings(resp.ConversationIDs, []string{"si_user_1"}) { + t.Fatalf("expected unknown maxSeqTime conversation to be kept, got %v", resp.ConversationIDs) + } +} + +func TestGetFullOwnerConversationIDsReadInactiveKeepsPinnedConversations(t *testing.T) { + withReadInactiveConversationCountThreshold(t, 0) + withReadInactiveConversationDuration(t, int64(time.Hour/time.Millisecond)) + + ids := []string{"si_user_1", "si_user_2", "si_user_3"} + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{ + conversationIDs: ids, + pinnedConversationIDs: []string{"si_user_1", "si_user_2"}, + }, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_1": {HasReadSeq: 0, MaxSeq: 0, MaxSeqTime: 0}, + "si_user_2": readInactiveSeq(time.Now().Add(-2 * time.Hour)), + "si_user_3": readInactiveSeq(time.Now().Add(-2 * time.Hour)), + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + }) + if err != nil { + t.Fatal(err) + } + if !sameStrings(resp.ConversationIDs, []string{"si_user_1", "si_user_2"}) { + t.Fatalf("expected pinned conversations to be kept, got %v", resp.ConversationIDs) + } +} + +func TestGetFullOwnerConversationIDsReadInactivePaginatesFilteredIDs(t *testing.T) { + withReadInactiveConversationCountThreshold(t, 0) + withReadInactiveConversationDuration(t, int64(time.Hour/time.Millisecond)) + + ids := []string{"si_user_1", "si_user_2", "si_user_3", "si_user_4", "si_user_5"} + filteredIDs := []string{"si_user_1", "si_user_3", "si_user_5"} + now := time.Now() + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_1": unreadSeq(now.Add(-2 * time.Hour)), + "si_user_2": readInactiveSeq(now.Add(-2 * time.Hour)), + "si_user_3": readActiveSeq(now.Add(-30 * time.Minute)), + "si_user_4": readInactiveSeq(now.Add(-3 * time.Hour)), + "si_user_5": unreadSeq(now.Add(-3 * time.Hour)), + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + Pagination: &sdkws.RequestPagination{ + PageNumber: 2, + ShowNumber: 1, + }, + }) + if err != nil { + t.Fatal(err) + } + if !sameStrings(resp.ConversationIDs, []string{filteredIDs[1]}) { + t.Fatalf("expected second filtered page %v, got %v", []string{filteredIDs[1]}, resp.ConversationIDs) + } + if resp.Total != int64(len(filteredIDs)) { + t.Fatalf("expected filtered total %d, got %d", len(filteredIDs), resp.Total) + } +} + +func BenchmarkGetFullOwnerConversationIDsLegacyFullIDs(b *testing.B) { + readInactiveConversationFilterEnabled = false + b.Cleanup(func() { + readInactiveConversationFilterEnabled = true + }) + + ids := newConversationIDs(50000) + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + } + req := &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), req) + if err != nil { + b.Fatal(err) + } + b.ReportMetric(float64(len(resp.ConversationIDs)), "ids/op") + b.ReportMetric(float64(proto.Size(resp)), "proto_bytes/op") + } +} + +func BenchmarkGetFullOwnerConversationIDsReadInactiveFilteredIDs(b *testing.B) { + readInactiveConversationDuration = int64(time.Hour / time.Millisecond) + b.Cleanup(func() { + readInactiveConversationDuration = int64((30 * 24 * time.Hour) / time.Millisecond) + }) + + ids := newConversationIDs(50000) + seqs := make(map[string]*pbmsg.FullSyncSeqs, len(ids)) + now := time.Now() + for i, conversationID := range ids { + if i%10 == 0 { + seqs[conversationID] = unreadSeq(now.Add(-2 * time.Hour)) + } else { + seqs[conversationID] = readInactiveSeq(now.Add(-2 * time.Hour)) + } + } + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: seqs}, + } + req := &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), req) + if err != nil { + b.Fatal(err) + } + b.ReportMetric(float64(len(resp.ConversationIDs)), "ids/op") + b.ReportMetric(float64(proto.Size(resp)), "proto_bytes/op") + } +} + +func TestGetFullOwnerConversationIDsReadInactiveSkipsSmallConversationSetsByDefault(t *testing.T) { + ids := []string{"si_user_1", "si_user_2", "si_user_3"} + srv := &conversationServer{ + conversationDatabase: &fakeConversationDatabase{conversationIDs: ids}, + msgClient: fakeMessageClient{seqs: map[string]*pbmsg.FullSyncSeqs{ + "si_user_1": readInactiveSeq(time.Now().Add(-2 * time.Hour)), + "si_user_2": readInactiveSeq(time.Now().Add(-2 * time.Hour)), + "si_user_3": readInactiveSeq(time.Now().Add(-2 * time.Hour)), + }}, + } + + resp, err := srv.GetFullOwnerConversationIDs(testCtx("customer-service"), &pbconversation.GetFullOwnerConversationIDsReq{ + UserID: "customer-service", + IdHash: 0, + }) + if err != nil { + t.Fatal(err) + } + if !sameStrings(resp.ConversationIDs, ids) { + t.Fatalf("expected small conversation sets to keep legacy IDs, got %v", resp.ConversationIDs) + } +} + +func readInactiveSeq(maxSeqTime time.Time) *pbmsg.FullSyncSeqs { + return &pbmsg.FullSyncSeqs{ + HasReadSeq: 10, + MaxSeq: 10, + MaxSeqTime: maxSeqTime.UnixMilli(), + } +} + +func readActiveSeq(maxSeqTime time.Time) *pbmsg.FullSyncSeqs { + return &pbmsg.FullSyncSeqs{ + HasReadSeq: 10, + MaxSeq: 10, + MaxSeqTime: maxSeqTime.UnixMilli(), + } +} + +func clearedSeq(maxSeqTime time.Time) *pbmsg.FullSyncSeqs { + return &pbmsg.FullSyncSeqs{ + HasReadSeq: 10, + MaxSeq: 10, + MaxSeqTime: maxSeqTime.UnixMilli(), + UserMinSeq: 11, + } +} + +func unreadSeq(maxSeqTime time.Time) *pbmsg.FullSyncSeqs { + return &pbmsg.FullSyncSeqs{ + HasReadSeq: 9, + MaxSeq: 10, + MaxSeqTime: maxSeqTime.UnixMilli(), + } +} + +func newConversationIDs(count int) []string { + ids := make([]string, 0, count) + for i := 0; i < count; i++ { + ids = append(ids, fmt.Sprintf("si_user_%05d", i)) + } + return ids +} + +func sameStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func withReadInactiveConversationCountThreshold(t *testing.T, threshold int) { + t.Helper() + old := readInactiveConversationCountThreshold + readInactiveConversationCountThreshold = threshold + t.Cleanup(func() { + readInactiveConversationCountThreshold = old + }) +} + +func withReadInactiveConversationFilterEnabled(t *testing.T, enabled bool) { + t.Helper() + old := readInactiveConversationFilterEnabled + readInactiveConversationFilterEnabled = enabled + t.Cleanup(func() { + readInactiveConversationFilterEnabled = old + }) +} + +func withReadInactiveConversationDuration(t *testing.T, duration int64) { + t.Helper() + old := readInactiveConversationDuration + readInactiveConversationDuration = duration + t.Cleanup(func() { + readInactiveConversationDuration = old + }) +} + +type fakeConversationDatabase struct { + conversationIDs []string + pinnedConversationIDs []string +} + +func (f *fakeConversationDatabase) UpdateUsersConversationField(context.Context, []string, string, map[string]any) error { + return nil +} + +func (f *fakeConversationDatabase) CreateConversation(context.Context, []*model.Conversation) error { + return nil +} + +func (f *fakeConversationDatabase) SyncPeerUserPrivateConversationTx(context.Context, []*model.Conversation) error { + return nil +} + +func (f *fakeConversationDatabase) FindConversations(_ context.Context, _ string, conversationIDs []string) ([]*model.Conversation, error) { + conversations := make([]*model.Conversation, 0, len(conversationIDs)) + for _, conversationID := range conversationIDs { + conversations = append(conversations, &model.Conversation{ConversationID: conversationID}) + } + return conversations, nil +} + +func (f *fakeConversationDatabase) GetUserAllConversation(context.Context, string) ([]*model.Conversation, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) SetUserConversations(context.Context, string, []*model.Conversation) error { + return nil +} + +func (f *fakeConversationDatabase) SetUsersConversationFieldTx(context.Context, []string, *model.Conversation, map[string]any) error { + return nil +} + +func (f *fakeConversationDatabase) UpdateUserConversations(context.Context, string, map[string]any) error { + return nil +} + +func (f *fakeConversationDatabase) CreateGroupChatConversation(context.Context, string, []string, *model.Conversation) error { + return nil +} + +func (f *fakeConversationDatabase) GetConversationIDs(context.Context, string) ([]string, error) { + return f.conversationIDs, nil +} + +func (f *fakeConversationDatabase) GetUserConversationIDsHash(context.Context, string) (uint64, error) { + return hashutil.IdHash(f.conversationIDs), nil +} + +func (f *fakeConversationDatabase) GetAllConversationIDs(context.Context) ([]string, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) GetAllConversationIDsNumber(context.Context) (int64, error) { + return 0, nil +} + +func (f *fakeConversationDatabase) PageConversationIDs(context.Context, pagination.Pagination) ([]string, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) GetConversationsByConversationID(context.Context, []string) ([]*model.Conversation, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) GetConversationIDsNeedDestruct(context.Context) ([]*model.Conversation, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) GetConversationNotReceiveMessageUserIDs(context.Context, string) ([]string, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) FindConversationUserVersion(context.Context, string, uint, int) (*model.VersionLog, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) FindMaxConversationUserVersionCache(context.Context, string) (*model.VersionLog, error) { + return &model.VersionLog{ID: primitive.NewObjectID(), LastUpdate: time.Now()}, nil +} + +func (f *fakeConversationDatabase) GetOwnerConversation(context.Context, string, pagination.Pagination) (int64, []*model.Conversation, error) { + return int64(len(f.conversationIDs)), nil, nil +} + +func (f *fakeConversationDatabase) GetNotNotifyConversationIDs(context.Context, string) ([]string, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) GetPinnedConversationIDs(context.Context, string) ([]string, error) { + return f.pinnedConversationIDs, nil +} + +func (f *fakeConversationDatabase) FindRandConversation(context.Context, int64, int) ([]*model.Conversation, error) { + return nil, nil +} + +func (f *fakeConversationDatabase) DeleteUsersConversations(context.Context, string, []string) error { + return nil +} + +var _ controller.ConversationDatabase = (*fakeConversationDatabase)(nil) + +type fakeMessageClient struct { + seqs map[string]*pbmsg.FullSyncSeqs +} + +func (f fakeMessageClient) GetMaxSeqs(context.Context, []string) (map[string]int64, error) { + return nil, nil +} + +func (f fakeMessageClient) GetMsgByConversationIDs(context.Context, []string, map[string]int64) (map[string]*sdkws.MsgData, error) { + return nil, nil +} + +func (f fakeMessageClient) GetHasReadSeqs(context.Context, []string, string) (map[string]int64, error) { + return nil, nil +} + +func (f fakeMessageClient) GetConversationsFullSyncSeqs(_ context.Context, req *pbmsg.GetConversationsFullSyncSeqsReq) (*pbmsg.GetConversationsFullSyncSeqsResp, error) { + seqs := make(map[string]*pbmsg.FullSyncSeqs, len(req.ConversationIDs)) + for _, conversationID := range req.ConversationIDs { + if seq, ok := f.seqs[conversationID]; ok { + seqs[conversationID] = seq + } + } + return &pbmsg.GetConversationsFullSyncSeqsResp{Seqs: seqs}, nil +} + +func (f fakeMessageClient) SetUserConversationMaxSeq(context.Context, string, []string, int64) error { + return nil +} + +func (f fakeMessageClient) SetUserConversationMin(context.Context, string, []string, int64) error { + return nil +} + +func (f fakeMessageClient) GetLastMessageSeqByTime(context.Context, string, int64) (int64, error) { + return 0, nil +} + +func (f fakeMessageClient) GetLastMessage(context.Context, *pbmsg.GetLastMessageReq, ...grpc.CallOption) (*pbmsg.GetLastMessageResp, error) { + return &pbmsg.GetLastMessageResp{}, nil +} diff --git a/internal/rpc/msg/as_read.go b/internal/rpc/msg/as_read.go index c52ce9c07..300e38142 100644 --- a/internal/rpc/msg/as_read.go +++ b/internal/rpc/msg/as_read.go @@ -85,6 +85,59 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m return resp, nil } +func (m *msgServer) GetConversationsFullSyncSeqs(ctx context.Context, req *msg.GetConversationsFullSyncSeqsReq) (*msg.GetConversationsFullSyncSeqsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } + var conversationIDs []string + if len(req.ConversationIDs) == 0 { + var err error + conversationIDs, err = m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) + if err != nil { + return nil, err + } + } else { + conversationIDs = req.ConversationIDs + } + + hasReadSeqs, err := m.MsgDatabase.GetHasReadSeqs(ctx, req.UserID, conversationIDs) + if err != nil { + return nil, err + } + userMinSeqs, err := m.MsgDatabase.GetUserConversationsMinSeqs(ctx, req.UserID, conversationIDs) + if err != nil { + return nil, err + } + conversations, err := m.ConversationLocalCache.GetConversations(ctx, req.UserID, conversationIDs) + if err != nil { + return nil, err + } + + conversationMaxSeqMap := make(map[string]int64) + for _, conversation := range conversations { + if conversation.MaxSeq != 0 { + conversationMaxSeqMap[conversation.ConversationID] = conversation.MaxSeq + } + } + maxSeqs, err := m.MsgDatabase.GetMaxSeqsWithTime(ctx, conversationIDs) + if err != nil { + return nil, err + } + resp := &msg.GetConversationsFullSyncSeqsResp{Seqs: make(map[string]*msg.FullSyncSeqs)} + for conversationID, maxSeq := range maxSeqs { + resp.Seqs[conversationID] = &msg.FullSyncSeqs{ + HasReadSeq: hasReadSeqs[conversationID], + MaxSeq: maxSeq.Seq, + MaxSeqTime: maxSeq.Time, + UserMinSeq: userMinSeqs[conversationID], + } + if v, ok := conversationMaxSeqMap[conversationID]; ok { + resp.Seqs[conversationID].MaxSeq = v + } + } + return resp, nil +} + func (m *msgServer) SetConversationHasReadSeq(ctx context.Context, req *msg.SetConversationHasReadSeqReq) (*msg.SetConversationHasReadSeqResp, error) { if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err diff --git a/pkg/common/storage/cache/redis/seq_user.go b/pkg/common/storage/cache/redis/seq_user.go index af9cbef5a..fca5cba33 100644 --- a/pkg/common/storage/cache/redis/seq_user.go +++ b/pkg/common/storage/cache/redis/seq_user.go @@ -61,6 +61,32 @@ func (s *seqUserCacheRedis) GetUserMinSeq(ctx context.Context, conversationID st }) } +func (s *seqUserCacheRedis) GetUserMinSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) { + res, err := batchGetCache2(ctx, s.rocks, s.expireTime, conversationIDs, func(conversationID string) string { + return s.getSeqUserMinSeqKey(conversationID, userID) + }, func(v *readSeqModel) string { + return v.ConversationID + }, func(ctx context.Context, conversationIDs []string) ([]*readSeqModel, error) { + seqs, err := s.mgo.GetUserMinSeqs(ctx, userID, conversationIDs) + if err != nil { + return nil, err + } + res := make([]*readSeqModel, 0, len(seqs)) + for conversationID, seq := range seqs { + res = append(res, &readSeqModel{ConversationID: conversationID, Seq: seq}) + } + return res, nil + }) + if err != nil { + return nil, err + } + data := make(map[string]int64) + for _, v := range res { + data[v.ConversationID] = v.Seq + } + return data, nil +} + func (s *seqUserCacheRedis) SetUserMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error { return s.SetUserMinSeqs(ctx, userID, map[string]int64{conversationID: seq}) } diff --git a/pkg/common/storage/cache/seq_user.go b/pkg/common/storage/cache/seq_user.go index cef414e16..ee54fe719 100644 --- a/pkg/common/storage/cache/seq_user.go +++ b/pkg/common/storage/cache/seq_user.go @@ -11,6 +11,7 @@ type SeqUser interface { SetUserReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error SetUserReadSeqToDB(ctx context.Context, conversationID string, userID string, seq int64) error SetUserMinSeqs(ctx context.Context, userID string, seqs map[string]int64) error + GetUserMinSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) SetUserReadSeqs(ctx context.Context, userID string, seqs map[string]int64) error GetUserReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) } diff --git a/pkg/common/storage/controller/msg.go b/pkg/common/storage/controller/msg.go index f833008e8..e0b8a3d21 100644 --- a/pkg/common/storage/controller/msg.go +++ b/pkg/common/storage/controller/msg.go @@ -71,6 +71,7 @@ type CommonMsgDatabase interface { SetMinSeq(ctx context.Context, conversationID string, seq int64) error SetUserConversationsMinSeqs(ctx context.Context, userID string, seqs map[string]int64) (err error) + GetUserConversationsMinSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) SetHasReadSeq(ctx context.Context, userID string, conversationID string, hasReadSeq int64) error GetHasReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) GetHasReadSeq(ctx context.Context, userID string, conversationID string) (int64, error) @@ -573,6 +574,10 @@ func (db *commonMsgDatabase) SetUserConversationsMinSeqs(ctx context.Context, us return db.seqUser.SetUserMinSeqs(ctx, userID, seqs) } +func (db *commonMsgDatabase) GetUserConversationsMinSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) { + return db.seqUser.GetUserMinSeqs(ctx, userID, conversationIDs) +} + func (db *commonMsgDatabase) SetUserConversationsMaxSeq(ctx context.Context, conversationID string, userID string, seq int64) error { return db.seqUser.SetUserMaxSeq(ctx, conversationID, userID, seq) } diff --git a/pkg/common/storage/database/mgo/seq_user.go b/pkg/common/storage/database/mgo/seq_user.go index 244de3000..125d115ad 100644 --- a/pkg/common/storage/database/mgo/seq_user.go +++ b/pkg/common/storage/database/mgo/seq_user.go @@ -80,6 +80,24 @@ func (s *seqUserMongo) GetUserMinSeq(ctx context.Context, conversationID string, return s.getSeq(ctx, conversationID, userID, "min_seq") } +func (s *seqUserMongo) GetUserMinSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) { + if len(conversationIDs) == 0 { + return map[string]int64{}, nil + } + filter := bson.M{"user_id": userID, "conversation_id": bson.M{"$in": conversationIDs}} + opt := options.Find().SetProjection(bson.M{"_id": 0, "conversation_id": 1, "min_seq": 1}) + seqs, err := mongoutil.Find[*model.SeqUser](ctx, s.coll, filter, opt) + if err != nil { + return nil, err + } + res := make(map[string]int64) + for _, seq := range seqs { + res[seq.ConversationID] = seq.MinSeq + } + s.notFoundSet0(res, conversationIDs) + return res, nil +} + func (s *seqUserMongo) SetUserMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error { return s.setSeq(ctx, conversationID, userID, seq, "min_seq") } diff --git a/pkg/common/storage/database/seq_user.go b/pkg/common/storage/database/seq_user.go index 9f75c710b..7ba4e1e3b 100644 --- a/pkg/common/storage/database/seq_user.go +++ b/pkg/common/storage/database/seq_user.go @@ -9,5 +9,6 @@ type SeqUser interface { SetUserMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error GetUserReadSeq(ctx context.Context, conversationID string, userID string) (int64, error) SetUserReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error + GetUserMinSeqs(ctx context.Context, userID string, conversationID []string) (map[string]int64, error) GetUserReadSeqs(ctx context.Context, userID string, conversationID []string) (map[string]int64, error) } diff --git a/pkg/rpcli/msg.go b/pkg/rpcli/msg.go index e4d1ece6e..6b36327c7 100644 --- a/pkg/rpcli/msg.go +++ b/pkg/rpcli/msg.go @@ -41,6 +41,10 @@ func (x *MsgClient) GetHasReadSeqs(ctx context.Context, conversationIDs []string return extractField(ctx, x.MsgClient.GetHasReadSeqs, req, (*msg.SeqsInfoResp).GetMaxSeqs) } +func (x *MsgClient) GetConversationsFullSyncSeqs(ctx context.Context, req *msg.GetConversationsFullSyncSeqsReq) (*msg.GetConversationsFullSyncSeqsResp, error) { + return x.MsgClient.GetConversationsFullSyncSeqs(ctx, req) +} + func (x *MsgClient) SetUserConversationMaxSeq(ctx context.Context, conversationID string, ownerUserIDs []string, maxSeq int64) error { if len(ownerUserIDs) == 0 { return nil