diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 26c9b1da4..9ef0b1667 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -12,6 +12,7 @@ import ( "Open_IM/pkg/common/middleware" promePkg "Open_IM/pkg/common/prometheus" "Open_IM/pkg/common/token_verify" + "Open_IM/pkg/common/tools" "Open_IM/pkg/common/trace_log" cp "Open_IM/pkg/common/utils" "Open_IM/pkg/getcdv3" @@ -270,10 +271,11 @@ func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbGroup.GetJo func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.InviteUserToGroupReq) (*pbGroup.InviteUserToGroupResp, error) { resp := &pbGroup.InviteUserToGroupResp{} - - if !imdb.IsExistGroupMember(req.GroupID, req.OpUserID) && !token_verify.IsManagerUserID(req.OpUserID) { - constant.SetErrorForResp(constant.ErrIdentity, resp.CommonResp) - return nil, utils.Wrap(constant.ErrIdentity, "") + opUserID := tools.OpUserID(ctx) + if err := token_verify.CheckManagerUserID(ctx, opUserID); err != nil { + if err := imdb.CheckIsExistGroupMember(ctx, req.GroupID, opUserID); err != nil { + return nil, err + } } groupInfo, err := (*imdb.Group)(nil).Take(ctx, req.GroupID) if err != nil { diff --git a/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go b/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go index 03ef2b741..b695119ad 100644 --- a/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go +++ b/pkg/common/db/mysql_model/im_mysql_model/group_member_model_k.go @@ -211,6 +211,18 @@ func IsExistGroupMember(groupID, userID string) bool { return true } +func CheckIsExistGroupMember(ctx context.Context, groupID, userID string) error { + var number int64 + err := GroupMemberDB.Table("group_members").Where("group_id = ? and user_id = ?", groupID, userID).Count(&number).Error + if err != nil { + return constant.ErrDB.Wrap() + } + if number != 1 { + return constant.ErrData.Wrap() + } + return nil +} + func GetGroupMemberByGroupID(groupID string, filter int32, begin int32, maxNumber int32) ([]GroupMember, error) { var memberList []GroupMember var err error diff --git a/pkg/common/token_verify/jwt_token.go b/pkg/common/token_verify/jwt_token.go index 3270aa9f3..438a972ca 100644 --- a/pkg/common/token_verify/jwt_token.go +++ b/pkg/common/token_verify/jwt_token.go @@ -143,11 +143,11 @@ func IsManagerUserID(OpUserID string) bool { } } -func CheckManagerUserID(ctx context.Context) error { - if utils.IsContain(tools.OpUserID(ctx), config.Config.Manager.AppManagerUid) { +func CheckManagerUserID(ctx context.Context, userID string) error { + if utils.IsContain(userID, config.Config.Manager.AppManagerUid) { return nil } - return constant.ErrIdentity.Wrap(utils.GetSelfFuncName()) + return constant.ErrNoPermission.Wrap() } func CheckAccess(ctx context.Context, OpUserID string, OwnerUserID string) bool { @@ -184,7 +184,7 @@ func CheckAccessV3(ctx context.Context, OwnerUserID string) (err error) { if opUserID == OwnerUserID { return nil } - return utils.Wrap(constant.ErrIdentity, open_utils.GetSelfFuncName()) + return constant.ErrIdentity.Wrap(utils.GetSelfFuncName()) } func GetUserIDFromToken(token string, operationID string) (bool, string, string) {