pull/1427/head
withchao 2 years ago
parent 1abfc8c805
commit 20158ec3ef

@ -18,6 +18,7 @@ import (
"context"
"errors"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo"
tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation"
"google.golang.org/grpc"
@ -27,13 +28,11 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/tx"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/convert"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/controller"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/relation"
tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification"
@ -46,18 +45,15 @@ type conversationServer struct {
}
func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
db, err := relation.NewGormDB()
rdb, err := cache.NewRedis()
if err != nil {
return err
}
if err := db.AutoMigrate(&tablerelation.ConversationModel{}); err != nil {
return err
}
rdb, err := cache.NewRedis()
mongo, err := unrelation.NewMongo()
if err != nil {
return err
}
mongo, err := unrelation.NewMongo()
tx, err := tx2.NewAuto(context.Background(), mongo.GetClient())
if err != nil {
return err
}
@ -70,7 +66,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e
pbconversation.RegisterConversationServer(server, &conversationServer{
conversationNotificationSender: notification.NewConversationNotificationSender(&msgRpcClient),
groupRpcClient: &groupRpcClient,
conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx.NewGorm(db)),
conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx),
})
return nil
}

@ -17,8 +17,8 @@ package group
import (
"context"
"fmt"
"github.com/OpenIMSDK/tools/tx"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo"
tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient/grouphash"
"math/big"
"math/rand"
@ -88,8 +88,12 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e
userRpcClient := rpcclient.NewUserRpcClient(client)
msgRpcClient := rpcclient.NewMessageRpcClient(client)
conversationRpcClient := rpcclient.NewConversationRpcClient(client)
tx, err := tx2.NewAuto(context.Background(), mongo.GetClient())
if err != nil {
return err
}
var gs groupServer
database := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx.NewMongo(mongo.GetClient()), grouphash.NewGroupHashFromGroupServer(&gs))
database := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx, grouphash.NewGroupHashFromGroupServer(&gs))
gs.db = database
gs.User = userRpcClient
gs.Notification = notification.NewGroupNotificationSender(database, &msgRpcClient, &userRpcClient, func(ctx context.Context, userIDs []string) ([]notification.CommonUser, error) {

@ -18,6 +18,7 @@ import (
"context"
"errors"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo"
tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx"
"strings"
"time"
@ -25,7 +26,6 @@ import (
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/tx"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation"
@ -36,7 +36,6 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/convert"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/controller"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/relation"
tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification"
@ -56,10 +55,6 @@ type userServer struct {
}
func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error {
db, err := relation.NewGormDB()
if err != nil {
return err
}
rdb, err := cache.NewRedis()
if err != nil {
return err
@ -68,9 +63,6 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error {
if err != nil {
return err
}
if err := db.AutoMigrate(&tablerelation.UserModel{}); err != nil {
return err
}
users := make([]*tablerelation.UserModel, 0)
if len(config.Config.Manager.UserID) != len(config.Config.Manager.Nickname) {
return errors.New("len(config.Config.Manager.AppManagerUid) != len(config.Config.Manager.Nickname)")
@ -82,9 +74,13 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error {
if err != nil {
return err
}
tx, err := tx2.NewAuto(context.Background(), mongo.GetClient())
if err != nil {
return err
}
cache := cache.NewUserCacheRedis(rdb, userDB, cache.GetDefaultOpt())
userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase())
database := controller.NewUserDatabase(userDB, cache, tx.NewMongo(mongo.GetClient()), userMongoDB)
database := controller.NewUserDatabase(userDB, cache, tx, userMongoDB)
friendRpcClient := rpcclient.NewFriendRpcClient(client)
groupRpcClient := rpcclient.NewGroupRpcClient(client)
msgRpcClient := rpcclient.NewMessageRpcClient(client)

@ -0,0 +1,132 @@
package user
import (
"context"
"errors"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/protocol/user"
"github.com/OpenIMSDK/tools/log"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/controller"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo"
tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation"
"github.com/redis/go-redis/v9"
"math/rand"
"net"
"strconv"
"testing"
)
var (
rdb redis.UniversalClient
mgo *unrelation.Mongo
ctx context.Context
)
func InitDB() error {
addr := "172.16.8.142"
pwd := "openIM123"
config.Config.Redis.Address = []string{net.JoinHostPort(addr, "16379")}
config.Config.Redis.Password = pwd
config.Config.Mongo.Address = []string{net.JoinHostPort(addr, "37017")}
config.Config.Mongo.Database = "openIM_v3"
config.Config.Mongo.Username = "root"
config.Config.Mongo.Password = pwd
config.Config.Mongo.MaxPoolSize = 100
var err error
rdb, err = cache.NewRedis()
if err != nil {
return err
}
mgo, err = unrelation.NewMongo()
if err != nil {
return err
}
tx, err := tx2.NewAuto(context.Background(), mgo.GetClient())
if err != nil {
return err
}
config.Config.Object.Enable = "minio"
config.Config.Object.ApiURL = "http://" + net.JoinHostPort(addr, "10002")
config.Config.Object.Minio.Bucket = "openim"
config.Config.Object.Minio.Endpoint = "http://" + net.JoinHostPort(addr, "10005")
config.Config.Object.Minio.AccessKeyID = "root"
config.Config.Object.Minio.SecretAccessKey = pwd
config.Config.Object.Minio.SignEndpoint = config.Config.Object.Minio.Endpoint
config.Config.Manager.UserID = []string{"openIM123456"}
config.Config.Manager.Nickname = []string{"openIM123456"}
ctx = context.WithValue(context.Background(), constant.OperationID, "debugOperationID")
ctx = context.WithValue(context.Background(), constant.OpUserID, config.Config.Manager.UserID[0])
if err := log.InitFromConfig("", "", 6, true, false, "", 2, 1); err != nil {
panic(err)
}
users := make([]*tablerelation.UserModel, 0)
if len(config.Config.Manager.UserID) != len(config.Config.Manager.Nickname) {
return errors.New("len(config.Config.Manager.AppManagerUid) != len(config.Config.Manager.Nickname)")
}
for k, v := range config.Config.Manager.UserID {
users = append(users, &tablerelation.UserModel{UserID: v, Nickname: config.Config.Manager.Nickname[k], AppMangerLevel: constant.AppAdmin})
}
userDB, err := newmgo.NewUserMongo(mgo.GetDatabase())
if err != nil {
return err
}
//var client registry.SvcDiscoveryRegistry
//_= client
cache := cache.NewUserCacheRedis(rdb, userDB, cache.GetDefaultOpt())
userMongoDB := unrelation.NewUserMongoDriver(mgo.GetDatabase())
database := controller.NewUserDatabase(userDB, cache, tx, userMongoDB)
//friendRpcClient := rpcclient.NewFriendRpcClient(client)
//groupRpcClient := rpcclient.NewGroupRpcClient(client)
//msgRpcClient := rpcclient.NewMessageRpcClient(client)
userSrv = &userServer{
UserDatabase: database,
//RegisterCenter: client,
//friendRpcClient: &friendRpcClient,
//groupRpcClient: &groupRpcClient,
//friendNotificationSender: notification.NewFriendNotificationSender(&msgRpcClient, notification.WithDBFunc(database.FindWithError)),
//userNotificationSender: notification.NewUserNotificationSender(&msgRpcClient, notification.WithUserFunc(database.FindWithError)),
}
return nil
}
func init() {
if err := InitDB(); err != nil {
panic(err)
}
}
var userSrv *userServer
func TestName(t *testing.T) {
userID := strconv.Itoa(int(rand.Uint32()))
res, err := userSrv.UserRegister(ctx, &user.UserRegisterReq{
Secret: config.Config.Secret,
Users: []*sdkws.UserInfo{
{
UserID: userID,
Nickname: userID,
FaceURL: "",
Ex: "",
},
},
})
if err != nil {
panic(err)
}
t.Log(res)
}

@ -19,6 +19,7 @@ import (
"fmt"
"github.com/OpenIMSDK/protocol/sdkws"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/newmgo"
tx2 "github.com/openimsdk/open-im-server/v3/pkg/common/db/tx"
"math"
"github.com/redis/go-redis/v9"
@ -33,13 +34,11 @@ import (
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/mcontext"
"github.com/OpenIMSDK/tools/mw"
"github.com/OpenIMSDK/tools/tx"
"github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/controller"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/relation"
"github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient"
"github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification"
@ -74,10 +73,6 @@ func InitMsgTool() (*MsgTool, error) {
if err != nil {
return nil, err
}
db, err := relation.NewGormDB()
if err != nil {
return nil, err
}
discov, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery)
if err != nil {
return nil, err
@ -87,12 +82,16 @@ func InitMsgTool() (*MsgTool, error) {
if err != nil {
return nil, err
}
tx, err := tx2.NewAuto(context.Background(), mongo.GetClient())
if err != nil {
return nil, err
}
msgDatabase := controller.InitCommonMsgDatabase(rdb, mongo.GetDatabase())
userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase())
userDatabase := controller.NewUserDatabase(
userDB,
cache.NewUserCacheRedis(rdb, userDB, cache.GetDefaultOpt()),
tx.NewMongo(mongo.GetClient()),
tx,
userMongoDB,
)
groupDB, err := newmgo.NewGroupMongo(mongo.GetDatabase())
@ -111,11 +110,11 @@ func InitMsgTool() (*MsgTool, error) {
if err != nil {
return nil, err
}
groupDatabase := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx.NewMongo(mongo.GetClient()), nil)
groupDatabase := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx, nil)
conversationDatabase := controller.NewConversationDatabase(
conversationDB,
cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB),
tx.NewGorm(db),
tx,
)
msgRpcClient := rpcclient.NewMessageRpcClient(discov)
msgNotificationSender := notification.NewMsgNotificationSender(rpcclient.WithRpcClient(&msgRpcClient))

@ -59,7 +59,7 @@ type ConversationDatabase interface {
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
}
func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.Tx) ConversationDatabase {
func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.CtxTx) ConversationDatabase {
return &conversationDatabase{
conversationDB: conversation,
cache: cache,
@ -70,22 +70,21 @@ func NewConversationDatabase(conversation relationtb.ConversationModelInterface,
type conversationDatabase struct {
conversationDB relationtb.ConversationModelInterface
cache cache.ConversationCache
tx tx.Tx
tx tx.CtxTx
}
func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) (err error) {
cache := c.cache.NewCache()
if conversation.GroupID != "" {
cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(conversation.GroupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(conversation.GroupID)
}
if err := c.tx.Transaction(func(tx any) error {
conversationTx := c.conversationDB.NewTx(tx)
haveUserIDs, err := conversationTx.FindUserID(ctx, userIDs, []string{conversation.ConversationID})
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.NewCache()
if conversation.GroupID != "" {
cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(conversation.GroupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(conversation.GroupID)
}
haveUserIDs, err := c.conversationDB.FindUserID(ctx, userIDs, []string{conversation.ConversationID})
if err != nil {
return err
}
if len(haveUserIDs) > 0 {
_, err = conversationTx.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, filedMap)
_, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, filedMap)
if err != nil {
return err
}
@ -113,17 +112,14 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
conversations = append(conversations, temp)
}
if len(conversations) > 0 {
err = conversationTx.Create(ctx, conversations)
err = c.conversationDB.Create(ctx, conversations)
if err != nil {
return err
}
cache = cache.DelConversationIDs(NotUserIDs...).DelUserConversationIDsHash(NotUserIDs...).DelConversations(conversation.ConversationID, NotUserIDs...)
}
return nil
}); err != nil {
return err
}
return cache.ExecDel(ctx)
return cache.ExecDel(ctx)
})
}
func (c *conversationDatabase) UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error {
@ -154,19 +150,18 @@ func (c *conversationDatabase) CreateConversation(ctx context.Context, conversat
}
func (c *conversationDatabase) SyncPeerUserPrivateConversationTx(ctx context.Context, conversations []*relationtb.ConversationModel) error {
cache := c.cache.NewCache()
if err := c.tx.Transaction(func(tx any) error {
conversationTx := c.conversationDB.NewTx(tx)
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.NewCache()
for _, conversation := range conversations {
for _, v := range [][2]string{{conversation.OwnerUserID, conversation.UserID}, {conversation.UserID, conversation.OwnerUserID}} {
ownerUserID := v[0]
userID := v[1]
haveUserIDs, err := conversationTx.FindUserID(ctx, []string{ownerUserID}, []string{conversation.ConversationID})
haveUserIDs, err := c.conversationDB.FindUserID(ctx, []string{ownerUserID}, []string{conversation.ConversationID})
if err != nil {
return err
}
if len(haveUserIDs) > 0 {
_, err := conversationTx.UpdateByMap(ctx, []string{ownerUserID}, conversation.ConversationID, map[string]any{"is_private_chat": conversation.IsPrivateChat})
_, err := c.conversationDB.UpdateByMap(ctx, []string{ownerUserID}, conversation.ConversationID, map[string]any{"is_private_chat": conversation.IsPrivateChat})
if err != nil {
return err
}
@ -177,18 +172,15 @@ func (c *conversationDatabase) SyncPeerUserPrivateConversationTx(ctx context.Con
newConversation.UserID = userID
newConversation.ConversationID = conversation.ConversationID
newConversation.IsPrivateChat = conversation.IsPrivateChat
if err := conversationTx.Create(ctx, []*relationtb.ConversationModel{&newConversation}); err != nil {
if err := c.conversationDB.Create(ctx, []*relationtb.ConversationModel{&newConversation}); err != nil {
return err
}
cache = cache.DelConversationIDs(ownerUserID).DelUserConversationIDsHash(ownerUserID)
}
}
}
return nil
}); err != nil {
return err
}
return cache.ExecDel(ctx)
return cache.ExecDel(ctx)
})
}
func (c *conversationDatabase) FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.ConversationModel, error) {
@ -204,28 +196,26 @@ func (c *conversationDatabase) GetUserAllConversation(ctx context.Context, owner
}
func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.ConversationModel) error {
cache := c.cache.NewCache()
groupIDs := utils.Distinct(utils.Filter(conversations, func(e *relationtb.ConversationModel) (string, bool) {
return e.GroupID, e.GroupID != ""
}))
for _, groupID := range groupIDs {
cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(groupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID)
}
if err := c.tx.Transaction(func(tx any) error {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.NewCache()
groupIDs := utils.Distinct(utils.Filter(conversations, func(e *relationtb.ConversationModel) (string, bool) {
return e.GroupID, e.GroupID != ""
}))
for _, groupID := range groupIDs {
cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(groupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID)
}
var conversationIDs []string
for _, conversation := range conversations {
conversationIDs = append(conversationIDs, conversation.ConversationID)
cache = cache.DelConversations(conversation.OwnerUserID, conversation.ConversationID)
}
conversationTx := c.conversationDB.NewTx(tx)
existConversations, err := conversationTx.Find(ctx, ownerUserID, conversationIDs)
existConversations, err := c.conversationDB.Find(ctx, ownerUserID, conversationIDs)
if err != nil {
return err
}
if len(existConversations) > 0 {
for _, conversation := range conversations {
err = conversationTx.Update(ctx, conversation)
err = c.conversationDB.Update(ctx, conversation)
if err != nil {
return err
}
@ -249,11 +239,8 @@ func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUs
}
cache = cache.DelConversationIDs(ownerUserID).DelUserConversationIDsHash(ownerUserID).DelConversationNotReceiveMessageUserIDs(utils.Slice(notExistConversations, func(e *relationtb.ConversationModel) string { return e.ConversationID })...)
}
return nil
}); err != nil {
return err
}
return cache.ExecDel(ctx)
return cache.ExecDel(ctx)
})
}
//func (c *conversationDatabase) FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error) {
@ -261,9 +248,9 @@ func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUs
//}
func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string) error {
cache := c.cache.NewCache()
conversationID := msgprocessor.GetConversationIDBySessionType(constant.SuperGroupChatType, groupID)
if err := c.tx.Transaction(func(tx any) error {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.NewCache()
conversationID := msgprocessor.GetConversationIDBySessionType(constant.SuperGroupChatType, groupID)
existConversationUserIDs, err := c.conversationDB.FindUserID(ctx, userIDs, []string{conversationID})
if err != nil {
return err
@ -289,11 +276,8 @@ func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context,
for _, v := range existConversationUserIDs {
cache = cache.DelConversations(v, conversationID)
}
return nil
}); err != nil {
return err
}
return cache.ExecDel(ctx)
return c.cache.ExecDel(ctx)
})
}
func (c *conversationDatabase) GetConversationIDs(ctx context.Context, userID string) ([]string, error) {

@ -127,8 +127,3 @@ func (c *ConversationMgo) GetConversationIDsNeedDestruct(ctx context.Context) ([
func (c *ConversationMgo) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) {
return mgotool.Find[string](ctx, c.coll, bson.M{"conversation_id": conversationID, "recv_msg_opt": bson.M{"$ne": constant.ReceiveMessage}}, options.Find().SetProjection(bson.M{"owner_user_id": 1}))
}
func (c *ConversationMgo) NewTx(tx any) relation.ConversationModelInterface {
//TODO implement me
panic("implement me")
}

@ -91,5 +91,4 @@ type ConversationModelInterface interface {
GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*ConversationModel, error)
GetConversationIDsNeedDestruct(ctx context.Context) ([]*ConversationModel, error)
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
NewTx(tx any) ConversationModelInterface
}

@ -0,0 +1,19 @@
package tx
import (
"context"
"github.com/OpenIMSDK/tools/tx"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
func NewAuto(ctx context.Context, cli *mongo.Client) (tx.CtxTx, error) {
var res map[string]any
if err := cli.Database("admin").RunCommand(ctx, bson.M{"isMaster": 1}).Decode(&res); err != nil {
return nil, err
}
if _, ok := res["setName"]; ok {
return NewMongoTx(cli), nil
}
return NewInvalidTx(), nil
}

@ -0,0 +1,16 @@
package tx
import (
"context"
"github.com/OpenIMSDK/tools/tx"
)
func NewInvalidTx() tx.CtxTx {
return invalid{}
}
type invalid struct{}
func (m invalid) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
return fn(ctx)
}

@ -0,0 +1,28 @@
package tx
import (
"context"
"github.com/OpenIMSDK/tools/tx"
"go.mongodb.org/mongo-driver/mongo"
)
func NewMongoTx(client *mongo.Client) tx.CtxTx {
return &mongoTx{
client: client,
}
}
type mongoTx struct {
client *mongo.Client
}
func (m *mongoTx) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
sess, err := m.client.StartSession()
if err != nil {
return err
}
_, err = sess.WithTransaction(ctx, func(ctx mongo.SessionContext) (interface{}, error) {
return nil, fn(ctx)
})
return err
}
Loading…
Cancel
Save