Merge branch 'openimsdk:pre-release-v3.8.4' into pre-release-v3.8.4

pull/3355/head
chao 4 months ago committed by GitHub
commit 2769a2fe81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -12,6 +12,10 @@ jobs:
go-build:
name: Test with go ${{ matrix.go_version }} on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
env:
SHARE_CONFIG_PATH: config/share.yml
permissions:
contents: write
pull-requests: write
@ -40,6 +44,10 @@ jobs:
with:
compose-file: "./docker-compose.yml"
- name: Modify Server Configuration
run: |
yq e '.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }}
# - name: Get Internal IP Address
# id: get-ip
# run: |
@ -71,6 +79,11 @@ jobs:
go mod download
go install github.com/magefile/mage@latest
- name: Modify Chat Configuration
run: |
cd ${{ github.workspace }}/chat-repo
yq e '.openIM.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }}
- name: Build and test Chat Services
run: |
cd ${{ github.workspace }}/chat-repo
@ -132,7 +145,7 @@ jobs:
# Test get admin token
get_admin_token_response=$(curl -X POST -H "Content-Type: application/json" -H "operationID: imAdmin" -d '{
"secret": "openIM123",
"secret": "123456",
"platformID": 2,
"userID": "imAdmin"
}' http://127.0.0.1:10002/auth/get_admin_token)
@ -169,7 +182,8 @@ jobs:
contents: write
env:
SDK_DIR: openim-sdk-core
CONFIG_PATH: config/notification.yml
NOTIFICATION_CONFIG_PATH: config/notification.yml
SHARE_CONFIG_PATH: config/share.yml
strategy:
matrix:
@ -184,7 +198,7 @@ jobs:
uses: actions/checkout@v4
with:
repository: "openimsdk/openim-sdk-core"
ref: "release-v3.8"
ref: "main"
path: ${{ env.SDK_DIR }}
- name: Set up Go ${{ matrix.go_version }}
@ -199,8 +213,9 @@ jobs:
- name: Modify Server Configuration
run: |
yq e '.groupCreated.isSendMsg = true' -i ${{ env.CONFIG_PATH }}
yq e '.friendApplicationApproved.isSendMsg = true' -i ${{ env.CONFIG_PATH }}
yq e '.groupCreated.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }}
yq e '.friendApplicationApproved.isSendMsg = true' -i ${{ env.NOTIFICATION_CONFIG_PATH }}
yq e '.secret = 123456' -i ${{ env.SHARE_CONFIG_PATH }}
- name: Start Server Services
run: |

@ -28,6 +28,8 @@ run:
# - util
# - .*~
# - api/swagger/docs
# - server/docs
# - components/mnt/config/certs
# - logs

@ -131,7 +131,15 @@ Thank you for contributing to building a powerful instant messaging solution!
## :closed_book: License
OpenIMSDK is available under the Apache License 2.0. See the [LICENSE file](https://github.com/openimsdk/open-im-server/blob/main/LICENSE) for more information.
This software is licensed under a dual-license model:
- The GNU Affero General Public License (AGPL), Version 3 or later; **OR**
- Commercial license terms from OpenIMSDK.
If you wish to use this software under commercial terms, please contact us at: contact@openim.io
For more information, see: https://www.openim.io/en/licensing

@ -131,9 +131,17 @@
感谢您的贡献,一起来打造强大的即时通讯解决方案!
## :closed_book: 许可证
## :closed_book: 开源许可证 License
本软件采用双重授权模型:
GNU Affero 通用公共许可证AGPL第 3 版或更高版本;或
来自 OpenIMSDK 的商业授权条款。
如需商用请联系contact@openim.io
详见https://www.openim.io/en/licensing
OpenIMSDK 在 Apache License 2.0 许可下可用。查看[LICENSE 文件](https://github.com/openimsdk/open-im-server/blob/main/LICENSE)了解更多信息。
## 🔮 Thanks to our contributors!

@ -7,3 +7,7 @@ multiLogin:
policy: 1
# max num of tokens in one end
maxNumOneEnd: 30
rpcMaxBodySize:
requestMaxBodySize: 8388608
responseMaxBodySize: 8388608

@ -12,8 +12,8 @@ require (
github.com/gorilla/websocket v1.5.1
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/mitchellh/mapstructure v1.5.0
github.com/openimsdk/protocol v0.0.72-alpha.79
github.com/openimsdk/tools v0.0.50-alpha.74
github.com/openimsdk/protocol v0.0.73-alpha.6
github.com/openimsdk/tools v0.0.50-alpha.81
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.18.0
github.com/stretchr/testify v1.9.0
@ -219,3 +219,5 @@ require (
golang.org/x/crypto v0.27.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)
//replace github.com/openimsdk/protocol => /Users/chao/Desktop/code/protocol

@ -345,12 +345,12 @@ github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA
github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To=
github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y=
github.com/onsi/gomega v1.25.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM=
github.com/openimsdk/gomake v0.0.14-alpha.5 h1:VY9c5x515lTfmdhhPjMvR3BBRrRquAUCFsz7t7vbv7Y=
github.com/openimsdk/gomake v0.0.14-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI=
github.com/openimsdk/protocol v0.0.72-alpha.79 h1:e46no8WVAsmTzyy405klrdoUiG7u+1ohDsXvQuFng4s=
github.com/openimsdk/protocol v0.0.72-alpha.79/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw=
github.com/openimsdk/tools v0.0.50-alpha.74 h1:yh10SiMiivMEjicEQg+QAsH4pvaO+4noMPdlw+ew0Kc=
github.com/openimsdk/tools v0.0.50-alpha.74/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo=
github.com/openimsdk/gomake v0.0.15-alpha.5 h1:eEZCEHm+NsmcO3onXZPIUbGFCYPYbsX5beV3ZyOsGhY=
github.com/openimsdk/gomake v0.0.15-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI=
github.com/openimsdk/protocol v0.0.73-alpha.6 h1:sna9coWG7HN1zObBPtvG0Ki/vzqHXiB4qKbA5P3w7kc=
github.com/openimsdk/protocol v0.0.73-alpha.6/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw=
github.com/openimsdk/tools v0.0.50-alpha.81 h1:VbuJKtigNXLkCKB/Q6f2UHsqoSaTOAwS8F51c1nhOCA=
github.com/openimsdk/tools v0.0.50-alpha.81/go.mod h1:n2poR3asX1e1XZce4O+MOWAp+X02QJRFvhcLCXZdzRo=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=

@ -45,7 +45,7 @@ func NewConfigManager(IMAdminUserID []string, cfg *config.AllConfig, client *cli
}
func (cm *ConfigManager) CheckAdmin(c *gin.Context) {
if err := authverify.CheckAdmin(c, cm.imAdminUserID); err != nil {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
c.Abort()
}

@ -144,24 +144,23 @@ func Start(ctx context.Context, index int, config *Config) error {
}
}()
if config.Discovery.Enable == conf.ETCD {
cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames())
cm.Watch(ctx)
}
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM)
shutdown := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
err := server.Shutdown(ctx)
if err != nil {
return errs.WrapMsg(err, "shutdown err")
}
return nil
}
disetcd.RegisterShutDown(shutdown)
//if config.Discovery.Enable == conf.ETCD {
// cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames())
// cm.Watch(ctx)
//}
//sigs := make(chan os.Signal, 1)
//signal.Notify(sigs, syscall.SIGTERM)
//select {
//case val := <-sigs:
// log.ZDebug(ctx, "recv exit", "signal", val.String())
// cancel(fmt.Errorf("signal %s", val.String()))
//case <-ctx.Done():
//}
<-apiCtx.Done()
exitCause := context.Cause(apiCtx)
log.ZWarn(ctx, "api server exit", exitCause)
timer := time.NewTimer(time.Second * 15)
defer timer.Stop()
select {
case <-sigs:
program.SIGTERMExit()

@ -281,7 +281,7 @@ func (m *MessageApi) SendMessage(c *gin.Context) {
}
// Check if the user has the app manager role.
if !authverify.IsAppManagerUid(c, m.imAdminUserID) {
if !authverify.IsAdmin(c) {
// Respond with a permission error if the user is not an app manager.
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
@ -355,7 +355,7 @@ func (m *MessageApi) SendBusinessNotification(c *gin.Context) {
if req.ReliabilityLevel == nil {
req.ReliabilityLevel = datautil.ToPtr(1)
}
if !authverify.IsAppManagerUid(c, m.imAdminUserID) {
if !authverify.IsAdmin(c) {
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}
@ -399,7 +399,7 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if err := authverify.CheckAdmin(c, m.imAdminUserID); err != nil {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}

@ -9,6 +9,11 @@ import (
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/openimsdk/open-im-server/v3/internal/api/jssdk"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
@ -96,7 +101,7 @@ func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cf
r.Use(gzip.Gzip(gzip.BestSpeed))
}
r.Use(prommetricsGin(), gin.RecoveryWithWriter(gin.DefaultErrorWriter, mw.GinPanicErr), mw.CorsHandler(),
mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn)))
mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn)), setGinIsAdmin(cfg.Share.IMAdminUserID))
u := NewUserApi(user.NewUserClient(userConn), client, cfg.Discovery.RpcService)
{
@ -124,6 +129,11 @@ func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cf
userRouterGroup.POST("/add_notification_account", u.AddNotificationAccount)
userRouterGroup.POST("/update_notification_account", u.UpdateNotificationAccountInfo)
userRouterGroup.POST("/search_notification_account", u.SearchNotificationAccount)
userRouterGroup.POST("/get_user_client_config", u.GetUserClientConfig)
userRouterGroup.POST("/set_user_client_config", u.SetUserClientConfig)
userRouterGroup.POST("/del_user_client_config", u.DelUserClientConfig)
userRouterGroup.POST("/page_user_client_config", u.PageUserClientConfig)
}
// friend routing group
{
@ -347,6 +357,14 @@ func GinParseToken(authClient *rpcli.AuthClient) gin.HandlerFunc {
}
}
func setGinIsAdmin(imAdminUserID []string) gin.HandlerFunc {
return func(c *gin.Context) {
opUserID := mcontext.GetOpUserID(c)
admin := datautil.Contain(opUserID, imAdminUserID...)
c.Set(authverify.CtxIsAdminKey, admin)
}
}
// Whitelist api not parse token
var Whitelist = []string{
"/auth/get_admin_token",

@ -242,3 +242,19 @@ func (u *UserApi) UpdateNotificationAccountInfo(c *gin.Context) {
func (u *UserApi) SearchNotificationAccount(c *gin.Context) {
a2r.Call(c, user.UserClient.SearchNotificationAccount, u.Client)
}
func (u *UserApi) GetUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.GetUserClientConfig, u.Client)
}
func (u *UserApi) SetUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.SetUserClientConfig, u.Client)
}
func (u *UserApi) DelUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.DelUserClientConfig, u.Client)
}
func (u *UserApi) PageUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.PageUserClientConfig, u.Client)
}

