diff --git a/go.mod b/go.mod index 3863fde1f..34fd8d79d 100644 --- a/go.mod +++ b/go.mod @@ -219,3 +219,5 @@ require ( golang.org/x/crypto v0.27.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) + +replace github.com/openimsdk/tools => /Users/chao/Desktop/code/tools diff --git a/go.sum b/go.sum index c80690f80..b815ca8d6 100644 --- a/go.sum +++ b/go.sum @@ -349,8 +349,6 @@ github.com/openimsdk/gomake v0.0.15-alpha.5 h1:eEZCEHm+NsmcO3onXZPIUbGFCYPYbsX5b github.com/openimsdk/gomake v0.0.15-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= github.com/openimsdk/protocol v0.0.73-alpha.6 h1:sna9coWG7HN1zObBPtvG0Ki/vzqHXiB4qKbA5P3w7kc= github.com/openimsdk/protocol v0.0.73-alpha.6/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= -github.com/openimsdk/tools v0.0.50-alpha.79 h1:jxYEbrzaze4Z2r4NrKad816buZ690ix0L9MTOOOH3ik= -github.com/openimsdk/tools v0.0.50-alpha.79/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 455b77635..b63c17ecc 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -487,6 +487,11 @@ func (g *groupServer) GetGroupAllMember(ctx context.Context, req *pbgroup.GetGro } func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGroupMemberListReq) (*pbgroup.GetGroupMemberListResp, error) { + if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { + if _, err := g.db.TakeGroupMember(ctx, req.GroupID, opUserID); err != nil { + return nil, err + } + } var ( total int64 members []*model.GroupMember @@ -495,7 +500,7 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr if req.Keyword == "" { total, members, err = g.db.PageGetGroupMember(ctx, req.GroupID, req.Pagination) } else { - members, err = g.db.FindGroupMemberAll(ctx, req.GroupID) + total, members, err = g.db.SearchGroupMember(ctx, req.GroupID, req.Keyword, req.Pagination) } if err != nil { return nil, err @@ -503,27 +508,6 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr if err := g.PopulateGroupMember(ctx, members...); err != nil { return nil, err } - if req.Keyword != "" { - groupMembers := make([]*model.GroupMember, 0) - for _, member := range members { - if member.UserID == req.Keyword { - groupMembers = append(groupMembers, member) - total++ - continue - } - if member.Nickname == req.Keyword { - groupMembers = append(groupMembers, member) - total++ - continue - } - } - - members := datautil.Paginate(groupMembers, int(req.Pagination.GetPageNumber()), int(req.Pagination.GetShowNumber())) - return &pbgroup.GetGroupMemberListResp{ - Total: uint32(total), - Members: datautil.Batch(convert.Db2PbGroupMember, members), - }, nil - } return &pbgroup.GetGroupMemberListResp{ Total: uint32(total), Members: datautil.Batch(convert.Db2PbGroupMember, members), diff --git a/internal/rpc/group/sync.go b/internal/rpc/group/sync.go index d311ae076..ed4fabb7a 100644 --- a/internal/rpc/group/sync.go +++ b/internal/rpc/group/sync.go @@ -11,16 +11,24 @@ import ( "github.com/openimsdk/protocol/constant" pbgroup "github.com/openimsdk/protocol/group" "github.com/openimsdk/protocol/sdkws" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/mcontext" + "github.com/openimsdk/tools/utils/datautil" ) const versionSyncLimit = 500 func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { - vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) + userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID) if err != nil { return nil, err } - userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID) + if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { + if !datautil.Contain(opUserID, userIDs...) { + return nil, errs.ErrNoPermission.WrapMsg("user not in group") + } + } + vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) if err != nil { return nil, err } @@ -37,6 +45,9 @@ func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgrou } func (g *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) { + if err := authverify.CheckAccessV3(ctx, req.UserID, g.config.Share.IMAdminUserID); err != nil { + return nil, err + } vl, err := g.db.FindMaxJoinGroupVersionCache(ctx, req.UserID) if err != nil { return nil, err @@ -159,6 +170,7 @@ func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup. } func (g *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) { + var num int resp := make(map[string]*pbgroup.GetIncrementalGroupMemberResp) for _, memberReq := range req.ReqList { diff --git a/internal/rpc/relation/black.go b/internal/rpc/relation/black.go index 2108d7dc5..787418903 100644 --- a/internal/rpc/relation/black.go +++ b/internal/rpc/relation/black.go @@ -29,6 +29,8 @@ import ( ) func (s *friendServer) GetPaginationBlacks(ctx context.Context, req *relation.GetPaginationBlacksReq) (resp *relation.GetPaginationBlacksResp, err error) { + panic("test panic") + if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { return nil, err } diff --git a/internal/rpc/relation/sync.go b/internal/rpc/relation/sync.go index 0ad94fe82..02f20b573 100644 --- a/internal/rpc/relation/sync.go +++ b/internal/rpc/relation/sync.go @@ -2,10 +2,11 @@ package relation import ( "context" + "slices" + "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/log" - "slices" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/pkg/authverify" @@ -39,6 +40,12 @@ func (s *friendServer) NotificationUserInfoUpdate(ctx context.Context, req *rela } func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.GetFullFriendUserIDsReq) (*relation.GetFullFriendUserIDsResp, error) { + req.ProtoReflect() + + req.ProtoReflect().Type() + if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + return nil, err + } vl, err := s.db.FindMaxFriendVersionCache(ctx, req.UserID) if err != nil { return nil, err diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index 6ef61f773..d6da6cd22 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -283,9 +283,9 @@ func (s *userServer) UserRegister(ctx context.Context, req *pbuser.UserRegisterR return nil, errs.ErrArgs.WrapMsg("users is empty") } // check if secret is changed - if s.config.Share.Secret == defaultSecret { - return nil, servererrs.ErrSecretNotChanged.Wrap() - } + //if s.config.Share.Secret == defaultSecret { + // return nil, servererrs.ErrSecretNotChanged.Wrap() + //} if err = authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { return nil, err diff --git a/pkg/authverify/token.go b/pkg/authverify/token.go index 872feb1cf..c4367bad6 100644 --- a/pkg/authverify/token.go +++ b/pkg/authverify/token.go @@ -46,8 +46,8 @@ func IsAppManagerUid(ctx context.Context, imAdminUserID []string) bool { return datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...) } -func CheckAdmin(ctx context.Context, imAdminUserID []string) error { - if datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...) { +func CheckAdmin(ctx context.Context) error { + if IsAdmin(ctx) { return nil } return servererrs.ErrNoPermission.WrapMsg(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx))) @@ -60,3 +60,29 @@ func IsManagerUserID(opUserID string, imAdminUserID []string) bool { func CheckSystemAccount(ctx context.Context, level int32) bool { return level >= constant.AppAdmin } + +type ctxAuthKey struct{} + +func WithIMAdminUserIDs(ctx context.Context, imAdminUserID []string) context.Context { + return context.WithValue(ctx, ctxAuthKey{}, imAdminUserID) +} + +func GetIMAdminUserIDs(ctx context.Context) []string { + imAdminUserID, _ := ctx.Value(ctxAuthKey{}).([]string) + return imAdminUserID +} + +func IsAdmin(ctx context.Context) bool { + return datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) +} + +func CheckAccess(ctx context.Context, ownerUserID string) error { + opUserID := mcontext.GetOpUserID(ctx) + if opUserID == ownerUserID { + return nil + } + if !datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) { + return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID) + } + return nil +} diff --git a/pkg/common/startrpc/mw.go b/pkg/common/startrpc/mw.go new file mode 100644 index 000000000..c6cd55380 --- /dev/null +++ b/pkg/common/startrpc/mw.go @@ -0,0 +1,15 @@ +package startrpc + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "google.golang.org/grpc" +) + +func grpcServerIMAdminUserID(imAdminUserID []string) grpc.ServerOption { + return grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + ctx = authverify.WithIMAdminUserIDs(ctx, imAdminUserID) + return handler(ctx, req) + }) +} diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index af50c408d..4c7c51f6d 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -38,6 +38,7 @@ import ( "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/openimsdk/tools/mw" + grpcsrv "github.com/openimsdk/tools/mw/grpc/server" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -76,6 +77,34 @@ func getConfigRpcMaxRequestBody(value reflect.Value) *conf.MaxRequestBody { return nil } +func getConfigShare(value reflect.Value) *conf.Share { + for value.Kind() == reflect.Pointer { + value = value.Elem() + } + if value.Kind() == reflect.Struct { + num := value.NumField() + for i := 0; i < num; i++ { + field := value.Field(i) + if !field.CanInterface() { + continue + } + for field.Kind() == reflect.Pointer { + field = field.Elem() + } + switch elem := field.Interface().(type) { + case conf.Share: + return &elem + } + if field.Kind() == reflect.Struct { + if elem := getConfigShare(field); elem != nil { + return elem + } + } + } + } + return nil +} + func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP, registerIP string, autoSetPorts bool, rpcPorts []int, index int, rpcRegisterName string, notification *conf.Notification, config T, watchConfigNames []string, watchServiceNames []string, @@ -87,12 +116,20 @@ func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *c } maxRequestBody := getConfigRpcMaxRequestBody(reflect.ValueOf(config)) + shareConfig := getConfigShare(reflect.ValueOf(config)) log.ZDebug(ctx, "rpc start", "rpcMaxRequestBody", maxRequestBody, "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "listenIP", listenIP) options = append(options, - mw.GrpcServer(), + grpcsrv.GrpcServerMetadataContext(), + grpcsrv.GrpcServerLogger(), + grpcsrv.GrpcServerErrorConvert(), + grpcsrv.GrpcServerRequestValidate(), + grpcsrv.GrpcServerPanicCapture(), ) + if shareConfig != nil && len(shareConfig.IMAdminUserID) > 0 { + options = append(options, grpcServerIMAdminUserID(shareConfig.IMAdminUserID)) + } var clientOptions []grpc.DialOption if maxRequestBody != nil { if maxRequestBody.RequestMaxBodySize > 0 {