From 20158ec3ef75b3e98f8c734bf1a6b366dc1a1b2c Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Fri, 17 Nov 2023 16:16:25 +0800 Subject: [PATCH] tx --- internal/rpc/conversation/conversaion.go | 14 +- internal/rpc/group/group.go | 8 +- internal/rpc/user/user.go | 16 +-- internal/rpc/user/user_test.go | 132 +++++++++++++++++++ internal/tools/msg.go | 17 ++- pkg/common/db/controller/conversation.go | 88 +++++-------- pkg/common/db/newmgo/conversation.go | 5 - pkg/common/db/table/relation/conversation.go | 1 - pkg/common/db/tx/auto.go | 19 +++ pkg/common/db/tx/invalid.go | 16 +++ pkg/common/db/tx/tx.go | 28 ++++ 11 files changed, 256 insertions(+), 88 deletions(-) create mode 100644 internal/rpc/user/user_test.go create mode 100644 pkg/common/db/tx/auto.go create mode 100644 pkg/common/db/tx/invalid.go create mode 100644 pkg/common/db/tx/tx.go diff --git a/internal/rpc/conversation/conversaion.go b/internal/rpc/conversation/conversaion.go index 0ae9ed9a9..12e3e96fd 100644 --- a/internal/rpc/conversation/conversaion.go +++ b/internal/rpc/conversation/conversaion.go @@ -18,6 +18,7 @@ import ( "context" "errors" "github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo" + tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx" "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" "google.golang.org/grpc" @@ -27,13 +28,11 @@ import ( "github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" - "github.com/OpenIMSDK/tools/tx" "github.com/OpenIMSDK/tools/utils" "github.com/openimsdk/open-im-server/v3/pkg/common/convert" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/relation" tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification" @@ -46,18 +45,15 @@ type conversationServer struct { } func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - db, err := relation.NewGormDB() + rdb, err := cache.NewRedis() if err != nil { return err } - if err := db.AutoMigrate(&tablerelation.ConversationModel{}); err != nil { - return err - } - rdb, err := cache.NewRedis() + mongo, err := unrelation.NewMongo() if err != nil { return err } - mongo, err := unrelation.NewMongo() + tx, err := tx2.NewAuto(context.Background(), mongo.GetClient()) if err != nil { return err } @@ -70,7 +66,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e pbconversation.RegisterConversationServer(server, &conversationServer{ conversationNotificationSender: notification.NewConversationNotificationSender(&msgRpcClient), groupRpcClient: &groupRpcClient, - conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx.NewGorm(db)), + conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx), }) return nil } diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 6fef13676..7822cc31a 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -17,8 +17,8 @@ package group import ( "context" "fmt" - "github.com/OpenIMSDK/tools/tx" "github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo" + tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient/grouphash" "math/big" "math/rand" @@ -88,8 +88,12 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e userRpcClient := rpcclient.NewUserRpcClient(client) msgRpcClient := rpcclient.NewMessageRpcClient(client) conversationRpcClient := rpcclient.NewConversationRpcClient(client) + tx, err := tx2.NewAuto(context.Background(), mongo.GetClient()) + if err != nil { + return err + } var gs groupServer - database := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx.NewMongo(mongo.GetClient()), grouphash.NewGroupHashFromGroupServer(&gs)) + database := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx, grouphash.NewGroupHashFromGroupServer(&gs)) gs.db = database gs.User = userRpcClient gs.Notification = notification.NewGroupNotificationSender(database, &msgRpcClient, &userRpcClient, func(ctx context.Context, userIDs []string) ([]notification.CommonUser, error) { diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index 617b595ef..cd77c4219 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -18,6 +18,7 @@ import ( "context" "errors" "github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo" + tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx" "strings" "time" @@ -25,7 +26,6 @@ import ( "github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" - "github.com/OpenIMSDK/tools/tx" "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" @@ -36,7 +36,6 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/convert" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/relation" tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification" @@ -56,10 +55,6 @@ type userServer struct { } func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { - db, err := relation.NewGormDB() - if err != nil { - return err - } rdb, err := cache.NewRedis() if err != nil { return err @@ -68,9 +63,6 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { if err != nil { return err } - if err := db.AutoMigrate(&tablerelation.UserModel{}); err != nil { - return err - } users := make([]*tablerelation.UserModel, 0) if len(config.Config.Manager.UserID) != len(config.Config.Manager.Nickname) { return errors.New("len(config.Config.Manager.AppManagerUid) != len(config.Config.Manager.Nickname)") @@ -82,9 +74,13 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { if err != nil { return err } + tx, err := tx2.NewAuto(context.Background(), mongo.GetClient()) + if err != nil { + return err + } cache := cache.NewUserCacheRedis(rdb, userDB, cache.GetDefaultOpt()) userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase()) - database := controller.NewUserDatabase(userDB, cache, tx.NewMongo(mongo.GetClient()), userMongoDB) + database := controller.NewUserDatabase(userDB, cache, tx, userMongoDB) friendRpcClient := rpcclient.NewFriendRpcClient(client) groupRpcClient := rpcclient.NewGroupRpcClient(client) msgRpcClient := rpcclient.NewMessageRpcClient(client) diff --git a/internal/rpc/user/user_test.go b/internal/rpc/user/user_test.go new file mode 100644 index 000000000..abe1c5023 --- /dev/null +++ b/internal/rpc/user/user_test.go @@ -0,0 +1,132 @@ +package user + +import ( + "context" + "errors" + "github.com/OpenIMSDK/protocol/constant" + "github.com/OpenIMSDK/protocol/sdkws" + "github.com/OpenIMSDK/protocol/user" + "github.com/OpenIMSDK/tools/log" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo" + tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" + tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" + "github.com/redis/go-redis/v9" + "math/rand" + "net" + "strconv" + "testing" +) + +var ( + rdb redis.UniversalClient + mgo *unrelation.Mongo + ctx context.Context +) + +func InitDB() error { + addr := "172.16.8.142" + pwd := "openIM123" + + config.Config.Redis.Address = []string{net.JoinHostPort(addr, "16379")} + config.Config.Redis.Password = pwd + config.Config.Mongo.Address = []string{net.JoinHostPort(addr, "37017")} + config.Config.Mongo.Database = "openIM_v3" + config.Config.Mongo.Username = "root" + config.Config.Mongo.Password = pwd + config.Config.Mongo.MaxPoolSize = 100 + var err error + rdb, err = cache.NewRedis() + if err != nil { + return err + } + mgo, err = unrelation.NewMongo() + if err != nil { + return err + } + tx, err := tx2.NewAuto(context.Background(), mgo.GetClient()) + if err != nil { + return err + } + + config.Config.Object.Enable = "minio" + config.Config.Object.ApiURL = "http://" + net.JoinHostPort(addr, "10002") + config.Config.Object.Minio.Bucket = "openim" + config.Config.Object.Minio.Endpoint = "http://" + net.JoinHostPort(addr, "10005") + config.Config.Object.Minio.AccessKeyID = "root" + config.Config.Object.Minio.SecretAccessKey = pwd + config.Config.Object.Minio.SignEndpoint = config.Config.Object.Minio.Endpoint + + config.Config.Manager.UserID = []string{"openIM123456"} + config.Config.Manager.Nickname = []string{"openIM123456"} + + ctx = context.WithValue(context.Background(), constant.OperationID, "debugOperationID") + ctx = context.WithValue(context.Background(), constant.OpUserID, config.Config.Manager.UserID[0]) + + if err := log.InitFromConfig("", "", 6, true, false, "", 2, 1); err != nil { + panic(err) + } + + users := make([]*tablerelation.UserModel, 0) + if len(config.Config.Manager.UserID) != len(config.Config.Manager.Nickname) { + return errors.New("len(config.Config.Manager.AppManagerUid) != len(config.Config.Manager.Nickname)") + } + for k, v := range config.Config.Manager.UserID { + users = append(users, &tablerelation.UserModel{UserID: v, Nickname: config.Config.Manager.Nickname[k], AppMangerLevel: constant.AppAdmin}) + } + userDB, err := newmgo.NewUserMongo(mgo.GetDatabase()) + if err != nil { + return err + } + + //var client registry.SvcDiscoveryRegistry + //_= client + cache := cache.NewUserCacheRedis(rdb, userDB, cache.GetDefaultOpt()) + userMongoDB := unrelation.NewUserMongoDriver(mgo.GetDatabase()) + database := controller.NewUserDatabase(userDB, cache, tx, userMongoDB) + //friendRpcClient := rpcclient.NewFriendRpcClient(client) + //groupRpcClient := rpcclient.NewGroupRpcClient(client) + //msgRpcClient := rpcclient.NewMessageRpcClient(client) + + userSrv = &userServer{ + UserDatabase: database, + //RegisterCenter: client, + //friendRpcClient: &friendRpcClient, + //groupRpcClient: &groupRpcClient, + //friendNotificationSender: notification.NewFriendNotificationSender(&msgRpcClient, notification.WithDBFunc(database.FindWithError)), + //userNotificationSender: notification.NewUserNotificationSender(&msgRpcClient, notification.WithUserFunc(database.FindWithError)), + } + + return nil +} + +func init() { + if err := InitDB(); err != nil { + panic(err) + } +} + +var userSrv *userServer + +func TestName(t *testing.T) { + userID := strconv.Itoa(int(rand.Uint32())) + res, err := userSrv.UserRegister(ctx, &user.UserRegisterReq{ + Secret: config.Config.Secret, + Users: []*sdkws.UserInfo{ + { + UserID: userID, + Nickname: userID, + FaceURL: "", + Ex: "", + }, + }, + }) + if err != nil { + panic(err) + } + t.Log(res) + +} diff --git a/internal/tools/msg.go b/internal/tools/msg.go index 63dc16f21..bf0231786 100644 --- a/internal/tools/msg.go +++ b/internal/tools/msg.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/OpenIMSDK/protocol/sdkws" "github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo" + tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx" "math" "github.com/redis/go-redis/v9" @@ -33,13 +34,11 @@ import ( "github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/mcontext" "github.com/OpenIMSDK/tools/mw" - "github.com/OpenIMSDK/tools/tx" "github.com/OpenIMSDK/tools/utils" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/relation" "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification" @@ -74,10 +73,6 @@ func InitMsgTool() (*MsgTool, error) { if err != nil { return nil, err } - db, err := relation.NewGormDB() - if err != nil { - return nil, err - } discov, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) if err != nil { return nil, err @@ -87,12 +82,16 @@ func InitMsgTool() (*MsgTool, error) { if err != nil { return nil, err } + tx, err := tx2.NewAuto(context.Background(), mongo.GetClient()) + if err != nil { + return nil, err + } msgDatabase := controller.InitCommonMsgDatabase(rdb, mongo.GetDatabase()) userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase()) userDatabase := controller.NewUserDatabase( userDB, cache.NewUserCacheRedis(rdb, userDB, cache.GetDefaultOpt()), - tx.NewMongo(mongo.GetClient()), + tx, userMongoDB, ) groupDB, err := newmgo.NewGroupMongo(mongo.GetDatabase()) @@ -111,11 +110,11 @@ func InitMsgTool() (*MsgTool, error) { if err != nil { return nil, err } - groupDatabase := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx.NewMongo(mongo.GetClient()), nil) + groupDatabase := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx, nil) conversationDatabase := controller.NewConversationDatabase( conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), - tx.NewGorm(db), + tx, ) msgRpcClient := rpcclient.NewMessageRpcClient(discov) msgNotificationSender := notification.NewMsgNotificationSender(rpcclient.WithRpcClient(&msgRpcClient)) diff --git a/pkg/common/db/controller/conversation.go b/pkg/common/db/controller/conversation.go index f905eb723..6f7b8acb1 100644 --- a/pkg/common/db/controller/conversation.go +++ b/pkg/common/db/controller/conversation.go @@ -59,7 +59,7 @@ type ConversationDatabase interface { GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) } -func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.Tx) ConversationDatabase { +func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.CtxTx) ConversationDatabase { return &conversationDatabase{ conversationDB: conversation, cache: cache, @@ -70,22 +70,21 @@ func NewConversationDatabase(conversation relationtb.ConversationModelInterface, type conversationDatabase struct { conversationDB relationtb.ConversationModelInterface cache cache.ConversationCache - tx tx.Tx + tx tx.CtxTx } func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) (err error) { - cache := c.cache.NewCache() - if conversation.GroupID != "" { - cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(conversation.GroupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(conversation.GroupID) - } - if err := c.tx.Transaction(func(tx any) error { - conversationTx := c.conversationDB.NewTx(tx) - haveUserIDs, err := conversationTx.FindUserID(ctx, userIDs, []string{conversation.ConversationID}) + return c.tx.Transaction(ctx, func(ctx context.Context) error { + cache := c.cache.NewCache() + if conversation.GroupID != "" { + cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(conversation.GroupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(conversation.GroupID) + } + haveUserIDs, err := c.conversationDB.FindUserID(ctx, userIDs, []string{conversation.ConversationID}) if err != nil { return err } if len(haveUserIDs) > 0 { - _, err = conversationTx.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, filedMap) + _, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, filedMap) if err != nil { return err } @@ -113,17 +112,14 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, conversations = append(conversations, temp) } if len(conversations) > 0 { - err = conversationTx.Create(ctx, conversations) + err = c.conversationDB.Create(ctx, conversations) if err != nil { return err } cache = cache.DelConversationIDs(NotUserIDs...).DelUserConversationIDsHash(NotUserIDs...).DelConversations(conversation.ConversationID, NotUserIDs...) } - return nil - }); err != nil { - return err - } - return cache.ExecDel(ctx) + return cache.ExecDel(ctx) + }) } func (c *conversationDatabase) UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error { @@ -154,19 +150,18 @@ func (c *conversationDatabase) CreateConversation(ctx context.Context, conversat } func (c *conversationDatabase) SyncPeerUserPrivateConversationTx(ctx context.Context, conversations []*relationtb.ConversationModel) error { - cache := c.cache.NewCache() - if err := c.tx.Transaction(func(tx any) error { - conversationTx := c.conversationDB.NewTx(tx) + return c.tx.Transaction(ctx, func(ctx context.Context) error { + cache := c.cache.NewCache() for _, conversation := range conversations { for _, v := range [][2]string{{conversation.OwnerUserID, conversation.UserID}, {conversation.UserID, conversation.OwnerUserID}} { ownerUserID := v[0] userID := v[1] - haveUserIDs, err := conversationTx.FindUserID(ctx, []string{ownerUserID}, []string{conversation.ConversationID}) + haveUserIDs, err := c.conversationDB.FindUserID(ctx, []string{ownerUserID}, []string{conversation.ConversationID}) if err != nil { return err } if len(haveUserIDs) > 0 { - _, err := conversationTx.UpdateByMap(ctx, []string{ownerUserID}, conversation.ConversationID, map[string]any{"is_private_chat": conversation.IsPrivateChat}) + _, err := c.conversationDB.UpdateByMap(ctx, []string{ownerUserID}, conversation.ConversationID, map[string]any{"is_private_chat": conversation.IsPrivateChat}) if err != nil { return err } @@ -177,18 +172,15 @@ func (c *conversationDatabase) SyncPeerUserPrivateConversationTx(ctx context.Con newConversation.UserID = userID newConversation.ConversationID = conversation.ConversationID newConversation.IsPrivateChat = conversation.IsPrivateChat - if err := conversationTx.Create(ctx, []*relationtb.ConversationModel{&newConversation}); err != nil { + if err := c.conversationDB.Create(ctx, []*relationtb.ConversationModel{&newConversation}); err != nil { return err } cache = cache.DelConversationIDs(ownerUserID).DelUserConversationIDsHash(ownerUserID) } } } - return nil - }); err != nil { - return err - } - return cache.ExecDel(ctx) + return cache.ExecDel(ctx) + }) } func (c *conversationDatabase) FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.ConversationModel, error) { @@ -204,28 +196,26 @@ func (c *conversationDatabase) GetUserAllConversation(ctx context.Context, owner } func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.ConversationModel) error { - cache := c.cache.NewCache() - - groupIDs := utils.Distinct(utils.Filter(conversations, func(e *relationtb.ConversationModel) (string, bool) { - return e.GroupID, e.GroupID != "" - })) - for _, groupID := range groupIDs { - cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(groupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID) - } - if err := c.tx.Transaction(func(tx any) error { + return c.tx.Transaction(ctx, func(ctx context.Context) error { + cache := c.cache.NewCache() + groupIDs := utils.Distinct(utils.Filter(conversations, func(e *relationtb.ConversationModel) (string, bool) { + return e.GroupID, e.GroupID != "" + })) + for _, groupID := range groupIDs { + cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(groupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID) + } var conversationIDs []string for _, conversation := range conversations { conversationIDs = append(conversationIDs, conversation.ConversationID) cache = cache.DelConversations(conversation.OwnerUserID, conversation.ConversationID) } - conversationTx := c.conversationDB.NewTx(tx) - existConversations, err := conversationTx.Find(ctx, ownerUserID, conversationIDs) + existConversations, err := c.conversationDB.Find(ctx, ownerUserID, conversationIDs) if err != nil { return err } if len(existConversations) > 0 { for _, conversation := range conversations { - err = conversationTx.Update(ctx, conversation) + err = c.conversationDB.Update(ctx, conversation) if err != nil { return err } @@ -249,11 +239,8 @@ func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUs } cache = cache.DelConversationIDs(ownerUserID).DelUserConversationIDsHash(ownerUserID).DelConversationNotReceiveMessageUserIDs(utils.Slice(notExistConversations, func(e *relationtb.ConversationModel) string { return e.ConversationID })...) } - return nil - }); err != nil { - return err - } - return cache.ExecDel(ctx) + return cache.ExecDel(ctx) + }) } //func (c *conversationDatabase) FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error) { @@ -261,9 +248,9 @@ func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUs //} func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string) error { - cache := c.cache.NewCache() - conversationID := msgprocessor.GetConversationIDBySessionType(constant.SuperGroupChatType, groupID) - if err := c.tx.Transaction(func(tx any) error { + return c.tx.Transaction(ctx, func(ctx context.Context) error { + cache := c.cache.NewCache() + conversationID := msgprocessor.GetConversationIDBySessionType(constant.SuperGroupChatType, groupID) existConversationUserIDs, err := c.conversationDB.FindUserID(ctx, userIDs, []string{conversationID}) if err != nil { return err @@ -289,11 +276,8 @@ func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context, for _, v := range existConversationUserIDs { cache = cache.DelConversations(v, conversationID) } - return nil - }); err != nil { - return err - } - return cache.ExecDel(ctx) + return c.cache.ExecDel(ctx) + }) } func (c *conversationDatabase) GetConversationIDs(ctx context.Context, userID string) ([]string, error) { diff --git a/pkg/common/db/newmgo/conversation.go b/pkg/common/db/newmgo/conversation.go index 863b3ad6c..dbe6eef7c 100644 --- a/pkg/common/db/newmgo/conversation.go +++ b/pkg/common/db/newmgo/conversation.go @@ -127,8 +127,3 @@ func (c *ConversationMgo) GetConversationIDsNeedDestruct(ctx context.Context) ([ func (c *ConversationMgo) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) { return mgotool.Find[string](ctx, c.coll, bson.M{"conversation_id": conversationID, "recv_msg_opt": bson.M{"$ne": constant.ReceiveMessage}}, options.Find().SetProjection(bson.M{"owner_user_id": 1})) } - -func (c *ConversationMgo) NewTx(tx any) relation.ConversationModelInterface { - //TODO implement me - panic("implement me") -} diff --git a/pkg/common/db/table/relation/conversation.go b/pkg/common/db/table/relation/conversation.go index e6e9e249b..ffc82f244 100644 --- a/pkg/common/db/table/relation/conversation.go +++ b/pkg/common/db/table/relation/conversation.go @@ -91,5 +91,4 @@ type ConversationModelInterface interface { GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*ConversationModel, error) GetConversationIDsNeedDestruct(ctx context.Context) ([]*ConversationModel, error) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) - NewTx(tx any) ConversationModelInterface } diff --git a/pkg/common/db/tx/auto.go b/pkg/common/db/tx/auto.go new file mode 100644 index 000000000..bf6817a24 --- /dev/null +++ b/pkg/common/db/tx/auto.go @@ -0,0 +1,19 @@ +package tx + +import ( + "context" + "github.com/OpenIMSDK/tools/tx" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" +) + +func NewAuto(ctx context.Context, cli *mongo.Client) (tx.CtxTx, error) { + var res map[string]any + if err := cli.Database("admin").RunCommand(ctx, bson.M{"isMaster": 1}).Decode(&res); err != nil { + return nil, err + } + if _, ok := res["setName"]; ok { + return NewMongoTx(cli), nil + } + return NewInvalidTx(), nil +} diff --git a/pkg/common/db/tx/invalid.go b/pkg/common/db/tx/invalid.go new file mode 100644 index 000000000..193972af5 --- /dev/null +++ b/pkg/common/db/tx/invalid.go @@ -0,0 +1,16 @@ +package tx + +import ( + "context" + "github.com/OpenIMSDK/tools/tx" +) + +func NewInvalidTx() tx.CtxTx { + return invalid{} +} + +type invalid struct{} + +func (m invalid) Transaction(ctx context.Context, fn func(ctx context.Context) error) error { + return fn(ctx) +} diff --git a/pkg/common/db/tx/tx.go b/pkg/common/db/tx/tx.go new file mode 100644 index 000000000..baf9a9a5d --- /dev/null +++ b/pkg/common/db/tx/tx.go @@ -0,0 +1,28 @@ +package tx + +import ( + "context" + "github.com/OpenIMSDK/tools/tx" + "go.mongodb.org/mongo-driver/mongo" +) + +func NewMongoTx(client *mongo.Client) tx.CtxTx { + return &mongoTx{ + client: client, + } +} + +type mongoTx struct { + client *mongo.Client +} + +func (m *mongoTx) Transaction(ctx context.Context, fn func(ctx context.Context) error) error { + sess, err := m.client.StartSession() + if err != nil { + return err + } + _, err = sess.WithTransaction(ctx, func(ctx mongo.SessionContext) (interface{}, error) { + return nil, fn(ctx) + }) + return err +}