diff --git a/.github/workflows/go-build-test.yml b/.github/workflows/go-build-test.yml index 4033603e6..9e2aa3f1c 100644 --- a/.github/workflows/go-build-test.yml +++ b/.github/workflows/go-build-test.yml @@ -12,6 +12,10 @@ jobs: go-build: name: Test with go ${{ matrix.go_version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + + env: + SHARE_CONFIG_PATH: config/share.yml + permissions: contents: write pull-requests: write @@ -40,6 +44,10 @@ jobs: with: compose-file: "./docker-compose.yml" + - name: Modify Server Configuration + run: | + yq e '.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }} + # - name: Get Internal IP Address # id: get-ip # run: | @@ -71,6 +79,11 @@ jobs: go mod download go install github.com/magefile/mage@latest + - name: Modify Chat Configuration + run: | + cd ${{ github.workspace }}/chat-repo + yq e '.openIM.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }} + - name: Build and test Chat Services run: | cd ${{ github.workspace }}/chat-repo @@ -132,7 +145,7 @@ jobs: # Test get admin token get_admin_token_response=$(curl -X POST -H "Content-Type: application/json" -H "operationID: imAdmin" -d '{ - "secret": "openIM123", + "secret": "123456", "platformID": 2, "userID": "imAdmin" }' http://127.0.0.1:10002/auth/get_admin_token) @@ -169,7 +182,8 @@ jobs: contents: write env: SDK_DIR: openim-sdk-core - CONFIG_PATH: config/notification.yml + NOTIFICATION_CONFIG_PATH: config/notification.yml + SHARE_CONFIG_PATH: config/share.yml strategy: matrix: @@ -184,7 +198,7 @@ jobs: uses: actions/checkout@v4 with: repository: "openimsdk/openim-sdk-core" - ref: "release-v3.8" + ref: "main" path: ${{ env.SDK_DIR }} - name: Set up Go ${{ matrix.go_version }} @@ -199,8 +213,9 @@ jobs: - name: Modify Server Configuration run: | - yq e '.groupCreated.isSendMsg = true' -i ${{ env.CONFIG_PATH }} - yq e '.friendApplicationApproved.isSendMsg = true' -i ${{ env.CONFIG_PATH }} + yq e '.groupCreated.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }} + yq e '.friendApplicationApproved.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }} + yq e '.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }} - name: Start Server Services run: | diff --git a/.golangci.yml b/.golangci.yml index a95e980f8..7d6c6b596 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -28,6 +28,8 @@ run: # - util # - .*~ # - api/swagger/docs + + # - server/docs # - components/mnt/config/certs # - logs diff --git a/README.md b/README.md index a99559cdb..4745b9a37 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,15 @@ Thank you for contributing to building a powerful instant messaging solution! ## :closed_book: License -OpenIMSDK is available under the Apache License 2.0. See the [LICENSE file](https://github.com/openimsdk/open-im-server/blob/main/LICENSE) for more information. +This software is licensed under a dual-license model: + +- The GNU Affero General Public License (AGPL), Version 3 or later; **OR** +- Commercial license terms from OpenIMSDK. + +If you wish to use this software under commercial terms, please contact us at: contact@openim.io + +For more information, see: https://www.openim.io/en/licensing + diff --git a/README_zh_CN.md b/README_zh_CN.md index 59198eafb..2340ad09a 100644 --- a/README_zh_CN.md +++ b/README_zh_CN.md @@ -131,9 +131,17 @@ 感谢您的贡献,一起来打造强大的即时通讯解决方案! -## :closed_book: 许可证 +## :closed_book: 开源许可证 License + +本软件采用双重授权模型: + +GNU Affero 通用公共许可证(AGPL)第 3 版或更高版本;或 + +来自 OpenIMSDK 的商业授权条款。 + +如需商用,请联系:contact@openim.io +详见:https://www.openim.io/en/licensing - OpenIMSDK 在 Apache License 2.0 许可下可用。查看[LICENSE 文件](https://github.com/openimsdk/open-im-server/blob/main/LICENSE)了解更多信息。 ## 🔮 Thanks to our contributors! diff --git a/config/share.yml b/config/share.yml index a5fbeac75..2e9821436 100644 --- a/config/share.yml +++ b/config/share.yml @@ -1,9 +1,13 @@ secret: openIM123 -imAdminUserID: [ imAdmin ] +imAdminUserID: [imAdmin] # 1: For Android, iOS, Windows, Mac, and web platforms, only one instance can be online at a time multiLogin: policy: 1 # max num of tokens in one end - maxNumOneEnd: 30 \ No newline at end of file + maxNumOneEnd: 30 + +rpcMaxBodySize: + requestMaxBodySize: 8388608 + responseMaxBodySize: 8388608 diff --git a/go.mod b/go.mod index 0a9de4010..0e3a13904 100644 --- a/go.mod +++ b/go.mod @@ -12,8 +12,8 @@ require ( github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/openimsdk/protocol v0.0.72-alpha.79 - github.com/openimsdk/tools v0.0.50-alpha.74 + github.com/openimsdk/protocol v0.0.73-alpha.6 + github.com/openimsdk/tools v0.0.50-alpha.81 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.18.0 github.com/stretchr/testify v1.9.0 @@ -219,3 +219,5 @@ require ( golang.org/x/crypto v0.27.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) + +//replace github.com/openimsdk/protocol => /Users/chao/Desktop/code/protocol diff --git a/go.sum b/go.sum index 66af77379..6bc410a2d 100644 --- a/go.sum +++ b/go.sum @@ -345,12 +345,12 @@ github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= github.com/onsi/gomega v1.25.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM= -github.com/openimsdk/gomake v0.0.14-alpha.5 h1:VY9c5x515lTfmdhhPjMvR3BBRrRquAUCFsz7t7vbv7Y= -github.com/openimsdk/gomake v0.0.14-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= -github.com/openimsdk/protocol v0.0.72-alpha.79 h1:e46no8WVAsmTzyy405klrdoUiG7u+1ohDsXvQuFng4s= -github.com/openimsdk/protocol v0.0.72-alpha.79/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= -github.com/openimsdk/tools v0.0.50-alpha.74 h1:yh10SiMiivMEjicEQg+QAsH4pvaO+4noMPdlw+ew0Kc= -github.com/openimsdk/tools v0.0.50-alpha.74/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo= +github.com/openimsdk/gomake v0.0.15-alpha.5 h1:eEZCEHm+NsmcO3onXZPIUbGFCYPYbsX5beV3ZyOsGhY= +github.com/openimsdk/gomake v0.0.15-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= +github.com/openimsdk/protocol v0.0.73-alpha.6 h1:sna9coWG7HN1zObBPtvG0Ki/vzqHXiB4qKbA5P3w7kc= +github.com/openimsdk/protocol v0.0.73-alpha.6/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= +github.com/openimsdk/tools v0.0.50-alpha.81 h1:VbuJKtigNXLkCKB/Q6f2UHsqoSaTOAwS8F51c1nhOCA= +github.com/openimsdk/tools v0.0.50-alpha.81/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= diff --git a/internal/api/config_manager.go b/internal/api/config_manager.go index 4d846dd9b..35ab060a7 100644 --- a/internal/api/config_manager.go +++ b/internal/api/config_manager.go @@ -45,7 +45,7 @@ func NewConfigManager(IMAdminUserID []string, cfg *config.AllConfig, client *cli } func (cm *ConfigManager) CheckAdmin(c *gin.Context) { - if err := authverify.CheckAdmin(c, cm.imAdminUserID); err != nil { + if err := authverify.CheckAdmin(c); err != nil { apiresp.GinError(c, err) c.Abort() } diff --git a/internal/api/init.go b/internal/api/init.go index 4bd29c9e0..1e0f1075f 100644 --- a/internal/api/init.go +++ b/internal/api/init.go @@ -144,24 +144,23 @@ func Start(ctx context.Context, index int, config *Config) error { } }() - if config.Discovery.Enable == conf.ETCD { - cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames()) - cm.Watch(ctx) - } - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) - - shutdown := func() error { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - err := server.Shutdown(ctx) - if err != nil { - return errs.WrapMsg(err, "shutdown err") - } - return nil - } - disetcd.RegisterShutDown(shutdown) + //if config.Discovery.Enable == conf.ETCD { + // cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames()) + // cm.Watch(ctx) + //} + //sigs := make(chan os.Signal, 1) + //signal.Notify(sigs, syscall.SIGTERM) + //select { + //case val := <-sigs: + // log.ZDebug(ctx, "recv exit", "signal", val.String()) + // cancel(fmt.Errorf("signal %s", val.String())) + //case <-ctx.Done(): + //} + <-apiCtx.Done() + exitCause := context.Cause(apiCtx) + log.ZWarn(ctx, "api server exit", exitCause) + timer := time.NewTimer(time.Second * 15) + defer timer.Stop() select { case <-sigs: program.SIGTERMExit() diff --git a/internal/api/msg.go b/internal/api/msg.go index 1d53cbc48..8be4832e6 100644 --- a/internal/api/msg.go +++ b/internal/api/msg.go @@ -281,7 +281,7 @@ func (m *MessageApi) SendMessage(c *gin.Context) { } // Check if the user has the app manager role. - if !authverify.IsAppManagerUid(c, m.imAdminUserID) { + if !authverify.IsAdmin(c) { // Respond with a permission error if the user is not an app manager. apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message")) return @@ -355,7 +355,7 @@ func (m *MessageApi) SendBusinessNotification(c *gin.Context) { if req.ReliabilityLevel == nil { req.ReliabilityLevel = datautil.ToPtr(1) } - if !authverify.IsAppManagerUid(c, m.imAdminUserID) { + if !authverify.IsAdmin(c) { apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message")) return } @@ -399,7 +399,7 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) { apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap()) return } - if err := authverify.CheckAdmin(c, m.imAdminUserID); err != nil { + if err := authverify.CheckAdmin(c); err != nil { apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message")) return } diff --git a/internal/api/router.go b/internal/api/router.go index 216a43363..e9e5f6d5f 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -9,6 +9,11 @@ import ( "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/openimsdk/tools/mcontext" + "github.com/openimsdk/tools/utils/datautil" + clientv3 "go.etcd.io/etcd/client/v3" + "github.com/openimsdk/open-im-server/v3/internal/api/jssdk" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" @@ -96,7 +101,7 @@ func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cf r.Use(gzip.Gzip(gzip.BestSpeed)) } r.Use(prommetricsGin(), gin.RecoveryWithWriter(gin.DefaultErrorWriter, mw.GinPanicErr), mw.CorsHandler(), - mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn))) + mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn)), setGinIsAdmin(cfg.Share.IMAdminUserID)) u := NewUserApi(user.NewUserClient(userConn), client, cfg.Discovery.RpcService) { @@ -124,6 +129,11 @@ func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cf userRouterGroup.POST("/add_notification_account", u.AddNotificationAccount) userRouterGroup.POST("/update_notification_account", u.UpdateNotificationAccountInfo) userRouterGroup.POST("/search_notification_account", u.SearchNotificationAccount) + + userRouterGroup.POST("/get_user_client_config", u.GetUserClientConfig) + userRouterGroup.POST("/set_user_client_config", u.SetUserClientConfig) + userRouterGroup.POST("/del_user_client_config", u.DelUserClientConfig) + userRouterGroup.POST("/page_user_client_config", u.PageUserClientConfig) } // friend routing group { @@ -347,6 +357,14 @@ func GinParseToken(authClient *rpcli.AuthClient) gin.HandlerFunc { } } +func setGinIsAdmin(imAdminUserID []string) gin.HandlerFunc { + return func(c *gin.Context) { + opUserID := mcontext.GetOpUserID(c) + admin := datautil.Contain(opUserID, imAdminUserID...) + c.Set(authverify.CtxIsAdminKey, admin) + } +} + // Whitelist api not parse token var Whitelist = []string{ "/auth/get_admin_token", diff --git a/internal/api/user.go b/internal/api/user.go index a88f8f65a..7f256f5dd 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -242,3 +242,19 @@ func (u *UserApi) UpdateNotificationAccountInfo(c *gin.Context) { func (u *UserApi) SearchNotificationAccount(c *gin.Context) { a2r.Call(c, user.UserClient.SearchNotificationAccount, u.Client) } + +func (u *UserApi) GetUserClientConfig(c *gin.Context) { + a2r.Call(c, user.UserClient.GetUserClientConfig, u.Client) +} + +func (u *UserApi) SetUserClientConfig(c *gin.Context) { + a2r.Call(c, user.UserClient.SetUserClientConfig, u.Client) +} + +func (u *UserApi) DelUserClientConfig(c *gin.Context) { + a2r.Call(c, user.UserClient.DelUserClientConfig, u.Client) +} + +func (u *UserApi) PageUserClientConfig(c *gin.Context) { + a2r.Call(c, user.UserClient.PageUserClientConfig, u.Client) +} diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index 887a90d7a..8c744b7d1 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -101,7 +101,7 @@ func NewServer(longConnServer LongConnServer, conf *Config, ready func(srv *Serv } func (s *Server) GetUsersOnlineStatus(ctx context.Context, req *msggateway.GetUsersOnlineStatusReq) (*msggateway.GetUsersOnlineStatusResp, error) { - if !authverify.IsAppManagerUid(ctx, s.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { return nil, errs.ErrNoPermission.WrapMsg("only app manager") } var resp msggateway.GetUsersOnlineStatusResp diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 2e64c365c..2c2691d1d 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -18,11 +18,14 @@ import ( "context" "errors" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/mcache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo" + "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" "github.com/openimsdk/open-im-server/v3/pkg/common/config" redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" - "github.com/openimsdk/tools/db/redisutil" "github.com/openimsdk/tools/utils/datautil" "github.com/redis/go-redis/v9" @@ -43,7 +46,7 @@ import ( type authServer struct { pbauth.UnimplementedAuthServer authDatabase controller.AuthDatabase - RegisterCenter discovery.SvcDiscoveryRegistry + RegisterCenter discovery.Conn config *Config userClient *rpcli.UserClient } @@ -51,15 +54,31 @@ type authServer struct { type Config struct { RpcConfig config.Auth RedisConfig config.Redis + MongoConfig config.Mongo Share config.Share Discovery config.Discovery } -func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build()) +func Start(ctx context.Context, config *Config, client discovery.Conn, server grpc.ServiceRegistrar) error { + dbb := dbbuild.NewBuilder(&config.MongoConfig, &config.RedisConfig) + rdb, err := dbb.Redis(ctx) if err != nil { return err } + var token cache.TokenModel + if rdb == nil { + mdb, err := dbb.Mongo(ctx) + if err != nil { + return err + } + mc, err := mgo.NewCacheMgo(mdb.GetDB()) + if err != nil { + return err + } + token = mcache.NewTokenCacheModel(mc, config.RpcConfig.TokenPolicy.Expire) + } else { + token = redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire) + } userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User) if err != nil { return err @@ -67,7 +86,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg pbauth.RegisterAuthServer(server, &authServer{ RegisterCenter: client, authDatabase: controller.NewAuthDatabase( - redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire), + token, config.Share.Secret, config.RpcConfig.TokenPolicy.Expire, config.Share.MultiLogin, @@ -106,7 +125,7 @@ func (s *authServer) GetAdminToken(ctx context.Context, req *pbauth.GetAdminToke } func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -116,7 +135,7 @@ func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenR resp := pbauth.GetUserTokenResp{} - if authverify.IsManagerUserID(req.UserID, s.config.Share.IMAdminUserID) { + if authverify.CheckUserIsAdmin(ctx, req.UserID) { return nil, errs.ErrNoPermission.WrapMsg("don't get Admin token") } user, err := s.userClient.GetUserInfo(ctx, req.UserID) @@ -140,15 +159,17 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim if err != nil { return nil, err } - isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) - if isAdmin { - return claims, nil - } m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) if err != nil { return nil, err } if len(m) == 0 { + isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID) + if isAdmin { + if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { + return claims, nil + } + } return nil, servererrs.ErrTokenNotExist.Wrap() } if v, ok := m[tokensString]; ok { @@ -160,6 +181,13 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim default: return nil, errs.Wrap(errs.ErrTokenUnknown) } + } else { + isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID) + if isAdmin { + if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil { + return claims, nil + } + } } return nil, servererrs.ErrTokenNotExist.Wrap() } @@ -177,7 +205,7 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbauth.ParseTokenReq) } func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } if err := s.forceKickOff(ctx, req.UserID, req.PlatformID); err != nil { diff --git a/internal/rpc/conversation/sync.go b/internal/rpc/conversation/sync.go index ad88b2bbd..cee74b319 100644 --- a/internal/rpc/conversation/sync.go +++ b/internal/rpc/conversation/sync.go @@ -4,12 +4,16 @@ import ( "context" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" "github.com/openimsdk/protocol/conversation" ) func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, req *conversation.GetFullOwnerConversationIDsReq) (*conversation.GetFullOwnerConversationIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } vl, err := c.conversationDatabase.FindMaxConversationUserVersionCache(ctx, req.UserID) if err != nil { return nil, err diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 602c4f3ee..1ed3ce799 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -156,7 +156,7 @@ func (g *groupServer) NotificationUserInfoUpdate(ctx context.Context, req *pbgro } func (g *groupServer) CheckGroupAdmin(ctx context.Context, groupID string) error { - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { groupMember, err := g.db.TakeGroupMember(ctx, groupID, mcontext.GetOpUserID(ctx)) if err != nil { return err @@ -208,7 +208,7 @@ func (g *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR if req.OwnerUserID == "" { return nil, errs.ErrArgs.WrapMsg("no group owner") } - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, g.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } userIDs := append(append(req.MemberUserIDs, req.AdminUserIDs...), req.OwnerUserID) @@ -311,7 +311,7 @@ func (g *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR } func (g *groupServer) GetJoinedGroupList(ctx context.Context, req *pbgroup.GetJoinedGroupListReq) (*pbgroup.GetJoinedGroupListResp, error) { - if err := authverify.CheckAccessV3(ctx, req.FromUserID, g.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil { return nil, err } total, members, err := g.db.PageGetJoinGroup(ctx, req.FromUserID, req.Pagination) @@ -383,7 +383,7 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite var groupMember *model.GroupMember var opUserID string - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { opUserID = mcontext.GetOpUserID(ctx) var err error groupMember, err = g.db.TakeGroupMember(ctx, req.GroupID, opUserID) @@ -402,7 +402,7 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite } if group.NeedVerification == constant.AllNeedVerification { - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { if !(groupMember.RoleLevel == constant.GroupOwner || groupMember.RoleLevel == constant.GroupAdmin) { var requests []*model.GroupRequest for _, userID := range req.InvitedUserIDs { @@ -451,12 +451,25 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite return nil, err } - if err := g.db.CreateGroup(ctx, nil, groupMembers); err != nil { - return nil, err - } + const singleQuantity = 50 + for start := 0; start < len(groupMembers); start += singleQuantity { + end := start + singleQuantity + if end > len(groupMembers) { + end = len(groupMembers) + } + currentMembers := groupMembers[start:end] - if err = g.notification.GroupApplicationAgreeMemberEnterNotification(ctx, req.GroupID, opUserID, req.InvitedUserIDs...); err != nil { - return nil, err + if err := g.db.CreateGroup(ctx, nil, currentMembers); err != nil { + return nil, err + } + + userIDs := datautil.Slice(currentMembers, func(e *model.GroupMember) string { + return e.UserID + }) + + if err = g.notification.GroupApplicationAgreeMemberEnterNotification(ctx, req.GroupID, req.SendMessage, opUserID, userIDs...); err != nil { + return nil, err + } } return &pbgroup.InviteUserToGroupResp{}, nil } @@ -477,6 +490,11 @@ func (g *groupServer) GetGroupAllMember(ctx context.Context, req *pbgroup.GetGro } func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGroupMemberListReq) (*pbgroup.GetGroupMemberListResp, error) { + if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { + if _, err := g.db.TakeGroupMember(ctx, req.GroupID, opUserID); err != nil { + return nil, err + } + } var ( total int64 members []*model.GroupMember @@ -485,7 +503,7 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr if req.Keyword == "" { total, members, err = g.db.PageGetGroupMember(ctx, req.GroupID, req.Pagination) } else { - members, err = g.db.FindGroupMemberAll(ctx, req.GroupID) + total, members, err = g.db.SearchGroupMember(ctx, req.GroupID, req.Keyword, req.Pagination) } if err != nil { return nil, err @@ -493,27 +511,6 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr if err := g.PopulateGroupMember(ctx, members...); err != nil { return nil, err } - if req.Keyword != "" { - groupMembers := make([]*model.GroupMember, 0) - for _, member := range members { - if member.UserID == req.Keyword { - groupMembers = append(groupMembers, member) - total++ - continue - } - if member.Nickname == req.Keyword { - groupMembers = append(groupMembers, member) - total++ - continue - } - } - - members := datautil.Paginate(groupMembers, int(req.Pagination.GetPageNumber()), int(req.Pagination.GetShowNumber())) - return &pbgroup.GetGroupMemberListResp{ - Total: uint32(total), - Members: datautil.Batch(convert.Db2PbGroupMember, members), - }, nil - } return &pbgroup.GetGroupMemberListResp{ Total: uint32(total), Members: datautil.Batch(convert.Db2PbGroupMember, members), @@ -554,7 +551,7 @@ func (g *groupServer) KickGroupMember(ctx context.Context, req *pbgroup.KickGrou for i, member := range members { memberMap[member.UserID] = members[i] } - isAppManagerUid := authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) + isAppManagerUid := authverify.IsAdmin(ctx) opMember := memberMap[opUserID] for _, userID := range req.KickedUserIDs { member, ok := memberMap[userID] @@ -772,7 +769,7 @@ func (g *groupServer) GroupApplicationResponse(ctx context.Context, req *pbgroup if !datautil.Contain(req.HandleResult, constant.GroupResponseAgree, constant.GroupResponseRefuse) { return nil, errs.ErrArgs.WrapMsg("HandleResult unknown") } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { groupMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -920,7 +917,7 @@ func (g *groupServer) QuitGroup(ctx context.Context, req *pbgroup.QuitGroupReq) if req.UserID == "" { req.UserID = mcontext.GetOpUserID(ctx) } else { - if err := authverify.CheckAccessV3(ctx, req.UserID, g.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } } @@ -958,7 +955,7 @@ func (g *groupServer) deleteMemberAndSetConversationSeq(ctx context.Context, gro func (g *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInfoReq) (*pbgroup.SetGroupInfoResp, error) { var opMember *model.GroupMember - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { var err error opMember, err = g.db.TakeGroupMember(ctx, req.GroupInfoForSet.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { @@ -1051,7 +1048,7 @@ func (g *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInf func (g *groupServer) SetGroupInfoEx(ctx context.Context, req *pbgroup.SetGroupInfoExReq) (*pbgroup.SetGroupInfoExResp, error) { var opMember *model.GroupMember - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { var err error opMember, err = g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) @@ -1203,7 +1200,7 @@ func (g *groupServer) TransferGroupOwner(ctx context.Context, req *pbgroup.Trans return nil, errs.ErrArgs.WrapMsg("NewOwnerUser not in group " + req.NewOwnerUserID) } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { if !(mcontext.GetOpUserID(ctx) == oldOwner.UserID && oldOwner.RoleLevel == constant.GroupOwner) { return nil, errs.ErrNoPermission.WrapMsg("no permission transfer group owner") } @@ -1346,7 +1343,7 @@ func (g *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou if err != nil { return nil, err } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { if owner.UserID != mcontext.GetOpUserID(ctx) { return nil, errs.ErrNoPermission.WrapMsg("not group owner") } @@ -1369,6 +1366,7 @@ func (g *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou if err != nil { return nil, err } + group.Status = constant.GroupStatusDismissed tips := &sdkws.GroupDismissedTips{ Group: g.groupDB2PB(group, owner.UserID, num), OpUser: &sdkws.GroupMemberFullInfo{}, @@ -1402,7 +1400,7 @@ func (g *groupServer) MuteGroupMember(ctx context.Context, req *pbgroup.MuteGrou if err := g.PopulateGroupMember(ctx, member); err != nil { return nil, err } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { opMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -1438,7 +1436,7 @@ func (g *groupServer) CancelMuteGroupMember(ctx context.Context, req *pbgroup.Ca return nil, err } - if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { opMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -1498,7 +1496,7 @@ func (g *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr if opUserID == "" { return nil, errs.ErrNoPermission.WrapMsg("no op user id") } - isAppManagerUid := authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) + isAppManagerUid := authverify.IsAdmin(ctx) groupMembers := make(map[string][]*pbgroup.SetGroupMemberInfo) for i, member := range req.Members { if member.RoleLevel != nil { diff --git a/internal/rpc/group/notification.go b/internal/rpc/group/notification.go index 1aa5333b4..a6c715735 100644 --- a/internal/rpc/group/notification.go +++ b/internal/rpc/group/notification.go @@ -242,8 +242,8 @@ func (g *NotificationSender) fillOpUserByUserID(ctx context.Context, userID stri return errs.ErrInternalServer.WrapMsg("**sdkws.GroupMemberFullInfo is nil") } if groupID != "" { - if authverify.IsManagerUserID(userID, g.config.Share.IMAdminUserID) { - *opUser = &sdkws.GroupMemberFullInfo{ + if authverify.CheckUserIsAdmin(ctx, userID) { + *targetUser = &sdkws.GroupMemberFullInfo{ GroupID: groupID, UserID: userID, RoleLevel: constant.GroupAdmin, @@ -283,7 +283,8 @@ func (g *NotificationSender) fillOpUserByUserID(ctx context.Context, userID stri func (g *NotificationSender) setVersion(ctx context.Context, version *uint64, versionID *string, collName string, id string) { versions := versionctx.GetVersionLog(ctx).Get() - for _, coll := range versions { + for i := len(versions) - 1; i >= 0; i-- { + coll := versions[i] if coll.Name == collName && coll.Doc.DID == id { *version = uint64(coll.Doc.Version) *versionID = coll.Doc.ID.Hex() @@ -519,7 +520,11 @@ func (g *NotificationSender) MemberKickedNotification(ctx context.Context, tips g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.MemberKickedNotification, tips) } -func (g *NotificationSender) GroupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, invitedOpUserID string, entrantUserID ...string) error { +func (g *NotificationSender) GroupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, SendMessage *bool, invitedOpUserID string, entrantUserID ...string) error { + return g.groupApplicationAgreeMemberEnterNotification(ctx, groupID, SendMessage, invitedOpUserID, entrantUserID...) +} + +func (g *NotificationSender) groupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, SendMessage *bool, invitedOpUserID string, entrantUserID ...string) error { var err error defer func() { if err != nil { diff --git a/internal/rpc/group/sync.go b/internal/rpc/group/sync.go index 0592aa811..ed608dea3 100644 --- a/internal/rpc/group/sync.go +++ b/internal/rpc/group/sync.go @@ -12,15 +12,23 @@ import ( pbgroup "github.com/openimsdk/protocol/group" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/errs" - "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/mcontext" + "github.com/openimsdk/tools/utils/datautil" ) -func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { - vl, err := s.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) +const versionSyncLimit = 500 + +func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) { + userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID) if err != nil { return nil, err } - userIDs, err := s.db.FindGroupMemberUserID(ctx, req.GroupID) + if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) { + if !datautil.Contain(opUserID, userIDs...) { + return nil, errs.ErrNoPermission.WrapMsg("user not in group") + } + } + vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID) if err != nil { return nil, err } @@ -36,8 +44,11 @@ func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgrou }, nil } -func (s *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) { - vl, err := s.db.FindMaxJoinGroupVersionCache(ctx, req.UserID) +func (g *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } + vl, err := g.db.FindMaxJoinGroupVersionCache(ctx, req.UserID) if err != nil { return nil, err } @@ -65,6 +76,9 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou if group.Status == constant.GroupStatusDismissed { return nil, servererrs.ErrDismissedAlready.Wrap() } + if _, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)); err != nil { + return nil, err + } var ( hasGroupUpdate bool sortVersion uint64 @@ -132,152 +146,8 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou return resp, nil } -func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (resp *pbgroup.BatchGetIncrementalGroupMemberResp, err error) { - type VersionInfo struct { - GroupID string - VersionID string - VersionNumber uint64 - } - - var groupIDs []string - - groupsVersionMap := make(map[string]*VersionInfo) - groupsMap := make(map[string]*model.Group) - hasGroupUpdateMap := make(map[string]bool) - sortVersionMap := make(map[string]uint64) - - var targetKeys, versionIDs []string - var versionNumbers []uint64 - - var requestBodyLen int - - for _, group := range req.ReqList { - groupsVersionMap[group.GroupID] = &VersionInfo{ - GroupID: group.GroupID, - VersionID: group.VersionID, - VersionNumber: group.Version, - } - - groupIDs = append(groupIDs, group.GroupID) - } - - groups, err := s.db.FindGroup(ctx, groupIDs) - if err != nil { - return nil, errs.Wrap(err) - } - - for _, group := range groups { - if group.Status == constant.GroupStatusDismissed { - err = servererrs.ErrDismissedAlready.Wrap() - log.ZError(ctx, "This group is Dismissed Already", err, "group is", group.GroupID) - - delete(groupsVersionMap, group.GroupID) - } else { - groupsMap[group.GroupID] = group - } - } - - for groupID, vInfo := range groupsVersionMap { - targetKeys = append(targetKeys, groupID) - versionIDs = append(versionIDs, vInfo.VersionID) - versionNumbers = append(versionNumbers, vInfo.VersionNumber) - } - - opt := incrversion.BatchOption[[]*sdkws.GroupMemberFullInfo, pbgroup.BatchGetIncrementalGroupMemberResp]{ - Ctx: ctx, - TargetKeys: targetKeys, - VersionIDs: versionIDs, - VersionNumbers: versionNumbers, - Versions: func(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) { - vLogs, err := s.db.BatchFindMemberIncrVersion(ctx, groupIDs, versions, limits) - if err != nil { - return nil, errs.Wrap(err) - } - - for groupID, vlog := range vLogs { - vlogElems := make([]model.VersionLogElem, 0, len(vlog.Logs)) - for i, log := range vlog.Logs { - switch log.EID { - case model.VersionGroupChangeID: - vlog.LogLen-- - hasGroupUpdateMap[groupID] = true - case model.VersionSortChangeID: - vlog.LogLen-- - sortVersionMap[groupID] = uint64(log.Version) - default: - vlogElems = append(vlogElems, vlog.Logs[i]) - } - } - vlog.Logs = vlogElems - if vlog.LogLen > 0 { - hasGroupUpdateMap[groupID] = true - } - } - - return vLogs, nil - }, - CacheMaxVersions: s.db.BatchFindMaxGroupMemberVersionCache, - Find: func(ctx context.Context, groupID string, ids []string) ([]*sdkws.GroupMemberFullInfo, error) { - memberInfo, err := s.getGroupMembersInfo(ctx, groupID, ids) - if err != nil { - return nil, err - } - - return memberInfo, err - }, - Resp: func(versions map[string]*model.VersionLog, deleteIdsMap map[string][]string, insertListMap, updateListMap map[string][]*sdkws.GroupMemberFullInfo, fullMap map[string]bool) *pbgroup.BatchGetIncrementalGroupMemberResp { - resList := make(map[string]*pbgroup.GetIncrementalGroupMemberResp) - - for groupID, versionLog := range versions { - resList[groupID] = &pbgroup.GetIncrementalGroupMemberResp{ - VersionID: versionLog.ID.Hex(), - Version: uint64(versionLog.Version), - Full: fullMap[groupID], - Delete: deleteIdsMap[groupID], - Insert: insertListMap[groupID], - Update: updateListMap[groupID], - SortVersion: sortVersionMap[groupID], - } - - requestBodyLen += len(insertListMap[groupID]) + len(updateListMap[groupID]) + len(deleteIdsMap[groupID]) - if requestBodyLen > 200 { - break - } - } - - return &pbgroup.BatchGetIncrementalGroupMemberResp{ - RespList: resList, - } - }, - } - - resp, err = opt.Build() - if err != nil { - return nil, errs.Wrap(err) - } - - for groupID, val := range resp.RespList { - if val.Full || hasGroupUpdateMap[groupID] { - count, err := s.db.FindGroupMemberNum(ctx, groupID) - if err != nil { - return nil, err - } - - owner, err := s.db.TakeGroupOwner(ctx, groupID) - if err != nil { - return nil, err - } - - resp.RespList[groupID].Group = s.groupDB2PB(groupsMap[groupID], owner.UserID, count) - } - } - - return resp, nil - -} - -func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { +func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{ diff --git a/internal/rpc/msg/as_read.go b/internal/rpc/msg/as_read.go index de1879438..b25eae6b1 100644 --- a/internal/rpc/msg/as_read.go +++ b/internal/rpc/msg/as_read.go @@ -61,6 +61,13 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m return nil, err } resp := &msg.GetConversationsHasReadAndMaxSeqResp{Seqs: make(map[string]*msg.Seqs)} + if req.ReturnPinned { + pinnedConversationIDs, err := m.ConversationLocalCache.GetPinnedConversationIDs(ctx, req.UserID) + if err != nil { + return nil, err + } + resp.PinnedConversationIDs = pinnedConversationIDs + } for conversationID, maxSeq := range maxSeqs { resp.Seqs[conversationID] = &msg.Seqs{ HasReadSeq: hasReadSeqs[conversationID], diff --git a/internal/rpc/msg/clear.go b/internal/rpc/msg/clear.go index 8e14b281e..96eb99aed 100644 --- a/internal/rpc/msg/clear.go +++ b/internal/rpc/msg/clear.go @@ -2,15 +2,16 @@ package msg import ( "context" + "strings" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/protocol/msg" "github.com/openimsdk/tools/log" - "strings" ) // DestructMsgs hard delete in Database. func (m *msgServer) DestructMsgs(ctx context.Context, req *msg.DestructMsgsReq) (*msg.DestructMsgsResp, error) { - if err := authverify.CheckAdmin(ctx, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } docs, err := m.MsgDatabase.GetRandBeforeMsg(ctx, req.Timestamp, int(req.Limit)) diff --git a/internal/rpc/msg/delete.go b/internal/rpc/msg/delete.go index d3485faaa..4590523d5 100644 --- a/internal/rpc/msg/delete.go +++ b/internal/rpc/msg/delete.go @@ -42,7 +42,7 @@ func (m *msgServer) validateDeleteSyncOpt(opt *msg.DeleteSyncOpt) (isSyncSelf, i } func (m *msgServer) ClearConversationsMsg(ctx context.Context, req *msg.ClearConversationsMsgReq) (*msg.ClearConversationsMsgResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } if err := m.clearConversation(ctx, req.ConversationIDs, req.UserID, req.DeleteSyncOpt); err != nil { @@ -52,7 +52,7 @@ func (m *msgServer) ClearConversationsMsg(ctx context.Context, req *msg.ClearCon } func (m *msgServer) UserClearAllMsg(ctx context.Context, req *msg.UserClearAllMsgReq) (*msg.UserClearAllMsgResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) @@ -66,7 +66,7 @@ func (m *msgServer) UserClearAllMsg(ctx context.Context, req *msg.UserClearAllMs } func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*msg.DeleteMsgsResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(req.DeleteSyncOpt) @@ -102,7 +102,7 @@ func (m *msgServer) DeleteMsgPhysicalBySeq(ctx context.Context, req *msg.DeleteM } func (m *msgServer) DeleteMsgPhysical(ctx context.Context, req *msg.DeleteMsgPhysicalReq) (*msg.DeleteMsgPhysicalResp, error) { - if err := authverify.CheckAdmin(ctx, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } remainTime := timeutil.GetCurrentTimestampBySecond() - req.Timestamp diff --git a/internal/rpc/msg/revoke.go b/internal/rpc/msg/revoke.go index c2fb5833f..bd1d66ba1 100644 --- a/internal/rpc/msg/revoke.go +++ b/internal/rpc/msg/revoke.go @@ -42,7 +42,7 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. if req.Seq < 0 { return nil, errs.ErrArgs.WrapMsg("seq is invalid") } - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } user, err := m.UserLocalCache.GetUserInfo(ctx, req.UserID) @@ -63,11 +63,11 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. data, _ := json.Marshal(msgs[0]) log.ZDebug(ctx, "GetMsgBySeqs", "conversationID", req.ConversationID, "seq", req.Seq, "msg", string(data)) var role int32 - if !authverify.IsAppManagerUid(ctx, m.config.Share.IMAdminUserID) { + if !authverify.IsAdmin(ctx) { sessionType := msgs[0].SessionType switch sessionType { case constant.SingleChatType: - if err := authverify.CheckAccessV3(ctx, msgs[0].SendID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, msgs[0].SendID); err != nil { return nil, err } role = user.AppMangerLevel diff --git a/internal/rpc/msg/sync_msg.go b/internal/rpc/msg/sync_msg.go index 6cf1c21d3..38eed93bc 100644 --- a/internal/rpc/msg/sync_msg.go +++ b/internal/rpc/msg/sync_msg.go @@ -118,7 +118,7 @@ func (m *msgServer) GetSeqMessage(ctx context.Context, req *msg.GetSeqMessageReq } func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) diff --git a/internal/rpc/relation/black.go b/internal/rpc/relation/black.go index b795d6248..381a56273 100644 --- a/internal/rpc/relation/black.go +++ b/internal/rpc/relation/black.go @@ -30,10 +30,9 @@ import ( ) func (s *friendServer) GetPaginationBlacks(ctx context.Context, req *relation.GetPaginationBlacksReq) (resp *relation.GetPaginationBlacksResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } - total, blacks, err := s.blackDatabase.FindOwnerBlacks(ctx, req.UserID, req.Pagination) if err != nil { return nil, err @@ -59,7 +58,7 @@ func (s *friendServer) IsBlack(ctx context.Context, req *relation.IsBlackReq) (* } func (s *friendServer) RemoveBlack(ctx context.Context, req *relation.RemoveBlackReq) (*relation.RemoveBlackResp, error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -74,7 +73,7 @@ func (s *friendServer) RemoveBlack(ctx context.Context, req *relation.RemoveBlac } func (s *friendServer) AddBlack(ctx context.Context, req *relation.AddBlackReq) (*relation.AddBlackResp, error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -100,7 +99,7 @@ func (s *friendServer) AddBlack(ctx context.Context, req *relation.AddBlackReq) } func (s *friendServer) GetSpecifiedBlacks(ctx context.Context, req *relation.GetSpecifiedBlacksReq) (*relation.GetSpecifiedBlacksResp, error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } diff --git a/internal/rpc/relation/friend.go b/internal/rpc/relation/friend.go index 8172b8681..50a8667ea 100644 --- a/internal/rpc/relation/friend.go +++ b/internal/rpc/relation/friend.go @@ -135,7 +135,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg // ok. func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *relation.ApplyToAddFriendReq) (resp *relation.ApplyToAddFriendResp, err error) { resp = &relation.ApplyToAddFriendResp{} - if err := authverify.CheckAccessV3(ctx, req.FromUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil { return nil, err } if req.ToUserID == req.FromUserID { @@ -165,7 +165,7 @@ func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *relation.Apply // ok. func (s *friendServer) ImportFriends(ctx context.Context, req *relation.ImportFriendReq) (resp *relation.ImportFriendResp, err error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -201,7 +201,7 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *relation.ImportFr // ok. func (s *friendServer) RespondFriendApply(ctx context.Context, req *relation.RespondFriendApplyReq) (resp *relation.RespondFriendApplyResp, err error) { resp = &relation.RespondFriendApplyResp{} - if err := authverify.CheckAccessV3(ctx, req.ToUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.ToUserID); err != nil { return nil, err } @@ -236,7 +236,7 @@ func (s *friendServer) RespondFriendApply(ctx context.Context, req *relation.Res // ok. func (s *friendServer) DeleteFriend(ctx context.Context, req *relation.DeleteFriendReq) (resp *relation.DeleteFriendResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -261,7 +261,7 @@ func (s *friendServer) SetFriendRemark(ctx context.Context, req *relation.SetFri return nil, err } - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil { return nil, err } @@ -331,7 +331,7 @@ func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context, // Get received friend requests (i.e., those initiated by others). func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *relation.GetPaginationFriendsApplyToReq) (resp *relation.GetPaginationFriendsApplyToResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } @@ -354,7 +354,7 @@ func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *rel func (s *friendServer) GetPaginationFriendsApplyFrom(ctx context.Context, req *relation.GetPaginationFriendsApplyFromReq) (resp *relation.GetPaginationFriendsApplyFromResp, err error) { resp = &relation.GetPaginationFriendsApplyFromResp{} - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } @@ -384,7 +384,7 @@ func (s *friendServer) IsFriend(ctx context.Context, req *relation.IsFriendReq) } func (s *friendServer) GetPaginationFriends(ctx context.Context, req *relation.GetPaginationFriendsReq) (resp *relation.GetPaginationFriendsResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } @@ -405,7 +405,7 @@ func (s *friendServer) GetPaginationFriends(ctx context.Context, req *relation.G } func (s *friendServer) GetFriendIDs(ctx context.Context, req *relation.GetFriendIDsReq) (resp *relation.GetFriendIDsResp, err error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } diff --git a/internal/rpc/relation/sync.go b/internal/rpc/relation/sync.go index 0ad94fe82..79fa0858c 100644 --- a/internal/rpc/relation/sync.go +++ b/internal/rpc/relation/sync.go @@ -2,10 +2,11 @@ package relation import ( "context" + "slices" + "github.com/openimsdk/open-im-server/v3/pkg/util/hashutil" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/log" - "slices" "github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion" "github.com/openimsdk/open-im-server/v3/pkg/authverify" @@ -39,6 +40,9 @@ func (s *friendServer) NotificationUserInfoUpdate(ctx context.Context, req *rela } func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.GetFullFriendUserIDsReq) (*relation.GetFullFriendUserIDsResp, error) { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } vl, err := s.db.FindMaxFriendVersionCache(ctx, req.UserID) if err != nil { return nil, err @@ -60,7 +64,7 @@ func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.G } func (s *friendServer) GetIncrementalFriends(ctx context.Context, req *relation.GetIncrementalFriendsReq) (*relation.GetIncrementalFriendsResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { return nil, err } var sortVersion uint64 diff --git a/internal/rpc/third/log.go b/internal/rpc/third/log.go index 4d8cbc0bb..fba3ecb88 100644 --- a/internal/rpc/third/log.go +++ b/internal/rpc/third/log.go @@ -82,7 +82,7 @@ func (t *thirdServer) UploadLogs(ctx context.Context, req *third.UploadLogsReq) } func (t *thirdServer) DeleteLogs(ctx context.Context, req *third.DeleteLogsReq) (*third.DeleteLogsResp, error) { - if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } userID := "" @@ -123,7 +123,7 @@ func dbToPbLogInfos(logs []*relationtb.Log) []*third.LogInfo { } func (t *thirdServer) SearchLogs(ctx context.Context, req *third.SearchLogsReq) (*third.SearchLogsResp, error) { - if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } var ( diff --git a/internal/rpc/third/s3.go b/internal/rpc/third/s3.go index 97206dd6d..bfdd9f1b8 100644 --- a/internal/rpc/third/s3.go +++ b/internal/rpc/third/s3.go @@ -62,7 +62,7 @@ func (t *thirdServer) InitiateMultipartUpload(ctx context.Context, req *third.In return nil, err } expireTime := time.Now().Add(t.defaultExpire) - result, err := t.s3dataBase.InitiateMultipartUpload(ctx, req.Hash, req.Size, t.defaultExpire, int(req.MaxParts)) + result, err := t.s3dataBase.InitiateMultipartUpload(ctx, req.Hash, req.Size, t.defaultExpire, int(req.MaxParts), req.ContentType) if err != nil { if haErr, ok := errs.Unwrap(err).(*cont.HashAlreadyExistsError); ok { obj := &model.Object{ @@ -198,7 +198,7 @@ func (t *thirdServer) InitiateFormData(ctx context.Context, req *third.InitiateF var duration time.Duration opUserID := mcontext.GetOpUserID(ctx) var key string - if t.IsManagerUserID(opUserID) { + if authverify.CheckUserIsAdmin(ctx, opUserID) { if req.Millisecond <= 0 { duration = time.Minute * 10 } else { @@ -289,7 +289,7 @@ func (t *thirdServer) apiAddress(prefix, name string) string { } func (t *thirdServer) DeleteOutdatedData(ctx context.Context, req *third.DeleteOutdatedDataReq) (*third.DeleteOutdatedDataResp, error) { - if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } engine := t.config.RpcConfig.Object.Enable diff --git a/internal/rpc/third/tool.go b/internal/rpc/third/tool.go index 4e22ffbf9..a063fa654 100644 --- a/internal/rpc/third/tool.go +++ b/internal/rpc/third/tool.go @@ -54,7 +54,7 @@ func (t *thirdServer) checkUploadName(ctx context.Context, name string) error { if opUserID == "" { return errs.ErrNoPermission.WrapMsg("opUserID is empty") } - if !authverify.IsManagerUserID(opUserID, t.config.Share.IMAdminUserID) { + if !authverify.CheckUserIsAdmin(ctx, opUserID) { if !strings.HasPrefix(name, opUserID+"/") { return errs.ErrNoPermission.WrapMsg(fmt.Sprintf("name must start with `%s/`", opUserID)) } @@ -79,10 +79,6 @@ func checkValidObjectName(objectName string) error { return checkValidObjectNamePrefix(objectName) } -func (t *thirdServer) IsManagerUserID(opUserID string) bool { - return authverify.IsManagerUserID(opUserID, t.config.Share.IMAdminUserID) -} - func putUpdate[T any](update map[string]any, name string, val interface{ GetValuePtr() *T }) { ptrVal := val.GetValuePtr() if ptrVal == nil { diff --git a/internal/rpc/user/config.go b/internal/rpc/user/config.go new file mode 100644 index 000000000..f3f5a7a96 --- /dev/null +++ b/internal/rpc/user/config.go @@ -0,0 +1,71 @@ +package user + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + pbuser "github.com/openimsdk/protocol/user" + "github.com/openimsdk/tools/utils/datautil" +) + +func (s *userServer) GetUserClientConfig(ctx context.Context, req *pbuser.GetUserClientConfigReq) (*pbuser.GetUserClientConfigResp, error) { + if req.UserID != "" { + if err := authverify.CheckAccess(ctx, req.UserID); err != nil { + return nil, err + } + if _, err := s.db.GetUserByID(ctx, req.UserID); err != nil { + return nil, err + } + } + res, err := s.clientConfig.GetUserConfig(ctx, req.UserID) + if err != nil { + return nil, err + } + return &pbuser.GetUserClientConfigResp{Configs: res}, nil +} + +func (s *userServer) SetUserClientConfig(ctx context.Context, req *pbuser.SetUserClientConfigReq) (*pbuser.SetUserClientConfigResp, error) { + if err := authverify.CheckAdmin(ctx); err != nil { + return nil, err + } + if req.UserID != "" { + if _, err := s.db.GetUserByID(ctx, req.UserID); err != nil { + return nil, err + } + } + if err := s.clientConfig.SetUserConfig(ctx, req.UserID, req.Configs); err != nil { + return nil, err + } + return &pbuser.SetUserClientConfigResp{}, nil +} + +func (s *userServer) DelUserClientConfig(ctx context.Context, req *pbuser.DelUserClientConfigReq) (*pbuser.DelUserClientConfigResp, error) { + if err := authverify.CheckAdmin(ctx); err != nil { + return nil, err + } + if err := s.clientConfig.DelUserConfig(ctx, req.UserID, req.Keys); err != nil { + return nil, err + } + return &pbuser.DelUserClientConfigResp{}, nil +} + +func (s *userServer) PageUserClientConfig(ctx context.Context, req *pbuser.PageUserClientConfigReq) (*pbuser.PageUserClientConfigResp, error) { + if err := authverify.CheckAdmin(ctx); err != nil { + return nil, err + } + total, res, err := s.clientConfig.GetUserConfigPage(ctx, req.UserID, req.Key, req.Pagination) + if err != nil { + return nil, err + } + return &pbuser.PageUserClientConfigResp{ + Total: total, + Configs: datautil.Slice(res, func(e *model.ClientConfig) *pbuser.ClientConfig { + return &pbuser.ClientConfig{ + UserID: e.UserID, + Key: e.Key, + Value: e.Value, + } + }), + }, nil +} diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index d4fe7ecc4..7f082f784 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -23,45 +23,48 @@ import ( "time" "github.com/openimsdk/open-im-server/v3/internal/rpc/relation" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/convert" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo" tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/open-im-server/v3/pkg/common/webhook" + "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" "github.com/openimsdk/open-im-server/v3/pkg/localcache" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" + "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/group" friendpb "github.com/openimsdk/protocol/relation" - "github.com/openimsdk/tools/db/redisutil" - - "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/common/convert" - "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" - "github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller" - "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/sdkws" pbuser "github.com/openimsdk/protocol/user" - "github.com/openimsdk/tools/db/mongoutil" "github.com/openimsdk/tools/db/pagination" - registry "github.com/openimsdk/tools/discovery" + "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/utils/datautil" "google.golang.org/grpc" ) +const ( + defaultSecret = "openIM123" +) + type userServer struct { pbuser.UnimplementedUserServer online cache.OnlineCache db controller.UserDatabase friendNotificationSender *relation.FriendNotificationSender userNotificationSender *UserNotificationSender - RegisterCenter registry.SvcDiscoveryRegistry + RegisterCenter discovery.Conn config *Config webhookClient *webhook.Client groupClient *rpcli.GroupClient relationClient *rpcli.RelationClient + clientConfig controller.ClientConfigDatabase } type Config struct { @@ -76,24 +79,30 @@ type Config struct { Discovery config.Discovery } -func Start(ctx context.Context, config *Config, client registry.SvcDiscoveryRegistry, server *grpc.Server) error { - mgocli, err := mongoutil.NewMongoDB(ctx, config.MongodbConfig.Build()) +func Start(ctx context.Context, config *Config, client discovery.Conn, server grpc.ServiceRegistrar) error { + dbb := dbbuild.NewBuilder(&config.MongodbConfig, &config.RedisConfig) + mgocli, err := dbb.Mongo(ctx) if err != nil { return err } - rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build()) + rdb, err := dbb.Redis(ctx) if err != nil { return err } + users := make([]*tablerelation.User, 0) for _, v := range config.Share.IMAdminUserID { - users = append(users, &tablerelation.User{UserID: v, Nickname: v, AppMangerLevel: constant.AppNotificationAdmin}) + users = append(users, &tablerelation.User{UserID: v, Nickname: v, AppMangerLevel: constant.AppAdmin}) } userDB, err := mgo.NewUserMongo(mgocli.GetDB()) if err != nil { return err } + clientConfigDB, err := mgo.NewClientConfig(mgocli.GetDB()) + if err != nil { + return err + } msgConn, err := client.GetConn(ctx, config.Discovery.RpcService.Msg) if err != nil { return err @@ -118,9 +127,9 @@ func Start(ctx context.Context, config *Config, client registry.SvcDiscoveryRegi userNotificationSender: NewUserNotificationSender(config, msgClient, WithUserFunc(database.FindWithError)), config: config, webhookClient: webhook.NewWebhookClient(config.WebhooksConfig.URL), - - groupClient: rpcli.NewGroupClient(groupConn), - relationClient: rpcli.NewRelationClient(friendConn), + clientConfig: controller.NewClientConfigDatabase(clientConfigDB, redis.NewClientConfigCache(rdb, clientConfigDB), mgocli.GetTx()), + groupClient: rpcli.NewGroupClient(groupConn), + relationClient: rpcli.NewRelationClient(friendConn), } pbuser.RegisterUserServer(server, u) return u.db.InitOnce(context.Background(), users) @@ -141,7 +150,7 @@ func (s *userServer) GetDesignateUsers(ctx context.Context, req *pbuser.GetDesig // UpdateUserInfo func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) (resp *pbuser.UpdateUserInfoResp, err error) { resp = &pbuser.UpdateUserInfoResp{} - err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config.Share.IMAdminUserID) + err = authverify.CheckAccess(ctx, req.UserInfo.UserID) if err != nil { return nil, err } @@ -168,7 +177,7 @@ func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserI func (s *userServer) UpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) (resp *pbuser.UpdateUserInfoExResp, err error) { resp = &pbuser.UpdateUserInfoExResp{} - err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config.Share.IMAdminUserID) + err = authverify.CheckAccess(ctx, req.UserInfo.UserID) if err != nil { return nil, err } @@ -226,8 +235,7 @@ func (s *userServer) AccountCheck(ctx context.Context, req *pbuser.AccountCheckR if datautil.Duplicate(req.CheckUserIDs) { return nil, errs.ErrArgs.WrapMsg("userID repeated") } - err = authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID) - if err != nil { + if err = authverify.CheckAdmin(ctx); err != nil { return nil, err } users, err := s.db.Find(ctx, req.CheckUserIDs) @@ -273,11 +281,13 @@ func (s *userServer) UserRegister(ctx context.Context, req *pbuser.UserRegisterR if len(req.Users) == 0 { return nil, errs.ErrArgs.WrapMsg("users is empty") } - - if err = authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + // check if secret is changed + //if s.config.Share.Secret == defaultSecret { + // return nil, servererrs.ErrSecretNotChanged.Wrap() + //} + if err = authverify.CheckAdmin(ctx); err != nil { return nil, err } - if datautil.DuplicateAny(req.Users, func(e *sdkws.UserInfo) string { return e.UserID }) { return nil, errs.ErrArgs.WrapMsg("userID repeated") } @@ -343,7 +353,7 @@ func (s *userServer) GetAllUserID(ctx context.Context, req *pbuser.GetAllUserIDR // ProcessUserCommandAdd user general function add. func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.ProcessUserCommandAddReq) (*pbuser.ProcessUserCommandAddResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -371,7 +381,7 @@ func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.Proc // ProcessUserCommandDelete user general function delete. func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.ProcessUserCommandDeleteReq) (*pbuser.ProcessUserCommandDeleteResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -390,7 +400,7 @@ func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.P // ProcessUserCommandUpdate user general function update. func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.ProcessUserCommandUpdateReq) (*pbuser.ProcessUserCommandUpdateResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -419,7 +429,7 @@ func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.P func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.ProcessUserCommandGetReq) (*pbuser.ProcessUserCommandGetResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -448,7 +458,7 @@ func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.Proc } func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.ProcessUserCommandGetAllReq) (*pbuser.ProcessUserCommandGetAllResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID) + err := authverify.CheckAccess(ctx, req.UserID) if err != nil { return nil, err } @@ -477,7 +487,7 @@ func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.P } func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.AddNotificationAccountReq) (*pbuser.AddNotificationAccountResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } if req.AppMangerLevel < constant.AppNotificationAdmin { @@ -523,7 +533,7 @@ func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.Add } func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbuser.UpdateNotificationAccountInfoReq) (*pbuser.UpdateNotificationAccountInfoResp, error) { - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -550,7 +560,7 @@ func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbu func (s *userServer) SearchNotificationAccount(ctx context.Context, req *pbuser.SearchNotificationAccountReq) (*pbuser.SearchNotificationAccountResp, error) { // Check if user is an admin - if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil { + if err := authverify.CheckAdmin(ctx); err != nil { return nil, err } @@ -605,7 +615,7 @@ func (s *userServer) GetNotificationAccount(ctx context.Context, req *pbuser.Get if err != nil { return nil, servererrs.ErrUserIDNotFound.Wrap() } - if user.AppMangerLevel == constant.AppAdmin || user.AppMangerLevel >= constant.AppNotificationAdmin { + if user.AppMangerLevel >= constant.AppAdmin { return &pbuser.GetNotificationAccountResp{Account: &pbuser.NotificationAccountInfo{ UserID: user.UserID, FaceURL: user.FaceURL, diff --git a/pkg/authverify/token.go b/pkg/authverify/token.go index 872feb1cf..75fb1448b 100644 --- a/pkg/authverify/token.go +++ b/pkg/authverify/token.go @@ -31,32 +31,49 @@ func Secret(secret string) jwt.Keyfunc { } } -func CheckAccessV3(ctx context.Context, ownerUserID string, imAdminUserID []string) (err error) { - opUserID := mcontext.GetOpUserID(ctx) - if datautil.Contain(opUserID, imAdminUserID...) { - return nil - } - if opUserID == ownerUserID { +func CheckAdmin(ctx context.Context) error { + if IsAdmin(ctx) { return nil } - return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID) + return servererrs.ErrNoPermission.WrapMsg(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx))) } -func IsAppManagerUid(ctx context.Context, imAdminUserID []string) bool { - return datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...) +//func IsManagerUserID(opUserID string, imAdminUserID []string) bool { +// return datautil.Contain(opUserID, imAdminUserID...) +//} + +func CheckUserIsAdmin(ctx context.Context, userID string) bool { + return datautil.Contain(userID, GetIMAdminUserIDs(ctx)...) } -func CheckAdmin(ctx context.Context, imAdminUserID []string) error { - if datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...) { - return nil - } - return servererrs.ErrNoPermission.WrapMsg(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx))) +func CheckSystemAccount(ctx context.Context, level int32) bool { + return level >= constant.AppAdmin } -func IsManagerUserID(opUserID string, imAdminUserID []string) bool { - return datautil.Contain(opUserID, imAdminUserID...) +const ( + CtxIsAdminKey = "CtxIsAdminKey" +) + +func WithIMAdminUserIDs(ctx context.Context, imAdminUserID []string) context.Context { + return context.WithValue(ctx, CtxIsAdminKey, imAdminUserID) } -func CheckSystemAccount(ctx context.Context, level int32) bool { - return level >= constant.AppAdmin +func GetIMAdminUserIDs(ctx context.Context) []string { + imAdminUserID, _ := ctx.Value(CtxIsAdminKey).([]string) + return imAdminUserID +} + +func IsAdmin(ctx context.Context) bool { + return datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) +} + +func CheckAccess(ctx context.Context, ownerUserID string) error { + opUserID := mcontext.GetOpUserID(ctx) + if opUserID == ownerUserID { + return nil + } + if datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) { + return nil + } + return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID) } diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index ca448083c..6b3bff30f 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -378,9 +378,15 @@ type AfterConfig struct { } type Share struct { - Secret string `mapstructure:"secret"` - IMAdminUserID []string `mapstructure:"imAdminUserID"` - MultiLogin MultiLogin `mapstructure:"multiLogin"` + Secret string `yaml:"secret"` + IMAdminUserID []string `yaml:"imAdminUserID"` + MultiLogin MultiLogin `yaml:"multiLogin"` + RPCMaxBodySize MaxRequestBody `yaml:"rpcMaxBodySize"` +} + +type MaxRequestBody struct { + RequestMaxBodySize int `yaml:"requestMaxBodySize"` + ResponseMaxBodySize int `yaml:"responseMaxBodySize"` } type MultiLogin struct { diff --git a/pkg/common/servererrs/code.go b/pkg/common/servererrs/code.go index 3d0aa4a71..906f890a5 100644 --- a/pkg/common/servererrs/code.go +++ b/pkg/common/servererrs/code.go @@ -37,7 +37,8 @@ const ( // General error codes. const ( - NoError = 0 // No error + NoError = 0 // No error + DatabaseError = 90002 // Database error (redis/mysql, etc.) NetworkError = 90004 // Network error DataError = 90007 // Data error @@ -45,11 +46,12 @@ const ( CallbackError = 80000 // General error codes. - ServerInternalError = 500 // Server internal error - ArgsError = 1001 // Input parameter error - NoPermissionError = 1002 // Insufficient permission - DuplicateKeyError = 1003 - RecordNotFoundError = 1004 // Record does not exist + ServerInternalError = 500 // Server internal error + ArgsError = 1001 // Input parameter error + NoPermissionError = 1002 // Insufficient permission + DuplicateKeyError = 1003 + RecordNotFoundError = 1004 // Record does not exist + SecretNotChangedError = 1050 // secret not changed // Account error codes. UserIDNotFoundError = 1101 // UserID does not exist or is not registered diff --git a/pkg/common/servererrs/predefine.go b/pkg/common/servererrs/predefine.go index ab09aa512..b1d6b06a9 100644 --- a/pkg/common/servererrs/predefine.go +++ b/pkg/common/servererrs/predefine.go @@ -17,6 +17,8 @@ package servererrs import "github.com/openimsdk/tools/errs" var ( + ErrSecretNotChanged = errs.NewCodeError(SecretNotChangedError, "secret not changed, please change secret in config/share.yml for security reasons") + ErrDatabase = errs.NewCodeError(DatabaseError, "DatabaseError") ErrNetwork = errs.NewCodeError(NetworkError, "NetworkError") ErrCallback = errs.NewCodeError(CallbackError, "CallbackError") diff --git a/pkg/common/startrpc/mw.go b/pkg/common/startrpc/mw.go new file mode 100644 index 000000000..c6cd55380 --- /dev/null +++ b/pkg/common/startrpc/mw.go @@ -0,0 +1,15 @@ +package startrpc + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "google.golang.org/grpc" +) + +func grpcServerIMAdminUserID(imAdminUserID []string) grpc.ServerOption { + return grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + ctx = authverify.WithIMAdminUserIDs(ctx, imAdminUserID) + return handler(ctx, req) + }) +} diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index 99df537f7..03621343b 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -19,214 +19,285 @@ import ( "errors" "fmt" "net" - "net/http" "os" "os/signal" + "reflect" "strconv" "syscall" "time" conf "github.com/openimsdk/open-im-server/v3/pkg/common/config" - disetcd "github.com/openimsdk/open-im-server/v3/pkg/common/discovery/etcd" - "github.com/openimsdk/tools/discovery/etcd" "github.com/openimsdk/tools/utils/datautil" "github.com/openimsdk/tools/utils/jsonutil" + "github.com/openimsdk/tools/utils/network" "google.golang.org/grpc/status" - "github.com/openimsdk/tools/utils/runtimeenv" - kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discovery" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" - "github.com/openimsdk/tools/mw" - "github.com/openimsdk/tools/utils/network" + grpccli "github.com/openimsdk/tools/mw/grpc/client" + grpcsrv "github.com/openimsdk/tools/mw/grpc/server" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) -// Start rpc server. -func Start[T any](ctx context.Context, discovery *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP, +func init() { + prommetrics.RegistryAll() +} + +func getConfigRpcMaxRequestBody(value reflect.Value) *conf.MaxRequestBody { + for value.Kind() == reflect.Pointer { + value = value.Elem() + } + if value.Kind() == reflect.Struct { + num := value.NumField() + for i := 0; i < num; i++ { + field := value.Field(i) + if !field.CanInterface() { + continue + } + for field.Kind() == reflect.Pointer { + field = field.Elem() + } + switch elem := field.Interface().(type) { + case conf.Share: + return &elem.RPCMaxBodySize + case conf.MaxRequestBody: + return &elem + } + if field.Kind() == reflect.Struct { + if elem := getConfigRpcMaxRequestBody(field); elem != nil { + return elem + } + } + } + } + return nil +} + +func getConfigShare(value reflect.Value) *conf.Share { + for value.Kind() == reflect.Pointer { + value = value.Elem() + } + if value.Kind() == reflect.Struct { + num := value.NumField() + for i := 0; i < num; i++ { + field := value.Field(i) + if !field.CanInterface() { + continue + } + for field.Kind() == reflect.Pointer { + field = field.Elem() + } + switch elem := field.Interface().(type) { + case conf.Share: + return &elem + } + if field.Kind() == reflect.Struct { + if elem := getConfigShare(field); elem != nil { + return elem + } + } + } + } + return nil +} + +func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP, registerIP string, autoSetPorts bool, rpcPorts []int, index int, rpcRegisterName string, notification *conf.Notification, config T, watchConfigNames []string, watchServiceNames []string, - rpcFn func(ctx context.Context, config T, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error, + rpcFn func(ctx context.Context, config T, client discovery.Conn, server grpc.ServiceRegistrar) error, options ...grpc.ServerOption) error { - watchConfigNames = append(watchConfigNames, conf.LogConfigFileName) - var ( - rpcTcpAddr string - netDone = make(chan struct{}, 2) - netErr error - prometheusPort int - ) - if notification != nil { conf.InitNotification(notification) } + maxRequestBody := getConfigRpcMaxRequestBody(reflect.ValueOf(config)) + shareConfig := getConfigShare(reflect.ValueOf(config)) + + log.ZDebug(ctx, "rpc start", "rpcMaxRequestBody", maxRequestBody, "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "listenIP", listenIP) + + options = append(options, + grpcsrv.GrpcServerMetadataContext(), + grpcsrv.GrpcServerLogger(), + grpcsrv.GrpcServerErrorConvert(), + grpcsrv.GrpcServerRequestValidate(), + grpcsrv.GrpcServerPanicCapture(), + ) + if shareConfig != nil && len(shareConfig.IMAdminUserID) > 0 { + options = append(options, grpcServerIMAdminUserID(shareConfig.IMAdminUserID)) + } + var clientOptions []grpc.DialOption + if maxRequestBody != nil { + if maxRequestBody.RequestMaxBodySize > 0 { + options = append(options, grpc.MaxRecvMsgSize(maxRequestBody.RequestMaxBodySize)) + clientOptions = append(clientOptions, grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxRequestBody.RequestMaxBodySize))) + } + if maxRequestBody.ResponseMaxBodySize > 0 { + options = append(options, grpc.MaxSendMsgSize(maxRequestBody.ResponseMaxBodySize)) + clientOptions = append(clientOptions, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxRequestBody.ResponseMaxBodySize))) + } + } + registerIP, err := network.GetRpcRegisterIP(registerIP) if err != nil { return err } - - runTimeEnv := runtimeenv.RuntimeEnvironment() - - if !autoSetPorts { - rpcPort, err := datautil.GetElemByIndex(rpcPorts, index) + var prometheusListenAddr string + if autoSetPorts { + prometheusListenAddr = net.JoinHostPort(listenIP, "0") + } else { + prometheusPort, err := datautil.GetElemByIndex(prometheusConfig.Ports, index) if err != nil { return err } - rpcTcpAddr = net.JoinHostPort(network.GetListenIP(listenIP), strconv.Itoa(rpcPort)) - } else { - rpcTcpAddr = net.JoinHostPort(network.GetListenIP(listenIP), "0") + prometheusListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(prometheusPort)) } - getAutoPort := func() (net.Listener, int, error) { - listener, err := net.Listen("tcp", rpcTcpAddr) - if err != nil { - return nil, 0, errs.WrapMsg(err, "listen err", "rpcTcpAddr", rpcTcpAddr) - } - _, portStr, _ := net.SplitHostPort(listener.Addr().String()) - port, _ := strconv.Atoi(portStr) - return listener, port, nil - } + watchConfigNames = append(watchConfigNames, conf.LogConfigFileName) - if autoSetPorts && discovery.Enable != conf.ETCD { - return errs.New("only etcd support autoSetPorts", "rpcRegisterName", rpcRegisterName).Wrap() - } - client, err := kdisc.NewDiscoveryRegister(discovery, runTimeEnv, watchServiceNames) + client, err := kdisc.NewDiscoveryRegister(disc, watchServiceNames) if err != nil { return err } defer client.Close() - client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) - - // var reg *prometheus.Registry - // var metric *grpcprometheus.ServerMetrics - if prometheusConfig.Enable { - // cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName, share) - // reg, metric, _ = prommetrics.NewGrpcPromObj(cusMetrics) - // options = append(options, mw.GrpcServer(), grpc.StreamInterceptor(metric.StreamServerInterceptor()), - // grpc.UnaryInterceptor(metric.UnaryServerInterceptor())) - options = append( - options, mw.GrpcServer(), - prommetricsUnaryInterceptor(rpcRegisterName), - prommetricsStreamInterceptor(rpcRegisterName), - ) + client.AddOption( + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")), - var ( - listener net.Listener - ) + grpccli.GrpcClientLogger(), + grpccli.GrpcClientContext(), + grpccli.GrpcClientErrorConvert(), + ) + if len(clientOptions) > 0 { + client.AddOption(clientOptions...) + } - if autoSetPorts { - listener, prometheusPort, err = getAutoPort() - if err != nil { - return err - } + ctx, cancel := context.WithCancelCause(ctx) - etcdClient := client.(*etcd.SvcDiscoveryRegistryImpl).GetClient() + go func() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL) + select { + case <-ctx.Done(): + return + case val := <-sigs: + log.ZDebug(ctx, "recv signal", "signal", val.String()) + cancel(fmt.Errorf("signal %s", val.String())) + } + }() - _, err = etcdClient.Put(ctx, prommetrics.BuildDiscoveryKey(rpcRegisterName), jsonutil.StructToJsonString(prommetrics.BuildDefaultTarget(registerIP, prometheusPort))) - if err != nil { - return errs.WrapMsg(err, "etcd put err") - } - } else { - prometheusPort, err = datautil.GetElemByIndex(prometheusConfig.Ports, index) - if err != nil { + if prometheusListenAddr != "" { + options = append( + options, + prommetricsUnaryInterceptor(rpcRegisterName), + prommetricsStreamInterceptor(rpcRegisterName), + ) + prometheusListener, prometheusPort, err := listenTCP(prometheusListenAddr) + if err != nil { + return err + } + log.ZDebug(ctx, "prometheus start", "addr", prometheusListener.Addr(), "rpcRegisterName", rpcRegisterName) + target, err := jsonutil.JsonMarshal(prommetrics.BuildDefaultTarget(registerIP, prometheusPort)) + if err != nil { + return err + } + if err := client.SetKey(ctx, prommetrics.BuildDiscoveryKey(prommetrics.APIKeyName), target); err != nil { + if !errors.Is(err, discovery.ErrNotSupportedKeyValue) { return err } - listener, err = net.Listen("tcp", fmt.Sprintf(":%d", prometheusPort)) - if err != nil { - return errs.WrapMsg(err, "listen err", "rpcTcpAddr", rpcTcpAddr) - } } - cs := prommetrics.GetGrpcCusMetrics(rpcRegisterName, discovery) go func() { - if err := prommetrics.RpcInit(cs, listener); err != nil && !errors.Is(err, http.ErrServerClosed) { - netErr = errs.WrapMsg(err, fmt.Sprintf("rpc %s prometheus start err: %d", rpcRegisterName, prometheusPort)) - netDone <- struct{}{} + err := prommetrics.Start(prometheusListener) + if err == nil { + err = fmt.Errorf("listener done") } - //metric.InitializeMetrics(srv) - // Create a HTTP server for prometheus. - // httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)} - // if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - // netErr = errs.WrapMsg(err, "prometheus start err", httpServer.Addr) - // netDone <- struct{}{} - // } + cancel(fmt.Errorf("prommetrics %s %w", rpcRegisterName, err)) }() - } else { - options = append(options, mw.GrpcServer()) } - listener, port, err := getAutoPort() - if err != nil { - return err - } + var ( + rpcServer *grpc.Server + rpcGracefulStop chan struct{} + ) - log.CInfo(ctx, "RPC server is initializing", "rpcRegisterName", rpcRegisterName, "rpcPort", port, - "prometheusPort", prometheusPort) + onGrpcServiceRegistrar := func(desc *grpc.ServiceDesc, impl any) { + if rpcServer != nil { + rpcServer.RegisterService(desc, impl) + return + } + var rpcListenAddr string + if autoSetPorts { + rpcListenAddr = net.JoinHostPort(listenIP, "0") + } else { + rpcPort, err := datautil.GetElemByIndex(rpcPorts, index) + if err != nil { + cancel(fmt.Errorf("rpcPorts index out of range %s %w", rpcRegisterName, err)) + return + } + rpcListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(rpcPort)) + } + rpcListener, err := net.Listen("tcp", rpcListenAddr) + if err != nil { + cancel(fmt.Errorf("listen rpc %s %s %w", rpcRegisterName, rpcListenAddr, err)) + return + } - defer listener.Close() - srv := grpc.NewServer(options...) + rpcServer = grpc.NewServer(options...) + rpcServer.RegisterService(desc, impl) + rpcGracefulStop = make(chan struct{}) + rpcPort := rpcListener.Addr().(*net.TCPAddr).Port + log.ZDebug(ctx, "rpc start register", "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "rpcPort", rpcPort) + grpcOpt := grpc.WithTransportCredentials(insecure.NewCredentials()) + rpcGracefulStop = make(chan struct{}) + go func() { + <-ctx.Done() + rpcServer.GracefulStop() + close(rpcGracefulStop) + }() + if err := client.Register(ctx, rpcRegisterName, registerIP, rpcListener.Addr().(*net.TCPAddr).Port, grpcOpt); err != nil { + cancel(fmt.Errorf("rpc register %s %w", rpcRegisterName, err)) + return + } - err = rpcFn(ctx, config, client, srv) - if err != nil { - return err + go func() { + err := rpcServer.Serve(rpcListener) + if err == nil { + err = fmt.Errorf("serve end") + } + cancel(fmt.Errorf("rpc %s %w", rpcRegisterName, err)) + }() } - err = client.Register( - ctx, - rpcRegisterName, - registerIP, - port, - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) + err = rpcFn(ctx, config, client, &grpcServiceRegistrar{onRegisterService: onGrpcServiceRegistrar}) if err != nil { return err } - - go func() { - err := srv.Serve(listener) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - netErr = errs.WrapMsg(err, "rpc start err: ", rpcTcpAddr) - netDone <- struct{}{} - } - }() - - if discovery.Enable == conf.ETCD { - cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), watchConfigNames) - cm.Watch(ctx) - } - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) - select { - case <-sigs: - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := gracefulStopWithCtx(ctx, srv.GracefulStop); err != nil { - return err + <-ctx.Done() + log.ZDebug(ctx, "cmd wait done", "err", context.Cause(ctx)) + if rpcGracefulStop != nil { + timeout := time.NewTimer(time.Second * 15) + defer timeout.Stop() + select { + case <-timeout.C: + log.ZWarn(ctx, "rcp graceful stop timeout", nil) + case <-rpcGracefulStop: + log.ZDebug(ctx, "rcp graceful stop done") } - return nil - case <-netDone: - return netErr } + return context.Cause(ctx) } -func gracefulStopWithCtx(ctx context.Context, f func()) error { - done := make(chan struct{}, 1) - go func() { - f() - close(done) - }() - select { - case <-ctx.Done(): - return errs.New("timeout, ctx graceful stop") - case <-done: - return nil +func listenTCP(addr string) (net.Listener, int, error) { + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, 0, errs.WrapMsg(err, "listen err", "addr", addr) } + return listener, listener.Addr().(*net.TCPAddr).Port, nil } func prommetricsUnaryInterceptor(rpcRegisterName string) grpc.ServerOption { @@ -250,3 +321,11 @@ func prommetricsUnaryInterceptor(rpcRegisterName string) grpc.ServerOption { func prommetricsStreamInterceptor(rpcRegisterName string) grpc.ServerOption { return grpc.ChainStreamInterceptor() } + +type grpcServiceRegistrar struct { + onRegisterService func(desc *grpc.ServiceDesc, impl any) +} + +func (x *grpcServiceRegistrar) RegisterService(desc *grpc.ServiceDesc, impl any) { + x.onRegisterService(desc, impl) +} diff --git a/pkg/common/storage/cache/cachekey/client_config.go b/pkg/common/storage/cache/cachekey/client_config.go new file mode 100644 index 000000000..16770adef --- /dev/null +++ b/pkg/common/storage/cache/cachekey/client_config.go @@ -0,0 +1,10 @@ +package cachekey + +const ClientConfig = "CLIENT_CONFIG" + +func GetClientConfigKey(userID string) string { + if userID == "" { + return ClientConfig + } + return ClientConfig + ":" + userID +} diff --git a/pkg/common/storage/cache/cachekey/token.go b/pkg/common/storage/cache/cachekey/token.go index 83ba2f211..6fe1bdfef 100644 --- a/pkg/common/storage/cache/cachekey/token.go +++ b/pkg/common/storage/cache/cachekey/token.go @@ -1,8 +1,9 @@ package cachekey import ( - "github.com/openimsdk/protocol/constant" "strings" + + "github.com/openimsdk/protocol/constant" ) const ( @@ -13,6 +14,10 @@ func GetTokenKey(userID string, platformID int) string { return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID) } +func GetTemporaryTokenKey(userID string, platformID int, token string) string { + return UidPidToken + ":TEMPORARY:" + userID + ":" + constant.PlatformIDToName(platformID) + ":" + token +} + func GetAllPlatformTokenKey(userID string) []string { res := make([]string, len(constant.PlatformID2Name)) for k := range constant.PlatformID2Name { diff --git a/pkg/common/storage/cache/client_config.go b/pkg/common/storage/cache/client_config.go new file mode 100644 index 000000000..329f25c59 --- /dev/null +++ b/pkg/common/storage/cache/client_config.go @@ -0,0 +1,8 @@ +package cache + +import "context" + +type ClientConfigCache interface { + DeleteUserCache(ctx context.Context, userIDs []string) error + GetUserConfig(ctx context.Context, userID string) (map[string]string, error) +} diff --git a/pkg/common/storage/cache/mcache/token.go b/pkg/common/storage/cache/mcache/token.go new file mode 100644 index 000000000..98b9cc066 --- /dev/null +++ b/pkg/common/storage/cache/mcache/token.go @@ -0,0 +1,166 @@ +package mcache + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/log" +) + +func NewTokenCacheModel(cache database.Cache, accessExpire int64) cache.TokenModel { + c := &tokenCache{cache: cache} + c.accessExpire = c.getExpireTime(accessExpire) + return c +} + +type tokenCache struct { + cache database.Cache + accessExpire time.Duration +} + +func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string { + return cachekey.GetTokenKey(userID, platformID) + ":" + token +} + +func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error { + return x.cache.Set(ctx, x.getTokenKey(userID, platformID, token), strconv.Itoa(flag), x.accessExpire) +} + +// SetTokenFlagEx set token and flag with expire time +func (x *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error { + return x.SetTokenFlag(ctx, userID, platformID, token, flag) +} + +func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) { + prefix := x.getTokenKey(userID, platformID, "") + m, err := x.cache.Prefix(ctx, prefix) + if err != nil { + return nil, errs.Wrap(err) + } + mm := make(map[string]int) + for k, v := range m { + state, err := strconv.Atoi(v) + if err != nil { + log.ZError(ctx, "token value is not int", err, "value", v, "userID", userID, "platformID", platformID) + continue + } + mm[strings.TrimPrefix(k, prefix)] = state + } + return mm, nil +} + +func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error { + key := cachekey.GetTemporaryTokenKey(userID, platformID, token) + if _, err := x.cache.Get(ctx, []string{key}); err != nil { + return err + } + return nil +} + +func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { + prefix := cachekey.UidPidToken + userID + ":" + tokens, err := x.cache.Prefix(ctx, prefix) + if err != nil { + return nil, err + } + res := make(map[int]map[string]int) + for key, flagStr := range tokens { + flag, err := strconv.Atoi(flagStr) + if err != nil { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + arr := strings.SplitN(strings.TrimPrefix(key, prefix), ":", 2) + if len(arr) != 2 { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + platformID, err := strconv.Atoi(arr[0]) + if err != nil { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + token := arr[1] + if token == "" { + log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID) + continue + } + tk, ok := res[platformID] + if !ok { + tk = make(map[string]int) + res[platformID] = tk + } + tk[token] = flag + } + return res, nil +} + +func (x *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { + for token, flag := range m { + err := x.SetTokenFlag(ctx, userID, platformID, token, flag) + if err != nil { + return err + } + } + return nil +} + +func (x *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { + for prefix, tokenFlag := range tokens { + for token, flag := range tokenFlag { + flagStr := fmt.Sprintf("%v", flag) + if err := x.cache.Set(ctx, prefix+":"+token, flagStr, x.accessExpire); err != nil { + return err + } + } + } + return nil +} + +func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error { + keys := make([]string, 0, len(fields)) + for _, token := range fields { + keys = append(keys, x.getTokenKey(userID, platformID, token)) + } + return x.cache.Del(ctx, keys) +} + +func (x *tokenCache) getExpireTime(t int64) time.Duration { + return time.Hour * 24 * time.Duration(t) +} + +func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error { + keys := make([]string, 0, len(tokens)) + for platformID, ts := range tokens { + for _, t := range ts { + keys = append(keys, x.getTokenKey(userID, platformID, t)) + } + } + return x.cache.Del(ctx, keys) +} + +func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error { + keys := make([]string, 0, len(fields)) + for _, f := range fields { + keys = append(keys, x.getTokenKey(userID, platformID, f)) + } + if err := x.cache.Del(ctx, keys); err != nil { + return err + } + + for _, f := range fields { + k := cachekey.GetTemporaryTokenKey(userID, platformID, f) + if err := x.cache.Set(ctx, k, "", time.Minute*5); err != nil { + return errs.Wrap(err) + } + } + + return nil +} diff --git a/pkg/common/storage/cache/redis/client_config.go b/pkg/common/storage/cache/redis/client_config.go new file mode 100644 index 000000000..c5a455146 --- /dev/null +++ b/pkg/common/storage/cache/redis/client_config.go @@ -0,0 +1,69 @@ +package redis + +import ( + "context" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/redis/go-redis/v9" +) + +func NewClientConfigCache(rdb redis.UniversalClient, mgo database.ClientConfig) cache.ClientConfigCache { + rc := newRocksCacheClient(rdb) + return &ClientConfigCache{ + mgo: mgo, + rcClient: rc, + delete: rc.GetBatchDeleter(), + } +} + +type ClientConfigCache struct { + mgo database.ClientConfig + rcClient *rocksCacheClient + delete cache.BatchDeleter +} + +func (x *ClientConfigCache) getExpireTime(userID string) time.Duration { + if userID == "" { + return time.Hour * 24 + } else { + return time.Hour + } +} + +func (x *ClientConfigCache) getClientConfigKey(userID string) string { + return cachekey.GetClientConfigKey(userID) +} + +func (x *ClientConfigCache) GetConfig(ctx context.Context, userID string) (map[string]string, error) { + return getCache(ctx, x.rcClient, x.getClientConfigKey(userID), x.getExpireTime(userID), func(ctx context.Context) (map[string]string, error) { + return x.mgo.Get(ctx, userID) + }) +} + +func (x *ClientConfigCache) DeleteUserCache(ctx context.Context, userIDs []string) error { + keys := make([]string, 0, len(userIDs)) + for _, userID := range userIDs { + keys = append(keys, x.getClientConfigKey(userID)) + } + return x.delete.ExecDelWithKeys(ctx, keys) +} + +func (x *ClientConfigCache) GetUserConfig(ctx context.Context, userID string) (map[string]string, error) { + config, err := x.GetConfig(ctx, "") + if err != nil { + return nil, err + } + if userID != "" { + userConfig, err := x.GetConfig(ctx, userID) + if err != nil { + return nil, err + } + for k, v := range userConfig { + config[k] = v + } + } + return config, nil +} diff --git a/pkg/common/storage/cache/redis/token.go b/pkg/common/storage/cache/redis/token.go index 510da43e3..b3870daee 100644 --- a/pkg/common/storage/cache/redis/token.go +++ b/pkg/common/storage/cache/redis/token.go @@ -9,6 +9,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/utils/datautil" "github.com/redis/go-redis/v9" ) @@ -55,6 +56,14 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p return mm, nil } +func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error { + err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err() + if err != nil { + return errs.Wrap(err) + } + return nil +} + func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) { var ( res = make(map[int]map[string]int) @@ -101,13 +110,19 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla } func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error { - pipe := c.rdb.Pipeline() - for k, v := range tokens { - pipe.HSet(ctx, k, v) - } - _, err := pipe.Exec(ctx) - if err != nil { - return errs.Wrap(err) + keys := datautil.Keys(tokens) + if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + pipe := c.rdb.Pipeline() + for k, v := range tokens { + pipe.HSet(ctx, k, v) + } + _, err := pipe.Exec(ctx) + if err != nil { + return errs.Wrap(err) + } + return nil + }); err != nil { + return err } return nil } @@ -119,3 +134,47 @@ func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla func (c *tokenCache) getExpireTime(t int64) time.Duration { return time.Hour * 24 * time.Duration(t) } + +// DeleteTokenByTokenMap tokens key is platformID, value is token slice +func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error { + var ( + keys = make([]string, 0, len(tokens)) + keyMap = make(map[string][]string) + ) + for k, v := range tokens { + k1 := cachekey.GetTokenKey(userID, k) + keys = append(keys, k1) + keyMap[k1] = v + } + + if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + pipe := c.rdb.Pipeline() + for k, v := range tokens { + pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...) + } + _, err := pipe.Exec(ctx) + if err != nil { + return errs.Wrap(err) + } + return nil + }); err != nil { + return err + } + + return nil +} + +func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error { + key := cachekey.GetTokenKey(userID, platformID) + if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil { + return errs.Wrap(err) + } + for _, f := range fields { + k := cachekey.GetTemporaryTokenKey(userID, platformID, f) + if err := c.rdb.Set(ctx, k, "", time.Minute*5).Err(); err != nil { + return errs.Wrap(err) + } + } + + return nil +} diff --git a/pkg/common/storage/cache/token.go b/pkg/common/storage/cache/token.go index e5e0a9383..441c08939 100644 --- a/pkg/common/storage/cache/token.go +++ b/pkg/common/storage/cache/token.go @@ -9,8 +9,11 @@ type TokenModel interface { // SetTokenFlagEx set token and flag with expire time SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) + HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error + DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error + DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error } diff --git a/pkg/common/storage/controller/auth.go b/pkg/common/storage/controller/auth.go index f9061a73b..496a434bf 100644 --- a/pkg/common/storage/controller/auth.go +++ b/pkg/common/storage/controller/auth.go @@ -17,6 +17,8 @@ import ( type AuthDatabase interface { // If the result is empty, no error is returned. GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) + + GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error // Create token CreateToken(ctx context.Context, userID string, platformID int) (string, error) @@ -51,6 +53,10 @@ func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, return a.cache.GetTokensWithoutError(ctx, userID, platformID) } +func (a *authDatabase) GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error { + return a.cache.HasTemporaryToken(ctx, userID, platformID, token) +} + func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error { return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m) } @@ -86,19 +92,20 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return "", err } - deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) + deleteTokenKey, kickedTokenKey, adminTokens, err := a.checkToken(ctx, tokens, platformID) + if err != nil { + return "", err + } + if len(deleteTokenKey) != 0 { + err = a.cache.DeleteTokenByTokenMap(ctx, userID, deleteTokenKey) if err != nil { return "", err } - if len(deleteTokenKey) != 0 { - err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) - if err != nil { - return "", err - } - } - if len(kickedTokenKey) != 0 { - for _, k := range kickedTokenKey { - err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) + } + if len(kickedTokenKey) != 0 { + for plt, ks := range kickedTokenKey { + for _, k := range ks { + err := a.cache.SetTokenFlagEx(ctx, userID, plt, k, constant.KickedToken) if err != nil { return "", err } @@ -106,6 +113,11 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI } } } + if len(adminTokens) != 0 { + if err = a.cache.DeleteAndSetTemporary(ctx, userID, constant.AdminPlatformID, adminTokens); err != nil { + return "", err + } + } claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) @@ -123,12 +135,13 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return tokenString, nil } -func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) { - // todo: Move the logic for handling old data to another location. +// checkToken will check token by tokenPolicy and return deleteToken,kickToken,deleteAdminToken +func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) (map[int][]string, map[int][]string, []string, error) { + // todo: Asynchronous deletion of old data. var ( loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0 - deleteToken = make([]string, 0) - kickToken = make([]string, 0) + deleteToken = make(map[int][]string) + kickToken = make(map[int][]string) adminToken = make([]string, 0) unkickTerminal = "" ) @@ -137,7 +150,7 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string for k, v := range tks { _, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret)) if err != nil || v != constant.NormalToken { - deleteToken = append(deleteToken, k) + deleteToken[plfID] = append(deleteToken[plfID], k) } else { if plfID != constant.AdminPlatformID { loginTokenMap[plfID] = append(loginTokenMap[plfID], k) @@ -157,14 +170,15 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string } limit := a.multiLogin.MaxNumOneEnd if l > limit { - kickToken = append(kickToken, ts[:l-limit]...) + kickToken[plt] = ts[:l-limit] } } case constant.AllLoginButSameTermKick: for plt, ts := range loginTokenMap { - kickToken = append(kickToken, ts[:len(ts)-1]...) + kickToken[plt] = ts[:len(ts)-1] + if plt == platformID { - kickToken = append(kickToken, ts[len(ts)-1]) + kickToken[plt] = append(kickToken[plt], ts[len(ts)-1]) } } case constant.PCAndOther: @@ -172,29 +186,33 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string if constant.PlatformIDToClass(platformID) != unkickTerminal { for plt, ts := range loginTokenMap { if constant.PlatformIDToClass(plt) != unkickTerminal { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } } } else { var ( - preKick []string - isReserve = true + preKickToken string + preKickPlt int + reserveToken = false ) for plt, ts := range loginTokenMap { if constant.PlatformIDToClass(plt) != unkickTerminal { // Keep a token from another end - if isReserve { - isReserve = false - kickToken = append(kickToken, ts[:len(ts)-1]...) - preKick = append(preKick, ts[len(ts)-1]) + if !reserveToken { + reserveToken = true + kickToken[plt] = ts[:len(ts)-1] + preKickToken = ts[len(ts)-1] + preKickPlt = plt continue } else { // Prioritize keeping Android if plt == constant.AndroidPlatformID { - kickToken = append(kickToken, preKick...) - kickToken = append(kickToken, ts[:len(ts)-1]...) + if preKickToken != "" { + kickToken[preKickPlt] = append(kickToken[preKickPlt], preKickToken) + } + kickToken[plt] = ts[:len(ts)-1] } else { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } } } @@ -207,19 +225,19 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string for plt, ts := range loginTokenMap { if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } else { if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok { reserved[constant.PlatformIDToClass(plt)] = struct{}{} - kickToken = append(kickToken, ts[:len(ts)-1]...) + kickToken[plt] = ts[:len(ts)-1] continue } else { - kickToken = append(kickToken, ts...) + kickToken[plt] = ts } } } default: - return nil, nil, errs.New("unknown multiLogin policy").Wrap() + return nil, nil, nil, errs.New("unknown multiLogin policy").Wrap() } //var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd @@ -233,5 +251,9 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string //if l > adminTokenMaxNum { // kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) //} - return deleteToken, kickToken, nil + var deleteAdminToken []string + if platformID == constant.AdminPlatformID { + deleteAdminToken = adminToken + } + return deleteToken, kickToken, deleteAdminToken, nil } diff --git a/pkg/common/storage/controller/client_config.go b/pkg/common/storage/controller/client_config.go new file mode 100644 index 000000000..1c3787634 --- /dev/null +++ b/pkg/common/storage/controller/client_config.go @@ -0,0 +1,58 @@ +package controller + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/db/pagination" + "github.com/openimsdk/tools/db/tx" +) + +type ClientConfigDatabase interface { + SetUserConfig(ctx context.Context, userID string, config map[string]string) error + GetUserConfig(ctx context.Context, userID string) (map[string]string, error) + DelUserConfig(ctx context.Context, userID string, keys []string) error + GetUserConfigPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) +} + +func NewClientConfigDatabase(db database.ClientConfig, cache cache.ClientConfigCache, tx tx.Tx) ClientConfigDatabase { + return &clientConfigDatabase{ + tx: tx, + db: db, + cache: cache, + } +} + +type clientConfigDatabase struct { + tx tx.Tx + db database.ClientConfig + cache cache.ClientConfigCache +} + +func (x *clientConfigDatabase) SetUserConfig(ctx context.Context, userID string, config map[string]string) error { + return x.tx.Transaction(ctx, func(ctx context.Context) error { + if err := x.db.Set(ctx, userID, config); err != nil { + return err + } + return x.cache.DeleteUserCache(ctx, []string{userID}) + }) +} + +func (x *clientConfigDatabase) GetUserConfig(ctx context.Context, userID string) (map[string]string, error) { + return x.cache.GetUserConfig(ctx, userID) +} + +func (x *clientConfigDatabase) DelUserConfig(ctx context.Context, userID string, keys []string) error { + return x.tx.Transaction(ctx, func(ctx context.Context) error { + if err := x.db.Del(ctx, userID, keys); err != nil { + return err + } + return x.cache.DeleteUserCache(ctx, []string{userID}) + }) +} + +func (x *clientConfigDatabase) GetUserConfigPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) { + return x.db.GetPage(ctx, userID, key, pagination) +} diff --git a/pkg/common/storage/controller/s3.go b/pkg/common/storage/controller/s3.go index 30d8d20ec..9ab31c5a6 100644 --- a/pkg/common/storage/controller/s3.go +++ b/pkg/common/storage/controller/s3.go @@ -33,7 +33,7 @@ type S3Database interface { PartLimit() (*s3.PartLimit, error) PartSize(ctx context.Context, size int64) (int64, error) AuthSign(ctx context.Context, uploadID string, partNumbers []int) (*s3.AuthSignResult, error) - InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int) (*cont.InitiateUploadResult, error) + InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int, contentType string) (*cont.InitiateUploadResult, error) CompleteMultipartUpload(ctx context.Context, uploadID string, parts []string) (*cont.UploadResult, error) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (time.Time, string, error) SetObject(ctx context.Context, info *model.Object) error @@ -73,8 +73,8 @@ func (s *s3Database) AuthSign(ctx context.Context, uploadID string, partNumbers return s.s3.AuthSign(ctx, uploadID, partNumbers) } -func (s *s3Database) InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int) (*cont.InitiateUploadResult, error) { - return s.s3.InitiateUpload(ctx, hash, size, expire, maxParts) +func (s *s3Database) InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int, contentType string) (*cont.InitiateUploadResult, error) { + return s.s3.InitiateUploadContentType(ctx, hash, size, expire, maxParts, contentType) } func (s *s3Database) CompleteMultipartUpload(ctx context.Context, uploadID string, parts []string) (*cont.UploadResult, error) { diff --git a/pkg/common/storage/database/client_config.go b/pkg/common/storage/database/client_config.go new file mode 100644 index 000000000..7fa888d24 --- /dev/null +++ b/pkg/common/storage/database/client_config.go @@ -0,0 +1,15 @@ +package database + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/db/pagination" +) + +type ClientConfig interface { + Set(ctx context.Context, userID string, config map[string]string) error + Get(ctx context.Context, userID string) (map[string]string, error) + Del(ctx context.Context, userID string, keys []string) error + GetPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) +} diff --git a/pkg/common/storage/database/mgo/cache.go b/pkg/common/storage/database/mgo/cache.go new file mode 100644 index 000000000..991dfa874 --- /dev/null +++ b/pkg/common/storage/database/mgo/cache.go @@ -0,0 +1,183 @@ +package mgo + +import ( + "context" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/db/mongoutil" + "github.com/openimsdk/tools/errs" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func NewCacheMgo(db *mongo.Database) (*CacheMgo, error) { + coll := db.Collection(database.CacheName) + _, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{ + { + Keys: bson.D{ + {Key: "key", Value: 1}, + }, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{ + {Key: "expire_at", Value: 1}, + }, + Options: options.Index().SetExpireAfterSeconds(0), + }, + }) + if err != nil { + return nil, errs.Wrap(err) + } + return &CacheMgo{coll: coll}, nil +} + +type CacheMgo struct { + coll *mongo.Collection +} + +func (x *CacheMgo) findToMap(res []model.Cache, now time.Time) map[string]string { + kv := make(map[string]string) + for _, re := range res { + if re.ExpireAt != nil && re.ExpireAt.Before(now) { + continue + } + kv[re.Key] = re.Value + } + return kv + +} + +func (x *CacheMgo) Get(ctx context.Context, key []string) (map[string]string, error) { + if len(key) == 0 { + return nil, nil + } + now := time.Now() + res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{ + "key": bson.M{"$in": key}, + "$or": []bson.M{ + {"expire_at": bson.M{"$gt": now}}, + {"expire_at": nil}, + }, + }) + if err != nil { + return nil, err + } + return x.findToMap(res, now), nil +} + +func (x *CacheMgo) Prefix(ctx context.Context, prefix string) (map[string]string, error) { + now := time.Now() + res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{ + "key": bson.M{"$regex": "^" + prefix}, + "$or": []bson.M{ + {"expire_at": bson.M{"$gt": now}}, + {"expire_at": nil}, + }, + }) + if err != nil { + return nil, err + } + return x.findToMap(res, now), nil +} + +func (x *CacheMgo) Set(ctx context.Context, key string, value string, expireAt time.Duration) error { + cv := &model.Cache{ + Key: key, + Value: value, + } + if expireAt > 0 { + now := time.Now().Add(expireAt) + cv.ExpireAt = &now + } + opt := options.Update().SetUpsert(true) + return mongoutil.UpdateOne(ctx, x.coll, bson.M{"key": key}, bson.M{"$set": cv}, false, opt) +} + +func (x *CacheMgo) Incr(ctx context.Context, key string, value int) (int, error) { + pipeline := mongo.Pipeline{ + { + {"$set", bson.M{ + "value": bson.M{ + "$toString": bson.M{ + "$add": bson.A{ + bson.M{"$toInt": "$value"}, + value, + }, + }, + }, + }}, + }, + } + opt := options.FindOneAndUpdate().SetReturnDocument(options.After) + res, err := mongoutil.FindOneAndUpdate[model.Cache](ctx, x.coll, bson.M{"key": key}, pipeline, opt) + if err != nil { + return 0, err + } + return strconv.Atoi(res.Value) +} + +func (x *CacheMgo) Del(ctx context.Context, key []string) error { + if len(key) == 0 { + return nil + } + _, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}}) + return errs.Wrap(err) +} + +func (x *CacheMgo) lockKey(key string) string { + return "LOCK_" + key +} + +func (x *CacheMgo) Lock(ctx context.Context, key string, duration time.Duration) (string, error) { + tmp, err := uuid.NewUUID() + if err != nil { + return "", err + } + if duration <= 0 || duration > time.Minute*10 { + duration = time.Minute * 10 + } + cv := &model.Cache{ + Key: x.lockKey(key), + Value: tmp.String(), + ExpireAt: nil, + } + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + wait := func() error { + timeout := time.NewTimer(time.Millisecond * 100) + defer timeout.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timeout.C: + return nil + } + } + for { + if err := mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": key, "expire_at": bson.M{"$lt": time.Now()}}); err != nil { + return "", err + } + expireAt := time.Now().Add(duration) + cv.ExpireAt = &expireAt + if err := mongoutil.InsertMany[*model.Cache](ctx, x.coll, []*model.Cache{cv}); err != nil { + if mongo.IsDuplicateKeyError(err) { + if err := wait(); err != nil { + return "", err + } + continue + } + return "", err + } + return cv.Value, nil + } +} + +func (x *CacheMgo) Unlock(ctx context.Context, key string, value string) error { + return mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": x.lockKey(key), "value": value}) +} diff --git a/pkg/common/storage/database/mgo/client_config.go b/pkg/common/storage/database/mgo/client_config.go new file mode 100644 index 000000000..0aa462899 --- /dev/null +++ b/pkg/common/storage/database/mgo/client_config.go @@ -0,0 +1,99 @@ +// Copyright © 2023 OpenIM open source community. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mgo + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/db/mongoutil" + "github.com/openimsdk/tools/db/pagination" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + + "github.com/openimsdk/tools/errs" +) + +func NewClientConfig(db *mongo.Database) (database.ClientConfig, error) { + coll := db.Collection("config") + _, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{ + { + Keys: bson.D{ + {Key: "key", Value: 1}, + {Key: "user_id", Value: 1}, + }, + Options: options.Index().SetUnique(true), + }, + }) + if err != nil { + return nil, errs.Wrap(err) + } + return &ClientConfig{ + coll: coll, + }, nil +} + +type ClientConfig struct { + coll *mongo.Collection +} + +func (x *ClientConfig) Set(ctx context.Context, userID string, config map[string]string) error { + if len(config) == 0 { + return nil + } + for key, value := range config { + filter := bson.M{"key": key, "user_id": userID} + update := bson.M{ + "value": value, + } + err := mongoutil.UpdateOne(ctx, x.coll, filter, bson.M{"$set": update}, false, options.Update().SetUpsert(true)) + if err != nil { + return err + } + } + return nil +} + +func (x *ClientConfig) Get(ctx context.Context, userID string) (map[string]string, error) { + cs, err := mongoutil.Find[*model.ClientConfig](ctx, x.coll, bson.M{"user_id": userID}) + if err != nil { + return nil, err + } + cm := make(map[string]string) + for _, config := range cs { + cm[config.Key] = config.Value + } + return cm, nil +} + +func (x *ClientConfig) Del(ctx context.Context, userID string, keys []string) error { + if len(keys) == 0 { + return nil + } + return mongoutil.DeleteMany(ctx, x.coll, bson.M{"key": bson.M{"$in": keys}, "user_id": userID}) +} + +func (x *ClientConfig) GetPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) { + filter := bson.M{} + if userID != "" { + filter["user_id"] = userID + } + if key != "" { + filter["key"] = key + } + return mongoutil.FindPage[*model.ClientConfig](ctx, x.coll, filter, pagination) +} diff --git a/pkg/common/storage/model/client_config.go b/pkg/common/storage/model/client_config.go new file mode 100644 index 000000000..f06e29102 --- /dev/null +++ b/pkg/common/storage/model/client_config.go @@ -0,0 +1,7 @@ +package model + +type ClientConfig struct { + Key string `bson:"key"` + UserID string `bson:"user_id"` + Value string `bson:"value"` +} diff --git a/pkg/rpccache/conversation.go b/pkg/rpccache/conversation.go index 70f5acfd1..162fda596 100644 --- a/pkg/rpccache/conversation.go +++ b/pkg/rpccache/conversation.go @@ -16,6 +16,7 @@ package rpccache import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/localcache" @@ -153,6 +154,26 @@ func (c *ConversationLocalCache) getConversationNotReceiveMessageUserIDs(ctx con })) } +func (c *ConversationLocalCache) getPinnedConversationIDs(ctx context.Context, userID string) (val []string, err error) { + log.ZDebug(ctx, "ConversationLocalCache getPinnedConversations req", "userID", userID) + defer func() { + if err == nil { + log.ZDebug(ctx, "ConversationLocalCache getPinnedConversations return", "userID", userID, "value", val) + } else { + log.ZError(ctx, "ConversationLocalCache getPinnedConversations return", err, "userID", userID) + } + }() + var cache cacheProto[pbconversation.GetPinnedConversationIDsResp] + resp, err := cache.Unmarshal(c.local.Get(ctx, cachekey.GetPinnedConversationIDs(userID), func(ctx context.Context) ([]byte, error) { + log.ZDebug(ctx, "ConversationLocalCache getConversationNotReceiveMessageUserIDs rpc", "userID", userID) + return cache.Marshal(c.client.ConversationClient.GetPinnedConversationIDs(ctx, &pbconversation.GetPinnedConversationIDsReq{UserID: userID})) + })) + if err != nil { + return nil, err + } + return resp.ConversationIDs, nil +} + func (c *ConversationLocalCache) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) { res, err := c.getConversationNotReceiveMessageUserIDs(ctx, conversationID) if err != nil { @@ -168,3 +189,7 @@ func (c *ConversationLocalCache) GetConversationNotReceiveMessageUserIDMap(ctx c } return datautil.SliceSet(res.UserIDs), nil } + +func (c *ConversationLocalCache) GetPinnedConversationIDs(ctx context.Context, userID string) ([]string, error) { + return c.getPinnedConversationIDs(ctx, userID) +} diff --git a/test/stress-test-v2/README.md b/test/stress-test-v2/README.md new file mode 100644 index 000000000..cbd4bdbde --- /dev/null +++ b/test/stress-test-v2/README.md @@ -0,0 +1,19 @@ +# Stress Test V2 + +## Usage + +You need set `TestTargetUserList` variables. + +### Build + +```bash + +go build -o test/stress-test-v2/stress-test-v2 test/stress-test-v2/main.go +``` + +### Excute + +```bash + +tools/stress-test-v2/stress-test-v2 -c config/ +``` diff --git a/test/stress-test-v2/main.go b/test/stress-test-v2/main.go new file mode 100644 index 000000000..0e4609964 --- /dev/null +++ b/test/stress-test-v2/main.go @@ -0,0 +1,759 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/apistruct" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/protocol/auth" + "github.com/openimsdk/protocol/constant" + "github.com/openimsdk/protocol/group" + "github.com/openimsdk/protocol/sdkws" + pbuser "github.com/openimsdk/protocol/user" + "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/system/program" +) + +// 1. Create 100K New Users +// 2. Create 100 100K Groups +// 3. Create 1000 999 Groups +// 4. Send message to 100K Groups every second +// 5. Send message to 999 Groups every minute + +var ( + // Use default userIDs List for testing, need to be created. + TestTargetUserList = []string{ + // "", + } + // DefaultGroupID = "" // Use default group ID for testing, need to be created. +) + +var ( + ApiAddress string + + // API method + GetAdminToken = "/auth/get_admin_token" + UserCheck = "/user/account_check" + CreateUser = "/user/user_register" + ImportFriend = "/friend/import_friend" + InviteToGroup = "/group/invite_user_to_group" + GetGroupMemberInfo = "/group/get_group_members_info" + SendMsg = "/msg/send_msg" + CreateGroup = "/group/create_group" + GetUserToken = "/auth/user_token" +) + +const ( + MaxUser = 100000 + Max1kUser = 1000 + Max100KGroup = 100 + Max999Group = 1000 + MaxInviteUserLimit = 999 + + CreateUserTicker = 1 * time.Second + CreateGroupTicker = 1 * time.Second + Create100KGroupTicker = 1 * time.Second + Create999GroupTicker = 1 * time.Second + SendMsgTo100KGroupTicker = 1 * time.Second + SendMsgTo999GroupTicker = 1 * time.Minute +) + +type BaseResp struct { + ErrCode int `json:"errCode"` + ErrMsg string `json:"errMsg"` + Data json.RawMessage `json:"data"` +} + +type StressTest struct { + Conf *conf + AdminUserID string + AdminToken string + DefaultGroupID string + DefaultUserID string + UserCounter int + CreateUserCounter int + Create100kGroupCounter int + Create999GroupCounter int + MsgCounter int + CreatedUsers []string + CreatedGroups []string + Mutex sync.Mutex + Ctx context.Context + Cancel context.CancelFunc + HttpClient *http.Client + Wg sync.WaitGroup + Once sync.Once +} + +type conf struct { + Share config.Share + Api config.API +} + +func initConfig(configDir string) (*config.Share, *config.API, error) { + var ( + share = &config.Share{} + apiConfig = &config.API{} + ) + + err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share) + if err != nil { + return nil, nil, err + } + + err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig) + if err != nil { + return nil, nil, err + } + + return share, apiConfig, nil +} + +// Post Request +func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) { + // Marshal body + jsonBody, err := json.Marshal(reqbody) + if err != nil { + log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody) + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("operationID", st.AdminUserID) + if st.AdminToken != "" { + req.Header.Set("token", st.AdminToken) + } + + // log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken) + + resp, err := st.HttpClient.Do(req) + if err != nil { + log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody) + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + log.ZError(ctx, "Failed to read response body", err, "url", url) + return nil, err + } + + var baseResp BaseResp + if err := json.Unmarshal(respBody, &baseResp); err != nil { + log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody)) + return nil, err + } + + if baseResp.ErrCode != 0 { + err = fmt.Errorf(baseResp.ErrMsg) + // log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp) + return nil, err + } + + return baseResp.Data, nil +} + +func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) { + req := auth.GetAdminTokenReq{ + Secret: st.Conf.Share.Secret, + UserID: st.AdminUserID, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req) + if err != nil { + return "", err + } + + data := &auth.GetAdminTokenResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return "", err + } + + return data.Token, nil +} + +func (st *StressTest) CheckUser(ctx context.Context, userIDs []string) ([]string, error) { + req := pbuser.AccountCheckReq{ + CheckUserIDs: userIDs, + } + + resp, err := st.PostRequest(ctx, ApiAddress+UserCheck, &req) + if err != nil { + return nil, err + } + + data := &pbuser.AccountCheckResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return nil, err + } + + unRegisteredUserIDs := make([]string, 0) + + for _, res := range data.Results { + if res.AccountStatus == constant.UnRegistered { + unRegisteredUserIDs = append(unRegisteredUserIDs, res.UserID) + } + } + + return unRegisteredUserIDs, nil +} + +func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) { + user := &sdkws.UserInfo{ + UserID: userID, + Nickname: userID, + } + + req := pbuser.UserRegisterReq{ + Users: []*sdkws.UserInfo{user}, + } + + _, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req) + if err != nil { + return "", err + } + + st.UserCounter++ + return userID, nil +} + +func (st *StressTest) CreateUserBatch(ctx context.Context, userIDs []string) error { + // The method can import a large number of users at once. + var userList []*sdkws.UserInfo + + defer st.Once.Do( + func() { + st.DefaultUserID = userIDs[0] + fmt.Println("Default Send User Created ID:", st.DefaultUserID) + }) + + needUserIDs, err := st.CheckUser(ctx, userIDs) + if err != nil { + return err + } + + for _, userID := range needUserIDs { + user := &sdkws.UserInfo{ + UserID: userID, + Nickname: userID, + } + userList = append(userList, user) + } + + req := pbuser.UserRegisterReq{ + Users: userList, + } + + _, err = st.PostRequest(ctx, ApiAddress+CreateUser, &req) + if err != nil { + return err + } + + st.UserCounter += len(userList) + return nil +} + +func (st *StressTest) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]string, error) { + needInviteUserIDs := make([]string, 0) + + const maxBatchSize = 500 + if len(userIDs) > maxBatchSize { + for i := 0; i < len(userIDs); i += maxBatchSize { + end := min(i+maxBatchSize, len(userIDs)) + batchUserIDs := userIDs[i:end] + + // log.ZInfo(ctx, "Processing group members batch", "groupID", groupID, "batch", i/maxBatchSize+1, + // "batchUserCount", len(batchUserIDs)) + + // Process a single batch + batchReq := group.GetGroupMembersInfoReq{ + GroupID: groupID, + UserIDs: batchUserIDs, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &batchReq) + if err != nil { + log.ZError(ctx, "Batch query failed", err, "batch", i/maxBatchSize+1) + continue + } + + data := &group.GetGroupMembersInfoResp{} + if err := json.Unmarshal(resp, &data); err != nil { + log.ZError(ctx, "Failed to parse batch response", err, "batch", i/maxBatchSize+1) + continue + } + + // Process the batch results + existingMembers := make(map[string]bool) + for _, member := range data.Members { + existingMembers[member.UserID] = true + } + + for _, userID := range batchUserIDs { + if !existingMembers[userID] { + needInviteUserIDs = append(needInviteUserIDs, userID) + } + } + } + + return needInviteUserIDs, nil + } + + req := group.GetGroupMembersInfoReq{ + GroupID: groupID, + UserIDs: userIDs, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &req) + if err != nil { + return nil, err + } + + data := &group.GetGroupMembersInfoResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return nil, err + } + + existingMembers := make(map[string]bool) + for _, member := range data.Members { + existingMembers[member.UserID] = true + } + + for _, userID := range userIDs { + if !existingMembers[userID] { + needInviteUserIDs = append(needInviteUserIDs, userID) + } + } + + return needInviteUserIDs, nil +} + +func (st *StressTest) InviteToGroup(ctx context.Context, groupID string, userIDs []string) error { + req := group.InviteUserToGroupReq{ + GroupID: groupID, + InvitedUserIDs: userIDs, + } + _, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req) + if err != nil { + return err + } + + return nil +} + +func (st *StressTest) SendMsg(ctx context.Context, userID string, groupID string) error { + contentObj := map[string]any{ + // "content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")), + "content": fmt.Sprintf("The current time is %s", time.Now().Format("2006-01-02 15:04:05.000")), + } + + req := &apistruct.SendMsgReq{ + SendMsg: apistruct.SendMsg{ + SendID: userID, + SenderNickname: userID, + GroupID: groupID, + ContentType: constant.Text, + SessionType: constant.ReadGroupChatType, + Content: contentObj, + }, + } + + _, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req) + if err != nil { + log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req) + return err + } + + st.MsgCounter++ + + return nil +} + +// Max userIDs number is 1000 +func (st *StressTest) CreateGroup(ctx context.Context, groupID string, userID string, userIDsList []string) (string, error) { + groupInfo := &sdkws.GroupInfo{ + GroupID: groupID, + GroupName: groupID, + GroupType: constant.WorkingGroup, + } + + req := group.CreateGroupReq{ + OwnerUserID: userID, + MemberUserIDs: userIDsList, + GroupInfo: groupInfo, + } + + resp := group.CreateGroupResp{} + + response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req) + if err != nil { + return "", err + } + + if err := json.Unmarshal(response, &resp); err != nil { + return "", err + } + + // st.GroupCounter++ + + return resp.GroupInfo.GroupID, nil +} + +func main() { + var configPath string + // defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config") + // flag.StringVar(&configPath, "c", defaultConfigDir, "config path") + flag.StringVar(&configPath, "c", "", "config path") + flag.Parse() + + if configPath == "" { + _, _ = fmt.Fprintln(os.Stderr, "config path is empty") + os.Exit(1) + return + } + + fmt.Printf(" Config Path: %s\n", configPath) + + share, apiConfig, err := initConfig(configPath) + if err != nil { + program.ExitWithError(err) + return + } + + ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0])) + + ctx, cancel := context.WithCancel(context.Background()) + // ch := make(chan struct{}) + + st := &StressTest{ + Conf: &conf{ + Share: *share, + Api: *apiConfig, + }, + AdminUserID: share.IMAdminUserID[0], + Ctx: ctx, + Cancel: cancel, + HttpClient: &http.Client{ + Timeout: 50 * time.Second, + }, + } + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + fmt.Println("\nReceived stop signal, stopping...") + + go func() { + // time.Sleep(5 * time.Second) + fmt.Println("Force exit") + os.Exit(0) + }() + + st.Cancel() + }() + + token, err := st.GetAdminToken(st.Ctx) + if err != nil { + log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID) + } + + st.AdminToken = token + fmt.Println("Admin Token:", st.AdminToken) + fmt.Println("ApiAddress:", ApiAddress) + + for i := range MaxUser { + userID := fmt.Sprintf("v2_StressTest_User_%d", i) + st.CreatedUsers = append(st.CreatedUsers, userID) + st.CreateUserCounter++ + } + + // err = st.CreateUserBatch(st.Ctx, st.CreatedUsers) + // if err != nil { + // log.ZError(ctx, "Create user failed.", err) + // } + + const batchSize = 1000 + totalUsers := len(st.CreatedUsers) + successCount := 0 + + if st.DefaultUserID == "" && len(st.CreatedUsers) > 0 { + st.DefaultUserID = st.CreatedUsers[0] + } + + for i := 0; i < totalUsers; i += batchSize { + end := min(i+batchSize, totalUsers) + + userBatch := st.CreatedUsers[i:end] + log.ZInfo(st.Ctx, "Creating user batch", "batch", i/batchSize+1, "count", len(userBatch)) + + err = st.CreateUserBatch(st.Ctx, userBatch) + if err != nil { + log.ZError(st.Ctx, "Batch user creation failed", err, "batch", i/batchSize+1) + } else { + successCount += len(userBatch) + log.ZInfo(st.Ctx, "Batch user creation succeeded", "batch", i/batchSize+1, + "progress", fmt.Sprintf("%d/%d", successCount, totalUsers)) + } + } + + // Execute create 100k group + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + create100kGroupTicker := time.NewTicker(Create100KGroupTicker) + defer create100kGroupTicker.Stop() + + for i := range Max100KGroup { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create 100K Group") + return + + case <-create100kGroupTicker.C: + // Create 100K groups + st.Wg.Add(1) + go func(idx int) { + startTime := time.Now() + defer func() { + elapsedTime := time.Since(startTime) + log.ZInfo(st.Ctx, "100K group creation completed", + "groupID", fmt.Sprintf("v2_StressTest_Group_100K_%d", idx), + "index", idx, + "duration", elapsedTime.String()) + }() + + defer st.Wg.Done() + defer func() { + st.Mutex.Lock() + st.Create100kGroupCounter++ + st.Mutex.Unlock() + }() + + groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", idx) + + if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil { + log.ZError(st.Ctx, "Create group failed.", err) + // continue + } + + for i := 0; i <= MaxUser/MaxInviteUserLimit; i++ { + InviteUserIDs := make([]string, 0) + // ensure TargetUserList is in group + InviteUserIDs = append(InviteUserIDs, TestTargetUserList...) + + startIdx := max(i*MaxInviteUserLimit, 1) + endIdx := min((i+1)*MaxInviteUserLimit, MaxUser) + + for j := startIdx; j < endIdx; j++ { + userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j) + InviteUserIDs = append(InviteUserIDs, userCreatedID) + } + + if len(InviteUserIDs) == 0 { + // log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs) + if err != nil { + log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID) + continue + } + + if len(InviteUserIDs) == 0 { + // log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + // Invite To Group + if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil { + log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs) + continue + // os.Exit(1) + // return + } + } + }(i) + } + } + }() + + // create 999 groups + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + create999GroupTicker := time.NewTicker(Create999GroupTicker) + defer create999GroupTicker.Stop() + + for i := range Max999Group { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create 999 Group") + return + + case <-create999GroupTicker.C: + // Create 999 groups + st.Wg.Add(1) + go func(idx int) { + startTime := time.Now() + defer func() { + elapsedTime := time.Since(startTime) + log.ZInfo(st.Ctx, "999 group creation completed", + "groupID", fmt.Sprintf("v2_StressTest_Group_1K_%d", idx), + "index", idx, + "duration", elapsedTime.String()) + }() + + defer st.Wg.Done() + defer func() { + st.Mutex.Lock() + st.Create999GroupCounter++ + st.Mutex.Unlock() + }() + + groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", idx) + + if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil { + log.ZError(st.Ctx, "Create group failed.", err) + // continue + } + for i := 0; i <= Max1kUser/MaxInviteUserLimit; i++ { + InviteUserIDs := make([]string, 0) + // ensure TargetUserList is in group + InviteUserIDs = append(InviteUserIDs, TestTargetUserList...) + + startIdx := max(i*MaxInviteUserLimit, 1) + endIdx := min((i+1)*MaxInviteUserLimit, Max1kUser) + + for j := startIdx; j < endIdx; j++ { + userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j) + InviteUserIDs = append(InviteUserIDs, userCreatedID) + } + + if len(InviteUserIDs) == 0 { + // log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs) + if err != nil { + log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID) + continue + } + + if len(InviteUserIDs) == 0 { + // log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + // Invite To Group + if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil { + log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs) + continue + // os.Exit(1) + // return + } + } + }(i) + } + } + }() + + // Send message to 100K groups + st.Wg.Wait() + fmt.Println("All groups created successfully, starting to send messages...") + log.ZInfo(ctx, "All groups created successfully, starting to send messages...") + + var groups100K []string + var groups999 []string + + for i := range Max100KGroup { + groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", i) + groups100K = append(groups100K, groupID) + } + + for i := range Max999Group { + groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", i) + groups999 = append(groups999, groupID) + } + + send100kGroupLimiter := make(chan struct{}, 20) + send999GroupLimiter := make(chan struct{}, 100) + + // execute Send message to 100K groups + go func() { + ticker := time.NewTicker(SendMsgTo100KGroupTicker) + defer ticker.Stop() + + for { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Send Message to 100K Group") + return + + case <-ticker.C: + // Send message to 100K groups + for _, groupID := range groups100K { + send100kGroupLimiter <- struct{}{} + go func(groupID string) { + defer func() { <-send100kGroupLimiter }() + if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil { + log.ZError(st.Ctx, "Send message to 100K group failed.", err) + } + }(groupID) + } + // log.ZInfo(st.Ctx, "Send message to 100K groups successfully.") + } + } + }() + + // execute Send message to 999 groups + go func() { + ticker := time.NewTicker(SendMsgTo999GroupTicker) + defer ticker.Stop() + + for { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Send Message to 999 Group") + return + + case <-ticker.C: + // Send message to 999 groups + for _, groupID := range groups999 { + send999GroupLimiter <- struct{}{} + go func(groupID string) { + defer func() { <-send999GroupLimiter }() + + if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil { + log.ZError(st.Ctx, "Send message to 999 group failed.", err) + } + }(groupID) + } + // log.ZInfo(st.Ctx, "Send message to 999 groups successfully.") + } + } + }() + + <-st.Ctx.Done() + fmt.Println("Received signal to exit, shutting down...") +} diff --git a/test/stress-test/README.md b/test/stress-test/README.md new file mode 100644 index 000000000..cba93e279 --- /dev/null +++ b/test/stress-test/README.md @@ -0,0 +1,19 @@ +# Stress Test + +## Usage + +You need set `TestTargetUserList` and `DefaultGroupID` variables. + +### Build + +```bash + +go build -o test/stress-test/stress-test test/stress-test/main.go +``` + +### Excute + +```bash + +tools/stress-test/stress-test -c config/ +``` diff --git a/test/stress-test/main.go b/test/stress-test/main.go new file mode 100755 index 000000000..6adbd12ee --- /dev/null +++ b/test/stress-test/main.go @@ -0,0 +1,458 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/apistruct" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/protocol/auth" + "github.com/openimsdk/protocol/constant" + "github.com/openimsdk/protocol/group" + "github.com/openimsdk/protocol/relation" + "github.com/openimsdk/protocol/sdkws" + pbuser "github.com/openimsdk/protocol/user" + "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/system/program" +) + +/* + 1. Create one user every minute + 2. Import target users as friends + 3. Add users to the default group + 4. Send a message to the default group every second, containing index and current timestamp + 5. Create a new group every minute and invite target users to join +*/ + +// !!! ATTENTION: This variable is must be added! +var ( + // Use default userIDs List for testing, need to be created. + TestTargetUserList = []string{ + "", + } + DefaultGroupID = "" // Use default group ID for testing, need to be created. +) + +var ( + ApiAddress string + + // API method + GetAdminToken = "/auth/get_admin_token" + CreateUser = "/user/user_register" + ImportFriend = "/friend/import_friend" + InviteToGroup = "/group/invite_user_to_group" + SendMsg = "/msg/send_msg" + CreateGroup = "/group/create_group" + GetUserToken = "/auth/user_token" +) + +const ( + MaxUser = 10000 + MaxGroup = 1000 + + CreateUserTicker = 1 * time.Minute // Ticker is 1min in create user + SendMessageTicker = 1 * time.Second // Ticker is 1s in send message + CreateGroupTicker = 1 * time.Minute +) + +type BaseResp struct { + ErrCode int `json:"errCode"` + ErrMsg string `json:"errMsg"` + Data json.RawMessage `json:"data"` +} + +type StressTest struct { + Conf *conf + AdminUserID string + AdminToken string + DefaultGroupID string + DefaultUserID string + UserCounter int + GroupCounter int + MsgCounter int + CreatedUsers []string + CreatedGroups []string + Mutex sync.Mutex + Ctx context.Context + Cancel context.CancelFunc + HttpClient *http.Client + Wg sync.WaitGroup + Once sync.Once +} + +type conf struct { + Share config.Share + Api config.API +} + +func initConfig(configDir string) (*config.Share, *config.API, error) { + var ( + share = &config.Share{} + apiConfig = &config.API{} + ) + + err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share) + if err != nil { + return nil, nil, err + } + + err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig) + if err != nil { + return nil, nil, err + } + + return share, apiConfig, nil +} + +// Post Request +func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) { + // Marshal body + jsonBody, err := json.Marshal(reqbody) + if err != nil { + log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody) + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("operationID", st.AdminUserID) + if st.AdminToken != "" { + req.Header.Set("token", st.AdminToken) + } + + // log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken) + + resp, err := st.HttpClient.Do(req) + if err != nil { + log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody) + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + log.ZError(ctx, "Failed to read response body", err, "url", url) + return nil, err + } + + var baseResp BaseResp + if err := json.Unmarshal(respBody, &baseResp); err != nil { + log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody)) + return nil, err + } + + if baseResp.ErrCode != 0 { + err = fmt.Errorf(baseResp.ErrMsg) + log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp) + return nil, err + } + + return baseResp.Data, nil +} + +func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) { + req := auth.GetAdminTokenReq{ + Secret: st.Conf.Share.Secret, + UserID: st.AdminUserID, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req) + if err != nil { + return "", err + } + + data := &auth.GetAdminTokenResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return "", err + } + + return data.Token, nil +} + +func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) { + user := &sdkws.UserInfo{ + UserID: userID, + Nickname: userID, + } + + req := pbuser.UserRegisterReq{ + Users: []*sdkws.UserInfo{user}, + } + + _, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req) + if err != nil { + return "", err + } + + st.UserCounter++ + return userID, nil +} + +func (st *StressTest) ImportFriend(ctx context.Context, userID string) error { + req := relation.ImportFriendReq{ + OwnerUserID: userID, + FriendUserIDs: TestTargetUserList, + } + + _, err := st.PostRequest(ctx, ApiAddress+ImportFriend, &req) + if err != nil { + return err + } + + return nil +} + +func (st *StressTest) InviteToGroup(ctx context.Context, userID string) error { + req := group.InviteUserToGroupReq{ + GroupID: st.DefaultGroupID, + InvitedUserIDs: []string{userID}, + } + _, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req) + if err != nil { + return err + } + + return nil +} + +func (st *StressTest) SendMsg(ctx context.Context, userID string) error { + contentObj := map[string]any{ + "content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")), + } + + req := &apistruct.SendMsgReq{ + SendMsg: apistruct.SendMsg{ + SendID: userID, + SenderNickname: userID, + GroupID: st.DefaultGroupID, + ContentType: constant.Text, + SessionType: constant.ReadGroupChatType, + Content: contentObj, + }, + } + + _, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req) + if err != nil { + log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req) + return err + } + + st.MsgCounter++ + + return nil +} + +func (st *StressTest) CreateGroup(ctx context.Context, userID string) (string, error) { + groupID := fmt.Sprintf("StressTestGroup_%d_%s", st.GroupCounter, time.Now().Format("20060102150405")) + + groupInfo := &sdkws.GroupInfo{ + GroupID: groupID, + GroupName: groupID, + GroupType: constant.WorkingGroup, + } + + req := group.CreateGroupReq{ + OwnerUserID: userID, + MemberUserIDs: TestTargetUserList, + GroupInfo: groupInfo, + } + + resp := group.CreateGroupResp{} + + response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req) + if err != nil { + return "", err + } + + if err := json.Unmarshal(response, &resp); err != nil { + return "", err + } + + st.GroupCounter++ + + return resp.GroupInfo.GroupID, nil +} + +func main() { + var configPath string + // defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config") + // flag.StringVar(&configPath, "c", defaultConfigDir, "config path") + flag.StringVar(&configPath, "c", "", "config path") + flag.Parse() + + if configPath == "" { + _, _ = fmt.Fprintln(os.Stderr, "config path is empty") + os.Exit(1) + return + } + + fmt.Printf(" Config Path: %s\n", configPath) + + share, apiConfig, err := initConfig(configPath) + if err != nil { + program.ExitWithError(err) + return + } + + ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0])) + + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan struct{}) + + defer cancel() + + st := &StressTest{ + Conf: &conf{ + Share: *share, + Api: *apiConfig, + }, + AdminUserID: share.IMAdminUserID[0], + Ctx: ctx, + Cancel: cancel, + HttpClient: &http.Client{ + Timeout: 50 * time.Second, + }, + } + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + fmt.Println("\nReceived stop signal, stopping...") + + select { + case <-ch: + default: + close(ch) + } + + st.Cancel() + }() + + token, err := st.GetAdminToken(st.Ctx) + if err != nil { + log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID) + } + + st.AdminToken = token + fmt.Println("Admin Token:", st.AdminToken) + fmt.Println("ApiAddress:", ApiAddress) + + st.DefaultGroupID = DefaultGroupID + + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + ticker := time.NewTicker(CreateUserTicker) + defer ticker.Stop() + + for st.UserCounter < MaxUser { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create user", "reason", "context done") + return + + case <-ticker.C: + // Create User + userID := fmt.Sprintf("%d_Stresstest_%s", st.UserCounter, time.Now().Format("0102150405")) + + userCreatedID, err := st.CreateUser(st.Ctx, userID) + if err != nil { + log.ZError(st.Ctx, "Create User failed.", err, "UserID", userID) + os.Exit(1) + return + } + // fmt.Println("User Created ID:", userCreatedID) + + // Import Friend + if err = st.ImportFriend(st.Ctx, userCreatedID); err != nil { + log.ZError(st.Ctx, "Import Friend failed.", err, "UserID", userCreatedID) + os.Exit(1) + return + } + // Invite To Group + if err = st.InviteToGroup(st.Ctx, userCreatedID); err != nil { + log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", userCreatedID) + os.Exit(1) + return + } + + st.Once.Do(func() { + st.DefaultUserID = userCreatedID + fmt.Println("Default Send User Created ID:", userCreatedID) + close(ch) + }) + } + } + }() + + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + ticker := time.NewTicker(SendMessageTicker) + defer ticker.Stop() + <-ch + + for { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Send message", "reason", "context done") + return + + case <-ticker.C: + // Send Message + if err = st.SendMsg(st.Ctx, st.DefaultUserID); err != nil { + log.ZError(st.Ctx, "Send Message failed.", err, "UserID", st.DefaultUserID) + continue + } + } + } + }() + + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + ticker := time.NewTicker(CreateGroupTicker) + defer ticker.Stop() + <-ch + + for st.GroupCounter < MaxGroup { + + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create Group", "reason", "context done") + return + + case <-ticker.C: + + // Create Group + _, err := st.CreateGroup(st.Ctx, st.DefaultUserID) + if err != nil { + log.ZError(st.Ctx, "Create Group failed.", err, "UserID", st.DefaultUserID) + os.Exit(1) + return + } + + // fmt.Println("Group Created ID:", groupID) + } + } + }() + + st.Wg.Wait() +} diff --git a/tools/s3/internal/conversion.go b/tools/s3/internal/conversion.go index ba2174535..af391ec42 100644 --- a/tools/s3/internal/conversion.go +++ b/tools/s3/internal/conversion.go @@ -4,6 +4,11 @@ import ( "context" "errors" "fmt" + "log" + "net/http" + "path/filepath" + "time" + "github.com/mitchellh/mapstructure" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" @@ -19,10 +24,6 @@ import ( "github.com/openimsdk/tools/s3/oss" "github.com/spf13/viper" "go.mongodb.org/mongo-driver/mongo" - "log" - "net/http" - "path/filepath" - "time" ) const defaultTimeout = time.Second * 10 @@ -159,7 +160,7 @@ func doObject(db database.ObjectInfo, newS3, oldS3 s3.Interface, skip int) (*Res if err != nil { return nil, err } - putURL, err := newS3.PresignedPutObject(ctx, obj.Key, time.Hour) + putURL, err := newS3.PresignedPutObject(ctx, obj.Key, time.Hour, &s3.PutOption{ContentType: obj.ContentType}) if err != nil { return nil, err } @@ -176,7 +177,7 @@ func doObject(db database.ObjectInfo, newS3, oldS3 s3.Interface, skip int) (*Res return nil, fmt.Errorf("download object failed %s", downloadResp.Status) } log.Printf("file size %d", obj.Size) - request, err := http.NewRequest(http.MethodPut, putURL, downloadResp.Body) + request, err := http.NewRequest(http.MethodPut, putURL.URL, downloadResp.Body) if err != nil { return nil, err } diff --git a/tools/seq/internal/main.go b/tools/seq/internal/main.go index 7e5d5598c..9fd352a96 100644 --- a/tools/seq/internal/main.go +++ b/tools/seq/internal/main.go @@ -337,7 +337,7 @@ func SetVersion(coll *mongo.Collection, key string, version int) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() option := options.Update().SetUpsert(true) - filter := bson.M{"key": key, "value": strconv.Itoa(version)} + filter := bson.M{"key": key} update := bson.M{"$set": bson.M{"key": key, "value": strconv.Itoa(version)}} return mongoutil.UpdateOne(ctx, coll, filter, update, false, option) } diff --git a/tools/stress-test-v2/main.go b/tools/stress-test-v2/main.go new file mode 100644 index 000000000..0c309b9c9 --- /dev/null +++ b/tools/stress-test-v2/main.go @@ -0,0 +1,736 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/apistruct" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/protocol/auth" + "github.com/openimsdk/protocol/constant" + "github.com/openimsdk/protocol/group" + "github.com/openimsdk/protocol/sdkws" + pbuser "github.com/openimsdk/protocol/user" + "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/system/program" +) + +// 1. Create 100K New Users +// 2. Create 100 100K Groups +// 3. Create 1000 999 Groups +// 4. Send message to 100K Groups every second +// 5. Send message to 999 Groups every minute + +var ( + // Use default userIDs List for testing, need to be created. + TestTargetUserList = []string{ + // "", + } + // DefaultGroupID = "" // Use default group ID for testing, need to be created. +) + +var ( + ApiAddress string + + // API method + GetAdminToken = "/auth/get_admin_token" + UserCheck = "/user/account_check" + CreateUser = "/user/user_register" + ImportFriend = "/friend/import_friend" + InviteToGroup = "/group/invite_user_to_group" + GetGroupMemberInfo = "/group/get_group_members_info" + SendMsg = "/msg/send_msg" + CreateGroup = "/group/create_group" + GetUserToken = "/auth/user_token" +) + +const ( + MaxUser = 100000 + Max100KGroup = 100 + Max999Group = 1000 + MaxInviteUserLimit = 999 + + CreateUserTicker = 1 * time.Second + CreateGroupTicker = 1 * time.Second + Create100KGroupTicker = 1 * time.Second + Create999GroupTicker = 1 * time.Second + SendMsgTo100KGroupTicker = 1 * time.Second + SendMsgTo999GroupTicker = 1 * time.Minute +) + +type BaseResp struct { + ErrCode int `json:"errCode"` + ErrMsg string `json:"errMsg"` + Data json.RawMessage `json:"data"` +} + +type StressTest struct { + Conf *conf + AdminUserID string + AdminToken string + DefaultGroupID string + DefaultUserID string + UserCounter int + CreateUserCounter int + Create100kGroupCounter int + Create999GroupCounter int + MsgCounter int + CreatedUsers []string + CreatedGroups []string + Mutex sync.Mutex + Ctx context.Context + Cancel context.CancelFunc + HttpClient *http.Client + Wg sync.WaitGroup + Once sync.Once +} + +type conf struct { + Share config.Share + Api config.API +} + +func initConfig(configDir string) (*config.Share, *config.API, error) { + var ( + share = &config.Share{} + apiConfig = &config.API{} + ) + + err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share) + if err != nil { + return nil, nil, err + } + + err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig) + if err != nil { + return nil, nil, err + } + + return share, apiConfig, nil +} + +// Post Request +func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) { + // Marshal body + jsonBody, err := json.Marshal(reqbody) + if err != nil { + log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody) + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("operationID", st.AdminUserID) + if st.AdminToken != "" { + req.Header.Set("token", st.AdminToken) + } + + // log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken) + + resp, err := st.HttpClient.Do(req) + if err != nil { + log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody) + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + log.ZError(ctx, "Failed to read response body", err, "url", url) + return nil, err + } + + var baseResp BaseResp + if err := json.Unmarshal(respBody, &baseResp); err != nil { + log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody)) + return nil, err + } + + if baseResp.ErrCode != 0 { + err = fmt.Errorf(baseResp.ErrMsg) + log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp) + return nil, err + } + + return baseResp.Data, nil +} + +func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) { + req := auth.GetAdminTokenReq{ + Secret: st.Conf.Share.Secret, + UserID: st.AdminUserID, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req) + if err != nil { + return "", err + } + + data := &auth.GetAdminTokenResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return "", err + } + + return data.Token, nil +} + +func (st *StressTest) CheckUser(ctx context.Context, userIDs []string) ([]string, error) { + req := pbuser.AccountCheckReq{ + CheckUserIDs: userIDs, + } + + resp, err := st.PostRequest(ctx, ApiAddress+UserCheck, &req) + if err != nil { + return nil, err + } + + data := &pbuser.AccountCheckResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return nil, err + } + + unRegisteredUserIDs := make([]string, 0) + + for _, res := range data.Results { + if res.AccountStatus == constant.UnRegistered { + unRegisteredUserIDs = append(unRegisteredUserIDs, res.UserID) + } + } + + return unRegisteredUserIDs, nil +} + +func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) { + user := &sdkws.UserInfo{ + UserID: userID, + Nickname: userID, + } + + req := pbuser.UserRegisterReq{ + Users: []*sdkws.UserInfo{user}, + } + + _, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req) + if err != nil { + return "", err + } + + st.UserCounter++ + return userID, nil +} + +func (st *StressTest) CreateUserBatch(ctx context.Context, userIDs []string) error { + // The method can import a large number of users at once. + var userList []*sdkws.UserInfo + + defer st.Once.Do( + func() { + st.DefaultUserID = userIDs[0] + fmt.Println("Default Send User Created ID:", st.DefaultUserID) + }) + + needUserIDs, err := st.CheckUser(ctx, userIDs) + if err != nil { + return err + } + + for _, userID := range needUserIDs { + user := &sdkws.UserInfo{ + UserID: userID, + Nickname: userID, + } + userList = append(userList, user) + } + + req := pbuser.UserRegisterReq{ + Users: userList, + } + + _, err = st.PostRequest(ctx, ApiAddress+CreateUser, &req) + if err != nil { + return err + } + + st.UserCounter += len(userList) + return nil +} + +func (st *StressTest) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]string, error) { + needInviteUserIDs := make([]string, 0) + + const maxBatchSize = 500 + if len(userIDs) > maxBatchSize { + for i := 0; i < len(userIDs); i += maxBatchSize { + end := min(i+maxBatchSize, len(userIDs)) + batchUserIDs := userIDs[i:end] + + // log.ZInfo(ctx, "Processing group members batch", "groupID", groupID, "batch", i/maxBatchSize+1, + // "batchUserCount", len(batchUserIDs)) + + // Process a single batch + batchReq := group.GetGroupMembersInfoReq{ + GroupID: groupID, + UserIDs: batchUserIDs, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &batchReq) + if err != nil { + log.ZError(ctx, "Batch query failed", err, "batch", i/maxBatchSize+1) + continue + } + + data := &group.GetGroupMembersInfoResp{} + if err := json.Unmarshal(resp, &data); err != nil { + log.ZError(ctx, "Failed to parse batch response", err, "batch", i/maxBatchSize+1) + continue + } + + // Process the batch results + existingMembers := make(map[string]bool) + for _, member := range data.Members { + existingMembers[member.UserID] = true + } + + for _, userID := range batchUserIDs { + if !existingMembers[userID] { + needInviteUserIDs = append(needInviteUserIDs, userID) + } + } + } + + return needInviteUserIDs, nil + } + + req := group.GetGroupMembersInfoReq{ + GroupID: groupID, + UserIDs: userIDs, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &req) + if err != nil { + return nil, err + } + + data := &group.GetGroupMembersInfoResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return nil, err + } + + existingMembers := make(map[string]bool) + for _, member := range data.Members { + existingMembers[member.UserID] = true + } + + for _, userID := range userIDs { + if !existingMembers[userID] { + needInviteUserIDs = append(needInviteUserIDs, userID) + } + } + + return needInviteUserIDs, nil +} + +func (st *StressTest) InviteToGroup(ctx context.Context, groupID string, userIDs []string) error { + req := group.InviteUserToGroupReq{ + GroupID: groupID, + InvitedUserIDs: userIDs, + } + _, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req) + if err != nil { + return err + } + + return nil +} + +func (st *StressTest) SendMsg(ctx context.Context, userID string, groupID string) error { + contentObj := map[string]any{ + // "content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")), + "content": fmt.Sprintf("The current time is %s", time.Now().Format("2006-01-02 15:04:05.000")), + } + + req := &apistruct.SendMsgReq{ + SendMsg: apistruct.SendMsg{ + SendID: userID, + SenderNickname: userID, + GroupID: groupID, + ContentType: constant.Text, + SessionType: constant.ReadGroupChatType, + Content: contentObj, + }, + } + + _, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req) + if err != nil { + log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req) + return err + } + + st.MsgCounter++ + + return nil +} + +// Max userIDs number is 1000 +func (st *StressTest) CreateGroup(ctx context.Context, groupID string, userID string, userIDsList []string) (string, error) { + groupInfo := &sdkws.GroupInfo{ + GroupID: groupID, + GroupName: groupID, + GroupType: constant.WorkingGroup, + } + + req := group.CreateGroupReq{ + OwnerUserID: userID, + MemberUserIDs: userIDsList, + GroupInfo: groupInfo, + } + + resp := group.CreateGroupResp{} + + response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req) + if err != nil { + return "", err + } + + if err := json.Unmarshal(response, &resp); err != nil { + return "", err + } + + // st.GroupCounter++ + + return resp.GroupInfo.GroupID, nil +} + +func main() { + var configPath string + // defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config") + // flag.StringVar(&configPath, "c", defaultConfigDir, "config path") + flag.StringVar(&configPath, "c", "", "config path") + flag.Parse() + + if configPath == "" { + _, _ = fmt.Fprintln(os.Stderr, "config path is empty") + os.Exit(1) + return + } + + fmt.Printf(" Config Path: %s\n", configPath) + + share, apiConfig, err := initConfig(configPath) + if err != nil { + program.ExitWithError(err) + return + } + + ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0])) + + ctx, cancel := context.WithCancel(context.Background()) + // ch := make(chan struct{}) + + st := &StressTest{ + Conf: &conf{ + Share: *share, + Api: *apiConfig, + }, + AdminUserID: share.IMAdminUserID[0], + Ctx: ctx, + Cancel: cancel, + HttpClient: &http.Client{ + Timeout: 50 * time.Second, + }, + } + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + fmt.Println("\nReceived stop signal, stopping...") + + go func() { + // time.Sleep(5 * time.Second) + fmt.Println("Force exit") + os.Exit(0) + }() + + st.Cancel() + }() + + token, err := st.GetAdminToken(st.Ctx) + if err != nil { + log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID) + } + + st.AdminToken = token + fmt.Println("Admin Token:", st.AdminToken) + fmt.Println("ApiAddress:", ApiAddress) + + for i := range MaxUser { + userID := fmt.Sprintf("v2_StressTest_User_%d", i) + st.CreatedUsers = append(st.CreatedUsers, userID) + st.CreateUserCounter++ + } + + // err = st.CreateUserBatch(st.Ctx, st.CreatedUsers) + // if err != nil { + // log.ZError(ctx, "Create user failed.", err) + // } + + const batchSize = 1000 + totalUsers := len(st.CreatedUsers) + successCount := 0 + + if st.DefaultUserID == "" && len(st.CreatedUsers) > 0 { + st.DefaultUserID = st.CreatedUsers[0] + } + + for i := 0; i < totalUsers; i += batchSize { + end := min(i+batchSize, totalUsers) + + userBatch := st.CreatedUsers[i:end] + log.ZInfo(st.Ctx, "Creating user batch", "batch", i/batchSize+1, "count", len(userBatch)) + + err = st.CreateUserBatch(st.Ctx, userBatch) + if err != nil { + log.ZError(st.Ctx, "Batch user creation failed", err, "batch", i/batchSize+1) + } else { + successCount += len(userBatch) + log.ZInfo(st.Ctx, "Batch user creation succeeded", "batch", i/batchSize+1, + "progress", fmt.Sprintf("%d/%d", successCount, totalUsers)) + } + } + + // Execute create 100k group + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + create100kGroupTicker := time.NewTicker(Create100KGroupTicker) + defer create100kGroupTicker.Stop() + + for i := range Max100KGroup { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create 100K Group") + return + + case <-create100kGroupTicker.C: + // Create 100K groups + st.Wg.Add(1) + go func(idx int) { + defer st.Wg.Done() + defer func() { + st.Create100kGroupCounter++ + }() + + groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", idx) + + if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil { + log.ZError(st.Ctx, "Create group failed.", err) + // continue + } + + for i := 0; i < MaxUser/MaxInviteUserLimit; i++ { + InviteUserIDs := make([]string, 0) + // ensure TargetUserList is in group + InviteUserIDs = append(InviteUserIDs, TestTargetUserList...) + + startIdx := max(i*MaxInviteUserLimit, 1) + endIdx := min((i+1)*MaxInviteUserLimit, MaxUser) + + for j := startIdx; j < endIdx; j++ { + userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j) + InviteUserIDs = append(InviteUserIDs, userCreatedID) + } + + if len(InviteUserIDs) == 0 { + log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs) + if err != nil { + log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID) + continue + } + + if len(InviteUserIDs) == 0 { + log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + // Invite To Group + if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil { + log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs) + continue + // os.Exit(1) + // return + } + } + }(i) + } + } + }() + + // create 999 groups + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + create999GroupTicker := time.NewTicker(Create999GroupTicker) + defer create999GroupTicker.Stop() + + for i := range Max999Group { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create 999 Group") + return + + case <-create999GroupTicker.C: + // Create 999 groups + st.Wg.Add(1) + go func(idx int) { + defer st.Wg.Done() + defer func() { + st.Create999GroupCounter++ + }() + + groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", idx) + + if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil { + log.ZError(st.Ctx, "Create group failed.", err) + // continue + } + for i := 0; i < MaxUser/MaxInviteUserLimit; i++ { + InviteUserIDs := make([]string, 0) + // ensure TargetUserList is in group + InviteUserIDs = append(InviteUserIDs, TestTargetUserList...) + + startIdx := max(i*MaxInviteUserLimit, 1) + endIdx := min((i+1)*MaxInviteUserLimit, MaxUser) + + for j := startIdx; j < endIdx; j++ { + userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j) + InviteUserIDs = append(InviteUserIDs, userCreatedID) + } + + if len(InviteUserIDs) == 0 { + log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs) + if err != nil { + log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID) + continue + } + + if len(InviteUserIDs) == 0 { + log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID) + continue + } + + // Invite To Group + if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil { + log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs) + continue + // os.Exit(1) + // return + } + } + }(i) + } + } + }() + + // Send message to 100K groups + st.Wg.Wait() + fmt.Println("All groups created successfully, starting to send messages...") + log.ZInfo(ctx, "All groups created successfully, starting to send messages...") + + var groups100K []string + var groups999 []string + + for i := range Max100KGroup { + groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", i) + groups100K = append(groups100K, groupID) + } + + for i := range Max999Group { + groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", i) + groups999 = append(groups999, groupID) + } + + send100kGroupLimiter := make(chan struct{}, 20) + send999GroupLimiter := make(chan struct{}, 100) + + // execute Send message to 100K groups + go func() { + ticker := time.NewTicker(SendMsgTo100KGroupTicker) + defer ticker.Stop() + + for { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Send Message to 100K Group") + return + + case <-ticker.C: + // Send message to 100K groups + for _, groupID := range groups100K { + send100kGroupLimiter <- struct{}{} + go func(groupID string) { + defer func() { <-send100kGroupLimiter }() + if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil { + log.ZError(st.Ctx, "Send message to 100K group failed.", err) + } + }(groupID) + } + // log.ZInfo(st.Ctx, "Send message to 100K groups successfully.") + } + } + }() + + // execute Send message to 999 groups + go func() { + ticker := time.NewTicker(SendMsgTo999GroupTicker) + defer ticker.Stop() + + for { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Send Message to 999 Group") + return + + case <-ticker.C: + // Send message to 999 groups + for _, groupID := range groups999 { + send999GroupLimiter <- struct{}{} + go func(groupID string) { + defer func() { <-send999GroupLimiter }() + + if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil { + log.ZError(st.Ctx, "Send message to 999 group failed.", err) + } + }(groupID) + } + // log.ZInfo(st.Ctx, "Send message to 999 groups successfully.") + } + } + }() + + <-st.Ctx.Done() + fmt.Println("Received signal to exit, shutting down...") +} diff --git a/tools/stress-test/README.md b/tools/stress-test/README.md new file mode 100644 index 000000000..531233a20 --- /dev/null +++ b/tools/stress-test/README.md @@ -0,0 +1,25 @@ +# Stress Test + +## Usage + +You need set `TestTargetUserList` and `DefaultGroupID` variables. + +### Build + +```bash +go build -o _output/bin/tools/linux/amd64/stress-test tools/stress-test/main.go + +# or + +go build -o tools/stress-test/stress-test tools/stress-test/main.go +``` + +### Excute + +```bash +_output/bin/tools/linux/amd64/stress-test -c config/ + +#or + +tools/stress-test/stress-test -c config/ +``` diff --git a/tools/stress-test/main.go b/tools/stress-test/main.go new file mode 100755 index 000000000..f845b5e93 --- /dev/null +++ b/tools/stress-test/main.go @@ -0,0 +1,459 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/apistruct" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/protocol/auth" + "github.com/openimsdk/protocol/constant" + "github.com/openimsdk/protocol/group" + "github.com/openimsdk/protocol/relation" + "github.com/openimsdk/protocol/sdkws" + pbuser "github.com/openimsdk/protocol/user" + "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/system/program" +) + +/* + 1. Create one user every minute + 2. Import target users as friends + 3. Add users to the default group + 4. Send a message to the default group every second, containing index and current timestamp + 5. Create a new group every minute and invite target users to join +*/ + +// !!! ATTENTION: This variable is must be added! +var ( + // Use default userIDs List for testing, need to be created. + TestTargetUserList = []string{ + "", + } + DefaultGroupID = "" // Use default group ID for testing, need to be created. +) + +var ( + ApiAddress string + + // API method + GetAdminToken = "/auth/get_admin_token" + CreateUser = "/user/user_register" + ImportFriend = "/friend/import_friend" + InviteToGroup = "/group/invite_user_to_group" + SendMsg = "/msg/send_msg" + CreateGroup = "/group/create_group" + GetUserToken = "/auth/user_token" +) + +const ( + MaxUser = 10000 + MaxGroup = 1000 + + CreateUserTicker = 1 * time.Minute // Ticker is 1min in create user + SendMessageTicker = 1 * time.Second // Ticker is 1s in send message + CreateGroupTicker = 1 * time.Minute +) + +type BaseResp struct { + ErrCode int `json:"errCode"` + ErrMsg string `json:"errMsg"` + Data json.RawMessage `json:"data"` +} + +type StressTest struct { + Conf *conf + AdminUserID string + AdminToken string + DefaultGroupID string + DefaultUserID string + UserCounter int + GroupCounter int + MsgCounter int + CreatedUsers []string + CreatedGroups []string + Mutex sync.Mutex + Ctx context.Context + Cancel context.CancelFunc + HttpClient *http.Client + Wg sync.WaitGroup + Once sync.Once +} + +type conf struct { + Share config.Share + Api config.API +} + +func initConfig(configDir string) (*config.Share, *config.API, error) { + var ( + share = &config.Share{} + apiConfig = &config.API{} + ) + + err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share) + if err != nil { + return nil, nil, err + } + + err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig) + if err != nil { + return nil, nil, err + } + + return share, apiConfig, nil +} + +// Post Request +func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) { + // Marshal body + jsonBody, err := json.Marshal(reqbody) + if err != nil { + log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody) + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("operationID", st.AdminUserID) + if st.AdminToken != "" { + req.Header.Set("token", st.AdminToken) + } + + // log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken) + + resp, err := st.HttpClient.Do(req) + if err != nil { + log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody) + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + log.ZError(ctx, "Failed to read response body", err, "url", url) + return nil, err + } + + var baseResp BaseResp + if err := json.Unmarshal(respBody, &baseResp); err != nil { + log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody)) + return nil, err + } + + if baseResp.ErrCode != 0 { + err = fmt.Errorf(baseResp.ErrMsg) + log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp) + return nil, err + } + + return baseResp.Data, nil +} + +func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) { + req := auth.GetAdminTokenReq{ + Secret: st.Conf.Share.Secret, + UserID: st.AdminUserID, + } + + resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req) + if err != nil { + return "", err + } + + data := &auth.GetAdminTokenResp{} + if err := json.Unmarshal(resp, &data); err != nil { + return "", err + } + + return data.Token, nil +} + +func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) { + user := &sdkws.UserInfo{ + UserID: userID, + Nickname: userID, + } + + req := pbuser.UserRegisterReq{ + Users: []*sdkws.UserInfo{user}, + } + + _, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req) + if err != nil { + return "", err + } + + st.UserCounter++ + return userID, nil +} + +func (st *StressTest) ImportFriend(ctx context.Context, userID string) error { + req := relation.ImportFriendReq{ + OwnerUserID: userID, + FriendUserIDs: TestTargetUserList, + } + + _, err := st.PostRequest(ctx, ApiAddress+ImportFriend, &req) + if err != nil { + return err + } + + return nil +} + +func (st *StressTest) InviteToGroup(ctx context.Context, userID string) error { + req := group.InviteUserToGroupReq{ + GroupID: st.DefaultGroupID, + InvitedUserIDs: []string{userID}, + } + _, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req) + if err != nil { + return err + } + + return nil +} + +func (st *StressTest) SendMsg(ctx context.Context, userID string) error { + contentObj := map[string]any{ + "content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")), + } + + req := &apistruct.SendMsgReq{ + SendMsg: apistruct.SendMsg{ + SendID: userID, + SenderNickname: userID, + GroupID: st.DefaultGroupID, + ContentType: constant.Text, + SessionType: constant.ReadGroupChatType, + Content: contentObj, + }, + } + + _, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req) + if err != nil { + log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req) + return err + } + + st.MsgCounter++ + + return nil +} + +func (st *StressTest) CreateGroup(ctx context.Context, userID string) (string, error) { + groupID := fmt.Sprintf("StressTestGroup_%d_%s", st.GroupCounter, time.Now().Format("20060102150405")) + + groupInfo := &sdkws.GroupInfo{ + GroupID: groupID, + GroupName: groupID, + GroupType: constant.WorkingGroup, + } + + req := group.CreateGroupReq{ + OwnerUserID: userID, + MemberUserIDs: TestTargetUserList, + GroupInfo: groupInfo, + } + + resp := group.CreateGroupResp{} + + response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req) + if err != nil { + return "", err + } + + if err := json.Unmarshal(response, &resp); err != nil { + return "", err + } + + st.GroupCounter++ + + return resp.GroupInfo.GroupID, nil +} + +func main() { + var configPath string + // defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config") + // flag.StringVar(&configPath, "c", defaultConfigDir, "config path") + flag.StringVar(&configPath, "c", "", "config path") + flag.Parse() + + if configPath == "" { + _, _ = fmt.Fprintln(os.Stderr, "config path is empty") + os.Exit(1) + return + } + + fmt.Printf(" Config Path: %s\n", configPath) + + share, apiConfig, err := initConfig(configPath) + if err != nil { + program.ExitWithError(err) + return + } + + ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0])) + + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan struct{}) + + defer cancel() + + st := &StressTest{ + Conf: &conf{ + Share: *share, + Api: *apiConfig, + }, + AdminUserID: share.IMAdminUserID[0], + Ctx: ctx, + Cancel: cancel, + HttpClient: &http.Client{ + Timeout: 50 * time.Second, + }, + } + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + fmt.Println("\nReceived stop signal, stopping...") + + select { + case <-ch: + default: + close(ch) + } + + st.Cancel() + }() + + token, err := st.GetAdminToken(st.Ctx) + if err != nil { + log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID) + } + + st.AdminToken = token + fmt.Println("Admin Token:", st.AdminToken) + fmt.Println("ApiAddress:", ApiAddress) + + st.DefaultGroupID = DefaultGroupID + + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + ticker := time.NewTicker(CreateUserTicker) + defer ticker.Stop() + + for st.UserCounter < MaxUser { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create user", "reason", "context done") + return + + case <-ticker.C: + // Create User + userID := fmt.Sprintf("%d_Stresstest_%s", st.UserCounter, time.Now().Format("0102150405")) + + userCreatedID, err := st.CreateUser(st.Ctx, userID) + if err != nil { + log.ZError(st.Ctx, "Create User failed.", err, "UserID", userID) + os.Exit(1) + return + } + // fmt.Println("User Created ID:", userCreatedID) + + // Import Friend + if err = st.ImportFriend(st.Ctx, userCreatedID); err != nil { + log.ZError(st.Ctx, "Import Friend failed.", err, "UserID", userCreatedID) + os.Exit(1) + return + } + + // Invite To Group + if err = st.InviteToGroup(st.Ctx, userCreatedID); err != nil { + log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", userCreatedID) + os.Exit(1) + return + } + + st.Once.Do(func() { + st.DefaultUserID = userCreatedID + fmt.Println("Default Send User Created ID:", userCreatedID) + close(ch) + }) + } + } + }() + + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + ticker := time.NewTicker(SendMessageTicker) + defer ticker.Stop() + <-ch + + for { + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Send message", "reason", "context done") + return + + case <-ticker.C: + // Send Message + if err = st.SendMsg(st.Ctx, st.DefaultSendUserID); err != nil { + log.ZError(st.Ctx, "Send Message failed.", err, "UserID", st.DefaultSendUserID) + continue + } + } + } + }() + + st.Wg.Add(1) + go func() { + defer st.Wg.Done() + + ticker := time.NewTicker(CreateGroupTicker) + defer ticker.Stop() + <-ch + + for st.GroupCounter < MaxGroup { + + select { + case <-st.Ctx.Done(): + log.ZInfo(st.Ctx, "Stop Create Group", "reason", "context done") + return + + case <-ticker.C: + + // Create Group + _, err := st.CreateGroup(st.Ctx, st.DefaultUserID) + if err != nil { + log.ZError(st.Ctx, "Create Group failed.", err, "UserID", st.DefaultUserID) + os.Exit(1) + return + } + + // fmt.Println("Group Created ID:", groupID) + } + } + }() + + st.Wg.Wait() +} diff --git a/version/version.go b/version/version.go index 23b3a82f5..32ad27808 100644 --- a/version/version.go +++ b/version/version.go @@ -1,6 +1,14 @@ package version -import _ "embed" +import ( + _ "embed" + "strings" +) //go:embed version var Version string + +func init() { + Version = strings.Trim(Version, "\n") + Version = strings.TrimSpace(Version) +}