From 326dc3836cd969bceddc1d56aeebb607aa5f2204 Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Thu, 11 Jul 2024 18:47:25 +0800 Subject: [PATCH] rockscache batch get --- pkg/common/storage/cache/group.go | 1 - pkg/common/storage/cache/redis/batch.go | 87 +++++++++++++++++ .../storage/cache/redis/batch_handler.go | 55 ++++++----- pkg/common/storage/cache/redis/batch_test.go | 79 +++++++++++++++ .../storage/cache/redis/conversation.go | 8 +- pkg/common/storage/cache/redis/friend.go | 25 ----- pkg/common/storage/cache/redis/group.go | 90 +++-------------- pkg/common/storage/cache/redis/msg.go | 1 - pkg/common/storage/cache/redis/user.go | 96 +------------------ pkg/common/storage/controller/user.go | 2 +- pkg/common/storage/database/group_member.go | 2 + .../storage/database/mgo/group_member.go | 16 ++++ 12 files changed, 236 insertions(+), 226 deletions(-) create mode 100644 pkg/common/storage/cache/redis/batch.go create mode 100644 pkg/common/storage/cache/redis/batch_test.go diff --git a/pkg/common/storage/cache/group.go b/pkg/common/storage/cache/group.go index 73479bb1b..91953d9f9 100644 --- a/pkg/common/storage/cache/group.go +++ b/pkg/common/storage/cache/group.go @@ -36,7 +36,6 @@ type GroupCache interface { DelGroupMembersHash(groupID string) GroupCache GetGroupMemberIDs(ctx context.Context, groupID string) (groupMemberIDs []string, err error) - GetGroupsMemberIDs(ctx context.Context, groupIDs []string) (groupMemberIDs map[string][]string, err error) DelGroupMemberIDs(groupID string) GroupCache diff --git a/pkg/common/storage/cache/redis/batch.go b/pkg/common/storage/cache/redis/batch.go new file mode 100644 index 000000000..8434474cb --- /dev/null +++ b/pkg/common/storage/cache/redis/batch.go @@ -0,0 +1,87 @@ +package redis + +import ( + "context" + "encoding/json" + "github.com/dtm-labs/rockscache" + "github.com/redis/go-redis/v9" + "golang.org/x/sync/singleflight" + "time" + "unsafe" +) + +func getRocksCacheRedisClient(cli *rockscache.Client) redis.UniversalClient { + type Client struct { + rdb redis.UniversalClient + _ rockscache.Options + _ singleflight.Group + } + return (*Client)(unsafe.Pointer(cli)).rdb +} + +func batchGetCache2[K comparable, V any](ctx context.Context, rcClient *rockscache.Client, expire time.Duration, ids []K, idKey func(id K) string, vId func(v V) K, fn func(ctx context.Context, ids []K) ([]V, error)) ([]V, error) { + if len(ids) == 0 { + return nil, nil + } + findKeys := make([]string, 0, len(ids)) + keyId := make(map[string]K) + for _, id := range ids { + key := idKey(id) + if _, ok := keyId[key]; ok { + continue + } + keyId[key] = id + findKeys = append(findKeys, key) + } + slotKeys, err := groupKeysBySlot(ctx, getRocksCacheRedisClient(rcClient), findKeys) + if err != nil { + return nil, err + } + result := make([]V, 0, len(findKeys)) + for _, keys := range slotKeys { + indexCache, err := rcClient.FetchBatch2(ctx, keys, expire, func(idx []int) (map[int]string, error) { + queryIds := make([]K, 0, len(idx)) + idIndex := make(map[K]int) + for _, index := range idx { + id := keyId[keys[index]] + idIndex[id] = index + queryIds = append(queryIds, id) + } + values, err := fn(ctx, queryIds) + if err != nil { + return nil, err + } + if len(values) == 0 { + return map[int]string{}, nil + } + cacheIndex := make(map[int]string) + for _, value := range values { + id := vId(value) + index, ok := idIndex[id] + if !ok { + continue + } + bs, err := json.Marshal(value) + if err != nil { + return nil, err + } + cacheIndex[index] = string(bs) + } + return cacheIndex, nil + }) + if err != nil { + return nil, err + } + for _, data := range indexCache { + if data == "" { + continue + } + var value V + if err := json.Unmarshal([]byte(data), &value); err != nil { + return nil, err + } + result = append(result, value) + } + } + return result, nil +} diff --git a/pkg/common/storage/cache/redis/batch_handler.go b/pkg/common/storage/cache/redis/batch_handler.go index 95f669904..52e046a40 100644 --- a/pkg/common/storage/cache/redis/batch_handler.go +++ b/pkg/common/storage/cache/redis/batch_handler.go @@ -23,7 +23,6 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/localcache" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" - "github.com/openimsdk/tools/mw/specialerror" "github.com/openimsdk/tools/utils/datautil" "github.com/redis/go-redis/v9" "time" @@ -147,30 +146,30 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin return t, nil } -func batchGetCache[T any, K comparable]( - ctx context.Context, - rcClient *rockscache.Client, - expire time.Duration, - keys []K, - keyFn func(key K) string, - fns func(ctx context.Context, key K) (T, error), -) ([]T, error) { - if len(keys) == 0 { - return nil, nil - } - res := make([]T, 0, len(keys)) - for _, key := range keys { - val, err := getCache(ctx, rcClient, keyFn(key), expire, func(ctx context.Context) (T, error) { - return fns(ctx, key) - }) - if err != nil { - if errs.ErrRecordNotFound.Is(specialerror.ErrCode(errs.Unwrap(err))) { - continue - } - return nil, errs.Wrap(err) - } - res = append(res, val) - } - - return res, nil -} +//func batchGetCache[T any, K comparable]( +// ctx context.Context, +// rcClient *rockscache.Client, +// expire time.Duration, +// keys []K, +// keyFn func(key K) string, +// fns func(ctx context.Context, key K) (T, error), +//) ([]T, error) { +// if len(keys) == 0 { +// return nil, nil +// } +// res := make([]T, 0, len(keys)) +// for _, key := range keys { +// val, err := getCache(ctx, rcClient, keyFn(key), expire, func(ctx context.Context) (T, error) { +// return fns(ctx, key) +// }) +// if err != nil { +// if errs.ErrRecordNotFound.Is(specialerror.ErrCode(errs.Unwrap(err))) { +// continue +// } +// return nil, errs.Wrap(err) +// } +// res = append(res, val) +// } +// +// return res, nil +//} diff --git a/pkg/common/storage/cache/redis/batch_test.go b/pkg/common/storage/cache/redis/batch_test.go new file mode 100644 index 000000000..b9174abbf --- /dev/null +++ b/pkg/common/storage/cache/redis/batch_test.go @@ -0,0 +1,79 @@ +package redis + +import ( + "context" + "github.com/dtm-labs/rockscache" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/db/mongoutil" + "github.com/openimsdk/tools/db/redisutil" + "testing" + "time" +) + +func TestName(t *testing.T) { + //var rocks rockscache.Client + //rdb := getRocksCacheRedisClient(&rocks) + //t.Log(rdb == nil) + + ctx := context.Background() + rdb, err := redisutil.NewRedisClient(ctx, (&config.Redis{ + Address: []string{"172.16.8.48:16379"}, + Password: "openIM123", + DB: 3, + }).Build()) + if err != nil { + panic(err) + } + mgocli, err := mongoutil.NewMongoDB(ctx, (&config.Mongo{ + Address: []string{"172.16.8.48:37017"}, + Database: "openim_v3", + Username: "openIM", + Password: "openIM123", + MaxPoolSize: 100, + MaxRetry: 1, + }).Build()) + if err != nil { + panic(err) + } + userMgo, err := mgo.NewUserMongo(mgocli.GetDB()) + if err != nil { + panic(err) + } + rock := rockscache.NewClient(rdb, rockscache.NewDefaultOptions()) + //var keys []string + //for i := 1; i <= 10; i++ { + // keys = append(keys, fmt.Sprintf("test%d", i)) + //} + //res, err := cli.FetchBatch2(ctx, keys, time.Hour, func(idx []int) (map[int]string, error) { + // t.Log("FetchBatch2=>", idx) + // time.Sleep(time.Second * 1) + // res := make(map[int]string) + // for _, i := range idx { + // res[i] = fmt.Sprintf("hello_%d", i) + // } + // t.Log("FetchBatch2=>", res) + // return res, nil + //}) + //if err != nil { + // t.Log(err) + // return + //} + //t.Log(res) + + userIDs := []string{"1814217053", "2110910952", "1234567890"} + + res, err := batchGetCache2(ctx, rock, time.Hour, userIDs, func(id string) string { + return "TEST_USER:" + id + }, func(v *model.User) string { + return v.UserID + }, func(ctx context.Context, ids []string) ([]*model.User, error) { + t.Log("find mongo", ids) + return userMgo.Find(ctx, ids) + }) + if err != nil { + panic(err) + } + t.Log("==>", res) +} diff --git a/pkg/common/storage/cache/redis/conversation.go b/pkg/common/storage/cache/redis/conversation.go index c491d1b94..95e680afb 100644 --- a/pkg/common/storage/cache/redis/conversation.go +++ b/pkg/common/storage/cache/redis/conversation.go @@ -164,10 +164,12 @@ func (c *ConversationRedisCache) DelConversations(ownerUserID string, conversati } func (c *ConversationRedisCache) GetConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*model.Conversation, error) { - return batchGetCache(ctx, c.rcClient, c.expireTime, conversationIDs, func(conversationID string) string { + return batchGetCache2(ctx, c.rcClient, c.expireTime, conversationIDs, func(conversationID string) string { return c.getConversationKey(ownerUserID, conversationID) - }, func(ctx context.Context, conversationID string) (*model.Conversation, error) { - return c.conversationDB.Take(ctx, ownerUserID, conversationID) + }, func(conversation *model.Conversation) string { + return conversation.ConversationID + }, func(ctx context.Context, conversationIDs []string) ([]*model.Conversation, error) { + return c.conversationDB.Find(ctx, ownerUserID, conversationIDs) }) } diff --git a/pkg/common/storage/cache/redis/friend.go b/pkg/common/storage/cache/redis/friend.go index 01988310c..be4687794 100644 --- a/pkg/common/storage/cache/redis/friend.go +++ b/pkg/common/storage/cache/redis/friend.go @@ -70,10 +70,6 @@ func (f *FriendCacheRedis) getFriendIDsKey(ownerUserID string) string { return cachekey.GetFriendIDsKey(ownerUserID) } -//func (f *FriendCacheRedis) getFriendSyncSortUserIDsKey(ownerUserID string) string { -// return cachekey.GetFriendSyncSortUserIDsKey(ownerUserID, f.syncCount) -//} - func (f *FriendCacheRedis) getFriendMaxVersionKey(ownerUserID string) string { return cachekey.GetFriendMaxVersionKey(ownerUserID) } @@ -107,16 +103,6 @@ func (f *FriendCacheRedis) DelFriendIDs(ownerUserIDs ...string) cache.FriendCach return newFriendCache } -//func (f *FriendCacheRedis) DelSortFriendUserIDs(ownerUserIDs ...string) cache.FriendCache { -// newGroupCache := f.CloneFriendCache() -// keys := make([]string, 0, len(ownerUserIDs)) -// for _, userID := range ownerUserIDs { -// keys = append(keys, f.getFriendSyncSortUserIDsKey(userID)) -// } -// newGroupCache.AddKeys(keys...) -// return newGroupCache -//} - // GetTwoWayFriendIDs retrieves two-way friend IDs from the cache. func (f *FriendCacheRedis) GetTwoWayFriendIDs(ctx context.Context, ownerUserID string) (twoWayFriendIDs []string, err error) { friendIDs, err := f.GetFriendIDs(ctx, ownerUserID) @@ -193,17 +179,6 @@ func (f *FriendCacheRedis) DelMaxFriendVersion(ownerUserIDs ...string) cache.Fri return newFriendCache } -//func (f *FriendCacheRedis) FindSortFriendUserIDs(ctx context.Context, ownerUserID string) ([]string, error) { -// userIDs, err := f.GetFriendIDs(ctx, ownerUserID) -// if err != nil { -// return nil, err -// } -// if len(userIDs) > f.syncCount { -// userIDs = userIDs[:f.syncCount] -// } -// return userIDs, nil -//} - func (f *FriendCacheRedis) FindMaxFriendVersion(ctx context.Context, ownerUserID string) (*model.VersionLog, error) { return getCache(ctx, f.rcClient, f.getFriendMaxVersionKey(ownerUserID), f.expireTime, func(ctx context.Context) (*model.VersionLog, error) { return f.friendDB.FindIncrVersion(ctx, ownerUserID, 0, 0) diff --git a/pkg/common/storage/cache/redis/group.go b/pkg/common/storage/cache/redis/group.go index 589678c50..d327c218f 100644 --- a/pkg/common/storage/cache/redis/group.go +++ b/pkg/common/storage/cache/redis/group.go @@ -118,34 +118,12 @@ func (g *GroupCacheRedis) getJoinGroupMaxVersionKey(userID string) string { return cachekey.GetJoinGroupMaxVersionKey(userID) } -func (g *GroupCacheRedis) GetGroupIndex(group *model.Group, keys []string) (int, error) { - key := g.getGroupInfoKey(group.GroupID) - for i, _key := range keys { - if _key == key { - return i, nil - } - } - - return 0, errIndex -} - -func (g *GroupCacheRedis) GetGroupMemberIndex(groupMember *model.GroupMember, keys []string) (int, error) { - key := g.getGroupMemberInfoKey(groupMember.GroupID, groupMember.UserID) - for i, _key := range keys { - if _key == key { - return i, nil - } - } - - return 0, errIndex +func (g *GroupCacheRedis) getGroupID(group *model.Group) string { + return group.GroupID } func (g *GroupCacheRedis) GetGroupsInfo(ctx context.Context, groupIDs []string) (groups []*model.Group, err error) { - return batchGetCache(ctx, g.rcClient, g.expireTime, groupIDs, func(groupID string) string { - return g.getGroupInfoKey(groupID) - }, func(ctx context.Context, groupID string) (*model.Group, error) { - return g.groupDB.Take(ctx, groupID) - }) + return batchGetCache2(ctx, g.rcClient, g.expireTime, groupIDs, g.getGroupInfoKey, g.getGroupID, g.groupDB.Find) } func (g *GroupCacheRedis) GetGroupInfo(ctx context.Context, groupID string) (group *model.Group, err error) { @@ -233,19 +211,6 @@ func (g *GroupCacheRedis) GetGroupMemberIDs(ctx context.Context, groupID string) }) } -func (g *GroupCacheRedis) GetGroupsMemberIDs(ctx context.Context, groupIDs []string) (map[string][]string, error) { - m := make(map[string][]string) - for _, groupID := range groupIDs { - userIDs, err := g.GetGroupMemberIDs(ctx, groupID) - if err != nil { - return nil, err - } - m[groupID] = userIDs - } - - return m, nil -} - func (g *GroupCacheRedis) DelGroupMemberIDs(groupID string) cache.GroupCache { cache := g.CloneGroupCache() cache.AddKeys(g.getGroupMemberIDsKey(groupID)) @@ -285,10 +250,12 @@ func (g *GroupCacheRedis) GetGroupMemberInfo(ctx context.Context, groupID, userI } func (g *GroupCacheRedis) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupMember, error) { - return batchGetCache(ctx, g.rcClient, g.expireTime, userIDs, func(userID string) string { + return batchGetCache2(ctx, g.rcClient, g.expireTime, userIDs, func(userID string) string { return g.getGroupMemberInfoKey(groupID, userID) - }, func(ctx context.Context, userID string) (*model.GroupMember, error) { - return g.groupMemberDB.Take(ctx, groupID, userID) + }, func(member *model.GroupMember) string { + return member.UserID + }, func(ctx context.Context, userIDs []string) ([]*model.GroupMember, error) { + return g.groupMemberDB.Find(ctx, groupID, userIDs) }) } @@ -301,14 +268,6 @@ func (g *GroupCacheRedis) GetAllGroupMembersInfo(ctx context.Context, groupID st return g.GetGroupMembersInfo(ctx, groupID, groupMemberIDs) } -func (g *GroupCacheRedis) GetAllGroupMemberInfo(ctx context.Context, groupID string) ([]*model.GroupMember, error) { - groupMemberIDs, err := g.GetGroupMemberIDs(ctx, groupID) - if err != nil { - return nil, err - } - return g.GetGroupMembersInfo(ctx, groupID, groupMemberIDs) -} - func (g *GroupCacheRedis) DelGroupMembersInfo(groupID string, userIDs ...string) cache.GroupCache { keys := make([]string, 0, len(userIDs)) for _, userID := range userIDs { @@ -388,42 +347,23 @@ func (g *GroupCacheRedis) GetGroupRolesLevelMemberInfo(ctx context.Context, grou return g.GetGroupMembersInfo(ctx, groupID, userIDs) } -func (g *GroupCacheRedis) FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (_ []*model.GroupMember, err error) { +func (g *GroupCacheRedis) FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) ([]*model.GroupMember, error) { if len(groupIDs) == 0 { + var err error groupIDs, err = g.GetJoinedGroupIDs(ctx, userID) if err != nil { return nil, err } } - return batchGetCache(ctx, g.rcClient, g.expireTime, groupIDs, func(groupID string) string { + return batchGetCache2(ctx, g.rcClient, g.expireTime, groupIDs, func(groupID string) string { return g.getGroupMemberInfoKey(groupID, userID) - }, func(ctx context.Context, groupID string) (*model.GroupMember, error) { - return g.groupMemberDB.Take(ctx, groupID, userID) + }, func(member *model.GroupMember) string { + return member.GroupID + }, func(ctx context.Context, groupIDs []string) ([]*model.GroupMember, error) { + return g.groupMemberDB.FindInGroup(ctx, userID, groupIDs) }) } -//func (g *GroupCacheRedis) FindSortGroupMemberUserIDs(ctx context.Context, groupID string) ([]string, error) { -// userIDs, err := g.GetGroupMemberIDs(ctx, groupID) -// if err != nil { -// return nil, err -// } -// if len(userIDs) > g.syncCount { -// userIDs = userIDs[:g.syncCount] -// } -// return userIDs, nil -//} -// -//func (g *GroupCacheRedis) FindSortJoinGroupIDs(ctx context.Context, userID string) ([]string, error) { -// groupIDs, err := g.GetJoinedGroupIDs(ctx, userID) -// if err != nil { -// return nil, err -// } -// if len(groupIDs) > g.syncCount { -// groupIDs = groupIDs[:g.syncCount] -// } -// return groupIDs, nil -//} - func (g *GroupCacheRedis) DelMaxGroupMemberVersion(groupIDs ...string) cache.GroupCache { keys := make([]string, 0, len(groupIDs)) for _, groupID := range groupIDs { diff --git a/pkg/common/storage/cache/redis/msg.go b/pkg/common/storage/cache/redis/msg.go index 2d21cfe13..30f367bb7 100644 --- a/pkg/common/storage/cache/redis/msg.go +++ b/pkg/common/storage/cache/redis/msg.go @@ -183,5 +183,4 @@ func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, return nil, nil, err } return seqMsgs, failedSeqs, nil - } diff --git a/pkg/common/storage/cache/redis/user.go b/pkg/common/storage/cache/redis/user.go index c3accd2c3..f6b490730 100644 --- a/pkg/common/storage/cache/redis/user.go +++ b/pkg/common/storage/cache/redis/user.go @@ -16,16 +16,12 @@ package redis import ( "context" - "encoding/json" "github.com/dtm-labs/rockscache" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" - "github.com/openimsdk/protocol/constant" - "github.com/openimsdk/protocol/user" - "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/redis/go-redis/v9" "time" @@ -58,9 +54,9 @@ func NewUserCacheRedis(rdb redis.UniversalClient, localCache *config.LocalCache, } } -//func (u *UserCacheRedis) getOnlineStatusKey(modKey string) string { -// return cachekey.GetOnlineStatusKey(modKey) -//} +func (u *UserCacheRedis) getUserID(user *model.User) string { + return user.UserID +} func (u *UserCacheRedis) CloneUserCache() cache.UserCache { return &UserCacheRedis{ @@ -87,11 +83,7 @@ func (u *UserCacheRedis) GetUserInfo(ctx context.Context, userID string) (userIn } func (u *UserCacheRedis) GetUsersInfo(ctx context.Context, userIDs []string) ([]*model.User, error) { - return batchGetCache(ctx, u.rcClient, u.expireTime, userIDs, func(userID string) string { - return u.getUserInfoKey(userID) - }, func(ctx context.Context, userID string) (*model.User, error) { - return u.userDB.Take(ctx, userID) - }) + return batchGetCache2(ctx, u.rcClient, u.expireTime, userIDs, u.getUserInfoKey, u.getUserID, u.userDB.Find) } func (u *UserCacheRedis) DelUsersInfo(userIDs ...string) cache.UserCache { @@ -127,83 +119,3 @@ func (u *UserCacheRedis) DelUsersGlobalRecvMsgOpt(userIDs ...string) cache.UserC return cache } - -func (u *UserCacheRedis) refreshStatusOffline(ctx context.Context, userID string, status, platformID int32, isNil bool, err error, result, key string) error { - if isNil { - log.ZWarn(ctx, "this user not online,maybe trigger order not right", - err, "userStatus", status) - - return nil - } - var onlineStatus user.OnlineStatus - err = json.Unmarshal([]byte(result), &onlineStatus) - if err != nil { - return errs.Wrap(err) - } - var newPlatformIDs []int32 - for _, val := range onlineStatus.PlatformIDs { - if val != platformID { - newPlatformIDs = append(newPlatformIDs, val) - } - } - if newPlatformIDs == nil { - _, err = u.rdb.HDel(ctx, key, userID).Result() - if err != nil { - return errs.Wrap(err) - } - } else { - onlineStatus.PlatformIDs = newPlatformIDs - newjsonData, err := json.Marshal(&onlineStatus) - if err != nil { - return errs.Wrap(err) - } - _, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result() - if err != nil { - return errs.Wrap(err) - } - } - - return nil -} - -func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string, platformID int32, isNil bool, err error, result, key string) error { - var onlineStatus user.OnlineStatus - if !isNil { - err := json.Unmarshal([]byte(result), &onlineStatus) - if err != nil { - return errs.Wrap(err) - } - onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID)) - } else { - onlineStatus.PlatformIDs = append(onlineStatus.PlatformIDs, platformID) - } - onlineStatus.Status = constant.Online - onlineStatus.UserID = userID - newjsonData, err := json.Marshal(&onlineStatus) - if err != nil { - return errs.WrapMsg(err, "json.Marshal failed") - } - _, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result() - if err != nil { - return errs.Wrap(err) - } - - return nil -} - -type Comparable interface { - ~int | ~string | ~float64 | ~int32 -} - -func RemoveRepeatedElementsInList[T Comparable](slc []T) []T { - var result []T - tempMap := map[T]struct{}{} - for _, e := range slc { - if _, found := tempMap[e]; !found { - tempMap[e] = struct{}{} - result = append(result, e) - } - } - - return result -} diff --git a/pkg/common/storage/controller/user.go b/pkg/common/storage/controller/user.go index 321eff03c..59559537b 100644 --- a/pkg/common/storage/controller/user.go +++ b/pkg/common/storage/controller/user.go @@ -195,7 +195,7 @@ func (u *userDatabase) GetAllUserID(ctx context.Context, pagination pagination.P } func (u *userDatabase) GetUserByID(ctx context.Context, userID string) (user *model.User, err error) { - return u.userDB.Take(ctx, userID) + return u.cache.GetUserInfo(ctx, userID) } // CountTotal Get the total number of users. diff --git a/pkg/common/storage/database/group_member.go b/pkg/common/storage/database/group_member.go index c272b6ef6..43a7e6095 100644 --- a/pkg/common/storage/database/group_member.go +++ b/pkg/common/storage/database/group_member.go @@ -28,6 +28,8 @@ type GroupMember interface { UpdateUserRoleLevels(ctx context.Context, groupID string, firstUserID string, firstUserRoleLevel int32, secondUserID string, secondUserRoleLevel int32) error FindMemberUserID(ctx context.Context, groupID string) (userIDs []string, err error) Take(ctx context.Context, groupID string, userID string) (groupMember *model.GroupMember, err error) + Find(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupMember, error) + FindInGroup(ctx context.Context, userID string, groupIDs []string) ([]*model.GroupMember, error) TakeOwner(ctx context.Context, groupID string) (groupMember *model.GroupMember, err error) SearchMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (total int64, groupList []*model.GroupMember, err error) FindRoleLevelUserIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error) diff --git a/pkg/common/storage/database/mgo/group_member.go b/pkg/common/storage/database/mgo/group_member.go index 3eb93a10e..f89822d3c 100644 --- a/pkg/common/storage/database/mgo/group_member.go +++ b/pkg/common/storage/database/mgo/group_member.go @@ -153,6 +153,22 @@ func (g *GroupMemberMgo) FindMemberUserID(ctx context.Context, groupID string) ( return mongoutil.Find[string](ctx, g.coll, bson.M{"group_id": groupID}, options.Find().SetProjection(bson.M{"_id": 0, "user_id": 1}).SetSort(g.memberSort())) } +func (g *GroupMemberMgo) Find(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupMember, error) { + filter := bson.M{"group_id": groupID} + if len(userIDs) > 0 { + filter["user_id"] = bson.M{"$in": userIDs} + } + return mongoutil.Find[*model.GroupMember](ctx, g.coll, filter) +} + +func (g *GroupMemberMgo) FindInGroup(ctx context.Context, userID string, groupIDs []string) ([]*model.GroupMember, error) { + filter := bson.M{"user_id": userID} + if len(groupIDs) > 0 { + filter["group_id"] = bson.M{"$in": groupIDs} + } + return mongoutil.Find[*model.GroupMember](ctx, g.coll, filter) +} + func (g *GroupMemberMgo) Take(ctx context.Context, groupID string, userID string) (groupMember *model.GroupMember, err error) { return mongoutil.FindOne[*model.GroupMember](ctx, g.coll, bson.M{"group_id": groupID, "user_id": userID}) }