fix: replace global config with dependency injection

pull/1960/head
luhaoling 2 years ago
parent c4cdf4de10
commit cd3a564477

@ -66,7 +66,7 @@ func Start(config *config.GlobalConfig, port int, proPort int) error {
var client discoveryregistry.SvcDiscoveryRegistry var client discoveryregistry.SvcDiscoveryRegistry
// Determine whether zk is passed according to whether it is a clustered deployment // Determine whether zk is passed according to whether it is a clustered deployment
client, err = kdisc.NewDiscoveryRegister(config.Envs.Discovery) client, err = kdisc.NewDiscoveryRegister(config)
if err != nil { if err != nil {
return errs.Wrap(err, "register discovery err") return errs.Wrap(err, "register discovery err")
} }
@ -318,6 +318,7 @@ func GinParseToken(rdb redis.UniversalClient, config *config.GlobalConfig) gin.H
cache.NewMsgCacheModel(rdb, config), cache.NewMsgCacheModel(rdb, config),
config.Secret, config.Secret,
config.TokenPolicy.Expire, config.TokenPolicy.Expire,
config,
) )
return func(c *gin.Context) { return func(c *gin.Context) {
switch c.Request.Method { switch c.Request.Method {
@ -329,7 +330,7 @@ func GinParseToken(rdb redis.UniversalClient, config *config.GlobalConfig) gin.H
c.Abort() c.Abort()
return return
} }
claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret()) claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret(config.Secret))
if err != nil { if err != nil {
log.ZWarn(c, "jwt get token error", errs.ErrTokenUnknown.Wrap()) log.ZWarn(c, "jwt get token error", errs.ErrTokenUnknown.Wrap())
apiresp.GinError(c, errs.ErrTokenUnknown.Wrap()) apiresp.GinError(c, errs.ErrTokenUnknown.Wrap())

@ -39,8 +39,8 @@ func (s *Server) InitServer(config *config.GlobalConfig, disCov discoveryregistr
return err return err
} }
msgModel := cache.NewMsgCacheModel(rdb) msgModel := cache.NewMsgCacheModel(rdb, config)
s.LongConnServer.SetDiscoveryRegistry(disCov) s.LongConnServer.SetDiscoveryRegistry(disCov, config)
s.LongConnServer.SetCacheHandler(msgModel) s.LongConnServer.SetCacheHandler(msgModel)
msggateway.RegisterMsgGatewayServer(server, s) msggateway.RegisterMsgGatewayServer(server, s)
return nil return nil
@ -61,18 +61,20 @@ type Server struct {
prometheusPort int prometheusPort int
LongConnServer LongConnServer LongConnServer LongConnServer
pushTerminal []int pushTerminal []int
config *config.GlobalConfig
} }
func (s *Server) SetLongConnServer(LongConnServer LongConnServer) { func (s *Server) SetLongConnServer(LongConnServer LongConnServer) {
s.LongConnServer = LongConnServer s.LongConnServer = LongConnServer
} }
func NewServer(rpcPort int, proPort int, longConnServer LongConnServer) *Server { func NewServer(rpcPort int, proPort int, longConnServer LongConnServer, config *config.GlobalConfig) *Server {
return &Server{ return &Server{
rpcPort: rpcPort, rpcPort: rpcPort,
prometheusPort: proPort, prometheusPort: proPort,
LongConnServer: longConnServer, LongConnServer: longConnServer,
pushTerminal: []int{constant.IOSPlatformID, constant.AndroidPlatformID}, pushTerminal: []int{constant.IOSPlatformID, constant.AndroidPlatformID},
config: config,
} }
} }
@ -87,7 +89,7 @@ func (s *Server) GetUsersOnlineStatus(
ctx context.Context, ctx context.Context,
req *msggateway.GetUsersOnlineStatusReq, req *msggateway.GetUsersOnlineStatusReq,
) (*msggateway.GetUsersOnlineStatusResp, error) { ) (*msggateway.GetUsersOnlineStatusResp, error) {
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
return nil, errs.ErrNoPermission.Wrap("only app manager") return nil, errs.ErrNoPermission.Wrap("only app manager")
} }
var resp msggateway.GetUsersOnlineStatusResp var resp msggateway.GetUsersOnlineStatusResp

