diff --git a/cmd/openim-api/main.go b/cmd/openim-api/main.go index 0fb609aa3..9a307686d 100644 --- a/cmd/openim-api/main.go +++ b/cmd/openim-api/main.go @@ -15,115 +15,17 @@ package main import ( - "context" - "fmt" - "net" - "net/http" + util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" _ "net/http/pprof" - "os" - "os/signal" - "strconv" - "syscall" - "time" - "github.com/OpenIMSDK/protocol/constant" - "github.com/OpenIMSDK/tools/discoveryregistry" - "github.com/OpenIMSDK/tools/errs" - "github.com/openimsdk/open-im-server/v3/internal/api" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" - kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister" - ginprom "github.com/openimsdk/open-im-server/v3/pkg/common/ginprometheus" - "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" - util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { apiCmd := cmd.NewApiCmd() apiCmd.AddPortFlag() apiCmd.AddPrometheusPortFlag() - apiCmd.AddApi(run) if err := apiCmd.Execute(); err != nil { util.ExitWithError(err) } } - -func run(port int, proPort int) error { - if port == 0 || proPort == 0 { - err := "port or proPort is empty:" + strconv.Itoa(port) + "," + strconv.Itoa(proPort) - return errs.Wrap(fmt.Errorf(err)) - } - rdb, err := cache.NewRedis() - if err != nil { - return err - } - - var client discoveryregistry.SvcDiscoveryRegistry - - // Determine whether zk is passed according to whether it is a clustered deployment - client, err = kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) - if err != nil { - return err - } - - if err = client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil { - return err - } - - if err = client.RegisterConf2Registry(constant.OpenIMCommonConfigKey, config.Config.EncodeConfig()); err != nil { - return err - } - - var ( - netDone = make(chan struct{}, 1) - netErr error - ) - - router := api.NewGinRouter(client, rdb) - if config.Config.Prometheus.Enable { - go func() { - p := ginprom.NewPrometheus("app", prommetrics.GetGinCusMetrics("Api")) - p.SetListenAddress(fmt.Sprintf(":%d", proPort)) - if err = p.Use(router); err != nil && err != http.ErrServerClosed { - netErr = errs.Wrap(err, fmt.Sprintf("prometheus start err: %d", proPort)) - netDone <- struct{}{} - } - }() - } - - var address string - if config.Config.Api.ListenIP != "" { - address = net.JoinHostPort(config.Config.Api.ListenIP, strconv.Itoa(port)) - } else { - address = net.JoinHostPort("0.0.0.0", strconv.Itoa(port)) - } - - server := http.Server{Addr: address, Handler: router} - - go func() { - err = server.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - netErr = errs.Wrap(err, fmt.Sprintf("api start err: %s", server.Addr)) - netDone <- struct{}{} - } - }() - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) - - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - select { - case <-sigs: - util.SIGTERMExit() - err := server.Shutdown(ctx) - if err != nil { - return errs.Wrap(err, "api shutdown err") - } - case <-netDone: - close(netDone) - return netErr - } - return nil -} diff --git a/cmd/openim-crontask/main.go b/cmd/openim-crontask/main.go index b284fd773..b52029c64 100644 --- a/cmd/openim-crontask/main.go +++ b/cmd/openim-crontask/main.go @@ -15,14 +15,13 @@ package main import ( - "github.com/openimsdk/open-im-server/v3/internal/tools" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { cronTaskCmd := cmd.NewCronTaskCmd() - if err := cronTaskCmd.Exec(tools.StartTask); err != nil { + if err := cronTaskCmd.Exec(); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-msggateway/main.go b/cmd/openim-msggateway/main.go index ed67b8f5d..01b13560d 100644 --- a/cmd/openim-msggateway/main.go +++ b/cmd/openim-msggateway/main.go @@ -24,7 +24,6 @@ func main() { msgGatewayCmd.AddWsPortFlag() msgGatewayCmd.AddPortFlag() msgGatewayCmd.AddPrometheusPortFlag() - if err := msgGatewayCmd.Exec(); err != nil { util.ExitWithError(err) } diff --git a/cmd/openim-push/main.go b/cmd/openim-push/main.go index bd31ffdef..c7d29fc97 100644 --- a/cmd/openim-push/main.go +++ b/cmd/openim-push/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/push" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - pushCmd := cmd.NewRpcCmd(cmd.RpcPushServer) + pushCmd := cmd.NewRpcCmd(cmd.RpcPushServer, push.Start) pushCmd.AddPortFlag() pushCmd.AddPrometheusPortFlag() if err := pushCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := pushCmd.StartSvr(config.Config.RpcRegisterName.OpenImPushName, push.Start); err != nil { - util.ExitWithError(err) - } } diff --git a/cmd/openim-rpc/openim-rpc-auth/main.go b/cmd/openim-rpc/openim-rpc-auth/main.go index 992a2b432..da281b70e 100644 --- a/cmd/openim-rpc/openim-rpc-auth/main.go +++ b/cmd/openim-rpc/openim-rpc-auth/main.go @@ -17,19 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/auth" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - authCmd := cmd.NewRpcCmd(cmd.RpcAuthServer) + authCmd := cmd.NewRpcCmd(cmd.RpcAuthServer, auth.Start) authCmd.AddPortFlag() authCmd.AddPrometheusPortFlag() if err := authCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := authCmd.StartSvr(config.Config.RpcRegisterName.OpenImAuthName, auth.Start); err != nil { - util.ExitWithError(err) - } - } diff --git a/cmd/openim-rpc/openim-rpc-conversation/main.go b/cmd/openim-rpc/openim-rpc-conversation/main.go index 10fe0b46c..6e74b3251 100644 --- a/cmd/openim-rpc/openim-rpc-conversation/main.go +++ b/cmd/openim-rpc/openim-rpc-conversation/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/conversation" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcConversationServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcConversationServer, conversation.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImConversationName, conversation.Start); err != nil { - util.ExitWithError(err) - } } diff --git a/cmd/openim-rpc/openim-rpc-friend/main.go b/cmd/openim-rpc/openim-rpc-friend/main.go index 63de23293..a307c01a1 100644 --- a/cmd/openim-rpc/openim-rpc-friend/main.go +++ b/cmd/openim-rpc/openim-rpc-friend/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/friend" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcFriendServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcFriendServer, friend.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImFriendName, friend.Start); err != nil { - util.ExitWithError(err) - } } diff --git a/cmd/openim-rpc/openim-rpc-group/main.go b/cmd/openim-rpc/openim-rpc-group/main.go index c0780acab..2afb7963c 100644 --- a/cmd/openim-rpc/openim-rpc-group/main.go +++ b/cmd/openim-rpc/openim-rpc-group/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/group" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcGroupServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcGroupServer, group.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImGroupName, group.Start); err != nil { - util.ExitWithError(err) - } } diff --git a/cmd/openim-rpc/openim-rpc-msg/main.go b/cmd/openim-rpc/openim-rpc-msg/main.go index 62bdff0a5..bbffbcae7 100644 --- a/cmd/openim-rpc/openim-rpc-msg/main.go +++ b/cmd/openim-rpc/openim-rpc-msg/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/msg" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcMsgServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcMsgServer, msg.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImMsgName, msg.Start); err != nil { - util.ExitWithError(err) - } } diff --git a/cmd/openim-rpc/openim-rpc-third/main.go b/cmd/openim-rpc/openim-rpc-third/main.go index c2893a398..09a8409e6 100644 --- a/cmd/openim-rpc/openim-rpc-third/main.go +++ b/cmd/openim-rpc/openim-rpc-third/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/third" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcThirdServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcThirdServer, third.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImThirdName, third.Start); err != nil { - util.ExitWithError(err) - } } diff --git a/cmd/openim-rpc/openim-rpc-user/main.go b/cmd/openim-rpc/openim-rpc-user/main.go index f7948bda0..18adbfae5 100644 --- a/cmd/openim-rpc/openim-rpc-user/main.go +++ b/cmd/openim-rpc/openim-rpc-user/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/user" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcUserServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcUserServer, user.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { util.ExitWithError(err) } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImUserName, user.Start); err != nil { - util.ExitWithError(err) - } } diff --git a/internal/api/msg.go b/internal/api/msg.go index 61cd1ae6c..d38c14d4e 100644 --- a/internal/api/msg.go +++ b/internal/api/msg.go @@ -198,7 +198,7 @@ func (m *MessageApi) SendMessage(c *gin.Context) { } // Check if the user has the app manager role. - if !authverify.IsAppManagerUid(c) { + if !authverify.IsAppManagerUid(c, m.Config) { // Respond with a permission error if the user is not an app manager. apiresp.GinError(c, errs.ErrNoPermission.Wrap("only app manager can send message")) return @@ -256,7 +256,7 @@ func (m *MessageApi) SendBusinessNotification(c *gin.Context) { return } - if !authverify.IsAppManagerUid(c) { + if !authverify.IsAppManagerUid(c, m.Config) { apiresp.GinError(c, errs.ErrNoPermission.Wrap("only app manager can send message")) return } @@ -300,7 +300,7 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) { return } log.ZInfo(c, "BatchSendMsg", "req", req) - if err := authverify.CheckAdmin(c); err != nil { + if err := authverify.CheckAdmin(c, m.Config); err != nil { apiresp.GinError(c, errs.ErrNoPermission.Wrap("only app manager can send message")) return } diff --git a/internal/api/route.go b/internal/api/route.go index 5c920ca05..ce8e3a62b 100644 --- a/internal/api/route.go +++ b/internal/api/route.go @@ -17,53 +17,142 @@ package api import ( "context" "fmt" + kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister" + ginprom "github.com/openimsdk/open-im-server/v3/pkg/common/ginprometheus" + "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" + "net" "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/tools/apiresp" - "github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/errs" - "github.com/OpenIMSDK/tools/log" - "github.com/OpenIMSDK/tools/mw" "github.com/OpenIMSDK/tools/tokenverify" - "github.com/gin-gonic/gin" - "github.com/gin-gonic/gin/binding" - "github.com/go-playground/validator/v10" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" - "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" + + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" + "github.com/go-playground/validator/v10" "github.com/redis/go-redis/v9" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + + "github.com/OpenIMSDK/tools/discoveryregistry" + "github.com/OpenIMSDK/tools/log" + "github.com/OpenIMSDK/tools/mw" + + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) -func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.UniversalClient) *gin.Engine { - discov.AddOption( - mw.GrpcClient(), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")), - ) // Default RPC middleware +func Start(config *config.GlobalConfig, port int, proPort int) error { + log.ZDebug(context.Background(), "configAPI1111111111111111111", config, "port", port, "javafdasfs") + if port == 0 || proPort == 0 { + err := "port or proPort is empty:" + strconv.Itoa(port) + "," + strconv.Itoa(proPort) + return errs.Wrap(fmt.Errorf(err)) + } + rdb, err := cache.NewRedis(config) + if err != nil { + return err + } + + var client discoveryregistry.SvcDiscoveryRegistry + + // Determine whether zk is passed according to whether it is a clustered deployment + client, err = kdisc.NewDiscoveryRegister(config) + if err != nil { + return errs.Wrap(err, "register discovery err") + } + + if err = client.CreateRpcRootNodes(config.GetServiceNames()); err != nil { + return errs.Wrap(err, "create rpc root nodes error") + } + + if err = client.RegisterConf2Registry(constant.OpenIMCommonConfigKey, config.EncodeConfig()); err != nil { + return errs.Wrap(err) + } + var ( + netDone = make(chan struct{}, 1) + netErr error + ) + router := newGinRouter(client, rdb, config) + if config.Prometheus.Enable { + go func() { + p := ginprom.NewPrometheus("app", prommetrics.GetGinCusMetrics("Api")) + p.SetListenAddress(fmt.Sprintf(":%d", proPort)) + if err = p.Use(router); err != nil && err != http.ErrServerClosed { + netErr = errs.Wrap(err, fmt.Sprintf("prometheus start err: %d", proPort)) + netDone <- struct{}{} + } + }() + + } + + var address string + if config.Api.ListenIP != "" { + address = net.JoinHostPort(config.Api.ListenIP, strconv.Itoa(port)) + } else { + address = net.JoinHostPort("0.0.0.0", strconv.Itoa(port)) + } + + server := http.Server{Addr: address, Handler: router} + + go func() { + err = server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + netErr = errs.Wrap(err, fmt.Sprintf("api start err: %s", server.Addr)) + netDone <- struct{}{} + + } + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + select { + case <-sigs: + util.SIGTERMExit() + err := server.Shutdown(ctx) + if err != nil { + return errs.Wrap(err, "shutdown err") + } + case <-netDone: + close(netDone) + return netErr + } + return nil +} + +func newGinRouter(disCov discoveryregistry.SvcDiscoveryRegistry, rdb redis.UniversalClient, config *config.GlobalConfig) *gin.Engine { + disCov.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) gin.SetMode(gin.ReleaseMode) r := gin.New() if v, ok := binding.Validator.Engine().(*validator.Validate); ok { _ = v.RegisterValidation("required_if", RequiredIf) } - log.ZInfo(context.Background(), "load config", "config", config.Config) r.Use(gin.Recovery(), mw.CorsHandler(), mw.GinParseOperationID()) // init rpc client here - userRpc := rpcclient.NewUser(discov) - groupRpc := rpcclient.NewGroup(discov) - friendRpc := rpcclient.NewFriend(discov) - messageRpc := rpcclient.NewMessage(discov) - conversationRpc := rpcclient.NewConversation(discov) - authRpc := rpcclient.NewAuth(discov) - thirdRpc := rpcclient.NewThird(discov) + userRpc := rpcclient.NewUser(disCov, config) + groupRpc := rpcclient.NewGroup(disCov, config) + friendRpc := rpcclient.NewFriend(disCov, config) + messageRpc := rpcclient.NewMessage(disCov, config) + conversationRpc := rpcclient.NewConversation(disCov, config) + authRpc := rpcclient.NewAuth(disCov, config) + thirdRpc := rpcclient.NewThird(disCov, config) u := NewUserApi(*userRpc) m := NewMessageApi(messageRpc, userRpc) - ParseToken := GinParseToken(rdb) + ParseToken := GinParseToken(rdb, config) userRouterGroup := r.Group("/user") { userRouterGroup.POST("/user_register", u.UserRegister) @@ -157,8 +246,8 @@ func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.Unive // Third service thirdGroup := r.Group("/third", ParseToken) { - thirdGroup.GET("/prometheus", GetPrometheus) t := NewThirdApi(*thirdRpc) + thirdGroup.GET("/prometheus", t.GetPrometheus) thirdGroup.POST("/fcm_update_token", t.FcmUpdateToken) thirdGroup.POST("/set_app_badge", t.SetAppBadge) @@ -225,12 +314,12 @@ func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.Unive return r } -// GinParseToken is a middleware that parses the token in the request header and verifies it. -func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { +func GinParseToken(rdb redis.UniversalClient, config *config.GlobalConfig) gin.HandlerFunc { dataBase := controller.NewAuthDatabase( - cache.NewMsgCacheModel(rdb), - config.Config.Secret, - config.Config.TokenPolicy.Expire, + cache.NewMsgCacheModel(rdb, config), + config.Secret, + config.TokenPolicy.Expire, + config, ) return func(c *gin.Context) { switch c.Request.Method { @@ -242,7 +331,7 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { c.Abort() return } - claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret()) + claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret(config.Secret)) if err != nil { log.ZWarn(c, "jwt get token error", errs.ErrTokenUnknown.Wrap()) apiresp.GinError(c, errs.ErrTokenUnknown.Wrap()) diff --git a/internal/api/third.go b/internal/api/third.go index f00c0d8d3..190e0d540 100644 --- a/internal/api/third.go +++ b/internal/api/third.go @@ -19,12 +19,12 @@ import ( "net/http" "strconv" + "github.com/gin-gonic/gin" + "github.com/OpenIMSDK/protocol/third" "github.com/OpenIMSDK/tools/a2r" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/mcontext" - "github.com/gin-gonic/gin" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) @@ -126,6 +126,6 @@ func (o *ThirdApi) SearchLogs(c *gin.Context) { a2r.Call(third.ThirdClient.SearchLogs, o.Client, c) } -func GetPrometheus(c *gin.Context) { - c.Redirect(http.StatusFound, config2.Config.Prometheus.GrafanaUrl) +func (o *ThirdApi) GetPrometheus(c *gin.Context) { + c.Redirect(http.StatusFound, o.Config.Prometheus.GrafanaUrl) } diff --git a/internal/api/user.go b/internal/api/user.go index 3cc5470a7..468432ee0 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -23,7 +23,7 @@ import ( "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" "github.com/gin-gonic/gin" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) @@ -70,7 +70,7 @@ func (u *UserApi) GetUsersOnlineStatus(c *gin.Context) { apiresp.GinError(c, err) return } - conns, err := u.Discov.GetConns(c, config.Config.RpcRegisterName.OpenImMessageGatewayName) + conns, err := u.Discov.GetConns(c, u.Config.RpcRegisterName.OpenImMessageGatewayName) if err != nil { apiresp.GinError(c, err) return @@ -134,7 +134,7 @@ func (u *UserApi) GetUsersOnlineTokenDetail(c *gin.Context) { apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap()) return } - conns, err := u.Discov.GetConns(c, config.Config.RpcRegisterName.OpenImMessageGatewayName) + conns, err := u.Discov.GetConns(c, u.Config.RpcRegisterName.OpenImMessageGatewayName) if err != nil { apiresp.GinError(c, err) return diff --git a/internal/msggateway/callback.go b/internal/msggateway/callback.go index ede48f74a..ab8c1f51f 100644 --- a/internal/msggateway/callback.go +++ b/internal/msggateway/callback.go @@ -25,12 +25,8 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/http" ) -func callBackURL() string { - return config.Config.Callback.CallbackUrl -} - -func CallbackUserOnline(ctx context.Context, userID string, platformID int, isAppBackground bool, connID string) error { - if !config.Config.Callback.CallbackUserOnline.Enable { +func CallbackUserOnline(ctx context.Context, globalConfig *config.GlobalConfig, userID string, platformID int, isAppBackground bool, connID string) error { + if !globalConfig.Callback.CallbackUserOnline.Enable { return nil } req := cbapi.CallbackUserOnlineReq{ @@ -48,14 +44,14 @@ func CallbackUserOnline(ctx context.Context, userID string, platformID int, isAp ConnID: connID, } resp := cbapi.CommonCallbackResp{} - if err := http.CallBackPostReturn(ctx, callBackURL(), &req, &resp, config.Config.Callback.CallbackUserOnline); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, &req, &resp, globalConfig.Callback.CallbackUserOnline); err != nil { return err } return nil } -func CallbackUserOffline(ctx context.Context, userID string, platformID int, connID string) error { - if !config.Config.Callback.CallbackUserOffline.Enable { +func CallbackUserOffline(ctx context.Context, globalConfig *config.GlobalConfig, userID string, platformID int, connID string) error { + if !globalConfig.Callback.CallbackUserOffline.Enable { return nil } req := &cbapi.CallbackUserOfflineReq{ @@ -72,14 +68,14 @@ func CallbackUserOffline(ctx context.Context, userID string, platformID int, con ConnID: connID, } resp := &cbapi.CallbackUserOfflineResp{} - if err := http.CallBackPostReturn(ctx, callBackURL(), req, resp, config.Config.Callback.CallbackUserOffline); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackUserOffline); err != nil { return err } return nil } -func CallbackUserKickOff(ctx context.Context, userID string, platformID int) error { - if !config.Config.Callback.CallbackUserKickOff.Enable { +func CallbackUserKickOff(ctx context.Context, globalConfig *config.GlobalConfig, userID string, platformID int) error { + if !globalConfig.Callback.CallbackUserKickOff.Enable { return nil } req := &cbapi.CallbackUserKickOffReq{ @@ -95,7 +91,7 @@ func CallbackUserKickOff(ctx context.Context, userID string, platformID int) err Seq: time.Now().UnixMilli(), } resp := &cbapi.CommonCallbackResp{} - if err := http.CallBackPostReturn(ctx, callBackURL(), req, resp, config.Config.Callback.CallbackUserOffline); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackUserOffline); err != nil { return err } return nil diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index 826c2488b..739c4232f 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -31,24 +31,25 @@ import ( "google.golang.org/grpc" ) -func (s *Server) InitServer(disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := cache.NewRedis() +func (s *Server) InitServer(config *config.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis(config) if err != nil { return err } - msgModel := cache.NewMsgCacheModel(rdb) - s.LongConnServer.SetDiscoveryRegistry(disCov) + msgModel := cache.NewMsgCacheModel(rdb, config) + s.LongConnServer.SetDiscoveryRegistry(disCov, config) s.LongConnServer.SetCacheHandler(msgModel) msggateway.RegisterMsgGatewayServer(server, s) return nil } -func (s *Server) Start() error { +func (s *Server) Start(conf *config.GlobalConfig) error { return startrpc.Start( s.rpcPort, - config.Config.RpcRegisterName.OpenImMessageGatewayName, + conf.RpcRegisterName.OpenImMessageGatewayName, s.prometheusPort, + conf, s.InitServer, ) } @@ -58,18 +59,20 @@ type Server struct { prometheusPort int LongConnServer LongConnServer pushTerminal []int + config *config.GlobalConfig } func (s *Server) SetLongConnServer(LongConnServer LongConnServer) { s.LongConnServer = LongConnServer } -func NewServer(rpcPort int, proPort int, longConnServer LongConnServer) *Server { +func NewServer(rpcPort int, proPort int, longConnServer LongConnServer, config *config.GlobalConfig) *Server { return &Server{ rpcPort: rpcPort, prometheusPort: proPort, LongConnServer: longConnServer, pushTerminal: []int{constant.IOSPlatformID, constant.AndroidPlatformID}, + config: config, } } @@ -84,7 +87,7 @@ func (s *Server) GetUsersOnlineStatus( ctx context.Context, req *msggateway.GetUsersOnlineStatusReq, ) (*msggateway.GetUsersOnlineStatusResp, error) { - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { return nil, errs.ErrNoPermission.Wrap("only app manager") } var resp msggateway.GetUsersOnlineStatusResp diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index 5d19ad16d..4efbb7cdf 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -22,23 +22,24 @@ import ( ) // RunWsAndServer run ws server. -func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { +func RunWsAndServer(conf *config.GlobalConfig, rpcPort, wsPort, prometheusPort int) error { fmt.Println("start rpc/msg_gateway server, port: ", rpcPort, wsPort, prometheusPort, ", OpenIM version: ", config.Version) longServer, err := NewWsServer( + conf, WithPort(wsPort), - WithMaxConnNum(int64(config.Config.LongConnSvr.WebsocketMaxConnNum)), - WithHandshakeTimeout(time.Duration(config.Config.LongConnSvr.WebsocketTimeout)*time.Second), - WithMessageMaxMsgLength(config.Config.LongConnSvr.WebsocketMaxMsgLen), - WithWriteBufferSize(config.Config.LongConnSvr.WebsocketWriteBufferSize), + WithMaxConnNum(int64(conf.LongConnSvr.WebsocketMaxConnNum)), + WithHandshakeTimeout(time.Duration(conf.LongConnSvr.WebsocketTimeout)*time.Second), + WithMessageMaxMsgLength(conf.LongConnSvr.WebsocketMaxMsgLen), + WithWriteBufferSize(conf.LongConnSvr.WebsocketWriteBufferSize), ) if err != nil { return err } - hubServer := NewServer(rpcPort, prometheusPort, longServer) + hubServer := NewServer(rpcPort, prometheusPort, longServer, conf) netDone := make(chan error) go func() { - err = hubServer.Start() + err = hubServer.Start(conf) netDone <- err }() return hubServer.LongConnServer.Run(netDone) diff --git a/internal/msggateway/message_handler.go b/internal/msggateway/message_handler.go index 105a77336..208cd6bf7 100644 --- a/internal/msggateway/message_handler.go +++ b/internal/msggateway/message_handler.go @@ -16,6 +16,7 @@ package msggateway import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "sync" "github.com/OpenIMSDK/protocol/msg" @@ -105,9 +106,9 @@ type GrpcHandler struct { validate *validator.Validate } -func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDiscoveryRegistry) *GrpcHandler { - msgRpcClient := rpcclient.NewMessageRpcClient(client) - pushRpcClient := rpcclient.NewPushRpcClient(client) +func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *GrpcHandler { + msgRpcClient := rpcclient.NewMessageRpcClient(client, config) + pushRpcClient := rpcclient.NewPushRpcClient(client, config) return &GrpcHandler{ msgRpcClient: &msgRpcClient, pushClient: &pushRpcClient, validate: validate, diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index e8ba9939a..f5838c703 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -49,7 +49,7 @@ type LongConnServer interface { GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) Validate(s any) error SetCacheHandler(cache cache.MsgModel) - SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) + SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) KickUserConn(client *Client) error UnRegister(c *Client) SetKickHandlerInfo(i *kickHandler) @@ -66,6 +66,7 @@ type LongConnServer interface { // } type WsServer struct { + globalConfig *config.GlobalConfig port int wsMaxConnNum int64 registerChan chan *Client @@ -92,9 +93,9 @@ type kickHandler struct { newClient *Client } -func (ws *WsServer) SetDiscoveryRegistry(disCov discoveryregistry.SvcDiscoveryRegistry) { - ws.MessageHandler = NewGrpcHandler(ws.validate, disCov) - u := rpcclient.NewUserRpcClient(disCov) +func (ws *WsServer) SetDiscoveryRegistry(disCov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) { + ws.MessageHandler = NewGrpcHandler(ws.validate, disCov, config) + u := rpcclient.NewUserRpcClient(disCov, config) ws.userClient = &u ws.disCov = disCov } @@ -106,12 +107,12 @@ func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, sta } switch status { case constant.Online: - err := CallbackUserOnline(ctx, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID()) + err := CallbackUserOnline(ctx, ws.globalConfig, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID()) if err != nil { log.ZWarn(ctx, "CallbackUserOnline err", err) } case constant.Offline: - err := CallbackUserOffline(ctx, client.UserID, client.PlatformID, client.ctx.GetConnID()) + err := CallbackUserOffline(ctx, ws.globalConfig, client.UserID, client.PlatformID, client.ctx.GetConnID()) if err != nil { log.ZWarn(ctx, "CallbackUserOffline err", err) } @@ -141,13 +142,14 @@ func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, return ws.clients.Get(userID, platform) } -func NewWsServer(opts ...Option) (*WsServer, error) { +func NewWsServer(globalConfig *config.GlobalConfig, opts ...Option) (*WsServer, error) { var config configs for _, o := range opts { o(&config) } v := validator.New() return &WsServer{ + globalConfig: globalConfig, port: config.port, wsMaxConnNum: config.maxConnNum, writeBufferSize: config.writeBufferSize, @@ -221,7 +223,7 @@ func (ws *WsServer) Run(done chan error) error { var concurrentRequest = 3 func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error { - conns, err := ws.disCov.GetConns(ctx, config.Config.RpcRegisterName.OpenImMessageGatewayName) + conns, err := ws.disCov.GetConns(ctx, ws.globalConfig.RpcRegisterName.OpenImMessageGatewayName) if err != nil { return err } @@ -286,7 +288,7 @@ func (ws *WsServer) registerClient(client *Client) { } wg := sync.WaitGroup{} - if config.Config.Envs.Discovery == "zookeeper" { + if ws.globalConfig.Envs.Discovery == "zookeeper" { wg.Add(1) go func() { defer wg.Done() @@ -329,7 +331,7 @@ func (ws *WsServer) KickUserConn(client *Client) error { } func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) { - switch config.Config.MultiLoginPolicy { + switch ws.globalConfig.MultiLoginPolicy { case constant.DefalutNotKick: case constant.PCAndOther: if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC { @@ -441,7 +443,7 @@ func (ws *WsServer) ParseWSArgs(r *http.Request) (args *WSArgs, err error) { return nil, errs.ErrConnArgsErr.Wrap("platformID is not int") } v.PlatformID = platformID - if err = authverify.WsVerifyToken(v.Token, v.UserID, platformID); err != nil { + if err = authverify.WsVerifyToken(v.Token, v.UserID, ws.globalConfig.Secret, platformID); err != nil { return nil, err } if query.Get(Compression) == GzipCompressionProtocol { diff --git a/internal/msgtransfer/init.go b/internal/msgtransfer/init.go index dfb7c5307..5e9e80663 100644 --- a/internal/msgtransfer/init.go +++ b/internal/msgtransfer/init.go @@ -44,21 +44,22 @@ type MsgTransfer struct { // This consumer aggregated messages, subscribed to the topic:ws2ms_chat, // the modification notification is sent to msg_to_modify topic, the message is stored in redis, Incr Redis, // and then the message is sent to ms2pschat topic for push, and the message is sent to msg_to_mongo topic for persistence - historyCH *OnlineHistoryRedisConsumerHandler + historyCH *OnlineHistoryRedisConsumerHandler + historyMongoCH *OnlineHistoryMongoConsumerHandler // mongoDB batch insert, delete messages in redis after success, // and handle the deletion notification message deleted subscriptions topic: msg_to_mongo - historyMongoCH *OnlineHistoryMongoConsumerHandler - ctx context.Context - cancel context.CancelFunc + ctx context.Context + cancel context.CancelFunc + config *config.GlobalConfig } -func StartTransfer(prometheusPort int) error { - rdb, err := cache.NewRedis() +func StartTransfer(config *config.GlobalConfig, prometheusPort int) error { + rdb, err := cache.NewRedis(config) if err != nil { return err } - mongo, err := unrelation.NewMongo() + mongo, err := unrelation.NewMongo(config) if err != nil { return err } @@ -66,38 +67,37 @@ func StartTransfer(prometheusPort int) error { if err = mongo.CreateMsgIndex(); err != nil { return err } - - client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) + client, err := kdisc.NewDiscoveryRegister(config) if err != nil { return err } - if err := client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil { + if err := client.CreateRpcRootNodes(config.GetServiceNames()); err != nil { return err } client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) - msgModel := cache.NewMsgCacheModel(rdb) - msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase()) - msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, msgModel) + msgModel := cache.NewMsgCacheModel(rdb, config) + msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase(config.Mongo.Database)) + msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, msgModel, config) if err != nil { return err } - conversationRpcClient := rpcclient.NewConversationRpcClient(client) - groupRpcClient := rpcclient.NewGroupRpcClient(client) - msgTransfer, err := NewMsgTransfer(msgDatabase, &conversationRpcClient, &groupRpcClient) + conversationRpcClient := rpcclient.NewConversationRpcClient(client, config) + groupRpcClient := rpcclient.NewGroupRpcClient(client, config) + msgTransfer, err := NewMsgTransfer(config, msgDatabase, &conversationRpcClient, &groupRpcClient) if err != nil { return err } - return msgTransfer.Start(prometheusPort) + return msgTransfer.Start(prometheusPort, config) } -func NewMsgTransfer(msgDatabase controller.CommonMsgDatabase, conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient) (*MsgTransfer, error) { - historyCH, err := NewOnlineHistoryRedisConsumerHandler(msgDatabase, conversationRpcClient, groupRpcClient) +func NewMsgTransfer(config *config.GlobalConfig, msgDatabase controller.CommonMsgDatabase, conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient) (*MsgTransfer, error) { + historyCH, err := NewOnlineHistoryRedisConsumerHandler(config, msgDatabase, conversationRpcClient, groupRpcClient) if err != nil { return nil, err } - historyMongoCH, err := NewOnlineHistoryMongoConsumerHandler(msgDatabase) + historyMongoCH, err := NewOnlineHistoryMongoConsumerHandler(config, msgDatabase) if err != nil { return nil, err } @@ -105,11 +105,12 @@ func NewMsgTransfer(msgDatabase controller.CommonMsgDatabase, conversationRpcCli return &MsgTransfer{ historyCH: historyCH, historyMongoCH: historyMongoCH, + config: config, }, nil } -func (m *MsgTransfer) Start(prometheusPort int) error { - fmt.Println("Start msg transfer", "prometheusPort:", prometheusPort) +func (m *MsgTransfer) Start(prometheusPort int, config *config.GlobalConfig) error { + fmt.Println("start msg transfer", "prometheusPort:", prometheusPort) if prometheusPort <= 0 { return errs.Wrap(errors.New("prometheusPort not correct")) } @@ -123,13 +124,13 @@ func (m *MsgTransfer) Start(prometheusPort int) error { go m.historyCH.historyConsumerGroup.RegisterHandleAndConsumer(m.ctx, m.historyCH) go m.historyMongoCH.historyConsumerGroup.RegisterHandleAndConsumer(m.ctx, m.historyMongoCH) - if config.Config.Prometheus.Enable { + if config.Prometheus.Enable { go func() { proreg := prometheus.NewRegistry() proreg.MustRegister( collectors.NewGoCollector(), ) - proreg.MustRegister(prommetrics.GetGrpcCusMetrics("Transfer")...) + proreg.MustRegister(prommetrics.GetGrpcCusMetrics("Transfer", config)...) http.Handle("/metrics", promhttp.HandlerFor(proreg, promhttp.HandlerOpts{Registry: proreg})) err := http.ListenAndServe(fmt.Sprintf(":%d", prometheusPort), nil) if err != nil && err != http.ErrServerClosed { diff --git a/internal/msgtransfer/online_history_msg_handler.go b/internal/msgtransfer/online_history_msg_handler.go index 4995d10e8..b81bd12b8 100644 --- a/internal/msgtransfer/online_history_msg_handler.go +++ b/internal/msgtransfer/online_history_msg_handler.go @@ -81,6 +81,7 @@ type OnlineHistoryRedisConsumerHandler struct { } func NewOnlineHistoryRedisConsumerHandler( + config *config.GlobalConfig, database controller.CommonMsgDatabase, conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient, @@ -96,11 +97,29 @@ func NewOnlineHistoryRedisConsumerHandler( och.conversationRpcClient = conversationRpcClient och.groupRpcClient = groupRpcClient var err error + + var tlsConfig *kafka.TLSConfig + if config.Kafka.TLS != nil { + tlsConfig = &kafka.TLSConfig{ + CACrt: config.Kafka.TLS.CACrt, + ClientCrt: config.Kafka.TLS.ClientCrt, + ClientKey: config.Kafka.TLS.ClientKey, + ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd, + InsecureSkipVerify: false, + } + } + och.historyConsumerGroup, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ KafkaVersion: sarama.V2_0_0_0, - OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, - }, []string{config.Config.Kafka.LatestMsgToRedis.Topic}, - config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToRedis) + OffsetsInitial: sarama.OffsetNewest, + IsReturnErr: false, + UserName: config.Kafka.Username, + Password: config.Kafka.Password, + }, []string{config.Kafka.LatestMsgToRedis.Topic}, + config.Kafka.Addr, + config.Kafka.ConsumerGroupID.MsgToRedis, + tlsConfig, + ) // statistics.NewStatistics(&och.singleMsgSuccessCount, config.Config.ModuleName.MsgTransferName, fmt.Sprintf("%d // second singleMsgCount insert to mongo", constant.StatisticsTimeInterval), constant.StatisticsTimeInterval) return &och, err diff --git a/internal/msgtransfer/online_msg_to_mongo_handler.go b/internal/msgtransfer/online_msg_to_mongo_handler.go index efffc191f..045f82220 100644 --- a/internal/msgtransfer/online_msg_to_mongo_handler.go +++ b/internal/msgtransfer/online_msg_to_mongo_handler.go @@ -32,12 +32,28 @@ type OnlineHistoryMongoConsumerHandler struct { msgDatabase controller.CommonMsgDatabase } -func NewOnlineHistoryMongoConsumerHandler(database controller.CommonMsgDatabase) (*OnlineHistoryMongoConsumerHandler, error) { +func NewOnlineHistoryMongoConsumerHandler(config *config.GlobalConfig, database controller.CommonMsgDatabase) (*OnlineHistoryMongoConsumerHandler, error) { + var tlsConfig *kfk.TLSConfig + if config.Kafka.TLS != nil { + tlsConfig = &kfk.TLSConfig{ + CACrt: config.Kafka.TLS.CACrt, + ClientCrt: config.Kafka.TLS.ClientCrt, + ClientKey: config.Kafka.TLS.ClientKey, + ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd, + InsecureSkipVerify: false, + } + } historyConsumerGroup, err := kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{ KafkaVersion: sarama.V2_0_0_0, - OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, - }, []string{config.Config.Kafka.MsgToMongo.Topic}, - config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToMongo) + OffsetsInitial: sarama.OffsetNewest, + IsReturnErr: false, + UserName: config.Kafka.Username, + Password: config.Kafka.Password, + }, []string{config.Kafka.MsgToMongo.Topic}, + config.Kafka.Addr, + config.Kafka.ConsumerGroupID.MsgToMongo, + tlsConfig, + ) if err != nil { return nil, err } diff --git a/internal/push/callback.go b/internal/push/callback.go index 70862e4d2..6415d63d6 100644 --- a/internal/push/callback.go +++ b/internal/push/callback.go @@ -26,12 +26,14 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/http" ) -func url() string { - return config.Config.Callback.CallbackUrl -} - -func callbackOfflinePush(ctx context.Context, userIDs []string, msg *sdkws.MsgData, offlinePushUserIDs *[]string) error { - if !config.Config.Callback.CallbackOfflinePush.Enable || msg.ContentType == constant.Typing { +func callbackOfflinePush( + ctx context.Context, + config *config.GlobalConfig, + userIDs []string, + msg *sdkws.MsgData, + offlinePushUserIDs *[]string, +) error { + if !config.Callback.CallbackOfflinePush.Enable || msg.ContentType == constant.Typing { return nil } req := &callbackstruct.CallbackBeforePushReq{ @@ -55,7 +57,7 @@ func callbackOfflinePush(ctx context.Context, userIDs []string, msg *sdkws.MsgDa } resp := &callbackstruct.CallbackBeforePushResp{} - if err := http.CallBackPostReturn(ctx, url(), req, resp, config.Config.Callback.CallbackOfflinePush); err != nil { + if err := http.CallBackPostReturn(ctx, config.Callback.CallbackUrl, req, resp, config.Callback.CallbackOfflinePush); err != nil { return err } @@ -68,8 +70,8 @@ func callbackOfflinePush(ctx context.Context, userIDs []string, msg *sdkws.MsgDa return nil } -func callbackOnlinePush(ctx context.Context, userIDs []string, msg *sdkws.MsgData) error { - if !config.Config.Callback.CallbackOnlinePush.Enable || utils.Contain(msg.SendID, userIDs...) || msg.ContentType == constant.Typing { +func callbackOnlinePush(ctx context.Context, config *config.GlobalConfig, userIDs []string, msg *sdkws.MsgData) error { + if !config.Callback.CallbackOnlinePush.Enable || utils.Contain(msg.SendID, userIDs...) || msg.ContentType == constant.Typing { return nil } req := callbackstruct.CallbackBeforePushReq{ @@ -91,7 +93,7 @@ func callbackOnlinePush(ctx context.Context, userIDs []string, msg *sdkws.MsgDat Content: GetContent(msg), } resp := &callbackstruct.CallbackBeforePushResp{} - if err := http.CallBackPostReturn(ctx, url(), req, resp, config.Config.Callback.CallbackOnlinePush); err != nil { + if err := http.CallBackPostReturn(ctx, config.Callback.CallbackUrl, req, resp, config.Callback.CallbackOnlinePush); err != nil { return err } return nil @@ -99,11 +101,12 @@ func callbackOnlinePush(ctx context.Context, userIDs []string, msg *sdkws.MsgDat func callbackBeforeSuperGroupOnlinePush( ctx context.Context, + config *config.GlobalConfig, groupID string, msg *sdkws.MsgData, pushToUserIDs *[]string, ) error { - if !config.Config.Callback.CallbackBeforeSuperGroupOnlinePush.Enable || msg.ContentType == constant.Typing { + if !config.Callback.CallbackBeforeSuperGroupOnlinePush.Enable || msg.ContentType == constant.Typing { return nil } req := callbackstruct.CallbackBeforeSuperGroupOnlinePushReq{ @@ -123,7 +126,7 @@ func callbackBeforeSuperGroupOnlinePush( Seq: msg.Seq, } resp := &callbackstruct.CallbackBeforeSuperGroupOnlinePushResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, req, resp, config.Config.Callback.CallbackBeforeSuperGroupOnlinePush); err != nil { + if err := http.CallBackPostReturn(ctx, config.Callback.CallbackUrl, req, resp, config.Callback.CallbackBeforeSuperGroupOnlinePush); err != nil { return err } diff --git a/internal/push/consumer_init.go b/internal/push/consumer_init.go index 92ce4714e..4ad77de2c 100644 --- a/internal/push/consumer_init.go +++ b/internal/push/consumer_init.go @@ -16,6 +16,7 @@ package push import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) type Consumer struct { @@ -24,8 +25,8 @@ type Consumer struct { // successCount uint64 } -func NewConsumer(pusher *Pusher) (*Consumer, error) { - c, err := NewConsumerHandler(pusher) +func NewConsumer(config *config.GlobalConfig, pusher *Pusher) (*Consumer, error) { + c, err := NewConsumerHandler(config, pusher) if err != nil { return nil, err } @@ -36,5 +37,4 @@ func NewConsumer(pusher *Pusher) (*Consumer, error) { func (c *Consumer) Start() { go c.pushCh.pushConsumerGroup.RegisterHandleAndConsumer(context.Background(), &c.pushCh) - } diff --git a/internal/push/offlinepush/fcm/push.go b/internal/push/offlinepush/fcm/push.go index aa6ec186f..977254462 100644 --- a/internal/push/offlinepush/fcm/push.go +++ b/internal/push/offlinepush/fcm/push.go @@ -20,12 +20,14 @@ import ( firebase "firebase.google.com/go" "firebase.google.com/go/messaging" + "github.com/redis/go-redis/v9" + "google.golang.org/api/option" + "github.com/OpenIMSDK/protocol/constant" + "github.com/openimsdk/open-im-server/v3/internal/push/offlinepush" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" - "github.com/redis/go-redis/v9" - "google.golang.org/api/option" ) const SinglePushCountLimit = 400 @@ -39,9 +41,9 @@ type Fcm struct { // NewClient initializes a new FCM client using the Firebase Admin SDK. // It requires the FCM service account credentials file located within the project's configuration directory. -func NewClient(cache cache.MsgModel) *Fcm { - projectRoot, _ := config.GetProjectRoot() - credentialsFilePath := filepath.Join(projectRoot, "config", config.Config.Push.Fcm.ServiceAccount) +func NewClient(globalConfig *config.GlobalConfig, cache cache.MsgModel) *Fcm { + projectRoot := config.GetProjectRoot() + credentialsFilePath := filepath.Join(projectRoot, "config", globalConfig.Push.Fcm.ServiceAccount) opt := option.WithCredentialsFile(credentialsFilePath) fcmApp, err := firebase.NewApp(context.Background(), nil, opt) if err != nil { diff --git a/internal/push/offlinepush/getui/body.go b/internal/push/offlinepush/getui/body.go index 01eb22e73..46479163f 100644 --- a/internal/push/offlinepush/getui/body.go +++ b/internal/push/offlinepush/getui/body.go @@ -133,13 +133,13 @@ type Payload struct { IsSignal bool `json:"isSignal"` } -func newPushReq(title, content string) PushReq { +func newPushReq(config *config.GlobalConfig, title, content string) PushReq { pushReq := PushReq{PushMessage: &PushMessage{Notification: &Notification{ Title: title, Body: content, ClickType: "startapp", - ChannelID: config.Config.Push.GeTui.ChannelID, - ChannelName: config.Config.Push.GeTui.ChannelName, + ChannelID: config.Push.GeTui.ChannelID, + ChannelName: config.Push.GeTui.ChannelName, }}} return pushReq } diff --git a/internal/push/offlinepush/getui/push.go b/internal/push/offlinepush/getui/push.go index e950585a1..67f6292db 100644 --- a/internal/push/offlinepush/getui/push.go +++ b/internal/push/offlinepush/getui/push.go @@ -55,10 +55,15 @@ type Client struct { cache cache.MsgModel tokenExpireTime int64 taskIDTTL int64 + config *config.GlobalConfig } -func NewClient(cache cache.MsgModel) *Client { - return &Client{cache: cache, tokenExpireTime: tokenExpireTime, taskIDTTL: taskIDTTL} +func NewClient(config *config.GlobalConfig, cache cache.MsgModel) *Client { + return &Client{cache: cache, + tokenExpireTime: tokenExpireTime, + taskIDTTL: taskIDTTL, + config: config, + } } func (g *Client) Push(ctx context.Context, userIDs []string, title, content string, opts *offlinepush.Opts) error { @@ -74,7 +79,7 @@ func (g *Client) Push(ctx context.Context, userIDs []string, title, content stri return err } } - pushReq := newPushReq(title, content) + pushReq := newPushReq(g.config, title, content) pushReq.setPushChannel(title, content) if len(userIDs) > 1 { maxNum := 999 @@ -85,9 +90,9 @@ func (g *Client) Push(ctx context.Context, userIDs []string, title, content stri for i, v := range s.GetSplitResult() { go func(index int, userIDs []string) { defer wg.Done() - if err2 := g.batchPush(ctx, token, userIDs, pushReq); err2 != nil { - log.ZError(ctx, "batchPush failed", err2, "index", index, "token", token, "req", pushReq) - err = err2 + if err := g.batchPush(ctx, token, userIDs, pushReq); err != nil { + log.ZError(ctx, "batchPush failed", err, "index", index, "token", token, "req", pushReq) + err = err } }(i, v.Item) } @@ -110,13 +115,13 @@ func (g *Client) Push(ctx context.Context, userIDs []string, title, content stri func (g *Client) Auth(ctx context.Context, timeStamp int64) (token string, expireTime int64, err error) { h := sha256.New() h.Write( - []byte(config.Config.Push.GeTui.AppKey + strconv.Itoa(int(timeStamp)) + config.Config.Push.GeTui.MasterSecret), + []byte(g.config.Push.GeTui.AppKey + strconv.Itoa(int(timeStamp)) + g.config.Push.GeTui.MasterSecret), ) sign := hex.EncodeToString(h.Sum(nil)) reqAuth := AuthReq{ Sign: sign, Timestamp: strconv.Itoa(int(timeStamp)), - AppKey: config.Config.Push.GeTui.AppKey, + AppKey: g.config.Push.GeTui.AppKey, } respAuth := AuthResp{} err = g.request(ctx, authURL, reqAuth, "", &respAuth) @@ -159,7 +164,7 @@ func (g *Client) request(ctx context.Context, url string, input any, token strin header := map[string]string{"token": token} resp := &Resp{} resp.Data = output - return g.postReturn(ctx, config.Config.Push.GeTui.PushUrl+url, header, input, resp, 3) + return g.postReturn(ctx, g.config.Push.GeTui.PushUrl+url, header, input, resp, 3) } func (g *Client) postReturn( diff --git a/internal/push/offlinepush/jpush/body/notification.go b/internal/push/offlinepush/jpush/body/notification.go index ddf3802af..b25882ea5 100644 --- a/internal/push/offlinepush/jpush/body/notification.go +++ b/internal/push/offlinepush/jpush/body/notification.go @@ -46,7 +46,6 @@ type Extras struct { func (n *Notification) SetAlert(alert string) { n.Alert = alert n.Android.Alert = alert - n.SetAndroidIntent() n.IOS.Alert = alert n.IOS.Sound = "default" n.IOS.Badge = "+1" @@ -57,8 +56,8 @@ func (n *Notification) SetExtras(extras Extras) { n.Android.Extras = extras } -func (n *Notification) SetAndroidIntent() { - n.Android.Intent.URL = config.Config.Push.Jpns.PushIntent +func (n *Notification) SetAndroidIntent(config *config.GlobalConfig) { + n.Android.Intent.URL = config.Push.Jpns.PushIntent } func (n *Notification) IOSEnableMutableContent() { diff --git a/internal/push/offlinepush/jpush/push.go b/internal/push/offlinepush/jpush/push.go index 567269f3c..2ced4bfd3 100644 --- a/internal/push/offlinepush/jpush/push.go +++ b/internal/push/offlinepush/jpush/push.go @@ -25,10 +25,12 @@ import ( http2 "github.com/openimsdk/open-im-server/v3/pkg/common/http" ) -type JPush struct{} +type JPush struct { + config *config.GlobalConfig +} -func NewClient() *JPush { - return &JPush{} +func NewClient(config *config.GlobalConfig) *JPush { + return &JPush{config: config} } func (j *JPush) Auth(apiKey, secretKey string, timeStamp int64) (token string, err error) { @@ -59,10 +61,12 @@ func (j *JPush) Push(ctx context.Context, userIDs []string, title, content strin no.IOSEnableMutableContent() no.SetExtras(extras) no.SetAlert(title) + no.SetAndroidIntent(j.config) + var msg body.Message msg.SetMsgContent(content) var opt body.Options - opt.SetApnsProduction(config.Config.IOSPush.Production) + opt.SetApnsProduction(j.config.IOSPush.Production) var pushObj body.PushObj pushObj.SetPlatform(&pf) pushObj.SetAudience(&au) @@ -76,9 +80,9 @@ func (j *JPush) Push(ctx context.Context, userIDs []string, title, content strin func (j *JPush) request(ctx context.Context, po body.PushObj, resp any, timeout int) error { return http2.PostReturn( ctx, - config.Config.Push.Jpns.PushUrl, + j.config.Push.Jpns.PushUrl, map[string]string{ - "Authorization": j.getAuthorization(config.Config.Push.Jpns.AppKey, config.Config.Push.Jpns.MasterSecret), + "Authorization": j.getAuthorization(j.config.Push.Jpns.AppKey, j.config.Push.Jpns.MasterSecret), }, po, resp, diff --git a/internal/push/push_handler.go b/internal/push/push_handler.go index 7b7a1150a..0e68e76b3 100644 --- a/internal/push/push_handler.go +++ b/internal/push/push_handler.go @@ -33,15 +33,29 @@ type ConsumerHandler struct { pusher *Pusher } -func NewConsumerHandler(pusher *Pusher) (*ConsumerHandler, error) { +func NewConsumerHandler(config *config.GlobalConfig, pusher *Pusher) (*ConsumerHandler, error) { var consumerHandler ConsumerHandler consumerHandler.pusher = pusher var err error + var tlsConfig *kfk.TLSConfig + if config.Kafka.TLS != nil { + tlsConfig = &kfk.TLSConfig{ + CACrt: config.Kafka.TLS.CACrt, + ClientCrt: config.Kafka.TLS.ClientCrt, + ClientKey: config.Kafka.TLS.ClientKey, + ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd, + InsecureSkipVerify: false, + } + } consumerHandler.pushConsumerGroup, err = kfk.NewMConsumerGroup(&kfk.MConsumerGroupConfig{ KafkaVersion: sarama.V2_0_0_0, - OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, - }, []string{config.Config.Kafka.MsgToPush.Topic}, config.Config.Kafka.Addr, - config.Config.Kafka.ConsumerGroupID.MsgToPush) + OffsetsInitial: sarama.OffsetNewest, + IsReturnErr: false, + UserName: config.Kafka.Username, + Password: config.Kafka.Password, + }, []string{config.Kafka.MsgToPush.Topic}, config.Kafka.Addr, + config.Kafka.ConsumerGroupID.MsgToPush, + tlsConfig) if err != nil { return nil, err } diff --git a/internal/push/push_rpc_server.go b/internal/push/push_rpc_server.go index c5516e7cf..ed2a0b1ef 100644 --- a/internal/push/push_rpc_server.go +++ b/internal/push/push_rpc_server.go @@ -16,6 +16,7 @@ package push import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/protocol/constant" pbpush "github.com/OpenIMSDK/protocol/push" @@ -31,20 +32,22 @@ import ( type pushServer struct { pusher *Pusher + config *config.GlobalConfig } -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := cache.NewRedis() +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis(config) if err != nil { return err } - cacheModel := cache.NewMsgCacheModel(rdb) - offlinePusher := NewOfflinePusher(cacheModel) + cacheModel := cache.NewMsgCacheModel(rdb, config) + offlinePusher := NewOfflinePusher(config, cacheModel) database := controller.NewPushDatabase(cacheModel) - groupRpcClient := rpcclient.NewGroupRpcClient(client) - conversationRpcClient := rpcclient.NewConversationRpcClient(client) - msgRpcClient := rpcclient.NewMessageRpcClient(client) + groupRpcClient := rpcclient.NewGroupRpcClient(client, config) + conversationRpcClient := rpcclient.NewConversationRpcClient(client, config) + msgRpcClient := rpcclient.NewMessageRpcClient(client, config) pusher := NewPusher( + config, client, offlinePusher, database, @@ -57,9 +60,10 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e pbpush.RegisterPushMsgServiceServer(server, &pushServer{ pusher: pusher, + config: config, }) - consumer, err := NewConsumer(pusher) + consumer, err := NewConsumer(config, pusher) if err != nil { return err } diff --git a/internal/push/push_to_client.go b/internal/push/push_to_client.go index 84efdba0c..49bce70ab 100644 --- a/internal/push/push_to_client.go +++ b/internal/push/push_to_client.go @@ -45,6 +45,7 @@ import ( ) type Pusher struct { + config *config.GlobalConfig database controller.PushDatabase discov discoveryregistry.SvcDiscoveryRegistry offlinePusher offlinepush.OfflinePusher @@ -57,11 +58,12 @@ type Pusher struct { var errNoOfflinePusher = errors.New("no offlinePusher is configured") -func NewPusher(discov discoveryregistry.SvcDiscoveryRegistry, offlinePusher offlinepush.OfflinePusher, database controller.PushDatabase, +func NewPusher(config *config.GlobalConfig, discov discoveryregistry.SvcDiscoveryRegistry, offlinePusher offlinepush.OfflinePusher, database controller.PushDatabase, groupLocalCache *localcache.GroupLocalCache, conversationLocalCache *localcache.ConversationLocalCache, conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient, msgRpcClient *rpcclient.MessageRpcClient, ) *Pusher { return &Pusher{ + config: config, discov: discov, database: database, offlinePusher: offlinePusher, @@ -73,15 +75,15 @@ func NewPusher(discov discoveryregistry.SvcDiscoveryRegistry, offlinePusher offl } } -func NewOfflinePusher(cache cache.MsgModel) offlinepush.OfflinePusher { +func NewOfflinePusher(config *config.GlobalConfig, cache cache.MsgModel) offlinepush.OfflinePusher { var offlinePusher offlinepush.OfflinePusher - switch config.Config.Push.Enable { + switch config.Push.Enable { case "getui": - offlinePusher = getui.NewClient(cache) + offlinePusher = getui.NewClient(config, cache) case "fcm": - offlinePusher = fcm.NewClient(cache) + offlinePusher = fcm.NewClient(config, cache) case "jpush": - offlinePusher = jpush.NewClient() + offlinePusher = jpush.NewClient(config) default: offlinePusher = dummy.NewClient() } @@ -99,7 +101,7 @@ func (p *Pusher) DeleteMemberAndSetConversationSeq(ctx context.Context, groupID func (p *Pusher) Push2User(ctx context.Context, userIDs []string, msg *sdkws.MsgData) error { log.ZDebug(ctx, "Get msg from msg_transfer And push msg", "userIDs", userIDs, "msg", msg.String()) - if err := callbackOnlinePush(ctx, userIDs, msg); err != nil { + if err := callbackOnlinePush(ctx, p.config, userIDs, msg); err != nil { return err } // push @@ -127,7 +129,7 @@ func (p *Pusher) Push2User(ctx context.Context, userIDs []string, msg *sdkws.Msg }) if len(offlinePushUserIDList) > 0 { - if err = callbackOfflinePush(ctx, offlinePushUserIDList, msg, &[]string{}); err != nil { + if err = callbackOfflinePush(ctx, p.config, offlinePushUserIDList, msg, &[]string{}); err != nil { return err } err = p.offlinePushMsg(ctx, msg.SendID, msg, offlinePushUserIDList) @@ -160,7 +162,7 @@ func (p *Pusher) k8sOfflinePush2SuperGroup(ctx context.Context, groupID string, } if len(needOfflinePushUserIDs) > 0 { var offlinePushUserIDs []string - err := callbackOfflinePush(ctx, needOfflinePushUserIDs, msg, &offlinePushUserIDs) + err := callbackOfflinePush(ctx, p.config, needOfflinePushUserIDs, msg, &offlinePushUserIDs) if err != nil { return err } @@ -191,7 +193,7 @@ func (p *Pusher) k8sOfflinePush2SuperGroup(ctx context.Context, groupID string, func (p *Pusher) Push2SuperGroup(ctx context.Context, groupID string, msg *sdkws.MsgData) (err error) { log.ZDebug(ctx, "Get super group msg from msg_transfer and push msg", "msg", msg.String(), "groupID", groupID) var pushToUserIDs []string - if err = callbackBeforeSuperGroupOnlinePush(ctx, groupID, msg, &pushToUserIDs); err != nil { + if err = callbackBeforeSuperGroupOnlinePush(ctx, p.config, groupID, msg, &pushToUserIDs); err != nil { return err } @@ -233,11 +235,11 @@ func (p *Pusher) Push2SuperGroup(ctx context.Context, groupID string, msg *sdkws return err } log.ZInfo(ctx, "GroupDismissedNotificationInfo****", "groupID", groupID, "num", len(pushToUserIDs), "list", pushToUserIDs) - if len(config.Config.Manager.UserID) > 0 { - ctx = mcontext.WithOpUserIDContext(ctx, config.Config.Manager.UserID[0]) + if len(p.config.Manager.UserID) > 0 { + ctx = mcontext.WithOpUserIDContext(ctx, p.config.Manager.UserID[0]) } - if len(config.Config.Manager.UserID) == 0 && len(config.Config.IMAdmin.UserID) > 0 { - ctx = mcontext.WithOpUserIDContext(ctx, config.Config.IMAdmin.UserID[0]) + if len(p.config.Manager.UserID) == 0 && len(p.config.IMAdmin.UserID) > 0 { + ctx = mcontext.WithOpUserIDContext(ctx, p.config.IMAdmin.UserID[0]) } defer func(groupID string) { if err = p.groupRpcClient.DismissGroup(ctx, groupID); err != nil { @@ -255,10 +257,10 @@ func (p *Pusher) Push2SuperGroup(ctx context.Context, groupID string, msg *sdkws log.ZDebug(ctx, "get conn and online push success", "result", wsResults, "msg", msg) isOfflinePush := utils.GetSwitchFromOptions(msg.Options, constant.IsOfflinePush) - if isOfflinePush && config.Config.Envs.Discovery == "k8s" { + if isOfflinePush && p.config.Envs.Discovery == "k8s" { return p.k8sOfflinePush2SuperGroup(ctx, groupID, msg, wsResults) } - if isOfflinePush && config.Config.Envs.Discovery == "zookeeper" { + if isOfflinePush && p.config.Envs.Discovery == "zookeeper" { var ( onlineSuccessUserIDs = []string{msg.SendID} webAndPcBackgroundUserIDs []string @@ -296,7 +298,7 @@ func (p *Pusher) Push2SuperGroup(ctx context.Context, groupID string, msg *sdkws // Use offline push messaging if len(needOfflinePushUserIDs) > 0 { var offlinePushUserIDs []string - err = callbackOfflinePush(ctx, needOfflinePushUserIDs, msg, &offlinePushUserIDs) + err = callbackOfflinePush(ctx, p.config, needOfflinePushUserIDs, msg, &offlinePushUserIDs) if err != nil { return err } @@ -355,7 +357,7 @@ func (p *Pusher) k8sOnlinePush(ctx context.Context, msg *sdkws.MsgData, pushToUs var ( mu sync.Mutex wg = errgroup.Group{} - maxWorkers = config.Config.Push.MaxConcurrentWorkers + maxWorkers = p.config.Push.MaxConcurrentWorkers ) if maxWorkers < 3 { maxWorkers = 3 @@ -384,10 +386,10 @@ func (p *Pusher) k8sOnlinePush(ctx context.Context, msg *sdkws.MsgData, pushToUs return wsResults, nil } func (p *Pusher) GetConnsAndOnlinePush(ctx context.Context, msg *sdkws.MsgData, pushToUserIDs []string) (wsResults []*msggateway.SingleMsgToUserResults, err error) { - if config.Config.Envs.Discovery == "k8s" { + if p.config.Envs.Discovery == "k8s" { return p.k8sOnlinePush(ctx, msg, pushToUserIDs) } - conns, err := p.discov.GetConns(ctx, config.Config.RpcRegisterName.OpenImMessageGatewayName) + conns, err := p.discov.GetConns(ctx, p.config.RpcRegisterName.OpenImMessageGatewayName) log.ZDebug(ctx, "get gateway conn", "conn length", len(conns)) if err != nil { return nil, err @@ -397,7 +399,7 @@ func (p *Pusher) GetConnsAndOnlinePush(ctx context.Context, msg *sdkws.MsgData, mu sync.Mutex wg = errgroup.Group{} input = &msggateway.OnlineBatchPushOneMsgReq{MsgData: msg, PushToUserIDs: pushToUserIDs} - maxWorkers = config.Config.Push.MaxConcurrentWorkers + maxWorkers = p.config.Push.MaxConcurrentWorkers ) if maxWorkers < 3 { diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 5f53911a4..eb1e2f68a 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -38,29 +38,32 @@ type authServer struct { authDatabase controller.AuthDatabase userRpcClient *rpcclient.UserRpcClient RegisterCenter discoveryregistry.SvcDiscoveryRegistry + config *config.GlobalConfig } -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := cache.NewRedis() +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis(config) if err != nil { return err } - userRpcClient := rpcclient.NewUserRpcClient(client) + userRpcClient := rpcclient.NewUserRpcClient(client, config) pbauth.RegisterAuthServer(server, &authServer{ userRpcClient: &userRpcClient, RegisterCenter: client, authDatabase: controller.NewAuthDatabase( - cache.NewMsgCacheModel(rdb), - config.Config.Secret, - config.Config.TokenPolicy.Expire, + cache.NewMsgCacheModel(rdb, config), + config.Secret, + config.TokenPolicy.Expire, + config, ), + config: config, }) return nil } func (s *authServer) UserToken(ctx context.Context, req *pbauth.UserTokenReq) (*pbauth.UserTokenResp, error) { resp := pbauth.UserTokenResp{} - if req.Secret != config.Config.Secret { + if req.Secret != s.config.Secret { return nil, errs.ErrNoPermission.Wrap("secret invalid") } if _, err := s.userRpcClient.GetUserInfo(ctx, req.UserID); err != nil { @@ -72,17 +75,17 @@ func (s *authServer) UserToken(ctx context.Context, req *pbauth.UserTokenReq) (* } prommetrics.UserLoginCounter.Inc() resp.Token = token - resp.ExpireTimeSeconds = config.Config.TokenPolicy.Expire * 24 * 60 * 60 + resp.ExpireTimeSeconds = s.config.TokenPolicy.Expire * 24 * 60 * 60 return &resp, nil } func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) { - if err := authverify.CheckAdmin(ctx); err != nil { + if err := authverify.CheckAdmin(ctx, s.config); err != nil { return nil, err } resp := pbauth.GetUserTokenResp{} - if authverify.IsManagerUserID(req.UserID) { + if authverify.IsManagerUserID(req.UserID, s.config) { return nil, errs.ErrNoPermission.Wrap("don't get Admin token") } @@ -94,12 +97,12 @@ func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenR return nil, err } resp.Token = token - resp.ExpireTimeSeconds = config.Config.TokenPolicy.Expire * 24 * 60 * 60 + resp.ExpireTimeSeconds = s.config.TokenPolicy.Expire * 24 * 60 * 60 return &resp, nil } func (s *authServer) parseToken(ctx context.Context, tokensString string) (claims *tokenverify.Claims, err error) { - claims, err = tokenverify.GetClaimFromToken(tokensString, authverify.Secret()) + claims, err = tokenverify.GetClaimFromToken(tokensString, authverify.Secret(s.config.Secret)) if err != nil { return nil, errs.Wrap(err) } @@ -139,7 +142,7 @@ func (s *authServer) ParseToken( } func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) { - if err := authverify.CheckAdmin(ctx); err != nil { + if err := authverify.CheckAdmin(ctx, s.config); err != nil { return nil, err } if err := s.forceKickOff(ctx, req.UserID, req.PlatformID, mcontext.GetOperationID(ctx)); err != nil { @@ -149,7 +152,7 @@ func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq } func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID int32, operationID string) error { - conns, err := s.RegisterCenter.GetConns(ctx, config.Config.RpcRegisterName.OpenImMessageGatewayName) + conns, err := s.RegisterCenter.GetConns(ctx, s.config.RpcRegisterName.OpenImMessageGatewayName) if err != nil { return err } diff --git a/internal/rpc/conversation/conversaion.go b/internal/rpc/conversation/conversaion.go index f3dc7ef17..cc08de298 100644 --- a/internal/rpc/conversation/conversaion.go +++ b/internal/rpc/conversation/conversaion.go @@ -17,6 +17,7 @@ package conversation import ( "context" "errors" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "sort" "github.com/OpenIMSDK/protocol/constant" @@ -44,6 +45,7 @@ type conversationServer struct { groupRpcClient *rpcclient.GroupRpcClient conversationDatabase controller.ConversationDatabase conversationNotificationSender *notification.ConversationNotificationSender + config *config.GlobalConfig } func (c *conversationServer) GetConversationNotReceiveMessageUserIDs( @@ -54,28 +56,29 @@ func (c *conversationServer) GetConversationNotReceiveMessageUserIDs( panic("implement me") } -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := cache.NewRedis() +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis(config) if err != nil { return err } - mongo, err := unrelation.NewMongo() + mongo, err := unrelation.NewMongo(config) if err != nil { return err } - conversationDB, err := mgo.NewConversationMongo(mongo.GetDatabase()) + conversationDB, err := mgo.NewConversationMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - groupRpcClient := rpcclient.NewGroupRpcClient(client) - msgRpcClient := rpcclient.NewMessageRpcClient(client) - userRpcClient := rpcclient.NewUserRpcClient(client) + groupRpcClient := rpcclient.NewGroupRpcClient(client, config) + msgRpcClient := rpcclient.NewMessageRpcClient(client, config) + userRpcClient := rpcclient.NewUserRpcClient(client, config) pbconversation.RegisterConversationServer(server, &conversationServer{ msgRpcClient: &msgRpcClient, user: &userRpcClient, - conversationNotificationSender: notification.NewConversationNotificationSender(&msgRpcClient), + conversationNotificationSender: notification.NewConversationNotificationSender(config, &msgRpcClient), groupRpcClient: &groupRpcClient, conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx.NewMongo(mongo.GetClient())), + config: config, }) return nil } diff --git a/internal/rpc/friend/black.go b/internal/rpc/friend/black.go index 57edd26ba..64c63eb73 100644 --- a/internal/rpc/friend/black.go +++ b/internal/rpc/friend/black.go @@ -65,7 +65,7 @@ func (s *friendServer) RemoveBlack(ctx context.Context, req *pbfriend.RemoveBlac } func (s *friendServer) AddBlack(ctx context.Context, req *pbfriend.AddBlackReq) (*pbfriend.AddBlackResp, error) { - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config); err != nil { return nil, err } _, err := s.userRpcClient.GetUsersInfo(ctx, []string{req.OwnerUserID, req.BlackUserID}) diff --git a/internal/rpc/friend/callback.go b/internal/rpc/friend/callback.go index 2b6e31899..78d4fc926 100644 --- a/internal/rpc/friend/callback.go +++ b/internal/rpc/friend/callback.go @@ -24,8 +24,8 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/http" ) -func CallbackBeforeAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriendReq) error { - if !config.Config.Callback.CallbackBeforeAddFriend.Enable { +func CallbackBeforeAddFriend(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.ApplyToAddFriendReq) error { + if !globalConfig.Callback.CallbackBeforeAddFriend.Enable { return nil } cbReq := &cbapi.CallbackBeforeAddFriendReq{ @@ -36,14 +36,14 @@ func CallbackBeforeAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriend Ex: req.Ex, } resp := &cbapi.CallbackBeforeAddFriendResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeAddFriend); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeAddFriend); err != nil { return err } return nil } -func CallbackBeforeSetFriendRemark(ctx context.Context, req *pbfriend.SetFriendRemarkReq) error { - if !config.Config.Callback.CallbackBeforeSetFriendRemark.Enable { +func CallbackBeforeSetFriendRemark(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.SetFriendRemarkReq) error { + if !globalConfig.Callback.CallbackBeforeSetFriendRemark.Enable { return nil } cbReq := &cbapi.CallbackBeforeSetFriendRemarkReq{ @@ -53,15 +53,15 @@ func CallbackBeforeSetFriendRemark(ctx context.Context, req *pbfriend.SetFriendR Remark: req.Remark, } resp := &cbapi.CallbackBeforeSetFriendRemarkResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeAddFriend); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeAddFriend); err != nil { return err } utils.NotNilReplace(&req.Remark, &resp.Remark) return nil } -func CallbackAfterSetFriendRemark(ctx context.Context, req *pbfriend.SetFriendRemarkReq) error { - if !config.Config.Callback.CallbackAfterSetFriendRemark.Enable { +func CallbackAfterSetFriendRemark(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.SetFriendRemarkReq) error { + if !globalConfig.Callback.CallbackAfterSetFriendRemark.Enable { return nil } cbReq := &cbapi.CallbackAfterSetFriendRemarkReq{ @@ -71,13 +71,13 @@ func CallbackAfterSetFriendRemark(ctx context.Context, req *pbfriend.SetFriendRe Remark: req.Remark, } resp := &cbapi.CallbackAfterSetFriendRemarkResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeAddFriend); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeAddFriend); err != nil { return err } return nil } -func CallbackBeforeAddBlack(ctx context.Context, req *pbfriend.AddBlackReq) error { - if !config.Config.Callback.CallbackBeforeAddBlack.Enable { +func CallbackBeforeAddBlack(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.AddBlackReq) error { + if !globalConfig.Callback.CallbackBeforeAddBlack.Enable { return nil } cbReq := &cbapi.CallbackBeforeAddBlackReq{ @@ -86,13 +86,13 @@ func CallbackBeforeAddBlack(ctx context.Context, req *pbfriend.AddBlackReq) erro BlackUserID: req.BlackUserID, } resp := &cbapi.CallbackBeforeAddBlackResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeAddBlack); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeAddBlack); err != nil { return err } return nil } -func CallbackAfterAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriendReq) error { - if !config.Config.Callback.CallbackAfterAddFriend.Enable { +func CallbackAfterAddFriend(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.ApplyToAddFriendReq) error { + if !globalConfig.Callback.CallbackAfterAddFriend.Enable { return nil } cbReq := &cbapi.CallbackAfterAddFriendReq{ @@ -102,14 +102,14 @@ func CallbackAfterAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriendR ReqMsg: req.ReqMsg, } resp := &cbapi.CallbackAfterAddFriendResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackAfterAddFriend); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackAfterAddFriend); err != nil { return err } return nil } -func CallbackBeforeAddFriendAgree(ctx context.Context, req *pbfriend.RespondFriendApplyReq) error { - if !config.Config.Callback.CallbackBeforeAddFriendAgree.Enable { +func CallbackBeforeAddFriendAgree(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.RespondFriendApplyReq) error { + if !globalConfig.Callback.CallbackBeforeAddFriendAgree.Enable { return nil } cbReq := &cbapi.CallbackBeforeAddFriendAgreeReq{ @@ -120,13 +120,13 @@ func CallbackBeforeAddFriendAgree(ctx context.Context, req *pbfriend.RespondFrie HandleResult: req.HandleResult, } resp := &cbapi.CallbackBeforeAddFriendAgreeResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeAddFriendAgree); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeAddFriendAgree); err != nil { return err } return nil } -func CallbackAfterDeleteFriend(ctx context.Context, req *pbfriend.DeleteFriendReq) error { - if !config.Config.Callback.CallbackAfterDeleteFriend.Enable { +func CallbackAfterDeleteFriend(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.DeleteFriendReq) error { + if !globalConfig.Callback.CallbackAfterDeleteFriend.Enable { return nil } cbReq := &cbapi.CallbackAfterDeleteFriendReq{ @@ -135,13 +135,13 @@ func CallbackAfterDeleteFriend(ctx context.Context, req *pbfriend.DeleteFriendRe FriendUserID: req.FriendUserID, } resp := &cbapi.CallbackAfterDeleteFriendResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackAfterDeleteFriend); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackAfterDeleteFriend); err != nil { return err } return nil } -func CallbackBeforeImportFriends(ctx context.Context, req *pbfriend.ImportFriendReq) error { - if !config.Config.Callback.CallbackBeforeImportFriends.Enable { +func CallbackBeforeImportFriends(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.ImportFriendReq) error { + if !globalConfig.Callback.CallbackBeforeImportFriends.Enable { return nil } cbReq := &cbapi.CallbackBeforeImportFriendsReq{ @@ -150,7 +150,7 @@ func CallbackBeforeImportFriends(ctx context.Context, req *pbfriend.ImportFriend FriendUserIDs: req.FriendUserIDs, } resp := &cbapi.CallbackBeforeImportFriendsResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeImportFriends); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeImportFriends); err != nil { return err } if len(resp.FriendUserIDs) != 0 { @@ -158,8 +158,8 @@ func CallbackBeforeImportFriends(ctx context.Context, req *pbfriend.ImportFriend } return nil } -func CallbackAfterImportFriends(ctx context.Context, req *pbfriend.ImportFriendReq) error { - if !config.Config.Callback.CallbackAfterImportFriends.Enable { +func CallbackAfterImportFriends(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.ImportFriendReq) error { + if !globalConfig.Callback.CallbackAfterImportFriends.Enable { return nil } cbReq := &cbapi.CallbackAfterImportFriendsReq{ @@ -168,14 +168,14 @@ func CallbackAfterImportFriends(ctx context.Context, req *pbfriend.ImportFriendR FriendUserIDs: req.FriendUserIDs, } resp := &cbapi.CallbackAfterImportFriendsResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackAfterImportFriends); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackAfterImportFriends); err != nil { return err } return nil } -func CallbackAfterRemoveBlack(ctx context.Context, req *pbfriend.RemoveBlackReq) error { - if !config.Config.Callback.CallbackAfterRemoveBlack.Enable { +func CallbackAfterRemoveBlack(ctx context.Context, globalConfig *config.GlobalConfig, req *pbfriend.RemoveBlackReq) error { + if !globalConfig.Callback.CallbackAfterRemoveBlack.Enable { return nil } cbReq := &cbapi.CallbackAfterRemoveBlackReq{ @@ -184,7 +184,7 @@ func CallbackAfterRemoveBlack(ctx context.Context, req *pbfriend.RemoveBlackReq) BlackUserID: req.BlackUserID, } resp := &cbapi.CallbackAfterRemoveBlackResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackAfterRemoveBlack); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackAfterRemoveBlack); err != nil { return err } return nil diff --git a/internal/rpc/friend/friend.go b/internal/rpc/friend/friend.go index 4168731f7..ffdeee98f 100644 --- a/internal/rpc/friend/friend.go +++ b/internal/rpc/friend/friend.go @@ -16,25 +16,33 @@ package friend import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + + "github.com/OpenIMSDK/tools/tx" + + "github.com/OpenIMSDK/protocol/sdkws" + + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + + "github.com/OpenIMSDK/tools/log" + + "github.com/openimsdk/open-im-server/v3/pkg/common/convert" + "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" + + "google.golang.org/grpc" "github.com/OpenIMSDK/protocol/constant" pbfriend "github.com/OpenIMSDK/protocol/friend" - "github.com/OpenIMSDK/protocol/sdkws" registry "github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/errs" - "github.com/OpenIMSDK/tools/log" - "github.com/OpenIMSDK/tools/tx" "github.com/OpenIMSDK/tools/utils" - "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/common/convert" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" "github.com/openimsdk/open-im-server/v3/pkg/common/db/mgo" tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" - "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification" - "google.golang.org/grpc" ) type friendServer struct { @@ -44,42 +52,44 @@ type friendServer struct { notificationSender *notification.FriendNotificationSender conversationRpcClient rpcclient.ConversationRpcClient RegisterCenter registry.SvcDiscoveryRegistry + config *config.GlobalConfig } -func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client registry.SvcDiscoveryRegistry, server *grpc.Server) error { // Initialize MongoDB - mongo, err := unrelation.NewMongo() + mongo, err := unrelation.NewMongo(config) if err != nil { return err } // Initialize Redis - rdb, err := cache.NewRedis() + rdb, err := cache.NewRedis(config) if err != nil { return err } - friendMongoDB, err := mgo.NewFriendMongo(mongo.GetDatabase()) + friendMongoDB, err := mgo.NewFriendMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - friendRequestMongoDB, err := mgo.NewFriendRequestMongo(mongo.GetDatabase()) + friendRequestMongoDB, err := mgo.NewFriendRequestMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - blackMongoDB, err := mgo.NewBlackMongo(mongo.GetDatabase()) + blackMongoDB, err := mgo.NewBlackMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } // Initialize RPC clients - userRpcClient := rpcclient.NewUserRpcClient(client) - msgRpcClient := rpcclient.NewMessageRpcClient(client) + userRpcClient := rpcclient.NewUserRpcClient(client, config) + msgRpcClient := rpcclient.NewMessageRpcClient(client, config) // Initialize notification sender notificationSender := notification.NewFriendNotificationSender( + config, &msgRpcClient, notification.WithRpcFunc(userRpcClient.GetUsersInfo), ) @@ -98,7 +108,8 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { userRpcClient: &userRpcClient, notificationSender: notificationSender, RegisterCenter: client, - conversationRpcClient: rpcclient.NewConversationRpcClient(client), + conversationRpcClient: rpcclient.NewConversationRpcClient(client, config), + config: config, }) return nil @@ -106,15 +117,14 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { // ok. func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriendReq) (resp *pbfriend.ApplyToAddFriendResp, err error) { - defer log.ZInfo(ctx, utils.GetFuncName()+" Return") resp = &pbfriend.ApplyToAddFriendResp{} - if err := authverify.CheckAccessV3(ctx, req.FromUserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.FromUserID,s.config); err != nil { return nil, err } if req.ToUserID == req.FromUserID { - return nil, errs.ErrCanNotAddYourself.Wrap() + return nil, errs.ErrCanNotAddYourself.Wrap("req.ToUserID", req.ToUserID) } - if err = CallbackBeforeAddFriend(ctx, req); err != nil && err != errs.ErrCallbackContinue { + if err = CallbackBeforeAddFriend(ctx, s.config, req); err != nil && err != errs.ErrCallbackContinue { return nil, err } if _, err := s.userRpcClient.GetUsersInfoMap(ctx, []string{req.ToUserID, req.FromUserID}); err != nil { @@ -131,7 +141,7 @@ func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *pbfriend.Apply return nil, err } s.notificationSender.FriendApplicationAddNotification(ctx, req) - if err = CallbackAfterAddFriend(ctx, req); err != nil && err != errs.ErrCallbackContinue { + if err = CallbackAfterAddFriend(ctx,s.config ,req); err != nil && err != errs.ErrCallbackContinue { return nil, err } return resp, nil @@ -140,7 +150,7 @@ func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *pbfriend.Apply // ok. func (s *friendServer) ImportFriends(ctx context.Context, req *pbfriend.ImportFriendReq) (resp *pbfriend.ImportFriendResp, err error) { defer log.ZInfo(ctx, utils.GetFuncName()+" Return") - if err := authverify.CheckAdmin(ctx); err != nil { + if err := authverify.CheckAdmin(ctx, s.config); err != nil { return nil, err } if _, err := s.userRpcClient.GetUsersInfo(ctx, append([]string{req.OwnerUserID}, req.FriendUserIDs...)); err != nil { @@ -152,7 +162,7 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *pbfriend.ImportFr if utils.Duplicate(req.FriendUserIDs) { return nil, errs.ErrArgs.Wrap("friend userID repeated") } - if err := CallbackBeforeImportFriends(ctx, req); err != nil { + if err := CallbackBeforeImportFriends(ctx, s.config, req); err != nil { return nil, err } @@ -166,7 +176,7 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *pbfriend.ImportFr HandleResult: constant.FriendResponseAgree, }) } - if err := CallbackAfterImportFriends(ctx, req); err != nil { + if err := CallbackAfterImportFriends(ctx, s.config, req); err != nil { return nil, err } return &pbfriend.ImportFriendResp{}, nil @@ -176,7 +186,7 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *pbfriend.ImportFr func (s *friendServer) RespondFriendApply(ctx context.Context, req *pbfriend.RespondFriendApplyReq) (resp *pbfriend.RespondFriendApplyResp, err error) { defer log.ZInfo(ctx, utils.GetFuncName()+" Return") resp = &pbfriend.RespondFriendApplyResp{} - if err := authverify.CheckAccessV3(ctx, req.ToUserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.ToUserID, s.config); err != nil { return nil, err } @@ -187,7 +197,7 @@ func (s *friendServer) RespondFriendApply(ctx context.Context, req *pbfriend.Res HandleResult: req.HandleResult, } if req.HandleResult == constant.FriendResponseAgree { - if err := CallbackBeforeAddFriendAgree(ctx, req); err != nil && err != errs.ErrCallbackContinue { + if err := CallbackBeforeAddFriendAgree(ctx, s.config, req); err != nil && err != errs.ErrCallbackContinue { return nil, err } err := s.friendDatabase.AgreeFriendRequest(ctx, &friendRequest) @@ -223,7 +233,7 @@ func (s *friendServer) DeleteFriend(ctx context.Context, req *pbfriend.DeleteFri return nil, err } s.notificationSender.FriendDeletedNotification(ctx, req) - if err := CallbackAfterDeleteFriend(ctx, req); err != nil { + if err := CallbackAfterDeleteFriend(ctx, s.config, req); err != nil { return nil, err } return resp, nil @@ -233,7 +243,7 @@ func (s *friendServer) DeleteFriend(ctx context.Context, req *pbfriend.DeleteFri func (s *friendServer) SetFriendRemark(ctx context.Context, req *pbfriend.SetFriendRemarkReq) (resp *pbfriend.SetFriendRemarkResp, err error) { defer log.ZInfo(ctx, utils.GetFuncName()+" Return") - if err = CallbackBeforeSetFriendRemark(ctx, req); err != nil && err != errs.ErrCallbackContinue { + if err = CallbackBeforeSetFriendRemark(ctx, s.config, req); err != nil && err != errs.ErrCallbackContinue { return nil, err } resp = &pbfriend.SetFriendRemarkResp{} @@ -247,7 +257,7 @@ func (s *friendServer) SetFriendRemark(ctx context.Context, req *pbfriend.SetFri if err := s.friendDatabase.UpdateRemark(ctx, req.OwnerUserID, req.FriendUserID, req.Remark); err != nil { return nil, err } - if err := CallbackAfterSetFriendRemark(ctx, req); err != nil && err != errs.ErrCallbackContinue { + if err := CallbackAfterSetFriendRemark(ctx, s.config, req); err != nil && err != errs.ErrCallbackContinue { return nil, err } s.notificationSender.FriendRemarkSetNotification(ctx, req.OwnerUserID, req.FriendUserID) diff --git a/internal/rpc/group/callback.go b/internal/rpc/group/callback.go index f09c23ec6..e82177dde 100644 --- a/internal/rpc/group/callback.go +++ b/internal/rpc/group/callback.go @@ -32,8 +32,8 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/http" ) -func CallbackBeforeCreateGroup(ctx context.Context, req *group.CreateGroupReq) (err error) { - if !config.Config.Callback.CallbackBeforeCreateGroup.Enable { +func CallbackBeforeCreateGroup(ctx context.Context, globalConfig *config.GlobalConfig, req *group.CreateGroupReq) (err error) { + if !globalConfig.Callback.CallbackBeforeCreateGroup.Enable { return nil } cbReq := &callbackstruct.CallbackBeforeCreateGroupReq{ @@ -58,7 +58,7 @@ func CallbackBeforeCreateGroup(ctx context.Context, req *group.CreateGroupReq) ( }) } resp := &callbackstruct.CallbackBeforeCreateGroupResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeCreateGroup); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeCreateGroup); err != nil { return err } utils.NotNilReplace(&req.GroupInfo.GroupID, resp.GroupID) @@ -76,8 +76,8 @@ func CallbackBeforeCreateGroup(ctx context.Context, req *group.CreateGroupReq) ( return nil } -func CallbackAfterCreateGroup(ctx context.Context, req *group.CreateGroupReq) (err error) { - if !config.Config.Callback.CallbackAfterCreateGroup.Enable { +func CallbackAfterCreateGroup(ctx context.Context, globalConfig *config.GlobalConfig, req *group.CreateGroupReq) (err error) { + if !globalConfig.Callback.CallbackAfterCreateGroup.Enable { return nil } cbReq := &callbackstruct.CallbackAfterCreateGroupReq{ @@ -101,7 +101,7 @@ func CallbackAfterCreateGroup(ctx context.Context, req *group.CreateGroupReq) (e }) } resp := &callbackstruct.CallbackAfterCreateGroupResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackAfterCreateGroup); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackAfterCreateGroup); err != nil { return err } return nil @@ -109,10 +109,11 @@ func CallbackAfterCreateGroup(ctx context.Context, req *group.CreateGroupReq) (e func CallbackBeforeMemberJoinGroup( ctx context.Context, + globalConfig *config.GlobalConfig, groupMember *relation.GroupMemberModel, groupEx string, ) (err error) { - if !config.Config.Callback.CallbackBeforeMemberJoinGroup.Enable { + if !globalConfig.Callback.CallbackBeforeMemberJoinGroup.Enable { return nil } callbackReq := &callbackstruct.CallbackBeforeMemberJoinGroupReq{ @@ -125,10 +126,10 @@ func CallbackBeforeMemberJoinGroup( resp := &callbackstruct.CallbackBeforeMemberJoinGroupResp{} err = http.CallBackPostReturn( ctx, - config.Config.Callback.CallbackUrl, + globalConfig.Callback.CallbackUrl, callbackReq, resp, - config.Config.Callback.CallbackBeforeMemberJoinGroup, + globalConfig.Callback.CallbackBeforeMemberJoinGroup, ) if err != nil { return err @@ -143,8 +144,8 @@ func CallbackBeforeMemberJoinGroup( return nil } -func CallbackBeforeSetGroupMemberInfo(ctx context.Context, req *group.SetGroupMemberInfo) (err error) { - if !config.Config.Callback.CallbackBeforeSetGroupMemberInfo.Enable { +func CallbackBeforeSetGroupMemberInfo(ctx context.Context, globalConfig *config.GlobalConfig, req *group.SetGroupMemberInfo) (err error) { + if !globalConfig.Callback.CallbackBeforeSetGroupMemberInfo.Enable { return nil } callbackReq := callbackstruct.CallbackBeforeSetGroupMemberInfoReq{ @@ -167,10 +168,10 @@ func CallbackBeforeSetGroupMemberInfo(ctx context.Context, req *group.SetGroupMe resp := &callbackstruct.CallbackBeforeSetGroupMemberInfoResp{} err = http.CallBackPostReturn( ctx, - config.Config.Callback.CallbackUrl, + globalConfig.Callback.CallbackUrl, callbackReq, resp, - config.Config.Callback.CallbackBeforeSetGroupMemberInfo, + globalConfig.Callback.CallbackBeforeSetGroupMemberInfo, ) if err != nil { return err @@ -189,8 +190,8 @@ func CallbackBeforeSetGroupMemberInfo(ctx context.Context, req *group.SetGroupMe } return nil } -func CallbackAfterSetGroupMemberInfo(ctx context.Context, req *group.SetGroupMemberInfo) (err error) { - if !config.Config.Callback.CallbackBeforeSetGroupMemberInfo.Enable { +func CallbackAfterSetGroupMemberInfo(ctx context.Context, globalConfig *config.GlobalConfig, req *group.SetGroupMemberInfo) (err error) { + if !globalConfig.Callback.CallbackBeforeSetGroupMemberInfo.Enable { return nil } callbackReq := callbackstruct.CallbackAfterSetGroupMemberInfoReq{ @@ -211,14 +212,14 @@ func CallbackAfterSetGroupMemberInfo(ctx context.Context, req *group.SetGroupMem callbackReq.Ex = &req.Ex.Value } resp := &callbackstruct.CallbackAfterSetGroupMemberInfoResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, callbackReq, resp, config.Config.Callback.CallbackAfterSetGroupMemberInfo); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, callbackReq, resp, globalConfig.Callback.CallbackAfterSetGroupMemberInfo); err != nil { return err } return nil } -func CallbackQuitGroup(ctx context.Context, req *group.QuitGroupReq) (err error) { - if !config.Config.Callback.CallbackQuitGroup.Enable { +func CallbackQuitGroup(ctx context.Context, globalConfig *config.GlobalConfig, req *group.QuitGroupReq) (err error) { + if !globalConfig.Callback.CallbackQuitGroup.Enable { return nil } cbReq := &callbackstruct.CallbackQuitGroupReq{ @@ -227,14 +228,14 @@ func CallbackQuitGroup(ctx context.Context, req *group.QuitGroupReq) (err error) UserID: req.UserID, } resp := &callbackstruct.CallbackQuitGroupResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackQuitGroup); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackQuitGroup); err != nil { return err } return nil } -func CallbackKillGroupMember(ctx context.Context, req *pbgroup.KickGroupMemberReq) (err error) { - if !config.Config.Callback.CallbackKillGroupMember.Enable { +func CallbackKillGroupMember(ctx context.Context, globalConfig *config.GlobalConfig, req *pbgroup.KickGroupMemberReq) (err error) { + if !globalConfig.Callback.CallbackKillGroupMember.Enable { return nil } cbReq := &callbackstruct.CallbackKillGroupMemberReq{ @@ -243,41 +244,41 @@ func CallbackKillGroupMember(ctx context.Context, req *pbgroup.KickGroupMemberRe KickedUserIDs: req.KickedUserIDs, } resp := &callbackstruct.CallbackKillGroupMemberResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackQuitGroup); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackQuitGroup); err != nil { return err } return nil } -func CallbackDismissGroup(ctx context.Context, req *callbackstruct.CallbackDisMissGroupReq) (err error) { - if !config.Config.Callback.CallbackDismissGroup.Enable { +func CallbackDismissGroup(ctx context.Context, globalConfig *config.GlobalConfig, req *callbackstruct.CallbackDisMissGroupReq) (err error) { + if !globalConfig.Callback.CallbackDismissGroup.Enable { return nil } req.CallbackCommand = callbackstruct.CallbackDisMissGroupCommand resp := &callbackstruct.CallbackDisMissGroupResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, req, resp, config.Config.Callback.CallbackQuitGroup); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackQuitGroup); err != nil { return err } return nil } -func CallbackApplyJoinGroupBefore(ctx context.Context, req *callbackstruct.CallbackJoinGroupReq) (err error) { - if !config.Config.Callback.CallbackBeforeJoinGroup.Enable { +func CallbackApplyJoinGroupBefore(ctx context.Context, globalConfig *config.GlobalConfig, req *callbackstruct.CallbackJoinGroupReq) (err error) { + if !globalConfig.Callback.CallbackBeforeJoinGroup.Enable { return nil } req.CallbackCommand = callbackstruct.CallbackBeforeJoinGroupCommand resp := &callbackstruct.CallbackJoinGroupResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, req, resp, config.Config.Callback.CallbackBeforeJoinGroup); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackBeforeJoinGroup); err != nil { return err } return nil } -func CallbackAfterTransferGroupOwner(ctx context.Context, req *pbgroup.TransferGroupOwnerReq) (err error) { - if !config.Config.Callback.CallbackAfterTransferGroupOwner.Enable { +func CallbackAfterTransferGroupOwner(ctx context.Context, globalConfig *config.GlobalConfig, req *pbgroup.TransferGroupOwnerReq) (err error) { + if !globalConfig.Callback.CallbackAfterTransferGroupOwner.Enable { return nil } @@ -289,13 +290,13 @@ func CallbackAfterTransferGroupOwner(ctx context.Context, req *pbgroup.TransferG } resp := &callbackstruct.CallbackTransferGroupOwnerResp{} - if err = http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackAfterTransferGroupOwner); err != nil { + if err = http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackAfterTransferGroupOwner); err != nil { return err } return nil } -func CallbackBeforeInviteUserToGroup(ctx context.Context, req *group.InviteUserToGroupReq) (err error) { - if !config.Config.Callback.CallbackBeforeInviteUserToGroup.Enable { +func CallbackBeforeInviteUserToGroup(ctx context.Context, globalConfig *config.GlobalConfig, req *group.InviteUserToGroupReq) (err error) { + if !globalConfig.Callback.CallbackBeforeInviteUserToGroup.Enable { return nil } @@ -310,10 +311,10 @@ func CallbackBeforeInviteUserToGroup(ctx context.Context, req *group.InviteUserT resp := &callbackstruct.CallbackBeforeInviteUserToGroupResp{} err = http.CallBackPostReturn( ctx, - config.Config.Callback.CallbackUrl, + globalConfig.Callback.CallbackUrl, callbackReq, resp, - config.Config.Callback.CallbackBeforeInviteUserToGroup, + globalConfig.Callback.CallbackBeforeInviteUserToGroup, ) if err != nil { @@ -327,8 +328,8 @@ func CallbackBeforeInviteUserToGroup(ctx context.Context, req *group.InviteUserT return nil } -func CallbackAfterJoinGroup(ctx context.Context, req *group.JoinGroupReq) error { - if !config.Config.Callback.CallbackAfterJoinGroup.Enable { +func CallbackAfterJoinGroup(ctx context.Context, globalConfig *config.GlobalConfig, req *group.JoinGroupReq) error { + if !globalConfig.Callback.CallbackAfterJoinGroup.Enable { return nil } callbackReq := &callbackstruct.CallbackAfterJoinGroupReq{ @@ -340,14 +341,14 @@ func CallbackAfterJoinGroup(ctx context.Context, req *group.JoinGroupReq) error InviterUserID: req.InviterUserID, } resp := &callbackstruct.CallbackAfterJoinGroupResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, callbackReq, resp, config.Config.Callback.CallbackAfterJoinGroup); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, callbackReq, resp, globalConfig.Callback.CallbackAfterJoinGroup); err != nil { return err } return nil } -func CallbackBeforeSetGroupInfo(ctx context.Context, req *group.SetGroupInfoReq) error { - if !config.Config.Callback.CallbackBeforeSetGroupInfo.Enable { +func CallbackBeforeSetGroupInfo(ctx context.Context, globalConfig *config.GlobalConfig, req *group.SetGroupInfoReq) error { + if !globalConfig.Callback.CallbackBeforeSetGroupInfo.Enable { return nil } callbackReq := &callbackstruct.CallbackBeforeSetGroupInfoReq{ @@ -374,7 +375,7 @@ func CallbackBeforeSetGroupInfo(ctx context.Context, req *group.SetGroupInfoReq) } resp := &callbackstruct.CallbackBeforeSetGroupInfoResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, callbackReq, resp, config.Config.Callback.CallbackBeforeSetGroupInfo); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, callbackReq, resp, globalConfig.Callback.CallbackBeforeSetGroupInfo); err != nil { return err } @@ -396,8 +397,8 @@ func CallbackBeforeSetGroupInfo(ctx context.Context, req *group.SetGroupInfoReq) utils.NotNilReplace(&req.GroupInfoForSet.Introduction, &resp.Introduction) return nil } -func CallbackAfterSetGroupInfo(ctx context.Context, req *group.SetGroupInfoReq) error { - if !config.Config.Callback.CallbackAfterSetGroupInfo.Enable { +func CallbackAfterSetGroupInfo(ctx context.Context, globalConfig *config.GlobalConfig, req *group.SetGroupInfoReq) error { + if !globalConfig.Callback.CallbackAfterSetGroupInfo.Enable { return nil } callbackReq := &callbackstruct.CallbackAfterSetGroupInfoReq{ @@ -421,7 +422,7 @@ func CallbackAfterSetGroupInfo(ctx context.Context, req *group.SetGroupInfoReq) callbackReq.ApplyMemberFriend = &req.GroupInfoForSet.ApplyMemberFriend.Value } resp := &callbackstruct.CallbackAfterSetGroupInfoResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, callbackReq, resp, config.Config.Callback.CallbackAfterSetGroupInfo); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, callbackReq, resp, globalConfig.Callback.CallbackAfterSetGroupInfo); err != nil { return err } return nil diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index fa060efda..325c1edce 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -17,6 +17,7 @@ package group import ( "context" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "math/big" "math/rand" "strconv" @@ -50,35 +51,35 @@ import ( "google.golang.org/grpc" ) -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - mongo, err := unrelation.NewMongo() +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + mongo, err := unrelation.NewMongo(config) if err != nil { return err } - rdb, err := cache.NewRedis() + rdb, err := cache.NewRedis(config) if err != nil { return err } - groupDB, err := mgo.NewGroupMongo(mongo.GetDatabase()) + groupDB, err := mgo.NewGroupMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - groupMemberDB, err := mgo.NewGroupMember(mongo.GetDatabase()) + groupMemberDB, err := mgo.NewGroupMember(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - groupRequestDB, err := mgo.NewGroupRequestMgo(mongo.GetDatabase()) + groupRequestDB, err := mgo.NewGroupRequestMgo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - userRpcClient := rpcclient.NewUserRpcClient(client) - msgRpcClient := rpcclient.NewMessageRpcClient(client) - conversationRpcClient := rpcclient.NewConversationRpcClient(client) + userRpcClient := rpcclient.NewUserRpcClient(client, config) + msgRpcClient := rpcclient.NewMessageRpcClient(client, config) + conversationRpcClient := rpcclient.NewConversationRpcClient(client, config) var gs groupServer database := controller.NewGroupDatabase(rdb, groupDB, groupMemberDB, groupRequestDB, tx.NewMongo(mongo.GetClient()), grouphash.NewGroupHashFromGroupServer(&gs)) gs.db = database gs.User = userRpcClient - gs.Notification = notification.NewGroupNotificationSender(database, &msgRpcClient, &userRpcClient, func(ctx context.Context, userIDs []string) ([]notification.CommonUser, error) { + gs.Notification = notification.NewGroupNotificationSender(database, &msgRpcClient, &userRpcClient, config, func(ctx context.Context, userIDs []string) ([]notification.CommonUser, error) { users, err := userRpcClient.GetUsersInfo(ctx, userIDs) if err != nil { return nil, err @@ -87,6 +88,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e }) gs.conversationRpcClient = conversationRpcClient gs.msgRpcClient = msgRpcClient + gs.config = config pbgroup.RegisterGroupServer(server, &gs) return nil } @@ -97,6 +99,7 @@ type groupServer struct { Notification *notification.GroupNotificationSender conversationRpcClient rpcclient.ConversationRpcClient msgRpcClient rpcclient.MessageRpcClient + config *config.GlobalConfig } func (s *groupServer) GetJoinedGroupIDs(ctx context.Context, req *pbgroup.GetJoinedGroupIDsReq) (*pbgroup.GetJoinedGroupIDsResp, error) { @@ -105,7 +108,6 @@ func (s *groupServer) GetJoinedGroupIDs(ctx context.Context, req *pbgroup.GetJoi } func (s *groupServer) NotificationUserInfoUpdate(ctx context.Context, req *pbgroup.NotificationUserInfoUpdateReq) (*pbgroup.NotificationUserInfoUpdateResp, error) { - defer log.ZDebug(ctx, "NotificationUserInfoUpdate return") members, err := s.db.FindGroupMemberUser(ctx, nil, req.UserID) if err != nil { return nil, err @@ -117,7 +119,6 @@ func (s *groupServer) NotificationUserInfoUpdate(ctx context.Context, req *pbgro } groupIDs = append(groupIDs, member.GroupID) } - log.ZInfo(ctx, "NotificationUserInfoUpdate", "joinGroupNum", len(members), "updateNum", len(groupIDs), "updateGroupIDs", groupIDs) for _, groupID := range groupIDs { if err := s.Notification.GroupMemberInfoSetNotification(ctx, groupID, req.UserID); err != nil { log.ZError(ctx, "NotificationUserInfoUpdate setGroupMemberInfo notification failed", err, "groupID", groupID) @@ -131,7 +132,7 @@ func (s *groupServer) NotificationUserInfoUpdate(ctx context.Context, req *pbgro } func (s *groupServer) CheckGroupAdmin(ctx context.Context, groupID string) error { - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { groupMember, err := s.db.TakeGroupMember(ctx, groupID, mcontext.GetOpUserID(ctx)) if err != nil { return err @@ -196,7 +197,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR if req.OwnerUserID == "" { return nil, errs.ErrArgs.Wrap("no group owner") } - if err := authverify.CheckAccessV3(ctx, req.OwnerUserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.OwnerUserID, s.config); err != nil { return nil, err } userIDs := append(append(req.MemberUserIDs, req.AdminUserIDs...), req.OwnerUserID) @@ -215,7 +216,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR return nil, errs.ErrUserIDNotFound.Wrap("user not found") } // Callback Before create Group - if err := CallbackBeforeCreateGroup(ctx, req); err != nil { + if err := CallbackBeforeCreateGroup(ctx, s.config, req); err != nil { return nil, err } var groupMembers []*relationtb.GroupMemberModel @@ -234,7 +235,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR JoinTime: time.Now(), MuteEndTime: time.UnixMilli(0), } - if err := CallbackBeforeMemberJoinGroup(ctx, groupMember, group.Ex); err != nil { + if err := CallbackBeforeMemberJoinGroup(ctx, s.config, groupMember, group.Ex); err != nil { return err } groupMembers = append(groupMembers, groupMember) @@ -302,7 +303,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR AdminUserIDs: req.AdminUserIDs, } - if err := CallbackAfterCreateGroup(ctx, reqCallBackAfter); err != nil { + if err := CallbackAfterCreateGroup(ctx, s.config, reqCallBackAfter); err != nil { return nil, err } @@ -311,7 +312,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR func (s *groupServer) GetJoinedGroupList(ctx context.Context, req *pbgroup.GetJoinedGroupListReq) (*pbgroup.GetJoinedGroupListResp, error) { resp := &pbgroup.GetJoinedGroupListResp{} - if err := authverify.CheckAccessV3(ctx, req.FromUserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.FromUserID, s.config); err != nil { return nil, err } total, members, err := s.db.PageGetJoinGroup(ctx, req.FromUserID, req.Pagination) @@ -381,7 +382,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite } var groupMember *relationtb.GroupMemberModel var opUserID string - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { opUserID = mcontext.GetOpUserID(ctx) var err error groupMember, err = s.db.TakeGroupMember(ctx, req.GroupID, opUserID) @@ -393,11 +394,11 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite } } - if err := CallbackBeforeInviteUserToGroup(ctx, req); err != nil { + if err := CallbackBeforeInviteUserToGroup(ctx, s.config, req); err != nil { return nil, err } if group.NeedVerification == constant.AllNeedVerification { - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { if !(groupMember.RoleLevel == constant.GroupOwner || groupMember.RoleLevel == constant.GroupAdmin) { var requests []*relationtb.GroupRequestModel for _, userID := range req.InvitedUserIDs { @@ -437,7 +438,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbgroup.Invite JoinTime: time.Now(), MuteEndTime: time.UnixMilli(0), } - if err := CallbackBeforeMemberJoinGroup(ctx, member, group.Ex); err != nil { + if err := CallbackBeforeMemberJoinGroup(ctx, s.config, member, group.Ex); err != nil { return nil, err } groupMembers = append(groupMembers, member) @@ -537,7 +538,7 @@ func (s *groupServer) KickGroupMember(ctx context.Context, req *pbgroup.KickGrou for i, member := range members { memberMap[member.UserID] = members[i] } - isAppManagerUid := authverify.IsAppManagerUid(ctx) + isAppManagerUid := authverify.IsAppManagerUid(ctx, s.config) opMember := memberMap[opUserID] for _, userID := range req.KickedUserIDs { member, ok := memberMap[userID] @@ -609,7 +610,7 @@ func (s *groupServer) KickGroupMember(ctx context.Context, req *pbgroup.KickGrou return nil, err } - if err := CallbackKillGroupMember(ctx, req); err != nil { + if err := CallbackKillGroupMember(ctx, s.config, req); err != nil { return nil, err } return resp, nil @@ -735,7 +736,7 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbgroup if !utils.Contain(req.HandleResult, constant.GroupResponseAgree, constant.GroupResponseRefuse) { return nil, errs.ErrArgs.Wrap("HandleResult unknown") } - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { groupMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -779,7 +780,7 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbgroup OperatorUserID: mcontext.GetOpUserID(ctx), Ex: groupRequest.Ex, } - if err = CallbackBeforeMemberJoinGroup(ctx, member, group.Ex); err != nil { + if err = CallbackBeforeMemberJoinGroup(ctx, s.config, member, group.Ex); err != nil { return nil, err } } @@ -827,7 +828,7 @@ func (s *groupServer) JoinGroup(ctx context.Context, req *pbgroup.JoinGroupReq) Ex: req.Ex, } - if err = CallbackApplyJoinGroupBefore(ctx, reqCall); err != nil { + if err = CallbackApplyJoinGroupBefore(ctx, s.config, reqCall); err != nil { return nil, err } _, err = s.db.TakeGroupMember(ctx, req.GroupID, req.InviterUserID) @@ -848,7 +849,7 @@ func (s *groupServer) JoinGroup(ctx context.Context, req *pbgroup.JoinGroupReq) JoinTime: time.Now(), MuteEndTime: time.UnixMilli(0), } - if err := CallbackBeforeMemberJoinGroup(ctx, groupMember, group.Ex); err != nil { + if err := CallbackBeforeMemberJoinGroup(ctx, s.config, groupMember, group.Ex); err != nil { return nil, err } if err := s.db.CreateGroup(ctx, nil, []*relationtb.GroupMemberModel{groupMember}); err != nil { @@ -859,7 +860,7 @@ func (s *groupServer) JoinGroup(ctx context.Context, req *pbgroup.JoinGroupReq) return nil, err } s.Notification.MemberEnterNotification(ctx, req.GroupID, req.InviterUserID) - if err = CallbackAfterJoinGroup(ctx, req); err != nil { + if err = CallbackAfterJoinGroup(ctx, s.config, req); err != nil { return nil, err } return resp, nil @@ -873,7 +874,7 @@ func (s *groupServer) JoinGroup(ctx context.Context, req *pbgroup.JoinGroupReq) HandledTime: time.Unix(0, 0), Ex: req.Ex, } - if err := s.db.CreateGroupRequest(ctx, []*relationtb.GroupRequestModel{&groupRequest}); err != nil { + if err = s.db.CreateGroupRequest(ctx, []*relationtb.GroupRequestModel{&groupRequest}); err != nil { return nil, err } s.Notification.JoinGroupApplicationNotification(ctx, req) @@ -885,7 +886,7 @@ func (s *groupServer) QuitGroup(ctx context.Context, req *pbgroup.QuitGroupReq) if req.UserID == "" { req.UserID = mcontext.GetOpUserID(ctx) } else { - if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.UserID, s.config); err != nil { return nil, err } } @@ -909,7 +910,7 @@ func (s *groupServer) QuitGroup(ctx context.Context, req *pbgroup.QuitGroupReq) } // callback - if err := CallbackQuitGroup(ctx, req); err != nil { + if err := CallbackQuitGroup(ctx, s.config, req); err != nil { return nil, err } return resp, nil @@ -926,7 +927,7 @@ func (s *groupServer) deleteMemberAndSetConversationSeq(ctx context.Context, gro func (s *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInfoReq) (*pbgroup.SetGroupInfoResp, error) { var opMember *relationtb.GroupMemberModel - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { var err error opMember, err = s.db.TakeGroupMember(ctx, req.GroupInfoForSet.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { @@ -939,7 +940,7 @@ func (s *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInf return nil, err } } - if err := CallbackBeforeSetGroupInfo(ctx, req); err != nil { + if err := CallbackBeforeSetGroupInfo(ctx, s.config, req); err != nil { return nil, err } group, err := s.db.TakeGroup(ctx, req.GroupInfoForSet.GroupID) @@ -1008,7 +1009,7 @@ func (s *groupServer) SetGroupInfo(ctx context.Context, req *pbgroup.SetGroupInf if num > 0 { _ = s.Notification.GroupInfoSetNotification(ctx, tips) } - if err := CallbackAfterSetGroupInfo(ctx, req); err != nil { + if err := CallbackAfterSetGroupInfo(ctx, s.config, req); err != nil { return nil, err } return resp, nil @@ -1045,7 +1046,7 @@ func (s *groupServer) TransferGroupOwner(ctx context.Context, req *pbgroup.Trans if newOwner == nil { return nil, errs.ErrArgs.Wrap("NewOwnerUser not in group " + req.NewOwnerUserID) } - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { if !(mcontext.GetOpUserID(ctx) == oldOwner.UserID && oldOwner.RoleLevel == constant.GroupOwner) { return nil, errs.ErrNoPermission.Wrap("no permission transfer group owner") } @@ -1054,7 +1055,7 @@ func (s *groupServer) TransferGroupOwner(ctx context.Context, req *pbgroup.Trans return nil, err } - if err := CallbackAfterTransferGroupOwner(ctx, req); err != nil { + if err := CallbackAfterTransferGroupOwner(ctx, s.config, req); err != nil { return nil, err } s.Notification.GroupOwnerTransferredNotification(ctx, req) @@ -1186,7 +1187,7 @@ func (s *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou if err != nil { return nil, err } - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { if owner.UserID != mcontext.GetOpUserID(ctx) { return nil, errs.ErrNoPermission.Wrap("not group owner") } @@ -1228,7 +1229,7 @@ func (s *groupServer) DismissGroup(ctx context.Context, req *pbgroup.DismissGrou MembersID: membersID, GroupType: string(group.GroupType), } - if err := CallbackDismissGroup(ctx, reqCall); err != nil { + if err := CallbackDismissGroup(ctx, s.config, reqCall); err != nil { return nil, err } @@ -1244,7 +1245,7 @@ func (s *groupServer) MuteGroupMember(ctx context.Context, req *pbgroup.MuteGrou if err := s.PopulateGroupMember(ctx, member); err != nil { return nil, err } - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { opMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -1278,7 +1279,7 @@ func (s *groupServer) CancelMuteGroupMember(ctx context.Context, req *pbgroup.Ca if err := s.PopulateGroupMember(ctx, member); err != nil { return nil, err } - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, s.config) { opMember, err := s.db.TakeGroupMember(ctx, req.GroupID, mcontext.GetOpUserID(ctx)) if err != nil { return nil, err @@ -1337,7 +1338,7 @@ func (s *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr if opUserID == "" { return nil, errs.ErrNoPermission.Wrap("no op user id") } - isAppManagerUid := authverify.IsAppManagerUid(ctx) + isAppManagerUid := authverify.IsAppManagerUid(ctx, s.config) for i := range req.Members { req.Members[i].FaceURL = nil } @@ -1420,7 +1421,7 @@ func (s *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr } } for i := 0; i < len(req.Members); i++ { - if err := CallbackBeforeSetGroupMemberInfo(ctx, req.Members[i]); err != nil { + if err := CallbackBeforeSetGroupMemberInfo(ctx, s.config, req.Members[i]); err != nil { return nil, err } } @@ -1447,7 +1448,7 @@ func (s *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbgroup.SetGr } } for i := 0; i < len(req.Members); i++ { - if err := CallbackAfterSetGroupMemberInfo(ctx, req.Members[i]); err != nil { + if err := CallbackAfterSetGroupMemberInfo(ctx, s.config, req.Members[i]); err != nil { return nil, err } } diff --git a/internal/rpc/msg/as_read.go b/internal/rpc/msg/as_read.go index cac9102fa..f16fdc62f 100644 --- a/internal/rpc/msg/as_read.go +++ b/internal/rpc/msg/as_read.go @@ -129,7 +129,7 @@ func (m *msgServer) MarkMsgsAsRead( Seqs: req.Seqs, ContentType: conversation.ConversationType, } - if err = CallbackSingleMsgRead(ctx, req_callback); err != nil { + if err = CallbackSingleMsgRead(ctx, m.config, req_callback); err != nil { return nil, err } @@ -206,7 +206,7 @@ func (m *msgServer) MarkConversationAsRead( UnreadMsgNum: req.HasReadSeq, ContentType: int64(conversation.ConversationType), } - if err := CallbackGroupMsgRead(ctx, reqCall); err != nil { + if err := CallbackGroupMsgRead(ctx, m.config, reqCall); err != nil { return nil, err } diff --git a/internal/rpc/msg/callback.go b/internal/rpc/msg/callback.go index 1a7cad70c..536402bf9 100644 --- a/internal/rpc/msg/callback.go +++ b/internal/rpc/msg/callback.go @@ -29,10 +29,6 @@ import ( "google.golang.org/protobuf/proto" ) -func cbURL() string { - return config.Config.Callback.CallbackUrl -} - func toCommonCallback(ctx context.Context, msg *pbchat.SendMsgReq, command string) cbapi.CommonCallbackReq { return cbapi.CommonCallbackReq{ SendID: msg.MsgData.SendID, @@ -66,8 +62,8 @@ func GetContent(msg *sdkws.MsgData) string { } } -func callbackBeforeSendSingleMsg(ctx context.Context, msg *pbchat.SendMsgReq) error { - if !config.Config.Callback.CallbackBeforeSendSingleMsg.Enable || msg.MsgData.ContentType == constant.Typing { +func callbackBeforeSendSingleMsg(ctx context.Context, globalConfig *config.GlobalConfig, msg *pbchat.SendMsgReq) error { + if !globalConfig.Callback.CallbackBeforeSendSingleMsg.Enable || msg.MsgData.ContentType == constant.Typing { return nil } req := &cbapi.CallbackBeforeSendSingleMsgReq{ @@ -75,14 +71,14 @@ func callbackBeforeSendSingleMsg(ctx context.Context, msg *pbchat.SendMsgReq) er RecvID: msg.MsgData.RecvID, } resp := &cbapi.CallbackBeforeSendSingleMsgResp{} - if err := http.CallBackPostReturn(ctx, cbURL(), req, resp, config.Config.Callback.CallbackBeforeSendSingleMsg); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackBeforeSendSingleMsg); err != nil { return err } return nil } -func callbackAfterSendSingleMsg(ctx context.Context, msg *pbchat.SendMsgReq) error { - if !config.Config.Callback.CallbackAfterSendSingleMsg.Enable || msg.MsgData.ContentType == constant.Typing { +func callbackAfterSendSingleMsg(ctx context.Context, globalConfig *config.GlobalConfig, msg *pbchat.SendMsgReq) error { + if !globalConfig.Callback.CallbackAfterSendSingleMsg.Enable || msg.MsgData.ContentType == constant.Typing { return nil } req := &cbapi.CallbackAfterSendSingleMsgReq{ @@ -90,14 +86,14 @@ func callbackAfterSendSingleMsg(ctx context.Context, msg *pbchat.SendMsgReq) err RecvID: msg.MsgData.RecvID, } resp := &cbapi.CallbackAfterSendSingleMsgResp{} - if err := http.CallBackPostReturn(ctx, cbURL(), req, resp, config.Config.Callback.CallbackAfterSendSingleMsg); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackAfterSendSingleMsg); err != nil { return err } return nil } -func callbackBeforeSendGroupMsg(ctx context.Context, msg *pbchat.SendMsgReq) error { - if !config.Config.Callback.CallbackBeforeSendGroupMsg.Enable || msg.MsgData.ContentType == constant.Typing { +func callbackBeforeSendGroupMsg(ctx context.Context, globalConfig *config.GlobalConfig, msg *pbchat.SendMsgReq) error { + if !globalConfig.Callback.CallbackBeforeSendGroupMsg.Enable || msg.MsgData.ContentType == constant.Typing { return nil } req := &cbapi.CallbackBeforeSendGroupMsgReq{ @@ -105,14 +101,14 @@ func callbackBeforeSendGroupMsg(ctx context.Context, msg *pbchat.SendMsgReq) err GroupID: msg.MsgData.GroupID, } resp := &cbapi.CallbackBeforeSendGroupMsgResp{} - if err := http.CallBackPostReturn(ctx, cbURL(), req, resp, config.Config.Callback.CallbackBeforeSendGroupMsg); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackBeforeSendGroupMsg); err != nil { return err } return nil } -func callbackAfterSendGroupMsg(ctx context.Context, msg *pbchat.SendMsgReq) error { - if !config.Config.Callback.CallbackAfterSendGroupMsg.Enable || msg.MsgData.ContentType == constant.Typing { +func callbackAfterSendGroupMsg(ctx context.Context, globalConfig *config.GlobalConfig, msg *pbchat.SendMsgReq) error { + if !globalConfig.Callback.CallbackAfterSendGroupMsg.Enable || msg.MsgData.ContentType == constant.Typing { return nil } req := &cbapi.CallbackAfterSendGroupMsgReq{ @@ -120,21 +116,21 @@ func callbackAfterSendGroupMsg(ctx context.Context, msg *pbchat.SendMsgReq) erro GroupID: msg.MsgData.GroupID, } resp := &cbapi.CallbackAfterSendGroupMsgResp{} - if err := http.CallBackPostReturn(ctx, cbURL(), req, resp, config.Config.Callback.CallbackAfterSendGroupMsg); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackAfterSendGroupMsg); err != nil { return err } return nil } -func callbackMsgModify(ctx context.Context, msg *pbchat.SendMsgReq) error { - if !config.Config.Callback.CallbackMsgModify.Enable || msg.MsgData.ContentType != constant.Text { +func callbackMsgModify(ctx context.Context, globalConfig *config.GlobalConfig, msg *pbchat.SendMsgReq) error { + if !globalConfig.Callback.CallbackMsgModify.Enable || msg.MsgData.ContentType != constant.Text { return nil } req := &cbapi.CallbackMsgModifyCommandReq{ CommonCallbackReq: toCommonCallback(ctx, msg, cbapi.CallbackMsgModifyCommand), } resp := &cbapi.CallbackMsgModifyCommandResp{} - if err := http.CallBackPostReturn(ctx, cbURL(), req, resp, config.Config.Callback.CallbackMsgModify); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackMsgModify); err != nil { return err } if resp.Content != nil { @@ -159,34 +155,34 @@ func callbackMsgModify(ctx context.Context, msg *pbchat.SendMsgReq) error { log.ZDebug(ctx, "callbackMsgModify", "msg", msg.MsgData) return nil } -func CallbackGroupMsgRead(ctx context.Context, req *cbapi.CallbackGroupMsgReadReq) error { - if !config.Config.Callback.CallbackGroupMsgRead.Enable || req.ContentType != constant.Text { +func CallbackGroupMsgRead(ctx context.Context, globalConfig *config.GlobalConfig, req *cbapi.CallbackGroupMsgReadReq) error { + if !globalConfig.Callback.CallbackGroupMsgRead.Enable || req.ContentType != constant.Text { return nil } req.CallbackCommand = cbapi.CallbackGroupMsgReadCommand resp := &cbapi.CallbackGroupMsgReadResp{} - if err := http.CallBackPostReturn(ctx, cbURL(), req, resp, config.Config.Callback.CallbackMsgModify); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackMsgModify); err != nil { return err } return nil } -func CallbackSingleMsgRead(ctx context.Context, req *cbapi.CallbackSingleMsgReadReq) error { - if !config.Config.Callback.CallbackSingleMsgRead.Enable || req.ContentType != constant.Text { +func CallbackSingleMsgRead(ctx context.Context, globalConfig *config.GlobalConfig, req *cbapi.CallbackSingleMsgReadReq) error { + if !globalConfig.Callback.CallbackSingleMsgRead.Enable || req.ContentType != constant.Text { return nil } req.CallbackCommand = cbapi.CallbackSingleMsgRead resp := &cbapi.CallbackSingleMsgReadResp{} - if err := http.CallBackPostReturn(ctx, cbURL(), req, resp, config.Config.Callback.CallbackMsgModify); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, req, resp, globalConfig.Callback.CallbackMsgModify); err != nil { return err } return nil } -func CallbackAfterRevokeMsg(ctx context.Context, req *pbchat.RevokeMsgReq) error { - if !config.Config.Callback.CallbackAfterRevokeMsg.Enable { +func CallbackAfterRevokeMsg(ctx context.Context, globalConfig *config.GlobalConfig, req *pbchat.RevokeMsgReq) error { + if !globalConfig.Callback.CallbackAfterRevokeMsg.Enable { return nil } callbackReq := &cbapi.CallbackAfterRevokeMsgReq{ @@ -196,7 +192,7 @@ func CallbackAfterRevokeMsg(ctx context.Context, req *pbchat.RevokeMsgReq) error UserID: req.UserID, } resp := &cbapi.CallbackAfterRevokeMsgResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, callbackReq, resp, config.Config.Callback.CallbackAfterRevokeMsg); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, callbackReq, resp, globalConfig.Callback.CallbackAfterRevokeMsg); err != nil { return err } return nil diff --git a/internal/rpc/msg/delete.go b/internal/rpc/msg/delete.go index 091f11f4c..14e24d23e 100644 --- a/internal/rpc/msg/delete.go +++ b/internal/rpc/msg/delete.go @@ -45,7 +45,7 @@ func (m *msgServer) ClearConversationsMsg( ctx context.Context, req *msg.ClearConversationsMsgReq, ) (*msg.ClearConversationsMsgResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil { return nil, err } if err := m.clearConversation(ctx, req.ConversationIDs, req.UserID, req.DeleteSyncOpt); err != nil { @@ -58,7 +58,7 @@ func (m *msgServer) UserClearAllMsg( ctx context.Context, req *msg.UserClearAllMsgReq, ) (*msg.UserClearAllMsgResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil { return nil, err } conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) @@ -73,7 +73,7 @@ func (m *msgServer) UserClearAllMsg( } func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*msg.DeleteMsgsResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil { return nil, err } isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(req.DeleteSyncOpt) @@ -121,7 +121,7 @@ func (m *msgServer) DeleteMsgPhysical( ctx context.Context, req *msg.DeleteMsgPhysicalReq, ) (*msg.DeleteMsgPhysicalResp, error) { - if err := authverify.CheckAdmin(ctx); err != nil { + if err := authverify.CheckAdmin(ctx, m.config); err != nil { return nil, err } remainTime := utils.GetCurrentTimestampBySecond() - req.Timestamp diff --git a/internal/rpc/msg/message_interceptor.go b/internal/rpc/msg/message_interceptor.go index 3a2731fea..97eac613d 100644 --- a/internal/rpc/msg/message_interceptor.go +++ b/internal/rpc/msg/message_interceptor.go @@ -24,17 +24,17 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) -type MessageInterceptorFunc func(ctx context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) +type MessageInterceptorFunc func(ctx context.Context, globalConfig *config.GlobalConfig, req *msg.SendMsgReq) (*sdkws.MsgData, error) -func MessageHasReadEnabled(_ context.Context, req *msg.SendMsgReq) (*sdkws.MsgData, error) { +func MessageHasReadEnabled(_ context.Context, globalConfig *config.GlobalConfig, req *msg.SendMsgReq) (*sdkws.MsgData, error) { switch { case req.MsgData.ContentType == constant.HasReadReceipt && req.MsgData.SessionType == constant.SingleChatType: - if !config.Config.SingleMessageHasReadReceiptEnable { + if !globalConfig.SingleMessageHasReadReceiptEnable { return nil, errs.ErrMessageHasReadDisable.Wrap() } return req.MsgData, nil case req.MsgData.ContentType == constant.HasReadReceipt && req.MsgData.SessionType == constant.SuperGroupChatType: - if !config.Config.GroupMessageHasReadReceiptEnable { + if !globalConfig.GroupMessageHasReadReceiptEnable { return nil, errs.ErrMessageHasReadDisable.Wrap() } return req.MsgData, nil diff --git a/internal/rpc/msg/revoke.go b/internal/rpc/msg/revoke.go index eebfdf779..4f844369f 100644 --- a/internal/rpc/msg/revoke.go +++ b/internal/rpc/msg/revoke.go @@ -19,6 +19,8 @@ import ( "encoding/json" "time" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/msg" "github.com/OpenIMSDK/protocol/sdkws" @@ -26,8 +28,7 @@ import ( "github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/mcontext" "github.com/OpenIMSDK/tools/utils" - "github.com/openimsdk/open-im-server/v3/pkg/authverify" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" + unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation" ) @@ -42,7 +43,7 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. if req.Seq < 0 { return nil, errs.ErrArgs.Wrap("seq is invalid") } - if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil { return nil, err } user, err := m.User.GetUserInfo(ctx, req.UserID) @@ -63,10 +64,10 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. data, _ := json.Marshal(msgs[0]) log.ZInfo(ctx, "GetMsgBySeqs", "conversationID", req.ConversationID, "seq", req.Seq, "msg", string(data)) var role int32 - if !authverify.IsAppManagerUid(ctx) { + if !authverify.IsAppManagerUid(ctx, m.config) { switch msgs[0].SessionType { case constant.SingleChatType: - if err := authverify.CheckAccessV3(ctx, msgs[0].SendID); err != nil { + if err := authverify.CheckAccessV3(ctx, msgs[0].SendID, m.config); err != nil { return nil, err } role = user.AppMangerLevel @@ -110,11 +111,11 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. } revokerUserID := mcontext.GetOpUserID(ctx) var flag bool - if len(config.Config.Manager.UserID) > 0 { - flag = utils.Contain(revokerUserID, config.Config.Manager.UserID...) + if len(m.config.Manager.UserID) > 0 { + flag = utils.Contain(revokerUserID, m.config.Manager.UserID...) } - if len(config.Config.Manager.UserID) == 0 && len(config.Config.IMAdmin.UserID) > 0 { - flag = utils.Contain(revokerUserID, config.Config.IMAdmin.UserID...) + if len(m.config.Manager.UserID) == 0 && len(m.config.IMAdmin.UserID) > 0 { + flag = utils.Contain(revokerUserID, m.config.IMAdmin.UserID...) } tips := sdkws.RevokeMsgTips{ RevokerUserID: revokerUserID, @@ -134,7 +135,7 @@ func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg. if err := m.notificationSender.NotificationWithSesstionType(ctx, req.UserID, recvID, constant.MsgRevokeNotification, msgs[0].SessionType, &tips); err != nil { return nil, err } - if err = CallbackAfterRevokeMsg(ctx, req); err != nil { + if err = CallbackAfterRevokeMsg(ctx, m.config, req); err != nil { return nil, err } return &msg.RevokeMsgResp{}, nil diff --git a/internal/rpc/msg/send.go b/internal/rpc/msg/send.go index e04cebb3b..4bac4d1e0 100644 --- a/internal/rpc/msg/send.go +++ b/internal/rpc/msg/send.go @@ -17,6 +17,9 @@ package msg import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" + "github.com/OpenIMSDK/protocol/constant" pbconversation "github.com/OpenIMSDK/protocol/conversation" pbmsg "github.com/OpenIMSDK/protocol/msg" @@ -26,14 +29,12 @@ import ( "github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/mcontext" "github.com/OpenIMSDK/tools/utils" - "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" - "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" ) func (m *msgServer) SendMsg(ctx context.Context, req *pbmsg.SendMsgReq) (resp *pbmsg.SendMsgResp, error error) { resp = &pbmsg.SendMsgResp{} if req.MsgData != nil { - flag := isMessageHasReadEnabled(req.MsgData) + flag := isMessageHasReadEnabled(req.MsgData, m.config) if !flag { return nil, errs.ErrMessageHasReadDisable.Wrap() } @@ -61,11 +62,11 @@ func (m *msgServer) sendMsgSuperGroupChat( prommetrics.GroupChatMsgProcessFailedCounter.Inc() return nil, err } - if err = callbackBeforeSendGroupMsg(ctx, req); err != nil { + if err = callbackBeforeSendGroupMsg(ctx, m.config, req); err != nil { return nil, err } - if err := callbackMsgModify(ctx, req); err != nil { + if err := callbackMsgModify(ctx, m.config, req); err != nil { return nil, err } err = m.MsgDatabase.MsgToMQ(ctx, utils.GenConversationUniqueKeyForGroup(req.MsgData.GroupID), req.MsgData) @@ -75,7 +76,7 @@ func (m *msgServer) sendMsgSuperGroupChat( if req.MsgData.ContentType == constant.AtText { go m.setConversationAtInfo(ctx, req.MsgData) } - if err = callbackAfterSendGroupMsg(ctx, req); err != nil { + if err = callbackAfterSendGroupMsg(ctx, m.config, req); err != nil { log.ZWarn(ctx, "CallbackAfterSendGroupMsg", err) } prommetrics.GroupChatMsgProcessSuccessCounter.Inc() @@ -107,7 +108,7 @@ func (m *msgServer) setConversationAtInfo(nctx context.Context, msg *sdkws.MsgDa conversation.GroupAtType = &wrapperspb.Int32Value{Value: constant.AtAll} } else { //@Everyone and @other people conversation.GroupAtType = &wrapperspb.Int32Value{Value: constant.AtAllAtMe} - err := m.Conversation.SetConversations(ctx, atUserID, conversation) + err = m.Conversation.SetConversations(ctx, atUserID, conversation) if err != nil { log.ZWarn(ctx, "SetConversations", err, "userID", atUserID, "conversation", conversation) } @@ -164,18 +165,18 @@ func (m *msgServer) sendMsgSingleChat(ctx context.Context, req *pbmsg.SendMsgReq prommetrics.SingleChatMsgProcessFailedCounter.Inc() return nil, nil } else { - if err = callbackBeforeSendSingleMsg(ctx, req); err != nil { + if err = callbackBeforeSendSingleMsg(ctx, m.config, req); err != nil { return nil, err } - if err := callbackMsgModify(ctx, req); err != nil { + if err := callbackMsgModify(ctx, m.config, req); err != nil { return nil, err } if err := m.MsgDatabase.MsgToMQ(ctx, utils.GenConversationUniqueKeyForSingle(req.MsgData.SendID, req.MsgData.RecvID), req.MsgData); err != nil { prommetrics.SingleChatMsgProcessFailedCounter.Inc() return nil, err } - err = callbackAfterSendSingleMsg(ctx, req) + err = callbackAfterSendSingleMsg(ctx, m.config, req) if err != nil { log.ZWarn(ctx, "CallbackAfterSendSingleMsg", err, "req", req) } diff --git a/internal/rpc/msg/server.go b/internal/rpc/msg/server.go index 79ae483a2..5b7cd2f66 100644 --- a/internal/rpc/msg/server.go +++ b/internal/rpc/msg/server.go @@ -15,16 +15,20 @@ package msg import ( + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + + "google.golang.org/grpc" + "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/conversation" "github.com/OpenIMSDK/protocol/msg" "github.com/OpenIMSDK/tools/discoveryregistry" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" "github.com/openimsdk/open-im-server/v3/pkg/common/db/localcache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" - "google.golang.org/grpc" ) type ( @@ -40,6 +44,7 @@ type ( ConversationLocalCache *localcache.ConversationLocalCache Handlers MessageInterceptorChain notificationSender *rpcclient.NotificationSender + config *config.GlobalConfig } ) @@ -47,37 +52,36 @@ func (m *msgServer) addInterceptorHandler(interceptorFunc ...MessageInterceptorF m.Handlers = append(m.Handlers, interceptorFunc...) } -// func `(*msgServer).execInterceptorHandler` is unused -// func (m *msgServer) execInterceptorHandler(ctx context.Context, req *msg.SendMsgReq) error { -// for _, handler := range m.Handlers { -// msgData, err := handler(ctx, req) -// if err != nil { -// return err -// } -// req.MsgData = msgData -// } -// return nil -// } +//func (m *msgServer) execInterceptorHandler(ctx context.Context, config *config.GlobalConfig, req *msg.SendMsgReq) error { +// for _, handler := range m.Handlers { +// msgData, err := handler(ctx, config, req) +// if err != nil { +// return err +// } +// req.MsgData = msgData +// } +// return nil +//} -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := cache.NewRedis() +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis(config) if err != nil { return err } - mongo, err := unrelation.NewMongo() + mongo, err := unrelation.NewMongo(config) if err != nil { return err } if err := mongo.CreateMsgIndex(); err != nil { return err } - cacheModel := cache.NewMsgCacheModel(rdb) - msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase()) - conversationClient := rpcclient.NewConversationRpcClient(client) - userRpcClient := rpcclient.NewUserRpcClient(client) - groupRpcClient := rpcclient.NewGroupRpcClient(client) - friendRpcClient := rpcclient.NewFriendRpcClient(client) - msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, cacheModel) + cacheModel := cache.NewMsgCacheModel(rdb, config) + msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase(config.Mongo.Database)) + conversationClient := rpcclient.NewConversationRpcClient(client, config) + userRpcClient := rpcclient.NewUserRpcClient(client, config) + groupRpcClient := rpcclient.NewGroupRpcClient(client, config) + friendRpcClient := rpcclient.NewFriendRpcClient(client, config) + msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, cacheModel, config) if err != nil { return err } @@ -90,8 +94,9 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e GroupLocalCache: localcache.NewGroupLocalCache(&groupRpcClient), ConversationLocalCache: localcache.NewConversationLocalCache(&conversationClient), friend: &friendRpcClient, + config: config, } - s.notificationSender = rpcclient.NewNotificationSender(rpcclient.WithLocalSendMsg(s.SendMsg)) + s.notificationSender = rpcclient.NewNotificationSender(config, rpcclient.WithLocalSendMsg(s.SendMsg)) s.addInterceptorHandler(MessageHasReadEnabled) msg.RegisterMsgServer(server, s) return nil diff --git a/internal/rpc/msg/sync_msg.go b/internal/rpc/msg/sync_msg.go index fa894e034..b714da375 100644 --- a/internal/rpc/msg/sync_msg.go +++ b/internal/rpc/msg/sync_msg.go @@ -88,7 +88,7 @@ func (m *msgServer) PullMessageBySeqs( } func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) { - if err := authverify.CheckAccessV3(ctx, req.UserID); err != nil { + if err := authverify.CheckAccessV3(ctx, req.UserID, m.config); err != nil { return nil, err } conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID) diff --git a/internal/rpc/msg/utils.go b/internal/rpc/msg/utils.go index d8a45e875..48665562a 100644 --- a/internal/rpc/msg/utils.go +++ b/internal/rpc/msg/utils.go @@ -23,16 +23,16 @@ import ( "go.mongodb.org/mongo-driver/mongo" ) -func isMessageHasReadEnabled(msgData *sdkws.MsgData) bool { +func isMessageHasReadEnabled(msgData *sdkws.MsgData, config *config.GlobalConfig) bool { switch { case msgData.ContentType == constant.HasReadReceipt && msgData.SessionType == constant.SingleChatType: - if config.Config.SingleMessageHasReadReceiptEnable { + if config.SingleMessageHasReadReceiptEnable { return true } else { return false } case msgData.ContentType == constant.HasReadReceipt && msgData.SessionType == constant.SuperGroupChatType: - if config.Config.GroupMessageHasReadReceiptEnable { + if config.GroupMessageHasReadReceiptEnable { return true } else { return false diff --git a/internal/rpc/msg/verify.go b/internal/rpc/msg/verify.go index 11055fac1..d72e4923e 100644 --- a/internal/rpc/msg/verify.go +++ b/internal/rpc/msg/verify.go @@ -25,7 +25,6 @@ import ( "github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/utils" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) var ExcludeContentType = []int{constant.HasReadReceipt} @@ -50,10 +49,10 @@ type MessageRevoked struct { func (m *msgServer) messageVerification(ctx context.Context, data *msg.SendMsgReq) error { switch data.MsgData.SessionType { case constant.SingleChatType: - if len(config.Config.Manager.UserID) > 0 && utils.IsContain(data.MsgData.SendID, config.Config.Manager.UserID) { + if len(m.config.Manager.UserID) > 0 && utils.IsContain(data.MsgData.SendID, m.config.Manager.UserID) { return nil } - if utils.IsContain(data.MsgData.SendID, config.Config.IMAdmin.UserID) { + if utils.IsContain(data.MsgData.SendID, m.config.IMAdmin.UserID) { return nil } if data.MsgData.ContentType <= constant.NotificationEnd && @@ -67,7 +66,7 @@ func (m *msgServer) messageVerification(ctx context.Context, data *msg.SendMsgRe if black { return errs.ErrBlockedByPeer.Wrap() } - if *config.Config.MessageVerify.FriendVerify { + if *m.config.MessageVerify.FriendVerify { friend, err := m.friend.IsFriend(ctx, data.MsgData.SendID, data.MsgData.RecvID) if err != nil { return err @@ -90,10 +89,10 @@ func (m *msgServer) messageVerification(ctx context.Context, data *msg.SendMsgRe if groupInfo.GroupType == constant.SuperGroup { return nil } - if len(config.Config.Manager.UserID) > 0 && utils.IsContain(data.MsgData.SendID, config.Config.Manager.UserID) { + if len(m.config.Manager.UserID) > 0 && utils.IsContain(data.MsgData.SendID, m.config.Manager.UserID) { return nil } - if utils.IsContain(data.MsgData.SendID, config.Config.IMAdmin.UserID) { + if utils.IsContain(data.MsgData.SendID, m.config.IMAdmin.UserID) { return nil } if data.MsgData.ContentType <= constant.NotificationEnd && @@ -158,9 +157,6 @@ func (m *msgServer) encapsulateMsgData(msg *sdkws.MsgData) { case constant.Custom: fallthrough case constant.Quote: - utils.SetSwitchFromOptions(msg.Options, constant.IsConversationUpdate, true) - utils.SetSwitchFromOptions(msg.Options, constant.IsUnreadCount, true) - utils.SetSwitchFromOptions(msg.Options, constant.IsSenderSync, true) case constant.Revoke: utils.SetSwitchFromOptions(msg.Options, constant.IsUnreadCount, false) utils.SetSwitchFromOptions(msg.Options, constant.IsOfflinePush, false) diff --git a/internal/rpc/third/log.go b/internal/rpc/third/log.go index 420f399ba..b425dd819 100644 --- a/internal/rpc/third/log.go +++ b/internal/rpc/third/log.go @@ -82,7 +82,7 @@ func (t *thirdServer) UploadLogs(ctx context.Context, req *third.UploadLogsReq) } func (t *thirdServer) DeleteLogs(ctx context.Context, req *third.DeleteLogsReq) (*third.DeleteLogsResp, error) { - if err := authverify.CheckAdmin(ctx); err != nil { + if err := authverify.CheckAdmin(ctx, t.config); err != nil { return nil, err } userID := "" @@ -123,7 +123,7 @@ func dbToPbLogInfos(logs []*relationtb.LogModel) []*third.LogInfo { } func (t *thirdServer) SearchLogs(ctx context.Context, req *third.SearchLogsReq) (*third.SearchLogsResp, error) { - if err := authverify.CheckAdmin(ctx); err != nil { + if err := authverify.CheckAdmin(ctx, t.config); err != nil { return nil, err } var ( diff --git a/internal/rpc/third/s3.go b/internal/rpc/third/s3.go index 7f68f4da8..1975163e5 100644 --- a/internal/rpc/third/s3.go +++ b/internal/rpc/third/s3.go @@ -29,7 +29,6 @@ import ( "github.com/OpenIMSDK/tools/mcontext" "github.com/OpenIMSDK/tools/utils" "github.com/google/uuid" - "github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3/cont" "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" @@ -54,7 +53,7 @@ func (t *thirdServer) PartSize(ctx context.Context, req *third.PartSizeReq) (*th func (t *thirdServer) InitiateMultipartUpload(ctx context.Context, req *third.InitiateMultipartUploadReq) (*third.InitiateMultipartUploadResp, error) { defer log.ZDebug(ctx, "return") - if err := checkUploadName(ctx, req.Name); err != nil { + if err := t.checkUploadName(ctx, req.Name); err != nil { return nil, err } expireTime := time.Now().Add(t.defaultExpire) @@ -133,7 +132,7 @@ func (t *thirdServer) AuthSign(ctx context.Context, req *third.AuthSignReq) (*th func (t *thirdServer) CompleteMultipartUpload(ctx context.Context, req *third.CompleteMultipartUploadReq) (*third.CompleteMultipartUploadResp, error) { defer log.ZDebug(ctx, "return") - if err := checkUploadName(ctx, req.Name); err != nil { + if err := t.checkUploadName(ctx, req.Name); err != nil { return nil, err } result, err := t.s3dataBase.CompleteMultipartUpload(ctx, req.UploadID, req.Parts) @@ -190,13 +189,13 @@ func (t *thirdServer) InitiateFormData(ctx context.Context, req *third.InitiateF if req.Size <= 0 { return nil, errs.ErrArgs.Wrap("size must be greater than 0") } - if err := checkUploadName(ctx, req.Name); err != nil { + if err := t.checkUploadName(ctx, req.Name); err != nil { return nil, err } var duration time.Duration opUserID := mcontext.GetOpUserID(ctx) var key string - if authverify.IsManagerUserID(opUserID) { + if t.IsManagerUserID(opUserID) { if req.Millisecond <= 0 { duration = time.Minute * 10 } else { @@ -256,7 +255,7 @@ func (t *thirdServer) CompleteFormData(ctx context.Context, req *third.CompleteF if err := json.Unmarshal(data, &mate); err != nil { return nil, errs.ErrArgs.Wrap("invalid id " + err.Error()) } - if err := checkUploadName(ctx, mate.Name); err != nil { + if err := t.checkUploadName(ctx, mate.Name); err != nil { return nil, err } info, err := t.s3dataBase.StatObject(ctx, mate.Key) diff --git a/internal/rpc/third/third.go b/internal/rpc/third/third.go index 9dd3ffd65..2bccb5c78 100644 --- a/internal/rpc/third/third.go +++ b/internal/rpc/third/third.go @@ -20,59 +20,63 @@ import ( "net/url" "time" - "github.com/OpenIMSDK/protocol/third" - "github.com/OpenIMSDK/tools/discoveryregistry" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" "github.com/openimsdk/open-im-server/v3/pkg/common/db/mgo" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3/cos" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3/minio" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3/oss" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" - "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" + "google.golang.org/grpc" + + "github.com/OpenIMSDK/protocol/third" + "github.com/OpenIMSDK/tools/discoveryregistry" + + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" + "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { - mongo, err := unrelation.NewMongo() +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { + mongo, err := unrelation.NewMongo(config) if err != nil { return err } - logdb, err := mgo.NewLogMongo(mongo.GetDatabase()) + logdb, err := mgo.NewLogMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - s3db, err := mgo.NewS3Mongo(mongo.GetDatabase()) + s3db, err := mgo.NewS3Mongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } - apiURL := config.Config.Object.ApiURL + apiURL := config.Object.ApiURL if apiURL == "" { return fmt.Errorf("api url is empty") } - if _, err := url.Parse(config.Config.Object.ApiURL); err != nil { + if _, parseErr := url.Parse(config.Object.ApiURL); parseErr != nil { return err } if apiURL[len(apiURL)-1] != '/' { apiURL += "/" } apiURL += "object/" - rdb, err := cache.NewRedis() + rdb, err := cache.NewRedis(config) if err != nil { return err } - // Select based on the configuration file strategy - enable := config.Config.Object.Enable + // 根据配置文件策略选择 oss 方式 + enable := config.Object.Enable var o s3.Interface - switch config.Config.Object.Enable { + switch enable { case "minio": - o, err = minio.NewMinio(cache.NewMinioCache(rdb)) + o, err = minio.NewMinio(cache.NewMinioCache(rdb), minio.Config(config.Object.Minio)) case "cos": - o, err = cos.NewCos() + o, err = cos.NewCos(cos.Config(config.Object.Cos)) case "oss": - o, err = oss.NewOSS() + o, err = oss.NewOSS(oss.Config(config.Object.Oss)) default: err = fmt.Errorf("invalid object enable: %s", enable) } @@ -81,10 +85,11 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e } third.RegisterThirdServer(server, &thirdServer{ apiURL: apiURL, - thirdDatabase: controller.NewThirdDatabase(cache.NewMsgCacheModel(rdb), logdb), - userRpcClient: rpcclient.NewUserRpcClient(client), + thirdDatabase: controller.NewThirdDatabase(cache.NewMsgCacheModel(rdb, config), logdb), + userRpcClient: rpcclient.NewUserRpcClient(client, config), s3dataBase: controller.NewS3Database(rdb, o, s3db), defaultExpire: time.Hour * 24 * 7, + config: config, }) return nil } @@ -95,6 +100,7 @@ type thirdServer struct { s3dataBase controller.S3Database userRpcClient rpcclient.UserRpcClient defaultExpire time.Duration + config *config.GlobalConfig } func (t *thirdServer) FcmUpdateToken(ctx context.Context, req *third.FcmUpdateTokenReq) (resp *third.FcmUpdateTokenResp, err error) { diff --git a/internal/rpc/third/tool.go b/internal/rpc/third/tool.go index cf25f9820..6591134d6 100644 --- a/internal/rpc/third/tool.go +++ b/internal/rpc/third/tool.go @@ -21,10 +21,11 @@ import ( "strings" "unicode/utf8" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/OpenIMSDK/protocol/third" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/mcontext" - "github.com/openimsdk/open-im-server/v3/pkg/authverify" ) func toPbMapArray(m map[string][]string) []*third.KeyValues { @@ -41,7 +42,7 @@ func toPbMapArray(m map[string][]string) []*third.KeyValues { return res } -func checkUploadName(ctx context.Context, name string) error { +func (t *thirdServer) checkUploadName(ctx context.Context, name string) error { if name == "" { return errs.ErrArgs.Wrap("name is empty") } @@ -55,7 +56,7 @@ func checkUploadName(ctx context.Context, name string) error { if opUserID == "" { return errs.ErrNoPermission.Wrap("opUserID is empty") } - if !authverify.IsManagerUserID(opUserID) { + if !authverify.IsManagerUserID(opUserID, t.config) { if !strings.HasPrefix(name, opUserID+"/") { return errs.ErrNoPermission.Wrap(fmt.Sprintf("name must start with `%s/`", opUserID)) } @@ -79,3 +80,7 @@ func checkValidObjectName(objectName string) error { } return checkValidObjectNamePrefix(objectName) } + +func (t *thirdServer) IsManagerUserID(opUserID string) bool { + return authverify.IsManagerUserID(opUserID, t.config) +} diff --git a/internal/rpc/user/callback.go b/internal/rpc/user/callback.go index 1437257f7..34f211973 100644 --- a/internal/rpc/user/callback.go +++ b/internal/rpc/user/callback.go @@ -24,8 +24,8 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/http" ) -func CallbackBeforeUpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) error { - if !config.Config.Callback.CallbackBeforeUpdateUserInfo.Enable { +func CallbackBeforeUpdateUserInfo(ctx context.Context, globalConfig *config.GlobalConfig, req *pbuser.UpdateUserInfoReq) error { + if !globalConfig.Callback.CallbackBeforeUpdateUserInfo.Enable { return nil } cbReq := &cbapi.CallbackBeforeUpdateUserInfoReq{ @@ -35,7 +35,7 @@ func CallbackBeforeUpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInf Nickname: &req.UserInfo.Nickname, } resp := &cbapi.CallbackBeforeUpdateUserInfoResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeUpdateUserInfo); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeUpdateUserInfo); err != nil { return err } utils.NotNilReplace(&req.UserInfo.FaceURL, resp.FaceURL) @@ -43,8 +43,8 @@ func CallbackBeforeUpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInf utils.NotNilReplace(&req.UserInfo.Nickname, resp.Nickname) return nil } -func CallbackAfterUpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) error { - if !config.Config.Callback.CallbackAfterUpdateUserInfo.Enable { +func CallbackAfterUpdateUserInfo(ctx context.Context, globalConfig *config.GlobalConfig, req *pbuser.UpdateUserInfoReq) error { + if !globalConfig.Callback.CallbackAfterUpdateUserInfo.Enable { return nil } cbReq := &cbapi.CallbackAfterUpdateUserInfoReq{ @@ -54,13 +54,13 @@ func CallbackAfterUpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfo Nickname: req.UserInfo.Nickname, } resp := &cbapi.CallbackAfterUpdateUserInfoResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeUpdateUserInfo); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeUpdateUserInfo); err != nil { return err } return nil } -func CallbackBeforeUpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) error { - if !config.Config.Callback.CallbackBeforeUpdateUserInfoEx.Enable { +func CallbackBeforeUpdateUserInfoEx(ctx context.Context, globalConfig *config.GlobalConfig, req *pbuser.UpdateUserInfoExReq) error { + if !globalConfig.Callback.CallbackBeforeUpdateUserInfoEx.Enable { return nil } cbReq := &cbapi.CallbackBeforeUpdateUserInfoExReq{ @@ -70,7 +70,7 @@ func CallbackBeforeUpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserI Nickname: req.UserInfo.Nickname, } resp := &cbapi.CallbackBeforeUpdateUserInfoExResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeUpdateUserInfoEx); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeUpdateUserInfoEx); err != nil { return err } utils.NotNilReplace(req.UserInfo.FaceURL, resp.FaceURL) @@ -78,8 +78,8 @@ func CallbackBeforeUpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserI utils.NotNilReplace(req.UserInfo.Nickname, resp.Nickname) return nil } -func CallbackAfterUpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) error { - if !config.Config.Callback.CallbackAfterUpdateUserInfoEx.Enable { +func CallbackAfterUpdateUserInfoEx(ctx context.Context, globalConfig *config.GlobalConfig, req *pbuser.UpdateUserInfoExReq) error { + if !globalConfig.Callback.CallbackAfterUpdateUserInfoEx.Enable { return nil } cbReq := &cbapi.CallbackAfterUpdateUserInfoExReq{ @@ -89,14 +89,14 @@ func CallbackAfterUpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserIn Nickname: req.UserInfo.Nickname, } resp := &cbapi.CallbackAfterUpdateUserInfoExResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeUpdateUserInfoEx); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeUpdateUserInfoEx); err != nil { return err } return nil } -func CallbackBeforeUserRegister(ctx context.Context, req *pbuser.UserRegisterReq) error { - if !config.Config.Callback.CallbackBeforeUserRegister.Enable { +func CallbackBeforeUserRegister(ctx context.Context, globalConfig *config.GlobalConfig, req *pbuser.UserRegisterReq) error { + if !globalConfig.Callback.CallbackBeforeUserRegister.Enable { return nil } cbReq := &cbapi.CallbackBeforeUserRegisterReq{ @@ -106,7 +106,7 @@ func CallbackBeforeUserRegister(ctx context.Context, req *pbuser.UserRegisterReq } resp := &cbapi.CallbackBeforeUserRegisterResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackBeforeUpdateUserInfo); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackBeforeUpdateUserInfo); err != nil { return err } if len(resp.Users) != 0 { @@ -115,8 +115,8 @@ func CallbackBeforeUserRegister(ctx context.Context, req *pbuser.UserRegisterReq return nil } -func CallbackAfterUserRegister(ctx context.Context, req *pbuser.UserRegisterReq) error { - if !config.Config.Callback.CallbackAfterUserRegister.Enable { +func CallbackAfterUserRegister(ctx context.Context, globalConfig *config.GlobalConfig, req *pbuser.UserRegisterReq) error { + if !globalConfig.Callback.CallbackAfterUserRegister.Enable { return nil } cbReq := &cbapi.CallbackAfterUserRegisterReq{ @@ -126,7 +126,7 @@ func CallbackAfterUserRegister(ctx context.Context, req *pbuser.UserRegisterReq) } resp := &cbapi.CallbackAfterUserRegisterResp{} - if err := http.CallBackPostReturn(ctx, config.Config.Callback.CallbackUrl, cbReq, resp, config.Config.Callback.CallbackAfterUpdateUserInfo); err != nil { + if err := http.CallBackPostReturn(ctx, globalConfig.Callback.CallbackUrl, cbReq, resp, globalConfig.Callback.CallbackAfterUpdateUserInfo); err != nil { return err } return nil diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index 7ed3ff7d6..02e641d20 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -21,26 +21,34 @@ import ( "strings" "time" + "github.com/OpenIMSDK/tools/pagination" + + "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" + + "github.com/OpenIMSDK/tools/tx" + + "github.com/openimsdk/open-im-server/v3/pkg/common/db/mgo" + "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/sdkws" - pbuser "github.com/OpenIMSDK/protocol/user" - registry "github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" - "github.com/OpenIMSDK/tools/pagination" - "github.com/OpenIMSDK/tools/tx" - "github.com/OpenIMSDK/tools/utils" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" + + registry "github.com/OpenIMSDK/tools/discoveryregistry" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/convert" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/controller" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/mgo" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" tablerelation "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient/notification" + + pbuser "github.com/OpenIMSDK/protocol/user" + "github.com/OpenIMSDK/tools/utils" "google.golang.org/grpc" ) @@ -51,6 +59,7 @@ type userServer struct { friendRpcClient *rpcclient.FriendRpcClient groupRpcClient *rpcclient.GroupRpcClient RegisterCenter registry.SvcDiscoveryRegistry + config *config.GlobalConfig } func (s *userServer) GetGroupOnlineUser(ctx context.Context, req *pbuser.GetGroupOnlineUserReq) (*pbuser.GetGroupOnlineUserResp, error) { @@ -58,39 +67,40 @@ func (s *userServer) GetGroupOnlineUser(ctx context.Context, req *pbuser.GetGrou panic("implement me") } -func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { - rdb, err := cache.NewRedis() +func Start(config *config.GlobalConfig, client registry.SvcDiscoveryRegistry, server *grpc.Server) error { + rdb, err := cache.NewRedis(config) if err != nil { return err } - mongo, err := unrelation.NewMongo() + mongo, err := unrelation.NewMongo(config) if err != nil { return err } users := make([]*tablerelation.UserModel, 0) - if len(config.Config.IMAdmin.UserID) != len(config.Config.IMAdmin.Nickname) { - return errors.New("len(config.Config.AppNotificationAdmin.AppManagerUid) != len(config.Config.AppNotificationAdmin.Nickname)") + if len(config.IMAdmin.UserID) != len(config.IMAdmin.Nickname) { + return errors.New("len(s.config.AppNotificationAdmin.AppManagerUid) != len(s.config.AppNotificationAdmin.Nickname)") } - for k, v := range config.Config.IMAdmin.UserID { - users = append(users, &tablerelation.UserModel{UserID: v, Nickname: config.Config.IMAdmin.Nickname[k], AppMangerLevel: constant.AppNotificationAdmin}) + for k, v := range config.IMAdmin.UserID { + users = append(users, &tablerelation.UserModel{UserID: v, Nickname: config.IMAdmin.Nickname[k], AppMangerLevel: constant.AppNotificationAdmin}) } - userDB, err := mgo.NewUserMongo(mongo.GetDatabase()) + userDB, err := mgo.NewUserMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return err } cache := cache.NewUserCacheRedis(rdb, userDB, cache.GetDefaultOpt()) - userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase()) + userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase(config.Mongo.Database)) database := controller.NewUserDatabase(userDB, cache, tx.NewMongo(mongo.GetClient()), userMongoDB) - friendRpcClient := rpcclient.NewFriendRpcClient(client) - groupRpcClient := rpcclient.NewGroupRpcClient(client) - msgRpcClient := rpcclient.NewMessageRpcClient(client) + friendRpcClient := rpcclient.NewFriendRpcClient(client, config) + groupRpcClient := rpcclient.NewGroupRpcClient(client, config) + msgRpcClient := rpcclient.NewMessageRpcClient(client, config) u := &userServer{ UserDatabase: database, RegisterCenter: client, friendRpcClient: &friendRpcClient, groupRpcClient: &groupRpcClient, - friendNotificationSender: notification.NewFriendNotificationSender(&msgRpcClient, notification.WithDBFunc(database.FindWithError)), - userNotificationSender: notification.NewUserNotificationSender(&msgRpcClient, notification.WithUserFunc(database.FindWithError)), + friendNotificationSender: notification.NewFriendNotificationSender(config, &msgRpcClient, notification.WithDBFunc(database.FindWithError)), + userNotificationSender: notification.NewUserNotificationSender(config, &msgRpcClient, notification.WithUserFunc(database.FindWithError)), + config: config, } pbuser.RegisterUserServer(server, u) return u.UserDatabase.InitOnce(context.Background(), users) @@ -111,11 +121,11 @@ func (s *userServer) GetDesignateUsers(ctx context.Context, req *pbuser.GetDesig func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserInfoReq) (resp *pbuser.UpdateUserInfoResp, err error) { resp = &pbuser.UpdateUserInfoResp{} - err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID) + err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config) if err != nil { return nil, err } - if err := CallbackBeforeUpdateUserInfo(ctx, req); err != nil { + if err := CallbackBeforeUpdateUserInfo(ctx, s.config, req); err != nil { return nil, err } data := convert.UserPb2DBMap(req.UserInfo) @@ -128,29 +138,29 @@ func (s *userServer) UpdateUserInfo(ctx context.Context, req *pbuser.UpdateUserI return nil, err } if req.UserInfo.Nickname != "" || req.UserInfo.FaceURL != "" { - if err := s.groupRpcClient.NotificationUserInfoUpdate(ctx, req.UserInfo.UserID); err != nil { + if err = s.groupRpcClient.NotificationUserInfoUpdate(ctx, req.UserInfo.UserID); err != nil { log.ZError(ctx, "NotificationUserInfoUpdate", err) } } for _, friendID := range friends { s.friendNotificationSender.FriendInfoUpdatedNotification(ctx, req.UserInfo.UserID, friendID) } - if err := CallbackAfterUpdateUserInfo(ctx, req); err != nil { + if err = CallbackAfterUpdateUserInfo(ctx, s.config, req); err != nil { return nil, err } - if err := s.groupRpcClient.NotificationUserInfoUpdate(ctx, req.UserInfo.UserID); err != nil { + if err = s.groupRpcClient.NotificationUserInfoUpdate(ctx, req.UserInfo.UserID); err != nil { log.ZError(ctx, "NotificationUserInfoUpdate", err, "userID", req.UserInfo.UserID) } return resp, nil } func (s *userServer) UpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUserInfoExReq) (resp *pbuser.UpdateUserInfoExResp, err error) { resp = &pbuser.UpdateUserInfoExResp{} - err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID) + err = authverify.CheckAccessV3(ctx, req.UserInfo.UserID, s.config) if err != nil { return nil, err } - if err = CallbackBeforeUpdateUserInfoEx(ctx, req); err != nil { + if err = CallbackBeforeUpdateUserInfoEx(ctx, s.config, req); err != nil { return nil, err } data := convert.UserPb2DBMapEx(req.UserInfo) @@ -170,7 +180,7 @@ func (s *userServer) UpdateUserInfoEx(ctx context.Context, req *pbuser.UpdateUse for _, friendID := range friends { s.friendNotificationSender.FriendInfoUpdatedNotification(ctx, req.UserInfo.UserID, friendID) } - if err := CallbackAfterUpdateUserInfoEx(ctx, req); err != nil { + if err := CallbackAfterUpdateUserInfoEx(ctx, s.config, req); err != nil { return nil, err } if err := s.groupRpcClient.NotificationUserInfoUpdate(ctx, req.UserInfo.UserID); err != nil { @@ -197,7 +207,7 @@ func (s *userServer) AccountCheck(ctx context.Context, req *pbuser.AccountCheckR if utils.Duplicate(req.CheckUserIDs) { return nil, errs.ErrArgs.Wrap("userID repeated") } - err = authverify.CheckAdmin(ctx) + err = authverify.CheckAdmin(ctx, s.config) if err != nil { return nil, err } @@ -244,8 +254,8 @@ func (s *userServer) UserRegister(ctx context.Context, req *pbuser.UserRegisterR if len(req.Users) == 0 { return nil, errs.ErrArgs.Wrap("users is empty") } - if req.Secret != config.Config.Secret { - log.ZDebug(ctx, "UserRegister", config.Config.Secret, req.Secret) + if req.Secret != s.config.Secret { + log.ZDebug(ctx, "UserRegister", s.config.Secret, req.Secret) return nil, errs.ErrNoPermission.Wrap("secret invalid") } if utils.DuplicateAny(req.Users, func(e *sdkws.UserInfo) string { return e.UserID }) { @@ -268,7 +278,7 @@ func (s *userServer) UserRegister(ctx context.Context, req *pbuser.UserRegisterR if exist { return nil, errs.ErrRegisteredAlready.Wrap("userID registered already") } - if err := CallbackBeforeUserRegister(ctx, req); err != nil { + if err := CallbackBeforeUserRegister(ctx, s.config, req); err != nil { return nil, err } now := time.Now() @@ -288,7 +298,7 @@ func (s *userServer) UserRegister(ctx context.Context, req *pbuser.UserRegisterR return nil, err } - if err := CallbackAfterUserRegister(ctx, req); err != nil { + if err := CallbackAfterUserRegister(ctx, s.config, req); err != nil { return nil, err } return resp, nil @@ -383,7 +393,7 @@ func (s *userServer) GetSubscribeUsersStatus(ctx context.Context, // ProcessUserCommandAdd user general function add. func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.ProcessUserCommandAddReq) (*pbuser.ProcessUserCommandAddResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID) + err := authverify.CheckAccessV3(ctx, req.UserID, s.config) if err != nil { return nil, err } @@ -414,7 +424,7 @@ func (s *userServer) ProcessUserCommandAdd(ctx context.Context, req *pbuser.Proc // ProcessUserCommandDelete user general function delete. func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.ProcessUserCommandDeleteReq) (*pbuser.ProcessUserCommandDeleteResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID) + err := authverify.CheckAccessV3(ctx, req.UserID, s.config) if err != nil { return nil, err } @@ -437,7 +447,7 @@ func (s *userServer) ProcessUserCommandDelete(ctx context.Context, req *pbuser.P // ProcessUserCommandUpdate user general function update. func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.ProcessUserCommandUpdateReq) (*pbuser.ProcessUserCommandUpdateResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID) + err := authverify.CheckAccessV3(ctx, req.UserID, s.config) if err != nil { return nil, err } @@ -469,7 +479,7 @@ func (s *userServer) ProcessUserCommandUpdate(ctx context.Context, req *pbuser.P func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.ProcessUserCommandGetReq) (*pbuser.ProcessUserCommandGetResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID) + err := authverify.CheckAccessV3(ctx, req.UserID, s.config) if err != nil { return nil, err } @@ -498,7 +508,7 @@ func (s *userServer) ProcessUserCommandGet(ctx context.Context, req *pbuser.Proc } func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.ProcessUserCommandGetAllReq) (*pbuser.ProcessUserCommandGetAllResp, error) { - err := authverify.CheckAccessV3(ctx, req.UserID) + err := authverify.CheckAccessV3(ctx, req.UserID, s.config) if err != nil { return nil, err } @@ -527,7 +537,7 @@ func (s *userServer) ProcessUserCommandGetAll(ctx context.Context, req *pbuser.P } func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.AddNotificationAccountReq) (*pbuser.AddNotificationAccountResp, error) { - if err := authverify.CheckIMAdmin(ctx); err != nil { + if err := authverify.CheckIMAdmin(ctx, s.config); err != nil { return nil, err } @@ -570,7 +580,7 @@ func (s *userServer) AddNotificationAccount(ctx context.Context, req *pbuser.Add } func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbuser.UpdateNotificationAccountInfoReq) (*pbuser.UpdateNotificationAccountInfoResp, error) { - if err := authverify.CheckIMAdmin(ctx); err != nil { + if err := authverify.CheckIMAdmin(ctx, s.config); err != nil { return nil, err } @@ -597,7 +607,7 @@ func (s *userServer) UpdateNotificationAccountInfo(ctx context.Context, req *pbu func (s *userServer) SearchNotificationAccount(ctx context.Context, req *pbuser.SearchNotificationAccountReq) (*pbuser.SearchNotificationAccountResp, error) { // Check if user is an admin - if err := authverify.CheckIMAdmin(ctx); err != nil { + if err := authverify.CheckIMAdmin(ctx, s.config); err != nil { return nil, err } @@ -671,7 +681,7 @@ func (s *userServer) userModelToResp(users []*relation.UserModel, pagination pag accounts := make([]*pbuser.NotificationAccountInfo, 0) var total int64 for _, v := range users { - if v.AppMangerLevel == constant.AppNotificationAdmin && !utils.IsContain(v.UserID, config.Config.IMAdmin.UserID) { + if v.AppMangerLevel == constant.AppNotificationAdmin && !utils.IsContain(v.UserID, s.config.IMAdmin.UserID) { temp := &pbuser.NotificationAccountInfo{ UserID: v.UserID, FaceURL: v.FaceURL, diff --git a/internal/tools/cron_task.go b/internal/tools/cron_task.go index 9b74a5767..ce87e9b90 100644 --- a/internal/tools/cron_task.go +++ b/internal/tools/cron_task.go @@ -23,37 +23,38 @@ import ( "time" "github.com/OpenIMSDK/tools/errs" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/redis/go-redis/v9" "github.com/robfig/cron/v3" + + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" ) -func StartTask() error { - fmt.Println("Cron task start, config:", config.Config.ChatRecordsClearTime) +func StartTask(config *config.GlobalConfig) error { + fmt.Println("cron task start, config", config.ChatRecordsClearTime) - msgTool, err := InitMsgTool() + msgTool, err := InitMsgTool(config) if err != nil { return err } msgTool.convertTools() - rdb, err := cache.NewRedis() + rdb, err := cache.NewRedis(config) if err != nil { return err } // register cron tasks var crontab = cron.New() - fmt.Printf("Start chatRecordsClearTime cron task, cron config: %s\n", config.Config.ChatRecordsClearTime) - _, err = crontab.AddFunc(config.Config.ChatRecordsClearTime, cronWrapFunc(rdb, "cron_clear_msg_and_fix_seq", msgTool.AllConversationClearMsgAndFixSeq)) + fmt.Printf("Start chatRecordsClearTime cron task, cron config: %s\n", config.ChatRecordsClearTime) + _, err = crontab.AddFunc(config.ChatRecordsClearTime, cronWrapFunc(config,rdb, "cron_clear_msg_and_fix_seq", msgTool.AllConversationClearMsgAndFixSeq)) if err != nil { return errs.Wrap(err) } - fmt.Printf("Start msgDestruct cron task, cron config: %s\n", config.Config.MsgDestructTime) - _, err = crontab.AddFunc(config.Config.MsgDestructTime, cronWrapFunc(rdb, "cron_conversations_destruct_msgs", msgTool.ConversationsDestructMsgs)) + fmt.Printf("Start msgDestruct cron task, cron config: %s\n", config.MsgDestructTime) + _, err = crontab.AddFunc(config.MsgDestructTime, cronWrapFunc(config,rdb, "cron_conversations_destruct_msgs", msgTool.ConversationsDestructMsgs)) if err != nil { return errs.Wrap(err, "cron_conversations_destruct_msgs") } @@ -91,8 +92,8 @@ func netlock(rdb redis.UniversalClient, key string, ttl time.Duration) bool { return ok } -func cronWrapFunc(rdb redis.UniversalClient, key string, fn func()) func() { - enableCronLocker := config.Config.EnableCronLocker +func cronWrapFunc(config *config.GlobalConfig, rdb redis.UniversalClient, key string, fn func()) func() { + enableCronLocker := config.EnableCronLocker return func() { // if don't enable cron-locker, call fn directly. if !enableCronLocker { diff --git a/internal/tools/cron_task_test.go b/internal/tools/cron_task_test.go index 28bc2c945..fcae5a5f6 100644 --- a/internal/tools/cron_task_test.go +++ b/internal/tools/cron_task_test.go @@ -15,8 +15,12 @@ package tools import ( + "flag" "fmt" + "github.com/OpenIMSDK/tools/errs" + "gopkg.in/yaml.v3" "math/rand" + "os" "sync" "testing" "time" @@ -61,7 +65,7 @@ func TestCronWrapFunc(t *testing.T) { start := time.Now() key := fmt.Sprintf("cron-%v", rand.Int31()) crontab := cron.New(cron.WithSeconds()) - crontab.AddFunc("*/1 * * * * *", cronWrapFunc(rdb, key, cb)) + crontab.AddFunc("*/1 * * * * *", cronWrapFunc(config.NewGlobalConfig(), rdb, key, cb)) crontab.Start() <-done @@ -71,7 +75,11 @@ func TestCronWrapFunc(t *testing.T) { } func TestCronWrapFuncWithNetlock(t *testing.T) { - config.Config.EnableCronLocker = true + conf, err := initCfg() + if err != nil { + panic(err) + } + conf.EnableCronLocker = true rdb := redis.NewClient(&redis.Options{}) defer rdb.Close() @@ -80,10 +88,10 @@ func TestCronWrapFuncWithNetlock(t *testing.T) { crontab := cron.New(cron.WithSeconds()) key := fmt.Sprintf("cron-%v", rand.Int31()) - crontab.AddFunc("*/1 * * * * *", cronWrapFunc(rdb, key, func() { + crontab.AddFunc("*/1 * * * * *", cronWrapFunc(conf, rdb, key, func() { done <- "host1" })) - crontab.AddFunc("*/1 * * * * *", cronWrapFunc(rdb, key, func() { + crontab.AddFunc("*/1 * * * * *", cronWrapFunc(conf, rdb, key, func() { done <- "host2" })) crontab.Start() @@ -94,3 +102,22 @@ func TestCronWrapFuncWithNetlock(t *testing.T) { crontab.Stop() } + +func initCfg() (*config.GlobalConfig, error) { + const ( + defaultCfgPath = "../../../../../config/config.yaml" + ) + + cfgPath := flag.String("c", defaultCfgPath, "Path to the configuration file") + data, err := os.ReadFile(*cfgPath) + if err != nil { + return nil, errs.Wrap(err, "ReadFile unmarshal failed") + } + + conf := config.NewGlobalConfig() + err = yaml.Unmarshal(data, &conf) + if err != nil { + return nil, errs.Wrap(err, "InitConfig unmarshal failed") + } + return conf, nil +} diff --git a/internal/tools/msg.go b/internal/tools/msg.go index f2df0d337..67c3895cb 100644 --- a/internal/tools/msg.go +++ b/internal/tools/msg.go @@ -46,10 +46,12 @@ type MsgTool struct { userDatabase controller.UserDatabase groupDatabase controller.GroupDatabase msgNotificationSender *notification.MsgNotificationSender + Config *config.GlobalConfig } func NewMsgTool(msgDatabase controller.CommonMsgDatabase, userDatabase controller.UserDatabase, - groupDatabase controller.GroupDatabase, conversationDatabase controller.ConversationDatabase, msgNotificationSender *notification.MsgNotificationSender, + groupDatabase controller.GroupDatabase, conversationDatabase controller.ConversationDatabase, + msgNotificationSender *notification.MsgNotificationSender, config *config.GlobalConfig, ) *MsgTool { return &MsgTool{ msgDatabase: msgDatabase, @@ -57,32 +59,33 @@ func NewMsgTool(msgDatabase controller.CommonMsgDatabase, userDatabase controlle groupDatabase: groupDatabase, conversationDatabase: conversationDatabase, msgNotificationSender: msgNotificationSender, + Config: config, } } -func InitMsgTool() (*MsgTool, error) { - rdb, err := cache.NewRedis() +func InitMsgTool(config *config.GlobalConfig) (*MsgTool, error) { + rdb, err := cache.NewRedis(config) if err != nil { return nil, err } - mongo, err := unrelation.NewMongo() + mongo, err := unrelation.NewMongo(config) if err != nil { return nil, err } - discov, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) + discov, err := kdisc.NewDiscoveryRegister(config) if err != nil { return nil, err } discov.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) - userDB, err := mgo.NewUserMongo(mongo.GetDatabase()) + userDB, err := mgo.NewUserMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return nil, err } - msgDatabase, err := controller.InitCommonMsgDatabase(rdb, mongo.GetDatabase()) + msgDatabase, err := controller.InitCommonMsgDatabase(rdb, mongo.GetDatabase(config.Mongo.Database), config) if err != nil { return nil, err } - userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase()) + userMongoDB := unrelation.NewUserMongoDriver(mongo.GetDatabase(config.Mongo.Database)) ctxTx := tx.NewMongo(mongo.GetClient()) userDatabase := controller.NewUserDatabase( userDB, @@ -90,19 +93,19 @@ func InitMsgTool() (*MsgTool, error) { ctxTx, userMongoDB, ) - groupDB, err := mgo.NewGroupMongo(mongo.GetDatabase()) + groupDB, err := mgo.NewGroupMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return nil, err } - groupMemberDB, err := mgo.NewGroupMember(mongo.GetDatabase()) + groupMemberDB, err := mgo.NewGroupMember(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return nil, err } - groupRequestDB, err := mgo.NewGroupRequestMgo(mongo.GetDatabase()) + groupRequestDB, err := mgo.NewGroupRequestMgo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return nil, err } - conversationDB, err := mgo.NewConversationMongo(mongo.GetDatabase()) + conversationDB, err := mgo.NewConversationMongo(mongo.GetDatabase(config.Mongo.Database)) if err != nil { return nil, err } @@ -112,9 +115,9 @@ func InitMsgTool() (*MsgTool, error) { cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), ctxTx, ) - msgRpcClient := rpcclient.NewMessageRpcClient(discov) - msgNotificationSender := notification.NewMsgNotificationSender(rpcclient.WithRpcClient(&msgRpcClient)) - msgTool := NewMsgTool(msgDatabase, userDatabase, groupDatabase, conversationDatabase, msgNotificationSender) + msgRpcClient := rpcclient.NewMessageRpcClient(discov, config) + msgNotificationSender := notification.NewMsgNotificationSender(config, rpcclient.WithRpcClient(&msgRpcClient)) + msgTool := NewMsgTool(msgDatabase, userDatabase, groupDatabase, conversationDatabase, msgNotificationSender, config) return msgTool, nil } @@ -176,8 +179,8 @@ func (c *MsgTool) AllConversationClearMsgAndFixSeq() { func (c *MsgTool) ClearConversationsMsg(ctx context.Context, conversationIDs []string) { for _, conversationID := range conversationIDs { - if err := c.msgDatabase.DeleteConversationMsgsAndSetMinSeq(ctx, conversationID, int64(config.Config.RetainChatRecords*24*60*60)); err != nil { - log.ZError(ctx, "DeleteUserSuperGroupMsgsAndSetMinSeq failed", err, "conversationID", conversationID, "DBRetainChatRecords", config.Config.RetainChatRecords) + if err := c.msgDatabase.DeleteConversationMsgsAndSetMinSeq(ctx, conversationID, int64(c.Config.RetainChatRecords*24*60*60)); err != nil { + log.ZError(ctx, "DeleteUserSuperGroupMsgsAndSetMinSeq failed", err, "conversationID", conversationID, "DBRetainChatRecords", c.Config.RetainChatRecords) } if err := c.checkMaxSeq(ctx, conversationID); err != nil { log.ZError(ctx, "fixSeq failed", err, "conversationID", conversationID) diff --git a/pkg/authverify/token.go b/pkg/authverify/token.go index 0a46af3ec..26c43532d 100644 --- a/pkg/authverify/token.go +++ b/pkg/authverify/token.go @@ -26,61 +26,60 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) -func Secret() jwt.Keyfunc { +func Secret(secret string) jwt.Keyfunc { return func(token *jwt.Token) (any, error) { - return []byte(config.Config.Secret), nil + return []byte(secret), nil } } -func CheckAccessV3(ctx context.Context, ownerUserID string) (err error) { +func CheckAccessV3(ctx context.Context, ownerUserID string, config *config.GlobalConfig) (err error) { opUserID := mcontext.GetOpUserID(ctx) - if len(config.Config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Config.Manager.UserID) { + if len(config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Manager.UserID) { return nil } - if utils.IsContain(opUserID, config.Config.IMAdmin.UserID) { + if utils.IsContain(opUserID, config.IMAdmin.UserID) { return nil } if opUserID == ownerUserID { return nil } - return errs.Wrap(errs.ErrNoPermission, "CheckAccessV3: no permission for user "+opUserID) + return errs.ErrNoPermission.Wrap("ownerUserID", ownerUserID) } -func IsAppManagerUid(ctx context.Context) bool { - return (len(config.Config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.Manager.UserID)) || - utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) +func IsAppManagerUid(ctx context.Context, config *config.GlobalConfig) bool { + return (len(config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Manager.UserID)) || + utils.IsContain(mcontext.GetOpUserID(ctx), config.IMAdmin.UserID) } -func CheckAdmin(ctx context.Context) error { - if len(config.Config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.Manager.UserID) { +func CheckAdmin(ctx context.Context, config *config.GlobalConfig) error { + if len(config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Manager.UserID) { return nil } - if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) { + if utils.IsContain(mcontext.GetOpUserID(ctx), config.IMAdmin.UserID) { return nil } return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx))) } - -func CheckIMAdmin(ctx context.Context) error { - if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) { +func CheckIMAdmin(ctx context.Context, config *config.GlobalConfig) error { + if utils.IsContain(mcontext.GetOpUserID(ctx), config.IMAdmin.UserID) { return nil } - if len(config.Config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.Manager.UserID) { + if len(config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Manager.UserID) { return nil } return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not CheckIMAdmin userID", mcontext.GetOpUserID(ctx))) } -func ParseRedisInterfaceToken(redisToken any) (*tokenverify.Claims, error) { - return tokenverify.GetClaimFromToken(string(redisToken.([]uint8)), Secret()) +func ParseRedisInterfaceToken(redisToken any, secret string) (*tokenverify.Claims, error) { + return tokenverify.GetClaimFromToken(string(redisToken.([]uint8)), Secret(secret)) } -func IsManagerUserID(opUserID string) bool { - return (len(config.Config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Config.Manager.UserID)) || utils.IsContain(opUserID, config.Config.IMAdmin.UserID) +func IsManagerUserID(opUserID string, config *config.GlobalConfig) bool { + return (len(config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Manager.UserID)) || utils.IsContain(opUserID, config.IMAdmin.UserID) } -func WsVerifyToken(token, userID string, platformID int) error { - claim, err := tokenverify.GetClaimFromToken(token, Secret()) +func WsVerifyToken(token, userID, secret string, platformID int) error { + claim, err := tokenverify.GetClaimFromToken(token, Secret(secret)) if err != nil { return err } diff --git a/pkg/common/cmd/api.go b/pkg/common/cmd/api.go index 47e4116e3..859508ce3 100644 --- a/pkg/common/cmd/api.go +++ b/pkg/common/cmd/api.go @@ -15,54 +15,44 @@ package cmd import ( - "errors" - "fmt" - "github.com/OpenIMSDK/protocol/constant" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/internal/api" "github.com/spf13/cobra" + + "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) type ApiCmd struct { *RootCmd + initFunc func(config *config.GlobalConfig, port int, promPort int) error } func NewApiCmd() *ApiCmd { - ret := &ApiCmd{NewRootCmd("api")} + ret := &ApiCmd{RootCmd: NewRootCmd("api"), initFunc: api.Start} ret.SetRootCmdPt(ret) - + ret.addPreRun() + ret.addRunE() return ret } -// AddApi configures the API command to run with specified ports for the API and Prometheus monitoring. -// It ensures error handling for port retrieval and only proceeds if both port numbers are successfully obtained. -func (a *ApiCmd) AddApi(f func(port int, promPort int) error) { - a.Command.RunE = func(cmd *cobra.Command, args []string) error { - port, err := a.getPortFlag(cmd) - if err != nil { - return err - } - - promPort, err := a.getPrometheusPortFlag(cmd) - if err != nil { - return err - } +func (a *ApiCmd) addPreRun() { + a.Command.PreRun = func(cmd *cobra.Command, args []string) { + a.port = a.getPortFlag(cmd) + a.prometheusPort = a.getPrometheusPortFlag(cmd) + } +} - return f(port, promPort) +func (a *ApiCmd) addRunE() { + a.Command.RunE = func(cmd *cobra.Command, args []string) error { + return a.initFunc(a.config, a.port, a.prometheusPort) } } -func (a *ApiCmd) GetPortFromConfig(portType string) (int, error) { +func (a *ApiCmd) GetPortFromConfig(portType string) int { if portType == constant.FlagPort { - if len(config2.Config.Api.OpenImApiPort) > 0 { - return config2.Config.Api.OpenImApiPort[0], nil - } - return 0, errors.New("API port configuration is empty or missing") + return a.config.Api.OpenImApiPort[0] } else if portType == constant.FlagPrometheusPort { - if len(config2.Config.Prometheus.ApiPrometheusPort) > 0 { - return config2.Config.Prometheus.ApiPrometheusPort[0], nil - } - return 0, errors.New("Prometheus port configuration is empty or missing") + return a.config.Prometheus.ApiPrometheusPort[0] } - return 0, fmt.Errorf("unknown port type: %s", portType) + return 0 } diff --git a/pkg/common/cmd/cron_task.go b/pkg/common/cmd/cron_task.go index fa7a46351..d8c9dd2a8 100644 --- a/pkg/common/cmd/cron_task.go +++ b/pkg/common/cmd/cron_task.go @@ -14,29 +14,35 @@ package cmd -import "github.com/spf13/cobra" +import ( + "github.com/openimsdk/open-im-server/v3/internal/tools" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/spf13/cobra" +) type CronTaskCmd struct { *RootCmd + initFunc func(config *config.GlobalConfig) error } func NewCronTaskCmd() *CronTaskCmd { - ret := &CronTaskCmd{NewRootCmd("cronTask", WithCronTaskLogName())} + ret := &CronTaskCmd{RootCmd: NewRootCmd("cronTask", WithCronTaskLogName()), + initFunc: tools.StartTask} + ret.addRunE() ret.SetRootCmdPt(ret) return ret } -func (c *CronTaskCmd) addRunE(f func() error) { +func (c *CronTaskCmd) addRunE() { c.Command.RunE = func(cmd *cobra.Command, args []string) error { - return f() + return c.initFunc(c.config) } } -func (c *CronTaskCmd) Exec(f func() error) error { - c.addRunE(f) +func (c *CronTaskCmd) Exec() error { return c.Execute() } -func (c *CronTaskCmd) GetPortFromConfig(portType string) (int, error) { - return 0, nil +func (c *CronTaskCmd) GetPortFromConfig(portType string) int { + return 0 } diff --git a/pkg/common/cmd/msg_gateway.go b/pkg/common/cmd/msg_gateway.go index 403d0ec84..3a939fa97 100644 --- a/pkg/common/cmd/msg_gateway.go +++ b/pkg/common/cmd/msg_gateway.go @@ -15,13 +15,13 @@ package cmd import ( - "errors" + "log" + + "github.com/spf13/cobra" "github.com/OpenIMSDK/protocol/constant" - "github.com/OpenIMSDK/tools/errs" + "github.com/openimsdk/open-im-server/v3/internal/msggateway" - v3config "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/spf13/cobra" ) type MsgGatewayCmd struct { @@ -30,6 +30,7 @@ type MsgGatewayCmd struct { func NewMsgGatewayCmd() *MsgGatewayCmd { ret := &MsgGatewayCmd{NewRootCmd("msgGateway")} + ret.addRunE() ret.SetRootCmdPt(ret) return ret } @@ -38,67 +39,39 @@ func (m *MsgGatewayCmd) AddWsPortFlag() { m.Command.Flags().IntP(constant.FlagWsPort, "w", 0, "ws server listen port") } -func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) (int, error) { +func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int { port, err := cmd.Flags().GetInt(constant.FlagWsPort) if err != nil { - return 0, errs.Wrap(err, "error getting ws port flag") + log.Println("Error getting ws port flag:", err) } if port == 0 { - port, _ = m.PortFromConfig(constant.FlagWsPort) + port = m.PortFromConfig(constant.FlagWsPort) } - return port, nil + return port } func (m *MsgGatewayCmd) addRunE() { m.Command.RunE = func(cmd *cobra.Command, args []string) error { - wsPort, err := m.getWsPortFlag(cmd) - if err != nil { - return errs.Wrap(err, "failed to get WS port flag") - } - port, err := m.getPortFlag(cmd) - if err != nil { - return err - } - prometheusPort, err := m.getPrometheusPortFlag(cmd) - if err != nil { - return err - } - return msggateway.RunWsAndServer(port, wsPort, prometheusPort) + return msggateway.RunWsAndServer(m.config, m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd)) } } func (m *MsgGatewayCmd) Exec() error { - m.addRunE() return m.Execute() } -func (m *MsgGatewayCmd) GetPortFromConfig(portType string) (int, error) { - var port int - var exists bool - +func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int { switch portType { case constant.FlagWsPort: - if len(v3config.Config.LongConnSvr.OpenImWsPort) > 0 { - port = v3config.Config.LongConnSvr.OpenImWsPort[0] - exists = true - } + return m.config.LongConnSvr.OpenImWsPort[0] case constant.FlagPort: - if len(v3config.Config.LongConnSvr.OpenImMessageGatewayPort) > 0 { - port = v3config.Config.LongConnSvr.OpenImMessageGatewayPort[0] - exists = true - } + return m.config.LongConnSvr.OpenImMessageGatewayPort[0] case constant.FlagPrometheusPort: - if len(v3config.Config.Prometheus.MessageGatewayPrometheusPort) > 0 { - port = v3config.Config.Prometheus.MessageGatewayPrometheusPort[0] - exists = true - } - } + return m.config.Prometheus.MessageGatewayPrometheusPort[0] - if !exists { - return 0, errs.Wrap(errors.New("port type '%s' not found in configuration"), portType) + default: + return 0 } - - return port, nil } diff --git a/pkg/common/cmd/msg_gateway_test.go b/pkg/common/cmd/msg_gateway_test.go index 106ad74ec..c0ea2b057 100644 --- a/pkg/common/cmd/msg_gateway_test.go +++ b/pkg/common/cmd/msg_gateway_test.go @@ -44,7 +44,7 @@ func TestMsgGatewayCmd_GetPortFromConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.portType, func(t *testing.T) { - got, _ := msgGatewayCmd.GetPortFromConfig(tt.portType) + got := msgGatewayCmd.GetPortFromConfig(tt.portType) assert.Equal(t, tt.want, got) }) } diff --git a/pkg/common/cmd/msg_transfer.go b/pkg/common/cmd/msg_transfer.go index 4db24fac5..e46b66b52 100644 --- a/pkg/common/cmd/msg_transfer.go +++ b/pkg/common/cmd/msg_transfer.go @@ -16,11 +16,10 @@ package cmd import ( "fmt" - "github.com/OpenIMSDK/protocol/constant" - "github.com/openimsdk/open-im-server/v3/internal/msgtransfer" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/spf13/cobra" + + "github.com/openimsdk/open-im-server/v3/internal/msgtransfer" ) type MsgTransferCmd struct { @@ -29,37 +28,29 @@ type MsgTransferCmd struct { func NewMsgTransferCmd() *MsgTransferCmd { ret := &MsgTransferCmd{NewRootCmd("msgTransfer")} + ret.addRunE() ret.SetRootCmdPt(ret) return ret } func (m *MsgTransferCmd) addRunE() { m.Command.RunE = func(cmd *cobra.Command, args []string) error { - prometheusPort, err := m.getPrometheusPortFlag(cmd) - if err != nil { - return err - } - return msgtransfer.StartTransfer(prometheusPort) + return msgtransfer.StartTransfer(m.config, m.getPrometheusPortFlag(cmd)) } } func (m *MsgTransferCmd) Exec() error { - m.addRunE() return m.Execute() } -func (m *MsgTransferCmd) GetPortFromConfig(portType string) (int, error) { +func (m *MsgTransferCmd) GetPortFromConfig(portType string) int { if portType == constant.FlagPort { - return 0, nil + return 0 } else if portType == constant.FlagPrometheusPort { n := m.getTransferProgressFlagValue() - - if n < len(config2.Config.Prometheus.MessageTransferPrometheusPort) { - return config2.Config.Prometheus.MessageTransferPrometheusPort[n], nil - } - return 0, fmt.Errorf("index out of range for MessageTransferPrometheusPort with index %d", n) + return m.config.Prometheus.MessageTransferPrometheusPort[n] } - return 0, fmt.Errorf("unknown port type: %s", portType) + return 0 } func (m *MsgTransferCmd) AddTransferProgressFlag() { @@ -67,10 +58,10 @@ func (m *MsgTransferCmd) AddTransferProgressFlag() { } func (m *MsgTransferCmd) getTransferProgressFlagValue() int { - nindex, err := m.Command.Flags().GetInt(constant.FlagTransferProgressIndex) + nIndex, err := m.Command.Flags().GetInt(constant.FlagTransferProgressIndex) if err != nil { - fmt.Println("get transfercmd error,make sure it is k8s env or not") + fmt.Println("get transfer cmd error,make sure it is k8s env or not") return 0 } - return nindex + return nIndex } diff --git a/pkg/common/cmd/msg_utils.go b/pkg/common/cmd/msg_utils.go index 03c1cab67..df15acd87 100644 --- a/pkg/common/cmd/msg_utils.go +++ b/pkg/common/cmd/msg_utils.go @@ -22,6 +22,7 @@ import ( type MsgUtilsCmd struct { cobra.Command + MsgTool *tools.MsgTool } func (m *MsgUtilsCmd) AddUserIDFlag() { @@ -135,7 +136,7 @@ func NewSeqCmd() *SeqCmd { func (s *SeqCmd) GetSeqCmd() *cobra.Command { s.Command.Run = func(cmdLines *cobra.Command, args []string) { - _, err := tools.InitMsgTool() + _, err := tools.InitMsgTool(s.MsgTool.Config) if err != nil { util.ExitWithError(err) } diff --git a/pkg/common/cmd/root.go b/pkg/common/cmd/root.go index 591bfc804..478942a5b 100644 --- a/pkg/common/cmd/root.go +++ b/pkg/common/cmd/root.go @@ -26,7 +26,7 @@ import ( ) type RootCmdPt interface { - GetPortFromConfig(portType string) (int, error) + GetPortFromConfig(portType string) int } type RootCmd struct { @@ -35,6 +35,11 @@ type RootCmd struct { port int prometheusPort int cmdItf RootCmdPt + config *config.GlobalConfig +} + +func (rc *RootCmd) Port() int { + return rc.port } type CmdOpts struct { @@ -54,7 +59,7 @@ func WithLogName(logName string) func(*CmdOpts) { } func NewRootCmd(name string, opts ...func(*CmdOpts)) *RootCmd { - rootCmd := &RootCmd{Name: name} + rootCmd := &RootCmd{Name: name, config: config.NewGlobalConfig()} cmd := cobra.Command{ Use: "Start openIM application", Short: fmt.Sprintf(`Start %s `, name), @@ -96,7 +101,7 @@ func (rc *RootCmd) applyOptions(opts ...func(*CmdOpts)) *CmdOpts { } func (rc *RootCmd) initializeLogger(cmdOpts *CmdOpts) error { - logConfig := config.Config.Log + logConfig := rc.config.Log return log.InitFromConfig( @@ -129,41 +134,36 @@ func (r *RootCmd) AddPortFlag() { r.Command.Flags().IntP(constant.FlagPort, "p", 0, "server listen port") } -func (r *RootCmd) getPortFlag(cmd *cobra.Command) (int, error) { +func (r *RootCmd) getPortFlag(cmd *cobra.Command) int { port, err := cmd.Flags().GetInt(constant.FlagPort) if err != nil { // Wrapping the error with additional context - return 0, errs.Wrap(err, "error getting port flag") + return 0 } if port == 0 { - port, _ = r.PortFromConfig(constant.FlagPort) - // port, err := r.PortFromConfig(constant.FlagPort) - // if err != nil { - // // Optionally wrap the error if it's an internal error needing context - // return 0, errs.Wrap(err, "error getting port from config") - // } + port = r.PortFromConfig(constant.FlagPort) } - return port, nil + return port } // // GetPortFlag returns the port flag. -func (r *RootCmd) GetPortFlag() (int, error) { - return r.port, nil +func (r *RootCmd) GetPortFlag() int { + return r.port } func (r *RootCmd) AddPrometheusPortFlag() { r.Command.Flags().IntP(constant.FlagPrometheusPort, "", 0, "server prometheus listen port") } -func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) (int, error) { +func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) int { port, err := cmd.Flags().GetInt(constant.FlagPrometheusPort) if err != nil || port == 0 { - port, err = r.PortFromConfig(constant.FlagPrometheusPort) + port = r.PortFromConfig(constant.FlagPrometheusPort) if err != nil { - return 0, err + return 0 } } - return port, nil + return port } func (r *RootCmd) GetPrometheusPortFlag() int { @@ -173,7 +173,7 @@ func (r *RootCmd) GetPrometheusPortFlag() int { func (r *RootCmd) getConfFromCmdAndInit(cmdLines *cobra.Command) error { configFolderPath, _ := cmdLines.Flags().GetString(constant.FlagConf) fmt.Println("The directory of the configuration file to start the process:", configFolderPath) - return config2.InitConfig(configFolderPath) + return config2.InitConfig(r.config, configFolderPath) } func (r *RootCmd) Execute() error { @@ -184,11 +184,8 @@ func (r *RootCmd) AddCommand(cmds ...*cobra.Command) { r.Command.AddCommand(cmds...) } -func (r *RootCmd) PortFromConfig(portType string) (int, error) { +func (r *RootCmd) PortFromConfig(portType string) int { // Retrieve the port and cache it - port, err := r.cmdItf.GetPortFromConfig(portType) - if err != nil { - return 0, err - } - return port, nil + port := r.cmdItf.GetPortFromConfig(portType) + return port } diff --git a/pkg/common/cmd/rpc.go b/pkg/common/cmd/rpc.go index e30de93a9..5199524e7 100644 --- a/pkg/common/cmd/rpc.go +++ b/pkg/common/cmd/rpc.go @@ -16,100 +16,144 @@ package cmd import ( "errors" - "fmt" + "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/protocol/constant" - "github.com/OpenIMSDK/tools/discoveryregistry" - "github.com/OpenIMSDK/tools/errs" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/pkg/common/startrpc" "github.com/spf13/cobra" "google.golang.org/grpc" + + config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" + + "github.com/OpenIMSDK/tools/discoveryregistry" + + "github.com/openimsdk/open-im-server/v3/pkg/common/startrpc" ) +type rpcInitFuc func(config *config2.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error + type RpcCmd struct { *RootCmd + RpcRegisterName string + initFunc rpcInitFuc } -func NewRpcCmd(name string) *RpcCmd { - ret := &RpcCmd{NewRootCmd(name)} +func NewRpcCmd(name string, initFunc rpcInitFuc) *RpcCmd { + ret := &RpcCmd{RootCmd: NewRootCmd(name), initFunc: initFunc} + ret.addPreRun() + ret.addRunE() ret.SetRootCmdPt(ret) return ret } -func (a *RpcCmd) Exec() error { - a.Command.RunE = func(cmd *cobra.Command, args []string) error { - portFlag, err := a.getPortFlag(cmd) - if err != nil { - return err - } - a.port = portFlag +func (a *RpcCmd) addPreRun() { + a.Command.PreRun = func(cmd *cobra.Command, args []string) { + a.port = a.getPortFlag(cmd) + a.prometheusPort = a.getPrometheusPortFlag(cmd) + } +} - prometheusPort, err := a.getPrometheusPortFlag(cmd) +func (a *RpcCmd) addRunE() { + a.Command.RunE = func(cmd *cobra.Command, args []string) error { + rpcRegisterName, err := a.GetRpcRegisterNameFromConfig() if err != nil { return err + } else { + return a.StartSvr(rpcRegisterName, a.initFunc) } - a.prometheusPort = prometheusPort - - return nil } - return a.Execute() } -func (a *RpcCmd) StartSvr(name string, rpcFn func(discov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error { - portFlag, err := a.GetPortFlag() - if err != nil { - return err - } else { - a.port = portFlag - } - - return startrpc.Start(portFlag, name, a.GetPrometheusPortFlag(), rpcFn) +func (a *RpcCmd) Exec() error { + return a.Execute() } -func (a *RpcCmd) GetPortFromConfig(portType string) (int, error) { - portConfigMap := map[string]map[string]int{ - RpcPushServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImPushPort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.PushPrometheusPort[0], - }, - RpcAuthServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImAuthPort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.AuthPrometheusPort[0], - }, - RpcConversationServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImConversationPort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.ConversationPrometheusPort[0], - }, - RpcFriendServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImFriendPort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.FriendPrometheusPort[0], - }, - RpcGroupServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImGroupPort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.GroupPrometheusPort[0], - }, - RpcMsgServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImMessagePort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.MessagePrometheusPort[0], - }, - RpcThirdServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImThirdPort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.ThirdPrometheusPort[0], - }, - RpcUserServer: { - constant.FlagPort: config2.Config.RpcPort.OpenImUserPort[0], - constant.FlagPrometheusPort: config2.Config.Prometheus.UserPrometheusPort[0], - }, +func (a *RpcCmd) StartSvr(name string, rpcFn func(config *config2.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error { + if a.GetPortFlag() == 0 { + return errs.Wrap(errors.New("port is required")) } + return startrpc.Start(a.GetPortFlag(), name, a.GetPrometheusPortFlag(), a.config, rpcFn) +} - if portMap, ok := portConfigMap[a.Name]; ok { - if port, ok := portMap[portType]; ok { - return port, nil - } else { - return 0, errs.Wrap(errors.New("port type not found"), fmt.Sprintf("Failed to get port for %s", a.Name)) +func (a *RpcCmd) GetPortFromConfig(portType string) int { + switch a.Name { + case RpcPushServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImPushPort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.PushPrometheusPort[0] + } + case RpcAuthServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImAuthPort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.AuthPrometheusPort[0] + } + case RpcConversationServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImConversationPort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.ConversationPrometheusPort[0] + } + case RpcFriendServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImFriendPort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.FriendPrometheusPort[0] + } + case RpcGroupServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImGroupPort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.GroupPrometheusPort[0] + } + case RpcMsgServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImMessagePort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.MessagePrometheusPort[0] + } + case RpcThirdServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImThirdPort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.ThirdPrometheusPort[0] + } + case RpcUserServer: + if portType == constant.FlagPort { + return a.config.RpcPort.OpenImUserPort[0] + } + if portType == constant.FlagPrometheusPort { + return a.config.Prometheus.UserPrometheusPort[0] } } + return 0 +} - return 0, errs.Wrap(fmt.Errorf("server name '%s' not found", a.Name), "Failed to get port configuration") +func (a *RpcCmd) GetRpcRegisterNameFromConfig() (string, error) { + switch a.Name { + case RpcPushServer: + return a.config.RpcRegisterName.OpenImPushName, nil + case RpcAuthServer: + return a.config.RpcRegisterName.OpenImAuthName, nil + case RpcConversationServer: + return a.config.RpcRegisterName.OpenImConversationName, nil + case RpcFriendServer: + return a.config.RpcRegisterName.OpenImFriendName, nil + case RpcGroupServer: + return a.config.RpcRegisterName.OpenImGroupName, nil + case RpcMsgServer: + return a.config.RpcRegisterName.OpenImMsgName, nil + case RpcThirdServer: + return a.config.RpcRegisterName.OpenImThirdName, nil + case RpcUserServer: + return a.config.RpcRegisterName.OpenImUserName, nil + } + return "", errs.Wrap(errors.New("can not get rpc register name"), a.Name) } diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index 7ee55d876..ac42395bf 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -21,7 +21,7 @@ import ( "gopkg.in/yaml.v3" ) -var Config configStruct +var Config GlobalConfig const ConfKey = "conf" @@ -57,7 +57,7 @@ type MYSQL struct { SlowThreshold int `yaml:"slowThreshold"` } -type configStruct struct { +type GlobalConfig struct { Envs struct { Discovery string `yaml:"discovery"` } @@ -339,6 +339,10 @@ type configStruct struct { Notification notification `yaml:"notification"` } +func NewGlobalConfig() *GlobalConfig { + return &GlobalConfig{} +} + type notification struct { GroupCreated NotificationConf `yaml:"groupCreated"` GroupInfoSet NotificationConf `yaml:"groupInfoSet"` @@ -378,7 +382,7 @@ type notification struct { ConversationSetPrivate NotificationConf `yaml:"conversationSetPrivate"` } -func (c *configStruct) GetServiceNames() []string { +func (c *GlobalConfig) GetServiceNames() []string { return []string{ c.RpcRegisterName.OpenImUserName, c.RpcRegisterName.OpenImFriendName, @@ -392,7 +396,7 @@ func (c *configStruct) GetServiceNames() []string { } } -func (c *configStruct) RegisterConf2Registry(registry discoveryregistry.SvcDiscoveryRegistry) error { +func (c *GlobalConfig) RegisterConf2Registry(registry discoveryregistry.SvcDiscoveryRegistry) error { data, err := yaml.Marshal(c) if err != nil { return err @@ -400,11 +404,11 @@ func (c *configStruct) RegisterConf2Registry(registry discoveryregistry.SvcDisco return registry.RegisterConf2Registry(ConfKey, data) } -func (c *configStruct) GetConfFromRegistry(registry discoveryregistry.SvcDiscoveryRegistry) ([]byte, error) { +func (c *GlobalConfig) GetConfFromRegistry(registry discoveryregistry.SvcDiscoveryRegistry) ([]byte, error) { return registry.GetConfFromRegistry(ConfKey) } -func (c *configStruct) EncodeConfig() []byte { +func (c *GlobalConfig) EncodeConfig() []byte { buf := bytes.NewBuffer(nil) if err := yaml.NewEncoder(buf).Encode(c); err != nil { panic(err) diff --git a/pkg/common/config/parse.go b/pkg/common/config/parse.go index a73665386..bfbf6daf7 100644 --- a/pkg/common/config/parse.go +++ b/pkg/common/config/parse.go @@ -21,10 +21,10 @@ import ( "path/filepath" "github.com/OpenIMSDK/protocol/constant" - "github.com/OpenIMSDK/tools/errs" + "gopkg.in/yaml.v3" + "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" - "gopkg.in/yaml.v3" ) //go:embed version @@ -36,38 +36,32 @@ const ( DefaultFolderPath = "../config/" ) -// GetDefaultConfigPath returns the absolute path to the default configuration directory -// relative to the executable's location. It is intended for use in Kubernetes container configurations. -// Errors are returned to the caller to allow for flexible error handling. -func GetDefaultConfigPath() (string, error) { +// return absolude path join ../config/, this is k8s container config path. +func GetDefaultConfigPath() string { executablePath, err := os.Executable() if err != nil { - return "", errs.Wrap(err, "failed to get executable path") + fmt.Println("GetDefaultConfigPath error:", err.Error()) + return "" } - // Calculate the config path as a directory relative to the executable's location configPath, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../config/")) if err != nil { - return "", errs.Wrap(err, "failed to get output directory") + fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err) + os.Exit(1) } - return configPath, nil + return configPath } -// GetProjectRoot returns the absolute path of the project root directory by navigating up from the directory -// containing the executable. It provides a detailed error if the path cannot be determined. -func GetProjectRoot() (string, error) { - executablePath, err := os.Executable() - if err != nil { - return "", errs.Wrap(err, "failed to retrieve executable path") - } +// getProjectRoot returns the absolute path of the project root directory. +func GetProjectRoot() string { + executablePath, _ := os.Executable() - // Attempt to compute the project root by navigating up from the executable's directory projectRoot, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../../../../..")) if err != nil { - return "", err + fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err) + os.Exit(1) } - - return projectRoot, nil + return projectRoot } func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options { @@ -93,62 +87,41 @@ func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options { // If the specified config file does not exist, it attempts to load from the project's default "config" directory. // It logs informative messages regarding the configuration path being used. func initConfig(config any, configName, configFolderPath string) error { - configFilePath := filepath.Join(configFolderPath, configName) - _, err := os.Stat(configFilePath) + configFolderPath = filepath.Join(configFolderPath, configName) + _, err := os.Stat(configFolderPath) if err != nil { if !os.IsNotExist(err) { - return errs.Wrap(err, fmt.Sprintf("failed to check existence of config file at path: %s", configFilePath)) + fmt.Println("stat config path error:", err.Error()) + return fmt.Errorf("stat config path error: %w", err) } - var projectRoot string - projectRoot, err = GetProjectRoot() - if err != nil { - return err - } - configFilePath = filepath.Join(projectRoot, "config", configName) - fmt.Printf("Configuration file not found at specified path. Falling back to project path: %s\n", configFilePath) + configFolderPath = filepath.Join(GetProjectRoot(), "config", configName) + fmt.Println("flag's path,enviment's path,default path all is not exist,using project path:", configFolderPath) } - - data, err := os.ReadFile(configFilePath) + data, err := os.ReadFile(configFolderPath) if err != nil { - // Wrap and return the error if reading the configuration file fails. - return errs.Wrap(err, fmt.Sprintf("failed to read configuration file at path: %s", configFilePath)) + return fmt.Errorf("read file error: %w", err) } - if err = yaml.Unmarshal(data, config); err != nil { - // Wrap and return the error if unmarshalling the YAML configuration fails. - return errs.Wrap(err, "failed to unmarshal YAML configuration") + return fmt.Errorf("unmarshal yaml error: %w", err) } + fmt.Println("The path of the configuration file to start the process:", configFolderPath) - fmt.Printf("Configuration file loaded successfully from path: %s\n", configFilePath) return nil } -// InitConfig initializes the application configuration by loading it from a specified folder path. -// If the folder path is not provided, it attempts to use the OPENIMCONFIG environment variable, -// and as a fallback, it uses the default configuration path. It loads both the main configuration -// and notification configuration, wrapping errors for better context. -func InitConfig(configFolderPath string) error { - // Use the provided config folder path, or fallback to environment variable or default path +func InitConfig(config *GlobalConfig, configFolderPath string) error { if configFolderPath == "" { - configFolderPath = os.Getenv("OPENIMCONFIG") - if configFolderPath == "" { - var err error - configFolderPath, err = GetDefaultConfigPath() - if err != nil { - return err - } + envConfigPath := os.Getenv("OPENIMCONFIG") + if envConfigPath != "" { + configFolderPath = envConfigPath + } else { + configFolderPath = GetDefaultConfigPath() } } - // Initialize the main configuration - if err := initConfig(&Config, FileName, configFolderPath); err != nil { + if err := initConfig(config, FileName, configFolderPath); err != nil { return err } - // Initialize the notification configuration - if err := initConfig(&Config.Notification, NotificationFileName, configFolderPath); err != nil { - return err - } - - return nil + return initConfig(&config.Notification, NotificationFileName, configFolderPath) } diff --git a/pkg/common/config/parse_test.go b/pkg/common/config/parse_test.go index b980de7bd..84dee1165 100644 --- a/pkg/common/config/parse_test.go +++ b/pkg/common/config/parse_test.go @@ -103,13 +103,14 @@ func TestInitConfig(t *testing.T) { tests := []struct { name string args args + config *GlobalConfig wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := InitConfig(tt.args.configFolderPath); (err != nil) != tt.wantErr { + if err := InitConfig(tt.config, tt.args.configFolderPath); (err != nil) != tt.wantErr { t.Errorf("InitConfig() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/pkg/common/db/cache/init_redis.go b/pkg/common/db/cache/init_redis.go index b81edf9e5..8f4e2c592 100644 --- a/pkg/common/db/cache/init_redis.go +++ b/pkg/common/db/cache/init_redis.go @@ -38,34 +38,34 @@ const ( ) // NewRedis Initialize redis connection. -func NewRedis() (redis.UniversalClient, error) { +func NewRedis(config *config.GlobalConfig) (redis.UniversalClient, error) { if redisClient != nil { return redisClient, nil } // Read configuration from environment variables - overrideConfigFromEnv() + overrideConfigFromEnv(config) - if len(config.Config.Redis.Address) == 0 { - return nil, errs.Wrap(errors.New("redis address is empty"), "Redis configuration error") + if len(config.Redis.Address) == 0 { + return nil, errs.Wrap(errors.New("redis address is empty")) } specialerror.AddReplace(redis.Nil, errs.ErrRecordNotFound) var rdb redis.UniversalClient - if len(config.Config.Redis.Address) > 1 || config.Config.Redis.ClusterMode { + if len(config.Redis.Address) > 1 || config.Redis.ClusterMode { rdb = redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: config.Config.Redis.Address, - Username: config.Config.Redis.Username, - Password: config.Config.Redis.Password, // no password set + Addrs: config.Redis.Address, + Username: config.Redis.Username, + Password: config.Redis.Password, // no password set PoolSize: 50, MaxRetries: maxRetry, }) } else { rdb = redis.NewClient(&redis.Options{ - Addr: config.Config.Redis.Address[0], - Username: config.Config.Redis.Username, - Password: config.Config.Redis.Password, // no password set - DB: 0, // use default DB - PoolSize: 100, // connection pool size + Addr: config.Redis.Address[0], + Username: config.Redis.Username, + Password: config.Redis.Password, + DB: 0, // use default DB + PoolSize: 100, // connection pool size MaxRetries: maxRetry, }) } @@ -75,33 +75,33 @@ func NewRedis() (redis.UniversalClient, error) { defer cancel() err = rdb.Ping(ctx).Err() if err != nil { - uriFormat := "address:%v, username:%s, clusterMode:%t, enablePipeline:%t" - errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline) - return nil, errs.Wrap(err, "Redis connection failed: %s", errMsg) + errMsg := fmt.Sprintf("address:%s, username:%s, password:%s, clusterMode:%t, enablePipeline:%t", config.Redis.Address, config.Redis.Username, + config.Redis.Password, config.Redis.ClusterMode, config.Redis.EnablePipeline) + return nil, errs.Wrap(err, errMsg) } redisClient = rdb return rdb, err } // overrideConfigFromEnv overrides configuration fields with environment variables if present. -func overrideConfigFromEnv() { +func overrideConfigFromEnv(config *config.GlobalConfig) { if envAddr := os.Getenv("REDIS_ADDRESS"); envAddr != "" { if envPort := os.Getenv("REDIS_PORT"); envPort != "" { addresses := strings.Split(envAddr, ",") for i, addr := range addresses { addresses[i] = addr + ":" + envPort } - config.Config.Redis.Address = addresses + config.Redis.Address = addresses } else { - config.Config.Redis.Address = strings.Split(envAddr, ",") + config.Redis.Address = strings.Split(envAddr, ",") } } if envUser := os.Getenv("REDIS_USERNAME"); envUser != "" { - config.Config.Redis.Username = envUser + config.Redis.Username = envUser } if envPass := os.Getenv("REDIS_PASSWORD"); envPass != "" { - config.Config.Redis.Password = envPass + config.Redis.Password = envPass } } diff --git a/pkg/common/db/cache/meta_cache.go b/pkg/common/db/cache/meta_cache.go index 4c25754d6..1fbb6c3b6 100644 --- a/pkg/common/db/cache/meta_cache.go +++ b/pkg/common/db/cache/meta_cache.go @@ -18,13 +18,16 @@ import ( "context" "encoding/json" "errors" + "fmt" "time" + "github.com/OpenIMSDK/tools/mw/specialerror" + + "github.com/dtm-labs/rockscache" + "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" - "github.com/OpenIMSDK/tools/mw/specialerror" "github.com/OpenIMSDK/tools/utils" - "github.com/dtm-labs/rockscache" ) const ( @@ -128,7 +131,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin v, err := rcClient.Fetch2(ctx, key, expire, func() (s string, err error) { t, err = fn(ctx) if err != nil { - return "", err + return "", errs.Wrap(err) } bs, err := json.Marshal(t) if err != nil { @@ -139,7 +142,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin return string(bs), nil }) if err != nil { - return t, err + return t, errs.Wrap(err) } if write { return t, nil @@ -149,8 +152,8 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin } err = json.Unmarshal([]byte(v), &t) if err != nil { - log.ZError(ctx, "cache json.Unmarshal failed", err, "key", key, "value", v, "expire", expire) - return t, errs.Wrap(err, "unmarshal failed") + errInfo := fmt.Sprintf("cache json.Unmarshal failed, key:%s, value:%s, expire:%s", key, v, expire) + return t, errs.Wrap(err, errInfo) } return t, nil @@ -203,7 +206,7 @@ func batchGetCache2[T any, K comparable]( fns func(ctx context.Context, key K) (T, error), ) ([]T, error) { if len(keys) == 0 { - return nil, nil + return nil, errs.ErrArgs.Wrap("groupID is empty") } res := make([]T, 0, len(keys)) for _, key := range keys { @@ -214,7 +217,7 @@ func batchGetCache2[T any, K comparable]( if errs.ErrRecordNotFound.Is(specialerror.ErrCode(errs.Unwrap(err))) { continue } - return nil, err + return nil, errs.Wrap(err) } res = append(res, val) } diff --git a/pkg/common/db/cache/msg.go b/pkg/common/db/cache/msg.go index 889f36baa..1266875f1 100644 --- a/pkg/common/db/cache/msg.go +++ b/pkg/common/db/cache/msg.go @@ -121,13 +121,14 @@ type MsgModel interface { UnLockMessageTypeKey(ctx context.Context, clientMsgID string, TypeKey string) error } -func NewMsgCacheModel(client redis.UniversalClient) MsgModel { - return &msgCache{rdb: client} +func NewMsgCacheModel(client redis.UniversalClient, config *config.GlobalConfig) MsgModel { + return &msgCache{rdb: client, config: config} } type msgCache struct { metaCache - rdb redis.UniversalClient + rdb redis.UniversalClient + config *config.GlobalConfig } func (c *msgCache) getMaxSeqKey(conversationID string) string { @@ -315,7 +316,7 @@ func (c *msgCache) allMessageCacheKey(conversationID string) string { } func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) { - if config.Config.Redis.EnablePipeline { + if c.config.Redis.EnablePipeline { return c.PipeGetMessagesBySeq(ctx, conversationID, seqs) } @@ -416,7 +417,7 @@ func (c *msgCache) ParallelGetMessagesBySeq(ctx context.Context, conversationID } func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { - if config.Config.Redis.EnablePipeline { + if c.config.Redis.EnablePipeline { return c.PipeSetMessageToCache(ctx, conversationID, msgs) } return c.ParallelSetMessageToCache(ctx, conversationID, msgs) @@ -431,7 +432,7 @@ func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID str } key := c.getMessageCacheKey(conversationID, msg.Seq) - _ = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second) + _ = pipe.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second) } results, err := pipe.Exec(ctx) @@ -461,7 +462,7 @@ func (c *msgCache) ParallelSetMessageToCache(ctx context.Context, conversationID } key := c.getMessageCacheKey(conversationID, msg.Seq) - if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil { + if err := c.rdb.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil { return errs.Wrap(err) } return nil @@ -496,10 +497,10 @@ func (c *msgCache) UserDeleteMsgs(ctx context.Context, conversationID string, se if err != nil { return errs.Wrap(err) } - if err := c.rdb.Expire(ctx, delUserListKey, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil { + if err := c.rdb.Expire(ctx, delUserListKey, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil { return errs.Wrap(err) } - if err := c.rdb.Expire(ctx, userDelListKey, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil { + if err := c.rdb.Expire(ctx, userDelListKey, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil { return errs.Wrap(err) } } @@ -604,7 +605,7 @@ func (c *msgCache) DelUserDeleteMsgsList(ctx context.Context, conversationID str } func (c *msgCache) DeleteMessages(ctx context.Context, conversationID string, seqs []int64) error { - if config.Config.Redis.EnablePipeline { + if c.config.Redis.EnablePipeline { return c.PipeDeleteMessages(ctx, conversationID, seqs) } @@ -686,7 +687,7 @@ func (c *msgCache) DelMsgFromCache(ctx context.Context, userID string, seqs []in if err != nil { return errs.Wrap(err) } - if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil { + if err := c.rdb.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil { return errs.Wrap(err) } } diff --git a/pkg/common/db/cache/user.go b/pkg/common/db/cache/user.go index c6c6966f3..34c220624 100644 --- a/pkg/common/db/cache/user.go +++ b/pkg/common/db/cache/user.go @@ -22,12 +22,16 @@ import ( "strconv" "time" + relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" + + "github.com/OpenIMSDK/tools/log" + "github.com/OpenIMSDK/protocol/constant" + "github.com/OpenIMSDK/protocol/user" "github.com/OpenIMSDK/tools/errs" - "github.com/OpenIMSDK/tools/log" + "github.com/dtm-labs/rockscache" - relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" "github.com/redis/go-redis/v9" ) @@ -62,7 +66,11 @@ type UserCacheRedis struct { rcClient *rockscache.Client } -func NewUserCacheRedis(rdb redis.UniversalClient, userDB relationtb.UserModelInterface, options rockscache.Options) UserCache { +func NewUserCacheRedis( + rdb redis.UniversalClient, + userDB relationtb.UserModelInterface, + options rockscache.Options, +) UserCache { rcClient := rockscache.NewClient(rdb, options) return &UserCacheRedis{ @@ -193,13 +201,13 @@ func (u *UserCacheRedis) SetUserStatus(ctx context.Context, userID string, statu Status: constant.Online, PlatformIDs: []int32{platformID}, } - jsonData, err2 := json.Marshal(&onlineStatus) - if err2 != nil { - return errs.Wrap(err2) + jsonData, err := json.Marshal(&onlineStatus) + if err != nil { + return errs.Wrap(err) } - _, err2 = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result() - if err2 != nil { - return errs.Wrap(err2) + _, err = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result() + if err != nil { + return errs.Wrap(err) } u.rdb.Expire(ctx, key, userOlineStatusExpireTime) @@ -273,9 +281,9 @@ func (u *UserCacheRedis) refreshStatusOffline(ctx context.Context, userID string func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string, platformID int32, isNil bool, err error, result, key string) error { var onlineStatus user.OnlineStatus if !isNil { - err2 := json.Unmarshal([]byte(result), &onlineStatus) - if err2 != nil { - return errs.Wrap(err, "json.Unmarshal failed") + err := json.Unmarshal([]byte(result), &onlineStatus) + if err != nil { + return errs.Wrap(err) } onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID)) } else { diff --git a/pkg/common/db/controller/auth.go b/pkg/common/db/controller/auth.go index dfd7b3e78..19bea5981 100644 --- a/pkg/common/db/controller/auth.go +++ b/pkg/common/db/controller/auth.go @@ -16,6 +16,7 @@ package controller import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/tools/errs" @@ -33,14 +34,14 @@ type AuthDatabase interface { } type authDatabase struct { - cache cache.MsgModel - + cache cache.MsgModel accessSecret string accessExpire int64 + config *config.GlobalConfig } -func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64) AuthDatabase { - return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire} +func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64, config *config.GlobalConfig) AuthDatabase { + return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, config: config} } // If the result is empty. @@ -56,7 +57,7 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI } var deleteTokenKey []string for k, v := range tokens { - _, err = tokenverify.GetClaimFromToken(k, authverify.Secret()) + _, err = tokenverify.GetClaimFromToken(k, authverify.Secret(a.config.Secret)) if err != nil || v != constant.NormalToken { deleteTokenKey = append(deleteTokenKey, k) } diff --git a/pkg/common/db/controller/msg.go b/pkg/common/db/controller/msg.go index 56cb52703..ccf209b7a 100644 --- a/pkg/common/db/controller/msg.go +++ b/pkg/common/db/controller/msg.go @@ -120,16 +120,33 @@ type CommonMsgDatabase interface { ConvertMsgsDocLen(ctx context.Context, conversationIDs []string) } -func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel) (CommonMsgDatabase, error) { - producerToRedis, err := kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.LatestMsgToRedis.Topic) +func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel, config *config.GlobalConfig) (CommonMsgDatabase, error) { + producerConfig := &kafka.ProducerConfig{ + ProducerAck: config.Kafka.ProducerAck, + CompressType: config.Kafka.CompressType, + Username: config.Kafka.Username, + Password: config.Kafka.Password, + } + + var tlsConfig *kafka.TLSConfig + if config.Kafka.TLS != nil { + tlsConfig = &kafka.TLSConfig{ + CACrt: config.Kafka.TLS.CACrt, + ClientCrt: config.Kafka.TLS.ClientCrt, + ClientKey: config.Kafka.TLS.ClientKey, + ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd, + InsecureSkipVerify: false, + } + } + producerToRedis, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.LatestMsgToRedis.Topic, producerConfig, tlsConfig) if err != nil { return nil, err } - producerToMongo, err := kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.MsgToMongo.Topic) + producerToMongo, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToMongo.Topic, producerConfig, tlsConfig) if err != nil { return nil, err } - producerToPush, err := kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.MsgToPush.Topic) + producerToPush, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToPush.Topic, producerConfig, tlsConfig) if err != nil { return nil, err } @@ -142,10 +159,10 @@ func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheMo }, nil } -func InitCommonMsgDatabase(rdb redis.UniversalClient, database *mongo.Database) (CommonMsgDatabase, error) { - cacheModel := cache.NewMsgCacheModel(rdb) +func InitCommonMsgDatabase(rdb redis.UniversalClient, database *mongo.Database, config *config.GlobalConfig) (CommonMsgDatabase, error) { + cacheModel := cache.NewMsgCacheModel(rdb, config) msgDocModel := unrelation.NewMsgMongoDriver(database) - return NewCommonMsgDatabase(msgDocModel, cacheModel) + return NewCommonMsgDatabase(msgDocModel, cacheModel, config) } type commonMsgDatabase struct { @@ -397,9 +414,9 @@ func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversa log.ZError(ctx, "db.cache.SetMaxSeq error", err, "conversationID", conversationID) prommetrics.SeqSetFailedCounter.Inc() } - err2 := db.cache.SetHasReadSeqs(ctx, conversationID, userSeqMap) + err = db.cache.SetHasReadSeqs(ctx, conversationID, userSeqMap) if err != nil { - log.ZError(ctx, "SetHasReadSeqs error", err2, "userSeqMap", userSeqMap, "conversationID", conversationID) + log.ZError(ctx, "SetHasReadSeqs error", err, "userSeqMap", userSeqMap, "conversationID", conversationID) prommetrics.SeqSetFailedCounter.Inc() } return lastMaxSeq, isNew, errs.Wrap(err) diff --git a/pkg/common/db/controller/msg_test.go b/pkg/common/db/controller/msg_test.go index 70c055bf3..4c2ab20da 100644 --- a/pkg/common/db/controller/msg_test.go +++ b/pkg/common/db/controller/msg_test.go @@ -33,27 +33,28 @@ import ( ) func Test_BatchInsertChat2DB(t *testing.T) { - config.Config.Mongo.Address = []string{"192.168.44.128:37017"} - // config.Config.Mongo.Timeout = 60 - config.Config.Mongo.Database = "openIM" - // config.Config.Mongo.Source = "admin" - config.Config.Mongo.Username = "root" - config.Config.Mongo.Password = "openIM123" - config.Config.Mongo.MaxPoolSize = 100 - config.Config.RetainChatRecords = 3650 - config.Config.ChatRecordsClearTime = "0 2 * * 3" - - mongo, err := unrelation.NewMongo() + conf := config.NewGlobalConfig() + conf.Mongo.Address = []string{"192.168.44.128:37017"} + // conf.Mongo.Timeout = 60 + conf.Mongo.Database = "openIM" + // conf.Mongo.Source = "admin" + conf.Mongo.Username = "root" + conf.Mongo.Password = "openIM123" + conf.Mongo.MaxPoolSize = 100 + conf.RetainChatRecords = 3650 + conf.ChatRecordsClearTime = "0 2 * * 3" + + mongo, err := unrelation.NewMongo(conf) if err != nil { t.Fatal(err) } - err = mongo.GetDatabase().Client().Ping(context.Background(), nil) + err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil) if err != nil { panic(err) } db := &commonMsgDatabase{ - msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()), + msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)), } //ctx := context.Background() @@ -70,7 +71,7 @@ func Test_BatchInsertChat2DB(t *testing.T) { //} _ = db.BatchInsertChat2DB - c := mongo.GetDatabase().Collection("msg") + c := mongo.GetDatabase(conf.Mongo.Database).Collection("msg") ch := make(chan int) rand.Seed(time.Now().UnixNano()) @@ -144,26 +145,27 @@ func Test_BatchInsertChat2DB(t *testing.T) { } func GetDB() *commonMsgDatabase { - config.Config.Mongo.Address = []string{"203.56.175.233:37017"} - // config.Config.Mongo.Timeout = 60 - config.Config.Mongo.Database = "openim_v3" - // config.Config.Mongo.Source = "admin" - config.Config.Mongo.Username = "root" - config.Config.Mongo.Password = "openIM123" - config.Config.Mongo.MaxPoolSize = 100 - config.Config.RetainChatRecords = 3650 - config.Config.ChatRecordsClearTime = "0 2 * * 3" - - mongo, err := unrelation.NewMongo() + conf := config.NewGlobalConfig() + conf.Mongo.Address = []string{"203.56.175.233:37017"} + // conf.Mongo.Timeout = 60 + conf.Mongo.Database = "openim_v3" + // conf.Mongo.Source = "admin" + conf.Mongo.Username = "root" + conf.Mongo.Password = "openIM123" + conf.Mongo.MaxPoolSize = 100 + conf.RetainChatRecords = 3650 + conf.ChatRecordsClearTime = "0 2 * * 3" + + mongo, err := unrelation.NewMongo(conf) if err != nil { panic(err) } - err = mongo.GetDatabase().Client().Ping(context.Background(), nil) + err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil) if err != nil { panic(err) } return &commonMsgDatabase{ - msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()), + msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)), } } diff --git a/pkg/common/db/s3/aws/aws.go b/pkg/common/db/s3/aws/aws.go deleted file mode 100644 index dd54ed155..000000000 --- a/pkg/common/db/s3/aws/aws.go +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright © 2023 OpenIM. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// docURL: https://docs.aws.amazon.com/AmazonS3/latest/API/Welcome.html - -package aws - -import ( - "context" - "errors" - "fmt" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - sdk "github.com/aws/aws-sdk-go/service/s3" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3" -) - -const ( - minPartSize int64 = 1024 * 1024 * 1 // 1MB - maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB - maxNumSize int64 = 10000 -) - -// const ( -// imagePng = "png" -// imageJpg = "jpg" -// imageJpeg = "jpeg" -// imageGif = "gif" -// imageWebp = "webp" -// ) - -// const successCode = http.StatusOK - -// const ( -// videoSnapshotImagePng = "png" -// videoSnapshotImageJpg = "jpg" -// ) - -func NewAWS() (s3.Interface, error) { - conf := config.Config.Object.Aws - credential := credentials.NewStaticCredentials( - conf.AccessKeyID, // accessKey - conf.AccessKeySecret, // secretKey - "") // stoken - - sess, err := session.NewSession(&aws.Config{ - Region: aws.String(conf.Region), // The area where the bucket is located - Credentials: credential, - }) - - if err != nil { - return nil, err - } - return &Aws{ - bucket: conf.Bucket, - client: sdk.New(sess), - credential: credential, - }, nil -} - -type Aws struct { - bucket string - client *sdk.S3 - credential *credentials.Credentials -} - -func (a *Aws) Engine() string { - return "aws" -} - -func (a *Aws) InitiateMultipartUpload(ctx context.Context, name string) (*s3.InitiateMultipartUploadResult, error) { - input := &sdk.CreateMultipartUploadInput{ - Bucket: aws.String(a.bucket), // TODO: To be verified whether it is required - Key: aws.String(name), - } - result, err := a.client.CreateMultipartUploadWithContext(ctx, input) - if err != nil { - return nil, err - } - return &s3.InitiateMultipartUploadResult{ - Bucket: *result.Bucket, - Key: *result.Key, - UploadID: *result.UploadId, - }, nil -} - -func (a *Aws) CompleteMultipartUpload(ctx context.Context, uploadID string, name string, parts []s3.Part) (*s3.CompleteMultipartUploadResult, error) { - sdkParts := make([]*sdk.CompletedPart, len(parts)) - for i, part := range parts { - sdkParts[i] = &sdk.CompletedPart{ - ETag: aws.String(part.ETag), - PartNumber: aws.Int64(int64(part.PartNumber)), - } - } - input := &sdk.CompleteMultipartUploadInput{ - Bucket: aws.String(a.bucket), // TODO: To be verified whether it is required - Key: aws.String(name), - UploadId: aws.String(uploadID), - MultipartUpload: &sdk.CompletedMultipartUpload{ - Parts: sdkParts, - }, - } - result, err := a.client.CompleteMultipartUploadWithContext(ctx, input) - if err != nil { - return nil, err - } - return &s3.CompleteMultipartUploadResult{ - Location: *result.Location, - Bucket: *result.Bucket, - Key: *result.Key, - ETag: *result.ETag, - }, nil -} - -func (a *Aws) PartSize(ctx context.Context, size int64) (int64, error) { - if size <= 0 { - return 0, errors.New("size must be greater than 0") - } - if size > maxPartSize*maxNumSize { - return 0, fmt.Errorf("AWS size must be less than the maximum allowed limit") - } - if size <= minPartSize*maxNumSize { - return minPartSize, nil - } - partSize := size / maxNumSize - if size%maxNumSize != 0 { - partSize++ - } - return partSize, nil -} - -func (a *Aws) DeleteObject(ctx context.Context, name string) error { - _, err := a.client.DeleteObjectWithContext(ctx, &sdk.DeleteObjectInput{ - Bucket: aws.String(a.bucket), - Key: aws.String(name), - }) - return err -} - -func (a *Aws) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) { - result, err := a.client.CopyObjectWithContext(ctx, &sdk.CopyObjectInput{ - Bucket: aws.String(a.bucket), - Key: aws.String(dst), - CopySource: aws.String(src), - }) - if err != nil { - return nil, err - } - return &s3.CopyObjectInfo{ - ETag: *result.CopyObjectResult.ETag, - Key: dst, - }, nil -} - -func (a *Aws) IsNotFound(err error) bool { - if err == nil { - return false - } - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case sdk.ErrCodeNoSuchKey: - return true - default: - return false - } - } - return false -} - -func (a *Aws) AbortMultipartUpload(ctx context.Context, uploadID string, name string) error { - _, err := a.client.AbortMultipartUploadWithContext(ctx, &sdk.AbortMultipartUploadInput{ - Bucket: aws.String(a.bucket), - Key: aws.String(name), - UploadId: aws.String(uploadID), - }) - return err -} - -func (a *Aws) ListUploadedParts(ctx context.Context, uploadID string, name string, partNumberMarker int, maxParts int) (*s3.ListUploadedPartsResult, error) { - result, err := a.client.ListPartsWithContext(ctx, &sdk.ListPartsInput{ - Bucket: aws.String(a.bucket), - Key: aws.String(name), - UploadId: aws.String(uploadID), - MaxParts: aws.Int64(int64(maxParts)), - PartNumberMarker: aws.Int64(int64(partNumberMarker)), - }) - if err != nil { - return nil, err - } - parts := make([]s3.UploadedPart, len(result.Parts)) - for i, part := range result.Parts { - parts[i] = s3.UploadedPart{ - PartNumber: int(*part.PartNumber), - LastModified: *part.LastModified, - Size: *part.Size, - ETag: *part.ETag, - } - } - return &s3.ListUploadedPartsResult{ - Key: *result.Key, - UploadID: *result.UploadId, - NextPartNumberMarker: int(*result.NextPartNumberMarker), - MaxParts: int(*result.MaxParts), - UploadedParts: parts, - }, nil -} - -func (a *Aws) PartLimit() *s3.PartLimit { - return &s3.PartLimit{ - MinPartSize: minPartSize, - MaxPartSize: maxPartSize, - MaxNumSize: maxNumSize, - } -} - -func (a *Aws) PresignedPutObject(ctx context.Context, name string, expire time.Duration) (string, error) { - req, _ := a.client.PutObjectRequest(&sdk.PutObjectInput{ - Bucket: aws.String(a.bucket), - Key: aws.String(name), - }) - url, err := req.Presign(expire) - if err != nil { - return "", err - } - return url, nil -} - -func (a *Aws) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, error) { - result, err := a.client.GetObjectWithContext(ctx, &sdk.GetObjectInput{ - Bucket: aws.String(a.bucket), - Key: aws.String(name), - }) - if err != nil { - return nil, err - } - res := &s3.ObjectInfo{ - Key: name, - ETag: *result.ETag, - Size: *result.ContentLength, - LastModified: *result.LastModified, - } - return res, nil -} - -// AccessURL todo. -func (a *Aws) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) { - // todo - return "", nil -} - -func (a *Aws) FormData(ctx context.Context, name string, size int64, contentType string, duration time.Duration) (*s3.FormData, error) { - // todo - return nil, nil -} - -func (a *Aws) AuthSign(ctx context.Context, uploadID string, name string, expire time.Duration, partNumbers []int) (*s3.AuthSignResult, error) { - // todo - return nil, nil -} diff --git a/pkg/common/db/s3/cos/cos.go b/pkg/common/db/s3/cos/cos.go index 619f142ab..a7c26fcc1 100644 --- a/pkg/common/db/s3/cos/cos.go +++ b/pkg/common/db/s3/cos/cos.go @@ -23,13 +23,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/OpenIMSDK/tools/errs" "net/http" "net/url" "strconv" "strings" "time" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3" "github.com/tencentyun/cos-go-sdk-v5" ) @@ -50,13 +50,15 @@ const ( const successCode = http.StatusOK -const ( -// videoSnapshotImagePng = "png" -// videoSnapshotImageJpg = "jpg" -) +type Config struct { + BucketURL string + SecretID string + SecretKey string + SessionToken string + PublicRead bool +} -func NewCos() (s3.Interface, error) { - conf := config.Config.Object.Cos +func NewCos(conf Config) (s3.Interface, error) { u, err := url.Parse(conf.BucketURL) if err != nil { panic(err) @@ -69,6 +71,7 @@ func NewCos() (s3.Interface, error) { }, }) return &Cos{ + publicRead: conf.PublicRead, copyURL: u.Host + "/", client: client, credential: client.GetCredential(), @@ -76,6 +79,7 @@ func NewCos() (s3.Interface, error) { } type Cos struct { + publicRead bool copyURL string client *cos.Client credential *cos.Credential @@ -226,7 +230,7 @@ func (c *Cos) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO } func (c *Cos) IsNotFound(err error) bool { - switch e := err.(type) { + switch e := errs.Unwrap(err).(type) { case *cos.ErrorResponse: return e.Response.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey" default: @@ -327,7 +331,7 @@ func (c *Cos) AccessURL(ctx context.Context, name string, expire time.Duration, } func (c *Cos) getPresignedURL(ctx context.Context, name string, expire time.Duration, opt *cos.PresignedURLOptions) (*url.URL, error) { - if !config.Config.Object.Cos.PublicRead { + if !c.publicRead { return c.client.Object.GetPresignedURL(ctx, http.MethodGet, name, c.credential.SecretID, c.credential.SecretKey, expire, opt) } return c.client.Object.GetObjectURL(name), nil diff --git a/pkg/common/db/s3/minio/image.go b/pkg/common/db/s3/minio/image.go index f363f94b1..71db1ea51 100644 --- a/pkg/common/db/s3/minio/image.go +++ b/pkg/common/db/s3/minio/image.go @@ -42,51 +42,79 @@ func ImageWidthHeight(img image.Image) (int, int) { return bounds.X, bounds.Y } -// resizeImage resizes an image to a specified maximum width and height, maintaining the aspect ratio. -// If both maxWidth and maxHeight are set to 0, the original image is returned. -// If both are non-zero, the image is scaled to fit within the constraints while maintaining aspect ratio. -// If only one of maxWidth or maxHeight is non-zero, the image is scaled accordingly. func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image { bounds := img.Bounds() - imgWidth, imgHeight := bounds.Dx(), bounds.Dy() + imgWidth := bounds.Max.X + imgHeight := bounds.Max.Y - // Return original image if no resizing is needed. + // 计算缩放比例 + scaleWidth := float64(maxWidth) / float64(imgWidth) + scaleHeight := float64(maxHeight) / float64(imgHeight) + + // 如果都为0,则不缩放,返回原始图片 if maxWidth == 0 && maxHeight == 0 { return img } - var scale float64 = 1 + // 如果宽度和高度都大于0,则选择较小的缩放比例,以保持宽高比 if maxWidth > 0 && maxHeight > 0 { - scaleWidth := float64(maxWidth) / float64(imgWidth) - scaleHeight := float64(maxHeight) / float64(imgHeight) - // Choose the smaller scale to fit both constraints. - scale = min(scaleWidth, scaleHeight) - } else if maxWidth > 0 { - scale = float64(maxWidth) / float64(imgWidth) - } else if maxHeight > 0 { - scale = float64(maxHeight) / float64(imgHeight) + scale := scaleWidth + if scaleHeight < scaleWidth { + scale = scaleHeight + } + + // 计算缩略图尺寸 + thumbnailWidth := int(float64(imgWidth) * scale) + thumbnailHeight := int(float64(imgHeight) * scale) + + // 使用"image"库的Resample方法生成缩略图 + thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight)) + for y := 0; y < thumbnailHeight; y++ { + for x := 0; x < thumbnailWidth; x++ { + srcX := int(float64(x) / scale) + srcY := int(float64(y) / scale) + thumbnail.Set(x, y, img.At(srcX, srcY)) + } + } + + return thumbnail } - newWidth := int(float64(imgWidth) * scale) - newHeight := int(float64(imgHeight) * scale) + // 如果只指定了宽度或高度,则根据最大不超过的规则生成缩略图 + if maxWidth > 0 { + thumbnailWidth := maxWidth + thumbnailHeight := int(float64(imgHeight) * scaleWidth) - // Resize the image by creating a new image and manually copying pixels. - thumbnail := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight)) - for y := 0; y < newHeight; y++ { - for x := 0; x < newWidth; x++ { - srcX := int(float64(x) / scale) - srcY := int(float64(y) / scale) - thumbnail.Set(x, y, img.At(srcX, srcY)) + // 使用"image"库的Resample方法生成缩略图 + thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight)) + for y := 0; y < thumbnailHeight; y++ { + for x := 0; x < thumbnailWidth; x++ { + srcX := int(float64(x) / scaleWidth) + srcY := int(float64(y) / scaleWidth) + thumbnail.Set(x, y, img.At(srcX, srcY)) + } } + + return thumbnail } - return thumbnail -} + if maxHeight > 0 { + thumbnailWidth := int(float64(imgWidth) * scaleHeight) + thumbnailHeight := maxHeight -// min returns the smaller of x or y. -func min(x, y float64) float64 { - if x < y { - return x + // 使用"image"库的Resample方法生成缩略图 + thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight)) + for y := 0; y < thumbnailHeight; y++ { + for x := 0; x < thumbnailWidth; x++ { + srcX := int(float64(x) / scaleHeight) + srcY := int(float64(y) / scaleHeight) + thumbnail.Set(x, y, img.At(srcX, srcY)) + } + } + + return thumbnail } - return y + + // 默认情况下,返回原始图片 + return img } diff --git a/pkg/common/db/s3/minio/minio.go b/pkg/common/db/s3/minio/minio.go index 1eb3257e1..cd77948d4 100644 --- a/pkg/common/db/s3/minio/minio.go +++ b/pkg/common/db/s3/minio/minio.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "github.com/OpenIMSDK/tools/errs" "io" "net/http" "net/url" @@ -33,7 +34,6 @@ import ( "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "github.com/minio/minio-go/v7/pkg/signer" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3" ) @@ -43,7 +43,7 @@ const ( ) const ( - minPartSize int64 = 1024 * 1024 * 5 // 1MB + minPartSize int64 = 1024 * 1024 * 5 // 5MB maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB maxNumSize int64 = 10000 ) @@ -57,13 +57,23 @@ const ( const successCode = http.StatusOK -func NewMinio(cache cache.MinioCache) (s3.Interface, error) { - u, err := url.Parse(config.Config.Object.Minio.Endpoint) +type Config struct { + Bucket string + Endpoint string + AccessKeyID string + SecretAccessKey string + SessionToken string + SignEndpoint string + PublicRead bool +} + +func NewMinio(cache cache.MinioCache, conf Config) (s3.Interface, error) { + u, err := url.Parse(conf.Endpoint) if err != nil { return nil, err } opts := &minio.Options{ - Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, config.Config.Object.Minio.SessionToken), + Creds: credentials.NewStaticV4(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken), Secure: u.Scheme == "https", } client, err := minio.New(u.Host, opts) @@ -71,26 +81,27 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) { return nil, err } m := &Minio{ - bucket: config.Config.Object.Minio.Bucket, + conf: conf, + bucket: conf.Bucket, core: &minio.Core{Client: client}, lock: &sync.Mutex{}, init: false, cache: cache, } - if config.Config.Object.Minio.SignEndpoint == "" || config.Config.Object.Minio.SignEndpoint == config.Config.Object.Minio.Endpoint { + if conf.SignEndpoint == "" || conf.SignEndpoint == conf.Endpoint { m.opts = opts m.sign = m.core.Client m.prefix = u.Path u.Path = "" - config.Config.Object.Minio.Endpoint = u.String() - m.signEndpoint = config.Config.Object.Minio.Endpoint + conf.Endpoint = u.String() + m.signEndpoint = conf.Endpoint } else { - su, err := url.Parse(config.Config.Object.Minio.SignEndpoint) + su, err := url.Parse(conf.SignEndpoint) if err != nil { return nil, err } m.opts = &minio.Options{ - Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, config.Config.Object.Minio.SessionToken), + Creds: credentials.NewStaticV4(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken), Secure: su.Scheme == "https", } m.sign, err = minio.New(su.Host, m.opts) @@ -99,8 +110,8 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) { } m.prefix = su.Path su.Path = "" - config.Config.Object.Minio.SignEndpoint = su.String() - m.signEndpoint = config.Config.Object.Minio.SignEndpoint + conf.SignEndpoint = su.String() + m.signEndpoint = conf.SignEndpoint } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -111,6 +122,7 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) { } type Minio struct { + conf Config bucket string signEndpoint string location string @@ -132,31 +144,30 @@ func (m *Minio) initMinio(ctx context.Context) error { if m.init { return nil } - conf := config.Config.Object.Minio - exists, err := m.core.Client.BucketExists(ctx, conf.Bucket) + exists, err := m.core.Client.BucketExists(ctx, m.conf.Bucket) if err != nil { return fmt.Errorf("check bucket exists error: %w", err) } if !exists { - if err = m.core.Client.MakeBucket(ctx, conf.Bucket, minio.MakeBucketOptions{}); err != nil { + if err = m.core.Client.MakeBucket(ctx, m.conf.Bucket, minio.MakeBucketOptions{}); err != nil { return fmt.Errorf("make bucket error: %w", err) } } - if conf.PublicRead { + if m.conf.PublicRead { policy := fmt.Sprintf( `{"Version": "2012-10-17","Statement": [{"Action": ["s3:GetObject","s3:PutObject"],"Effect": "Allow","Principal": {"AWS": ["*"]},"Resource": ["arn:aws:s3:::%s/*"],"Sid": ""}]}`, - conf.Bucket, + m.conf.Bucket, ) - if err = m.core.Client.SetBucketPolicy(ctx, conf.Bucket, policy); err != nil { + if err = m.core.Client.SetBucketPolicy(ctx, m.conf.Bucket, policy); err != nil { return err } } - m.location, err = m.core.Client.GetBucketLocation(ctx, conf.Bucket) + m.location, err = m.core.Client.GetBucketLocation(ctx, m.conf.Bucket) if err != nil { return err } func() { - if conf.SignEndpoint == "" || conf.SignEndpoint == conf.Endpoint { + if m.conf.SignEndpoint == "" || m.conf.SignEndpoint == m.conf.Endpoint { return } defer func() { @@ -176,7 +187,7 @@ func (m *Minio) initMinio(ctx context.Context) error { blc := reflect.ValueOf(m.sign).Elem().FieldByName("bucketLocCache") vblc := reflect.New(reflect.PtrTo(blc.Type())) *(*unsafe.Pointer)(vblc.UnsafePointer()) = unsafe.Pointer(blc.UnsafeAddr()) - vblc.Elem().Elem().Interface().(interface{ Set(string, string) }).Set(conf.Bucket, m.location) + vblc.Elem().Elem().Interface().(interface{ Set(string, string) }).Set(m.conf.Bucket, m.location) }() m.init = true return nil @@ -341,10 +352,7 @@ func (m *Minio) CopyObject(ctx context.Context, src string, dst string) (*s3.Cop } func (m *Minio) IsNotFound(err error) bool { - if err == nil { - return false - } - switch e := err.(type) { + switch e := errs.Unwrap(err).(type) { case minio.ErrorResponse: return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey" case *minio.ErrorResponse: @@ -397,7 +405,7 @@ func (m *Minio) PresignedGetObject(ctx context.Context, name string, expire time rawURL *url.URL err error ) - if config.Config.Object.Minio.PublicRead { + if m.conf.PublicRead { rawURL, err = makeTargetURL(m.sign, m.bucket, name, m.location, false, query) } else { rawURL, err = m.sign.PresignedGetObject(ctx, m.bucket, name, expire, query) diff --git a/pkg/common/db/s3/oss/oss.go b/pkg/common/db/s3/oss/oss.go index 442f4e52f..e485db277 100644 --- a/pkg/common/db/s3/oss/oss.go +++ b/pkg/common/db/s3/oss/oss.go @@ -32,7 +32,6 @@ import ( "github.com/OpenIMSDK/tools/errs" "github.com/aliyun/aliyun-oss-go-sdk/oss" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3" ) @@ -52,13 +51,17 @@ const ( const successCode = http.StatusOK -/* const ( - videoSnapshotImagePng = "png" - videoSnapshotImageJpg = "jpg" -) */ +type Config struct { + Endpoint string + Bucket string + BucketURL string + AccessKeyID string + AccessKeySecret string + SessionToken string + PublicRead bool +} -func NewOSS() (s3.Interface, error) { - conf := config.Config.Object.Oss +func NewOSS(conf Config) (s3.Interface, error) { if conf.BucketURL == "" { return nil, errs.Wrap(errors.New("bucket url is empty")) } @@ -78,6 +81,7 @@ func NewOSS() (s3.Interface, error) { bucket: bucket, credentials: client.Config.GetCredentials(), um: *(*urlMaker)(reflect.ValueOf(bucket.Client.Conn).Elem().FieldByName("url").UnsafePointer()), + publicRead: conf.PublicRead, }, nil } @@ -86,6 +90,7 @@ type OSS struct { bucket *oss.Bucket credentials oss.Credentials um urlMaker + publicRead bool } func (o *OSS) Engine() string { @@ -236,7 +241,7 @@ func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO } func (o *OSS) IsNotFound(err error) bool { - switch e := err.(type) { + switch e := errs.Unwrap(err).(type) { case oss.ServiceError: return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey" case *oss.ServiceError: @@ -282,7 +287,6 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin } func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) { - publicRead := config.Config.Object.Oss.PublicRead var opts []oss.Option if opt != nil { if opt.Image != nil { @@ -310,7 +314,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, process += ",format," + format opts = append(opts, oss.Process(process)) } - if !publicRead { + if !o.publicRead { if opt.ContentType != "" { opts = append(opts, oss.ResponseContentType(opt.ContentType)) } @@ -324,7 +328,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, } else if expire < time.Second { expire = time.Second } - if !publicRead { + if !o.publicRead { return o.bucket.SignURL(name, http.MethodGet, int64(expire/time.Second), opts...) } rawParams, err := oss.GetRawParams(opts) diff --git a/pkg/common/db/unrelation/mongo.go b/pkg/common/db/unrelation/mongo.go index 09880fb37..363e97867 100644 --- a/pkg/common/db/unrelation/mongo.go +++ b/pkg/common/db/unrelation/mongo.go @@ -36,13 +36,14 @@ const ( ) type Mongo struct { - db *mongo.Client + db *mongo.Client + config *config.GlobalConfig } // NewMongo Initialize MongoDB connection. -func NewMongo() (*Mongo, error) { +func NewMongo(config *config.GlobalConfig) (*Mongo, error) { specialerror.AddReplace(mongo.ErrNoDocuments, errs.ErrRecordNotFound) - uri := buildMongoURI() + uri := buildMongoURI(config) var mongoClient *mongo.Client var err error @@ -56,7 +57,7 @@ func NewMongo() (*Mongo, error) { if err = mongoClient.Ping(ctx, nil); err != nil { return nil, errs.Wrap(err, uri) } - return &Mongo{db: mongoClient}, nil + return &Mongo{db: mongoClient, config: config}, nil } if shouldRetry(err) { time.Sleep(time.Second) // exponential backoff could be implemented here @@ -66,14 +67,14 @@ func NewMongo() (*Mongo, error) { return nil, errs.Wrap(err, uri) } -func buildMongoURI() string { +func buildMongoURI(config *config.GlobalConfig) string { uri := os.Getenv("MONGO_URI") if uri != "" { return uri } - if config.Config.Mongo.Uri != "" { - return config.Config.Mongo.Uri + if config.Mongo.Uri != "" { + return config.Mongo.Uri } username := os.Getenv("MONGO_OPENIM_USERNAME") @@ -84,21 +85,21 @@ func buildMongoURI() string { maxPoolSize := os.Getenv("MONGO_MAX_POOL_SIZE") if username == "" { - username = config.Config.Mongo.Username + username = config.Mongo.Username } if password == "" { - password = config.Config.Mongo.Password + password = config.Mongo.Password } if address == "" { - address = strings.Join(config.Config.Mongo.Address, ",") + address = strings.Join(config.Mongo.Address, ",") } else if port != "" { address = fmt.Sprintf("%s:%s", address, port) } if database == "" { - database = config.Config.Mongo.Database + database = config.Mongo.Database } if maxPoolSize == "" { - maxPoolSize = fmt.Sprint(config.Config.Mongo.MaxPoolSize) + maxPoolSize = fmt.Sprint(config.Mongo.MaxPoolSize) } uriFormat := "mongodb://%s/%s?maxPoolSize=%s" @@ -122,8 +123,8 @@ func (m *Mongo) GetClient() *mongo.Client { } // GetDatabase returns the specific database from MongoDB. -func (m *Mongo) GetDatabase() *mongo.Database { - return m.db.Database(config.Config.Mongo.Database) +func (m *Mongo) GetDatabase(database string) *mongo.Database { + return m.db.Database(database) } // CreateMsgIndex creates an index for messages in MongoDB. @@ -133,7 +134,7 @@ func (m *Mongo) CreateMsgIndex() error { // createMongoIndex creates an index in a MongoDB collection. func (m *Mongo) createMongoIndex(collection string, isUnique bool, keys ...string) error { - db := m.GetDatabase().Collection(collection) + db := m.GetDatabase(m.config.Mongo.Database).Collection(collection) opts := options.CreateIndexes().SetMaxTime(10 * time.Second) indexView := db.Indexes() diff --git a/pkg/common/discoveryregister/direct/directconn.go b/pkg/common/discoveryregister/direct/directconn.go index 2ae0de170..ced209602 100644 --- a/pkg/common/discoveryregister/direct/directconn.go +++ b/pkg/common/discoveryregister/direct/directconn.go @@ -27,17 +27,17 @@ import ( type ServiceAddresses map[string][]int -func getServiceAddresses() ServiceAddresses { +func getServiceAddresses(config *config2.GlobalConfig) ServiceAddresses { return ServiceAddresses{ - config2.Config.RpcRegisterName.OpenImUserName: config2.Config.RpcPort.OpenImUserPort, - config2.Config.RpcRegisterName.OpenImFriendName: config2.Config.RpcPort.OpenImFriendPort, - config2.Config.RpcRegisterName.OpenImMsgName: config2.Config.RpcPort.OpenImMessagePort, - config2.Config.RpcRegisterName.OpenImMessageGatewayName: config2.Config.LongConnSvr.OpenImMessageGatewayPort, - config2.Config.RpcRegisterName.OpenImGroupName: config2.Config.RpcPort.OpenImGroupPort, - config2.Config.RpcRegisterName.OpenImAuthName: config2.Config.RpcPort.OpenImAuthPort, - config2.Config.RpcRegisterName.OpenImPushName: config2.Config.RpcPort.OpenImPushPort, - config2.Config.RpcRegisterName.OpenImConversationName: config2.Config.RpcPort.OpenImConversationPort, - config2.Config.RpcRegisterName.OpenImThirdName: config2.Config.RpcPort.OpenImThirdPort, + config.RpcRegisterName.OpenImUserName: config.RpcPort.OpenImUserPort, + config.RpcRegisterName.OpenImFriendName: config.RpcPort.OpenImFriendPort, + config.RpcRegisterName.OpenImMsgName: config.RpcPort.OpenImMessagePort, + config.RpcRegisterName.OpenImMessageGatewayName: config.LongConnSvr.OpenImMessageGatewayPort, + config.RpcRegisterName.OpenImGroupName: config.RpcPort.OpenImGroupPort, + config.RpcRegisterName.OpenImAuthName: config.RpcPort.OpenImAuthPort, + config.RpcRegisterName.OpenImPushName: config.RpcPort.OpenImPushPort, + config.RpcRegisterName.OpenImConversationName: config.RpcPort.OpenImConversationPort, + config.RpcRegisterName.OpenImThirdName: config.RpcPort.OpenImThirdPort, } } @@ -46,6 +46,7 @@ type ConnDirect struct { currentServiceAddress string conns map[string][]*grpc.ClientConn resolverDirect *ResolverDirect + config *config2.GlobalConfig } func (cd *ConnDirect) GetClientLocalConns() map[string][]*grpc.ClientConn { @@ -80,10 +81,11 @@ func (cd *ConnDirect) Close() { } -func NewConnDirect() (*ConnDirect, error) { +func NewConnDirect(config *config2.GlobalConfig) (*ConnDirect, error) { return &ConnDirect{ conns: make(map[string][]*grpc.ClientConn), resolverDirect: NewResolverDirect(), + config: config, }, nil } @@ -93,12 +95,12 @@ func (cd *ConnDirect) GetConns(ctx context.Context, if conns, exists := cd.conns[serviceName]; exists { return conns, nil } - ports := getServiceAddresses()[serviceName] + ports := getServiceAddresses(cd.config)[serviceName] var connections []*grpc.ClientConn for _, port := range ports { - conn, err := cd.dialServiceWithoutResolver(ctx, fmt.Sprintf(config2.Config.Rpc.ListenIP+":%d", port), append(cd.additionalOpts, opts...)...) + conn, err := cd.dialServiceWithoutResolver(ctx, fmt.Sprintf(cd.config.Rpc.ListenIP+":%d", port), append(cd.additionalOpts, opts...)...) if err != nil { - fmt.Printf("connect to port %d failed,serviceName %s, IP %s\n", port, serviceName, config2.Config.Rpc.ListenIP) + fmt.Printf("connect to port %d failed,serviceName %s, IP %s\n", port, serviceName, cd.config.Rpc.ListenIP) } connections = append(connections, conn) } @@ -111,7 +113,7 @@ func (cd *ConnDirect) GetConns(ctx context.Context, func (cd *ConnDirect) GetConn(ctx context.Context, serviceName string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { // Get service addresses - addresses := getServiceAddresses() + addresses := getServiceAddresses(cd.config) address, ok := addresses[serviceName] if !ok { return nil, errs.Wrap(errors.New("unknown service name"), "serviceName", serviceName) @@ -119,9 +121,9 @@ func (cd *ConnDirect) GetConn(ctx context.Context, serviceName string, opts ...g var result string for _, addr := range address { if result != "" { - result = result + "," + fmt.Sprintf(config2.Config.Rpc.ListenIP+":%d", addr) + result = result + "," + fmt.Sprintf(cd.config.Rpc.ListenIP+":%d", addr) } else { - result = fmt.Sprintf(config2.Config.Rpc.ListenIP+":%d", addr) + result = fmt.Sprintf(cd.config.Rpc.ListenIP+":%d", addr) } } // Try to dial a new connection diff --git a/pkg/common/discoveryregister/discoveryregister.go b/pkg/common/discoveryregister/discoveryregister.go index a21d8d62a..c43583a80 100644 --- a/pkg/common/discoveryregister/discoveryregister.go +++ b/pkg/common/discoveryregister/discoveryregister.go @@ -16,6 +16,7 @@ package discoveryregister import ( "errors" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "os" "github.com/OpenIMSDK/tools/discoveryregistry" @@ -26,19 +27,19 @@ import ( ) // NewDiscoveryRegister creates a new service discovery and registry client based on the provided environment type. -func NewDiscoveryRegister(envType string) (discoveryregistry.SvcDiscoveryRegistry, error) { +func NewDiscoveryRegister(config *config.GlobalConfig) (discoveryregistry.SvcDiscoveryRegistry, error) { if os.Getenv("ENVS_DISCOVERY") != "" { - envType = os.Getenv("ENVS_DISCOVERY") + config.Envs.Discovery = os.Getenv("ENVS_DISCOVERY") } - switch envType { + switch config.Envs.Discovery { case "zookeeper": - return zookeeper.NewZookeeperDiscoveryRegister() + return zookeeper.NewZookeeperDiscoveryRegister(config) case "k8s": - return kubernetes.NewK8sDiscoveryRegister() + return kubernetes.NewK8sDiscoveryRegister(config.RpcRegisterName.OpenImMessageGatewayName) case "direct": - return direct.NewConnDirect() + return direct.NewConnDirect(config) default: return nil, errs.Wrap(errors.New("envType not correct")) } diff --git a/pkg/common/discoveryregister/discoveryregister_test.go b/pkg/common/discoveryregister/discoveryregister_test.go index e7a5fb276..7d4fa53cd 100644 --- a/pkg/common/discoveryregister/discoveryregister_test.go +++ b/pkg/common/discoveryregister/discoveryregister_test.go @@ -15,6 +15,7 @@ package discoveryregister import ( + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "os" "testing" @@ -32,20 +33,23 @@ func setupTestEnvironment() { func TestNewDiscoveryRegister(t *testing.T) { setupTestEnvironment() - + conf := config.NewGlobalConfig() tests := []struct { envType string + gatewayName string expectedError bool expectedResult bool }{ - {"zookeeper", false, true}, - {"k8s", false, true}, // Assume that the k8s configuration is also set up correctly - {"direct", false, true}, - {"invalid", true, false}, + {"zookeeper", "MessageGateway", false, true}, + {"k8s", "MessageGateway", false, true}, + {"direct", "MessageGateway", false, true}, + {"invalid", "MessageGateway", true, false}, } for _, test := range tests { - client, err := NewDiscoveryRegister(test.envType) + conf.Envs.Discovery = test.envType + conf.RpcRegisterName.OpenImMessageGatewayName = test.gatewayName + client, err := NewDiscoveryRegister(conf) if test.expectedError { assert.Error(t, err) diff --git a/pkg/common/discoveryregister/kubernetes/kubernetes.go b/pkg/common/discoveryregister/kubernetes/kubernetes.go index 1292c64a8..b5d603fd1 100644 --- a/pkg/common/discoveryregister/kubernetes/kubernetes.go +++ b/pkg/common/discoveryregister/kubernetes/kubernetes.go @@ -22,11 +22,12 @@ import ( "strconv" "strings" - "github.com/OpenIMSDK/tools/discoveryregistry" - "github.com/OpenIMSDK/tools/log" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/stathat/consistent" + "google.golang.org/grpc" + + "github.com/OpenIMSDK/tools/discoveryregistry" + "github.com/OpenIMSDK/tools/log" ) // K8sDR represents the Kubernetes service discovery and registration client. @@ -34,11 +35,12 @@ type K8sDR struct { options []grpc.DialOption rpcRegisterAddr string gatewayHostConsistent *consistent.Consistent + gatewayName string } -func NewK8sDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) { +func NewK8sDiscoveryRegister(gatewayName string) (discoveryregistry.SvcDiscoveryRegistry, error) { gatewayConsistent := consistent.New() - gatewayHosts := getMsgGatewayHost(context.Background()) + gatewayHosts := getMsgGatewayHost(context.Background(), gatewayName) for _, v := range gatewayHosts { gatewayConsistent.Add(v) } @@ -46,10 +48,10 @@ func NewK8sDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) { } func (cli *K8sDR) Register(serviceName, host string, port int, opts ...grpc.DialOption) error { - if serviceName != config.Config.RpcRegisterName.OpenImMessageGatewayName { + if serviceName != cli.gatewayName { cli.rpcRegisterAddr = serviceName } else { - cli.rpcRegisterAddr = getSelfHost(context.Background()) + cli.rpcRegisterAddr = getSelfHost(context.Background(), cli.gatewayName) } return nil @@ -81,15 +83,15 @@ func (cli *K8sDR) GetUserIdHashGatewayHost(ctx context.Context, userId string) ( } return host, err } -func getSelfHost(ctx context.Context) string { +func getSelfHost(ctx context.Context, gatewayName string) string { port := 88 instance := "openimserver" selfPodName := os.Getenv("MY_POD_NAME") ns := os.Getenv("MY_POD_NAMESPACE") statefuleIndex := 0 - gatewayEnds := strings.Split(config.Config.RpcRegisterName.OpenImMessageGatewayName, ":") + gatewayEnds := strings.Split(gatewayName, ":") if len(gatewayEnds) != 2 { - log.ZError(ctx, "msggateway RpcRegisterName is error:config.Config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error")) + log.ZError(ctx, "msggateway RpcRegisterName is error:config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error")) } else { port, _ = strconv.Atoi(gatewayEnds[1]) } @@ -102,15 +104,15 @@ func getSelfHost(ctx context.Context) string { } // like openimserver-openim-msggateway-0.openimserver-openim-msggateway-headless.openim-lin.svc.cluster.local:88. -func getMsgGatewayHost(ctx context.Context) []string { +func getMsgGatewayHost(ctx context.Context, gatewayName string) []string { port := 88 instance := "openimserver" selfPodName := os.Getenv("MY_POD_NAME") replicas := os.Getenv("MY_MSGGATEWAY_REPLICACOUNT") ns := os.Getenv("MY_POD_NAMESPACE") - gatewayEnds := strings.Split(config.Config.RpcRegisterName.OpenImMessageGatewayName, ":") + gatewayEnds := strings.Split(gatewayName, ":") if len(gatewayEnds) != 2 { - log.ZError(ctx, "msggateway RpcRegisterName is error:config.Config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error")) + log.ZError(ctx, "msggateway RpcRegisterName is error:config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error")) } else { port, _ = strconv.Atoi(gatewayEnds[1]) } @@ -131,7 +133,7 @@ func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc // This conditional checks if the serviceName is not the OpenImMessageGatewayName. // It seems to handle a special case for the OpenImMessageGateway. - if serviceName != config.Config.RpcRegisterName.OpenImMessageGatewayName { + if serviceName != cli.gatewayName { // DialContext creates a client connection to the given target (serviceName) using the specified context. // 'cli.options' are likely default or common options for all connections in this struct. // 'opts...' allows for additional gRPC dial options to be passed and used. @@ -146,7 +148,7 @@ func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc // getMsgGatewayHost presumably retrieves hosts for the message gateway service. // The context is passed, likely for cancellation and timeout control. - gatewayHosts := getMsgGatewayHost(ctx) + gatewayHosts := getMsgGatewayHost(ctx, cli.gatewayName) // Iterating over the retrieved gateway hosts. for _, host := range gatewayHosts { diff --git a/pkg/common/discoveryregister/zookeeper/zookeeper.go b/pkg/common/discoveryregister/zookeeper/zookeeper.go index 5f9a3b6bd..0aa40a907 100644 --- a/pkg/common/discoveryregister/zookeeper/zookeeper.go +++ b/pkg/common/discoveryregister/zookeeper/zookeeper.go @@ -28,11 +28,11 @@ import ( ) // NewZookeeperDiscoveryRegister creates a new instance of ZookeeperDR for Zookeeper service discovery and registration. -func NewZookeeperDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) { - schema := getEnv("ZOOKEEPER_SCHEMA", config.Config.Zookeeper.Schema) - zkAddr := getZkAddrFromEnv(config.Config.Zookeeper.ZkAddr) - username := getEnv("ZOOKEEPER_USERNAME", config.Config.Zookeeper.Username) - password := getEnv("ZOOKEEPER_PASSWORD", config.Config.Zookeeper.Password) +func NewZookeeperDiscoveryRegister(config *config.GlobalConfig) (discoveryregistry.SvcDiscoveryRegistry, error) { + schema := getEnv("ZOOKEEPER_SCHEMA", config.Zookeeper.Schema) + zkAddr := getZkAddrFromEnv(config.Zookeeper.ZkAddr) + username := getEnv("ZOOKEEPER_USERNAME", config.Zookeeper.Username) + password := getEnv("ZOOKEEPER_PASSWORD", config.Zookeeper.Password) zk, err := openkeeper.NewClient( zkAddr, @@ -46,10 +46,10 @@ func NewZookeeperDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, er if err != nil { uriFormat := "address:%s, username:%s, password:%s, schema:%s." errInfo := fmt.Sprintf(uriFormat, - config.Config.Zookeeper.ZkAddr, - config.Config.Zookeeper.Username, - config.Config.Zookeeper.Password, - config.Config.Zookeeper.Schema) + config.Zookeeper.ZkAddr, + config.Zookeeper.Username, + config.Zookeeper.Password, + config.Zookeeper.Schema) return nil, errs.Wrap(err, errInfo) } return zk, nil diff --git a/pkg/common/kafka/consumer.go b/pkg/common/kafka/consumer.go index d0e06d482..664e5d468 100644 --- a/pkg/common/kafka/consumer.go +++ b/pkg/common/kafka/consumer.go @@ -17,6 +17,8 @@ package kafka import ( "sync" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/IBM/sarama" "github.com/OpenIMSDK/tools/errs" ) @@ -29,22 +31,31 @@ type Consumer struct { Consumer sarama.Consumer } -func NewKafkaConsumer(addr []string, topic string, kafkaConfig *sarama.Config) (*Consumer, error) { - p := Consumer{ - Topic: topic, - addr: addr, +func NewKafkaConsumer(addr []string, topic string, config *config.GlobalConfig) (*Consumer,error) { + p := Consumer{} + p.Topic = topic + p.addr = addr + consumerConfig := sarama.NewConfig() + if config.Kafka.Username != "" && config.Kafka.Password != "" { + consumerConfig.Net.SASL.Enable = true + consumerConfig.Net.SASL.User = config.Kafka.Username + consumerConfig.Net.SASL.Password = config.Kafka.Password } - - if kafkaConfig.Net.SASL.User != "" && kafkaConfig.Net.SASL.Password != "" { - kafkaConfig.Net.SASL.Enable = true + var tlsConfig *TLSConfig + if config.Kafka.TLS != nil { + tlsConfig = &TLSConfig{ + CACrt: config.Kafka.TLS.CACrt, + ClientCrt: config.Kafka.TLS.ClientCrt, + ClientKey: config.Kafka.TLS.ClientKey, + ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd, + InsecureSkipVerify: false, + } } - - err := SetupTLSConfig(kafkaConfig) - if err != nil { - return nil, err + err:=SetupTLSConfig(consumerConfig, tlsConfig) + if err!=nil{ + return nil,err } - - consumer, err := sarama.NewConsumer(p.addr, kafkaConfig) + consumer, err := sarama.NewConsumer(p.addr, consumerConfig) if err != nil { return nil, errs.Wrap(err, "NewKafkaConsumer: creating consumer failed") } diff --git a/pkg/common/kafka/consumer_group.go b/pkg/common/kafka/consumer_group.go index 824c5be2e..b6b4435ab 100644 --- a/pkg/common/kafka/consumer_group.go +++ b/pkg/common/kafka/consumer_group.go @@ -17,12 +17,12 @@ package kafka import ( "context" "errors" - "strings" "github.com/IBM/sarama" "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" + + "strings" ) type MConsumerGroup struct { @@ -35,22 +35,25 @@ type MConsumerGroupConfig struct { KafkaVersion sarama.KafkaVersion OffsetsInitial int64 IsReturnErr bool + UserName string + Password string } -func NewMConsumerGroup(consumerConfig *MConsumerGroupConfig, topics, addrs []string, groupID string) (*MConsumerGroup, error) { +func NewMConsumerGroup(consumerConfig *MConsumerGroupConfig, topics, addrs []string, groupID string, tlsConfig *TLSConfig) (*MConsumerGroup, error) { consumerGroupConfig := sarama.NewConfig() consumerGroupConfig.Version = consumerConfig.KafkaVersion consumerGroupConfig.Consumer.Offsets.Initial = consumerConfig.OffsetsInitial consumerGroupConfig.Consumer.Return.Errors = consumerConfig.IsReturnErr - if config.Config.Kafka.Username != "" && config.Config.Kafka.Password != "" { + if consumerConfig.UserName != "" && consumerConfig.Password != "" { consumerGroupConfig.Net.SASL.Enable = true - consumerGroupConfig.Net.SASL.User = config.Config.Kafka.Username - consumerGroupConfig.Net.SASL.Password = config.Config.Kafka.Password + consumerGroupConfig.Net.SASL.User = consumerConfig.UserName + consumerGroupConfig.Net.SASL.Password = consumerConfig.Password } - SetupTLSConfig(consumerGroupConfig) + + SetupTLSConfig(consumerGroupConfig, tlsConfig) consumerGroup, err := sarama.NewConsumerGroup(addrs, groupID, consumerGroupConfig) if err != nil { - return nil, errs.Wrap(err, strings.Join(topics, ","), strings.Join(addrs, ","), groupID, config.Config.Kafka.Username, config.Config.Kafka.Password) + return nil, errs.Wrap(err, strings.Join(topics, ","), strings.Join(addrs, ","), groupID, consumerConfig.UserName, consumerConfig.Password) } return &MConsumerGroup{ diff --git a/pkg/common/kafka/producer.go b/pkg/common/kafka/producer.go index 4c2378c59..afc53b35a 100644 --- a/pkg/common/kafka/producer.go +++ b/pkg/common/kafka/producer.go @@ -22,12 +22,12 @@ import ( "strings" "time" + "github.com/OpenIMSDK/tools/errs" + "github.com/IBM/sarama" "github.com/OpenIMSDK/protocol/constant" - "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/mcontext" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "google.golang.org/protobuf/proto" ) @@ -43,8 +43,15 @@ type Producer struct { producer sarama.SyncProducer } +type ProducerConfig struct { + ProducerAck string + CompressType string + Username string + Password string +} + // NewKafkaProducer initializes a new Kafka producer. -func NewKafkaProducer(addr []string, topic string) (*Producer, error) { +func NewKafkaProducer(addr []string, topic string, producerConfig *ProducerConfig, tlsConfig *TLSConfig) (*Producer, error) { p := Producer{ addr: addr, topic: topic, @@ -59,14 +66,14 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) { p.config.Producer.Partitioner = sarama.NewHashPartitioner // Configure producer acknowledgement level - configureProducerAck(&p, config.Config.Kafka.ProducerAck) + configureProducerAck(&p, producerConfig.ProducerAck) // Configure message compression - configureCompression(&p, config.Config.Kafka.CompressType) + configureCompression(&p, producerConfig.CompressType) // Get Kafka configuration from environment variables or fallback to config file - kafkaUsername := getEnvOrConfig("KAFKA_USERNAME", config.Config.Kafka.Username) - kafkaPassword := getEnvOrConfig("KAFKA_PASSWORD", config.Config.Kafka.Password) + kafkaUsername := getEnvOrConfig("KAFKA_USERNAME", producerConfig.Username) + kafkaPassword := getEnvOrConfig("KAFKA_PASSWORD", producerConfig.Password) kafkaAddr := getKafkaAddrFromEnv(addr) // Updated to use the new function // Configure SASL authentication if credentials are provided @@ -80,7 +87,7 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) { p.addr = kafkaAddr // Set up TLS configuration (if required) - SetupTLSConfig(p.config) + SetupTLSConfig(p.config, tlsConfig) // Create the producer with retries var err error diff --git a/pkg/common/kafka/util.go b/pkg/common/kafka/util.go index a94397819..4e2a02714 100644 --- a/pkg/common/kafka/util.go +++ b/pkg/common/kafka/util.go @@ -20,19 +20,27 @@ import ( "strings" "github.com/IBM/sarama" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/tls" ) +type TLSConfig struct { + CACrt string + ClientCrt string + ClientKey string + ClientKeyPwd string + InsecureSkipVerify bool +} + // SetupTLSConfig set up the TLS config from config file. -func SetupTLSConfig(cfg *sarama.Config) error { - if config.Config.Kafka.TLS != nil { +func SetupTLSConfig(cfg *sarama.Config, tlsConfig *TLSConfig) error { + if tlsConfig != nil { cfg.Net.TLS.Enable = true tlsConfig, err := tls.NewTLSConfig( - config.Config.Kafka.TLS.ClientCrt, - config.Config.Kafka.TLS.ClientKey, - config.Config.Kafka.TLS.CACrt, - []byte(config.Config.Kafka.TLS.ClientKeyPwd), + tlsConfig.ClientCrt, + tlsConfig.ClientKey, + tlsConfig.CACrt, + []byte(tlsConfig.ClientKeyPwd), + tlsConfig.InsecureSkipVerify, ) if err != nil { return err diff --git a/pkg/common/prommetrics/prommetrics.go b/pkg/common/prommetrics/prommetrics.go index 52694168f..9089e7b5f 100644 --- a/pkg/common/prommetrics/prommetrics.go +++ b/pkg/common/prommetrics/prommetrics.go @@ -31,17 +31,17 @@ func NewGrpcPromObj(cusMetrics []prometheus.Collector) (*prometheus.Registry, *g return reg, grpcMetrics, nil } -func GetGrpcCusMetrics(registerName string) []prometheus.Collector { +func GetGrpcCusMetrics(registerName string, config *config2.GlobalConfig) []prometheus.Collector { switch registerName { - case config2.Config.RpcRegisterName.OpenImMessageGatewayName: + case config.RpcRegisterName.OpenImMessageGatewayName: return []prometheus.Collector{OnlineUserGauge} - case config2.Config.RpcRegisterName.OpenImMsgName: + case config.RpcRegisterName.OpenImMsgName: return []prometheus.Collector{SingleChatMsgProcessSuccessCounter, SingleChatMsgProcessFailedCounter, GroupChatMsgProcessSuccessCounter, GroupChatMsgProcessFailedCounter} case "Transfer": return []prometheus.Collector{MsgInsertRedisSuccessCounter, MsgInsertRedisFailedCounter, MsgInsertMongoSuccessCounter, MsgInsertMongoFailedCounter, SeqSetFailedCounter} - case config2.Config.RpcRegisterName.OpenImPushName: + case config.RpcRegisterName.OpenImPushName: return []prometheus.Collector{MsgOfflinePushFailedCounter} - case config2.Config.RpcRegisterName.OpenImAuthName: + case config.RpcRegisterName.OpenImAuthName: return []prometheus.Collector{UserLoginCounter} default: return nil diff --git a/pkg/common/prommetrics/prommetrics_test.go b/pkg/common/prommetrics/prommetrics_test.go index 1e48c63ba..eb6f3c771 100644 --- a/pkg/common/prommetrics/prommetrics_test.go +++ b/pkg/common/prommetrics/prommetrics_test.go @@ -58,17 +58,20 @@ func TestNewGrpcPromObj(t *testing.T) { } func TestGetGrpcCusMetrics(t *testing.T) { + conf := config2.NewGlobalConfig() + + config2.InitConfig(conf, "../../config") // Test various cases based on the switch statement in the GetGrpcCusMetrics function. testCases := []struct { name string expected int // The expected number of metrics for each case. }{ - {config2.Config.RpcRegisterName.OpenImMessageGatewayName, 1}, + {conf.RpcRegisterName.OpenImMessageGatewayName, 1}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - metrics := GetGrpcCusMetrics(tc.name) + metrics := GetGrpcCusMetrics(tc.name, conf) assert.Len(t, metrics, tc.expected) }) } diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index 1576762e8..8af894ac0 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "net" "net/http" "os" @@ -27,19 +28,25 @@ import ( "syscall" "time" - "github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/errs" - "github.com/OpenIMSDK/tools/mw" - "github.com/OpenIMSDK/tools/network" - grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" - kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister" - "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + + config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + + grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + + kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister" + + "github.com/OpenIMSDK/tools/discoveryregistry" + "github.com/OpenIMSDK/tools/mw" + "github.com/OpenIMSDK/tools/network" ) // Start rpc server. @@ -47,37 +54,38 @@ func Start( rpcPort int, rpcRegisterName string, prometheusPort int, - rpcFn func(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error, + config *config2.GlobalConfig, + rpcFn func(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error, options ...grpc.ServerOption, ) error { fmt.Printf("start %s server, port: %d, prometheusPort: %d, OpenIM version: %s\n", - rpcRegisterName, rpcPort, prometheusPort, config.Version) - rpcTcpAddr := net.JoinHostPort(network.GetListenIP(config.Config.Rpc.ListenIP), strconv.Itoa(rpcPort)) + rpcRegisterName, rpcPort, prometheusPort, config2.Version) + rpcTcpAddr := net.JoinHostPort(network.GetListenIP(config.Rpc.ListenIP), strconv.Itoa(rpcPort)) listener, err := net.Listen( "tcp", rpcTcpAddr, ) if err != nil { - return errs.Wrap(err, "rpc start err", rpcTcpAddr) + return errs.Wrap(err, "listen err", rpcTcpAddr) } defer listener.Close() - client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) + client, err := kdisc.NewDiscoveryRegister(config) if err != nil { return err } defer client.Close() client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) - registerIP, err := network.GetRpcRegisterIP(config.Config.Rpc.RegisterIP) + registerIP, err := network.GetRpcRegisterIP(config.Rpc.RegisterIP) if err != nil { return errs.Wrap(err) } var reg *prometheus.Registry var metric *grpcprometheus.ServerMetrics - if config.Config.Prometheus.Enable { - cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName) + if config.Prometheus.Enable { + cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName, config) reg, metric, _ = prommetrics.NewGrpcPromObj(cusMetrics) options = append(options, mw.GrpcServer(), grpc.StreamInterceptor(metric.StreamServerInterceptor()), grpc.UnaryInterceptor(metric.UnaryServerInterceptor())) @@ -91,7 +99,7 @@ func Start( once.Do(srv.GracefulStop) }() - err = rpcFn(client, srv) + err = rpcFn(config, client, srv) if err != nil { return err } @@ -111,7 +119,7 @@ func Start( httpServer *http.Server ) go func() { - if config.Config.Prometheus.Enable && prometheusPort != 0 { + if config.Prometheus.Enable && prometheusPort != 0 { metric.InitializeMetrics(srv) // Create a HTTP server for prometheus. httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)} diff --git a/pkg/common/startrpc/start_test.go b/pkg/common/startrpc/start_test.go index 481986e15..e5e37e221 100644 --- a/pkg/common/startrpc/start_test.go +++ b/pkg/common/startrpc/start_test.go @@ -16,6 +16,7 @@ package startrpc import ( "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "net" "testing" "time" @@ -25,7 +26,7 @@ import ( ) // mockRpcFn is a mock gRPC function for testing. -func mockRpcFn(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func mockRpcFn(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { // Implement a mock gRPC service registration logic if needed return nil } @@ -40,7 +41,8 @@ func TestStart(t *testing.T) { doneChan := make(chan error, 1) go func() { - err := Start(testRpcPort, testRpcRegisterName, testPrometheusPort, mockRpcFn) + err := Start(testRpcPort, testRpcRegisterName, testPrometheusPort, + config.NewGlobalConfig(), mockRpcFn) doneChan <- err }() diff --git a/pkg/common/tls/tls.go b/pkg/common/tls/tls.go old mode 100644 new mode 100755 index a52f46df7..736913758 --- a/pkg/common/tls/tls.go +++ b/pkg/common/tls/tls.go @@ -22,7 +22,6 @@ import ( "os" "github.com/OpenIMSDK/tools/errs" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) // decryptPEM decrypts a PEM block using a password. @@ -50,15 +49,14 @@ func readEncryptablePEMBlock(path string, pwd []byte) ([]byte, error) { } // NewTLSConfig setup the TLS config from general config file. -func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte) (*tls.Config, error) { - var tlsConfig tls.Config +func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte, insecureSkipVerify bool) (*tls.Config,error) { + tlsConfig := tls.Config{} if clientCertFile != "" && clientKeyFile != "" { certPEMBlock, err := os.ReadFile(clientCertFile) if err != nil { return nil, errs.Wrap(err, "NewTLSConfig: failed to read client cert file") } - keyPEMBlock, err := readEncryptablePEMBlock(clientKeyFile, keyPwd) if err != nil { return nil, err @@ -84,7 +82,7 @@ func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byt tlsConfig.RootCAs = caCertPool } - tlsConfig.InsecureSkipVerify = config.Config.Kafka.TLS.InsecureSkipVerify + tlsConfig.InsecureSkipVerify = insecureSkipVerify return &tlsConfig, nil } diff --git a/pkg/rpcclient/auth.go b/pkg/rpcclient/auth.go index bfd4b1119..24597120f 100644 --- a/pkg/rpcclient/auth.go +++ b/pkg/rpcclient/auth.go @@ -24,17 +24,18 @@ import ( "google.golang.org/grpc" ) -func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry) *Auth { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImAuthName) +func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Auth { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImAuthName) if err != nil { util.ExitWithError(err) } client := auth.NewAuthClient(conn) - return &Auth{discov: discov, conn: conn, Client: client} + return &Auth{discov: discov, conn: conn, Client: client, Config: config} } type Auth struct { conn grpc.ClientConnInterface Client auth.AuthClient discov discoveryregistry.SvcDiscoveryRegistry + Config *config.GlobalConfig } diff --git a/pkg/rpcclient/conversation.go b/pkg/rpcclient/conversation.go index ee9818f8f..6981844b6 100644 --- a/pkg/rpcclient/conversation.go +++ b/pkg/rpcclient/conversation.go @@ -30,21 +30,22 @@ type Conversation struct { Client pbconversation.ConversationClient conn grpc.ClientConnInterface discov discoveryregistry.SvcDiscoveryRegistry + Config *config.GlobalConfig } -func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry) *Conversation { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImConversationName) +func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Conversation { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImConversationName) if err != nil { util.ExitWithError(err) } client := pbconversation.NewConversationClient(conn) - return &Conversation{discov: discov, conn: conn, Client: client} + return &Conversation{discov: discov, conn: conn, Client: client, Config: config} } type ConversationRpcClient Conversation -func NewConversationRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) ConversationRpcClient { - return ConversationRpcClient(*NewConversation(discov)) +func NewConversationRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) ConversationRpcClient { + return ConversationRpcClient(*NewConversation(discov, config)) } func (c *ConversationRpcClient) GetSingleConversationRecvMsgOpt(ctx context.Context, userID, conversationID string) (int32, error) { diff --git a/pkg/rpcclient/friend.go b/pkg/rpcclient/friend.go index 520616564..e1f6ed076 100644 --- a/pkg/rpcclient/friend.go +++ b/pkg/rpcclient/friend.go @@ -29,21 +29,22 @@ type Friend struct { conn grpc.ClientConnInterface Client friend.FriendClient discov discoveryregistry.SvcDiscoveryRegistry + Config *config.GlobalConfig } -func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry) *Friend { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImFriendName) +func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Friend { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImFriendName) if err != nil { util.ExitWithError(err) } client := friend.NewFriendClient(conn) - return &Friend{discov: discov, conn: conn, Client: client} + return &Friend{discov: discov, conn: conn, Client: client, Config: config} } type FriendRpcClient Friend -func NewFriendRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) FriendRpcClient { - return FriendRpcClient(*NewFriend(discov)) +func NewFriendRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) FriendRpcClient { + return FriendRpcClient(*NewFriend(discov, config)) } func (f *FriendRpcClient) GetFriendsInfo( diff --git a/pkg/rpcclient/group.go b/pkg/rpcclient/group.go index bc9a3c75c..773637858 100644 --- a/pkg/rpcclient/group.go +++ b/pkg/rpcclient/group.go @@ -33,21 +33,22 @@ type Group struct { conn grpc.ClientConnInterface Client group.GroupClient discov discoveryregistry.SvcDiscoveryRegistry + Config *config.GlobalConfig } -func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry) *Group { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImGroupName) +func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Group { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImGroupName) if err != nil { util.ExitWithError(err) } client := group.NewGroupClient(conn) - return &Group{discov: discov, conn: conn, Client: client} + return &Group{discov: discov, conn: conn, Client: client, Config: config} } type GroupRpcClient Group -func NewGroupRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) GroupRpcClient { - return GroupRpcClient(*NewGroup(discov)) +func NewGroupRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) GroupRpcClient { + return GroupRpcClient(*NewGroup(discov, config)) } func (g *GroupRpcClient) GetGroupInfos( diff --git a/pkg/rpcclient/msg.go b/pkg/rpcclient/msg.go index e4008ed20..3db39925e 100644 --- a/pkg/rpcclient/msg.go +++ b/pkg/rpcclient/msg.go @@ -17,6 +17,8 @@ package rpcclient import ( "context" "encoding/json" + "fmt" + "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/msg" @@ -29,47 +31,47 @@ import ( "google.golang.org/protobuf/proto" ) -func newContentTypeConf() map[int32]config.NotificationConf { +func newContentTypeConf(conf *config.GlobalConfig) map[int32]config.NotificationConf { return map[int32]config.NotificationConf{ // group - constant.GroupCreatedNotification: config.Config.Notification.GroupCreated, - constant.GroupInfoSetNotification: config.Config.Notification.GroupInfoSet, - constant.JoinGroupApplicationNotification: config.Config.Notification.JoinGroupApplication, - constant.MemberQuitNotification: config.Config.Notification.MemberQuit, - constant.GroupApplicationAcceptedNotification: config.Config.Notification.GroupApplicationAccepted, - constant.GroupApplicationRejectedNotification: config.Config.Notification.GroupApplicationRejected, - constant.GroupOwnerTransferredNotification: config.Config.Notification.GroupOwnerTransferred, - constant.MemberKickedNotification: config.Config.Notification.MemberKicked, - constant.MemberInvitedNotification: config.Config.Notification.MemberInvited, - constant.MemberEnterNotification: config.Config.Notification.MemberEnter, - constant.GroupDismissedNotification: config.Config.Notification.GroupDismissed, - constant.GroupMutedNotification: config.Config.Notification.GroupMuted, - constant.GroupCancelMutedNotification: config.Config.Notification.GroupCancelMuted, - constant.GroupMemberMutedNotification: config.Config.Notification.GroupMemberMuted, - constant.GroupMemberCancelMutedNotification: config.Config.Notification.GroupMemberCancelMuted, - constant.GroupMemberInfoSetNotification: config.Config.Notification.GroupMemberInfoSet, - constant.GroupMemberSetToAdminNotification: config.Config.Notification.GroupMemberSetToAdmin, - constant.GroupMemberSetToOrdinaryUserNotification: config.Config.Notification.GroupMemberSetToOrdinary, - constant.GroupInfoSetAnnouncementNotification: config.Config.Notification.GroupInfoSetAnnouncement, - constant.GroupInfoSetNameNotification: config.Config.Notification.GroupInfoSetName, + constant.GroupCreatedNotification: conf.Notification.GroupCreated, + constant.GroupInfoSetNotification: conf.Notification.GroupInfoSet, + constant.JoinGroupApplicationNotification: conf.Notification.JoinGroupApplication, + constant.MemberQuitNotification: conf.Notification.MemberQuit, + constant.GroupApplicationAcceptedNotification: conf.Notification.GroupApplicationAccepted, + constant.GroupApplicationRejectedNotification: conf.Notification.GroupApplicationRejected, + constant.GroupOwnerTransferredNotification: conf.Notification.GroupOwnerTransferred, + constant.MemberKickedNotification: conf.Notification.MemberKicked, + constant.MemberInvitedNotification: conf.Notification.MemberInvited, + constant.MemberEnterNotification: conf.Notification.MemberEnter, + constant.GroupDismissedNotification: conf.Notification.GroupDismissed, + constant.GroupMutedNotification: conf.Notification.GroupMuted, + constant.GroupCancelMutedNotification: conf.Notification.GroupCancelMuted, + constant.GroupMemberMutedNotification: conf.Notification.GroupMemberMuted, + constant.GroupMemberCancelMutedNotification: conf.Notification.GroupMemberCancelMuted, + constant.GroupMemberInfoSetNotification: conf.Notification.GroupMemberInfoSet, + constant.GroupMemberSetToAdminNotification: conf.Notification.GroupMemberSetToAdmin, + constant.GroupMemberSetToOrdinaryUserNotification: conf.Notification.GroupMemberSetToOrdinary, + constant.GroupInfoSetAnnouncementNotification: conf.Notification.GroupInfoSetAnnouncement, + constant.GroupInfoSetNameNotification: conf.Notification.GroupInfoSetName, // user - constant.UserInfoUpdatedNotification: config.Config.Notification.UserInfoUpdated, - constant.UserStatusChangeNotification: config.Config.Notification.UserStatusChanged, + constant.UserInfoUpdatedNotification: conf.Notification.UserInfoUpdated, + constant.UserStatusChangeNotification: conf.Notification.UserStatusChanged, // friend - constant.FriendApplicationNotification: config.Config.Notification.FriendApplicationAdded, - constant.FriendApplicationApprovedNotification: config.Config.Notification.FriendApplicationApproved, - constant.FriendApplicationRejectedNotification: config.Config.Notification.FriendApplicationRejected, - constant.FriendAddedNotification: config.Config.Notification.FriendAdded, - constant.FriendDeletedNotification: config.Config.Notification.FriendDeleted, - constant.FriendRemarkSetNotification: config.Config.Notification.FriendRemarkSet, - constant.BlackAddedNotification: config.Config.Notification.BlackAdded, - constant.BlackDeletedNotification: config.Config.Notification.BlackDeleted, - constant.FriendInfoUpdatedNotification: config.Config.Notification.FriendInfoUpdated, - constant.FriendsInfoUpdateNotification: config.Config.Notification.FriendInfoUpdated, //use the same FriendInfoUpdated + constant.FriendApplicationNotification: conf.Notification.FriendApplicationAdded, + constant.FriendApplicationApprovedNotification: conf.Notification.FriendApplicationApproved, + constant.FriendApplicationRejectedNotification: conf.Notification.FriendApplicationRejected, + constant.FriendAddedNotification: conf.Notification.FriendAdded, + constant.FriendDeletedNotification: conf.Notification.FriendDeleted, + constant.FriendRemarkSetNotification: conf.Notification.FriendRemarkSet, + constant.BlackAddedNotification: conf.Notification.BlackAdded, + constant.BlackDeletedNotification: conf.Notification.BlackDeleted, + constant.FriendInfoUpdatedNotification: conf.Notification.FriendInfoUpdated, + constant.FriendsInfoUpdateNotification: conf.Notification.FriendInfoUpdated, //use the same FriendInfoUpdated // conversation - constant.ConversationChangeNotification: config.Config.Notification.ConversationChanged, - constant.ConversationUnreadNotification: config.Config.Notification.ConversationChanged, - constant.ConversationPrivateChatNotification: config.Config.Notification.ConversationSetPrivate, + constant.ConversationChangeNotification: conf.Notification.ConversationChanged, + constant.ConversationUnreadNotification: conf.Notification.ConversationChanged, + constant.ConversationPrivateChatNotification: conf.Notification.ConversationSetPrivate, // msg constant.MsgRevokeNotification: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg}, constant.HasReadReceipt: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg}, @@ -127,21 +129,22 @@ type Message struct { conn grpc.ClientConnInterface Client msg.MsgClient discov discoveryregistry.SvcDiscoveryRegistry + Config *config.GlobalConfig } -func NewMessage(discov discoveryregistry.SvcDiscoveryRegistry) *Message { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImMsgName) +func NewMessage(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Message { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImMsgName) if err != nil { panic(err) } client := msg.NewMsgClient(conn) - return &Message{discov: discov, conn: conn, Client: client} + return &Message{discov: discov, conn: conn, Client: client, Config: config} } type MessageRpcClient Message -func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) MessageRpcClient { - return MessageRpcClient(*NewMessage(discov)) +func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) MessageRpcClient { + return MessageRpcClient(*NewMessage(discov, config)) } // SendMsg sends a message through the gRPC client and returns the response. @@ -234,8 +237,8 @@ func WithUserRpcClient(userRpcClient *UserRpcClient) NotificationSenderOptions { } } -func NewNotificationSender(opts ...NotificationSenderOptions) *NotificationSender { - notificationSender := &NotificationSender{contentTypeConf: newContentTypeConf(), sessionTypeConf: newSessionTypeConf()} +func NewNotificationSender(config *config.GlobalConfig, opts ...NotificationSenderOptions) *NotificationSender { + notificationSender := &NotificationSender{contentTypeConf: newContentTypeConf(config), sessionTypeConf: newSessionTypeConf()} for _, opt := range opts { opt(notificationSender) } @@ -258,8 +261,8 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s n := sdkws.NotificationElem{Detail: utils.StructToJsonString(m)} content, err := json.Marshal(&n) if err != nil { - log.ZError(ctx, "MsgClient Notification json.Marshal failed", err, "sendID", sendID, "recvID", recvID, "contentType", contentType, "msg", m) - return err + errInfo := fmt.Sprintf("MsgClient Notification json.Marshal failed, sendID:%s, recvID:%s, contentType:%d, msg:%s", sendID, recvID, contentType, m) + return errs.Wrap(err, errInfo) } notificationOpt := ¬ificationOpt{} for _, opt := range opts { @@ -271,7 +274,8 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s if notificationOpt.WithRpcGetUsername && s.getUserInfo != nil { userInfo, err = s.getUserInfo(ctx, sendID) if err != nil { - log.ZWarn(ctx, "getUserInfo failed", err, "sendID", sendID) + errInfo := fmt.Sprintf("getUserInfo failed, sendID:%s", sendID) + return errs.Wrap(err, errInfo) } else { msg.SenderNickname = userInfo.Nickname msg.SenderFaceURL = userInfo.FaceURL @@ -303,10 +307,9 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s msg.OfflinePushInfo = &offlineInfo req.MsgData = &msg _, err = s.sendMsg(ctx, &req) - if err == nil { - log.ZDebug(ctx, "MsgClient Notification SendMsg success", "req", &req) - } else { - log.ZError(ctx, "MsgClient Notification SendMsg failed", err, "req", &req) + if err != nil { + errInfo := fmt.Sprintf("MsgClient Notification SendMsg failed, req:%s", &req) + return errs.Wrap(err, errInfo) } return err } diff --git a/pkg/rpcclient/notification/conversation.go b/pkg/rpcclient/notification/conversation.go index b43e5494a..115fb81a2 100644 --- a/pkg/rpcclient/notification/conversation.go +++ b/pkg/rpcclient/notification/conversation.go @@ -16,6 +16,7 @@ package notification import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/sdkws" @@ -26,8 +27,8 @@ type ConversationNotificationSender struct { *rpcclient.NotificationSender } -func NewConversationNotificationSender(msgRpcClient *rpcclient.MessageRpcClient) *ConversationNotificationSender { - return &ConversationNotificationSender{rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient))} +func NewConversationNotificationSender(config *config.GlobalConfig, msgRpcClient *rpcclient.MessageRpcClient) *ConversationNotificationSender { + return &ConversationNotificationSender{rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient))} } // SetPrivate invote. diff --git a/pkg/rpcclient/notification/friend.go b/pkg/rpcclient/notification/friend.go index 9237111be..31426da31 100644 --- a/pkg/rpcclient/notification/friend.go +++ b/pkg/rpcclient/notification/friend.go @@ -16,6 +16,7 @@ package notification import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/protocol/constant" pbfriend "github.com/OpenIMSDK/protocol/friend" @@ -80,11 +81,12 @@ func WithRpcFunc( } func NewFriendNotificationSender( + config *config.GlobalConfig, msgRpcClient *rpcclient.MessageRpcClient, opts ...friendNotificationSenderOptions, ) *FriendNotificationSender { f := &FriendNotificationSender{ - NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient)), + NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient)), } for _, opt := range opts { opt(f) diff --git a/pkg/rpcclient/notification/group.go b/pkg/rpcclient/notification/group.go index 1778a498d..5500f4f43 100644 --- a/pkg/rpcclient/notification/group.go +++ b/pkg/rpcclient/notification/group.go @@ -17,6 +17,7 @@ package notification import ( "context" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/protocol/constant" pbgroup "github.com/OpenIMSDK/protocol/group" @@ -35,12 +36,14 @@ func NewGroupNotificationSender( db controller.GroupDatabase, msgRpcClient *rpcclient.MessageRpcClient, userRpcClient *rpcclient.UserRpcClient, + config *config.GlobalConfig, fn func(ctx context.Context, userIDs []string) ([]CommonUser, error), ) *GroupNotificationSender { return &GroupNotificationSender{ - NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient), rpcclient.WithUserRpcClient(userRpcClient)), + NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient), rpcclient.WithUserRpcClient(userRpcClient)), getUsersInfo: fn, db: db, + config: config, } } @@ -48,6 +51,7 @@ type GroupNotificationSender struct { *rpcclient.NotificationSender getUsersInfo func(ctx context.Context, userIDs []string) ([]CommonUser, error) db controller.GroupDatabase + config *config.GlobalConfig } func (g *GroupNotificationSender) PopulateGroupMember(ctx context.Context, members ...*relation.GroupMemberModel) error { @@ -243,21 +247,15 @@ func (g *GroupNotificationSender) groupMemberDB2PB(member *relation.GroupMemberM } */ func (g *GroupNotificationSender) fillOpUser(ctx context.Context, opUser **sdkws.GroupMemberFullInfo, groupID string) (err error) { - defer log.ZDebug(ctx, "return") - defer func() { - if err != nil { - log.ZError(ctx, utils.GetFuncName(1)+" failed", err) - } - }() if opUser == nil { return errs.ErrInternalServer.Wrap("**sdkws.GroupMemberFullInfo is nil") } if *opUser != nil { - return nil + return errs.ErrArgs.Wrap("*opUser is not nil") } userID := mcontext.GetOpUserID(ctx) if groupID != "" { - if authverify.IsManagerUserID(userID) { + if authverify.IsManagerUserID(userID, g.config) { *opUser = &sdkws.GroupMemberFullInfo{ GroupID: groupID, UserID: userID, @@ -265,11 +263,11 @@ func (g *GroupNotificationSender) fillOpUser(ctx context.Context, opUser **sdkws AppMangerLevel: constant.AppAdmin, } } else { - member, err2 := g.db.TakeGroupMember(ctx, groupID, userID) - if err2 == nil { + member, err := g.db.TakeGroupMember(ctx, groupID, userID) + if err == nil { *opUser = g.groupMemberDB2PB(member, 0) - } else if !errs.ErrRecordNotFound.Is(err2) { - return err2 + } else if !errs.ErrRecordNotFound.Is(err) { + return err } } } @@ -650,12 +648,6 @@ func (g *GroupNotificationSender) GroupCancelMutedNotification(ctx context.Conte } func (g *GroupNotificationSender) GroupMemberInfoSetNotification(ctx context.Context, groupID, groupMemberUserID string) (err error) { - defer log.ZDebug(ctx, "return") - defer func() { - if err != nil { - log.ZError(ctx, utils.GetFuncName(1)+" failed", err) - } - }() group, err := g.getGroupInfo(ctx, groupID) if err != nil { return err @@ -672,12 +664,6 @@ func (g *GroupNotificationSender) GroupMemberInfoSetNotification(ctx context.Con } func (g *GroupNotificationSender) GroupMemberSetToAdminNotification(ctx context.Context, groupID, groupMemberUserID string) (err error) { - defer log.ZDebug(ctx, "return") - defer func() { - if err != nil { - log.ZError(ctx, utils.GetFuncName(1)+" failed", err) - } - }() group, err := g.getGroupInfo(ctx, groupID) if err != nil { return err diff --git a/pkg/rpcclient/notification/msg.go b/pkg/rpcclient/notification/msg.go index 83280f80a..19819e7b7 100644 --- a/pkg/rpcclient/notification/msg.go +++ b/pkg/rpcclient/notification/msg.go @@ -16,6 +16,7 @@ package notification import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/sdkws" @@ -26,8 +27,8 @@ type MsgNotificationSender struct { *rpcclient.NotificationSender } -func NewMsgNotificationSender(opts ...rpcclient.NotificationSenderOptions) *MsgNotificationSender { - return &MsgNotificationSender{rpcclient.NewNotificationSender(opts...)} +func NewMsgNotificationSender(config *config.GlobalConfig, opts ...rpcclient.NotificationSenderOptions) *MsgNotificationSender { + return &MsgNotificationSender{rpcclient.NewNotificationSender(config, opts...)} } func (m *MsgNotificationSender) UserDeleteMsgsNotification(ctx context.Context, userID, conversationID string, seqs []int64) error { diff --git a/pkg/rpcclient/notification/user.go b/pkg/rpcclient/notification/user.go index f94e59a33..204b13e61 100644 --- a/pkg/rpcclient/notification/user.go +++ b/pkg/rpcclient/notification/user.go @@ -16,6 +16,7 @@ package notification import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/sdkws" @@ -58,11 +59,12 @@ func WithUserFunc( } func NewUserNotificationSender( + config *config.GlobalConfig, msgRpcClient *rpcclient.MessageRpcClient, opts ...userNotificationSenderOptions, ) *UserNotificationSender { f := &UserNotificationSender{ - NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient)), + NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient)), } for _, opt := range opts { opt(f) diff --git a/pkg/rpcclient/push.go b/pkg/rpcclient/push.go index 2f540da81..c0aa9efa4 100644 --- a/pkg/rpcclient/push.go +++ b/pkg/rpcclient/push.go @@ -30,8 +30,8 @@ type Push struct { discov discoveryregistry.SvcDiscoveryRegistry } -func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImPushName) +func NewPush(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Push { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImPushName) if err != nil { util.ExitWithError(err) } @@ -44,8 +44,8 @@ func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push { type PushRpcClient Push -func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) PushRpcClient { - return PushRpcClient(*NewPush(discov)) +func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) PushRpcClient { + return PushRpcClient(*NewPush(discov, config)) } func (p *PushRpcClient) DelUserPushToken(ctx context.Context, req *push.DelUserPushTokenReq) (*push.DelUserPushTokenResp, error) { diff --git a/pkg/rpcclient/third.go b/pkg/rpcclient/third.go old mode 100644 new mode 100755 index be40335d5..50d545a64 --- a/pkg/rpcclient/third.go +++ b/pkg/rpcclient/third.go @@ -16,13 +16,15 @@ package rpcclient import ( "context" + "github.com/OpenIMSDK/tools/errs" "net/url" - "github.com/OpenIMSDK/protocol/third" - "github.com/OpenIMSDK/tools/discoveryregistry" - "github.com/OpenIMSDK/tools/errs" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" + + "github.com/OpenIMSDK/protocol/third" + "github.com/OpenIMSDK/tools/discoveryregistry" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" "google.golang.org/grpc" @@ -33,47 +35,41 @@ type Third struct { Client third.ThirdClient discov discoveryregistry.SvcDiscoveryRegistry MinioClient *minio.Client + Config *config.GlobalConfig } -func NewThird(discov discoveryregistry.SvcDiscoveryRegistry) *Third { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImThirdName) +func NewThird(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Third { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImThirdName) if err != nil { util.ExitWithError(err) } client := third.NewThirdClient(conn) - minioClient, err := minioInit() + minioClient, err := minioInit(config) if err != nil { util.ExitWithError(err) } - return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient} + return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient, Config: config} } -func minioInit() (*minio.Client, error) { - // Retrieve MinIO configuration details - endpoint := config.Config.Object.Minio.Endpoint - accessKeyID := config.Config.Object.Minio.AccessKeyID - secretAccessKey := config.Config.Object.Minio.SecretAccessKey - - // Parse the MinIO URL to determine if the connection should be secure - minioURL, err := url.Parse(endpoint) +func minioInit(config *config.GlobalConfig) (*minio.Client, error) { + minioClient := &minio.Client{} + initUrl := config.Object.Minio.Endpoint + minioUrl, err := url.Parse(initUrl) if err != nil { return nil, errs.Wrap(err, "minioInit: failed to parse MinIO endpoint URL") } - - // Determine the security of the connection based on the scheme - secure := minioURL.Scheme == "https" - - // Setup MinIO client options opts := &minio.Options{ - Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""), - Secure: secure, + Creds: credentials.NewStaticV4(config.Object.Minio.AccessKeyID, config.Object.Minio.SecretAccessKey, ""), + // Region: config.Credential.Minio.Location, } - - // Initialize MinIO client - minioClient, err := minio.New(minioURL.Host, opts) + if minioUrl.Scheme == "http" { + opts.Secure = false + } else if minioUrl.Scheme == "https" { + opts.Secure = true + } + minioClient, err = minio.New(minioUrl.Host, opts) if err != nil { return nil, errs.Wrap(err, "minioInit: failed to create MinIO client") } - return minioClient, nil } diff --git a/pkg/rpcclient/user.go b/pkg/rpcclient/user.go index a6c202129..08ad41dc1 100644 --- a/pkg/rpcclient/user.go +++ b/pkg/rpcclient/user.go @@ -34,16 +34,17 @@ type User struct { conn grpc.ClientConnInterface Client user.UserClient Discov discoveryregistry.SvcDiscoveryRegistry + Config *config.GlobalConfig } // NewUser initializes and returns a User instance based on the provided service discovery registry. -func NewUser(discov discoveryregistry.SvcDiscoveryRegistry) *User { - conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImUserName) +func NewUser(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *User { + conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImUserName) if err != nil { util.ExitWithError(err) } client := user.NewUserClient(conn) - return &User{Discov: discov, Client: client, conn: conn} + return &User{Discov: discov, Client: client, conn: conn, Config: config} } // UserRpcClient represents the structure for a User RPC client. @@ -56,8 +57,8 @@ func NewUserRpcClientByUser(user *User) *UserRpcClient { } // NewUserRpcClient initializes a UserRpcClient based on the provided service discovery registry. -func NewUserRpcClient(client discoveryregistry.SvcDiscoveryRegistry) UserRpcClient { - return UserRpcClient(*NewUser(client)) +func NewUserRpcClient(client discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) UserRpcClient { + return UserRpcClient(*NewUser(client, config)) } // GetUsersInfo retrieves information for multiple users based on their user IDs. @@ -160,7 +161,7 @@ func (u *UserRpcClient) Access(ctx context.Context, ownerUserID string) error { if err != nil { return err } - return authverify.CheckAccessV3(ctx, ownerUserID) + return authverify.CheckAccessV3(ctx, ownerUserID, u.Config) } // GetAllUserIDs retrieves all user IDs with pagination options. diff --git a/tools/component/component.go b/tools/component/component.go index 6b879d7f8..bb3a2d46f 100644 --- a/tools/component/component.go +++ b/tools/component/component.go @@ -47,27 +47,33 @@ var ( cfgPath = flag.String("c", defaultCfgPath, "Path to the configuration file") ) -func initCfg() error { +func initCfg() (*config.GlobalConfig, error) { data, err := os.ReadFile(*cfgPath) if err != nil { - return err + return nil, errs.Wrap(err, "ReadFile unmarshal failed") } - return yaml.Unmarshal(data, &config.Config) + conf := config.NewGlobalConfig() + err = yaml.Unmarshal(data, &conf) + if err != nil { + return nil, errs.Wrap(err, "InitConfig unmarshal failed") + } + return conf, nil } type checkFunc struct { name string - function func() error + function func(*config.GlobalConfig) error flag bool + config *config.GlobalConfig } func main() { flag.Parse() - if err := initCfg(); err != nil { + conf, err := initCfg() + if err != nil { fmt.Printf("Read config failed: %v\n", err) - return } @@ -75,11 +81,11 @@ func main() { checks := []checkFunc{ //{name: "Mysql", function: checkMysql}, - {name: "Mongo", function: checkMongo}, - {name: "Redis", function: checkRedis}, - {name: "Minio", function: checkMinio}, - {name: "Zookeeper", function: checkZookeeper}, - {name: "Kafka", function: checkKafka}, + {name: "Mongo", function: checkMongo, config: conf}, + {name: "Redis", function: checkRedis, config: conf}, + {name: "Minio", function: checkMinio, config: conf}, + {name: "Zookeeper", function: checkZookeeper, config: conf}, + {name: "Kafka", function: checkKafka, config: conf}, } for i := 0; i < maxRetry; i++ { @@ -92,7 +98,7 @@ func main() { allSuccess := true for index, check := range checks { if !check.flag { - err = check.function() + err = check.function(check.config) if err != nil { component.ErrorPrint(fmt.Sprintf("Starting %s failed:%v.", check.name, err)) allSuccess = false @@ -112,30 +118,30 @@ func main() { } // checkMongo checks the MongoDB connection without retries -func checkMongo() error { - _, err := unrelation.NewMongo() +func checkMongo(config *config.GlobalConfig) error { + _, err := unrelation.NewMongo(config) return err } // checkRedis checks the Redis connection -func checkRedis() error { - _, err := cache.NewRedis() +func checkRedis(config *config.GlobalConfig) error { + _, err := cache.NewRedis(config) return err } // checkMinio checks the MinIO connection -func checkMinio() error { +func checkMinio(config *config.GlobalConfig) error { // Check if MinIO is enabled - if config.Config.Object.Enable != "minio" { + if config.Object.Enable != "minio" { return errs.Wrap(errors.New("minio.Enable is empty")) } minio := &component.Minio{ - ApiURL: config.Config.Object.ApiURL, - Endpoint: config.Config.Object.Minio.Endpoint, - AccessKeyID: config.Config.Object.Minio.AccessKeyID, - SecretAccessKey: config.Config.Object.Minio.SecretAccessKey, - SignEndpoint: config.Config.Object.Minio.SignEndpoint, + ApiURL: config.Object.ApiURL, + Endpoint: config.Object.Minio.Endpoint, + AccessKeyID: config.Object.Minio.AccessKeyID, + SecretAccessKey: config.Object.Minio.SecretAccessKey, + SignEndpoint: config.Object.Minio.SignEndpoint, UseSSL: getEnv("MINIO_USE_SSL", "false"), } err := component.CheckMinio(minio) @@ -143,18 +149,18 @@ func checkMinio() error { } // checkZookeeper checks the Zookeeper connection -func checkZookeeper() error { - _, err := zookeeper.NewZookeeperDiscoveryRegister() +func checkZookeeper(config *config.GlobalConfig) error { + _, err := zookeeper.NewZookeeperDiscoveryRegister(config) return err } // checkKafka checks the Kafka connection -func checkKafka() error { +func checkKafka(config *config.GlobalConfig) error { // Prioritize environment variables kafkaStu := &component.Kafka{ - Username: config.Config.Kafka.Username, - Password: config.Config.Kafka.Password, - Addr: config.Config.Kafka.Addr, + Username: config.Kafka.Username, + Password: config.Kafka.Password, + Addr: config.Kafka.Addr, } kafkaClient, err := component.CheckKafka(kafkaStu) @@ -170,9 +176,9 @@ func checkKafka() error { } requiredTopics := []string{ - config.Config.Kafka.MsgToMongo.Topic, - config.Config.Kafka.MsgToPush.Topic, - config.Config.Kafka.LatestMsgToRedis.Topic, + config.Kafka.MsgToMongo.Topic, + config.Kafka.MsgToPush.Topic, + config.Kafka.LatestMsgToRedis.Topic, } for _, requiredTopic := range requiredTopics { @@ -181,11 +187,25 @@ func checkKafka() error { } } + var tlsConfig *kafka.TLSConfig + if config.Kafka.TLS != nil { + tlsConfig = &kafka.TLSConfig{ + CACrt: config.Kafka.TLS.CACrt, + ClientCrt: config.Kafka.TLS.ClientCrt, + ClientKey: config.Kafka.TLS.ClientKey, + ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd, + InsecureSkipVerify: config.Kafka.TLS.InsecureSkipVerify, + } + } + _, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ KafkaVersion: sarama.V2_0_0_0, - OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, - }, []string{config.Config.Kafka.LatestMsgToRedis.Topic}, - config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToRedis) + OffsetsInitial: sarama.OffsetNewest, + IsReturnErr: false, + UserName: config.Kafka.Username, + Password: config.Kafka.Password, + }, []string{config.Kafka.LatestMsgToRedis.Topic}, + config.Kafka.Addr, config.Kafka.ConsumerGroupID.MsgToRedis, tlsConfig) if err != nil { return err } @@ -193,8 +213,8 @@ func checkKafka() error { _, err = kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ KafkaVersion: sarama.V2_0_0_0, OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, - }, []string{config.Config.Kafka.MsgToPush.Topic}, - config.Config.Kafka.Addr, config.Config.Kafka.ConsumerGroupID.MsgToMongo) + }, []string{config.Kafka.MsgToPush.Topic}, + config.Kafka.Addr, config.Kafka.ConsumerGroupID.MsgToMongo, tlsConfig) if err != nil { return err } @@ -202,8 +222,8 @@ func checkKafka() error { kafka.NewMConsumerGroup(&kafka.MConsumerGroupConfig{ KafkaVersion: sarama.V2_0_0_0, OffsetsInitial: sarama.OffsetNewest, IsReturnErr: false, - }, []string{config.Config.Kafka.MsgToPush.Topic}, config.Config.Kafka.Addr, - config.Config.Kafka.ConsumerGroupID.MsgToPush) + }, []string{config.Kafka.MsgToPush.Topic}, config.Kafka.Addr, + config.Kafka.ConsumerGroupID.MsgToPush, tlsConfig) if err != nil { return err } diff --git a/tools/component/component_test.go b/tools/component/component_test.go index 4488c029e..c56361b2c 100644 --- a/tools/component/component_test.go +++ b/tools/component/component_test.go @@ -21,20 +21,11 @@ import ( "time" "github.com/redis/go-redis/v9" - - "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) -// Mock for initCfg for testing purpose -func mockInitCfg() error { - config.Config.Mysql.Username = "root" - config.Config.Mysql.Password = "openIM123" - config.Config.Mysql.Address = []string{"127.0.0.1:13306"} - return nil -} - func TestRedis(t *testing.T) { - config.Config.Redis.Address = []string{ + conf, err := initCfg() + conf.Redis.Address = []string{ "172.16.8.142:7000", //"172.16.8.142:7000", "172.16.8.142:7001", "172.16.8.142:7002", "172.16.8.142:7003", "172.16.8.142:7004", "172.16.8.142:7005", } @@ -45,20 +36,20 @@ func TestRedis(t *testing.T) { redisClient.Close() } }() - if len(config.Config.Redis.Address) > 1 { + if len(conf.Redis.Address) > 1 { redisClient = redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: config.Config.Redis.Address, - Username: config.Config.Redis.Username, - Password: config.Config.Redis.Password, + Addrs: conf.Redis.Address, + Username: conf.Redis.Username, + Password: conf.Redis.Password, }) } else { redisClient = redis.NewClient(&redis.Options{ - Addr: config.Config.Redis.Address[0], - Username: config.Config.Redis.Username, - Password: config.Config.Redis.Password, + Addr: conf.Redis.Address[0], + Username: conf.Redis.Username, + Password: conf.Redis.Password, }) } - _, err := redisClient.Ping(context.Background()).Result() + _, err = redisClient.Ping(context.Background()).Result() if err != nil { t.Fatal(err) } diff --git a/tools/up35/pkg/pkg.go b/tools/up35/pkg/pkg.go index b7e7c01f5..1348172d2 100644 --- a/tools/up35/pkg/pkg.go +++ b/tools/up35/pkg/pkg.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "github.com/OpenIMSDK/tools/errs" "log" "os" "reflect" @@ -45,36 +46,43 @@ const ( versionValue = 35 ) -func InitConfig(path string) error { +func InitConfig(path string) (*config.GlobalConfig, error) { data, err := os.ReadFile(path) if err != nil { - return err + return nil, errs.Wrap(err, "ReadFile unmarshal failed") } - return yaml.Unmarshal(data, &config.Config) + + conf := config.NewGlobalConfig() + err = yaml.Unmarshal(data, &conf) + if err != nil { + return nil, errs.Wrap(err, "InitConfig unmarshal failed") + } + return conf, nil } -func GetMysql() (*gorm.DB, error) { - conf := config.Config.Mysql +func GetMysql(config *config.GlobalConfig) (*gorm.DB, error) { + conf := config.Mysql mysqlDSN := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", conf.Username, conf.Password, conf.Address[0], conf.Database) return gorm.Open(gormmysql.Open(mysqlDSN), &gorm.Config{Logger: logger.Discard}) } -func GetMongo() (*mongo.Database, error) { - mgo, err := unrelation.NewMongo() +func GetMongo(config *config.GlobalConfig) (*mongo.Database, error) { + mgo, err := unrelation.NewMongo(config) if err != nil { return nil, err } - return mgo.GetDatabase(), nil + return mgo.GetDatabase(config.Mongo.Database), nil } func Main(path string) error { - if err := InitConfig(path); err != nil { + conf, err := InitConfig(path) + if err != nil { return err } - if config.Config.Mysql == nil { + if conf.Mysql == nil { return nil } - mongoDB, err := GetMongo() + mongoDB, err := GetMongo(conf) if err != nil { return err } @@ -91,7 +99,7 @@ func Main(path string) error { default: return err } - mysqlDB, err := GetMysql() + mysqlDB, err := GetMysql(conf) if err != nil { if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == 1049 { if err := SetMongoDataVersion(mongoDB, version.Value); err != nil { @@ -113,7 +121,7 @@ func Main(path string) error { func() error { return NewTask(mysqlDB, mongoDB, mgo.NewGroupMember, c.GroupMember) }, func() error { return NewTask(mysqlDB, mongoDB, mgo.NewGroupRequestMgo, c.GroupRequest) }, func() error { return NewTask(mysqlDB, mongoDB, mgo.NewConversationMongo, c.Conversation) }, - func() error { return NewTask(mysqlDB, mongoDB, mgo.NewS3Mongo, c.Object(config.Config.Object.Enable)) }, + func() error { return NewTask(mysqlDB, mongoDB, mgo.NewS3Mongo, c.Object(conf.Object.Enable)) }, func() error { return NewTask(mysqlDB, mongoDB, mgo.NewLogMongo, c.Log) }, func() error { return NewTask(mysqlDB, mongoDB, rtcmgo.NewSignal, c.SignalModel) },