diff --git a/test/mongo/cmd/main.go b/test/mongo/cmd/main.go index ffc4a6eb0..2a8d41523 100644 --- a/test/mongo/cmd/main.go +++ b/test/mongo/cmd/main.go @@ -1,6 +1,7 @@ package main import ( + "Open_IM/pkg/common/config" mongo2 "Open_IM/test/mongo" "context" "flag" @@ -10,9 +11,22 @@ import ( ) func init() { - clientOptions := options.Client().ApplyURI("mongodb://127.0.0.1:37017/openIM/?maxPoolSize=100") + uri := "mongodb://sample.host:27017/?maxPoolSize=20&w=majority" + if config.Config.Mongo.DBUri != "" { + // example: mongodb://$user:$password@mongo1.mongo:27017,mongo2.mongo:27017,mongo3.mongo:27017/$DBDatabase/?replicaSet=rs0&readPreference=secondary&authSource=admin&maxPoolSize=$DBMaxPoolSize + uri = config.Config.Mongo.DBUri + } else { + if config.Config.Mongo.DBPassword != "" && config.Config.Mongo.DBUserName != "" { + uri = fmt.Sprintf("mongodb://%s:%s@%s/%s?maxPoolSize=%d", config.Config.Mongo.DBUserName, config.Config.Mongo.DBPassword, config.Config.Mongo.DBAddress[0], + config.Config.Mongo.DBDatabase, config.Config.Mongo.DBMaxPoolSize) + } else { + uri = fmt.Sprintf("mongodb://%s/%s/?maxPoolSize=%d", + config.Config.Mongo.DBAddress[0], config.Config.Mongo.DBDatabase, + config.Config.Mongo.DBMaxPoolSize) + } + } var err error - mongo2.Client, err = mongo.Connect(context.TODO(), clientOptions) + mongo2.Client, err = mongo.Connect(context.TODO(), options.Client().ApplyURI(uri)) if err != nil { panic(err) } diff --git a/test/mongo/mongo_utils.go b/test/mongo/mongo_utils.go index 86c305e26..53835c355 100644 --- a/test/mongo/mongo_utils.go +++ b/test/mongo/mongo_utils.go @@ -10,6 +10,7 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "gopkg.in/mgo.v2/bson" + "time" ) var ( @@ -17,14 +18,15 @@ var ( ) func GetUserAllChat(uid string) { + ctx, _ := context.WithTimeout(context.Background(), time.Duration(config.Config.Mongo.DBTimeout)*time.Second) collection := Client.Database(config.Config.Mongo.DBDatabase).Collection("msg") var userChatList []db.UserChat result, err := collection.Find(context.Background(), bson.M{"uid": primitive.Regex{Pattern: uid}}) if err != nil { - fmt.Println(err.Error()) + fmt.Println("find error", err.Error()) return } - if err := result.All(context.Background(), &userChatList); err != nil { + if err := result.All(ctx, &userChatList); err != nil { fmt.Println(err.Error()) } for _, userChat := range userChatList {