@ -101,7 +101,7 @@ func NewServer(longConnServer LongConnServer, conf *Config, ready func(srv *Serv
}
func (s *Server) GetUsersOnlineStatus(ctx context.Context, req *msggateway.GetUsersOnlineStatusReq) (*msggateway.GetUsersOnlineStatusResp, error) {
if !authverify.IsAppManagerUid(ctx, s.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
return nil, errs.ErrNoPermission.WrapMsg("only app manager")
}
var resp msggateway.GetUsersOnlineStatusResp

@ -18,11 +18,14 @@ import (
"context"
"errors"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/mcache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo"
"github.com/openimsdk/open-im-server/v3/pkg/dbbuild"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis"
"github.com/openimsdk/tools/db/redisutil"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
@ -43,7 +46,7 @@ import (
type authServer struct {
pbauth.UnimplementedAuthServer
authDatabase controller.AuthDatabase
RegisterCenter discovery.SvcDiscoveryRegistry
RegisterCenter discovery.Conn
config *Config
userClient *rpcli.UserClient
}
@ -51,15 +54,31 @@ type authServer struct {
type Config struct {
RpcConfig config.Auth
RedisConfig config.Redis
MongoConfig config.Mongo
Share config.Share
Discovery config.Discovery
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error {
rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build())
func Start(ctx context.Context, config *Config, client discovery.Conn, server grpc.ServiceRegistrar) error {
dbb := dbbuild.NewBuilder(&config.MongoConfig, &config.RedisConfig)
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
var token cache.TokenModel
if rdb == nil {
mdb, err := dbb.Mongo(ctx)
if err != nil {
return err
}
mc, err := mgo.NewCacheMgo(mdb.GetDB())
if err != nil {
return err
}
token = mcache.NewTokenCacheModel(mc, config.RpcConfig.TokenPolicy.Expire)
} else {
token = redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire)
}
userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
@ -67,7 +86,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg
pbauth.RegisterAuthServer(server, &authServer{
RegisterCenter: client,
authDatabase: controller.NewAuthDatabase(
redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire),
token,
config.Share.Secret,
config.RpcConfig.TokenPolicy.Expire,
config.Share.MultiLogin,
@ -106,7 +125,7 @@ func (s *authServer) GetAdminToken(ctx context.Context, req *pbauth.GetAdminToke
}
func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) {
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
@ -116,7 +135,7 @@ func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenR
resp := pbauth.GetUserTokenResp{}
if authverify.IsManagerUserID(req.UserID, s.config.Share.IMAdminUserID) {
if authverify.CheckUserIsAdmin(ctx, req.UserID) {
return nil, errs.ErrNoPermission.WrapMsg("don't get Admin token")
}
user, err := s.userClient.GetUserInfo(ctx, req.UserID)
@ -140,15 +159,17 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
if err != nil {
return nil, err
}
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
return claims, nil
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil {
return nil, err
}
if len(m) == 0 {
isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
return nil, servererrs.ErrTokenNotExist.Wrap()
}
if v, ok := m[tokensString]; ok {
@ -160,6 +181,13 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
default:
return nil, errs.Wrap(errs.ErrTokenUnknown)
}
} else {
isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
}
return nil, servererrs.ErrTokenNotExist.Wrap()
}
@ -177,7 +205,7 @@ func (s *authServer) ParseToken(ctx context.Context, req *pbauth.ParseTokenReq)
}
func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) {
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if err := s.forceKickOff(ctx, req.UserID, req.PlatformID); err != nil {

@ -4,12 +4,16 @@ import (
"context"
"github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/open-im-server/v3/pkg/util/hashutil"
"github.com/openimsdk/protocol/conversation"
)
func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, req *conversation.GetFullOwnerConversationIDsReq) (*conversation.GetFullOwnerConversationIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
vl, err := c.conversationDatabase.FindMaxConversationUserVersionCache(ctx, req.UserID)
if err != nil {
return nil, err

@ -156,7 +156,7 @@ func (g *groupServer) NotificationUserInfoUpdate(ctx context.Context, req *pbgro
}
func (g *groupServer) CheckGroupAdmin(ctx context.Context, groupID string) error {
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
groupMember, err := g.db.TakeGroupMember(ctx, groupID, mcontext.GetOpUserID(ctx))
if err != nil {
return err
@ -208,7 +208,7 @@ func (g *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR
if req.OwnerUserID == "" {
return nil, errs.ErrArgs.WrapMsg("no group owner")
}
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, g.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
userIDs := append(append(req.MemberUserIDs, req.AdminUserIDs...), req.OwnerUserID)
@ -311,7 +311,7 @@ func (g *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR
}
func (g *groupServer) GetJoinedGroupList(ctx context.Context, req *pbgroup.GetJoinedGroupListReq) (*pbgroup.GetJoinedGroupListResp, error) {
if err := authverify.CheckAccessV3(ctx, req.FromUserID, g.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil {
return nil, err
}
total, members, err := g.db.PageGetJoinGroup(ctx, req.FromUserID, req.Pagination)
@ -383,7 +383,7 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite
var groupMember *model.GroupMember
var opUserID string
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
opUserID = mcontext.GetOpUserID(ctx)
var err error
groupMember, err = g.db.TakeGroupMember(ctx, req.GroupID, opUserID)
@ -402,7 +402,7 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite
}
if group.NeedVerification == constant.AllNeedVerification {
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
if !(groupMember.RoleLevel == constant.GroupOwner || groupMember.RoleLevel == constant.GroupAdmin) {
var requests []*model.GroupRequest
for _, userID := range req.InvitedUserIDs {
@ -451,13 +451,26 @@ func (g *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite
return nil, err
}
if err := g.db.CreateGroup(ctx, nil, groupMembers); err != nil {
const singleQuantity = 50
for start := 0; start < len(groupMembers); start += singleQuantity {
end := start + singleQuantity
if end > len(groupMembers) {
end = len(groupMembers)
}
currentMembers := groupMembers[start:end]
if err := g.db.CreateGroup(ctx, nil, currentMembers); err != nil {
return nil, err
}
if err = g.notification.GroupApplicationAgreeMemberEnterNotification(ctx, req.GroupID, opUserID, req.InvitedUserIDs...); err != nil {
userIDs := datautil.Slice(currentMembers, func(e *model.GroupMember) string {
return e.UserID
})
if err = g.notification.GroupApplicationAgreeMemberEnterNotification(ctx, req.GroupID, req.SendMessage, opUserID, userIDs...); err != nil {
return nil, err
}
}
return &pbgroup.InviteUserToGroupResp{}, nil
}
@ -477,6 +490,11 @@ func (g *groupServer) GetGroupAllMember(ctx context.Context, req *pbgroup.GetGro
}
func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGroupMemberListReq) (*pbgroup.GetGroupMemberListResp, error) {
if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) {
if _, err := g.db.TakeGroupMember(ctx, req.GroupID, opUserID); err != nil {
return nil, err
}
}
var (
total int64
members []*model.GroupMember
@ -485,7 +503,7 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr
if req.Keyword == "" {
total, members, err = g.db.PageGetGroupMember(ctx, req.GroupID, req.Pagination)
} else {
members, err = g.db.FindGroupMemberAll(ctx, req.GroupID)
total, members, err = g.db.SearchGroupMember(ctx, req.GroupID, req.Keyword, req.Pagination)
}
if err != nil {
return nil, err
@ -493,27 +511,6 @@ func (g *groupServer) GetGroupMemberList(ctx context.Context, req *pbgroup.GetGr
if err := g.PopulateGroupMember(ctx, members...); err != nil {
return nil, err
}
if req.Keyword != "" {
groupMembers := make([]*model.GroupMember, 0)
for _, member := range members {
if member.UserID == req.Keyword {
groupMembers = append(groupMembers, member)
total++
continue
}
if member.Nickname == req.Keyword {
groupMembers = append(groupMembers, member)
total++
continue
}
}
members := datautil.Paginate(groupMembers, int(req.Pagination.GetPageNumber()), int(req.Pagination.GetShowNumber()))
return &pbgroup.GetGroupMemberListResp{
Total: uint32(total),
Members: datautil.Batch(convert.Db2PbGroupMember, members),
}, nil
}
return &pbgroup.GetGroupMemberListResp{
Total: uint32(total),
Members: datautil.Batch(convert.Db2PbGroupMember, members),
@ -554,7 +551,7 @@ func (g *groupServer) KickGroupMember(ctx context.Context, req *pbgroup.KickGrou
for i, member := range members {
memberMap[member.UserID] = members[i]
}
isAppManagerUid := authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID)
isAppManagerUid := authverify.IsAdmin(ctx)
opMember := memberMap[opUserID]
for _, userID := range req.KickedUserIDs {
member, ok := memberMap[userID]
@ -772,7 +769,7 @@ func (g *groupServer) GroupApplicationResponse(ctx context.Context, req *pbgroup
if !datautil.Contain(req.HandleResult, constant.GroupResponseAgree, constant.GroupResponseRefuse) {
return nil, errs.ErrArgs.WrapMsg("HandleResult unknown")
}
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
groupMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx))
if err != nil {
return nil, err
@ -920,7 +917,7 @@ func (g *groupServer) QuitGroup(ctx context.Context, req *pbgroup.QuitGroupReq)
if req.UserID == "" {
req.UserID = mcontext.GetOpUserID(ctx)
} else {
if err := authverify.CheckAccessV3(ctx, req.UserID, g.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
}
@ -958,7 +955,7 @@ func (g *groupServer) deleteMemberAndSetConversationSeq(ctx context.Context, gro
func (g *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInfoReq) (*pbgroup.SetGroupInfoResp, error) {
var opMember *model.GroupMember
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
var err error
opMember, err = g.db.TakeGroupMember(ctx, req.GroupInfoForSet.GroupID, mcontext.GetOpUserID(ctx))
if err != nil {
@ -1051,7 +1048,7 @@ func (g *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInf
func (g *groupServer) SetGroupInfoEx(ctx context.Context, req *pbgroup.SetGroupInfoExReq) (*pbgroup.SetGroupInfoExResp, error) {
var opMember *model.GroupMember
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
var err error
opMember, err = g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx))
@ -1203,7 +1200,7 @@ func (g *groupServer) TransferGroupOwner(ctx context.Context, req *pbgroup.Trans
return nil, errs.ErrArgs.WrapMsg("NewOwnerUser not in group " + req.NewOwnerUserID)
}
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
if !(mcontext.GetOpUserID(ctx) == oldOwner.UserID && oldOwner.RoleLevel == constant.GroupOwner) {
return nil, errs.ErrNoPermission.WrapMsg("no permission transfer group owner")
}
@ -1346,7 +1343,7 @@ func (g *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou
if err != nil {
return nil, err
}
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
if owner.UserID != mcontext.GetOpUserID(ctx) {
return nil, errs.ErrNoPermission.WrapMsg("not group owner")
}
@ -1369,6 +1366,7 @@ func (g *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou
if err != nil {
return nil, err
}
group.Status = constant.GroupStatusDismissed
tips := &sdkws.GroupDismissedTips{
Group: g.groupDB2PB(group, owner.UserID, num),
OpUser: &sdkws.GroupMemberFullInfo{},
@ -1402,7 +1400,7 @@ func (g *groupServer) MuteGroupMember(ctx context.Context, req *pbgroup.MuteGrou
if err := g.PopulateGroupMember(ctx, member); err != nil {
return nil, err
}
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
opMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx))
if err != nil {
return nil, err
@ -1438,7 +1436,7 @@ func (g *groupServer) CancelMuteGroupMember(ctx context.Context, req *pbgroup.Ca
return nil, err
}
if !authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
opMember, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx))
if err != nil {
return nil, err
@ -1498,7 +1496,7 @@ func (g *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr
if opUserID == "" {
return nil, errs.ErrNoPermission.WrapMsg("no op user id")
}
isAppManagerUid := authverify.IsAppManagerUid(ctx, g.config.Share.IMAdminUserID)
isAppManagerUid := authverify.IsAdmin(ctx)
groupMembers := make(map[string][]*pbgroup.SetGroupMemberInfo)
for i, member := range req.Members {
if member.RoleLevel != nil {

@ -242,8 +242,8 @@ func (g *NotificationSender) fillOpUserByUserID(ctx context.Context, userID stri
return errs.ErrInternalServer.WrapMsg("**sdkws.GroupMemberFullInfo is nil")
}
if groupID != "" {
if authverify.IsManagerUserID(userID, g.config.Share.IMAdminUserID) {
*opUser = &sdkws.GroupMemberFullInfo{
if authverify.CheckUserIsAdmin(ctx, userID) {
*targetUser = &sdkws.GroupMemberFullInfo{
GroupID: groupID,
UserID: userID,
RoleLevel: constant.GroupAdmin,
@ -283,7 +283,8 @@ func (g *NotificationSender) fillOpUserByUserID(ctx context.Context, userID stri
func (g *NotificationSender) setVersion(ctx context.Context, version *uint64, versionID *string, collName string, id string) {
versions := versionctx.GetVersionLog(ctx).Get()
for _, coll := range versions {
for i := len(versions) - 1; i >= 0; i-- {
coll := versions[i]
if coll.Name == collName && coll.Doc.DID == id {
*version = uint64(coll.Doc.Version)
*versionID = coll.Doc.ID.Hex()
@ -519,7 +520,11 @@ func (g *NotificationSender) MemberKickedNotification(ctx context.Context, tips
g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.MemberKickedNotification, tips)
}
func (g *NotificationSender) GroupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, invitedOpUserID string, entrantUserID ...string) error {
func (g *NotificationSender) GroupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, SendMessage *bool, invitedOpUserID string, entrantUserID ...string) error {
return g.groupApplicationAgreeMemberEnterNotification(ctx, groupID, SendMessage, invitedOpUserID, entrantUserID...)
}
func (g *NotificationSender) groupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, SendMessage *bool, invitedOpUserID string, entrantUserID ...string) error {
var err error
defer func() {
if err != nil {

@ -12,15 +12,23 @@ import (
pbgroup "github.com/openimsdk/protocol/group"
"github.com/openimsdk/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) {
vl, err := s.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID)
const versionSyncLimit = 500
func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) {
userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID)
if err != nil {
return nil, err
}
userIDs, err := s.db.FindGroupMemberUserID(ctx, req.GroupID)
if opUserID := mcontext.GetOpUserID(ctx); !datautil.Contain(opUserID, g.config.Share.IMAdminUserID...) {
if !datautil.Contain(opUserID, userIDs...) {
return nil, errs.ErrNoPermission.WrapMsg("user not in group")
}
}
vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID)
if err != nil {
return nil, err
}
@ -36,8 +44,11 @@ func (s *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgrou
}, nil
}
func (s *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) {
vl, err := s.db.FindMaxJoinGroupVersionCache(ctx, req.UserID)
func (g *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
vl, err := g.db.FindMaxJoinGroupVersionCache(ctx, req.UserID)
if err != nil {
return nil, err
}
@ -65,6 +76,9 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou
if group.Status == constant.GroupStatusDismissed {
return nil, servererrs.ErrDismissedAlready.Wrap()
}
if _, err := g.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)); err != nil {
return nil, err
}
var (
hasGroupUpdate bool
sortVersion uint64
@ -132,152 +146,8 @@ func (s *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgrou
return resp, nil
}
func (s *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (resp *pbgroup.BatchGetIncrementalGroupMemberResp, err error) {
type VersionInfo struct {
GroupID string
VersionID string
VersionNumber uint64
}
var groupIDs []string
groupsVersionMap := make(map[string]*VersionInfo)
groupsMap := make(map[string]*model.Group)
hasGroupUpdateMap := make(map[string]bool)
sortVersionMap := make(map[string]uint64)
var targetKeys, versionIDs []string
var versionNumbers []uint64
var requestBodyLen int
for _, group := range req.ReqList {
groupsVersionMap[group.GroupID] = &VersionInfo{
GroupID: group.GroupID,
VersionID: group.VersionID,
VersionNumber: group.Version,
}
groupIDs = append(groupIDs, group.GroupID)
}
groups, err := s.db.FindGroup(ctx, groupIDs)
if err != nil {
return nil, errs.Wrap(err)
}
for _, group := range groups {
if group.Status == constant.GroupStatusDismissed {
err = servererrs.ErrDismissedAlready.Wrap()
log.ZError(ctx, "This group is Dismissed Already", err, "group is", group.GroupID)
delete(groupsVersionMap, group.GroupID)
} else {
groupsMap[group.GroupID] = group
}
}
for groupID, vInfo := range groupsVersionMap {
targetKeys = append(targetKeys, groupID)
versionIDs = append(versionIDs, vInfo.VersionID)
versionNumbers = append(versionNumbers, vInfo.VersionNumber)
}
opt := incrversion.BatchOption[[]*sdkws.GroupMemberFullInfo, pbgroup.BatchGetIncrementalGroupMemberResp]{
Ctx: ctx,
TargetKeys: targetKeys,
VersionIDs: versionIDs,
VersionNumbers: versionNumbers,
Versions: func(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) {
vLogs, err := s.db.BatchFindMemberIncrVersion(ctx, groupIDs, versions, limits)
if err != nil {
return nil, errs.Wrap(err)
}
for groupID, vlog := range vLogs {
vlogElems := make([]model.VersionLogElem, 0, len(vlog.Logs))
for i, log := range vlog.Logs {
switch log.EID {
case model.VersionGroupChangeID:
vlog.LogLen--
hasGroupUpdateMap[groupID] = true
case model.VersionSortChangeID:
vlog.LogLen--
sortVersionMap[groupID] = uint64(log.Version)
default:
vlogElems = append(vlogElems, vlog.Logs[i])
}
}
vlog.Logs = vlogElems
if vlog.LogLen > 0 {
hasGroupUpdateMap[groupID] = true
}
}
return vLogs, nil
},
CacheMaxVersions: s.db.BatchFindMaxGroupMemberVersionCache,
Find: func(ctx context.Context, groupID string, ids []string) ([]*sdkws.GroupMemberFullInfo, error) {
memberInfo, err := s.getGroupMembersInfo(ctx, groupID, ids)
if err != nil {
return nil, err
}
return memberInfo, err
},
Resp: func(versions map[string]*model.VersionLog, deleteIdsMap map[string][]string, insertListMap, updateListMap map[string][]*sdkws.GroupMemberFullInfo, fullMap map[string]bool) *pbgroup.BatchGetIncrementalGroupMemberResp {
resList := make(map[string]*pbgroup.GetIncrementalGroupMemberResp)
for groupID, versionLog := range versions {
resList[groupID] = &pbgroup.GetIncrementalGroupMemberResp{
VersionID: versionLog.ID.Hex(),
Version: uint64(versionLog.Version),
Full: fullMap[groupID],
Delete: deleteIdsMap[groupID],
Insert: insertListMap[groupID],
Update: updateListMap[groupID],
SortVersion: sortVersionMap[groupID],
}
requestBodyLen += len(insertListMap[groupID]) + len(updateListMap[groupID]) + len(deleteIdsMap[groupID])
if requestBodyLen > 200 {
break
}
}
return &pbgroup.BatchGetIncrementalGroupMemberResp{
RespList: resList,
}
},
}
resp, err = opt.Build()
if err != nil {
return nil, errs.Wrap(err)
}
for groupID, val := range resp.RespList {
if val.Full || hasGroupUpdateMap[groupID] {
count, err := s.db.FindGroupMemberNum(ctx, groupID)
if err != nil {
return nil, err
}
owner, err := s.db.TakeGroupOwner(ctx, groupID)
if err != nil {
return nil, err
}
resp.RespList[groupID].Group = s.groupDB2PB(groupsMap[groupID], owner.UserID, count)
}
}
return resp, nil
}
func (s *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{

@ -61,6 +61,13 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m
return nil, err
}
resp := &msg.GetConversationsHasReadAndMaxSeqResp{Seqs: make(map[string]*msg.Seqs)}
if req.ReturnPinned {
pinnedConversationIDs, err := m.ConversationLocalCache.GetPinnedConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
resp.PinnedConversationIDs = pinnedConversationIDs
}
for conversationID, maxSeq := range maxSeqs {
resp.Seqs[conversationID] = &msg.Seqs{
HasReadSeq: hasReadSeqs[conversationID],

@ -2,15 +2,16 @@ package msg
import (
"context"
"strings"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/protocol/msg"
"github.com/openimsdk/tools/log"
"strings"
)
// DestructMsgs hard delete in Database.
func (m *msgServer) DestructMsgs(ctx context.Context, req *msg.DestructMsgsReq) (*msg.DestructMsgsResp, error) {
if err := authverify.CheckAdmin(ctx, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
docs, err := m.MsgDatabase.GetRandBeforeMsg(ctx, req.Timestamp, int(req.Limit))

@ -42,7 +42,7 @@ func (m *msgServer) validateDeleteSyncOpt(opt *msg.DeleteSyncOpt) (isSyncSelf, i
}
func (m *msgServer) ClearConversationsMsg(ctx context.Context, req *msg.ClearConversationsMsgReq) (*msg.ClearConversationsMsgResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
if err := m.clearConversation(ctx, req.ConversationIDs, req.UserID, req.DeleteSyncOpt); err != nil {
@ -52,7 +52,7 @@ func (m *msgServer) ClearConversationsMsg(ctx context.Context, req *msg.ClearCon
}
func (m *msgServer) UserClearAllMsg(ctx context.Context, req *msg.UserClearAllMsgReq) (*msg.UserClearAllMsgResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID)
@ -66,7 +66,7 @@ func (m *msgServer) UserClearAllMsg(ctx context.Context, req *msg.UserClearAllMs
}
func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*msg.DeleteMsgsResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(req.DeleteSyncOpt)
@ -102,7 +102,7 @@ func (m *msgServer) DeleteMsgPhysicalBySeq(ctx context.Context, req *msg.DeleteM
}
func (m *msgServer) DeleteMsgPhysical(ctx context.Context, req *msg.DeleteMsgPhysicalReq) (*msg.DeleteMsgPhysicalResp, error) {
if err := authverify.CheckAdmin(ctx, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
remainTime := timeutil.GetCurrentTimestampBySecond() - req.Timestamp

@ -42,7 +42,7 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg.
if req.Seq < 0 {
return nil, errs.ErrArgs.WrapMsg("seq is invalid")
}
if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
user, err := m.UserLocalCache.GetUserInfo(ctx, req.UserID)
@ -63,11 +63,11 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg.
data, _ := json.Marshal(msgs[0])
log.ZDebug(ctx, "GetMsgBySeqs", "conversationID", req.ConversationID, "seq", req.Seq, "msg", string(data))
var role int32
if !authverify.IsAppManagerUid(ctx, m.config.Share.IMAdminUserID) {
if !authverify.IsAdmin(ctx) {
sessionType := msgs[0].SessionType
switch sessionType {
case constant.SingleChatType:
if err := authverify.CheckAccessV3(ctx, msgs[0].SendID, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, msgs[0].SendID); err != nil {
return nil, err
}
role = user.AppMangerLevel

@ -118,7 +118,7 @@ func (m *msgServer) GetSeqMessage(ctx context.Context, req *msg.GetSeqMessageReq
}
func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, m.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID)

@ -30,10 +30,9 @@ import (
)
func (s *friendServer) GetPaginationBlacks(ctx context.Context, req *relation.GetPaginationBlacksReq) (resp *relation.GetPaginationBlacksResp, err error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
total, blacks, err := s.blackDatabase.FindOwnerBlacks(ctx, req.UserID, req.Pagination)
if err != nil {
return nil, err
@ -59,7 +58,7 @@ func (s *friendServer) IsBlack(ctx context.Context, req *relation.IsBlackReq) (*
}
func (s *friendServer) RemoveBlack(ctx context.Context, req *relation.RemoveBlackReq) (*relation.RemoveBlackResp, error) {
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
@ -74,7 +73,7 @@ func (s *friendServer) RemoveBlack(ctx context.Context, req *relation.RemoveBlac
}
func (s *friendServer) AddBlack(ctx context.Context, req *relation.AddBlackReq) (*relation.AddBlackResp, error) {
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
@ -100,7 +99,7 @@ func (s *friendServer) AddBlack(ctx context.Context, req *relation.AddBlackReq)
}
func (s *friendServer) GetSpecifiedBlacks(ctx context.Context, req *relation.GetSpecifiedBlacksReq) (*relation.GetSpecifiedBlacksResp, error) {
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}

@ -135,7 +135,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg
// ok.
func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *relation.ApplyToAddFriendReq) (resp *relation.ApplyToAddFriendResp, err error) {
resp = &relation.ApplyToAddFriendResp{}
if err := authverify.CheckAccessV3(ctx, req.FromUserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil {
return nil, err
}
if req.ToUserID == req.FromUserID {
@ -165,7 +165,7 @@ func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *relation.Apply
// ok.
func (s *friendServer) ImportFriends(ctx context.Context, req *relation.ImportFriendReq) (resp *relation.ImportFriendResp, err error) {
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
@ -201,7 +201,7 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *relation.ImportFr
// ok.
func (s *friendServer) RespondFriendApply(ctx context.Context, req *relation.RespondFriendApplyReq) (resp *relation.RespondFriendApplyResp, err error) {
resp = &relation.RespondFriendApplyResp{}
if err := authverify.CheckAccessV3(ctx, req.ToUserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.ToUserID); err != nil {
return nil, err
}
@ -236,7 +236,7 @@ func (s *friendServer) RespondFriendApply(ctx context.Context, req *relation.Res
// ok.
func (s *friendServer) DeleteFriend(ctx context.Context, req *relation.DeleteFriendReq) (resp *relation.DeleteFriendResp, err error) {
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
@ -261,7 +261,7 @@ func (s *friendServer) SetFriendRemark(ctx context.Context, req *relation.SetFri
return nil, err
}
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
@ -331,7 +331,7 @@ func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context,
// Get received friend requests (i.e., those initiated by others).
func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *relation.GetPaginationFriendsApplyToReq) (resp *relation.GetPaginationFriendsApplyToResp, err error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
@ -354,7 +354,7 @@ func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *rel
func (s *friendServer) GetPaginationFriendsApplyFrom(ctx context.Context, req *relation.GetPaginationFriendsApplyFromReq) (resp *relation.GetPaginationFriendsApplyFromResp, err error) {
resp = &relation.GetPaginationFriendsApplyFromResp{}
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
@ -384,7 +384,7 @@ func (s *friendServer) IsFriend(ctx context.Context, req *relation.IsFriendReq)
}
func (s *friendServer) GetPaginationFriends(ctx context.Context, req *relation.GetPaginationFriendsReq) (resp *relation.GetPaginationFriendsResp, err error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
@ -405,7 +405,7 @@ func (s *friendServer) GetPaginationFriends(ctx context.Context, req *relation.G
}
func (s *friendServer) GetFriendIDs(ctx context.Context, req *relation.GetFriendIDsReq) (resp *relation.GetFriendIDsResp, err error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}

@ -2,10 +2,11 @@ package relation
import (
"context"
"slices"
"github.com/openimsdk/open-im-server/v3/pkg/util/hashutil"
"github.com/openimsdk/protocol/sdkws"
"github.com/openimsdk/tools/log"
"slices"
"github.com/openimsdk/open-im-server/v3/internal/rpc/incrversion"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
@ -39,6 +40,9 @@ func (s *friendServer) NotificationUserInfoUpdate(ctx context.Context, req *rela
}
func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.GetFullFriendUserIDsReq) (*relation.GetFullFriendUserIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
vl, err := s.db.FindMaxFriendVersionCache(ctx, req.UserID)
if err != nil {
return nil, err
@ -60,7 +64,7 @@ func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.G
}
func (s *friendServer) GetIncrementalFriends(ctx context.Context, req *relation.GetIncrementalFriendsReq) (*relation.GetIncrementalFriendsResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
var sortVersion uint64

@ -82,7 +82,7 @@ func (t *thirdServer) UploadLogs(ctx context.Context, req *third.UploadLogsReq)
}
func (t *thirdServer) DeleteLogs(ctx context.Context, req *third.DeleteLogsReq) (*third.DeleteLogsResp, error) {
if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
userID := ""
@ -123,7 +123,7 @@ func dbToPbLogInfos(logs []*relationtb.Log) []*third.LogInfo {
}
func (t *thirdServer) SearchLogs(ctx context.Context, req *third.SearchLogsReq) (*third.SearchLogsResp, error) {
if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
var (

@ -62,7 +62,7 @@ func (t *thirdServer) InitiateMultipartUpload(ctx context.Context, req *third.In
return nil, err
}
expireTime := time.Now().Add(t.defaultExpire)
result, err := t.s3dataBase.InitiateMultipartUpload(ctx, req.Hash, req.Size, t.defaultExpire, int(req.MaxParts))
result, err := t.s3dataBase.InitiateMultipartUpload(ctx, req.Hash, req.Size, t.defaultExpire, int(req.MaxParts), req.ContentType)
if err != nil {
if haErr, ok := errs.Unwrap(err).(*cont.HashAlreadyExistsError); ok {
obj := &model.Object{
@ -198,7 +198,7 @@ func (t *thirdServer) InitiateFormData(ctx context.Context, req *third.InitiateF
var duration time.Duration
opUserID := mcontext.GetOpUserID(ctx)
var key string
if t.IsManagerUserID(opUserID) {
if authverify.CheckUserIsAdmin(ctx, opUserID) {
if req.Millisecond <= 0 {
duration = time.Minute * 10
} else {
@ -289,7 +289,7 @@ func (t *thirdServer) apiAddress(prefix, name string) string {
}
func (t *thirdServer) DeleteOutdatedData(ctx context.Context, req *third.DeleteOutdatedDataReq) (*third.DeleteOutdatedDataResp, error) {
if err := authverify.CheckAdmin(ctx, t.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
engine := t.config.RpcConfig.Object.Enable

@ -54,7 +54,7 @@ func (t *thirdServer) checkUploadName(ctx context.Context, name string) error {
if opUserID == "" {
return errs.ErrNoPermission.WrapMsg("opUserID is empty")
}
if !authverify.IsManagerUserID(opUserID, t.config.Share.IMAdminUserID) {
if !authverify.CheckUserIsAdmin(ctx, opUserID) {
if !strings.HasPrefix(name, opUserID+"/") {
return errs.ErrNoPermission.WrapMsg(fmt.Sprintf("name must start with `%s/`", opUserID))
}
@ -79,10 +79,6 @@ func checkValidObjectName(objectName string) error {
return checkValidObjectNamePrefix(objectName)
}
func (t *thirdServer) IsManagerUserID(opUserID string) bool {
return authverify.IsManagerUserID(opUserID, t.config.Share.IMAdminUserID)
}
func putUpdate[T any](update map[string]any, name string, val interface{ GetValuePtr() *T }) {
ptrVal := val.GetValuePtr()
if ptrVal == nil {

@ -0,0 +1,71 @@
package user
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
pbuser "github.com/openimsdk/protocol/user"
"github.com/openimsdk/tools/utils/datautil"
)
func (s *userServer) GetUserClientConfig(ctx context.Context, req *pbuser.GetUserClientConfigReq) (*pbuser.GetUserClientConfigResp, error) {
if req.UserID != "" {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
if _, err := s.db.GetUserByID(ctx, req.UserID); err != nil {
return nil, err
}
}
res, err := s.clientConfig.GetUserConfig(ctx, req.UserID)
if err != nil {
return nil, err
}
return &pbuser.GetUserClientConfigResp{Configs: res}, nil
}
func (s *userServer) SetUserClientConfig(ctx context.Context, req *pbuser.SetUserClientConfigReq) (*pbuser.SetUserClientConfigResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if req.UserID != "" {
if _, err := s.db.GetUserByID(ctx, req.UserID); err != nil {
return nil, err
}
}
if err := s.clientConfig.SetUserConfig(ctx, req.UserID, req.Configs); err != nil {
return nil, err
}
return &pbuser.SetUserClientConfigResp{}, nil
}
func (s *userServer) DelUserClientConfig(ctx context.Context, req *pbuser.DelUserClientConfigReq) (*pbuser.DelUserClientConfigResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if err := s.clientConfig.DelUserConfig(ctx, req.UserID, req.Keys); err != nil {
return nil, err
}
return &pbuser.DelUserClientConfigResp{}, nil
}
func (s *userServer) PageUserClientConfig(ctx context.Context, req *pbuser.PageUserClientConfigReq) (*pbuser.PageUserClientConfigResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
total, res, err := s.clientConfig.GetUserConfigPage(ctx, req.UserID, req.Key, req.Pagination)
if err != nil {
return nil, err
}
return &pbuser.PageUserClientConfigResp{
Total: total,
Configs: datautil.Slice(res, func(e *model.ClientConfig) *pbuser.ClientConfig {
return &pbuser.ClientConfig{
UserID: e.UserID,
Key: e.Key,
Value: e.Value,
}
}),
}, nil
}

@ -23,45 +23,48 @@ import (
"time"
"github.com/openimsdk/open-im-server/v3/internal/rpc/relation"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/convert"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo"
tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/open-im-server/v3/pkg/common/webhook"
"github.com/openimsdk/open-im-server/v3/pkg/dbbuild"
"github.com/openimsdk/open-im-server/v3/pkg/localcache"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/group"
friendpb "github.com/openimsdk/protocol/relation"
"github.com/openimsdk/tools/db/redisutil"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/convert"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/sdkws"
pbuser "github.com/openimsdk/protocol/user"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
registry "github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/grpc"
)
const (
defaultSecret = "openIM123"
)
type userServer struct {
pbuser.UnimplementedUserServer
online cache.OnlineCache
db controller.UserDatabase
friendNotificationSender *relation.FriendNotificationSender
userNotificationSender *UserNotificationSender
RegisterCenter registry.SvcDiscoveryRegistry
RegisterCenter discovery.Conn
config *Config
webhookClient *webhook.Client
groupClient *rpcli.GroupClient
relationClient *rpcli.RelationClient
clientConfig controller.ClientConfigDatabase
}
type Config struct {
@ -76,24 +79,30 @@ type Config struct {
Discovery config.Discovery
}
func Start(ctx context.Context, config *Config, client registry.SvcDiscoveryRegistry, server *grpc.Server) error {
mgocli, err := mongoutil.NewMongoDB(ctx, config.MongodbConfig.Build())
func Start(ctx context.Context, config *Config, client discovery.Conn, server grpc.ServiceRegistrar) error {
dbb := dbbuild.NewBuilder(&config.MongodbConfig, &config.RedisConfig)
mgocli, err := dbb.Mongo(ctx)
if err != nil {
return err
}
rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build())
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
users := make([]*tablerelation.User, 0)
for _, v := range config.Share.IMAdminUserID {
users = append(users, &tablerelation.User{UserID: v, Nickname: v, AppMangerLevel: constant.AppNotificationAdmin})
users = append(users, &tablerelation.User{UserID: v, Nickname: v, AppMangerLevel: constant.AppAdmin})
}
userDB, err := mgo.NewUserMongo(mgocli.GetDB())
if err != nil {
return err
}
clientConfigDB, err := mgo.NewClientConfig(mgocli.GetDB())
if err != nil {
return err
}
msgConn, err := client.GetConn(ctx, config.Discovery.RpcService.Msg)
if err != nil {
return err
@ -118,7 +127,7 @@ func Start(ctx context.Context, config *Config, client registry.SvcDiscoveryRegi
userNotificationSender: NewUserNotificationSender(config, msgClient, WithUserFunc(database.FindWithError)),
config: config,
webhookClient: webhook.NewWebhookClient(config.WebhooksConfig.URL),
clientConfig: controller.NewClientConfigDatabase(clientConfigDB, redis.NewClientConfigCache(rdb, clientConfigDB), mgocli.GetTx()),
groupClient: rpcli.NewGroupClient(groupConn),
relationClient: rpcli.NewRelationClient(friendConn),
}
@ -141,7 +150,7 @@ func (s *userServer) GetDesignateUsers(ctx context.Context, req *pbuser.GetDesig
// UpdateUserInfo
func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) (resp *pbuser.UpdateUserInfoResp, err error) {
resp = &pbuser.UpdateUserInfoResp{}
err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config.Share.IMAdminUserID)
err = authverify.CheckAccess(ctx, req.UserInfo.UserID)
if err != nil {
return nil, err
}
@ -168,7 +177,7 @@ func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserI
func (s *userServer) UpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) (resp *pbuser.UpdateUserInfoExResp, err error) {
resp = &pbuser.UpdateUserInfoExResp{}
err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config.Share.IMAdminUserID)
err = authverify.CheckAccess(ctx, req.UserInfo.UserID)
if err != nil {
return nil, err
}
@ -226,8 +235,7 @@ func (s *userServer) AccountCheck(ctx context.Context, req *pbuser.AccountCheckR
if datautil.Duplicate(req.CheckUserIDs) {
return nil, errs.ErrArgs.WrapMsg("userID repeated")
}
err = authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID)
if err != nil {
if err = authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
users, err := s.db.Find(ctx, req.CheckUserIDs)
@ -273,11 +281,13 @@ func (s *userServer) UserRegister(ctx context.Context, req *pbuser.UserRegisterR
if len(req.Users) == 0 {
return nil, errs.ErrArgs.WrapMsg("users is empty")
}
if err = authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
// check if secret is changed
//if s.config.Share.Secret == defaultSecret {
// return nil, servererrs.ErrSecretNotChanged.Wrap()
//}
if err = authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if datautil.DuplicateAny(req.Users, func(e *sdkws.UserInfo) string { return e.UserID }) {
return nil, errs.ErrArgs.WrapMsg("userID repeated")
}
@ -343,7 +353,7 @@ func (s *userServer) GetAllUserID(ctx context.Context, req *pbuser.GetAllUserIDR
// ProcessUserCommandAdd user general function add.
func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.ProcessUserCommandAddReq) (*pbuser.ProcessUserCommandAddResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID)
err := authverify.CheckAccess(ctx, req.UserID)
if err != nil {
return nil, err
}
@ -371,7 +381,7 @@ func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.Proc
// ProcessUserCommandDelete user general function delete.
func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.ProcessUserCommandDeleteReq) (*pbuser.ProcessUserCommandDeleteResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID)
err := authverify.CheckAccess(ctx, req.UserID)
if err != nil {
return nil, err
}
@ -390,7 +400,7 @@ func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.P
// ProcessUserCommandUpdate user general function update.
func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.ProcessUserCommandUpdateReq) (*pbuser.ProcessUserCommandUpdateResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID)
err := authverify.CheckAccess(ctx, req.UserID)
if err != nil {
return nil, err
}
@ -419,7 +429,7 @@ func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.P
func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.ProcessUserCommandGetReq) (*pbuser.ProcessUserCommandGetResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID)
err := authverify.CheckAccess(ctx, req.UserID)
if err != nil {
return nil, err
}
@ -448,7 +458,7 @@ func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.Proc
}
func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.ProcessUserCommandGetAllReq) (*pbuser.ProcessUserCommandGetAllResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID)
err := authverify.CheckAccess(ctx, req.UserID)
if err != nil {
return nil, err
}
@ -477,7 +487,7 @@ func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.P
}
func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.AddNotificationAccountReq) (*pbuser.AddNotificationAccountResp, error) {
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if req.AppMangerLevel < constant.AppNotificationAdmin {
@ -523,7 +533,7 @@ func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.Add
}
func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbuser.UpdateNotificationAccountInfoReq) (*pbuser.UpdateNotificationAccountInfoResp, error) {
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
@ -550,7 +560,7 @@ func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbu
func (s *userServer) SearchNotificationAccount(ctx context.Context, req *pbuser.SearchNotificationAccountReq) (*pbuser.SearchNotificationAccountResp, error) {
// Check if user is an admin
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
@ -605,7 +615,7 @@ func (s *userServer) GetNotificationAccount(ctx context.Context, req *pbuser.Get
if err != nil {
return nil, servererrs.ErrUserIDNotFound.Wrap()
}
if user.AppMangerLevel == constant.AppAdmin || user.AppMangerLevel >= constant.AppNotificationAdmin {
if user.AppMangerLevel >= constant.AppAdmin {
return &pbuser.GetNotificationAccountResp{Account: &pbuser.NotificationAccountInfo{
UserID: user.UserID,
FaceURL: user.FaceURL,

@ -31,32 +31,49 @@ func Secret(secret string) jwt.Keyfunc {
}
}
func CheckAccessV3(ctx context.Context, ownerUserID string, imAdminUserID []string) (err error) {
opUserID := mcontext.GetOpUserID(ctx)
if datautil.Contain(opUserID, imAdminUserID...) {
func CheckAdmin(ctx context.Context) error {
if IsAdmin(ctx) {
return nil
}
if opUserID == ownerUserID {
return nil
return servererrs.ErrNoPermission.WrapMsg(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx)))
}
return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID)
//func IsManagerUserID(opUserID string, imAdminUserID []string) bool {
// return datautil.Contain(opUserID, imAdminUserID...)
//}
func CheckUserIsAdmin(ctx context.Context, userID string) bool {
return datautil.Contain(userID, GetIMAdminUserIDs(ctx)...)
}
func IsAppManagerUid(ctx context.Context, imAdminUserID []string) bool {
return datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...)
func CheckSystemAccount(ctx context.Context, level int32) bool {
return level >= constant.AppAdmin
}
func CheckAdmin(ctx context.Context, imAdminUserID []string) error {
if datautil.Contain(mcontext.GetOpUserID(ctx), imAdminUserID...) {
return nil
const (
CtxIsAdminKey = "CtxIsAdminKey"
)
func WithIMAdminUserIDs(ctx context.Context, imAdminUserID []string) context.Context {
return context.WithValue(ctx, CtxIsAdminKey, imAdminUserID)
}
return servererrs.ErrNoPermission.WrapMsg(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx)))
func GetIMAdminUserIDs(ctx context.Context) []string {
imAdminUserID, _ := ctx.Value(CtxIsAdminKey).([]string)
return imAdminUserID
}
func IsManagerUserID(opUserID string, imAdminUserID []string) bool {
return datautil.Contain(opUserID, imAdminUserID...)
func IsAdmin(ctx context.Context) bool {
return datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...)
}
func CheckSystemAccount(ctx context.Context, level int32) bool {
return level >= constant.AppAdmin
func CheckAccess(ctx context.Context, ownerUserID string) error {
opUserID := mcontext.GetOpUserID(ctx)
if opUserID == ownerUserID {
return nil
}
if datautil.Contain(mcontext.GetOpUserID(ctx), GetIMAdminUserIDs(ctx)...) {
return nil
}
return servererrs.ErrNoPermission.WrapMsg("ownerUserID", ownerUserID)
}

@ -378,9 +378,15 @@ type AfterConfig struct {
}
type Share struct {
Secret string `mapstructure:"secret"`
IMAdminUserID []string `mapstructure:"imAdminUserID"`
MultiLogin MultiLogin `mapstructure:"multiLogin"`
Secret string `yaml:"secret"`
IMAdminUserID []string `yaml:"imAdminUserID"`
MultiLogin MultiLogin `yaml:"multiLogin"`
RPCMaxBodySize MaxRequestBody `yaml:"rpcMaxBodySize"`
}
type MaxRequestBody struct {
RequestMaxBodySize int `yaml:"requestMaxBodySize"`
ResponseMaxBodySize int `yaml:"responseMaxBodySize"`
}
type MultiLogin struct {

@ -38,6 +38,7 @@ const (
// General error codes.
const (
NoError = 0 // No error
DatabaseError = 90002 // Database error (redis/mysql, etc.)
NetworkError = 90004 // Network error
DataError = 90007 // Data error
@ -50,6 +51,7 @@ const (
NoPermissionError = 1002 // Insufficient permission
DuplicateKeyError = 1003
RecordNotFoundError = 1004 // Record does not exist
SecretNotChangedError = 1050 // secret not changed
// Account error codes.
UserIDNotFoundError = 1101 // UserID does not exist or is not registered

@ -17,6 +17,8 @@ package servererrs
import "github.com/openimsdk/tools/errs"
var (
ErrSecretNotChanged = errs.NewCodeError(SecretNotChangedError, "secret not changed, please change secret in config/share.yml for security reasons")
ErrDatabase = errs.NewCodeError(DatabaseError, "DatabaseError")
ErrNetwork = errs.NewCodeError(NetworkError, "NetworkError")
ErrCallback = errs.NewCodeError(CallbackError, "CallbackError")

@ -0,0 +1,15 @@
package startrpc
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"google.golang.org/grpc"
)
func grpcServerIMAdminUserID(imAdminUserID []string) grpc.ServerOption {
return grpc.ChainUnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
ctx = authverify.WithIMAdminUserIDs(ctx, imAdminUserID)
return handler(ctx, req)
})
}

@ -19,214 +19,285 @@ import (
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"reflect"
"strconv"
"syscall"
"time"
conf "github.com/openimsdk/open-im-server/v3/pkg/common/config"
disetcd "github.com/openimsdk/open-im-server/v3/pkg/common/discovery/etcd"
"github.com/openimsdk/tools/discovery/etcd"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/openimsdk/tools/utils/network"
"google.golang.org/grpc/status"
"github.com/openimsdk/tools/utils/runtimeenv"
kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discovery"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mw"
"github.com/openimsdk/tools/utils/network"
grpccli "github.com/openimsdk/tools/mw/grpc/client"
grpcsrv "github.com/openimsdk/tools/mw/grpc/server"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Start rpc server.
func Start[T any](ctx context.Context, discovery *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP,
func init() {
prommetrics.RegistryAll()
}
func getConfigRpcMaxRequestBody(value reflect.Value) *conf.MaxRequestBody {
for value.Kind() == reflect.Pointer {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
num := value.NumField()
for i := 0; i < num; i++ {
field := value.Field(i)
if !field.CanInterface() {
continue
}
for field.Kind() == reflect.Pointer {
field = field.Elem()
}
switch elem := field.Interface().(type) {
case conf.Share:
return &elem.RPCMaxBodySize
case conf.MaxRequestBody:
return &elem
}
if field.Kind() == reflect.Struct {
if elem := getConfigRpcMaxRequestBody(field); elem != nil {
return elem
}
}
}
}
return nil
}
func getConfigShare(value reflect.Value) *conf.Share {
for value.Kind() == reflect.Pointer {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
num := value.NumField()
for i := 0; i < num; i++ {
field := value.Field(i)
if !field.CanInterface() {
continue
}
for field.Kind() == reflect.Pointer {
field = field.Elem()
}
switch elem := field.Interface().(type) {
case conf.Share:
return &elem
}
if field.Kind() == reflect.Struct {
if elem := getConfigShare(field); elem != nil {
return elem
}
}
}
}
return nil
}
func Start[T any](ctx context.Context, disc *conf.Discovery, prometheusConfig *conf.Prometheus, listenIP,
registerIP string, autoSetPorts bool, rpcPorts []int, index int, rpcRegisterName string, notification *conf.Notification, config T,
watchConfigNames []string, watchServiceNames []string,
rpcFn func(ctx context.Context, config T, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error,
rpcFn func(ctx context.Context, config T, client discovery.Conn, server grpc.ServiceRegistrar) error,
options ...grpc.ServerOption) error {
watchConfigNames = append(watchConfigNames, conf.LogConfigFileName)
var (
rpcTcpAddr string
netDone = make(chan struct{}, 2)
netErr error
prometheusPort int
)
if notification != nil {
conf.InitNotification(notification)
}
registerIP, err := network.GetRpcRegisterIP(registerIP)
if err != nil {
return err
}
maxRequestBody := getConfigRpcMaxRequestBody(reflect.ValueOf(config))
shareConfig := getConfigShare(reflect.ValueOf(config))
runTimeEnv := runtimeenv.RuntimeEnvironment()
log.ZDebug(ctx, "rpc start", "rpcMaxRequestBody", maxRequestBody, "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "listenIP", listenIP)
if !autoSetPorts {
rpcPort, err := datautil.GetElemByIndex(rpcPorts, index)
options = append(options,
grpcsrv.GrpcServerMetadataContext(),
grpcsrv.GrpcServerLogger(),
grpcsrv.GrpcServerErrorConvert(),
grpcsrv.GrpcServerRequestValidate(),
grpcsrv.GrpcServerPanicCapture(),
)
if shareConfig != nil && len(shareConfig.IMAdminUserID) > 0 {
options = append(options, grpcServerIMAdminUserID(shareConfig.IMAdminUserID))
}
var clientOptions []grpc.DialOption
if maxRequestBody != nil {
if maxRequestBody.RequestMaxBodySize > 0 {
options = append(options, grpc.MaxRecvMsgSize(maxRequestBody.RequestMaxBodySize))
clientOptions = append(clientOptions, grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxRequestBody.RequestMaxBodySize)))
}
if maxRequestBody.ResponseMaxBodySize > 0 {
options = append(options, grpc.MaxSendMsgSize(maxRequestBody.ResponseMaxBodySize))
clientOptions = append(clientOptions, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxRequestBody.ResponseMaxBodySize)))
}
}
registerIP, err := network.GetRpcRegisterIP(registerIP)
if err != nil {
return err
}
rpcTcpAddr = net.JoinHostPort(network.GetListenIP(listenIP), strconv.Itoa(rpcPort))
var prometheusListenAddr string
if autoSetPorts {
prometheusListenAddr = net.JoinHostPort(listenIP, "0")
} else {
rpcTcpAddr = net.JoinHostPort(network.GetListenIP(listenIP), "0")
}
getAutoPort := func() (net.Listener, int, error) {
listener, err := net.Listen("tcp", rpcTcpAddr)
prometheusPort, err := datautil.GetElemByIndex(prometheusConfig.Ports, index)
if err != nil {
return nil, 0, errs.WrapMsg(err, "listen err", "rpcTcpAddr", rpcTcpAddr)
return err
}
_, portStr, _ := net.SplitHostPort(listener.Addr().String())
port, _ := strconv.Atoi(portStr)
return listener, port, nil
prometheusListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(prometheusPort))
}
if autoSetPorts && discovery.Enable != conf.ETCD {
return errs.New("only etcd support autoSetPorts", "rpcRegisterName", rpcRegisterName).Wrap()
}
client, err := kdisc.NewDiscoveryRegister(discovery, runTimeEnv, watchServiceNames)
watchConfigNames = append(watchConfigNames, conf.LogConfigFileName)
client, err := kdisc.NewDiscoveryRegister(disc, watchServiceNames)
if err != nil {
return err
}
defer client.Close()
client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")))
// var reg *prometheus.Registry
// var metric *grpcprometheus.ServerMetrics
if prometheusConfig.Enable {
// cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName, share)
// reg, metric, _ = prommetrics.NewGrpcPromObj(cusMetrics)
// options = append(options, mw.GrpcServer(), grpc.StreamInterceptor(metric.StreamServerInterceptor()),
// grpc.UnaryInterceptor(metric.UnaryServerInterceptor()))
options = append(
options, mw.GrpcServer(),
prommetricsUnaryInterceptor(rpcRegisterName),
prommetricsStreamInterceptor(rpcRegisterName),
)
client.AddOption(
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")),
var (
listener net.Listener
grpccli.GrpcClientLogger(),
grpccli.GrpcClientContext(),
grpccli.GrpcClientErrorConvert(),
)
if autoSetPorts {
listener, prometheusPort, err = getAutoPort()
if err != nil {
return err
if len(clientOptions) > 0 {
client.AddOption(clientOptions...)
}
etcdClient := client.(*etcd.SvcDiscoveryRegistryImpl).GetClient()
ctx, cancel := context.WithCancelCause(ctx)
_, err = etcdClient.Put(ctx, prommetrics.BuildDiscoveryKey(rpcRegisterName), jsonutil.StructToJsonString(prommetrics.BuildDefaultTarget(registerIP, prometheusPort)))
if err != nil {
return errs.WrapMsg(err, "etcd put err")
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL)
select {
case <-ctx.Done():
return
case val := <-sigs:
log.ZDebug(ctx, "recv signal", "signal", val.String())
cancel(fmt.Errorf("signal %s", val.String()))
}
} else {
prometheusPort, err = datautil.GetElemByIndex(prometheusConfig.Ports, index)
}()
if prometheusListenAddr != "" {
options = append(
options,
prommetricsUnaryInterceptor(rpcRegisterName),
prommetricsStreamInterceptor(rpcRegisterName),
)
prometheusListener, prometheusPort, err := listenTCP(prometheusListenAddr)
if err != nil {
return err
}
listener, err = net.Listen("tcp", fmt.Sprintf(":%d", prometheusPort))
log.ZDebug(ctx, "prometheus start", "addr", prometheusListener.Addr(), "rpcRegisterName", rpcRegisterName)
target, err := jsonutil.JsonMarshal(prommetrics.BuildDefaultTarget(registerIP, prometheusPort))
if err != nil {
return errs.WrapMsg(err, "listen err", "rpcTcpAddr", rpcTcpAddr)
return err
}
if err := client.SetKey(ctx, prommetrics.BuildDiscoveryKey(prommetrics.APIKeyName), target); err != nil {
if !errors.Is(err, discovery.ErrNotSupportedKeyValue) {
return err
}
}
cs := prommetrics.GetGrpcCusMetrics(rpcRegisterName, discovery)
go func() {
if err := prommetrics.RpcInit(cs, listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
netErr = errs.WrapMsg(err, fmt.Sprintf("rpc %s prometheus start err: %d", rpcRegisterName, prometheusPort))
netDone <- struct{}{}
}
//metric.InitializeMetrics(srv)
// Create a HTTP server for prometheus.
// httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)}
// if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
// netErr = errs.WrapMsg(err, "prometheus start err", httpServer.Addr)
// netDone <- struct{}{}
// }
err := prommetrics.Start(prometheusListener)
if err == nil {
err = fmt.Errorf("listener done")
}
cancel(fmt.Errorf("prommetrics %s %w", rpcRegisterName, err))
}()
} else {
options = append(options, mw.GrpcServer())
}
listener, port, err := getAutoPort()
var (
rpcServer *grpc.Server
rpcGracefulStop chan struct{}
)
onGrpcServiceRegistrar := func(desc *grpc.ServiceDesc, impl any) {
if rpcServer != nil {
rpcServer.RegisterService(desc, impl)
return
}
var rpcListenAddr string
if autoSetPorts {
rpcListenAddr = net.JoinHostPort(listenIP, "0")
} else {
rpcPort, err := datautil.GetElemByIndex(rpcPorts, index)
if err != nil {
return err
cancel(fmt.Errorf("rpcPorts index out of range %s %w", rpcRegisterName, err))
return
}
log.CInfo(ctx, "RPC server is initializing", "rpcRegisterName", rpcRegisterName, "rpcPort", port,
"prometheusPort", prometheusPort)
defer listener.Close()
srv := grpc.NewServer(options...)
err = rpcFn(ctx, config, client, srv)
rpcListenAddr = net.JoinHostPort(listenIP, strconv.Itoa(rpcPort))
}
rpcListener, err := net.Listen("tcp", rpcListenAddr)
if err != nil {
return err
cancel(fmt.Errorf("listen rpc %s %s %w", rpcRegisterName, rpcListenAddr, err))
return
}
err = client.Register(
ctx,
rpcRegisterName,
registerIP,
port,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return err
rpcServer = grpc.NewServer(options...)
rpcServer.RegisterService(desc, impl)
rpcGracefulStop = make(chan struct{})
rpcPort := rpcListener.Addr().(*net.TCPAddr).Port
log.ZDebug(ctx, "rpc start register", "rpcRegisterName", rpcRegisterName, "registerIP", registerIP, "rpcPort", rpcPort)
grpcOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
rpcGracefulStop = make(chan struct{})
go func() {
<-ctx.Done()
rpcServer.GracefulStop()
close(rpcGracefulStop)
}()
if err := client.Register(ctx, rpcRegisterName, registerIP, rpcListener.Addr().(*net.TCPAddr).Port, grpcOpt); err != nil {
cancel(fmt.Errorf("rpc register %s %w", rpcRegisterName, err))
return
}
go func() {
err := srv.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
netErr = errs.WrapMsg(err, "rpc start err: ", rpcTcpAddr)
netDone <- struct{}{}
err := rpcServer.Serve(rpcListener)
if err == nil {
err = fmt.Errorf("serve end")
}
cancel(fmt.Errorf("rpc %s %w", rpcRegisterName, err))
}()
if discovery.Enable == conf.ETCD {
cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), watchConfigNames)
cm.Watch(ctx)
}
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM)
select {
case <-sigs:
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := gracefulStopWithCtx(ctx, srv.GracefulStop); err != nil {
err = rpcFn(ctx, config, client, &grpcServiceRegistrar{onRegisterService: onGrpcServiceRegistrar})
if err != nil {
return err
}
return nil
case <-netDone:
return netErr
<-ctx.Done()
log.ZDebug(ctx, "cmd wait done", "err", context.Cause(ctx))
if rpcGracefulStop != nil {
timeout := time.NewTimer(time.Second * 15)
defer timeout.Stop()
select {
case <-timeout.C:
log.ZWarn(ctx, "rcp graceful stop timeout", nil)
case <-rpcGracefulStop:
log.ZDebug(ctx, "rcp graceful stop done")
}
}
return context.Cause(ctx)
}
func gracefulStopWithCtx(ctx context.Context, f func()) error {
done := make(chan struct{}, 1)
go func() {
f()
close(done)
}()
select {
case <-ctx.Done():
return errs.New("timeout, ctx graceful stop")
case <-done:
return nil
func listenTCP(addr string) (net.Listener, int, error) {
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, 0, errs.WrapMsg(err, "listen err", "addr", addr)
}
return listener, listener.Addr().(*net.TCPAddr).Port, nil
}
func prommetricsUnaryInterceptor(rpcRegisterName string) grpc.ServerOption {
@ -250,3 +321,11 @@ func prommetricsUnaryInterceptor(rpcRegisterName string) grpc.ServerOption {
func prommetricsStreamInterceptor(rpcRegisterName string) grpc.ServerOption {
return grpc.ChainStreamInterceptor()
}
type grpcServiceRegistrar struct {
onRegisterService func(desc *grpc.ServiceDesc, impl any)
}
func (x *grpcServiceRegistrar) RegisterService(desc *grpc.ServiceDesc, impl any) {
x.onRegisterService(desc, impl)
}

@ -0,0 +1,10 @@
package cachekey
const ClientConfig = "CLIENT_CONFIG"
func GetClientConfigKey(userID string) string {
if userID == "" {
return ClientConfig
}
return ClientConfig + ":" + userID
}

@ -1,8 +1,9 @@
package cachekey
import (
"github.com/openimsdk/protocol/constant"
"strings"
"github.com/openimsdk/protocol/constant"
)
const (
@ -13,6 +14,10 @@ func GetTokenKey(userID string, platformID int) string {
return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
}
func GetTemporaryTokenKey(userID string, platformID int, token string) string {
return UidPidToken + ":TEMPORARY:" + userID + ":" + constant.PlatformIDToName(platformID) + ":" + token
}
func GetAllPlatformTokenKey(userID string) []string {
res := make([]string, len(constant.PlatformID2Name))
for k := range constant.PlatformID2Name {

@ -0,0 +1,8 @@
package cache
import "context"
type ClientConfigCache interface {
DeleteUserCache(ctx context.Context, userIDs []string) error
GetUserConfig(ctx context.Context, userID string) (map[string]string, error)
}

@ -0,0 +1,166 @@
package mcache
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
)
func NewTokenCacheModel(cache database.Cache, accessExpire int64) cache.TokenModel {
c := &tokenCache{cache: cache}
c.accessExpire = c.getExpireTime(accessExpire)
return c
}
type tokenCache struct {
cache database.Cache
accessExpire time.Duration
}
func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string {
return cachekey.GetTokenKey(userID, platformID) + ":" + token
}
func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
return x.cache.Set(ctx, x.getTokenKey(userID, platformID, token), strconv.Itoa(flag), x.accessExpire)
}
// SetTokenFlagEx set token and flag with expire time
func (x *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error {
return x.SetTokenFlag(ctx, userID, platformID, token, flag)
}
func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
prefix := x.getTokenKey(userID, platformID, "")
m, err := x.cache.Prefix(ctx, prefix)
if err != nil {
return nil, errs.Wrap(err)
}
mm := make(map[string]int)
for k, v := range m {
state, err := strconv.Atoi(v)
if err != nil {
log.ZError(ctx, "token value is not int", err, "value", v, "userID", userID, "platformID", platformID)
continue
}
mm[strings.TrimPrefix(k, prefix)] = state
}
return mm, nil
}
func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
key := cachekey.GetTemporaryTokenKey(userID, platformID, token)
if _, err := x.cache.Get(ctx, []string{key}); err != nil {
return err
}
return nil
}
func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
prefix := cachekey.UidPidToken + userID + ":"
tokens, err := x.cache.Prefix(ctx, prefix)
if err != nil {
return nil, err
}
res := make(map[int]map[string]int)
for key, flagStr := range tokens {
flag, err := strconv.Atoi(flagStr)
if err != nil {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
arr := strings.SplitN(strings.TrimPrefix(key, prefix), ":", 2)
if len(arr) != 2 {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
platformID, err := strconv.Atoi(arr[0])
if err != nil {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
token := arr[1]
if token == "" {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
tk, ok := res[platformID]
if !ok {
tk = make(map[string]int)
res[platformID] = tk
}
tk[token] = flag
}
return res, nil
}
func (x *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
for token, flag := range m {
err := x.SetTokenFlag(ctx, userID, platformID, token, flag)
if err != nil {
return err
}
}
return nil
}
func (x *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
for prefix, tokenFlag := range tokens {
for token, flag := range tokenFlag {
flagStr := fmt.Sprintf("%v", flag)
if err := x.cache.Set(ctx, prefix+":"+token, flagStr, x.accessExpire); err != nil {
return err
}
}
}
return nil
}
func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
keys := make([]string, 0, len(fields))
for _, token := range fields {
keys = append(keys, x.getTokenKey(userID, platformID, token))
}
return x.cache.Del(ctx, keys)
}
func (x *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t)
}
func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
keys := make([]string, 0, len(tokens))
for platformID, ts := range tokens {
for _, t := range ts {
keys = append(keys, x.getTokenKey(userID, platformID, t))
}
}
return x.cache.Del(ctx, keys)
}
func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
keys := make([]string, 0, len(fields))
for _, f := range fields {
keys = append(keys, x.getTokenKey(userID, platformID, f))
}
if err := x.cache.Del(ctx, keys); err != nil {
return err
}
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := x.cache.Set(ctx, k, "", time.Minute*5); err != nil {
return errs.Wrap(err)
}
}
return nil
}

@ -0,0 +1,69 @@
package redis
import (
"context"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/redis/go-redis/v9"
)
func NewClientConfigCache(rdb redis.UniversalClient, mgo database.ClientConfig) cache.ClientConfigCache {
rc := newRocksCacheClient(rdb)
return &ClientConfigCache{
mgo: mgo,
rcClient: rc,
delete: rc.GetBatchDeleter(),
}
}
type ClientConfigCache struct {
mgo database.ClientConfig
rcClient *rocksCacheClient
delete cache.BatchDeleter
}
func (x *ClientConfigCache) getExpireTime(userID string) time.Duration {
if userID == "" {
return time.Hour * 24
} else {
return time.Hour
}
}
func (x *ClientConfigCache) getClientConfigKey(userID string) string {
return cachekey.GetClientConfigKey(userID)
}
func (x *ClientConfigCache) GetConfig(ctx context.Context, userID string) (map[string]string, error) {
return getCache(ctx, x.rcClient, x.getClientConfigKey(userID), x.getExpireTime(userID), func(ctx context.Context) (map[string]string, error) {
return x.mgo.Get(ctx, userID)
})
}
func (x *ClientConfigCache) DeleteUserCache(ctx context.Context, userIDs []string) error {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, x.getClientConfigKey(userID))
}
return x.delete.ExecDelWithKeys(ctx, keys)
}
func (x *ClientConfigCache) GetUserConfig(ctx context.Context, userID string) (map[string]string, error) {
config, err := x.GetConfig(ctx, "")
if err != nil {
return nil, err
}
if userID != "" {
userConfig, err := x.GetConfig(ctx, userID)
if err != nil {
return nil, err
}
for k, v := range userConfig {
config[k] = v
}
}
return config, nil
}

@ -9,6 +9,7 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
@ -55,6 +56,14 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p
return mm, nil
}
func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err()
if err != nil {
return errs.Wrap(err)
}
return nil
}
func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
var (
res = make(map[int]map[string]int)
@ -101,6 +110,8 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla
}
func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
keys := datautil.Keys(tokens)
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
for k, v := range tokens {
pipe.HSet(ctx, k, v)
@ -110,6 +121,10 @@ func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[st
return errs.Wrap(err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
@ -119,3 +134,47 @@ func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla
func (c *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t)
}
// DeleteTokenByTokenMap tokens key is platformID, value is token slice
func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
var (
keys = make([]string, 0, len(tokens))
keyMap = make(map[string][]string)
)
for k, v := range tokens {
k1 := cachekey.GetTokenKey(userID, k)
keys = append(keys, k1)
keyMap[k1] = v
}
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
for k, v := range tokens {
pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...)
}
_, err := pipe.Exec(ctx)
if err != nil {
return errs.Wrap(err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil {
return errs.Wrap(err)
}
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := c.rdb.Set(ctx, k, "", time.Minute*5).Err(); err != nil {
return errs.Wrap(err)
}
}
return nil
}

@ -9,8 +9,11 @@ type TokenModel interface {
// SetTokenFlagEx set token and flag with expire time
SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error
GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error
DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error
}

@ -17,6 +17,8 @@ import (
type AuthDatabase interface {
// If the result is empty, no error is returned.
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error
// Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error)
@ -51,6 +53,10 @@ func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string,
return a.cache.GetTokensWithoutError(ctx, userID, platformID)
}
func (a *authDatabase) GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error {
return a.cache.HasTemporaryToken(ctx, userID, platformID, token)
}
func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m)
}
@ -86,19 +92,20 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return "", err
}
deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID)
deleteTokenKey, kickedTokenKey, adminTokens, err := a.checkToken(ctx, tokens, platformID)
if err != nil {
return "", err
}
if len(deleteTokenKey) != 0 {
err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
err = a.cache.DeleteTokenByTokenMap(ctx, userID, deleteTokenKey)
if err != nil {
return "", err
}
}
if len(kickedTokenKey) != 0 {
for _, k := range kickedTokenKey {
err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken)
for plt, ks := range kickedTokenKey {
for _, k := range ks {
err := a.cache.SetTokenFlagEx(ctx, userID, plt, k, constant.KickedToken)
if err != nil {
return "", err
}
@ -106,6 +113,11 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
}
}
}
if len(adminTokens) != 0 {
if err = a.cache.DeleteAndSetTemporary(ctx, userID, constant.AdminPlatformID, adminTokens); err != nil {
return "", err
}
}
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@ -123,12 +135,13 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
return tokenString, nil
}
func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) {
// todo: Move the logic for handling old data to another location.
// checkToken will check token by tokenPolicy and return deleteToken,kickToken,deleteAdminToken
func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) (map[int][]string, map[int][]string, []string, error) {
// todo: Asynchronous deletion of old data.
var (
loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0
deleteToken = make([]string, 0)
kickToken = make([]string, 0)
deleteToken = make(map[int][]string)
kickToken = make(map[int][]string)
adminToken = make([]string, 0)
unkickTerminal = ""
)
@ -137,7 +150,7 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
for k, v := range tks {
_, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret))
if err != nil || v != constant.NormalToken {
deleteToken = append(deleteToken, k)
deleteToken[plfID] = append(deleteToken[plfID], k)
} else {
if plfID != constant.AdminPlatformID {
loginTokenMap[plfID] = append(loginTokenMap[plfID], k)
@ -157,14 +170,15 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
}
limit := a.multiLogin.MaxNumOneEnd
if l > limit {
kickToken = append(kickToken, ts[:l-limit]...)
kickToken[plt] = ts[:l-limit]
}
}
case constant.AllLoginButSameTermKick:
for plt, ts := range loginTokenMap {
kickToken = append(kickToken, ts[:len(ts)-1]...)
kickToken[plt] = ts[:len(ts)-1]
if plt == platformID {
kickToken = append(kickToken, ts[len(ts)-1])
kickToken[plt] = append(kickToken[plt], ts[len(ts)-1])
}
}
case constant.PCAndOther:
@ -172,29 +186,33 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
if constant.PlatformIDToClass(platformID) != unkickTerminal {
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal {
kickToken = append(kickToken, ts...)
kickToken[plt] = ts
}
}
} else {
var (
preKick []string
isReserve = true
preKickToken string
preKickPlt int
reserveToken = false
)
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal {
// Keep a token from another end
if isReserve {
isReserve = false
kickToken = append(kickToken, ts[:len(ts)-1]...)
preKick = append(preKick, ts[len(ts)-1])
if !reserveToken {
reserveToken = true
kickToken[plt] = ts[:len(ts)-1]
preKickToken = ts[len(ts)-1]
preKickPlt = plt
continue
} else {
// Prioritize keeping Android
if plt == constant.AndroidPlatformID {
kickToken = append(kickToken, preKick...)
kickToken = append(kickToken, ts[:len(ts)-1]...)
if preKickToken != "" {
kickToken[preKickPlt] = append(kickToken[preKickPlt], preKickToken)
}
kickToken[plt] = ts[:len(ts)-1]
} else {
kickToken = append(kickToken, ts...)
kickToken[plt] = ts
}
}
}
@ -207,19 +225,19 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) {
kickToken = append(kickToken, ts...)
kickToken[plt] = ts
} else {
if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok {
reserved[constant.PlatformIDToClass(plt)] = struct{}{}
kickToken = append(kickToken, ts[:len(ts)-1]...)
kickToken[plt] = ts[:len(ts)-1]
continue
} else {
kickToken = append(kickToken, ts...)
kickToken[plt] = ts
}
}
}
default:
return nil, nil, errs.New("unknown multiLogin policy").Wrap()
return nil, nil, nil, errs.New("unknown multiLogin policy").Wrap()
}
//var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd
@ -233,5 +251,9 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string
//if l > adminTokenMaxNum {
// kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...)
//}
return deleteToken, kickToken, nil
var deleteAdminToken []string
if platformID == constant.AdminPlatformID {
deleteAdminToken = adminToken
}
return deleteToken, kickToken, deleteAdminToken, nil
}

@ -0,0 +1,58 @@
package controller
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/db/tx"
)
type ClientConfigDatabase interface {
SetUserConfig(ctx context.Context, userID string, config map[string]string) error
GetUserConfig(ctx context.Context, userID string) (map[string]string, error)
DelUserConfig(ctx context.Context, userID string, keys []string) error
GetUserConfigPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error)
}
func NewClientConfigDatabase(db database.ClientConfig, cache cache.ClientConfigCache, tx tx.Tx) ClientConfigDatabase {
return &clientConfigDatabase{
tx: tx,
db: db,
cache: cache,
}
}
type clientConfigDatabase struct {
tx tx.Tx
db database.ClientConfig
cache cache.ClientConfigCache
}
func (x *clientConfigDatabase) SetUserConfig(ctx context.Context, userID string, config map[string]string) error {
return x.tx.Transaction(ctx, func(ctx context.Context) error {
if err := x.db.Set(ctx, userID, config); err != nil {
return err
}
return x.cache.DeleteUserCache(ctx, []string{userID})
})
}
func (x *clientConfigDatabase) GetUserConfig(ctx context.Context, userID string) (map[string]string, error) {
return x.cache.GetUserConfig(ctx, userID)
}
func (x *clientConfigDatabase) DelUserConfig(ctx context.Context, userID string, keys []string) error {
return x.tx.Transaction(ctx, func(ctx context.Context) error {
if err := x.db.Del(ctx, userID, keys); err != nil {
return err
}
return x.cache.DeleteUserCache(ctx, []string{userID})
})
}
func (x *clientConfigDatabase) GetUserConfigPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) {
return x.db.GetPage(ctx, userID, key, pagination)
}

@ -33,7 +33,7 @@ type S3Database interface {
PartLimit() (*s3.PartLimit, error)
PartSize(ctx context.Context, size int64) (int64, error)
AuthSign(ctx context.Context, uploadID string, partNumbers []int) (*s3.AuthSignResult, error)
InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int) (*cont.InitiateUploadResult, error)
InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int, contentType string) (*cont.InitiateUploadResult, error)
CompleteMultipartUpload(ctx context.Context, uploadID string, parts []string) (*cont.UploadResult, error)
AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (time.Time, string, error)
SetObject(ctx context.Context, info *model.Object) error
@ -73,8 +73,8 @@ func (s *s3Database) AuthSign(ctx context.Context, uploadID string, partNumbers
return s.s3.AuthSign(ctx, uploadID, partNumbers)
}
func (s *s3Database) InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int) (*cont.InitiateUploadResult, error) {
return s.s3.InitiateUpload(ctx, hash, size, expire, maxParts)
func (s *s3Database) InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int, contentType string) (*cont.InitiateUploadResult, error) {
return s.s3.InitiateUploadContentType(ctx, hash, size, expire, maxParts, contentType)
}
func (s *s3Database) CompleteMultipartUpload(ctx context.Context, uploadID string, parts []string) (*cont.UploadResult, error) {

@ -0,0 +1,15 @@
package database
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type ClientConfig interface {
Set(ctx context.Context, userID string, config map[string]string) error
Get(ctx context.Context, userID string) (map[string]string, error)
Del(ctx context.Context, userID string, keys []string) error
GetPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error)
}

@ -0,0 +1,183 @@
package mgo
import (
"context"
"strconv"
"time"
"github.com/google/uuid"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewCacheMgo(db *mongo.Database) (*CacheMgo, error) {
coll := db.Collection(database.CacheName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "key", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "expire_at", Value: 1},
},
Options: options.Index().SetExpireAfterSeconds(0),
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &CacheMgo{coll: coll}, nil
}
type CacheMgo struct {
coll *mongo.Collection
}
func (x *CacheMgo) findToMap(res []model.Cache, now time.Time) map[string]string {
kv := make(map[string]string)
for _, re := range res {
if re.ExpireAt != nil && re.ExpireAt.Before(now) {
continue
}
kv[re.Key] = re.Value
}
return kv
}
func (x *CacheMgo) Get(ctx context.Context, key []string) (map[string]string, error) {
if len(key) == 0 {
return nil, nil
}
now := time.Now()
res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{
"key": bson.M{"$in": key},
"$or": []bson.M{
{"expire_at": bson.M{"$gt": now}},
{"expire_at": nil},
},
})
if err != nil {
return nil, err
}
return x.findToMap(res, now), nil
}
func (x *CacheMgo) Prefix(ctx context.Context, prefix string) (map[string]string, error) {
now := time.Now()
res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{
"key": bson.M{"$regex": "^" + prefix},
"$or": []bson.M{
{"expire_at": bson.M{"$gt": now}},
{"expire_at": nil},
},
})
if err != nil {
return nil, err
}
return x.findToMap(res, now), nil
}
func (x *CacheMgo) Set(ctx context.Context, key string, value string, expireAt time.Duration) error {
cv := &model.Cache{
Key: key,
Value: value,
}
if expireAt > 0 {
now := time.Now().Add(expireAt)
cv.ExpireAt = &now
}
opt := options.Update().SetUpsert(true)
return mongoutil.UpdateOne(ctx, x.coll, bson.M{"key": key}, bson.M{"$set": cv}, false, opt)
}
func (x *CacheMgo) Incr(ctx context.Context, key string, value int) (int, error) {
pipeline := mongo.Pipeline{
{
{"$set", bson.M{
"value": bson.M{
"$toString": bson.M{
"$add": bson.A{
bson.M{"$toInt": "$value"},
value,
},
},
},
}},
},
}
opt := options.FindOneAndUpdate().SetReturnDocument(options.After)
res, err := mongoutil.FindOneAndUpdate[model.Cache](ctx, x.coll, bson.M{"key": key}, pipeline, opt)
if err != nil {
return 0, err
}
return strconv.Atoi(res.Value)
}
func (x *CacheMgo) Del(ctx context.Context, key []string) error {
if len(key) == 0 {
return nil
}
_, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}})
return errs.Wrap(err)
}
func (x *CacheMgo) lockKey(key string) string {
return "LOCK_" + key
}
func (x *CacheMgo) Lock(ctx context.Context, key string, duration time.Duration) (string, error) {
tmp, err := uuid.NewUUID()
if err != nil {
return "", err
}
if duration <= 0 || duration > time.Minute*10 {
duration = time.Minute * 10
}
cv := &model.Cache{
Key: x.lockKey(key),
Value: tmp.String(),
ExpireAt: nil,
}
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
wait := func() error {
timeout := time.NewTimer(time.Millisecond * 100)
defer timeout.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout.C:
return nil
}
}
for {
if err := mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": key, "expire_at": bson.M{"$lt": time.Now()}}); err != nil {
return "", err
}
expireAt := time.Now().Add(duration)
cv.ExpireAt = &expireAt
if err := mongoutil.InsertMany[*model.Cache](ctx, x.coll, []*model.Cache{cv}); err != nil {
if mongo.IsDuplicateKeyError(err) {
if err := wait(); err != nil {
return "", err
}
continue
}
return "", err
}
return cv.Value, nil
}
}
func (x *CacheMgo) Unlock(ctx context.Context, key string, value string) error {
return mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": x.lockKey(key), "value": value})
}

@ -0,0 +1,99 @@
// Copyright © 2023 OpenIM open source community. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/tools/errs"
)
func NewClientConfig(db *mongo.Database) (database.ClientConfig, error) {
coll := db.Collection("config")
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "key", Value: 1},
{Key: "user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &ClientConfig{
coll: coll,
}, nil
}
type ClientConfig struct {
coll *mongo.Collection
}
func (x *ClientConfig) Set(ctx context.Context, userID string, config map[string]string) error {
if len(config) == 0 {
return nil
}
for key, value := range config {
filter := bson.M{"key": key, "user_id": userID}
update := bson.M{
"value": value,
}
err := mongoutil.UpdateOne(ctx, x.coll, filter, bson.M{"$set": update}, false, options.Update().SetUpsert(true))
if err != nil {
return err
}
}
return nil
}
func (x *ClientConfig) Get(ctx context.Context, userID string) (map[string]string, error) {
cs, err := mongoutil.Find[*model.ClientConfig](ctx, x.coll, bson.M{"user_id": userID})
if err != nil {
return nil, err
}
cm := make(map[string]string)
for _, config := range cs {
cm[config.Key] = config.Value
}
return cm, nil
}
func (x *ClientConfig) Del(ctx context.Context, userID string, keys []string) error {
if len(keys) == 0 {
return nil
}
return mongoutil.DeleteMany(ctx, x.coll, bson.M{"key": bson.M{"$in": keys}, "user_id": userID})
}
func (x *ClientConfig) GetPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) {
filter := bson.M{}
if userID != "" {
filter["user_id"] = userID
}
if key != "" {
filter["key"] = key
}
return mongoutil.FindPage[*model.ClientConfig](ctx, x.coll, filter, pagination)
}

@ -0,0 +1,7 @@
package model
type ClientConfig struct {
Key string `bson:"key"`
UserID string `bson:"user_id"`
Value string `bson:"value"`
}

@ -16,6 +16,7 @@ package rpccache
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/open-im-server/v3/pkg/localcache"
@ -153,6 +154,26 @@ func (c *ConversationLocalCache) getConversationNotReceiveMessageUserIDs(ctx con
}))
}
func (c *ConversationLocalCache) getPinnedConversationIDs(ctx context.Context, userID string) (val []string, err error) {
log.ZDebug(ctx, "ConversationLocalCache getPinnedConversations req", "userID", userID)
defer func() {
if err == nil {
log.ZDebug(ctx, "ConversationLocalCache getPinnedConversations return", "userID", userID, "value", val)
} else {
log.ZError(ctx, "ConversationLocalCache getPinnedConversations return", err, "userID", userID)
}
}()
var cache cacheProto[pbconversation.GetPinnedConversationIDsResp]
resp, err := cache.Unmarshal(c.local.Get(ctx, cachekey.GetPinnedConversationIDs(userID), func(ctx context.Context) ([]byte, error) {
log.ZDebug(ctx, "ConversationLocalCache getConversationNotReceiveMessageUserIDs rpc", "userID", userID)
return cache.Marshal(c.client.ConversationClient.GetPinnedConversationIDs(ctx, &pbconversation.GetPinnedConversationIDsReq{UserID: userID}))
}))
if err != nil {
return nil, err
}
return resp.ConversationIDs, nil
}
func (c *ConversationLocalCache) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) {
res, err := c.getConversationNotReceiveMessageUserIDs(ctx, conversationID)
if err != nil {
@ -168,3 +189,7 @@ func (c *ConversationLocalCache) GetConversationNotReceiveMessageUserIDMap(ctx c
}
return datautil.SliceSet(res.UserIDs), nil
}
func (c *ConversationLocalCache) GetPinnedConversationIDs(ctx context.Context, userID string) ([]string, error) {
return c.getPinnedConversationIDs(ctx, userID)
}

@ -0,0 +1,19 @@
# Stress Test V2
## Usage
You need set `TestTargetUserList` variables.
### Build
```bash
go build -o test/stress-test-v2/stress-test-v2 test/stress-test-v2/main.go
```
### Excute
```bash
tools/stress-test-v2/stress-test-v2 -c config/
```

@ -0,0 +1,759 @@
package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/apistruct"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/protocol/auth"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/group"
"github.com/openimsdk/protocol/sdkws"
pbuser "github.com/openimsdk/protocol/user"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/system/program"
)
// 1. Create 100K New Users
// 2. Create 100 100K Groups
// 3. Create 1000 999 Groups
// 4. Send message to 100K Groups every second
// 5. Send message to 999 Groups every minute
var (
// Use default userIDs List for testing, need to be created.
TestTargetUserList = []string{
// "<need-update-it>",
}
// DefaultGroupID = "<need-update-it>" // Use default group ID for testing, need to be created.
)
var (
ApiAddress string
// API method
GetAdminToken = "/auth/get_admin_token"
UserCheck = "/user/account_check"
CreateUser = "/user/user_register"
ImportFriend = "/friend/import_friend"
InviteToGroup = "/group/invite_user_to_group"
GetGroupMemberInfo = "/group/get_group_members_info"
SendMsg = "/msg/send_msg"
CreateGroup = "/group/create_group"
GetUserToken = "/auth/user_token"
)
const (
MaxUser = 100000
Max1kUser = 1000
Max100KGroup = 100
Max999Group = 1000
MaxInviteUserLimit = 999
CreateUserTicker = 1 * time.Second
CreateGroupTicker = 1 * time.Second
Create100KGroupTicker = 1 * time.Second
Create999GroupTicker = 1 * time.Second
SendMsgTo100KGroupTicker = 1 * time.Second
SendMsgTo999GroupTicker = 1 * time.Minute
)
type BaseResp struct {
ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"`
Data json.RawMessage `json:"data"`
}
type StressTest struct {
Conf *conf
AdminUserID string
AdminToken string
DefaultGroupID string
DefaultUserID string
UserCounter int
CreateUserCounter int
Create100kGroupCounter int
Create999GroupCounter int
MsgCounter int
CreatedUsers []string
CreatedGroups []string
Mutex sync.Mutex
Ctx context.Context
Cancel context.CancelFunc
HttpClient *http.Client
Wg sync.WaitGroup
Once sync.Once
}
type conf struct {
Share config.Share
Api config.API
}
func initConfig(configDir string) (*config.Share, *config.API, error) {
var (
share = &config.Share{}
apiConfig = &config.API{}
)
err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share)
if err != nil {
return nil, nil, err
}
err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig)
if err != nil {
return nil, nil, err
}
return share, apiConfig, nil
}
// Post Request
func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) {
// Marshal body
jsonBody, err := json.Marshal(reqbody)
if err != nil {
log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody)
return nil, err
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("operationID", st.AdminUserID)
if st.AdminToken != "" {
req.Header.Set("token", st.AdminToken)
}
// log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken)
resp, err := st.HttpClient.Do(req)
if err != nil {
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody)
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.ZError(ctx, "Failed to read response body", err, "url", url)
return nil, err
}
var baseResp BaseResp
if err := json.Unmarshal(respBody, &baseResp); err != nil {
log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody))
return nil, err
}
if baseResp.ErrCode != 0 {
err = fmt.Errorf(baseResp.ErrMsg)
// log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp)
return nil, err
}
return baseResp.Data, nil
}
func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) {
req := auth.GetAdminTokenReq{
Secret: st.Conf.Share.Secret,
UserID: st.AdminUserID,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req)
if err != nil {
return "", err
}
data := &auth.GetAdminTokenResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return "", err
}
return data.Token, nil
}
func (st *StressTest) CheckUser(ctx context.Context, userIDs []string) ([]string, error) {
req := pbuser.AccountCheckReq{
CheckUserIDs: userIDs,
}
resp, err := st.PostRequest(ctx, ApiAddress+UserCheck, &req)
if err != nil {
return nil, err
}
data := &pbuser.AccountCheckResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return nil, err
}
unRegisteredUserIDs := make([]string, 0)
for _, res := range data.Results {
if res.AccountStatus == constant.UnRegistered {
unRegisteredUserIDs = append(unRegisteredUserIDs, res.UserID)
}
}
return unRegisteredUserIDs, nil
}
func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) {
user := &sdkws.UserInfo{
UserID: userID,
Nickname: userID,
}
req := pbuser.UserRegisterReq{
Users: []*sdkws.UserInfo{user},
}
_, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req)
if err != nil {
return "", err
}
st.UserCounter++
return userID, nil
}
func (st *StressTest) CreateUserBatch(ctx context.Context, userIDs []string) error {
// The method can import a large number of users at once.
var userList []*sdkws.UserInfo
defer st.Once.Do(
func() {
st.DefaultUserID = userIDs[0]
fmt.Println("Default Send User Created ID:", st.DefaultUserID)
})
needUserIDs, err := st.CheckUser(ctx, userIDs)
if err != nil {
return err
}
for _, userID := range needUserIDs {
user := &sdkws.UserInfo{
UserID: userID,
Nickname: userID,
}
userList = append(userList, user)
}
req := pbuser.UserRegisterReq{
Users: userList,
}
_, err = st.PostRequest(ctx, ApiAddress+CreateUser, &req)
if err != nil {
return err
}
st.UserCounter += len(userList)
return nil
}
func (st *StressTest) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]string, error) {
needInviteUserIDs := make([]string, 0)
const maxBatchSize = 500
if len(userIDs) > maxBatchSize {
for i := 0; i < len(userIDs); i += maxBatchSize {
end := min(i+maxBatchSize, len(userIDs))
batchUserIDs := userIDs[i:end]
// log.ZInfo(ctx, "Processing group members batch", "groupID", groupID, "batch", i/maxBatchSize+1,
// "batchUserCount", len(batchUserIDs))
// Process a single batch
batchReq := group.GetGroupMembersInfoReq{
GroupID: groupID,
UserIDs: batchUserIDs,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &batchReq)
if err != nil {
log.ZError(ctx, "Batch query failed", err, "batch", i/maxBatchSize+1)
continue
}
data := &group.GetGroupMembersInfoResp{}
if err := json.Unmarshal(resp, &data); err != nil {
log.ZError(ctx, "Failed to parse batch response", err, "batch", i/maxBatchSize+1)
continue
}
// Process the batch results
existingMembers := make(map[string]bool)
for _, member := range data.Members {
existingMembers[member.UserID] = true
}
for _, userID := range batchUserIDs {
if !existingMembers[userID] {
needInviteUserIDs = append(needInviteUserIDs, userID)
}
}
}
return needInviteUserIDs, nil
}
req := group.GetGroupMembersInfoReq{
GroupID: groupID,
UserIDs: userIDs,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &req)
if err != nil {
return nil, err
}
data := &group.GetGroupMembersInfoResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return nil, err
}
existingMembers := make(map[string]bool)
for _, member := range data.Members {
existingMembers[member.UserID] = true
}
for _, userID := range userIDs {
if !existingMembers[userID] {
needInviteUserIDs = append(needInviteUserIDs, userID)
}
}
return needInviteUserIDs, nil
}
func (st *StressTest) InviteToGroup(ctx context.Context, groupID string, userIDs []string) error {
req := group.InviteUserToGroupReq{
GroupID: groupID,
InvitedUserIDs: userIDs,
}
_, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req)
if err != nil {
return err
}
return nil
}
func (st *StressTest) SendMsg(ctx context.Context, userID string, groupID string) error {
contentObj := map[string]any{
// "content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")),
"content": fmt.Sprintf("The current time is %s", time.Now().Format("2006-01-02 15:04:05.000")),
}
req := &apistruct.SendMsgReq{
SendMsg: apistruct.SendMsg{
SendID: userID,
SenderNickname: userID,
GroupID: groupID,
ContentType: constant.Text,
SessionType: constant.ReadGroupChatType,
Content: contentObj,
},
}
_, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req)
if err != nil {
log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req)
return err
}
st.MsgCounter++
return nil
}
// Max userIDs number is 1000
func (st *StressTest) CreateGroup(ctx context.Context, groupID string, userID string, userIDsList []string) (string, error) {
groupInfo := &sdkws.GroupInfo{
GroupID: groupID,
GroupName: groupID,
GroupType: constant.WorkingGroup,
}
req := group.CreateGroupReq{
OwnerUserID: userID,
MemberUserIDs: userIDsList,
GroupInfo: groupInfo,
}
resp := group.CreateGroupResp{}
response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req)
if err != nil {
return "", err
}
if err := json.Unmarshal(response, &resp); err != nil {
return "", err
}
// st.GroupCounter++
return resp.GroupInfo.GroupID, nil
}
func main() {
var configPath string
// defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config")
// flag.StringVar(&configPath, "c", defaultConfigDir, "config path")
flag.StringVar(&configPath, "c", "", "config path")
flag.Parse()
if configPath == "" {
_, _ = fmt.Fprintln(os.Stderr, "config path is empty")
os.Exit(1)
return
}
fmt.Printf(" Config Path: %s\n", configPath)
share, apiConfig, err := initConfig(configPath)
if err != nil {
program.ExitWithError(err)
return
}
ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0]))
ctx, cancel := context.WithCancel(context.Background())
// ch := make(chan struct{})
st := &StressTest{
Conf: &conf{
Share: *share,
Api: *apiConfig,
},
AdminUserID: share.IMAdminUserID[0],
Ctx: ctx,
Cancel: cancel,
HttpClient: &http.Client{
Timeout: 50 * time.Second,
},
}
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
fmt.Println("\nReceived stop signal, stopping...")
go func() {
// time.Sleep(5 * time.Second)
fmt.Println("Force exit")
os.Exit(0)
}()
st.Cancel()
}()
token, err := st.GetAdminToken(st.Ctx)
if err != nil {
log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID)
}
st.AdminToken = token
fmt.Println("Admin Token:", st.AdminToken)
fmt.Println("ApiAddress:", ApiAddress)
for i := range MaxUser {
userID := fmt.Sprintf("v2_StressTest_User_%d", i)
st.CreatedUsers = append(st.CreatedUsers, userID)
st.CreateUserCounter++
}
// err = st.CreateUserBatch(st.Ctx, st.CreatedUsers)
// if err != nil {
// log.ZError(ctx, "Create user failed.", err)
// }
const batchSize = 1000
totalUsers := len(st.CreatedUsers)
successCount := 0
if st.DefaultUserID == "" && len(st.CreatedUsers) > 0 {
st.DefaultUserID = st.CreatedUsers[0]
}
for i := 0; i < totalUsers; i += batchSize {
end := min(i+batchSize, totalUsers)
userBatch := st.CreatedUsers[i:end]
log.ZInfo(st.Ctx, "Creating user batch", "batch", i/batchSize+1, "count", len(userBatch))
err = st.CreateUserBatch(st.Ctx, userBatch)
if err != nil {
log.ZError(st.Ctx, "Batch user creation failed", err, "batch", i/batchSize+1)
} else {
successCount += len(userBatch)
log.ZInfo(st.Ctx, "Batch user creation succeeded", "batch", i/batchSize+1,
"progress", fmt.Sprintf("%d/%d", successCount, totalUsers))
}
}
// Execute create 100k group
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
create100kGroupTicker := time.NewTicker(Create100KGroupTicker)
defer create100kGroupTicker.Stop()
for i := range Max100KGroup {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create 100K Group")
return
case <-create100kGroupTicker.C:
// Create 100K groups
st.Wg.Add(1)
go func(idx int) {
startTime := time.Now()
defer func() {
elapsedTime := time.Since(startTime)
log.ZInfo(st.Ctx, "100K group creation completed",
"groupID", fmt.Sprintf("v2_StressTest_Group_100K_%d", idx),
"index", idx,
"duration", elapsedTime.String())
}()
defer st.Wg.Done()
defer func() {
st.Mutex.Lock()
st.Create100kGroupCounter++
st.Mutex.Unlock()
}()
groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", idx)
if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil {
log.ZError(st.Ctx, "Create group failed.", err)
// continue
}
for i := 0; i <= MaxUser/MaxInviteUserLimit; i++ {
InviteUserIDs := make([]string, 0)
// ensure TargetUserList is in group
InviteUserIDs = append(InviteUserIDs, TestTargetUserList...)
startIdx := max(i*MaxInviteUserLimit, 1)
endIdx := min((i+1)*MaxInviteUserLimit, MaxUser)
for j := startIdx; j < endIdx; j++ {
userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j)
InviteUserIDs = append(InviteUserIDs, userCreatedID)
}
if len(InviteUserIDs) == 0 {
// log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs)
if err != nil {
log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID)
continue
}
if len(InviteUserIDs) == 0 {
// log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
// Invite To Group
if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil {
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs)
continue
// os.Exit(1)
// return
}
}
}(i)
}
}
}()
// create 999 groups
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
create999GroupTicker := time.NewTicker(Create999GroupTicker)
defer create999GroupTicker.Stop()
for i := range Max999Group {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create 999 Group")
return
case <-create999GroupTicker.C:
// Create 999 groups
st.Wg.Add(1)
go func(idx int) {
startTime := time.Now()
defer func() {
elapsedTime := time.Since(startTime)
log.ZInfo(st.Ctx, "999 group creation completed",
"groupID", fmt.Sprintf("v2_StressTest_Group_1K_%d", idx),
"index", idx,
"duration", elapsedTime.String())
}()
defer st.Wg.Done()
defer func() {
st.Mutex.Lock()
st.Create999GroupCounter++
st.Mutex.Unlock()
}()
groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", idx)
if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil {
log.ZError(st.Ctx, "Create group failed.", err)
// continue
}
for i := 0; i <= Max1kUser/MaxInviteUserLimit; i++ {
InviteUserIDs := make([]string, 0)
// ensure TargetUserList is in group
InviteUserIDs = append(InviteUserIDs, TestTargetUserList...)
startIdx := max(i*MaxInviteUserLimit, 1)
endIdx := min((i+1)*MaxInviteUserLimit, Max1kUser)
for j := startIdx; j < endIdx; j++ {
userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j)
InviteUserIDs = append(InviteUserIDs, userCreatedID)
}
if len(InviteUserIDs) == 0 {
// log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs)
if err != nil {
log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID)
continue
}
if len(InviteUserIDs) == 0 {
// log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
// Invite To Group
if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil {
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs)
continue
// os.Exit(1)
// return
}
}
}(i)
}
}
}()
// Send message to 100K groups
st.Wg.Wait()
fmt.Println("All groups created successfully, starting to send messages...")
log.ZInfo(ctx, "All groups created successfully, starting to send messages...")
var groups100K []string
var groups999 []string
for i := range Max100KGroup {
groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", i)
groups100K = append(groups100K, groupID)
}
for i := range Max999Group {
groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", i)
groups999 = append(groups999, groupID)
}
send100kGroupLimiter := make(chan struct{}, 20)
send999GroupLimiter := make(chan struct{}, 100)
// execute Send message to 100K groups
go func() {
ticker := time.NewTicker(SendMsgTo100KGroupTicker)
defer ticker.Stop()
for {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Send Message to 100K Group")
return
case <-ticker.C:
// Send message to 100K groups
for _, groupID := range groups100K {
send100kGroupLimiter <- struct{}{}
go func(groupID string) {
defer func() { <-send100kGroupLimiter }()
if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil {
log.ZError(st.Ctx, "Send message to 100K group failed.", err)
}
}(groupID)
}
// log.ZInfo(st.Ctx, "Send message to 100K groups successfully.")
}
}
}()
// execute Send message to 999 groups
go func() {
ticker := time.NewTicker(SendMsgTo999GroupTicker)
defer ticker.Stop()
for {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Send Message to 999 Group")
return
case <-ticker.C:
// Send message to 999 groups
for _, groupID := range groups999 {
send999GroupLimiter <- struct{}{}
go func(groupID string) {
defer func() { <-send999GroupLimiter }()
if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil {
log.ZError(st.Ctx, "Send message to 999 group failed.", err)
}
}(groupID)
}
// log.ZInfo(st.Ctx, "Send message to 999 groups successfully.")
}
}
}()
<-st.Ctx.Done()
fmt.Println("Received signal to exit, shutting down...")
}

@ -0,0 +1,19 @@
# Stress Test
## Usage
You need set `TestTargetUserList` and `DefaultGroupID` variables.
### Build
```bash
go build -o test/stress-test/stress-test test/stress-test/main.go
```
### Excute
```bash
tools/stress-test/stress-test -c config/
```

@ -0,0 +1,458 @@
package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/apistruct"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/protocol/auth"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/group"
"github.com/openimsdk/protocol/relation"
"github.com/openimsdk/protocol/sdkws"
pbuser "github.com/openimsdk/protocol/user"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/system/program"
)
/*
1. Create one user every minute
2. Import target users as friends
3. Add users to the default group
4. Send a message to the default group every second, containing index and current timestamp
5. Create a new group every minute and invite target users to join
*/
// !!! ATTENTION: This variable is must be added!
var (
// Use default userIDs List for testing, need to be created.
TestTargetUserList = []string{
"<need-update-it>",
}
DefaultGroupID = "<need-update-it>" // Use default group ID for testing, need to be created.
)
var (
ApiAddress string
// API method
GetAdminToken = "/auth/get_admin_token"
CreateUser = "/user/user_register"
ImportFriend = "/friend/import_friend"
InviteToGroup = "/group/invite_user_to_group"
SendMsg = "/msg/send_msg"
CreateGroup = "/group/create_group"
GetUserToken = "/auth/user_token"
)
const (
MaxUser = 10000
MaxGroup = 1000
CreateUserTicker = 1 * time.Minute // Ticker is 1min in create user
SendMessageTicker = 1 * time.Second // Ticker is 1s in send message
CreateGroupTicker = 1 * time.Minute
)
type BaseResp struct {
ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"`
Data json.RawMessage `json:"data"`
}
type StressTest struct {
Conf *conf
AdminUserID string
AdminToken string
DefaultGroupID string
DefaultUserID string
UserCounter int
GroupCounter int
MsgCounter int
CreatedUsers []string
CreatedGroups []string
Mutex sync.Mutex
Ctx context.Context
Cancel context.CancelFunc
HttpClient *http.Client
Wg sync.WaitGroup
Once sync.Once
}
type conf struct {
Share config.Share
Api config.API
}
func initConfig(configDir string) (*config.Share, *config.API, error) {
var (
share = &config.Share{}
apiConfig = &config.API{}
)
err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share)
if err != nil {
return nil, nil, err
}
err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig)
if err != nil {
return nil, nil, err
}
return share, apiConfig, nil
}
// Post Request
func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) {
// Marshal body
jsonBody, err := json.Marshal(reqbody)
if err != nil {
log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody)
return nil, err
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("operationID", st.AdminUserID)
if st.AdminToken != "" {
req.Header.Set("token", st.AdminToken)
}
// log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken)
resp, err := st.HttpClient.Do(req)
if err != nil {
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody)
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.ZError(ctx, "Failed to read response body", err, "url", url)
return nil, err
}
var baseResp BaseResp
if err := json.Unmarshal(respBody, &baseResp); err != nil {
log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody))
return nil, err
}
if baseResp.ErrCode != 0 {
err = fmt.Errorf(baseResp.ErrMsg)
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp)
return nil, err
}
return baseResp.Data, nil
}
func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) {
req := auth.GetAdminTokenReq{
Secret: st.Conf.Share.Secret,
UserID: st.AdminUserID,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req)
if err != nil {
return "", err
}
data := &auth.GetAdminTokenResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return "", err
}
return data.Token, nil
}
func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) {
user := &sdkws.UserInfo{
UserID: userID,
Nickname: userID,
}
req := pbuser.UserRegisterReq{
Users: []*sdkws.UserInfo{user},
}
_, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req)
if err != nil {
return "", err
}
st.UserCounter++
return userID, nil
}
func (st *StressTest) ImportFriend(ctx context.Context, userID string) error {
req := relation.ImportFriendReq{
OwnerUserID: userID,
FriendUserIDs: TestTargetUserList,
}
_, err := st.PostRequest(ctx, ApiAddress+ImportFriend, &req)
if err != nil {
return err
}
return nil
}
func (st *StressTest) InviteToGroup(ctx context.Context, userID string) error {
req := group.InviteUserToGroupReq{
GroupID: st.DefaultGroupID,
InvitedUserIDs: []string{userID},
}
_, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req)
if err != nil {
return err
}
return nil
}
func (st *StressTest) SendMsg(ctx context.Context, userID string) error {
contentObj := map[string]any{
"content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")),
}
req := &apistruct.SendMsgReq{
SendMsg: apistruct.SendMsg{
SendID: userID,
SenderNickname: userID,
GroupID: st.DefaultGroupID,
ContentType: constant.Text,
SessionType: constant.ReadGroupChatType,
Content: contentObj,
},
}
_, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req)
if err != nil {
log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req)
return err
}
st.MsgCounter++
return nil
}
func (st *StressTest) CreateGroup(ctx context.Context, userID string) (string, error) {
groupID := fmt.Sprintf("StressTestGroup_%d_%s", st.GroupCounter, time.Now().Format("20060102150405"))
groupInfo := &sdkws.GroupInfo{
GroupID: groupID,
GroupName: groupID,
GroupType: constant.WorkingGroup,
}
req := group.CreateGroupReq{
OwnerUserID: userID,
MemberUserIDs: TestTargetUserList,
GroupInfo: groupInfo,
}
resp := group.CreateGroupResp{}
response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req)
if err != nil {
return "", err
}
if err := json.Unmarshal(response, &resp); err != nil {
return "", err
}
st.GroupCounter++
return resp.GroupInfo.GroupID, nil
}
func main() {
var configPath string
// defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config")
// flag.StringVar(&configPath, "c", defaultConfigDir, "config path")
flag.StringVar(&configPath, "c", "", "config path")
flag.Parse()
if configPath == "" {
_, _ = fmt.Fprintln(os.Stderr, "config path is empty")
os.Exit(1)
return
}
fmt.Printf(" Config Path: %s\n", configPath)
share, apiConfig, err := initConfig(configPath)
if err != nil {
program.ExitWithError(err)
return
}
ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0]))
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{})
defer cancel()
st := &StressTest{
Conf: &conf{
Share: *share,
Api: *apiConfig,
},
AdminUserID: share.IMAdminUserID[0],
Ctx: ctx,
Cancel: cancel,
HttpClient: &http.Client{
Timeout: 50 * time.Second,
},
}
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
fmt.Println("\nReceived stop signal, stopping...")
select {
case <-ch:
default:
close(ch)
}
st.Cancel()
}()
token, err := st.GetAdminToken(st.Ctx)
if err != nil {
log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID)
}
st.AdminToken = token
fmt.Println("Admin Token:", st.AdminToken)
fmt.Println("ApiAddress:", ApiAddress)
st.DefaultGroupID = DefaultGroupID
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
ticker := time.NewTicker(CreateUserTicker)
defer ticker.Stop()
for st.UserCounter < MaxUser {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create user", "reason", "context done")
return
case <-ticker.C:
// Create User
userID := fmt.Sprintf("%d_Stresstest_%s", st.UserCounter, time.Now().Format("0102150405"))
userCreatedID, err := st.CreateUser(st.Ctx, userID)
if err != nil {
log.ZError(st.Ctx, "Create User failed.", err, "UserID", userID)
os.Exit(1)
return
}
// fmt.Println("User Created ID:", userCreatedID)
// Import Friend
if err = st.ImportFriend(st.Ctx, userCreatedID); err != nil {
log.ZError(st.Ctx, "Import Friend failed.", err, "UserID", userCreatedID)
os.Exit(1)
return
}
// Invite To Group
if err = st.InviteToGroup(st.Ctx, userCreatedID); err != nil {
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", userCreatedID)
os.Exit(1)
return
}
st.Once.Do(func() {
st.DefaultUserID = userCreatedID
fmt.Println("Default Send User Created ID:", userCreatedID)
close(ch)
})
}
}
}()
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
ticker := time.NewTicker(SendMessageTicker)
defer ticker.Stop()
<-ch
for {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Send message", "reason", "context done")
return
case <-ticker.C:
// Send Message
if err = st.SendMsg(st.Ctx, st.DefaultUserID); err != nil {
log.ZError(st.Ctx, "Send Message failed.", err, "UserID", st.DefaultUserID)
continue
}
}
}
}()
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
ticker := time.NewTicker(CreateGroupTicker)
defer ticker.Stop()
<-ch
for st.GroupCounter < MaxGroup {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create Group", "reason", "context done")
return
case <-ticker.C:
// Create Group
_, err := st.CreateGroup(st.Ctx, st.DefaultUserID)
if err != nil {
log.ZError(st.Ctx, "Create Group failed.", err, "UserID", st.DefaultUserID)
os.Exit(1)
return
}
// fmt.Println("Group Created ID:", groupID)
}
}
}()
st.Wg.Wait()
}