@ -43,7 +43,7 @@ func RunWsAndServer(conf *config.GlobalConfig, rpcPort, wsPort, prometheusPort i
return err return err
} }
hubServer := NewServer(rpcPort, prometheusPort, longServer) hubServer := NewServer(rpcPort, prometheusPort, longServer, conf)
netDone := make(chan error) netDone := make(chan error)
go func() { go func() {
err = hubServer.Start(conf) err = hubServer.Start(conf)

@ -110,7 +110,7 @@ type GrpcHandler struct {
func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *GrpcHandler { func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *GrpcHandler {
msgRpcClient := rpcclient.NewMessageRpcClient(client, config) msgRpcClient := rpcclient.NewMessageRpcClient(client, config)
pushRpcClient := rpcclient.NewPushRpcClient(client) pushRpcClient := rpcclient.NewPushRpcClient(client, config)
return &GrpcHandler{ return &GrpcHandler{
msgRpcClient: &msgRpcClient, msgRpcClient: &msgRpcClient,
pushClient: &pushRpcClient, validate: validate, pushClient: &pushRpcClient, validate: validate,

@ -58,7 +58,7 @@ func StartTransfer(config *config.GlobalConfig, prometheusPort int) error {
return err return err
} }
mongo, err := unrelation.NewMongo() mongo, err := unrelation.NewMongo(config)
if err != nil { if err != nil {
return err return err
} }
@ -66,7 +66,7 @@ func StartTransfer(config *config.GlobalConfig, prometheusPort int) error {
if err = mongo.CreateMsgIndex(); err != nil { if err = mongo.CreateMsgIndex(); err != nil {
return err return err
} }
client, err := kdisc.NewDiscoveryRegister(config.Envs.Discovery) client, err := kdisc.NewDiscoveryRegister(config)
if err != nil { if err != nil {
return err return err
} }
@ -75,14 +75,14 @@ func StartTransfer(config *config.GlobalConfig, prometheusPort int) error {
return err return err
} }
client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")))
msgModel := cache.NewMsgCacheModel(rdb) msgModel := cache.NewMsgCacheModel(rdb, config)
msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase()) msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase(config.Mongo.Database))
msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, msgModel) msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, msgModel, config)
if err != nil { if err != nil {
return err return err
} }
conversationRpcClient := rpcclient.NewConversationRpcClient(client) conversationRpcClient := rpcclient.NewConversationRpcClient(client, config)
groupRpcClient := rpcclient.NewGroupRpcClient(client) groupRpcClient := rpcclient.NewGroupRpcClient(client, config)
msgTransfer, err := NewMsgTransfer(config, msgDatabase, &conversationRpcClient, &groupRpcClient) msgTransfer, err := NewMsgTransfer(config, msgDatabase, &conversationRpcClient, &groupRpcClient)
if err != nil { if err != nil {
return err return err

@ -101,11 +101,24 @@ func NewOnlineHistoryRedisConsumerHandler(
och.conversationRpcClient = conversationRpcClient och.conversationRpcClient = conversationRpcClient
och.groupRpcClient = groupRpcClient och.groupRpcClient = groupRpcClient
var err error var err error
tlsConfig := &kafka.TLSConfig{
CACrt: config.Kafka.TLS.CACrt,
ClientCrt: config.Kafka.TLS.ClientCrt,
ClientKey: config.Kafka.TLS.ClientKey,
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
InsecureSkipVerify: false,
}
och.historyConsumerGroup, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ och.historyConsumerGroup, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{
KafkaVersion: sarama.V2_0_0_0, KafkaVersion: sarama.V2_0_0_0,
OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, OffsetsInitial: sarama.OffsetNewest,
IsReturnErr: false,
UserName: config.Kafka.Username,
Password: config.Kafka.Password,
}, []string{config.Kafka.LatestMsgToRedis.Topic}, }, []string{config.Kafka.LatestMsgToRedis.Topic},
config.Kafka.Addr, config.Kafka.ConsumerGroupID.MsgToRedis) config.Kafka.Addr,
config.Kafka.ConsumerGroupID.MsgToRedis,
tlsConfig,
)
// statistics.NewStatistics(&och.singleMsgSuccessCount, config.Config.ModuleName.MsgTransferName, fmt.Sprintf("%d // statistics.NewStatistics(&och.singleMsgSuccessCount, config.Config.ModuleName.MsgTransferName, fmt.Sprintf("%d
// second singleMsgCount insert to mongo", constant.StatisticsTimeInterval), constant.StatisticsTimeInterval) // second singleMsgCount insert to mongo", constant.StatisticsTimeInterval), constant.StatisticsTimeInterval)
return &och, err return &och, err

@ -35,11 +35,24 @@ type OnlineHistoryMongoConsumerHandler struct {
} }
func NewOnlineHistoryMongoConsumerHandler(config *config.GlobalConfig, database controller.CommonMsgDatabase) (*OnlineHistoryMongoConsumerHandler, error) { func NewOnlineHistoryMongoConsumerHandler(config *config.GlobalConfig, database controller.CommonMsgDatabase) (*OnlineHistoryMongoConsumerHandler, error) {
tlsConfig := &kfk.TLSConfig{
CACrt: config.Kafka.TLS.CACrt,
ClientCrt: config.Kafka.TLS.ClientCrt,
ClientKey: config.Kafka.TLS.ClientKey,
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
InsecureSkipVerify: false,
}
historyConsumerGroup, err := kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{ historyConsumerGroup, err := kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{
KafkaVersion: sarama.V2_0_0_0, KafkaVersion: sarama.V2_0_0_0,
OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, OffsetsInitial: sarama.OffsetNewest,
IsReturnErr: false,
UserName: config.Kafka.Username,
Password: config.Kafka.Password,
}, []string{config.Kafka.MsgToMongo.Topic}, }, []string{config.Kafka.MsgToMongo.Topic},
config.Kafka.Addr, config.Kafka.ConsumerGroupID.MsgToMongo) config.Kafka.Addr,
config.Kafka.ConsumerGroupID.MsgToMongo,
tlsConfig,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -39,11 +39,22 @@ func NewConsumerHandler(config *config.GlobalConfig, pusher *Pusher) (*ConsumerH
var consumerHandler ConsumerHandler var consumerHandler ConsumerHandler
consumerHandler.pusher = pusher consumerHandler.pusher = pusher
var err error var err error
tlsConfig := &kfk.TLSConfig{
CACrt: config.Kafka.TLS.CACrt,
ClientCrt: config.Kafka.TLS.ClientCrt,
ClientKey: config.Kafka.TLS.ClientKey,
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
InsecureSkipVerify: false,
}
consumerHandler.pushConsumerGroup, err = kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{ consumerHandler.pushConsumerGroup, err = kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{
KafkaVersion: sarama.V2_0_0_0, KafkaVersion: sarama.V2_0_0_0,
OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, OffsetsInitial: sarama.OffsetNewest,
IsReturnErr: false,
UserName: config.Kafka.Username,
Password: config.Kafka.Password,
}, []string{config.Kafka.MsgToPush.Topic}, config.Kafka.Addr, }, []string{config.Kafka.MsgToPush.Topic}, config.Kafka.Addr,
config.Kafka.ConsumerGroupID.MsgToPush) config.Kafka.ConsumerGroupID.MsgToPush,
tlsConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -43,12 +43,12 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
if err != nil { if err != nil {
return err return err
} }
cacheModel := cache.NewMsgCacheModel(rdb) cacheModel := cache.NewMsgCacheModel(rdb, config)
offlinePusher := NewOfflinePusher(config, cacheModel) offlinePusher := NewOfflinePusher(config, cacheModel)
database := controller.NewPushDatabase(cacheModel) database := controller.NewPushDatabase(cacheModel)
groupRpcClient := rpcclient.NewGroupRpcClient(client) groupRpcClient := rpcclient.NewGroupRpcClient(client, config)
conversationRpcClient := rpcclient.NewConversationRpcClient(client) conversationRpcClient := rpcclient.NewConversationRpcClient(client, config)
msgRpcClient := rpcclient.NewMessageRpcClient(client) msgRpcClient := rpcclient.NewMessageRpcClient(client, config)
pusher := NewPusher( pusher := NewPusher(
config, config,
client, client,

@ -55,9 +55,10 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
userRpcClient: &userRpcClient, userRpcClient: &userRpcClient,
RegisterCenter: client, RegisterCenter: client,
authDatabase: controller.NewAuthDatabase( authDatabase: controller.NewAuthDatabase(
cache.NewMsgCacheModel(rdb), cache.NewMsgCacheModel(rdb, config),
config.Secret, config.Secret,
config.TokenPolicy.Expire, config.TokenPolicy.Expire,
config,
), ),
config: config, config: config,
}) })
@ -83,12 +84,12 @@ func (s *authServer) UserToken(ctx context.Context, req *pbauth.UserTokenReq) (*
} }
func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) { func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil { if err := authverify.CheckAdmin(ctx, s.config); err != nil {
return nil, err return nil, err
} }
resp := pbauth.GetUserTokenResp{} resp := pbauth.GetUserTokenResp{}
if authverify.IsManagerUserID(req.UserID) { if authverify.IsManagerUserID(req.UserID, s.config) {
return nil, errs.ErrNoPermission.Wrap("don't get Admin token") return nil, errs.ErrNoPermission.Wrap("don't get Admin token")
} }
@ -105,7 +106,7 @@ func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenR
} }
func (s *authServer) parseToken(ctx context.Context, tokensString string) (claims *tokenverify.Claims, err error) { func (s *authServer) parseToken(ctx context.Context, tokensString string) (claims *tokenverify.Claims, err error) {
claims, err = tokenverify.GetClaimFromToken(tokensString, authverify.Secret()) claims, err = tokenverify.GetClaimFromToken(tokensString, authverify.Secret(s.config.Secret))
if err != nil { if err != nil {
return nil, utils.Wrap(err, "") return nil, utils.Wrap(err, "")
} }
@ -145,7 +146,7 @@ func (s *authServer) ParseToken(
} }
func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) { func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil { if err := authverify.CheckAdmin(ctx, s.config); err != nil {
return nil, err return nil, err
} }
if err := s.forceKickOff(ctx, req.UserID, req.PlatformID, mcontext.GetOperationID(ctx)); err != nil { if err := s.forceKickOff(ctx, req.UserID, req.PlatformID, mcontext.GetOperationID(ctx)); err != nil {

@ -80,7 +80,7 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
pbconversation.RegisterConversationServer(server, &conversationServer{ pbconversation.RegisterConversationServer(server, &conversationServer{
msgRpcClient: &msgRpcClient, msgRpcClient: &msgRpcClient,
user: &userRpcClient, user: &userRpcClient,
conversationNotificationSender: notification.NewConversationNotificationSender(&msgRpcClient), conversationNotificationSender: notification.NewConversationNotificationSender(config, &msgRpcClient),
groupRpcClient: &groupRpcClient, groupRpcClient: &groupRpcClient,
conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx.NewMongo(mongo.GetClient())), conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx.NewMongo(mongo.GetClient())),
config: config, config: config,

@ -67,7 +67,7 @@ func (s *friendServer) RemoveBlack(ctx context.Context, req *pbfriend.RemoveBlac
} }
func (s *friendServer) AddBlack(ctx context.Context, req *pbfriend.AddBlackReq) (*pbfriend.AddBlackResp, error) { func (s *friendServer) AddBlack(ctx context.Context, req *pbfriend.AddBlackReq) (*pbfriend.AddBlackResp, error) {
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config); err != nil {
return nil, err return nil, err
} }
_, err := s.userRpcClient.GetUsersInfo(ctx, []string{req.OwnerUserID, req.BlackUserID}) _, err := s.userRpcClient.GetUsersInfo(ctx, []string{req.OwnerUserID, req.BlackUserID})

@ -89,6 +89,7 @@ func Start(config *config.GlobalConfig, client registry.SvcDiscoveryRegistry, se
// Initialize notification sender // Initialize notification sender
notificationSender := notification.NewFriendNotificationSender( notificationSender := notification.NewFriendNotificationSender(
config,
&msgRpcClient, &msgRpcClient,
notification.WithRpcFunc(userRpcClient.GetUsersInfo), notification.WithRpcFunc(userRpcClient.GetUsersInfo),
) )
@ -117,7 +118,7 @@ func Start(config *config.GlobalConfig, client registry.SvcDiscoveryRegistry, se
// ok. // ok.
func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriendReq) (resp *pbfriend.ApplyToAddFriendResp, err error) { func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriendReq) (resp *pbfriend.ApplyToAddFriendResp, err error) {
resp = &pbfriend.ApplyToAddFriendResp{} resp = &pbfriend.ApplyToAddFriendResp{}
if err := authverify.CheckAccessV3(ctx, req.FromUserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.FromUserID, s.config); err != nil {
return nil, err return nil, err
} }
if req.ToUserID == req.FromUserID { if req.ToUserID == req.FromUserID {
@ -149,7 +150,7 @@ func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *pbfriend.Apply
// ok. // ok.
func (s *friendServer) ImportFriends(ctx context.Context, req *pbfriend.ImportFriendReq) (resp *pbfriend.ImportFriendResp, err error) { func (s *friendServer) ImportFriends(ctx context.Context, req *pbfriend.ImportFriendReq) (resp *pbfriend.ImportFriendResp, err error) {
defer log.ZInfo(ctx, utils.GetFuncName()+" Return") defer log.ZInfo(ctx, utils.GetFuncName()+" Return")
if err := authverify.CheckAdmin(ctx); err != nil { if err := authverify.CheckAdmin(ctx, s.config); err != nil {
return nil, err return nil, err
} }
if _, err := s.userRpcClient.GetUsersInfo(ctx, append([]string{req.OwnerUserID}, req.FriendUserIDs...)); err != nil { if _, err := s.userRpcClient.GetUsersInfo(ctx, append([]string{req.OwnerUserID}, req.FriendUserIDs...)); err != nil {
@ -185,7 +186,7 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *pbfriend.ImportFr
func (s *friendServer) RespondFriendApply(ctx context.Context, req *pbfriend.RespondFriendApplyReq) (resp *pbfriend.RespondFriendApplyResp, err error) { func (s *friendServer) RespondFriendApply(ctx context.Context, req *pbfriend.RespondFriendApplyReq) (resp *pbfriend.RespondFriendApplyResp, err error) {
defer log.ZInfo(ctx, utils.GetFuncName()+" Return") defer log.ZInfo(ctx, utils.GetFuncName()+" Return")
resp = &pbfriend.RespondFriendApplyResp{} resp = &pbfriend.RespondFriendApplyResp{}
if err := authverify.CheckAccessV3(ctx, req.ToUserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.ToUserID, s.config); err != nil {
return nil, err return nil, err
} }

@ -88,7 +88,7 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
database := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx.NewMongo(mongo.GetClient()), grouphash.NewGroupHashFromGroupServer(&gs)) database := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx.NewMongo(mongo.GetClient()), grouphash.NewGroupHashFromGroupServer(&gs))
gs.db = database gs.db = database
gs.User = userRpcClient gs.User = userRpcClient
gs.Notification = notification.NewGroupNotificationSender(database, &msgRpcClient, &userRpcClient, func(ctx context.Context, userIDs []string) ([]notification.CommonUser, error) { gs.Notification = notification.NewGroupNotificationSender(database, &msgRpcClient, &userRpcClient, config, func(ctx context.Context, userIDs []string) ([]notification.CommonUser, error) {
users, err := userRpcClient.GetUsersInfo(ctx, userIDs) users, err := userRpcClient.GetUsersInfo(ctx, userIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -141,7 +141,7 @@ func (s *groupServer) NotificationUserInfoUpdate(ctx context.Context, req *pbgro
} }
func (s *groupServer) CheckGroupAdmin(ctx context.Context, groupID string) error { func (s *groupServer) CheckGroupAdmin(ctx context.Context, groupID string) error {
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
groupMember, err := s.db.TakeGroupMember(ctx, groupID, mcontext.GetOpUserID(ctx)) groupMember, err := s.db.TakeGroupMember(ctx, groupID, mcontext.GetOpUserID(ctx))
if err != nil { if err != nil {
return err return err
@ -206,7 +206,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR
if req.OwnerUserID == "" { if req.OwnerUserID == "" {
return nil, errs.ErrArgs.Wrap("no group owner") return nil, errs.ErrArgs.Wrap("no group owner")
} }
if err := authverify.CheckAccessV3(ctx, req.OwnerUserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config); err != nil {
return nil, err return nil, err
} }
userIDs := append(append(req.MemberUserIDs, req.AdminUserIDs...), req.OwnerUserID) userIDs := append(append(req.MemberUserIDs, req.AdminUserIDs...), req.OwnerUserID)
@ -321,7 +321,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR
func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbgroup.GetJoinedGroupListReq) (*pbgroup.GetJoinedGroupListResp, error) { func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbgroup.GetJoinedGroupListReq) (*pbgroup.GetJoinedGroupListResp, error) {
resp := &pbgroup.GetJoinedGroupListResp{} resp := &pbgroup.GetJoinedGroupListResp{}
if err := authverify.CheckAccessV3(ctx, req.FromUserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.FromUserID, s.config); err != nil {
return nil, err return nil, err
} }
total, members, err := s.db.PageGetJoinGroup(ctx, req.FromUserID, req.Pagination) total, members, err := s.db.PageGetJoinGroup(ctx, req.FromUserID, req.Pagination)
@ -391,7 +391,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite
} }
var groupMember *relationtb.GroupMemberModel var groupMember *relationtb.GroupMemberModel
var opUserID string var opUserID string
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
opUserID = mcontext.GetOpUserID(ctx) opUserID = mcontext.GetOpUserID(ctx)
var err error var err error
groupMember, err = s.db.TakeGroupMember(ctx, req.GroupID, opUserID) groupMember, err = s.db.TakeGroupMember(ctx, req.GroupID, opUserID)
@ -407,7 +407,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite
return nil, err return nil, err
} }
if group.NeedVerification == constant.AllNeedVerification { if group.NeedVerification == constant.AllNeedVerification {
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
if !(groupMember.RoleLevel == constant.GroupOwner || groupMember.RoleLevel == constant.GroupAdmin) { if !(groupMember.RoleLevel == constant.GroupOwner || groupMember.RoleLevel == constant.GroupAdmin) {
var requests []*relationtb.GroupRequestModel var requests []*relationtb.GroupRequestModel
for _, userID := range req.InvitedUserIDs { for _, userID := range req.InvitedUserIDs {
@ -547,7 +547,7 @@ func (s *groupServer) KickGroupMember(ctx context.Context, req *pbgroup.KickGrou
for i, member := range members { for i, member := range members {
memberMap[member.UserID] = members[i] memberMap[member.UserID] = members[i]
} }
isAppManagerUid := authverify.IsAppManagerUid(ctx) isAppManagerUid := authverify.IsAppManagerUid(ctx, s.config)
opMember := memberMap[opUserID] opMember := memberMap[opUserID]
for _, userID := range req.KickedUserIDs { for _, userID := range req.KickedUserIDs {
member, ok := memberMap[userID] member, ok := memberMap[userID]
@ -745,7 +745,7 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbgroup
if !utils.Contain(req.HandleResult, constant.GroupResponseAgree, constant.GroupResponseRefuse) { if !utils.Contain(req.HandleResult, constant.GroupResponseAgree, constant.GroupResponseRefuse) {
return nil, errs.ErrArgs.Wrap("HandleResult unknown") return nil, errs.ErrArgs.Wrap("HandleResult unknown")
} }
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
groupMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) groupMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
@ -895,7 +895,7 @@ func (s *groupServer) QuitGroup(ctx context.Context, req *pbgroup.QuitGroupReq)
if req.UserID == "" { if req.UserID == "" {
req.UserID = mcontext.GetOpUserID(ctx) req.UserID = mcontext.GetOpUserID(ctx)
} else { } else {
if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, s.config); err != nil {
return nil, err return nil, err
} }
} }
@ -936,7 +936,7 @@ func (s *groupServer) deleteMemberAndSetConversationSeq(ctx context.Context, gro
func (s *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInfoReq) (*pbgroup.SetGroupInfoResp, error) { func (s *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInfoReq) (*pbgroup.SetGroupInfoResp, error) {
var opMember *relationtb.GroupMemberModel var opMember *relationtb.GroupMemberModel
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
var err error var err error
opMember, err = s.db.TakeGroupMember(ctx, req.GroupInfoForSet.GroupID, mcontext.GetOpUserID(ctx)) opMember, err = s.db.TakeGroupMember(ctx, req.GroupInfoForSet.GroupID, mcontext.GetOpUserID(ctx))
if err != nil { if err != nil {
@ -1055,7 +1055,7 @@ func (s *groupServer) TransferGroupOwner(ctx context.Context, req *pbgroup.Trans
if newOwner == nil { if newOwner == nil {
return nil, errs.ErrArgs.Wrap("NewOwnerUser not in group " + req.NewOwnerUserID) return nil, errs.ErrArgs.Wrap("NewOwnerUser not in group " + req.NewOwnerUserID)
} }
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
if !(mcontext.GetOpUserID(ctx) == oldOwner.UserID && oldOwner.RoleLevel == constant.GroupOwner) { if !(mcontext.GetOpUserID(ctx) == oldOwner.UserID && oldOwner.RoleLevel == constant.GroupOwner) {
return nil, errs.ErrNoPermission.Wrap("no permission transfer group owner") return nil, errs.ErrNoPermission.Wrap("no permission transfer group owner")
} }
@ -1196,7 +1196,7 @@ func (s *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
if owner.UserID != mcontext.GetOpUserID(ctx) { if owner.UserID != mcontext.GetOpUserID(ctx) {
return nil, errs.ErrNoPermission.Wrap("not group owner") return nil, errs.ErrNoPermission.Wrap("not group owner")
} }
@ -1254,7 +1254,7 @@ func (s *groupServer) MuteGroupMember(ctx context.Context, req *pbgroup.MuteGrou
if err := s.PopulateGroupMember(ctx, member); err != nil { if err := s.PopulateGroupMember(ctx, member); err != nil {
return nil, err return nil, err
} }
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
opMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) opMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
@ -1288,7 +1288,7 @@ func (s *groupServer) CancelMuteGroupMember(ctx context.Context, req *pbgroup.Ca
if err := s.PopulateGroupMember(ctx, member); err != nil { if err := s.PopulateGroupMember(ctx, member); err != nil {
return nil, err return nil, err
} }
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, s.config) {
opMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) opMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
@ -1347,7 +1347,7 @@ func (s *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr
if opUserID == "" { if opUserID == "" {
return nil, errs.ErrNoPermission.Wrap("no op user id") return nil, errs.ErrNoPermission.Wrap("no op user id")
} }
isAppManagerUid := authverify.IsAppManagerUid(ctx) isAppManagerUid := authverify.IsAppManagerUid(ctx, s.config)
for i := range req.Members { for i := range req.Members {
req.Members[i].FaceURL = nil req.Members[i].FaceURL = nil
} }

@ -46,7 +46,7 @@ func (m *msgServer) ClearConversationsMsg(
ctx context.Context, ctx context.Context,
req *msg.ClearConversationsMsgReq, req *msg.ClearConversationsMsgReq,
) (*msg.ClearConversationsMsgResp, error) { ) (*msg.ClearConversationsMsgResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil {
return nil, err return nil, err
} }
if err := m.clearConversation(ctx, req.ConversationIDs, req.UserID, req.DeleteSyncOpt); err != nil { if err := m.clearConversation(ctx, req.ConversationIDs, req.UserID, req.DeleteSyncOpt); err != nil {
@ -59,7 +59,7 @@ func (m *msgServer) UserClearAllMsg(
ctx context.Context, ctx context.Context,
req *msg.UserClearAllMsgReq, req *msg.UserClearAllMsgReq,
) (*msg.UserClearAllMsgResp, error) { ) (*msg.UserClearAllMsgResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil {
return nil, err return nil, err
} }
conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID)
@ -74,7 +74,7 @@ func (m *msgServer) UserClearAllMsg(
} }
func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*msg.DeleteMsgsResp, error) { func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*msg.DeleteMsgsResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil {
return nil, err return nil, err
} }
isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(req.DeleteSyncOpt) isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(req.DeleteSyncOpt)
@ -122,7 +122,7 @@ func (m *msgServer) DeleteMsgPhysical(
ctx context.Context, ctx context.Context,
req *msg.DeleteMsgPhysicalReq, req *msg.DeleteMsgPhysicalReq,
) (*msg.DeleteMsgPhysicalResp, error) { ) (*msg.DeleteMsgPhysicalResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil { if err := authverify.CheckAdmin(ctx, m.config); err != nil {
return nil, err return nil, err
} }
remainTime := utils.GetCurrentTimestampBySecond() - req.Timestamp remainTime := utils.GetCurrentTimestampBySecond() - req.Timestamp

@ -43,7 +43,7 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg.
if req.Seq < 0 { if req.Seq < 0 {
return nil, errs.ErrArgs.Wrap("seq is invalid") return nil, errs.ErrArgs.Wrap("seq is invalid")
} }
if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil {
return nil, err return nil, err
} }
user, err := m.User.GetUserInfo(ctx, req.UserID) user, err := m.User.GetUserInfo(ctx, req.UserID)
@ -64,10 +64,10 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg.
data, _ := json.Marshal(msgs[0]) data, _ := json.Marshal(msgs[0])
log.ZInfo(ctx, "GetMsgBySeqs", "conversationID", req.ConversationID, "seq", req.Seq, "msg", string(data)) log.ZInfo(ctx, "GetMsgBySeqs", "conversationID", req.ConversationID, "seq", req.Seq, "msg", string(data))
var role int32 var role int32
if !authverify.IsAppManagerUid(ctx) { if !authverify.IsAppManagerUid(ctx, m.config) {
switch msgs[0].SessionType { switch msgs[0].SessionType {
case constant.SingleChatType: case constant.SingleChatType:
if err := authverify.CheckAccessV3(ctx, msgs[0].SendID); err != nil { if err := authverify.CheckAccessV3(ctx, msgs[0].SendID, m.config); err != nil {
return nil, err return nil, err
} }
role = user.AppMangerLevel role = user.AppMangerLevel

@ -76,13 +76,13 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
if err := mongo.CreateMsgIndex(); err != nil { if err := mongo.CreateMsgIndex(); err != nil {
return err return err
} }
cacheModel := cache.NewMsgCacheModel(rdb) cacheModel := cache.NewMsgCacheModel(rdb, config)
msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase(config.Mongo.Database)) msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase(config.Mongo.Database))
conversationClient := rpcclient.NewConversationRpcClient(client, config) conversationClient := rpcclient.NewConversationRpcClient(client, config)
userRpcClient := rpcclient.NewUserRpcClient(client, config) userRpcClient := rpcclient.NewUserRpcClient(client, config)
groupRpcClient := rpcclient.NewGroupRpcClient(client, config) groupRpcClient := rpcclient.NewGroupRpcClient(client, config)
friendRpcClient := rpcclient.NewFriendRpcClient(client, config) friendRpcClient := rpcclient.NewFriendRpcClient(client, config)
msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, cacheModel) msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, cacheModel, config)
if err != nil { if err != nil {
return err return err
} }
@ -97,7 +97,7 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
friend: &friendRpcClient, friend: &friendRpcClient,
config: config, config: config,
} }
s.notificationSender = rpcclient.NewNotificationSender(rpcclient.WithLocalSendMsg(s.SendMsg)) s.notificationSender = rpcclient.NewNotificationSender(config, rpcclient.WithLocalSendMsg(s.SendMsg))
s.addInterceptorHandler(MessageHasReadEnabled) s.addInterceptorHandler(MessageHasReadEnabled)
msg.RegisterMsgServer(server, s) msg.RegisterMsgServer(server, s)
return nil return nil

@ -90,7 +90,7 @@ func (m *msgServer) PullMessageBySeqs(
} }
func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) { func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil {
return nil, err return nil, err
} }
conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID)

@ -83,7 +83,7 @@ func (t *thirdServer) UploadLogs(ctx context.Context, req *third.UploadLogsReq)
} }
func (t *thirdServer) DeleteLogs(ctx context.Context, req *third.DeleteLogsReq) (*third.DeleteLogsResp, error) { func (t *thirdServer) DeleteLogs(ctx context.Context, req *third.DeleteLogsReq) (*third.DeleteLogsResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil { if err := authverify.CheckAdmin(ctx, t.config); err != nil {
return nil, err return nil, err
} }
userID := "" userID := ""
@ -124,7 +124,7 @@ func dbToPbLogInfos(logs []*relationtb.LogModel) []*third.LogInfo {
} }
func (t *thirdServer) SearchLogs(ctx context.Context, req *third.SearchLogsReq) (*third.SearchLogsResp, error) { func (t *thirdServer) SearchLogs(ctx context.Context, req *third.SearchLogsReq) (*third.SearchLogsResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil { if err := authverify.CheckAdmin(ctx, t.config); err != nil {
return nil, err return nil, err
} }
var ( var (

@ -58,7 +58,7 @@ func (t *thirdServer) PartSize(ctx context.Context, req *third.PartSizeReq) (*th
func (t *thirdServer) InitiateMultipartUpload(ctx context.Context, req *third.InitiateMultipartUploadReq) (*third.InitiateMultipartUploadResp, error) { func (t *thirdServer) InitiateMultipartUpload(ctx context.Context, req *third.InitiateMultipartUploadReq) (*third.InitiateMultipartUploadResp, error) {
defer log.ZDebug(ctx, "return") defer log.ZDebug(ctx, "return")
if err := checkUploadName(ctx, req.Name); err != nil { if err := checkUploadName(ctx, req.Name, t.config); err != nil {
return nil, err return nil, err
} }
expireTime := time.Now().Add(t.defaultExpire) expireTime := time.Now().Add(t.defaultExpire)
@ -137,7 +137,7 @@ func (t *thirdServer) AuthSign(ctx context.Context, req *third.AuthSignReq) (*th
func (t *thirdServer) CompleteMultipartUpload(ctx context.Context, req *third.CompleteMultipartUploadReq) (*third.CompleteMultipartUploadResp, error) { func (t *thirdServer) CompleteMultipartUpload(ctx context.Context, req *third.CompleteMultipartUploadReq) (*third.CompleteMultipartUploadResp, error) {
defer log.ZDebug(ctx, "return") defer log.ZDebug(ctx, "return")
if err := checkUploadName(ctx, req.Name); err != nil { if err := checkUploadName(ctx, req.Name, t.config); err != nil {
return nil, err return nil, err
} }
result, err := t.s3dataBase.CompleteMultipartUpload(ctx, req.UploadID, req.Parts) result, err := t.s3dataBase.CompleteMultipartUpload(ctx, req.UploadID, req.Parts)
@ -194,13 +194,13 @@ func (t *thirdServer) InitiateFormData(ctx context.Context, req *third.InitiateF
if req.Size <= 0 { if req.Size <= 0 {
return nil, errs.ErrArgs.Wrap("size must be greater than 0") return nil, errs.ErrArgs.Wrap("size must be greater than 0")
} }
if err := checkUploadName(ctx, req.Name); err != nil { if err := checkUploadName(ctx, req.Name, t.config); err != nil {
return nil, err return nil, err
} }
var duration time.Duration var duration time.Duration
opUserID := mcontext.GetOpUserID(ctx) opUserID := mcontext.GetOpUserID(ctx)
var key string var key string
if authverify.IsManagerUserID(opUserID) { if authverify.IsManagerUserID(opUserID, t.config) {
if req.Millisecond <= 0 { if req.Millisecond <= 0 {
duration = time.Minute * 10 duration = time.Minute * 10
} else { } else {
@ -260,7 +260,7 @@ func (t *thirdServer) CompleteFormData(ctx context.Context, req *third.CompleteF
if err := json.Unmarshal(data, &mate); err != nil { if err := json.Unmarshal(data, &mate); err != nil {
return nil, errs.ErrArgs.Wrap("invalid id " + err.Error()) return nil, errs.ErrArgs.Wrap("invalid id " + err.Error())
} }
if err := checkUploadName(ctx, mate.Name); err != nil { if err := checkUploadName(ctx, mate.Name, t.config); err != nil {
return nil, err return nil, err
} }
info, err := t.s3dataBase.StatObject(ctx, mate.Key) info, err := t.s3dataBase.StatObject(ctx, mate.Key)

@ -72,11 +72,11 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
var o s3.Interface var o s3.Interface
switch config.Object.Enable { switch config.Object.Enable {
case "minio": case "minio":
o, err = minio.NewMinio(cache.NewMinioCache(rdb)) o, err = minio.NewMinio(cache.NewMinioCache(rdb), config)
case "cos": case "cos":
o, err = cos.NewCos() o, err = cos.NewCos(config)
case "oss": case "oss":
o, err = oss.NewOSS() o, err = oss.NewOSS(config)
default: default:
err = fmt.Errorf("invalid object enable: %s", enable) err = fmt.Errorf("invalid object enable: %s", enable)
} }
@ -85,7 +85,7 @@ func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryReg
} }
third.RegisterThirdServer(server, &thirdServer{ third.RegisterThirdServer(server, &thirdServer{
apiURL: apiURL, apiURL: apiURL,
thirdDatabase: controller.NewThirdDatabase(cache.NewMsgCacheModel(rdb), logdb), thirdDatabase: controller.NewThirdDatabase(cache.NewMsgCacheModel(rdb, config), logdb),
userRpcClient: rpcclient.NewUserRpcClient(client, config), userRpcClient: rpcclient.NewUserRpcClient(client, config),
s3dataBase: controller.NewS3Database(rdb, o, s3db), s3dataBase: controller.NewS3Database(rdb, o, s3db),
defaultExpire: time.Hour * 24 * 7, defaultExpire: time.Hour * 24 * 7,

@ -18,6 +18,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@ -42,7 +43,7 @@ func toPbMapArray(m map[string][]string) []*third.KeyValues {
return res return res
} }
func checkUploadName(ctx context.Context, name string) error { func checkUploadName(ctx context.Context, name string, config *config.GlobalConfig) error {
if name == "" { if name == "" {
return errs.ErrArgs.Wrap("name is empty") return errs.ErrArgs.Wrap("name is empty")
} }
@ -56,7 +57,7 @@ func checkUploadName(ctx context.Context, name string) error {
if opUserID == "" { if opUserID == "" {
return errs.ErrNoPermission.Wrap("opUserID is empty") return errs.ErrNoPermission.Wrap("opUserID is empty")
} }
if !authverify.IsManagerUserID(opUserID) { if !authverify.IsManagerUserID(opUserID, config) {
if !strings.HasPrefix(name, opUserID+"/") { if !strings.HasPrefix(name, opUserID+"/") {
return errs.ErrNoPermission.Wrap(fmt.Sprintf("name must start with `%s/`", opUserID)) return errs.ErrNoPermission.Wrap(fmt.Sprintf("name must start with `%s/`", opUserID))
} }

@ -98,8 +98,8 @@ func Start(config *config.GlobalConfig, client registry.SvcDiscoveryRegistry, se
RegisterCenter: client, RegisterCenter: client,
friendRpcClient: &friendRpcClient, friendRpcClient: &friendRpcClient,
groupRpcClient: &groupRpcClient, groupRpcClient: &groupRpcClient,
friendNotificationSender: notification.NewFriendNotificationSender(&msgRpcClient, notification.WithDBFunc(database.FindWithError)), friendNotificationSender: notification.NewFriendNotificationSender(config, &msgRpcClient, notification.WithDBFunc(database.FindWithError)),
userNotificationSender: notification.NewUserNotificationSender(&msgRpcClient, notification.WithUserFunc(database.FindWithError)), userNotificationSender: notification.NewUserNotificationSender(config, &msgRpcClient, notification.WithUserFunc(database.FindWithError)),
config: config, config: config,
} }
pbuser.RegisterUserServer(server, u) pbuser.RegisterUserServer(server, u)
@ -121,7 +121,7 @@ func (s *userServer) GetDesignateUsers(ctx context.Context, req *pbuser.GetDesig
func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) (resp *pbuser.UpdateUserInfoResp, err error) { func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) (resp *pbuser.UpdateUserInfoResp, err error) {
resp = &pbuser.UpdateUserInfoResp{} resp = &pbuser.UpdateUserInfoResp{}
err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID) err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -155,7 +155,7 @@ func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserI
} }
func (s *userServer) UpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) (resp *pbuser.UpdateUserInfoExResp, err error) { func (s *userServer) UpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) (resp *pbuser.UpdateUserInfoExResp, err error) {
resp = &pbuser.UpdateUserInfoExResp{} resp = &pbuser.UpdateUserInfoExResp{}
err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID) err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -207,7 +207,7 @@ func (s *userServer) AccountCheck(ctx context.Context, req *pbuser.AccountCheckR
if utils.Duplicate(req.CheckUserIDs) { if utils.Duplicate(req.CheckUserIDs) {
return nil, errs.ErrArgs.Wrap("userID repeated") return nil, errs.ErrArgs.Wrap("userID repeated")
} }
err = authverify.CheckAdmin(ctx) err = authverify.CheckAdmin(ctx, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -393,7 +393,7 @@ func (s *userServer) GetSubscribeUsersStatus(ctx context.Context,
// ProcessUserCommandAdd user general function add. // ProcessUserCommandAdd user general function add.
func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.ProcessUserCommandAddReq) (*pbuser.ProcessUserCommandAddResp, error) { func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.ProcessUserCommandAddReq) (*pbuser.ProcessUserCommandAddResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID) err := authverify.CheckAccessV3(ctx, req.UserID, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -424,7 +424,7 @@ func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.Proc
// ProcessUserCommandDelete user general function delete. // ProcessUserCommandDelete user general function delete.
func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.ProcessUserCommandDeleteReq) (*pbuser.ProcessUserCommandDeleteResp, error) { func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.ProcessUserCommandDeleteReq) (*pbuser.ProcessUserCommandDeleteResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID) err := authverify.CheckAccessV3(ctx, req.UserID, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -447,7 +447,7 @@ func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.P
// ProcessUserCommandUpdate user general function update. // ProcessUserCommandUpdate user general function update.
func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.ProcessUserCommandUpdateReq) (*pbuser.ProcessUserCommandUpdateResp, error) { func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.ProcessUserCommandUpdateReq) (*pbuser.ProcessUserCommandUpdateResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID) err := authverify.CheckAccessV3(ctx, req.UserID, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -479,7 +479,7 @@ func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.P
func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.ProcessUserCommandGetReq) (*pbuser.ProcessUserCommandGetResp, error) { func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.ProcessUserCommandGetReq) (*pbuser.ProcessUserCommandGetResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID) err := authverify.CheckAccessV3(ctx, req.UserID, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -508,7 +508,7 @@ func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.Proc
} }
func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.ProcessUserCommandGetAllReq) (*pbuser.ProcessUserCommandGetAllResp, error) { func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.ProcessUserCommandGetAllReq) (*pbuser.ProcessUserCommandGetAllResp, error) {
err := authverify.CheckAccessV3(ctx, req.UserID) err := authverify.CheckAccessV3(ctx, req.UserID, s.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -537,7 +537,7 @@ func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.P
} }
func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.AddNotificationAccountReq) (*pbuser.AddNotificationAccountResp, error) { func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.AddNotificationAccountReq) (*pbuser.AddNotificationAccountResp, error) {
if err := authverify.CheckIMAdmin(ctx); err != nil { if err := authverify.CheckIMAdmin(ctx, s.config); err != nil {
return nil, err return nil, err
} }
@ -580,7 +580,7 @@ func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.Add
} }
func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbuser.UpdateNotificationAccountInfoReq) (*pbuser.UpdateNotificationAccountInfoResp, error) { func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbuser.UpdateNotificationAccountInfoReq) (*pbuser.UpdateNotificationAccountInfoResp, error) {
if err := authverify.CheckIMAdmin(ctx); err != nil { if err := authverify.CheckIMAdmin(ctx, s.config); err != nil {
return nil, err return nil, err
} }
@ -607,7 +607,7 @@ func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbu
func (s *userServer) SearchNotificationAccount(ctx context.Context, req *pbuser.SearchNotificationAccountReq) (*pbuser.SearchNotificationAccountResp, error) { func (s *userServer) SearchNotificationAccount(ctx context.Context, req *pbuser.SearchNotificationAccountReq) (*pbuser.SearchNotificationAccountResp, error) {
// Check if user is an admin // Check if user is an admin
if err := authverify.CheckIMAdmin(ctx); err != nil { if err := authverify.CheckIMAdmin(ctx, s.config); err != nil {
return nil, err return nil, err
} }

@ -78,7 +78,7 @@ func InitMsgTool(config *config.GlobalConfig) (*MsgTool, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
discov, err := kdisc.NewDiscoveryRegister(config.Envs.Discovery) discov, err := kdisc.NewDiscoveryRegister(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,7 +87,7 @@ func InitMsgTool(config *config.GlobalConfig) (*MsgTool, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
msgDatabase, err := controller.InitCommonMsgDatabase(rdb, mongo.GetDatabase(config.Mongo.Database)) msgDatabase, err := controller.InitCommonMsgDatabase(rdb, mongo.GetDatabase(config.Mongo.Database), config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -185,8 +185,8 @@ func (c *MsgTool) AllConversationClearMsgAndFixSeq() {
func (c *MsgTool) ClearConversationsMsg(ctx context.Context, conversationIDs []string) { func (c *MsgTool) ClearConversationsMsg(ctx context.Context, conversationIDs []string) {
for _, conversationID := range conversationIDs { for _, conversationID := range conversationIDs {
if err := c.msgDatabase.DeleteConversationMsgsAndSetMinSeq(ctx, conversationID, int64(c.config.RetainChatRecords*24*60*60)); err != nil { if err := c.msgDatabase.DeleteConversationMsgsAndSetMinSeq(ctx, conversationID, int64(c.Config.RetainChatRecords*24*60*60)); err != nil {
log.ZError(ctx, "DeleteUserSuperGroupMsgsAndSetMinSeq failed", err, "conversationID", conversationID, "DBRetainChatRecords", c.config.RetainChatRecords) log.ZError(ctx, "DeleteUserSuperGroupMsgsAndSetMinSeq failed", err, "conversationID", conversationID, "DBRetainChatRecords", c.Config.RetainChatRecords)
} }
if err := c.checkMaxSeq(ctx, conversationID); err != nil { if err := c.checkMaxSeq(ctx, conversationID); err != nil {
log.ZError(ctx, "fixSeq failed", err, "conversationID", conversationID) log.ZError(ctx, "fixSeq failed", err, "conversationID", conversationID)

@ -75,7 +75,7 @@ func ParseRedisInterfaceToken(redisToken any, secret string) (*tokenverify.Claim
return tokenverify.GetClaimFromToken(string(redisToken.([]uint8)), Secret(secret)) return tokenverify.GetClaimFromToken(string(redisToken.([]uint8)), Secret(secret))
} }
func IsManagerUserID(opUserID string, config config.GlobalConfig) bool { func IsManagerUserID(opUserID string, config *config.GlobalConfig) bool {
return (len(config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Manager.UserID)) || utils.IsContain(opUserID, config.IMAdmin.UserID) return (len(config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Manager.UserID)) || utils.IsContain(opUserID, config.IMAdmin.UserID)
} }

@ -16,6 +16,7 @@ package controller
import ( import (
"context" "context"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
@ -36,14 +37,14 @@ type AuthDatabase interface {
} }
type authDatabase struct { type authDatabase struct {
cache cache.MsgModel cache cache.MsgModel
accessSecret string accessSecret string
accessExpire int64 accessExpire int64
config *config.GlobalConfig
} }
func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64) AuthDatabase { func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64, config *config.GlobalConfig) AuthDatabase {
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire} return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, config: config}
} }
// 结果为空 不返回错误. // 结果为空 不返回错误.
@ -63,7 +64,7 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
} }
var deleteTokenKey []string var deleteTokenKey []string
for k, v := range tokens { for k, v := range tokens {
_, err = tokenverify.GetClaimFromToken(k, authverify.Secret()) _, err = tokenverify.GetClaimFromToken(k, authverify.Secret(a.config.Secret))
if err != nil || v != constant.NormalToken { if err != nil || v != constant.NormalToken {
deleteTokenKey = append(deleteTokenKey, k) deleteTokenKey = append(deleteTokenKey, k)
} }

@ -127,15 +127,28 @@ type CommonMsgDatabase interface {
} }
func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel, config *config.GlobalConfig) (CommonMsgDatabase, error) { func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel, config *config.GlobalConfig) (CommonMsgDatabase, error) {
producerToRedis, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.LatestMsgToRedis.Topic) producerConfig := &kafka.ProducerConfig{
ProducerAck: config.Kafka.ProducerAck,
CompressType: config.Kafka.CompressType,
Username: config.Kafka.Username,
Password: config.Kafka.Password,
}
tlsConfig := &kafka.TLSConfig{
CACrt: config.Kafka.TLS.CACrt,
ClientCrt: config.Kafka.TLS.ClientCrt,
ClientKey: config.Kafka.TLS.ClientKey,
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
InsecureSkipVerify: config.Kafka.TLS.InsecureSkipVerify,
}
producerToRedis, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.LatestMsgToRedis.Topic, producerConfig, tlsConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
producerToMongo, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToMongo.Topic) producerToMongo, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToMongo.Topic, producerConfig, tlsConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
producerToPush, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToPush.Topic) producerToPush, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToPush.Topic, producerConfig, tlsConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -33,27 +33,28 @@ import (
) )
func Test_BatchInsertChat2DB(t *testing.T) { func Test_BatchInsertChat2DB(t *testing.T) {
config.Config.Mongo.Address = []string{"192.168.44.128:37017"} conf := config.NewGlobalConfig()
// config.Config.Mongo.Timeout = 60 conf.Mongo.Address = []string{"192.168.44.128:37017"}
config.Config.Mongo.Database = "openIM" // conf.Mongo.Timeout = 60
// config.Config.Mongo.Source = "admin" conf.Mongo.Database = "openIM"
config.Config.Mongo.Username = "root" // conf.Mongo.Source = "admin"
config.Config.Mongo.Password = "openIM123" conf.Mongo.Username = "root"
config.Config.Mongo.MaxPoolSize = 100 conf.Mongo.Password = "openIM123"
config.Config.RetainChatRecords = 3650 conf.Mongo.MaxPoolSize = 100
config.Config.ChatRecordsClearTime = "0 2 * * 3" conf.RetainChatRecords = 3650
conf.ChatRecordsClearTime = "0 2 * * 3"
mongo, err := unrelation.NewMongo()
mongo, err := unrelation.NewMongo(conf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = mongo.GetDatabase().Client().Ping(context.Background(), nil) err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
db := &commonMsgDatabase{ db := &commonMsgDatabase{
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()), msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)),
} }
//ctx := context.Background() //ctx := context.Background()
@ -70,7 +71,7 @@ func Test_BatchInsertChat2DB(t *testing.T) {
//} //}
_ = db.BatchInsertChat2DB _ = db.BatchInsertChat2DB
c := mongo.GetDatabase().Collection("msg") c := mongo.GetDatabase(conf.Mongo.Database).Collection("msg")
ch := make(chan int) ch := make(chan int)
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
@ -144,26 +145,27 @@ func Test_BatchInsertChat2DB(t *testing.T) {
} }
func GetDB() *commonMsgDatabase { func GetDB() *commonMsgDatabase {
config.Config.Mongo.Address = []string{"203.56.175.233:37017"} conf := config.NewGlobalConfig()
// config.Config.Mongo.Timeout = 60 conf.Mongo.Address = []string{"203.56.175.233:37017"}
config.Config.Mongo.Database = "openim_v3" // conf.Mongo.Timeout = 60
// config.Config.Mongo.Source = "admin" conf.Mongo.Database = "openim_v3"
config.Config.Mongo.Username = "root" // conf.Mongo.Source = "admin"
config.Config.Mongo.Password = "openIM123" conf.Mongo.Username = "root"
config.Config.Mongo.MaxPoolSize = 100 conf.Mongo.Password = "openIM123"
config.Config.RetainChatRecords = 3650 conf.Mongo.MaxPoolSize = 100
config.Config.ChatRecordsClearTime = "0 2 * * 3" conf.RetainChatRecords = 3650
conf.ChatRecordsClearTime = "0 2 * * 3"
mongo, err := unrelation.NewMongo()
mongo, err := unrelation.NewMongo(conf)
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = mongo.GetDatabase().Client().Ping(context.Background(), nil) err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return &commonMsgDatabase{ return &commonMsgDatabase{
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()), msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)),
} }
} }

@ -59,7 +59,7 @@ const (
const successCode = http.StatusOK const successCode = http.StatusOK
func NewMinio(cache cache.MinioCache, config config.GlobalConfig) (s3.Interface, error) { func NewMinio(cache cache.MinioCache, config *config.GlobalConfig) (s3.Interface, error) {
u, err := url.Parse(config.Object.Minio.Endpoint) u, err := url.Parse(config.Object.Minio.Endpoint)
if err != nil { if err != nil {
return nil, err return nil, err
@ -124,7 +124,7 @@ type Minio struct {
init bool init bool
prefix string prefix string
cache cache.MinioCache cache cache.MinioCache
config config.GlobalConfig config *config.GlobalConfig
} }
func (m *Minio) initMinio(ctx context.Context) error { func (m *Minio) initMinio(ctx context.Context) error {

@ -57,8 +57,8 @@ const (
videoSnapshotImageJpg = "jpg" videoSnapshotImageJpg = "jpg"
) )
func NewOSS() (s3.Interface, error) { func NewOSS(config *config.GlobalConfig) (s3.Interface, error) {
conf := config.Config.Object.Oss conf := config.Object.Oss
if conf.BucketURL == "" { if conf.BucketURL == "" {
return nil, errors.New("bucket url is empty") return nil, errors.New("bucket url is empty")
} }
@ -78,6 +78,7 @@ func NewOSS() (s3.Interface, error) {
bucket: bucket, bucket: bucket,
credentials: client.Config.GetCredentials(), credentials: client.Config.GetCredentials(),
um: *(*urlMaker)(reflect.ValueOf(bucket.Client.Conn).Elem().FieldByName("url").UnsafePointer()), um: *(*urlMaker)(reflect.ValueOf(bucket.Client.Conn).Elem().FieldByName("url").UnsafePointer()),
PublicRead: conf.PublicRead,
}, nil }, nil
} }
@ -86,6 +87,7 @@ type OSS struct {
bucket *oss.Bucket bucket *oss.Bucket
credentials oss.Credentials credentials oss.Credentials
um urlMaker um urlMaker
PublicRead bool
} }
func (o *OSS) Engine() string { func (o *OSS) Engine() string {
@ -282,7 +284,7 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin
} }
func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) { func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) {
publicRead := config.Config.Object.Oss.PublicRead publicRead := o.PublicRead
var opts []oss.Option var opts []oss.Option
if opt != nil { if opt != nil {
if opt.Image != nil { if opt.Image != nil {

@ -16,6 +16,7 @@ package discoveryregister
import ( import (
"errors" "errors"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"os" "os"
"github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister/direct" "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister/direct"
@ -27,17 +28,17 @@ import (
) )
// NewDiscoveryRegister creates a new service discovery and registry client based on the provided environment type. // NewDiscoveryRegister creates a new service discovery and registry client based on the provided environment type.
func NewDiscoveryRegister(envType string) (discoveryregistry.SvcDiscoveryRegistry, error) { func NewDiscoveryRegister(config *config.GlobalConfig) (discoveryregistry.SvcDiscoveryRegistry, error) {
if os.Getenv("ENVS_DISCOVERY") != "" { if os.Getenv("ENVS_DISCOVERY") != "" {
envType = os.Getenv("ENVS_DISCOVERY") config.Envs.Discovery = os.Getenv("ENVS_DISCOVERY")
} }
switch envType { switch config.Envs.Discovery {
case "zookeeper": case "zookeeper":
return zookeeper.NewZookeeperDiscoveryRegister() return zookeeper.NewZookeeperDiscoveryRegister(config)
case "k8s": case "k8s":
return kubernetes.NewK8sDiscoveryRegister() return kubernetes.NewK8sDiscoveryRegister(config.RpcRegisterName.OpenImMessageGatewayName)
case "direct": case "direct":
return direct.NewConnDirect() return direct.NewConnDirect()
default: default:

@ -15,6 +15,7 @@
package discoveryregister package discoveryregister
import ( import (
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"os" "os"
"testing" "testing"
@ -32,20 +33,23 @@ func setupTestEnvironment() {
func TestNewDiscoveryRegister(t *testing.T) { func TestNewDiscoveryRegister(t *testing.T) {
setupTestEnvironment() setupTestEnvironment()
conf := config.NewGlobalConfig()
tests := []struct { tests := []struct {
envType string envType string
gatewayName string
expectedError bool expectedError bool
expectedResult bool expectedResult bool
}{ }{
{"zookeeper", false, true}, {"zookeeper", "MessageGateway", false, true},
{"k8s", false, true}, // 假设 k8s 配置也已正确设置 {"k8s", "MessageGateway", false, true}, // 假设 k8s 配置也已正确设置
{"direct", false, true}, {"direct", "MessageGateway", false, true},
{"invalid", true, false}, {"invalid", "MessageGateway", true, false},
} }
for _, test := range tests { for _, test := range tests {
client, err := NewDiscoveryRegister(test.envType) conf.Envs.Discovery = test.envType
conf.RpcRegisterName.OpenImMessageGatewayName = test.gatewayName
client, err := NewDiscoveryRegister(conf)
if test.expectedError { if test.expectedError {
assert.Error(t, err) assert.Error(t, err)

@ -28,8 +28,6 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/log"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
) )
// K8sDR represents the Kubernetes service discovery and registration client. // K8sDR represents the Kubernetes service discovery and registration client.
@ -37,11 +35,12 @@ type K8sDR struct {
options []grpc.DialOption options []grpc.DialOption
rpcRegisterAddr string rpcRegisterAddr string
gatewayHostConsistent *consistent.Consistent gatewayHostConsistent *consistent.Consistent
gatewayName string
} }
func NewK8sDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) { func NewK8sDiscoveryRegister(gatewayName string) (discoveryregistry.SvcDiscoveryRegistry, error) {
gatewayConsistent := consistent.New() gatewayConsistent := consistent.New()
gatewayHosts := getMsgGatewayHost(context.Background()) gatewayHosts := getMsgGatewayHost(context.Background(), gatewayName)
for _, v := range gatewayHosts { for _, v := range gatewayHosts {
gatewayConsistent.Add(v) gatewayConsistent.Add(v)
} }
@ -49,10 +48,10 @@ func NewK8sDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) {
} }
func (cli *K8sDR) Register(serviceName, host string, port int, opts ...grpc.DialOption) error { func (cli *K8sDR) Register(serviceName, host string, port int, opts ...grpc.DialOption) error {
if serviceName != config.Config.RpcRegisterName.OpenImMessageGatewayName { if serviceName != cli.gatewayName {
cli.rpcRegisterAddr = serviceName cli.rpcRegisterAddr = serviceName
} else { } else {
cli.rpcRegisterAddr = getSelfHost(context.Background()) cli.rpcRegisterAddr = getSelfHost(context.Background(), cli.gatewayName)
} }
return nil return nil
@ -84,15 +83,15 @@ func (cli *K8sDR) GetUserIdHashGatewayHost(ctx context.Context, userId string) (
} }
return host, err return host, err
} }
func getSelfHost(ctx context.Context) string { func getSelfHost(ctx context.Context, gatewayName string) string {
port := 88 port := 88
instance := "openimserver" instance := "openimserver"
selfPodName := os.Getenv("MY_POD_NAME") selfPodName := os.Getenv("MY_POD_NAME")
ns := os.Getenv("MY_POD_NAMESPACE") ns := os.Getenv("MY_POD_NAMESPACE")
statefuleIndex := 0 statefuleIndex := 0
gatewayEnds := strings.Split(config.Config.RpcRegisterName.OpenImMessageGatewayName, ":") gatewayEnds := strings.Split(gatewayName, ":")
if len(gatewayEnds) != 2 { if len(gatewayEnds) != 2 {
log.ZError(ctx, "msggateway RpcRegisterName is error:config.Config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error")) log.ZError(ctx, "msggateway RpcRegisterName is error:config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error"))
} else { } else {
port, _ = strconv.Atoi(gatewayEnds[1]) port, _ = strconv.Atoi(gatewayEnds[1])
} }
@ -105,15 +104,15 @@ func getSelfHost(ctx context.Context) string {
} }
// like openimserver-openim-msggateway-0.openimserver-openim-msggateway-headless.openim-lin.svc.cluster.local:88. // like openimserver-openim-msggateway-0.openimserver-openim-msggateway-headless.openim-lin.svc.cluster.local:88.
func getMsgGatewayHost(ctx context.Context) []string { func getMsgGatewayHost(ctx context.Context, gatewayName string) []string {
port := 88 port := 88
instance := "openimserver" instance := "openimserver"
selfPodName := os.Getenv("MY_POD_NAME") selfPodName := os.Getenv("MY_POD_NAME")
replicas := os.Getenv("MY_MSGGATEWAY_REPLICACOUNT") replicas := os.Getenv("MY_MSGGATEWAY_REPLICACOUNT")
ns := os.Getenv("MY_POD_NAMESPACE") ns := os.Getenv("MY_POD_NAMESPACE")
gatewayEnds := strings.Split(config.Config.RpcRegisterName.OpenImMessageGatewayName, ":") gatewayEnds := strings.Split(gatewayName, ":")
if len(gatewayEnds) != 2 { if len(gatewayEnds) != 2 {
log.ZError(ctx, "msggateway RpcRegisterName is error:config.Config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error")) log.ZError(ctx, "msggateway RpcRegisterName is error:config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error"))
} else { } else {
port, _ = strconv.Atoi(gatewayEnds[1]) port, _ = strconv.Atoi(gatewayEnds[1])
} }
@ -134,7 +133,7 @@ func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc
// This conditional checks if the serviceName is not the OpenImMessageGatewayName. // This conditional checks if the serviceName is not the OpenImMessageGatewayName.
// It seems to handle a special case for the OpenImMessageGateway. // It seems to handle a special case for the OpenImMessageGateway.
if serviceName != config.Config.RpcRegisterName.OpenImMessageGatewayName { if serviceName != cli.gatewayName {
// DialContext creates a client connection to the given target (serviceName) using the specified context. // DialContext creates a client connection to the given target (serviceName) using the specified context.
// 'cli.options' are likely default or common options for all connections in this struct. // 'cli.options' are likely default or common options for all connections in this struct.
// 'opts...' allows for additional gRPC dial options to be passed and used. // 'opts...' allows for additional gRPC dial options to be passed and used.
@ -149,7 +148,7 @@ func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc
// getMsgGatewayHost presumably retrieves hosts for the message gateway service. // getMsgGatewayHost presumably retrieves hosts for the message gateway service.
// The context is passed, likely for cancellation and timeout control. // The context is passed, likely for cancellation and timeout control.
gatewayHosts := getMsgGatewayHost(ctx) gatewayHosts := getMsgGatewayHost(ctx, cli.gatewayName)
// Iterating over the retrieved gateway hosts. // Iterating over the retrieved gateway hosts.
for _, host := range gatewayHosts { for _, host := range gatewayHosts {

@ -30,11 +30,11 @@ import (
) )
// NewZookeeperDiscoveryRegister creates a new instance of ZookeeperDR for Zookeeper service discovery and registration. // NewZookeeperDiscoveryRegister creates a new instance of ZookeeperDR for Zookeeper service discovery and registration.
func NewZookeeperDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) { func NewZookeeperDiscoveryRegister(config *config.GlobalConfig) (discoveryregistry.SvcDiscoveryRegistry, error) {
schema := getEnv("ZOOKEEPER_SCHEMA", config.Config.Zookeeper.Schema) schema := getEnv("ZOOKEEPER_SCHEMA", config.Zookeeper.Schema)
zkAddr := getZkAddrFromEnv(config.Config.Zookeeper.ZkAddr) zkAddr := getZkAddrFromEnv(config.Zookeeper.ZkAddr)
username := getEnv("ZOOKEEPER_USERNAME", config.Config.Zookeeper.Username) username := getEnv("ZOOKEEPER_USERNAME", config.Zookeeper.Username)
password := getEnv("ZOOKEEPER_PASSWORD", config.Config.Zookeeper.Password) password := getEnv("ZOOKEEPER_PASSWORD", config.Zookeeper.Password)
zk, err := openkeeper.NewClient( zk, err := openkeeper.NewClient(
zkAddr, zkAddr,
@ -48,10 +48,10 @@ func NewZookeeperDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, er
if err != nil { if err != nil {
uriFormat := "address:%s, username:%s, password:%s, schema:%s." uriFormat := "address:%s, username:%s, password:%s, schema:%s."
errInfo := fmt.Sprintf(uriFormat, errInfo := fmt.Sprintf(uriFormat,
config.Config.Zookeeper.ZkAddr, config.Zookeeper.ZkAddr,
config.Config.Zookeeper.Username, config.Zookeeper.Username,
config.Config.Zookeeper.Password, config.Zookeeper.Password,
config.Config.Zookeeper.Schema) config.Zookeeper.Schema)
return nil, errs.Wrap(err, errInfo) return nil, errs.Wrap(err, errInfo)
} }
return zk, nil return zk, nil

@ -30,17 +30,24 @@ type Consumer struct {
Consumer sarama.Consumer Consumer sarama.Consumer
} }
func NewKafkaConsumer(addr []string, topic string) *Consumer { func NewKafkaConsumer(addr []string, topic string, config *config.GlobalConfig) *Consumer {
p := Consumer{} p := Consumer{}
p.Topic = topic p.Topic = topic
p.addr = addr p.addr = addr
consumerConfig := sarama.NewConfig() consumerConfig := sarama.NewConfig()
if config.Config.Kafka.Username != "" && config.Config.Kafka.Password != "" { if config.Kafka.Username != "" && config.Kafka.Password != "" {
consumerConfig.Net.SASL.Enable = true consumerConfig.Net.SASL.Enable = true
consumerConfig.Net.SASL.User = config.Config.Kafka.Username consumerConfig.Net.SASL.User = config.Kafka.Username
consumerConfig.Net.SASL.Password = config.Config.Kafka.Password consumerConfig.Net.SASL.Password = config.Kafka.Password
} }
SetupTLSConfig(consumerConfig) tlsConfig := &TLSConfig{
CACrt: config.Kafka.TLS.CACrt,
ClientCrt: config.Kafka.TLS.ClientCrt,
ClientKey: config.Kafka.TLS.ClientKey,
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
InsecureSkipVerify: config.Kafka.TLS.InsecureSkipVerify,
}
SetupTLSConfig(consumerConfig, tlsConfig)
consumer, err := sarama.NewConsumer(p.addr, consumerConfig) consumer, err := sarama.NewConsumer(p.addr, consumerConfig)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())

@ -24,8 +24,6 @@ import (
"github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/log"
"strings" "strings"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
) )
type MConsumerGroup struct { type MConsumerGroup struct {
@ -38,22 +36,25 @@ type MConsumerGroupConfig struct {
KafkaVersion sarama.KafkaVersion KafkaVersion sarama.KafkaVersion
OffsetsInitial int64 OffsetsInitial int64
IsReturnErr bool IsReturnErr bool
UserName string
Password string
} }
func NewMConsumerGroup(consumerConfig *MConsumerGroupConfig, topics, addrs []string, groupID string) (*MConsumerGroup, error) { func NewMConsumerGroup(consumerConfig *MConsumerGroupConfig, topics, addrs []string, groupID string, tlsConfig *TLSConfig) (*MConsumerGroup, error) {
consumerGroupConfig := sarama.NewConfig() consumerGroupConfig := sarama.NewConfig()
consumerGroupConfig.Version = consumerConfig.KafkaVersion consumerGroupConfig.Version = consumerConfig.KafkaVersion
consumerGroupConfig.Consumer.Offsets.Initial = consumerConfig.OffsetsInitial consumerGroupConfig.Consumer.Offsets.Initial = consumerConfig.OffsetsInitial
consumerGroupConfig.Consumer.Return.Errors = consumerConfig.IsReturnErr consumerGroupConfig.Consumer.Return.Errors = consumerConfig.IsReturnErr
if config.Config.Kafka.Username != "" && config.Config.Kafka.Password != "" { if consumerConfig.UserName != "" && consumerConfig.Password != "" {
consumerGroupConfig.Net.SASL.Enable = true consumerGroupConfig.Net.SASL.Enable = true
consumerGroupConfig.Net.SASL.User = config.Config.Kafka.Username consumerGroupConfig.Net.SASL.User = consumerConfig.UserName
consumerGroupConfig.Net.SASL.Password = config.Config.Kafka.Password consumerGroupConfig.Net.SASL.Password = consumerConfig.Password
} }
SetupTLSConfig(consumerGroupConfig)
SetupTLSConfig(consumerGroupConfig, tlsConfig)
consumerGroup, err := sarama.NewConsumerGroup(addrs, groupID, consumerGroupConfig) consumerGroup, err := sarama.NewConsumerGroup(addrs, groupID, consumerGroupConfig)
if err != nil { if err != nil {
return nil, errs.Wrap(err, strings.Join(topics, ","), strings.Join(addrs, ","), groupID, config.Config.Kafka.Username, config.Config.Kafka.Password) return nil, errs.Wrap(err, strings.Join(topics, ","), strings.Join(addrs, ","), groupID, consumerConfig.UserName, consumerConfig.Password)
} }
return &MConsumerGroup{ return &MConsumerGroup{

@ -29,8 +29,6 @@ import (
"github.com/OpenIMSDK/tools/mcontext" "github.com/OpenIMSDK/tools/mcontext"
"github.com/OpenIMSDK/tools/utils" "github.com/OpenIMSDK/tools/utils"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
) )
const maxRetry = 10 // number of retries const maxRetry = 10 // number of retries
@ -45,8 +43,15 @@ type Producer struct {
producer sarama.SyncProducer producer sarama.SyncProducer
} }
type ProducerConfig struct {
ProducerAck string
CompressType string
Username string
Password string
}
// NewKafkaProducer initializes a new Kafka producer. // NewKafkaProducer initializes a new Kafka producer.
func NewKafkaProducer(addr []string, topic string) (*Producer, error) { func NewKafkaProducer(addr []string, topic string, producerConfig *ProducerConfig, tlsConfig *TLSConfig) (*Producer, error) {
p := Producer{ p := Producer{
addr: addr, addr: addr,
topic: topic, topic: topic,
@ -61,14 +66,14 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) {
p.config.Producer.Partitioner = sarama.NewHashPartitioner p.config.Producer.Partitioner = sarama.NewHashPartitioner
// Configure producer acknowledgement level // Configure producer acknowledgement level
configureProducerAck(&p, config.Config.Kafka.ProducerAck) configureProducerAck(&p, producerConfig.ProducerAck)
// Configure message compression // Configure message compression
configureCompression(&p, config.Config.Kafka.CompressType) configureCompression(&p, producerConfig.CompressType)
// Get Kafka configuration from environment variables or fallback to config file // Get Kafka configuration from environment variables or fallback to config file
kafkaUsername := getEnvOrConfig("KAFKA_USERNAME", config.Config.Kafka.Username) kafkaUsername := getEnvOrConfig("KAFKA_USERNAME", producerConfig.Username)
kafkaPassword := getEnvOrConfig("KAFKA_PASSWORD", config.Config.Kafka.Password) kafkaPassword := getEnvOrConfig("KAFKA_PASSWORD", producerConfig.Password)
kafkaAddr := getKafkaAddrFromEnv(addr) // Updated to use the new function kafkaAddr := getKafkaAddrFromEnv(addr) // Updated to use the new function
// Configure SASL authentication if credentials are provided // Configure SASL authentication if credentials are provided
@ -82,7 +87,7 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) {
p.addr = kafkaAddr p.addr = kafkaAddr
// Set up TLS configuration (if required) // Set up TLS configuration (if required)
SetupTLSConfig(p.config) SetupTLSConfig(p.config, tlsConfig)
// Create the producer with retries // Create the producer with retries
var err error var err error

@ -21,19 +21,27 @@ import (
"github.com/IBM/sarama" "github.com/IBM/sarama"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/common/tls" "github.com/openimsdk/open-im-server/v3/pkg/common/tls"
) )
type TLSConfig struct {
CACrt string
ClientCrt string
ClientKey string
ClientKeyPwd string
InsecureSkipVerify bool
}
// SetupTLSConfig set up the TLS config from config file. // SetupTLSConfig set up the TLS config from config file.
func SetupTLSConfig(cfg *sarama.Config) { func SetupTLSConfig(cfg *sarama.Config, tlsConfig *TLSConfig) {
if config.Config.Kafka.TLS != nil { if tlsConfig != nil {
cfg.Net.TLS.Enable = true cfg.Net.TLS.Enable = true
cfg.Net.TLS.Config = tls.NewTLSConfig( cfg.Net.TLS.Config = tls.NewTLSConfig(
config.Config.Kafka.TLS.ClientCrt, tlsConfig.ClientCrt,
config.Config.Kafka.TLS.ClientKey, tlsConfig.ClientKey,
config.Config.Kafka.TLS.CACrt, tlsConfig.CACrt,
[]byte(config.Config.Kafka.TLS.ClientKeyPwd), []byte(tlsConfig.ClientKeyPwd),
tlsConfig.InsecureSkipVerify,
) )
} }
} }

@ -70,7 +70,7 @@ func Start(
} }
defer listener.Close() defer listener.Close()
client, err := kdisc.NewDiscoveryRegister(config.Envs.Discovery) client, err := kdisc.NewDiscoveryRegister(config)
if err != nil { if err != nil {
return errs.Wrap(err) return errs.Wrap(err)
} }

@ -20,8 +20,6 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"os" "os"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
) )
// decryptPEM decrypts a PEM block using a password. // decryptPEM decrypts a PEM block using a password.
@ -49,7 +47,7 @@ func readEncryptablePEMBlock(path string, pwd []byte) ([]byte, error) {
} }
// NewTLSConfig setup the TLS config from general config file. // NewTLSConfig setup the TLS config from general config file.
func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte) *tls.Config { func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte, insecureSkipVerify bool) *tls.Config {
tlsConfig := tls.Config{} tlsConfig := tls.Config{}
if clientCertFile != "" && clientKeyFile != "" { if clientCertFile != "" && clientKeyFile != "" {
@ -79,7 +77,7 @@ func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byt
} }
tlsConfig.RootCAs = caCertPool tlsConfig.RootCAs = caCertPool
tlsConfig.InsecureSkipVerify = config.Config.Kafka.TLS.InsecureSkipVerify tlsConfig.InsecureSkipVerify = insecureSkipVerify
return &tlsConfig return &tlsConfig
} }

@ -34,47 +34,47 @@ import (
// "google.golang.org/protobuf/proto". // "google.golang.org/protobuf/proto".
) )
func newContentTypeConf() map[int32]config.NotificationConf { func newContentTypeConf(conf *config.GlobalConfig) map[int32]config.NotificationConf {
return map[int32]config.NotificationConf{ return map[int32]config.NotificationConf{
// group // group
constant.GroupCreatedNotification: config.Config.Notification.GroupCreated, constant.GroupCreatedNotification: conf.Notification.GroupCreated,
constant.GroupInfoSetNotification: config.Config.Notification.GroupInfoSet, constant.GroupInfoSetNotification: conf.Notification.GroupInfoSet,
constant.JoinGroupApplicationNotification: config.Config.Notification.JoinGroupApplication, constant.JoinGroupApplicationNotification: conf.Notification.JoinGroupApplication,
constant.MemberQuitNotification: config.Config.Notification.MemberQuit, constant.MemberQuitNotification: conf.Notification.MemberQuit,
constant.GroupApplicationAcceptedNotification: config.Config.Notification.GroupApplicationAccepted, constant.GroupApplicationAcceptedNotification: conf.Notification.GroupApplicationAccepted,
constant.GroupApplicationRejectedNotification: config.Config.Notification.GroupApplicationRejected, constant.GroupApplicationRejectedNotification: conf.Notification.GroupApplicationRejected,
constant.GroupOwnerTransferredNotification: config.Config.Notification.GroupOwnerTransferred, constant.GroupOwnerTransferredNotification: conf.Notification.GroupOwnerTransferred,
constant.MemberKickedNotification: config.Config.Notification.MemberKicked, constant.MemberKickedNotification: conf.Notification.MemberKicked,
constant.MemberInvitedNotification: config.Config.Notification.MemberInvited, constant.MemberInvitedNotification: conf.Notification.MemberInvited,
constant.MemberEnterNotification: config.Config.Notification.MemberEnter, constant.MemberEnterNotification: conf.Notification.MemberEnter,
constant.GroupDismissedNotification: config.Config.Notification.GroupDismissed, constant.GroupDismissedNotification: conf.Notification.GroupDismissed,
constant.GroupMutedNotification: config.Config.Notification.GroupMuted, constant.GroupMutedNotification: conf.Notification.GroupMuted,
constant.GroupCancelMutedNotification: config.Config.Notification.GroupCancelMuted, constant.GroupCancelMutedNotification: conf.Notification.GroupCancelMuted,
constant.GroupMemberMutedNotification: config.Config.Notification.GroupMemberMuted, constant.GroupMemberMutedNotification: conf.Notification.GroupMemberMuted,
constant.GroupMemberCancelMutedNotification: config.Config.Notification.GroupMemberCancelMuted, constant.GroupMemberCancelMutedNotification: conf.Notification.GroupMemberCancelMuted,
constant.GroupMemberInfoSetNotification: config.Config.Notification.GroupMemberInfoSet, constant.GroupMemberInfoSetNotification: conf.Notification.GroupMemberInfoSet,
constant.GroupMemberSetToAdminNotification: config.Config.Notification.GroupMemberSetToAdmin, constant.GroupMemberSetToAdminNotification: conf.Notification.GroupMemberSetToAdmin,
constant.GroupMemberSetToOrdinaryUserNotification: config.Config.Notification.GroupMemberSetToOrdinary, constant.GroupMemberSetToOrdinaryUserNotification: conf.Notification.GroupMemberSetToOrdinary,
constant.GroupInfoSetAnnouncementNotification: config.Config.Notification.GroupInfoSetAnnouncement, constant.GroupInfoSetAnnouncementNotification: conf.Notification.GroupInfoSetAnnouncement,
constant.GroupInfoSetNameNotification: config.Config.Notification.GroupInfoSetName, constant.GroupInfoSetNameNotification: conf.Notification.GroupInfoSetName,
// user // user
constant.UserInfoUpdatedNotification: config.Config.Notification.UserInfoUpdated, constant.UserInfoUpdatedNotification: conf.Notification.UserInfoUpdated,
constant.UserStatusChangeNotification: config.Config.Notification.UserStatusChanged, constant.UserStatusChangeNotification: conf.Notification.UserStatusChanged,
// friend // friend
constant.FriendApplicationNotification: config.Config.Notification.FriendApplicationAdded, constant.FriendApplicationNotification: conf.Notification.FriendApplicationAdded,
constant.FriendApplicationApprovedNotification: config.Config.Notification.FriendApplicationApproved, constant.FriendApplicationApprovedNotification: conf.Notification.FriendApplicationApproved,
constant.FriendApplicationRejectedNotification: config.Config.Notification.FriendApplicationRejected, constant.FriendApplicationRejectedNotification: conf.Notification.FriendApplicationRejected,
constant.FriendAddedNotification: config.Config.Notification.FriendAdded, constant.FriendAddedNotification: conf.Notification.FriendAdded,
constant.FriendDeletedNotification: config.Config.Notification.FriendDeleted, constant.FriendDeletedNotification: conf.Notification.FriendDeleted,
constant.FriendRemarkSetNotification: config.Config.Notification.FriendRemarkSet, constant.FriendRemarkSetNotification: conf.Notification.FriendRemarkSet,
constant.BlackAddedNotification: config.Config.Notification.BlackAdded, constant.BlackAddedNotification: conf.Notification.BlackAdded,
constant.BlackDeletedNotification: config.Config.Notification.BlackDeleted, constant.BlackDeletedNotification: conf.Notification.BlackDeleted,
constant.FriendInfoUpdatedNotification: config.Config.Notification.FriendInfoUpdated, constant.FriendInfoUpdatedNotification: conf.Notification.FriendInfoUpdated,
constant.FriendsInfoUpdateNotification: config.Config.Notification.FriendInfoUpdated, //use the same FriendInfoUpdated constant.FriendsInfoUpdateNotification: conf.Notification.FriendInfoUpdated, //use the same FriendInfoUpdated
// conversation // conversation
constant.ConversationChangeNotification: config.Config.Notification.ConversationChanged, constant.ConversationChangeNotification: conf.Notification.ConversationChanged,
constant.ConversationUnreadNotification: config.Config.Notification.ConversationChanged, constant.ConversationUnreadNotification: conf.Notification.ConversationChanged,
constant.ConversationPrivateChatNotification: config.Config.Notification.ConversationSetPrivate, constant.ConversationPrivateChatNotification: conf.Notification.ConversationSetPrivate,
// msg // msg
constant.MsgRevokeNotification: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg}, constant.MsgRevokeNotification: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg},
constant.HasReadReceipt: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg}, constant.HasReadReceipt: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg},
@ -224,8 +224,8 @@ func WithUserRpcClient(userRpcClient *UserRpcClient) NotificationSenderOptions {
} }
} }
func NewNotificationSender(opts ...NotificationSenderOptions) *NotificationSender { func NewNotificationSender(config *config.GlobalConfig, opts ...NotificationSenderOptions) *NotificationSender {
notificationSender := &NotificationSender{contentTypeConf: newContentTypeConf(), sessionTypeConf: newSessionTypeConf()} notificationSender := &NotificationSender{contentTypeConf: newContentTypeConf(config), sessionTypeConf: newSessionTypeConf()}
for _, opt := range opts { for _, opt := range opts {
opt(notificationSender) opt(notificationSender)
} }

@ -16,6 +16,7 @@ package notification
import ( import (
"context" "context"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/protocol/sdkws"
@ -27,8 +28,8 @@ type ConversationNotificationSender struct {
*rpcclient.NotificationSender *rpcclient.NotificationSender
} }
func NewConversationNotificationSender(msgRpcClient *rpcclient.MessageRpcClient) *ConversationNotificationSender { func NewConversationNotificationSender(config *config.GlobalConfig, msgRpcClient *rpcclient.MessageRpcClient) *ConversationNotificationSender {
return &ConversationNotificationSender{rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient))} return &ConversationNotificationSender{rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient))}
} }
// SetPrivate调用. // SetPrivate调用.

@ -16,6 +16,7 @@ package notification
import ( import (
"context" "context"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/tools/mcontext" "github.com/OpenIMSDK/tools/mcontext"
@ -82,11 +83,12 @@ func WithRpcFunc(
} }
func NewFriendNotificationSender( func NewFriendNotificationSender(
config *config.GlobalConfig,
msgRpcClient *rpcclient.MessageRpcClient, msgRpcClient *rpcclient.MessageRpcClient,
opts ...friendNotificationSenderOptions, opts ...friendNotificationSenderOptions,
) *FriendNotificationSender { ) *FriendNotificationSender {
f := &FriendNotificationSender{ f := &FriendNotificationSender{
NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient)), NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient)),
} }
for _, opt := range opts { for _, opt := range opts {
opt(f) opt(f)

@ -17,6 +17,7 @@ package notification
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
@ -37,12 +38,14 @@ func NewGroupNotificationSender(
db controller.GroupDatabase, db controller.GroupDatabase,
msgRpcClient *rpcclient.MessageRpcClient, msgRpcClient *rpcclient.MessageRpcClient,
userRpcClient *rpcclient.UserRpcClient, userRpcClient *rpcclient.UserRpcClient,
config *config.GlobalConfig,
fn func(ctx context.Context, userIDs []string) ([]CommonUser, error), fn func(ctx context.Context, userIDs []string) ([]CommonUser, error),
) *GroupNotificationSender { ) *GroupNotificationSender {
return &GroupNotificationSender{ return &GroupNotificationSender{
NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient), rpcclient.WithUserRpcClient(userRpcClient)), NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient), rpcclient.WithUserRpcClient(userRpcClient)),
getUsersInfo: fn, getUsersInfo: fn,
db: db, db: db,
config: config,
} }
} }
@ -50,6 +53,7 @@ type GroupNotificationSender struct {
*rpcclient.NotificationSender *rpcclient.NotificationSender
getUsersInfo func(ctx context.Context, userIDs []string) ([]CommonUser, error) getUsersInfo func(ctx context.Context, userIDs []string) ([]CommonUser, error)
db controller.GroupDatabase db controller.GroupDatabase
config *config.GlobalConfig
} }
func (g *GroupNotificationSender) PopulateGroupMember(ctx context.Context, members ...*relation.GroupMemberModel) error { func (g *GroupNotificationSender) PopulateGroupMember(ctx context.Context, members ...*relation.GroupMemberModel) error {
@ -252,7 +256,7 @@ func (g *GroupNotificationSender) fillOpUser(ctx context.Context, opUser **sdkws
} }
userID := mcontext.GetOpUserID(ctx) userID := mcontext.GetOpUserID(ctx)
if groupID != "" { if groupID != "" {
if authverify.IsManagerUserID(userID) { if authverify.IsManagerUserID(userID, g.config) {
*opUser = &sdkws.GroupMemberFullInfo{ *opUser = &sdkws.GroupMemberFullInfo{
GroupID: groupID, GroupID: groupID,
UserID: userID, UserID: userID,

@ -16,6 +16,7 @@ package notification
import ( import (
"context" "context"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/protocol/sdkws"
@ -27,8 +28,8 @@ type MsgNotificationSender struct {
*rpcclient.NotificationSender *rpcclient.NotificationSender
} }
func NewMsgNotificationSender(opts ...rpcclient.NotificationSenderOptions) *MsgNotificationSender { func NewMsgNotificationSender(config *config.GlobalConfig, opts ...rpcclient.NotificationSenderOptions) *MsgNotificationSender {
return &MsgNotificationSender{rpcclient.NewNotificationSender(opts...)} return &MsgNotificationSender{rpcclient.NewNotificationSender(config, opts...)}
} }
func (m *MsgNotificationSender) UserDeleteMsgsNotification(ctx context.Context, userID, conversationID string, seqs []int64) error { func (m *MsgNotificationSender) UserDeleteMsgsNotification(ctx context.Context, userID, conversationID string, seqs []int64) error {

@ -16,6 +16,7 @@ package notification
import ( import (
"context" "context"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/protocol/sdkws"
@ -59,11 +60,12 @@ func WithUserFunc(
} }
func NewUserNotificationSender( func NewUserNotificationSender(
config *config.GlobalConfig,
msgRpcClient *rpcclient.MessageRpcClient, msgRpcClient *rpcclient.MessageRpcClient,
opts ...userNotificationSenderOptions, opts ...userNotificationSenderOptions,
) *UserNotificationSender { ) *UserNotificationSender {
f := &UserNotificationSender{ f := &UserNotificationSender{
NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient)), NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient)),
} }
for _, opt := range opts { for _, opt := range opts {
opt(f) opt(f)

@ -31,8 +31,8 @@ type Push struct {
discov discoveryregistry.SvcDiscoveryRegistry discov discoveryregistry.SvcDiscoveryRegistry
} }
func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push { func NewPush(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Push {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImPushName) conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImPushName)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -45,8 +45,8 @@ func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push {
type PushRpcClient Push type PushRpcClient Push
func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) PushRpcClient { func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) PushRpcClient {
return PushRpcClient(*NewPush(discov)) return PushRpcClient(*NewPush(discov, config))
} }
func (p *PushRpcClient) DelUserPushToken( func (p *PushRpcClient) DelUserPushToken(

@ -42,21 +42,21 @@ func NewThird(discov discoveryregistry.SvcDiscoveryRegistry, config *config.Glob
panic(err) panic(err)
} }
client := third.NewThirdClient(conn) client := third.NewThirdClient(conn)
minioClient, err := minioInit() minioClient, err := minioInit(config)
return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient, Config: config} return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient, Config: config}
} }
func minioInit() (*minio.Client, error) { func minioInit(config *config.GlobalConfig) (*minio.Client, error) {
minioClient := &minio.Client{} minioClient := &minio.Client{}
var initUrl string var initUrl string
initUrl = config.Config.Object.Minio.Endpoint initUrl = config.Object.Minio.Endpoint
minioUrl, err := url.Parse(initUrl) minioUrl, err := url.Parse(initUrl)
if err != nil { if err != nil {
return nil, err return nil, err
} }
opts := &minio.Options{ opts := &minio.Options{
Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, ""), Creds: credentials.NewStaticV4(config.Object.Minio.AccessKeyID, config.Object.Minio.SecretAccessKey, ""),
// Region: config.Config.Credential.Minio.Location, // Region: config.Credential.Minio.Location,
} }
if minioUrl.Scheme == "http" { if minioUrl.Scheme == "http" {
opts.Secure = false opts.Secure = false

@ -163,7 +163,7 @@ func (u *UserRpcClient) Access(ctx context.Context, ownerUserID string) error {
if err != nil { if err != nil {
return err return err
} }
return authverify.CheckAccessV3(ctx, ownerUserID) return authverify.CheckAccessV3(ctx, ownerUserID, u.Config)
} }
// GetAllUserIDs retrieves all user IDs with pagination options. // GetAllUserIDs retrieves all user IDs with pagination options.

@ -47,27 +47,33 @@ var (
cfgPath = flag.String("c", defaultCfgPath, "Path to the configuration file") cfgPath = flag.String("c", defaultCfgPath, "Path to the configuration file")
) )
func initCfg() error { func initCfg(path string) (*config.GlobalConfig, error) {
data, err := os.ReadFile(*cfgPath) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return err return nil, errs.Wrap(err, "ReadFile unmarshal failed")
} }
return yaml.Unmarshal(data, &config.Config) conf := config.NewGlobalConfig()
err = yaml.Unmarshal(data, &conf)
if err != nil {
return nil, errs.Wrap(err, "InitConfig unmarshal failed")
}
return conf, nil
} }
type checkFunc struct { type checkFunc struct {
name string name string
function func() error function func(*config.GlobalConfig) error
flag bool flag bool
config *config.GlobalConfig
} }
func main() { func main() {
flag.Parse() flag.Parse()
if err := initCfg(); err != nil { conf, err := initCfg(defaultCfgPath)
if err != nil {
fmt.Printf("Read config failed: %v\n", err) fmt.Printf("Read config failed: %v\n", err)
return return
} }
@ -75,11 +81,11 @@ func main() {
checks := []checkFunc{ checks := []checkFunc{
//{name: "Mysql", function: checkMysql}, //{name: "Mysql", function: checkMysql},
{name: "Mongo", function: checkMongo}, {name: "Mongo", function: checkMongo, config: conf},
{name: "Redis", function: checkRedis}, {name: "Redis", function: checkRedis, config: conf},
{name: "Minio", function: checkMinio}, {name: "Minio", function: checkMinio, config: conf},
{name: "Zookeeper", function: checkZookeeper}, {name: "Zookeeper", function: checkZookeeper, config: conf},
{name: "Kafka", function: checkKafka}, {name: "Kafka", function: checkKafka, config: conf},
} }
for i := 0; i < maxRetry; i++ { for i := 0; i < maxRetry; i++ {
@ -92,7 +98,7 @@ func main() {
allSuccess := true allSuccess := true
for index, check := range checks { for index, check := range checks {
if !check.flag { if !check.flag {
err = check.function() err = check.function(check.config)
if err != nil { if err != nil {
component.ErrorPrint(fmt.Sprintf("Starting %s failed:%v.", check.name, err)) component.ErrorPrint(fmt.Sprintf("Starting %s failed:%v.", check.name, err))
allSuccess = false allSuccess = false
@ -112,30 +118,30 @@ func main() {
} }
// checkMongo checks the MongoDB connection without retries // checkMongo checks the MongoDB connection without retries
func checkMongo() error { func checkMongo(config *config.GlobalConfig) error {
_, err := unrelation.NewMongo() _, err := unrelation.NewMongo(config)
return err return err
} }
// checkRedis checks the Redis connection // checkRedis checks the Redis connection
func checkRedis() error { func checkRedis(config *config.GlobalConfig) error {
_, err := cache.NewRedis() _, err := cache.NewRedis(config)
return err return err
} }
// checkMinio checks the MinIO connection // checkMinio checks the MinIO connection
func checkMinio() error { func checkMinio(config *config.GlobalConfig) error {
// Check if MinIO is enabled // Check if MinIO is enabled
if config.Config.Object.Enable != "minio" { if config.Object.Enable != "minio" {
return errs.Wrap(errors.New("minio.Enable is empty")) return errs.Wrap(errors.New("minio.Enable is empty"))
} }
minio := &component.Minio{ minio := &component.Minio{
ApiURL: config.Config.Object.ApiURL, ApiURL: config.Object.ApiURL,
Endpoint: config.Config.Object.Minio.Endpoint, Endpoint: config.Object.Minio.Endpoint,
AccessKeyID: config.Config.Object.Minio.AccessKeyID, AccessKeyID: config.Object.Minio.AccessKeyID,
SecretAccessKey: config.Config.Object.Minio.SecretAccessKey, SecretAccessKey: config.Object.Minio.SecretAccessKey,
SignEndpoint: config.Config.Object.Minio.SignEndpoint, SignEndpoint: config.Object.Minio.SignEndpoint,
UseSSL: getEnv("MINIO_USE_SSL", "false"), UseSSL: getEnv("MINIO_USE_SSL", "false"),
} }
err := component.CheckMinio(minio) err := component.CheckMinio(minio)
@ -143,18 +149,18 @@ func checkMinio() error {
} }
// checkZookeeper checks the Zookeeper connection // checkZookeeper checks the Zookeeper connection
func checkZookeeper() error { func checkZookeeper(config *config.GlobalConfig) error {
_, err := zookeeper.NewZookeeperDiscoveryRegister() _, err := zookeeper.NewZookeeperDiscoveryRegister(config)
return err return err
} }
// checkKafka checks the Kafka connection // checkKafka checks the Kafka connection
func checkKafka() error { func checkKafka(config *config.GlobalConfig) error {
// Prioritize environment variables // Prioritize environment variables
kafkaStu := &component.Kafka{ kafkaStu := &component.Kafka{
Username: config.Config.Kafka.Username, Username: config.Kafka.Username,
Password: config.Config.Kafka.Password, Password: config.Kafka.Password,
Addr: config.Config.Kafka.Addr, Addr: config.Kafka.Addr,
} }
kafkaClient, err := component.CheckKafka(kafkaStu) kafkaClient, err := component.CheckKafka(kafkaStu)
@ -170,9 +176,9 @@ func checkKafka() error {
} }
requiredTopics := []string{ requiredTopics := []string{
config.Config.Kafka.MsgToMongo.Topic, config.Kafka.MsgToMongo.Topic,
config.Config.Kafka.MsgToPush.Topic, config.Kafka.MsgToPush.Topic,
config.Config.Kafka.LatestMsgToRedis.Topic, config.Kafka.LatestMsgToRedis.Topic,
} }
for _, requiredTopic := range requiredTopics { for _, requiredTopic := range requiredTopics {
@ -181,11 +187,22 @@ func checkKafka() error {
} }
} }
tlsConfig := &kafka.TLSConfig{
CACrt: config.Kafka.TLS.CACrt,
ClientCrt: config.Kafka.TLS.ClientCrt,
ClientKey: config.Kafka.TLS.ClientKey,
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
InsecureSkipVerify: config.Kafka.TLS.InsecureSkipVerify,
}
_, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ _, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{
KafkaVersion: sarama.V2_0_0_0, KafkaVersion: sarama.V2_0_0_0,
OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, OffsetsInitial: sarama.OffsetNewest,
}, []string{config.Config.Kafka.LatestMsgToRedis.Topic}, IsReturnErr: false,
config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToRedis) UserName: config.Kafka.Username,
Password: config.Kafka.Password,
}, []string{config.Kafka.LatestMsgToRedis.Topic},
config.Kafka.Addr, config.Kafka.ConsumerGroupID.MsgToRedis, tlsConfig)
if err != nil { if err != nil {
return err return err
} }
@ -193,8 +210,8 @@ func checkKafka() error {
_, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ _, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{
KafkaVersion: sarama.V2_0_0_0, KafkaVersion: sarama.V2_0_0_0,
OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false,
}, []string{config.Config.Kafka.MsgToPush.Topic}, }, []string{config.Kafka.MsgToPush.Topic},
config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToMongo) config.Kafka.Addr, config.Kafka.ConsumerGroupID.MsgToMongo, tlsConfig)
if err != nil { if err != nil {
return err return err
} }
@ -202,8 +219,8 @@ func checkKafka() error {
kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{
KafkaVersion: sarama.V2_0_0_0, KafkaVersion: sarama.V2_0_0_0,
OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false,
}, []string{config.Config.Kafka.MsgToPush.Topic}, config.Config.Kafka.Addr, }, []string{config.Kafka.MsgToPush.Topic}, config.Kafka.Addr,
config.Config.Kafka.ConsumerGroupID.MsgToPush) config.Kafka.ConsumerGroupID.MsgToPush, tlsConfig)
if err != nil { if err != nil {
return err return err
} }

@ -34,7 +34,8 @@ func mockInitCfg() error {
} }
func TestRedis(t *testing.T) { func TestRedis(t *testing.T) {
config.Config.Redis.Address = []string{ conf, err := initCfg(defaultCfgPath)
conf.Redis.Address = []string{
"172.16.8.142:7000", "172.16.8.142:7000",
//"172.16.8.142:7000", "172.16.8.142:7001", "172.16.8.142:7002", "172.16.8.142:7003", "172.16.8.142:7004", "172.16.8.142:7005", //"172.16.8.142:7000", "172.16.8.142:7001", "172.16.8.142:7002", "172.16.8.142:7003", "172.16.8.142:7004", "172.16.8.142:7005",
} }
@ -45,20 +46,20 @@ func TestRedis(t *testing.T) {
redisClient.Close() redisClient.Close()
} }
}() }()
if len(config.Config.Redis.Address) > 1 { if len(conf.Redis.Address) > 1 {
redisClient = redis.NewClusterClient(&redis.ClusterOptions{ redisClient = redis.NewClusterClient(&redis.ClusterOptions{
Addrs: config.Config.Redis.Address, Addrs: conf.Redis.Address,
Username: config.Config.Redis.Username, Username: conf.Redis.Username,
Password: config.Config.Redis.Password, Password: conf.Redis.Password,
}) })
} else { } else {
redisClient = redis.NewClient(&redis.Options{ redisClient = redis.NewClient(&redis.Options{
Addr: config.Config.Redis.Address[0], Addr: conf.Redis.Address[0],
Username: config.Config.Redis.Username, Username: conf.Redis.Username,
Password: config.Config.Redis.Password, Password: conf.Redis.Password,
}) })
} }
_, err := redisClient.Ping(context.Background()).Result() _, err = redisClient.Ping(context.Background()).Result()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -18,6 +18,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/OpenIMSDK/tools/errs"
"log" "log"
"os" "os"
"reflect" "reflect"
@ -45,36 +46,43 @@ const (
versionValue = 35 versionValue = 35
) )
func InitConfig(path string) error { func InitConfig(path string) (*config.GlobalConfig, error) {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return err return nil, errs.Wrap(err, "ReadFile unmarshal failed")
} }
return yaml.Unmarshal(data, &config.Config)
conf := config.NewGlobalConfig()
err = yaml.Unmarshal(data, &conf)
if err != nil {
return nil, errs.Wrap(err, "InitConfig unmarshal failed")
}
return conf, nil
} }
func GetMysql() (*gorm.DB, error) { func GetMysql(config *config.GlobalConfig) (*gorm.DB, error) {
conf := config.Config.Mysql conf := config.Mysql
mysqlDSN := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", conf.Username, conf.Password, conf.Address[0], conf.Database) mysqlDSN := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", conf.Username, conf.Password, conf.Address[0], conf.Database)
return gorm.Open(gormmysql.Open(mysqlDSN), &gorm.Config{Logger: logger.Discard}) return gorm.Open(gormmysql.Open(mysqlDSN), &gorm.Config{Logger: logger.Discard})
} }
func GetMongo() (*mongo.Database, error) { func GetMongo(config *config.GlobalConfig) (*mongo.Database, error) {
mgo, err := unrelation.NewMongo() mgo, err := unrelation.NewMongo(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return mgo.GetDatabase(), nil return mgo.GetDatabase(config.Mongo.Database), nil
} }
func Main(path string) error { func Main(path string) error {
if err := InitConfig(path); err != nil { conf, err := InitConfig(path)
if err != nil {
return err return err
} }
if config.Config.Mysql == nil { if config.Config.Mysql == nil {
return nil return nil
} }
mongoDB, err := GetMongo() mongoDB, err := GetMongo(conf)
if err != nil { if err != nil {
return err return err
} }
@ -91,7 +99,7 @@ func Main(path string) error {
default: default:
return err return err
} }
mysqlDB, err := GetMysql() mysqlDB, err := GetMysql(conf)
if err != nil { if err != nil {
if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == 1049 { if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == 1049 {
if err := SetMongoDataVersion(mongoDB, version.Value); err != nil { if err := SetMongoDataVersion(mongoDB, version.Value); err != nil {

Loading…
Cancel
Save