diff --git a/internal/api/msg.go b/internal/api/msg.go index 14df448f8..da0b8910a 100644 --- a/internal/api/msg.go +++ b/internal/api/msg.go @@ -252,14 +252,24 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) { var recvIDs []string var err error if req.IsSendAll { - recvIDs, err = m.userRpcClient.GetAllUserIDs(c) - if err != nil { - log.ZError(c, "GetAllUserIDs failed", err) - apiresp.GinError(c, err) + pageNumber := 1 + showNumber := 100 + for { + recvIDsPart, err := m.userRpcClient.GetAllUserIDs(c, int32(pageNumber), int32(showNumber)) + if err != nil { + log.ZError(c, "GetAllUserIDs failed", err) + apiresp.GinError(c, err) + } + if len(recvIDsPart) < showNumber { + recvIDs = append(recvIDs, recvIDsPart...) + break + } + pageNumber++ } } else { recvIDs = req.RecvIDs } + log.ZDebug(c, "BatchSendMsg nums", "nums ", len(recvIDs)) sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg) if err != nil { log.ZError(c, "decodeData failed", err) diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index 9b1e691e5..3b0b67691 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -17,10 +17,11 @@ package user import ( "context" "errors" - "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "strings" "time" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/config" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/convert" @@ -233,7 +234,7 @@ func (s *userServer) GetGlobalRecvMessageOpt(ctx context.Context, req *pbuser.Ge } func (s *userServer) GetAllUserID(ctx context.Context, req *pbuser.GetAllUserIDReq) (resp *pbuser.GetAllUserIDResp, err error) { - userIDs, err := s.UserDatabase.GetAllUserID(ctx) + userIDs, err := s.UserDatabase.GetAllUserID(ctx, req.Pagination.PageNumber, req.Pagination.ShowNumber) if err != nil { return nil, err } diff --git a/pkg/common/db/controller/user.go b/pkg/common/db/controller/user.go index db725ae60..fe14f0a1c 100644 --- a/pkg/common/db/controller/user.go +++ b/pkg/common/db/controller/user.go @@ -41,7 +41,7 @@ type UserDatabase interface { //只要有一个存在就为true IsExist(ctx context.Context, userIDs []string) (exist bool, err error) //获取所有用户ID - GetAllUserID(ctx context.Context) ([]string, error) + GetAllUserID(ctx context.Context, pageNumber, showNumber int32) ([]string, error) //函数内部先查询db中是否存在,存在则什么都不做;不存在则插入 InitOnce(ctx context.Context, users []*relation.UserModel) (err error) // 获取用户总数 @@ -147,8 +147,8 @@ func (u *userDatabase) IsExist(ctx context.Context, userIDs []string) (exist boo return false, nil } -func (u *userDatabase) GetAllUserID(ctx context.Context) (userIDs []string, err error) { - return u.userDB.GetAllUserID(ctx) +func (u *userDatabase) GetAllUserID(ctx context.Context, pageNumber, showNumber int32) (userIDs []string, err error) { + return u.userDB.GetAllUserID(ctx, pageNumber, showNumber) } func (u *userDatabase) CountTotal(ctx context.Context, before *time.Time) (count int64, err error) { diff --git a/pkg/common/db/relation/user_model.go b/pkg/common/db/relation/user_model.go index a073c7b4d..3da877c19 100644 --- a/pkg/common/db/relation/user_model.go +++ b/pkg/common/db/relation/user_model.go @@ -84,8 +84,8 @@ func (u *UserGorm) Page( } // 获取所有用户ID -func (u *UserGorm) GetAllUserID(ctx context.Context) (userIDs []string, err error) { - return userIDs, errs.Wrap(u.db(ctx).Pluck("user_id", &userIDs).Error) +func (u *UserGorm) GetAllUserID(ctx context.Context, pageNumber, showNumber int32) (userIDs []string, err error) { + return userIDs, errs.Wrap(u.db(ctx).Limit(int(showNumber)).Offset(int((pageNumber-1)*showNumber)).Pluck("user_id", &userIDs).Error) } func (u *UserGorm) GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error) { diff --git a/pkg/common/db/table/relation/user.go b/pkg/common/db/table/relation/user.go index f44a610ad..c840f7070 100644 --- a/pkg/common/db/table/relation/user.go +++ b/pkg/common/db/table/relation/user.go @@ -63,7 +63,7 @@ type UserModelInterface interface { Take(ctx context.Context, userID string) (user *UserModel, err error) // 获取用户信息 不存在,不返回错误 Page(ctx context.Context, pageNumber, showNumber int32) (users []*UserModel, count int64, err error) - GetAllUserID(ctx context.Context) (userIDs []string, err error) + GetAllUserID(ctx context.Context, pageNumber, showNumber int32) (userIDs []string, err error) GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error) // 获取用户总数 CountTotal(ctx context.Context, before *time.Time) (count int64, err error) diff --git a/pkg/rpcclient/user.go b/pkg/rpcclient/user.go index a5eb1bed0..023c26d3c 100644 --- a/pkg/rpcclient/user.go +++ b/pkg/rpcclient/user.go @@ -147,8 +147,8 @@ func (u *UserRpcClient) Access(ctx context.Context, ownerUserID string) error { return tokenverify.CheckAccessV3(ctx, ownerUserID) } -func (u *UserRpcClient) GetAllUserIDs(ctx context.Context) ([]string, error) { - resp, err := u.Client.GetAllUserID(ctx, &user.GetAllUserIDReq{}) +func (u *UserRpcClient) GetAllUserIDs(ctx context.Context, pageNumber, showNumber int32) ([]string, error) { + resp, err := u.Client.GetAllUserID(ctx, &user.GetAllUserIDReq{Pagination: &sdkws.RequestPagination{PageNumber: pageNumber, ShowNumber: showNumber}}) if err != nil { return nil, err }