@ -4,6 +4,11 @@ import (
"context"
"errors"
"fmt"
"log"
"net/http"
"path/filepath"
"time"
"github.com/mitchellh/mapstructure"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis"
@ -19,10 +24,6 @@ import (
"github.com/openimsdk/tools/s3/oss"
"github.com/spf13/viper"
"go.mongodb.org/mongo-driver/mongo"
"log"
"net/http"
"path/filepath"
"time"
)
const defaultTimeout = time.Second * 10
@ -159,7 +160,7 @@ func doObject(db database.ObjectInfo, newS3, oldS3 s3.Interface, skip int) (*Res
if err != nil {
return nil, err
}
putURL, err := newS3.PresignedPutObject(ctx, obj.Key, time.Hour)
putURL, err := newS3.PresignedPutObject(ctx, obj.Key, time.Hour, &s3.PutOption{ContentType: obj.ContentType})
if err != nil {
return nil, err
}
@ -176,7 +177,7 @@ func doObject(db database.ObjectInfo, newS3, oldS3 s3.Interface, skip int) (*Res
return nil, fmt.Errorf("download object failed %s", downloadResp.Status)
}
log.Printf("file size %d", obj.Size)
request, err := http.NewRequest(http.MethodPut, putURL, downloadResp.Body)
request, err := http.NewRequest(http.MethodPut, putURL.URL, downloadResp.Body)
if err != nil {
return nil, err
}

@ -337,7 +337,7 @@ func SetVersion(coll *mongo.Collection, key string, version int) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
option := options.Update().SetUpsert(true)
filter := bson.M{"key": key, "value": strconv.Itoa(version)}
filter := bson.M{"key": key}
update := bson.M{"$set": bson.M{"key": key, "value": strconv.Itoa(version)}}
return mongoutil.UpdateOne(ctx, coll, filter, update, false, option)
}

@ -0,0 +1,736 @@
package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/apistruct"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/protocol/auth"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/group"
"github.com/openimsdk/protocol/sdkws"
pbuser "github.com/openimsdk/protocol/user"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/system/program"
)
// 1. Create 100K New Users
// 2. Create 100 100K Groups
// 3. Create 1000 999 Groups
// 4. Send message to 100K Groups every second
// 5. Send message to 999 Groups every minute
var (
// Use default userIDs List for testing, need to be created.
TestTargetUserList = []string{
// "<need-update-it>",
}
// DefaultGroupID = "<need-update-it>" // Use default group ID for testing, need to be created.
)
var (
ApiAddress string
// API method
GetAdminToken = "/auth/get_admin_token"
UserCheck = "/user/account_check"
CreateUser = "/user/user_register"
ImportFriend = "/friend/import_friend"
InviteToGroup = "/group/invite_user_to_group"
GetGroupMemberInfo = "/group/get_group_members_info"
SendMsg = "/msg/send_msg"
CreateGroup = "/group/create_group"
GetUserToken = "/auth/user_token"
)
const (
MaxUser = 100000
Max100KGroup = 100
Max999Group = 1000
MaxInviteUserLimit = 999
CreateUserTicker = 1 * time.Second
CreateGroupTicker = 1 * time.Second
Create100KGroupTicker = 1 * time.Second
Create999GroupTicker = 1 * time.Second
SendMsgTo100KGroupTicker = 1 * time.Second
SendMsgTo999GroupTicker = 1 * time.Minute
)
type BaseResp struct {
ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"`
Data json.RawMessage `json:"data"`
}
type StressTest struct {
Conf *conf
AdminUserID string
AdminToken string
DefaultGroupID string
DefaultUserID string
UserCounter int
CreateUserCounter int
Create100kGroupCounter int
Create999GroupCounter int
MsgCounter int
CreatedUsers []string
CreatedGroups []string
Mutex sync.Mutex
Ctx context.Context
Cancel context.CancelFunc
HttpClient *http.Client
Wg sync.WaitGroup
Once sync.Once
}
type conf struct {
Share config.Share
Api config.API
}
func initConfig(configDir string) (*config.Share, *config.API, error) {
var (
share = &config.Share{}
apiConfig = &config.API{}
)
err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share)
if err != nil {
return nil, nil, err
}
err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig)
if err != nil {
return nil, nil, err
}
return share, apiConfig, nil
}
// Post Request
func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) {
// Marshal body
jsonBody, err := json.Marshal(reqbody)
if err != nil {
log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody)
return nil, err
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("operationID", st.AdminUserID)
if st.AdminToken != "" {
req.Header.Set("token", st.AdminToken)
}
// log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken)
resp, err := st.HttpClient.Do(req)
if err != nil {
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody)
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.ZError(ctx, "Failed to read response body", err, "url", url)
return nil, err
}
var baseResp BaseResp
if err := json.Unmarshal(respBody, &baseResp); err != nil {
log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody))
return nil, err
}
if baseResp.ErrCode != 0 {
err = fmt.Errorf(baseResp.ErrMsg)
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp)
return nil, err
}
return baseResp.Data, nil
}
func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) {
req := auth.GetAdminTokenReq{
Secret: st.Conf.Share.Secret,
UserID: st.AdminUserID,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req)
if err != nil {
return "", err
}
data := &auth.GetAdminTokenResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return "", err
}
return data.Token, nil
}
func (st *StressTest) CheckUser(ctx context.Context, userIDs []string) ([]string, error) {
req := pbuser.AccountCheckReq{
CheckUserIDs: userIDs,
}
resp, err := st.PostRequest(ctx, ApiAddress+UserCheck, &req)
if err != nil {
return nil, err
}
data := &pbuser.AccountCheckResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return nil, err
}
unRegisteredUserIDs := make([]string, 0)
for _, res := range data.Results {
if res.AccountStatus == constant.UnRegistered {
unRegisteredUserIDs = append(unRegisteredUserIDs, res.UserID)
}
}
return unRegisteredUserIDs, nil
}
func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) {
user := &sdkws.UserInfo{
UserID: userID,
Nickname: userID,
}
req := pbuser.UserRegisterReq{
Users: []*sdkws.UserInfo{user},
}
_, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req)
if err != nil {
return "", err
}
st.UserCounter++
return userID, nil
}
func (st *StressTest) CreateUserBatch(ctx context.Context, userIDs []string) error {
// The method can import a large number of users at once.
var userList []*sdkws.UserInfo
defer st.Once.Do(
func() {
st.DefaultUserID = userIDs[0]
fmt.Println("Default Send User Created ID:", st.DefaultUserID)
})
needUserIDs, err := st.CheckUser(ctx, userIDs)
if err != nil {
return err
}
for _, userID := range needUserIDs {
user := &sdkws.UserInfo{
UserID: userID,
Nickname: userID,
}
userList = append(userList, user)
}
req := pbuser.UserRegisterReq{
Users: userList,
}
_, err = st.PostRequest(ctx, ApiAddress+CreateUser, &req)
if err != nil {
return err
}
st.UserCounter += len(userList)
return nil
}
func (st *StressTest) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]string, error) {
needInviteUserIDs := make([]string, 0)
const maxBatchSize = 500
if len(userIDs) > maxBatchSize {
for i := 0; i < len(userIDs); i += maxBatchSize {
end := min(i+maxBatchSize, len(userIDs))
batchUserIDs := userIDs[i:end]
// log.ZInfo(ctx, "Processing group members batch", "groupID", groupID, "batch", i/maxBatchSize+1,
// "batchUserCount", len(batchUserIDs))
// Process a single batch
batchReq := group.GetGroupMembersInfoReq{
GroupID: groupID,
UserIDs: batchUserIDs,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &batchReq)
if err != nil {
log.ZError(ctx, "Batch query failed", err, "batch", i/maxBatchSize+1)
continue
}
data := &group.GetGroupMembersInfoResp{}
if err := json.Unmarshal(resp, &data); err != nil {
log.ZError(ctx, "Failed to parse batch response", err, "batch", i/maxBatchSize+1)
continue
}
// Process the batch results
existingMembers := make(map[string]bool)
for _, member := range data.Members {
existingMembers[member.UserID] = true
}
for _, userID := range batchUserIDs {
if !existingMembers[userID] {
needInviteUserIDs = append(needInviteUserIDs, userID)
}
}
}
return needInviteUserIDs, nil
}
req := group.GetGroupMembersInfoReq{
GroupID: groupID,
UserIDs: userIDs,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &req)
if err != nil {
return nil, err
}
data := &group.GetGroupMembersInfoResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return nil, err
}
existingMembers := make(map[string]bool)
for _, member := range data.Members {
existingMembers[member.UserID] = true
}
for _, userID := range userIDs {
if !existingMembers[userID] {
needInviteUserIDs = append(needInviteUserIDs, userID)
}
}
return needInviteUserIDs, nil
}
func (st *StressTest) InviteToGroup(ctx context.Context, groupID string, userIDs []string) error {
req := group.InviteUserToGroupReq{
GroupID: groupID,
InvitedUserIDs: userIDs,
}
_, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req)
if err != nil {
return err
}
return nil
}
func (st *StressTest) SendMsg(ctx context.Context, userID string, groupID string) error {
contentObj := map[string]any{
// "content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")),
"content": fmt.Sprintf("The current time is %s", time.Now().Format("2006-01-02 15:04:05.000")),
}
req := &apistruct.SendMsgReq{
SendMsg: apistruct.SendMsg{
SendID: userID,
SenderNickname: userID,
GroupID: groupID,
ContentType: constant.Text,
SessionType: constant.ReadGroupChatType,
Content: contentObj,
},
}
_, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req)
if err != nil {
log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req)
return err
}
st.MsgCounter++
return nil
}
// Max userIDs number is 1000
func (st *StressTest) CreateGroup(ctx context.Context, groupID string, userID string, userIDsList []string) (string, error) {
groupInfo := &sdkws.GroupInfo{
GroupID: groupID,
GroupName: groupID,
GroupType: constant.WorkingGroup,
}
req := group.CreateGroupReq{
OwnerUserID: userID,
MemberUserIDs: userIDsList,
GroupInfo: groupInfo,
}
resp := group.CreateGroupResp{}
response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req)
if err != nil {
return "", err
}
if err := json.Unmarshal(response, &resp); err != nil {
return "", err
}
// st.GroupCounter++
return resp.GroupInfo.GroupID, nil
}
func main() {
var configPath string
// defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config")
// flag.StringVar(&configPath, "c", defaultConfigDir, "config path")
flag.StringVar(&configPath, "c", "", "config path")
flag.Parse()
if configPath == "" {
_, _ = fmt.Fprintln(os.Stderr, "config path is empty")
os.Exit(1)
return
}
fmt.Printf(" Config Path: %s\n", configPath)
share, apiConfig, err := initConfig(configPath)
if err != nil {
program.ExitWithError(err)
return
}
ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0]))
ctx, cancel := context.WithCancel(context.Background())
// ch := make(chan struct{})
st := &StressTest{
Conf: &conf{
Share: *share,
Api: *apiConfig,
},
AdminUserID: share.IMAdminUserID[0],
Ctx: ctx,
Cancel: cancel,
HttpClient: &http.Client{
Timeout: 50 * time.Second,
},
}
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
fmt.Println("\nReceived stop signal, stopping...")
go func() {
// time.Sleep(5 * time.Second)
fmt.Println("Force exit")
os.Exit(0)
}()
st.Cancel()
}()
token, err := st.GetAdminToken(st.Ctx)
if err != nil {
log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID)
}
st.AdminToken = token
fmt.Println("Admin Token:", st.AdminToken)
fmt.Println("ApiAddress:", ApiAddress)
for i := range MaxUser {
userID := fmt.Sprintf("v2_StressTest_User_%d", i)
st.CreatedUsers = append(st.CreatedUsers, userID)
st.CreateUserCounter++
}
// err = st.CreateUserBatch(st.Ctx, st.CreatedUsers)
// if err != nil {
// log.ZError(ctx, "Create user failed.", err)
// }
const batchSize = 1000
totalUsers := len(st.CreatedUsers)
successCount := 0
if st.DefaultUserID == "" && len(st.CreatedUsers) > 0 {
st.DefaultUserID = st.CreatedUsers[0]
}
for i := 0; i < totalUsers; i += batchSize {
end := min(i+batchSize, totalUsers)
userBatch := st.CreatedUsers[i:end]
log.ZInfo(st.Ctx, "Creating user batch", "batch", i/batchSize+1, "count", len(userBatch))
err = st.CreateUserBatch(st.Ctx, userBatch)
if err != nil {
log.ZError(st.Ctx, "Batch user creation failed", err, "batch", i/batchSize+1)
} else {
successCount += len(userBatch)
log.ZInfo(st.Ctx, "Batch user creation succeeded", "batch", i/batchSize+1,
"progress", fmt.Sprintf("%d/%d", successCount, totalUsers))
}
}
// Execute create 100k group
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
create100kGroupTicker := time.NewTicker(Create100KGroupTicker)
defer create100kGroupTicker.Stop()
for i := range Max100KGroup {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create 100K Group")
return
case <-create100kGroupTicker.C:
// Create 100K groups
st.Wg.Add(1)
go func(idx int) {
defer st.Wg.Done()
defer func() {
st.Create100kGroupCounter++
}()
groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", idx)
if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil {
log.ZError(st.Ctx, "Create group failed.", err)
// continue
}
for i := 0; i < MaxUser/MaxInviteUserLimit; i++ {
InviteUserIDs := make([]string, 0)
// ensure TargetUserList is in group
InviteUserIDs = append(InviteUserIDs, TestTargetUserList...)
startIdx := max(i*MaxInviteUserLimit, 1)
endIdx := min((i+1)*MaxInviteUserLimit, MaxUser)
for j := startIdx; j < endIdx; j++ {
userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j)
InviteUserIDs = append(InviteUserIDs, userCreatedID)
}
if len(InviteUserIDs) == 0 {
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs)
if err != nil {
log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID)
continue
}
if len(InviteUserIDs) == 0 {
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
// Invite To Group
if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil {
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs)
continue
// os.Exit(1)
// return
}
}
}(i)
}
}
}()
// create 999 groups
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
create999GroupTicker := time.NewTicker(Create999GroupTicker)
defer create999GroupTicker.Stop()
for i := range Max999Group {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create 999 Group")
return
case <-create999GroupTicker.C:
// Create 999 groups
st.Wg.Add(1)
go func(idx int) {
defer st.Wg.Done()
defer func() {
st.Create999GroupCounter++
}()
groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", idx)
if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil {
log.ZError(st.Ctx, "Create group failed.", err)
// continue
}
for i := 0; i < MaxUser/MaxInviteUserLimit; i++ {
InviteUserIDs := make([]string, 0)
// ensure TargetUserList is in group
InviteUserIDs = append(InviteUserIDs, TestTargetUserList...)
startIdx := max(i*MaxInviteUserLimit, 1)
endIdx := min((i+1)*MaxInviteUserLimit, MaxUser)
for j := startIdx; j < endIdx; j++ {
userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j)
InviteUserIDs = append(InviteUserIDs, userCreatedID)
}
if len(InviteUserIDs) == 0 {
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs)
if err != nil {
log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID)
continue
}
if len(InviteUserIDs) == 0 {
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
continue
}
// Invite To Group
if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil {
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs)
continue
// os.Exit(1)
// return
}
}
}(i)
}
}
}()
// Send message to 100K groups
st.Wg.Wait()
fmt.Println("All groups created successfully, starting to send messages...")
log.ZInfo(ctx, "All groups created successfully, starting to send messages...")
var groups100K []string
var groups999 []string
for i := range Max100KGroup {
groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", i)
groups100K = append(groups100K, groupID)
}
for i := range Max999Group {
groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", i)
groups999 = append(groups999, groupID)
}
send100kGroupLimiter := make(chan struct{}, 20)
send999GroupLimiter := make(chan struct{}, 100)
// execute Send message to 100K groups
go func() {
ticker := time.NewTicker(SendMsgTo100KGroupTicker)
defer ticker.Stop()
for {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Send Message to 100K Group")
return
case <-ticker.C:
// Send message to 100K groups
for _, groupID := range groups100K {
send100kGroupLimiter <- struct{}{}
go func(groupID string) {
defer func() { <-send100kGroupLimiter }()
if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil {
log.ZError(st.Ctx, "Send message to 100K group failed.", err)
}
}(groupID)
}
// log.ZInfo(st.Ctx, "Send message to 100K groups successfully.")
}
}
}()
// execute Send message to 999 groups
go func() {
ticker := time.NewTicker(SendMsgTo999GroupTicker)
defer ticker.Stop()
for {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Send Message to 999 Group")
return
case <-ticker.C:
// Send message to 999 groups
for _, groupID := range groups999 {
send999GroupLimiter <- struct{}{}
go func(groupID string) {
defer func() { <-send999GroupLimiter }()
if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil {
log.ZError(st.Ctx, "Send message to 999 group failed.", err)
}
}(groupID)
}
// log.ZInfo(st.Ctx, "Send message to 999 groups successfully.")
}
}
}()
<-st.Ctx.Done()
fmt.Println("Received signal to exit, shutting down...")
}

