diff --git a/go.mod b/go.mod index 228458b20..dd4034ca5 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/mitchellh/mapstructure v1.5.0 github.com/openimsdk/protocol v0.0.73-alpha.8 - github.com/openimsdk/tools v0.0.50-alpha.81 + github.com/openimsdk/tools v0.0.50-alpha.83 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.18.0 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index c2f3164d7..8278480c8 100644 --- a/go.sum +++ b/go.sum @@ -349,8 +349,8 @@ github.com/openimsdk/gomake v0.0.15-alpha.5 h1:eEZCEHm+NsmcO3onXZPIUbGFCYPYbsX5b github.com/openimsdk/gomake v0.0.15-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= github.com/openimsdk/protocol v0.0.73-alpha.8 h1:GqksOHXWZSqRQaGYvuVQ4IzA7kFhIXSk7NZk0LGk35A= github.com/openimsdk/protocol v0.0.73-alpha.8/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= -github.com/openimsdk/tools v0.0.50-alpha.81 h1:VbuJKtigNXLkCKB/Q6f2UHsqoSaTOAwS8F51c1nhOCA= -github.com/openimsdk/tools v0.0.50-alpha.81/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo= +github.com/openimsdk/tools v0.0.50-alpha.83 h1:7c1D40YGqIWUmGfCII5pduETGC/8c2DyS9SQ4LvoplU= +github.com/openimsdk/tools v0.0.50-alpha.83/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= diff --git a/internal/api/conversation.go b/internal/api/conversation.go index f6eadf15a..5a191c0ec 100644 --- a/internal/api/conversation.go +++ b/internal/api/conversation.go @@ -49,9 +49,9 @@ func (o *ConversationApi) SetConversations(c *gin.Context) { a2r.Call(c, conversation.ConversationClient.SetConversations, o.Client) } -func (o *ConversationApi) GetConversationOfflinePushUserIDs(c *gin.Context) { - a2r.Call(c, conversation.ConversationClient.GetConversationOfflinePushUserIDs, o.Client) -} +//func (o *ConversationApi) GetConversationOfflinePushUserIDs(c *gin.Context) { +// a2r.Call(c, conversation.ConversationClient.GetConversationOfflinePushUserIDs, o.Client) +//} func (o *ConversationApi) GetFullOwnerConversationIDs(c *gin.Context) { a2r.Call(c, conversation.ConversationClient.GetFullOwnerConversationIDs, o.Client) diff --git a/internal/api/router.go b/internal/api/router.go index 7c2c78c7d..700d8392e 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -9,12 +9,8 @@ import ( "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" - "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/tools/mcontext" - "github.com/openimsdk/tools/utils/datautil" - clientv3 "go.etcd.io/etcd/client/v3" - "github.com/openimsdk/open-im-server/v3/internal/api/jssdk" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" @@ -32,6 +28,7 @@ import ( "github.com/openimsdk/tools/discovery/etcd" "github.com/openimsdk/tools/log" "github.com/openimsdk/tools/mw" + clientv3 "go.etcd.io/etcd/client/v3" ) const ( @@ -265,13 +262,12 @@ func newGinRouter(ctx context.Context, client discovery.Conn, cfg *Config) (*gin conversationGroup.POST("/get_conversation", c.GetConversation) conversationGroup.POST("/get_conversations", c.GetConversations) conversationGroup.POST("/set_conversations", c.SetConversations) - conversationGroup.POST("/get_conversation_offline_push_user_ids", c.GetConversationOfflinePushUserIDs) + //conversationGroup.POST("/get_conversation_offline_push_user_ids", c.GetConversationOfflinePushUserIDs) conversationGroup.POST("/get_full_conversation_ids", c.GetFullOwnerConversationIDs) conversationGroup.POST("/get_incremental_conversations", c.GetIncrementalConversation) conversationGroup.POST("/get_owner_conversation", c.GetOwnerConversation) conversationGroup.POST("/get_not_notify_conversation_ids", c.GetNotNotifyConversationIDs) conversationGroup.POST("/get_pinned_conversation_ids", c.GetPinnedConversationIDs) - conversationGroup.POST("/update_conversations_by_user", c.UpdateConversationsByUser) } { @@ -315,7 +311,6 @@ func newGinRouter(ctx context.Context, client discovery.Conn, cfg *Config) (*gin configGroup.POST("/get_config_list", cm.GetConfigList) configGroup.POST("/get_config", cm.GetConfig) configGroup.POST("/set_config", cm.SetConfig) - configGroup.POST("/set_configs", cm.SetConfigs) configGroup.POST("/reset_config", cm.ResetConfig) configGroup.POST("/set_enable_config_manager", cm.SetEnableConfigManager) configGroup.POST("/get_enable_config_manager", cm.GetEnableConfigManager) @@ -359,9 +354,7 @@ func GinParseToken(authClient *rpcli.AuthClient) gin.HandlerFunc { func setGinIsAdmin(imAdminUserID []string) gin.HandlerFunc { return func(c *gin.Context) { - opUserID := mcontext.GetOpUserID(c) - admin := datautil.Contain(opUserID, imAdminUserID...) - c.Set(authverify.CtxIsAdminKey, admin) + c.Set(authverify.CtxAdminUserIDsKey, imAdminUserID) } } diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index d69062bda..211482cb1 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -49,7 +49,7 @@ type WsServer struct { unregisterChan chan *Client kickHandlerChan chan *kickHandler clients UserMap - online *rpccache.OnlineCache + online rpccache.OnlineCache subscription *Subscription clientPool sync.Pool onlineUserNum atomic.Int64 diff --git a/internal/msgtransfer/online_msg_to_mongo_handler.go b/internal/msgtransfer/online_msg_to_mongo_handler.go index ae14d02a1..a895bb9c4 100644 --- a/internal/msgtransfer/online_msg_to_mongo_handler.go +++ b/internal/msgtransfer/online_msg_to_mongo_handler.go @@ -48,15 +48,7 @@ func (mc *OnlineHistoryMongoConsumerHandler) HandleChatWs2Mongo(ctx context.Cont log.ZDebug(ctx, "mongo consumer recv msg", "msgs", msgFromMQ.String()) err = mc.msgTransferDatabase.BatchInsertChat2DB(ctx, msgFromMQ.ConversationID, msgFromMQ.MsgData, msgFromMQ.LastSeq) if err != nil { - log.ZError( - ctx, - "single data insert to mongo err", - err, - "msg", - msgFromMQ.MsgData, - "conversationID", - msgFromMQ.ConversationID, - ) + log.ZError(ctx, "single data insert to mongo err", err, "msg", msgFromMQ.MsgData, "conversationID", msgFromMQ.ConversationID) prommetrics.MsgInsertMongoFailedCounter.Inc() } else { prommetrics.MsgInsertMongoSuccessCounter.Inc() diff --git a/internal/push/push_handler.go b/internal/push/push_handler.go index 418c4c7f2..7b1efe3bf 100644 --- a/internal/push/push_handler.go +++ b/internal/push/push_handler.go @@ -33,7 +33,7 @@ type ConsumerHandler struct { offlinePusher offlinepush.OfflinePusher onlinePusher OnlinePusher pushDatabase controller.PushDatabase - onlineCache *rpccache.OnlineCache + onlineCache rpccache.OnlineCache groupLocalCache *rpccache.GroupLocalCache conversationLocalCache *rpccache.ConversationLocalCache webhookClient *webhook.Client @@ -120,11 +120,7 @@ func (c *ConsumerHandler) HandleMs2PsChat(ctx context.Context, msg []byte) { } func (c *ConsumerHandler) WaitCache() { - c.onlineCache.Lock.Lock() - for c.onlineCache.CurrentPhase.Load() < rpccache.DoSubscribeOver { - c.onlineCache.Cond.Wait() - } - c.onlineCache.Lock.Unlock() + c.onlineCache.WaitCache() } // Push2User Suitable for two types of conversations, one is SingleChatType and the other is NotificationChatType. diff --git a/internal/rpc/conversation/conversation.go b/internal/rpc/conversation/conversation.go index ca2c58878..b0b1053ed 100644 --- a/internal/rpc/conversation/conversation.go +++ b/internal/rpc/conversation/conversation.go @@ -19,6 +19,7 @@ import ( "sort" "time" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" @@ -117,6 +118,9 @@ func Start(ctx context.Context, config *Config, client discovery.Conn, server gr } func (c *conversationServer) GetConversation(ctx context.Context, req *pbconversation.GetConversationReq) (*pbconversation.GetConversationResp, error) { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } conversations, err := c.conversationDatabase.FindConversations(ctx, req.OwnerUserID, []string{req.ConversationID}) if err != nil { return nil, err @@ -130,7 +134,9 @@ func (c *conversationServer) GetConversation(ctx context.Context, req *pbconvers } func (c *conversationServer) GetSortedConversationList(ctx context.Context, req *pbconversation.GetSortedConversationListReq) (resp *pbconversation.GetSortedConversationListResp, err error) { - log.ZDebug(ctx, "GetSortedConversationList", "seqs", req, "userID", req.UserID) + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } var conversationIDs []string if len(req.ConversationIDs) == 0 { conversationIDs, err = c.conversationDatabase.GetConversationIDs(ctx, req.UserID) @@ -203,6 +209,9 @@ func (c *conversationServer) GetSortedConversationList(ctx context.Context, req } func (c *conversationServer) GetAllConversations(ctx context.Context, req *pbconversation.GetAllConversationsReq) (*pbconversation.GetAllConversationsResp, error) { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } conversations, err := c.conversationDatabase.GetUserAllConversation(ctx, req.OwnerUserID) if err != nil { return nil, err @@ -213,6 +222,9 @@ func (c *conversationServer) GetAllConversations(ctx context.Context, req *pbcon } func (c *conversationServer) GetConversations(ctx context.Context, req *pbconversation.GetConversationsReq) (*pbconversation.GetConversationsResp, error) { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } conversations, err := c.getConversations(ctx, req.OwnerUserID, req.ConversationIDs) if err != nil { return nil, err @@ -233,6 +245,9 @@ func (c *conversationServer) getConversations(ctx context.Context, ownerUserID s } func (c *conversationServer) SetConversation(ctx context.Context, req *pbconversation.SetConversationReq) (*pbconversation.SetConversationResp, error) { + if err := authverify.CheckAccess(ctx, req.GetConversation().GetUserID()); err != nil { + return nil, err + } var conversation dbModel.Conversation if err := datautil.CopyStructFields(&conversation, req.Conversation); err != nil { return nil, err @@ -247,10 +262,11 @@ func (c *conversationServer) SetConversation(ctx context.Context, req *pbconvers } func (c *conversationServer) SetConversations(ctx context.Context, req *pbconversation.SetConversationsReq) (*pbconversation.SetConversationsResp, error) { - if req.Conversation == nil { - return nil, errs.ErrArgs.WrapMsg("conversation must not be nil") + for _, userID := range req.UserIDs { + if err := authverify.CheckAccess(ctx, userID); err != nil { + return nil, err + } } - if req.Conversation.ConversationType == constant.WriteGroupChatType { groupInfo, err := c.groupClient.GetGroupInfo(ctx, req.Conversation.GroupID) if err != nil { @@ -331,6 +347,9 @@ func (c *conversationServer) SetConversations(ctx context.Context, req *pbconver } func (c *conversationServer) UpdateConversationsByUser(ctx context.Context, req *pbconversation.UpdateConversationsByUserReq) (*pbconversation.UpdateConversationsByUserResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } m := make(map[string]any) if req.Ex != nil { m["ex"] = req.Ex.Value @@ -343,15 +362,8 @@ func (c *conversationServer) UpdateConversationsByUser(ctx context.Context, req return &pbconversation.UpdateConversationsByUserResp{}, nil } -// Get user IDs with "Do Not Disturb" enabled in super large groups. -func (c *conversationServer) GetRecvMsgNotNotifyUserIDs(ctx context.Context, req *pbconversation.GetRecvMsgNotNotifyUserIDsReq) (*pbconversation.GetRecvMsgNotNotifyUserIDsResp, error) { - return nil, errs.New("deprecated") -} - // create conversation without notification for msg redis transfer. -func (c *conversationServer) CreateSingleChatConversations(ctx context.Context, - req *pbconversation.CreateSingleChatConversationsReq, -) (*pbconversation.CreateSingleChatConversationsResp, error) { +func (c *conversationServer) CreateSingleChatConversations(ctx context.Context, req *pbconversation.CreateSingleChatConversationsReq) (*pbconversation.CreateSingleChatConversationsResp, error) { var conversation dbModel.Conversation switch req.ConversationType { case constant.SingleChatType: @@ -454,6 +466,9 @@ func (c *conversationServer) SetConversationMinSeq(ctx context.Context, req *pbc } func (c *conversationServer) GetConversationIDs(ctx context.Context, req *pbconversation.GetConversationIDsReq) (*pbconversation.GetConversationIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } conversationIDs, err := c.conversationDatabase.GetConversationIDs(ctx, req.UserID) if err != nil { return nil, err @@ -462,6 +477,9 @@ func (c *conversationServer) GetConversationIDs(ctx context.Context, req *pbconv } func (c *conversationServer) GetUserConversationIDsHash(ctx context.Context, req *pbconversation.GetUserConversationIDsHashReq) (*pbconversation.GetUserConversationIDsHashResp, error) { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } hash, err := c.conversationDatabase.GetUserConversationIDsHash(ctx, req.OwnerUserID) if err != nil { return nil, err @@ -469,10 +487,7 @@ func (c *conversationServer) GetUserConversationIDsHash(ctx context.Context, req return &pbconversation.GetUserConversationIDsHashResp{Hash: hash}, nil } -func (c *conversationServer) GetConversationsByConversationID( - ctx context.Context, - req *pbconversation.GetConversationsByConversationIDReq, -) (*pbconversation.GetConversationsByConversationIDResp, error) { +func (c *conversationServer) GetConversationsByConversationID(ctx context.Context, req *pbconversation.GetConversationsByConversationIDReq) (*pbconversation.GetConversationsByConversationIDResp, error) { conversations, err := c.conversationDatabase.GetConversationsByConversationID(ctx, req.ConversationIDs) if err != nil { return nil, err @@ -526,10 +541,7 @@ func (c *conversationServer) conversationSort(conversations map[int64]string, re resp.ConversationElems = append(resp.ConversationElems, cons...) } -func (c *conversationServer) getConversationInfo( - ctx context.Context, - chatLogs map[string]*sdkws.MsgData, - userID string) (map[string]*pbconversation.ConversationElem, error) { +func (c *conversationServer) getConversationInfo(ctx context.Context, chatLogs map[string]*sdkws.MsgData, userID string) (map[string]*pbconversation.ConversationElem, error) { var ( sendIDs []string groupIDs []string @@ -615,6 +627,11 @@ func (c *conversationServer) GetConversationNotReceiveMessageUserIDs(ctx context } func (c *conversationServer) UpdateConversation(ctx context.Context, req *pbconversation.UpdateConversationReq) (*pbconversation.UpdateConversationResp, error) { + for _, userID := range req.UserIDs { + if err := authverify.CheckAccess(ctx, userID); err != nil { + return nil, err + } + } m := make(map[string]any) if req.RecvMsgOpt != nil { m["recv_msg_opt"] = req.RecvMsgOpt.Value @@ -661,6 +678,9 @@ func (c *conversationServer) UpdateConversation(ctx context.Context, req *pbconv } func (c *conversationServer) GetOwnerConversation(ctx context.Context, req *pbconversation.GetOwnerConversationReq) (*pbconversation.GetOwnerConversationResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } total, conversations, err := c.conversationDatabase.GetOwnerConversation(ctx, req.UserID, req.Pagination) if err != nil { return nil, err @@ -722,6 +742,9 @@ func (c *conversationServer) GetConversationsNeedClearMsg(ctx context.Context, _ } func (c *conversationServer) GetNotNotifyConversationIDs(ctx context.Context, req *pbconversation.GetNotNotifyConversationIDsReq) (*pbconversation.GetNotNotifyConversationIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } conversationIDs, err := c.conversationDatabase.GetNotNotifyConversationIDs(ctx, req.UserID) if err != nil { return nil, err @@ -730,6 +753,9 @@ func (c *conversationServer) GetNotNotifyConversationIDs(ctx context.Context, re } func (c *conversationServer) GetPinnedConversationIDs(ctx context.Context, req *pbconversation.GetPinnedConversationIDsReq) (*pbconversation.GetPinnedConversationIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } conversationIDs, err := c.conversationDatabase.GetPinnedConversationIDs(ctx, req.UserID) if err != nil { return nil, err diff --git a/internal/rpc/conversation/sync.go b/internal/rpc/conversation/sync.go index cee74b319..a24dd85c6 100644 --- a/internal/rpc/conversation/sync.go +++ b/internal/rpc/conversation/sync.go @@ -35,6 +35,9 @@ func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, re } func (c *conversationServer) GetIncrementalConversation(ctx context.Context, req *conversation.GetIncrementalConversationReq) (*conversation.GetIncrementalConversationResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } opt := incrversion.Option[*conversation.Conversation, conversation.GetIncrementalConversationResp]{ Ctx: ctx, VersionKey: req.UserID, diff --git a/internal/rpc/group/cache.go b/internal/rpc/group/cache.go index bcd246267..ec0e5b566 100644 --- a/internal/rpc/group/cache.go +++ b/internal/rpc/group/cache.go @@ -17,6 +17,7 @@ package group import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/convert" pbgroup "github.com/openimsdk/protocol/group" ) @@ -33,6 +34,9 @@ func (g *groupServer) GetGroupInfoCache(ctx context.Context, req *pbgroup.GetGro } func (g *groupServer) GetGroupMemberCache(ctx context.Context, req *pbgroup.GetGroupMemberCacheReq) (*pbgroup.GetGroupMemberCacheResp, error) { + if err := authverify.CheckAccess(ctx, req.GroupMemberID); err != nil { + return nil, err + } members, err := g.db.TakeGroupMember(ctx, req.GroupID, req.GroupMemberID) if err != nil { return nil, err diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 10cdc2546..2026ba71b 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -476,6 +476,19 @@ func (g *groupServer) GetGroupAllMember(ctx context.Context, req *pbgroup.GetGro if err != nil { return nil, err } + if !authverify.IsAdmin(ctx) { + var inGroup bool + opUserID := mcontext.GetOpUserID(ctx) + for _, member := range members { + if member.UserID == opUserID { + inGroup = true + break + } + } + if !inGroup { + return nil, errs.ErrNoPermission.WrapMsg("opuser not in group") + } + } if err := g.PopulateGroupMember(ctx, members...); err != nil { return nil, err } @@ -486,11 +499,24 @@ func (g *groupServer) GetGroupAllMember(ctx context.Context, req *pbgroup.GetGro return &resp, nil } +func (g *groupServer) checkAdminOrInGroup(ctx context.Context, groupID string) error { + if authverify.IsAdmin(ctx) { + return nil + } + opUserID := mcontext.GetOpUserID(ctx) + members, err := g.db.FindGroupMembers(ctx, groupID, []string{opUserID}) + if err != nil { + return err + } + if len(members) == 0 { + return errs.ErrNoPermission.WrapMsg("op user not in group") + } + return nil +} + func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGroupMemberListReq) (*pbgroup.GetGroupMemberListResp, error) { - if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { - if _, err := g.db.TakeGroupMember(ctx, req.GroupID, opUserID); err != nil { - return nil, err - } + if err := g.checkAdminOrInGroup(ctx, req.GroupID); err != nil { + return nil, err } var ( total int64 @@ -631,6 +657,9 @@ func (g *groupServer) GetGroupMembersInfo(ctx context.Context, req *pbgroup.GetG if req.GroupID == "" { return nil, errs.ErrArgs.WrapMsg("groupID empty") } + if err := g.checkAdminOrInGroup(ctx, req.GroupID); err != nil { + return nil, err + } members, err := g.getGroupMembersInfo(ctx, req.GroupID, req.UserIDs) if err != nil { return nil, err @@ -658,6 +687,9 @@ func (g *groupServer) getGroupMembersInfo(ctx context.Context, groupID string, u // GetGroupApplicationList handles functions that get a list of group requests. func (g *groupServer) GetGroupApplicationList(ctx context.Context, req *pbgroup.GetGroupApplicationListReq) (*pbgroup.GetGroupApplicationListResp, error) { + if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil { + return nil, err + } groupIDs, err := g.db.FindUserManagedGroupID(ctx, req.FromUserID) if err != nil { return nil, err @@ -1652,6 +1684,11 @@ func (g *groupServer) GetGroupAbstractInfo(ctx context.Context, req *pbgroup.Get if datautil.Duplicate(req.GroupIDs) { return nil, errs.ErrArgs.WrapMsg("groupIDs duplicate") } + for _, groupID := range req.GroupIDs { + if err := g.checkAdminOrInGroup(ctx, groupID); err != nil { + return nil, err + } + } groups, err := g.db.FindGroup(ctx, req.GroupIDs) if err != nil { return nil, err @@ -1699,6 +1736,9 @@ func (g *groupServer) GetGroupMemberUserIDs(ctx context.Context, req *pbgroup.Ge if err != nil { return nil, err } + if err := authverify.CheckAccessIn(ctx, userIDs...); err != nil { + return nil, err + } return &pbgroup.GetGroupMemberUserIDsResp{ UserIDs: userIDs, }, nil diff --git a/internal/rpc/group/statistics.go b/internal/rpc/group/statistics.go index 1c582fda1..4ee3396da 100644 --- a/internal/rpc/group/statistics.go +++ b/internal/rpc/group/statistics.go @@ -18,6 +18,7 @@ import ( "context" "time" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/protocol/group" "github.com/openimsdk/tools/errs" ) @@ -26,6 +27,9 @@ func (g *groupServer) GroupCreateCount(ctx context.Context, req *group.GroupCrea if req.Start > req.End { return nil, errs.ErrArgs.WrapMsg("start > end: %d > %d", req.Start, req.End) } + if err := authverify.CheckAdmin(ctx); err != nil { + return nil, err + } total, err := g.db.CountTotal(ctx, nil) if err != nil { return nil, err diff --git a/internal/rpc/group/sync.go b/internal/rpc/group/sync.go index 822b15307..baee2f2d4 100644 --- a/internal/rpc/group/sync.go +++ b/internal/rpc/group/sync.go @@ -11,9 +11,6 @@ import ( "github.com/openimsdk/protocol/constant" pbgroup "github.com/openimsdk/protocol/group" "github.com/openimsdk/protocol/sdkws" - "github.com/openimsdk/tools/errs" - "github.com/openimsdk/tools/mcontext" - "github.com/openimsdk/tools/utils/datautil" ) const versionSyncLimit = 500 @@ -23,10 +20,8 @@ func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgrou if err != nil { return nil, err } - if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { - if !datautil.Contain(opUserID, userIDs...) { - return nil, errs.ErrNoPermission.WrapMsg("user not in group") - } + if err := authverify.CheckAccessIn(ctx, userIDs...); err != nil { + return nil, err } vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) if err != nil { @@ -69,6 +64,9 @@ func (g *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetF } func (g *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgroup.GetIncrementalGroupMemberReq) (*pbgroup.GetIncrementalGroupMemberResp, error) { + if err := g.checkAdminOrInGroup(ctx, req.GroupID); err != nil { + return nil, err + } group, err := g.db.TakeGroup(ctx, req.GroupID) if err != nil { return nil, err @@ -76,9 +74,6 @@ func (g *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou if group.Status == constant.GroupStatusDismissed { return nil, servererrs.ErrDismissedAlready.Wrap() } - if _, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)); err != nil { - return nil, err - } var ( hasGroupUpdate bool sortVersion uint64 diff --git a/internal/rpc/msg/as_read.go b/internal/rpc/msg/as_read.go index b25eae6b1..c52ce9c07 100644 --- a/internal/rpc/msg/as_read.go +++ b/internal/rpc/msg/as_read.go @@ -18,6 +18,7 @@ import ( "context" "errors" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" cbapi "github.com/openimsdk/open-im-server/v3/pkg/callbackstruct" "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/msg" @@ -29,6 +30,9 @@ import ( ) func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *msg.GetConversationsHasReadAndMaxSeqReq) (*msg.GetConversationsHasReadAndMaxSeqResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } var conversationIDs []string if len(req.ConversationIDs) == 0 { var err error @@ -82,6 +86,9 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m } func (m *msgServer) SetConversationHasReadSeq(ctx context.Context, req *msg.SetConversationHasReadSeqReq) (*msg.SetConversationHasReadSeqResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } maxSeq, err := m.MsgDatabase.GetMaxSeq(ctx, req.ConversationID) if err != nil { return nil, err @@ -97,8 +104,8 @@ func (m *msgServer) SetConversationHasReadSeq(ctx context.Context, req *msg.SetC } func (m *msgServer) MarkMsgsAsRead(ctx context.Context, req *msg.MarkMsgsAsReadReq) (*msg.MarkMsgsAsReadResp, error) { - if len(req.Seqs) < 1 { - return nil, errs.ErrArgs.WrapMsg("seqs must not be empty") + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err } maxSeq, err := m.MsgDatabase.GetMaxSeq(ctx, req.ConversationID) if err != nil { @@ -139,6 +146,9 @@ func (m *msgServer) MarkMsgsAsRead(ctx context.Context, req *msg.MarkMsgsAsReadR } func (m *msgServer) MarkConversationAsRead(ctx context.Context, req *msg.MarkConversationAsReadReq) (*msg.MarkConversationAsReadResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } conversation, err := m.ConversationLocalCache.GetConversation(ctx, req.UserID, req.ConversationID) if err != nil { return nil, err @@ -216,5 +226,4 @@ func (m *msgServer) sendMarkAsReadNotification(ctx context.Context, conversation HasReadSeq: hasReadSeq, } m.notificationSender.NotificationWithSessionType(ctx, sendID, recvID, constant.HasReadReceipt, sessionType, tips) - } diff --git a/internal/rpc/msg/delete.go b/internal/rpc/msg/delete.go index 4590523d5..d24420ebb 100644 --- a/internal/rpc/msg/delete.go +++ b/internal/rpc/msg/delete.go @@ -94,6 +94,9 @@ func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*ms } func (m *msgServer) DeleteMsgPhysicalBySeq(ctx context.Context, req *msg.DeleteMsgPhysicalBySeqReq) (*msg.DeleteMsgPhysicalBySeqResp, error) { + if err := authverify.CheckAdmin(ctx); err != nil { + return nil, err + } err := m.MsgDatabase.DeleteMsgsPhysicalBySeqs(ctx, req.ConversationID, req.Seqs) if err != nil { return nil, err diff --git a/internal/rpc/msg/send.go b/internal/rpc/msg/send.go index 6b2ec30b5..0e3a9950b 100644 --- a/internal/rpc/msg/send.go +++ b/internal/rpc/msg/send.go @@ -17,6 +17,7 @@ package msg import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "google.golang.org/protobuf/proto" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" @@ -37,6 +38,9 @@ func (m *msgServer) SendMsg(ctx context.Context, req *pbmsg.SendMsgReq) (*pbmsg. if req.MsgData == nil { return nil, errs.ErrArgs.WrapMsg("msgData is nil") } + if err := authverify.CheckAccess(ctx, req.MsgData.SendID); err != nil { + return nil, err + } before := new(*sdkws.MsgData) resp, err := m.sendMsg(ctx, req, before) if err != nil { @@ -172,13 +176,7 @@ func (m *msgServer) sendMsgSingleChat(ctx context.Context, req *pbmsg.SendMsgReq isSend := true isNotification := msgprocessor.IsNotificationByMsg(req.MsgData) if !isNotification { - isSend, err = m.modifyMessageByUserMessageReceiveOpt( - ctx, - req.MsgData.RecvID, - conversationutil.GenConversationIDForSingle(req.MsgData.SendID, req.MsgData.RecvID), - constant.SingleChatType, - req, - ) + isSend, err = m.modifyMessageByUserMessageReceiveOpt(authverify.WithTempAdmin(ctx), req.MsgData.RecvID, conversationutil.GenConversationIDForSingle(req.MsgData.SendID, req.MsgData.RecvID), constant.SingleChatType, req) if err != nil { return nil, err } diff --git a/internal/rpc/msg/seq.go b/internal/rpc/msg/seq.go index bd68138fb..6a0461dc8 100644 --- a/internal/rpc/msg/seq.go +++ b/internal/rpc/msg/seq.go @@ -17,9 +17,10 @@ package msg import ( "context" "errors" + "sort" + pbmsg "github.com/openimsdk/protocol/msg" "github.com/redis/go-redis/v9" - "sort" ) func (m *msgServer) GetConversationMaxSeq(ctx context.Context, req *pbmsg.GetConversationMaxSeqReq) (*pbmsg.GetConversationMaxSeqResp, error) { diff --git a/internal/rpc/msg/statistics.go b/internal/rpc/msg/statistics.go index 01c0f1c46..b1f90cae4 100644 --- a/internal/rpc/msg/statistics.go +++ b/internal/rpc/msg/statistics.go @@ -16,15 +16,20 @@ package msg import ( "context" - "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "time" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/protocol/msg" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/utils/datautil" ) func (m *msgServer) GetActiveUser(ctx context.Context, req *msg.GetActiveUserReq) (*msg.GetActiveUserResp, error) { + if err := authverify.CheckAdmin(ctx); err != nil { + return nil, err + } msgCount, userCount, users, dateCount, err := m.MsgDatabase.RangeUserSendCount(ctx, time.UnixMilli(req.Start), time.UnixMilli(req.End), req.Group, req.Ase, req.Pagination.PageNumber, req.Pagination.ShowNumber) if err != nil { return nil, err @@ -60,6 +65,9 @@ func (m *msgServer) GetActiveUser(ctx context.Context, req *msg.GetActiveUserReq } func (m *msgServer) GetActiveGroup(ctx context.Context, req *msg.GetActiveGroupReq) (*msg.GetActiveGroupResp, error) { + if err := authverify.CheckAdmin(ctx); err != nil { + return nil, err + } msgCount, groupCount, groups, dateCount, err := m.MsgDatabase.RangeGroupSendCount(ctx, time.UnixMilli(req.Start), time.UnixMilli(req.End), req.Ase, req.Pagination.PageNumber, req.Pagination.ShowNumber) if err != nil { return nil, err diff --git a/internal/rpc/msg/sync_msg.go b/internal/rpc/msg/sync_msg.go index 38eed93bc..259c7f85d 100644 --- a/internal/rpc/msg/sync_msg.go +++ b/internal/rpc/msg/sync_msg.go @@ -29,6 +29,9 @@ import ( ) func (m *msgServer) PullMessageBySeqs(ctx context.Context, req *sdkws.PullMessageBySeqsReq) (*sdkws.PullMessageBySeqsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } resp := &sdkws.PullMessageBySeqsResp{} resp.Msgs = make(map[string]*sdkws.PullMsgs) resp.NotificationMsgs = make(map[string]*sdkws.PullMsgs) diff --git a/internal/rpc/relation/black.go b/internal/rpc/relation/black.go index 4a0e60ecd..07a1537b1 100644 --- a/internal/rpc/relation/black.go +++ b/internal/rpc/relation/black.go @@ -46,6 +46,9 @@ func (s *friendServer) GetPaginationBlacks(ctx context.Context, req *relation.Ge } func (s *friendServer) IsBlack(ctx context.Context, req *relation.IsBlackReq) (*relation.IsBlackResp, error) { + if err := authverify.CheckAccessIn(ctx, req.UserID1, req.UserID2); err != nil { + return nil, err + } in1, in2, err := s.blackDatabase.CheckIn(ctx, req.UserID1, req.UserID2) if err != nil { return nil, err diff --git a/internal/rpc/relation/friend.go b/internal/rpc/relation/friend.go index 050cd5ffe..06661c79d 100644 --- a/internal/rpc/relation/friend.go +++ b/internal/rpc/relation/friend.go @@ -280,6 +280,9 @@ func (s *friendServer) SetFriendRemark(ctx context.Context, req *relation.SetFri } func (s *friendServer) GetFriendInfo(ctx context.Context, req *relation.GetFriendInfoReq) (*relation.GetFriendInfoResp, error) { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } friends, err := s.db.FindFriendsWithError(ctx, req.OwnerUserID, req.FriendUserIDs) if err != nil { return nil, err @@ -288,6 +291,9 @@ func (s *friendServer) GetFriendInfo(ctx context.Context, req *relation.GetFrien } func (s *friendServer) GetDesignatedFriends(ctx context.Context, req *relation.GetDesignatedFriendsReq) (resp *relation.GetDesignatedFriendsResp, err error) { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } resp = &relation.GetDesignatedFriendsResp{} if datautil.Duplicate(req.FriendUserIDs) { return nil, errs.ErrArgs.WrapMsg("friend userID repeated") @@ -313,9 +319,10 @@ func (s *friendServer) getFriend(ctx context.Context, ownerUserID string, friend } // Get the list of friend requests sent out proactively. -func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context, - req *relation.GetDesignatedFriendsApplyReq, -) (resp *relation.GetDesignatedFriendsApplyResp, err error) { +func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context, req *relation.GetDesignatedFriendsApplyReq) (resp *relation.GetDesignatedFriendsApplyResp, err error) { + if err := authverify.CheckAccessIn(ctx, req.FromUserID, req.ToUserID); err != nil { + return nil, err + } friendRequests, err := s.db.FindBothFriendRequests(ctx, req.FromUserID, req.ToUserID) if err != nil { return nil, err @@ -374,6 +381,9 @@ func (s *friendServer) GetPaginationFriendsApplyFrom(ctx context.Context, req *r // ok. func (s *friendServer) IsFriend(ctx context.Context, req *relation.IsFriendReq) (resp *relation.IsFriendResp, err error) { + if err := authverify.CheckAccessIn(ctx, req.UserID1, req.UserID2); err != nil { + return nil, err + } resp = &relation.IsFriendResp{} resp.InUser1Friends, resp.InUser2Friends, err = s.db.CheckIn(ctx, req.UserID1, req.UserID2) if err != nil { @@ -426,6 +436,9 @@ func (s *friendServer) GetSpecifiedFriendsInfo(ctx context.Context, req *relatio return nil, errs.ErrArgs.WrapMsg("userIDList repeated") } + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } userMap, err := s.userClient.GetUsersInfoMap(ctx, req.UserIDList) if err != nil { return nil, err @@ -494,10 +507,7 @@ func (s *friendServer) GetSpecifiedFriendsInfo(ctx context.Context, req *relatio return resp, nil } -func (s *friendServer) UpdateFriends( - ctx context.Context, - req *relation.UpdateFriendsReq, -) (*relation.UpdateFriendsResp, error) { +func (s *friendServer) UpdateFriends(ctx context.Context, req *relation.UpdateFriendsReq) (*relation.UpdateFriendsResp, error) { if len(req.FriendUserIDs) == 0 { return nil, errs.ErrArgs.WrapMsg("friendIDList is empty") } @@ -505,6 +515,10 @@ func (s *friendServer) UpdateFriends( return nil, errs.ErrArgs.WrapMsg("friendIDList repeated") } + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { + return nil, err + } + _, err := s.db.FindFriendsWithError(ctx, req.OwnerUserID, req.FriendUserIDs) if err != nil { return nil, err diff --git a/internal/rpc/third/log.go b/internal/rpc/third/log.go index fba3ecb88..9a4995ace 100644 --- a/internal/rpc/third/log.go +++ b/internal/rpc/third/log.go @@ -25,6 +25,7 @@ import ( "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/third" "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/utils/datautil" ) @@ -45,7 +46,7 @@ func genLogID() string { func (t *thirdServer) UploadLogs(ctx context.Context, req *third.UploadLogsReq) (*third.UploadLogsResp, error) { var dbLogs []*relationtb.Log - userID := ctx.Value(constant.OpUserID).(string) + userID := mcontext.GetOpUserID(ctx) platform := constant.PlatformID2Name[int(req.Platform)] for _, fileURL := range req.FileURLs { log := relationtb.Log{ diff --git a/internal/rpc/third/third.go b/internal/rpc/third/third.go index 0afd54014..c6dcb2ea4 100644 --- a/internal/rpc/third/third.go +++ b/internal/rpc/third/third.go @@ -19,6 +19,7 @@ import ( "fmt" "time" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/mcache" "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" @@ -148,6 +149,9 @@ func (t *thirdServer) FcmUpdateToken(ctx context.Context, req *third.FcmUpdateTo } func (t *thirdServer) SetAppBadge(ctx context.Context, req *third.SetAppBadgeReq) (resp *third.SetAppBadgeResp, err error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } err = t.thirdDatabase.SetAppBadge(ctx, req.UserID, int(req.AppUnreadCount)) if err != nil { return nil, err diff --git a/pkg/authverify/token.go b/pkg/authverify/token.go index 75fb1448b..a82eba30a 100644 --- a/pkg/authverify/token.go +++ b/pkg/authverify/token.go @@ -51,29 +51,70 @@ func CheckSystemAccount(ctx context.Context, level int32) bool { } const ( - CtxIsAdminKey = "CtxIsAdminKey" + CtxAdminUserIDsKey = "CtxAdminUserIDsKey" ) func WithIMAdminUserIDs(ctx context.Context, imAdminUserID []string) context.Context { - return context.WithValue(ctx, CtxIsAdminKey, imAdminUserID) + return context.WithValue(ctx, CtxAdminUserIDsKey, imAdminUserID) } func GetIMAdminUserIDs(ctx context.Context) []string { - imAdminUserID, _ := ctx.Value(CtxIsAdminKey).([]string) + imAdminUserID, _ := ctx.Value(CtxAdminUserIDsKey).([]string) return imAdminUserID } func IsAdmin(ctx context.Context) bool { - return datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) + return IsTempAdmin(ctx) || IsSystemAdmin(ctx) } func CheckAccess(ctx context.Context, ownerUserID string) error { - opUserID := mcontext.GetOpUserID(ctx) - if opUserID == ownerUserID { + if mcontext.GetOpUserID(ctx) == ownerUserID { return nil } - if datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) { + if IsAdmin(ctx) { return nil } return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID) } + +func CheckAccessIn(ctx context.Context, ownerUserIDs ...string) error { + opUserID := mcontext.GetOpUserID(ctx) + for _, userID := range ownerUserIDs { + if opUserID == userID { + return nil + } + } + if IsAdmin(ctx) { + return nil + } + return servererrs.ErrNoPermission.WrapMsg("opUser in ownerUserIDs") +} + +var tempAdminValue = []string{"1"} + +const ctxTempAdminKey = "ctxImTempAdminKey" + +func WithTempAdmin(ctx context.Context) context.Context { + keys, _ := ctx.Value(constant.RpcCustomHeader).([]string) + if datautil.Contain(ctxTempAdminKey, keys...) { + return ctx + } + if len(keys) > 0 { + temp := make([]string, 0, len(keys)+1) + temp = append(temp, keys...) + keys = append(temp, ctxTempAdminKey) + } else { + keys = []string{ctxTempAdminKey} + } + ctx = context.WithValue(ctx, constant.RpcCustomHeader, keys) + return context.WithValue(ctx, ctxTempAdminKey, tempAdminValue) +} + +func IsTempAdmin(ctx context.Context) bool { + values, _ := ctx.Value(ctxTempAdminKey).([]string) + return datautil.Equal(tempAdminValue, values) +} + +func IsSystemAdmin(ctx context.Context) bool { + return datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) +} diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index 03621343b..b99d32db1 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -47,64 +47,6 @@ func init() { prommetrics.RegistryAll() } -func getConfigRpcMaxRequestBody(value reflect.Value) *conf.MaxRequestBody { - for value.Kind() == reflect.Pointer { - value = value.Elem() - } - if value.Kind() == reflect.Struct { - num := value.NumField() - for i := 0; i < num; i++ { - field := value.Field(i) - if !field.CanInterface() { - continue - } - for field.Kind() == reflect.Pointer { - field = field.Elem() - } - switch elem := field.Interface().(type) { - case conf.Share: - return &elem.RPCMaxBodySize - case conf.MaxRequestBody: - return &elem - } - if field.Kind() == reflect.Struct { - if elem := getConfigRpcMaxRequestBody(field); elem != nil { - return elem - } - } - } - } - return nil -} - -func getConfigShare(value reflect.Value) *conf.Share { - for value.Kind() == reflect.Pointer { - value = value.Elem() - } - if value.Kind() == reflect.Struct { - num := value.NumField() - for i := 0; i < num; i++ { - field := value.Field(i) - if !field.CanInterface() { - continue - } - for field.Kind() == reflect.Pointer { - field = field.Elem() - } - switch elem := field.Interface().(type) { - case conf.Share: - return &elem - } - if field.Kind() == reflect.Struct { - if elem := getConfigShare(field); elem != nil { - return elem - } - } - } - } - return nil -} - func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP, registerIP string, autoSetPorts bool, rpcPorts []int, index int, rpcRegisterName string, notification *conf.Notification, config T, watchConfigNames []string, watchServiceNames []string, @@ -122,8 +64,8 @@ func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *c options = append(options, grpcsrv.GrpcServerMetadataContext(), - grpcsrv.GrpcServerLogger(), grpcsrv.GrpcServerErrorConvert(), + grpcsrv.GrpcServerLogger(), grpcsrv.GrpcServerRequestValidate(), grpcsrv.GrpcServerPanicCapture(), ) diff --git a/pkg/common/startrpc/tools.go b/pkg/common/startrpc/tools.go new file mode 100644 index 000000000..54b518181 --- /dev/null +++ b/pkg/common/startrpc/tools.go @@ -0,0 +1,39 @@ +package startrpc + +import ( + "reflect" + + conf "github.com/openimsdk/open-im-server/v3/pkg/common/config" +) + +func getConfig[T any](value reflect.Value) *T { + for value.Kind() == reflect.Pointer { + value = value.Elem() + } + if value.Kind() == reflect.Struct { + num := value.NumField() + for i := 0; i < num; i++ { + field := value.Field(i) + for field.Kind() == reflect.Pointer { + field = field.Elem() + } + if field.Kind() == reflect.Struct { + if elem, ok := field.Interface().(T); ok { + return &elem + } + if elem := getConfig[T](field); elem != nil { + return elem + } + } + } + } + return nil +} + +func getConfigRpcMaxRequestBody(value reflect.Value) *conf.MaxRequestBody { + return getConfig[conf.MaxRequestBody](value) +} + +func getConfigShare(value reflect.Value) *conf.Share { + return getConfig[conf.Share](value) +} diff --git a/pkg/rpccache/online.go b/pkg/rpccache/online.go index 959b054a2..9b8c03101 100644 --- a/pkg/rpccache/online.go +++ b/pkg/rpccache/online.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/user" @@ -23,9 +24,25 @@ import ( "github.com/redis/go-redis/v9" ) -func NewOnlineCache(client *rpcli.UserClient, group *GroupLocalCache, rdb redis.UniversalClient, fullUserCache bool, fn func(ctx context.Context, userID string, platformIDs []int32)) (*OnlineCache, error) { +const ( + Begin uint32 = iota + DoOnlineStatusOver + DoSubscribeOver +) + +type OnlineCache interface { + GetUserOnlinePlatform(ctx context.Context, userID string) ([]int32, error) + GetUserOnline(ctx context.Context, userID string) (bool, error) + GetUsersOnline(ctx context.Context, userIDs []string) ([]string, []string, error) + WaitCache() +} + +func NewOnlineCache(client *rpcli.UserClient, group *GroupLocalCache, rdb redis.UniversalClient, fullUserCache bool, fn func(ctx context.Context, userID string, platformIDs []int32)) (OnlineCache, error) { + if config.Standalone() { + return disableOnlineCache{client: client}, nil + } l := &sync.Mutex{} - x := &OnlineCache{ + x := &defaultOnlineCache{ client: client, group: group, fullUserCache: fullUserCache, @@ -60,13 +77,7 @@ func NewOnlineCache(client *rpcli.UserClient, group *GroupLocalCache, rdb redis. return x, nil } -const ( - Begin uint32 = iota - DoOnlineStatusOver - DoSubscribeOver -) - -type OnlineCache struct { +type defaultOnlineCache struct { client *rpcli.UserClient group *GroupLocalCache @@ -82,7 +93,7 @@ type OnlineCache struct { CurrentPhase atomic.Uint32 } -func (o *OnlineCache) initUsersOnlineStatus(ctx context.Context) (err error) { +func (o *defaultOnlineCache) initUsersOnlineStatus(ctx context.Context) (err error) { log.ZDebug(ctx, "init users online status begin") var ( @@ -135,7 +146,7 @@ func (o *OnlineCache) initUsersOnlineStatus(ctx context.Context) (err error) { return nil } -func (o *OnlineCache) doSubscribe(ctx context.Context, rdb redis.UniversalClient, fn func(ctx context.Context, userID string, platformIDs []int32)) { +func (o *defaultOnlineCache) doSubscribe(ctx context.Context, rdb redis.UniversalClient, fn func(ctx context.Context, userID string, platformIDs []int32)) { o.Lock.Lock() ch := rdb.Subscribe(ctx, cachekey.OnlineChannel).Channel() for o.CurrentPhase.Load() < DoOnlineStatusOver { @@ -186,7 +197,7 @@ func (o *OnlineCache) doSubscribe(ctx context.Context, rdb redis.UniversalClient } } -func (o *OnlineCache) getUserOnlinePlatform(ctx context.Context, userID string) ([]int32, error) { +func (o *defaultOnlineCache) getUserOnlinePlatform(ctx context.Context, userID string) ([]int32, error) { platformIDs, err := o.lruCache.Get(userID, func() ([]int32, error) { return o.client.GetUserOnlinePlatform(ctx, userID) }) @@ -198,7 +209,7 @@ func (o *OnlineCache) getUserOnlinePlatform(ctx context.Context, userID string) return platformIDs, nil } -func (o *OnlineCache) GetUserOnlinePlatform(ctx context.Context, userID string) ([]int32, error) { +func (o *defaultOnlineCache) GetUserOnlinePlatform(ctx context.Context, userID string) ([]int32, error) { platformIDs, err := o.getUserOnlinePlatform(ctx, userID) if err != nil { return nil, err @@ -208,7 +219,7 @@ func (o *OnlineCache) GetUserOnlinePlatform(ctx context.Context, userID string) return platformIDs, nil } -func (o *OnlineCache) GetUserOnline(ctx context.Context, userID string) (bool, error) { +func (o *defaultOnlineCache) GetUserOnline(ctx context.Context, userID string) (bool, error) { platformIDs, err := o.getUserOnlinePlatform(ctx, userID) if err != nil { return false, err @@ -216,7 +227,7 @@ func (o *OnlineCache) GetUserOnline(ctx context.Context, userID string) (bool, e return len(platformIDs) > 0, nil } -func (o *OnlineCache) getUserOnlinePlatformBatch(ctx context.Context, userIDs []string) (map[string][]int32, error) { +func (o *defaultOnlineCache) getUserOnlinePlatformBatch(ctx context.Context, userIDs []string) (map[string][]int32, error) { if len(userIDs) == 0 { return nil, nil } @@ -240,7 +251,7 @@ func (o *OnlineCache) getUserOnlinePlatformBatch(ctx context.Context, userIDs [] return platformIDsMap, nil } -func (o *OnlineCache) GetUsersOnline(ctx context.Context, userIDs []string) ([]string, []string, error) { +func (o *defaultOnlineCache) GetUsersOnline(ctx context.Context, userIDs []string) ([]string, []string, error) { t := time.Now() var ( @@ -276,7 +287,7 @@ func (o *OnlineCache) GetUsersOnline(ctx context.Context, userIDs []string) ([]s return onlineUserIDs, offlineUserIDs, nil } -func (o *OnlineCache) setUserOnline(userID string, platformIDs []int32) { +func (o *defaultOnlineCache) setUserOnline(userID string, platformIDs []int32) { switch o.fullUserCache { case true: o.mapCache.Store(userID, platformIDs) @@ -285,6 +296,51 @@ func (o *OnlineCache) setUserOnline(userID string, platformIDs []int32) { } } -func (o *OnlineCache) setHasUserOnline(userID string, platformIDs []int32) bool { +func (o *defaultOnlineCache) setHasUserOnline(userID string, platformIDs []int32) bool { return o.lruCache.SetHas(userID, platformIDs) } + +func (o *defaultOnlineCache) WaitCache() { + o.Lock.Lock() + for o.CurrentPhase.Load() < DoSubscribeOver { + o.Cond.Wait() + } + o.Lock.Unlock() +} + +type disableOnlineCache struct { + client *rpcli.UserClient +} + +func (o disableOnlineCache) GetUserOnlinePlatform(ctx context.Context, userID string) ([]int32, error) { + return o.client.GetUserOnlinePlatform(ctx, userID) +} + +func (o disableOnlineCache) GetUserOnline(ctx context.Context, userID string) (bool, error) { + onlinePlatform, err := o.client.GetUserOnlinePlatform(ctx, userID) + if err != nil { + return false, err + } + return len(onlinePlatform) > 0, err +} + +func (o disableOnlineCache) GetUsersOnline(ctx context.Context, userIDs []string) ([]string, []string, error) { + var ( + onlineUserIDs = make([]string, 0, len(userIDs)) + offlineUserIDs = make([]string, 0, len(userIDs)) + ) + for _, userID := range userIDs { + online, err := o.GetUserOnline(ctx, userID) + if err != nil { + return nil, nil, err + } + if online { + onlineUserIDs = append(onlineUserIDs, userID) + } else { + offlineUserIDs = append(offlineUserIDs, userID) + } + } + return onlineUserIDs, offlineUserIDs, nil +} + +func (o disableOnlineCache) WaitCache() {} diff --git a/pkg/tools/batcher/batcher.go b/pkg/tools/batcher/batcher.go index dcf5d07ad..93a31ed8f 100644 --- a/pkg/tools/batcher/batcher.go +++ b/pkg/tools/batcher/batcher.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/utils/idutil" ) @@ -253,13 +254,14 @@ func (b *Batcher[T]) distributeMessage(messages map[string][]*T, totalCount int, func (b *Batcher[T]) run(channelID int, ch <-chan *Msg[T]) { defer b.wait.Done() + ctx := authverify.WithTempAdmin(context.Background()) for { select { case messages, ok := <-ch: if !ok { return } - b.Do(context.Background(), channelID, messages) + b.Do(ctx, channelID, messages) if b.config.syncWait { b.counter.Done() }