@ -0,0 +1,25 @@
# Stress Test
## Usage
You need set `TestTargetUserList` and `DefaultGroupID` variables.
### Build
```bash
go build -o _output/bin/tools/linux/amd64/stress-test tools/stress-test/main.go
# or
go build -o tools/stress-test/stress-test tools/stress-test/main.go
```
### Excute
```bash
_output/bin/tools/linux/amd64/stress-test -c config/
#or
tools/stress-test/stress-test -c config/
```

@ -0,0 +1,459 @@
package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/openimsdk/open-im-server/v3/pkg/apistruct"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/protocol/auth"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/group"
"github.com/openimsdk/protocol/relation"
"github.com/openimsdk/protocol/sdkws"
pbuser "github.com/openimsdk/protocol/user"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/system/program"
)
/*
1. Create one user every minute
2. Import target users as friends
3. Add users to the default group
4. Send a message to the default group every second, containing index and current timestamp
5. Create a new group every minute and invite target users to join
*/
// !!! ATTENTION: This variable is must be added!
var (
// Use default userIDs List for testing, need to be created.
TestTargetUserList = []string{
"<need-update-it>",
}
DefaultGroupID = "<need-update-it>" // Use default group ID for testing, need to be created.
)
var (
ApiAddress string
// API method
GetAdminToken = "/auth/get_admin_token"
CreateUser = "/user/user_register"
ImportFriend = "/friend/import_friend"
InviteToGroup = "/group/invite_user_to_group"
SendMsg = "/msg/send_msg"
CreateGroup = "/group/create_group"
GetUserToken = "/auth/user_token"
)
const (
MaxUser = 10000
MaxGroup = 1000
CreateUserTicker = 1 * time.Minute // Ticker is 1min in create user
SendMessageTicker = 1 * time.Second // Ticker is 1s in send message
CreateGroupTicker = 1 * time.Minute
)
type BaseResp struct {
ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"`
Data json.RawMessage `json:"data"`
}
type StressTest struct {
Conf *conf
AdminUserID string
AdminToken string
DefaultGroupID string
DefaultUserID string
UserCounter int
GroupCounter int
MsgCounter int
CreatedUsers []string
CreatedGroups []string
Mutex sync.Mutex
Ctx context.Context
Cancel context.CancelFunc
HttpClient *http.Client
Wg sync.WaitGroup
Once sync.Once
}
type conf struct {
Share config.Share
Api config.API
}
func initConfig(configDir string) (*config.Share, *config.API, error) {
var (
share = &config.Share{}
apiConfig = &config.API{}
)
err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share)
if err != nil {
return nil, nil, err
}
err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig)
if err != nil {
return nil, nil, err
}
return share, apiConfig, nil
}
// Post Request
func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) {
// Marshal body
jsonBody, err := json.Marshal(reqbody)
if err != nil {
log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody)
return nil, err
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("operationID", st.AdminUserID)
if st.AdminToken != "" {
req.Header.Set("token", st.AdminToken)
}
// log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken)
resp, err := st.HttpClient.Do(req)
if err != nil {
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody)
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
log.ZError(ctx, "Failed to read response body", err, "url", url)
return nil, err
}
var baseResp BaseResp
if err := json.Unmarshal(respBody, &baseResp); err != nil {
log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody))
return nil, err
}
if baseResp.ErrCode != 0 {
err = fmt.Errorf(baseResp.ErrMsg)
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp)
return nil, err
}
return baseResp.Data, nil
}
func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) {
req := auth.GetAdminTokenReq{
Secret: st.Conf.Share.Secret,
UserID: st.AdminUserID,
}
resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req)
if err != nil {
return "", err
}
data := &auth.GetAdminTokenResp{}
if err := json.Unmarshal(resp, &data); err != nil {
return "", err
}
return data.Token, nil
}
func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) {
user := &sdkws.UserInfo{
UserID: userID,
Nickname: userID,
}
req := pbuser.UserRegisterReq{
Users: []*sdkws.UserInfo{user},
}
_, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req)
if err != nil {
return "", err
}
st.UserCounter++
return userID, nil
}
func (st *StressTest) ImportFriend(ctx context.Context, userID string) error {
req := relation.ImportFriendReq{
OwnerUserID: userID,
FriendUserIDs: TestTargetUserList,
}
_, err := st.PostRequest(ctx, ApiAddress+ImportFriend, &req)
if err != nil {
return err
}
return nil
}
func (st *StressTest) InviteToGroup(ctx context.Context, userID string) error {
req := group.InviteUserToGroupReq{
GroupID: st.DefaultGroupID,
InvitedUserIDs: []string{userID},
}
_, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req)
if err != nil {
return err
}
return nil
}
func (st *StressTest) SendMsg(ctx context.Context, userID string) error {
contentObj := map[string]any{
"content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")),
}
req := &apistruct.SendMsgReq{
SendMsg: apistruct.SendMsg{
SendID: userID,
SenderNickname: userID,
GroupID: st.DefaultGroupID,
ContentType: constant.Text,
SessionType: constant.ReadGroupChatType,
Content: contentObj,
},
}
_, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req)
if err != nil {
log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req)
return err
}
st.MsgCounter++
return nil
}
func (st *StressTest) CreateGroup(ctx context.Context, userID string) (string, error) {
groupID := fmt.Sprintf("StressTestGroup_%d_%s", st.GroupCounter, time.Now().Format("20060102150405"))
groupInfo := &sdkws.GroupInfo{
GroupID: groupID,
GroupName: groupID,
GroupType: constant.WorkingGroup,
}
req := group.CreateGroupReq{
OwnerUserID: userID,
MemberUserIDs: TestTargetUserList,
GroupInfo: groupInfo,
}
resp := group.CreateGroupResp{}
response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req)
if err != nil {
return "", err
}
if err := json.Unmarshal(response, &resp); err != nil {
return "", err
}
st.GroupCounter++
return resp.GroupInfo.GroupID, nil
}
func main() {
var configPath string
// defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config")
// flag.StringVar(&configPath, "c", defaultConfigDir, "config path")
flag.StringVar(&configPath, "c", "", "config path")
flag.Parse()
if configPath == "" {
_, _ = fmt.Fprintln(os.Stderr, "config path is empty")
os.Exit(1)
return
}
fmt.Printf(" Config Path: %s\n", configPath)
share, apiConfig, err := initConfig(configPath)
if err != nil {
program.ExitWithError(err)
return
}
ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0]))
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan struct{})
defer cancel()
st := &StressTest{
Conf: &conf{
Share: *share,
Api: *apiConfig,
},
AdminUserID: share.IMAdminUserID[0],
Ctx: ctx,
Cancel: cancel,
HttpClient: &http.Client{
Timeout: 50 * time.Second,
},
}
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
fmt.Println("\nReceived stop signal, stopping...")
select {
case <-ch:
default:
close(ch)
}
st.Cancel()
}()
token, err := st.GetAdminToken(st.Ctx)
if err != nil {
log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID)
}
st.AdminToken = token
fmt.Println("Admin Token:", st.AdminToken)
fmt.Println("ApiAddress:", ApiAddress)
st.DefaultGroupID = DefaultGroupID
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
ticker := time.NewTicker(CreateUserTicker)
defer ticker.Stop()
for st.UserCounter < MaxUser {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create user", "reason", "context done")
return
case <-ticker.C:
// Create User
userID := fmt.Sprintf("%d_Stresstest_%s", st.UserCounter, time.Now().Format("0102150405"))
userCreatedID, err := st.CreateUser(st.Ctx, userID)
if err != nil {
log.ZError(st.Ctx, "Create User failed.", err, "UserID", userID)
os.Exit(1)
return
}
// fmt.Println("User Created ID:", userCreatedID)
// Import Friend
if err = st.ImportFriend(st.Ctx, userCreatedID); err != nil {
log.ZError(st.Ctx, "Import Friend failed.", err, "UserID", userCreatedID)
os.Exit(1)
return
}
// Invite To Group
if err = st.InviteToGroup(st.Ctx, userCreatedID); err != nil {
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", userCreatedID)
os.Exit(1)
return
}
st.Once.Do(func() {
st.DefaultUserID = userCreatedID
fmt.Println("Default Send User Created ID:", userCreatedID)
close(ch)
})
}
}
}()
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
ticker := time.NewTicker(SendMessageTicker)
defer ticker.Stop()
<-ch
for {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Send message", "reason", "context done")
return
case <-ticker.C:
// Send Message
if err = st.SendMsg(st.Ctx, st.DefaultSendUserID); err != nil {
log.ZError(st.Ctx, "Send Message failed.", err, "UserID", st.DefaultSendUserID)
continue
}
}
}
}()
st.Wg.Add(1)
go func() {
defer st.Wg.Done()
ticker := time.NewTicker(CreateGroupTicker)
defer ticker.Stop()
<-ch
for st.GroupCounter < MaxGroup {
select {
case <-st.Ctx.Done():
log.ZInfo(st.Ctx, "Stop Create Group", "reason", "context done")
return
case <-ticker.C:
// Create Group
_, err := st.CreateGroup(st.Ctx, st.DefaultUserID)
if err != nil {
log.ZError(st.Ctx, "Create Group failed.", err, "UserID", st.DefaultUserID)
os.Exit(1)
return
}
// fmt.Println("Group Created ID:", groupID)
}
}
}()
st.Wg.Wait()
}

@ -1,6 +1,14 @@
package version
import _ "embed"
import (
_ "embed"
"strings"
)
//go:embed version
var Version string
func init() {
Version = strings.Trim(Version, "\n")
Version = strings.TrimSpace(Version)
}

Loading…
Cancel
Save