diff --git a/cmd/openim-api/main.go b/cmd/openim-api/main.go index 7a7e06293..9a307686d 100644 --- a/cmd/openim-api/main.go +++ b/cmd/openim-api/main.go @@ -15,118 +15,17 @@ package main import ( - "context" - "fmt" - "net" - "net/http" - _ "net/http/pprof" - "os" - "os/signal" - "strconv" - "syscall" - "time" - util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" + _ "net/http/pprof" - "github.com/OpenIMSDK/tools/errs" - - "github.com/OpenIMSDK/protocol/constant" - "github.com/OpenIMSDK/tools/discoveryregistry" - - "github.com/openimsdk/open-im-server/v3/internal/api" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" - kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister" - ginprom "github.com/openimsdk/open-im-server/v3/pkg/common/ginprometheus" - "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" ) func main() { apiCmd := cmd.NewApiCmd() apiCmd.AddPortFlag() apiCmd.AddPrometheusPortFlag() - apiCmd.AddApi(run) if err := apiCmd.Execute(); err != nil { util.ExitWithError(err) } } - -func run(port int, proPort int) error { - if port == 0 || proPort == 0 { - err := "port or proPort is empty:" + strconv.Itoa(port) + "," + strconv.Itoa(proPort) - return errs.Wrap(fmt.Errorf(err)) - } - rdb, err := cache.NewRedis() - if err != nil { - return err - } - - var client discoveryregistry.SvcDiscoveryRegistry - - // Determine whether zk is passed according to whether it is a clustered deployment - client, err = kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) - if err != nil { - return errs.Wrap(err, "register discovery err") - } - - if err = client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil { - return errs.Wrap(err, "create rpc root nodes error") - } - - if err = client.RegisterConf2Registry(constant.OpenIMCommonConfigKey, config.Config.EncodeConfig()); err != nil { - return err - } - var ( - netDone = make(chan struct{}, 1) - netErr error - ) - router := api.NewGinRouter(client, rdb) - if config.Config.Prometheus.Enable { - go func() { - p := ginprom.NewPrometheus("app", prommetrics.GetGinCusMetrics("Api")) - p.SetListenAddress(fmt.Sprintf(":%d", proPort)) - if err = p.Use(router); err != nil && err != http.ErrServerClosed { - netErr = errs.Wrap(err, fmt.Sprintf("prometheus start err: %d", proPort)) - netDone <- struct{}{} - } - }() - - } - - var address string - if config.Config.Api.ListenIP != "" { - address = net.JoinHostPort(config.Config.Api.ListenIP, strconv.Itoa(port)) - } else { - address = net.JoinHostPort("0.0.0.0", strconv.Itoa(port)) - } - - server := http.Server{Addr: address, Handler: router} - - go func() { - err = server.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - netErr = errs.Wrap(err, fmt.Sprintf("api start err: %s", server.Addr)) - netDone <- struct{}{} - - } - }() - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) - - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - select { - case <-sigs: - util.SIGUSR1Exit() - err := server.Shutdown(ctx) - if err != nil { - return errs.Wrap(err, "shutdown err") - } - case <-netDone: - close(netDone) - return netErr - } - return nil -} diff --git a/cmd/openim-crontask/main.go b/cmd/openim-crontask/main.go index b284fd773..b52029c64 100644 --- a/cmd/openim-crontask/main.go +++ b/cmd/openim-crontask/main.go @@ -15,14 +15,13 @@ package main import ( - "github.com/openimsdk/open-im-server/v3/internal/tools" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { cronTaskCmd := cmd.NewCronTaskCmd() - if err := cronTaskCmd.Exec(tools.StartTask); err != nil { + if err := cronTaskCmd.Exec(); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-msggateway/main.go b/cmd/openim-msggateway/main.go index ed67b8f5d..01b13560d 100644 --- a/cmd/openim-msggateway/main.go +++ b/cmd/openim-msggateway/main.go @@ -24,7 +24,6 @@ func main() { msgGatewayCmd.AddWsPortFlag() msgGatewayCmd.AddPortFlag() msgGatewayCmd.AddPrometheusPortFlag() - if err := msgGatewayCmd.Exec(); err != nil { util.ExitWithError(err) } diff --git a/cmd/openim-push/main.go b/cmd/openim-push/main.go index e0539fa52..c7d29fc97 100644 --- a/cmd/openim-push/main.go +++ b/cmd/openim-push/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/push" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - pushCmd := cmd.NewRpcCmd(cmd.RpcPushServer) + pushCmd := cmd.NewRpcCmd(cmd.RpcPushServer, push.Start) pushCmd.AddPortFlag() pushCmd.AddPrometheusPortFlag() if err := pushCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := pushCmd.StartSvr(config.Config.RpcRegisterName.OpenImPushName, push.Start); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-rpc/openim-rpc-auth/main.go b/cmd/openim-rpc/openim-rpc-auth/main.go index b526c3b86..da281b70e 100644 --- a/cmd/openim-rpc/openim-rpc-auth/main.go +++ b/cmd/openim-rpc/openim-rpc-auth/main.go @@ -17,19 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/auth" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - authCmd := cmd.NewRpcCmd(cmd.RpcAuthServer) + authCmd := cmd.NewRpcCmd(cmd.RpcAuthServer, auth.Start) authCmd.AddPortFlag() authCmd.AddPrometheusPortFlag() if err := authCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := authCmd.StartSvr(config.Config.RpcRegisterName.OpenImAuthName, auth.Start); err != nil { util.ExitWithError(err) } - } diff --git a/cmd/openim-rpc/openim-rpc-conversation/main.go b/cmd/openim-rpc/openim-rpc-conversation/main.go index bde191c51..6e74b3251 100644 --- a/cmd/openim-rpc/openim-rpc-conversation/main.go +++ b/cmd/openim-rpc/openim-rpc-conversation/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/conversation" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcConversationServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcConversationServer, conversation.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImConversationName, conversation.Start); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-rpc/openim-rpc-friend/main.go b/cmd/openim-rpc/openim-rpc-friend/main.go index 8eeb9c8e1..a307c01a1 100644 --- a/cmd/openim-rpc/openim-rpc-friend/main.go +++ b/cmd/openim-rpc/openim-rpc-friend/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/friend" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcFriendServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcFriendServer, friend.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImFriendName, friend.Start); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-rpc/openim-rpc-group/main.go b/cmd/openim-rpc/openim-rpc-group/main.go index a5842ffd1..2afb7963c 100644 --- a/cmd/openim-rpc/openim-rpc-group/main.go +++ b/cmd/openim-rpc/openim-rpc-group/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/group" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcGroupServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcGroupServer, group.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImGroupName, group.Start); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-rpc/openim-rpc-msg/main.go b/cmd/openim-rpc/openim-rpc-msg/main.go index b3895a502..bbffbcae7 100644 --- a/cmd/openim-rpc/openim-rpc-msg/main.go +++ b/cmd/openim-rpc/openim-rpc-msg/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/msg" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcMsgServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcMsgServer, msg.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImMsgName, msg.Start); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-rpc/openim-rpc-third/main.go b/cmd/openim-rpc/openim-rpc-third/main.go index 8f390bb6a..09a8409e6 100644 --- a/cmd/openim-rpc/openim-rpc-third/main.go +++ b/cmd/openim-rpc/openim-rpc-third/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/third" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcThirdServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcThirdServer, third.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImThirdName, third.Start); err != nil { util.ExitWithError(err) } } diff --git a/cmd/openim-rpc/openim-rpc-user/main.go b/cmd/openim-rpc/openim-rpc-user/main.go index 6994ea2b1..18adbfae5 100644 --- a/cmd/openim-rpc/openim-rpc-user/main.go +++ b/cmd/openim-rpc/openim-rpc-user/main.go @@ -17,18 +17,14 @@ package main import ( "github.com/openimsdk/open-im-server/v3/internal/rpc/user" "github.com/openimsdk/open-im-server/v3/pkg/common/cmd" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" ) func main() { - rpcCmd := cmd.NewRpcCmd(cmd.RpcUserServer) + rpcCmd := cmd.NewRpcCmd(cmd.RpcUserServer, user.Start) rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) - } - if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImUserName, user.Start); err != nil { util.ExitWithError(err) } } diff --git a/internal/api/route.go b/internal/api/route.go index 24ed5f6bb..88004e802 100644 --- a/internal/api/route.go +++ b/internal/api/route.go @@ -17,7 +17,17 @@ package api import ( "context" "fmt" + kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister" + ginprom "github.com/openimsdk/open-im-server/v3/pkg/common/ginprometheus" + "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil" + "net" "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" "github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/tools/apiresp" @@ -43,8 +53,87 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) -func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.UniversalClient) *gin.Engine { - discov.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) // 默认RPC中间件 +func Start(config *config.GlobalConfig, port int, proPort int) error { + if port == 0 || proPort == 0 { + err := "port or proPort is empty:" + strconv.Itoa(port) + "," + strconv.Itoa(proPort) + return errs.Wrap(fmt.Errorf(err)) + } + rdb, err := cache.NewRedis() + if err != nil { + return err + } + + var client discoveryregistry.SvcDiscoveryRegistry + + // Determine whether zk is passed according to whether it is a clustered deployment + client, err = kdisc.NewDiscoveryRegister(config.Envs.Discovery) + if err != nil { + return errs.Wrap(err, "register discovery err") + } + + if err = client.CreateRpcRootNodes(config.GetServiceNames()); err != nil { + return errs.Wrap(err, "create rpc root nodes error") + } + + if err = client.RegisterConf2Registry(constant.OpenIMCommonConfigKey, config.EncodeConfig()); err != nil { + return errs.Wrap(err) + } + var ( + netDone = make(chan struct{}, 1) + netErr error + ) + router := newGinRouter(client, rdb) + if config.Prometheus.Enable { + go func() { + p := ginprom.NewPrometheus("app", prommetrics.GetGinCusMetrics("Api")) + p.SetListenAddress(fmt.Sprintf(":%d", proPort)) + if err = p.Use(router); err != nil && err != http.ErrServerClosed { + netErr = errs.Wrap(err, fmt.Sprintf("prometheus start err: %d", proPort)) + netDone <- struct{}{} + } + }() + + } + + var address string + if config.Api.ListenIP != "" { + address = net.JoinHostPort(config.Api.ListenIP, strconv.Itoa(port)) + } else { + address = net.JoinHostPort("0.0.0.0", strconv.Itoa(port)) + } + + server := http.Server{Addr: address, Handler: router} + + go func() { + err = server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + netErr = errs.Wrap(err, fmt.Sprintf("api start err: %s", server.Addr)) + netDone <- struct{}{} + + } + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + select { + case <-sigs: + util.SIGUSR1Exit() + err := server.Shutdown(ctx) + if err != nil { + return errs.Wrap(err, "shutdown err") + } + case <-netDone: + close(netDone) + return netErr + } + return nil +} + +func newGinRouter(disCov discoveryregistry.SvcDiscoveryRegistry, rdb redis.UniversalClient) *gin.Engine { + disCov.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) // 默认RPC中间件 gin.SetMode(gin.ReleaseMode) r := gin.New() if v, ok := binding.Validator.Engine().(*validator.Validate); ok { @@ -53,13 +142,13 @@ func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.Unive log.ZInfo(context.Background(), "load config", "config", config.Config) r.Use(gin.Recovery(), mw.CorsHandler(), mw.GinParseOperationID()) // init rpc client here - userRpc := rpcclient.NewUser(discov) - groupRpc := rpcclient.NewGroup(discov) - friendRpc := rpcclient.NewFriend(discov) - messageRpc := rpcclient.NewMessage(discov) - conversationRpc := rpcclient.NewConversation(discov) - authRpc := rpcclient.NewAuth(discov) - thirdRpc := rpcclient.NewThird(discov) + userRpc := rpcclient.NewUser(disCov) + groupRpc := rpcclient.NewGroup(disCov) + friendRpc := rpcclient.NewFriend(disCov) + messageRpc := rpcclient.NewMessage(disCov) + conversationRpc := rpcclient.NewConversation(disCov) + authRpc := rpcclient.NewAuth(disCov) + thirdRpc := rpcclient.NewThird(disCov) u := NewUserApi(*userRpc) m := NewMessageApi(messageRpc, userRpc) diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index 807c4af3b..97c98e8cd 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -33,7 +33,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/startrpc" ) -func (s *Server) InitServer(disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func (s *Server) InitServer(config *config.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { rdb, err := cache.NewRedis() if err != nil { return err @@ -46,11 +46,12 @@ func (s *Server) InitServer(disCov discoveryregistry.SvcDiscoveryRegistry, serve return nil } -func (s *Server) Start() error { +func (s *Server) Start(conf *config.GlobalConfig) error { return startrpc.Start( s.rpcPort, config.Config.RpcRegisterName.OpenImMessageGatewayName, s.prometheusPort, + conf, s.InitServer, ) } diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index 321407f7e..36b6ea88e 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -22,7 +22,7 @@ import ( ) // RunWsAndServer run ws server. -func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { +func RunWsAndServer(conf *config.GlobalConfig, rpcPort, wsPort, prometheusPort int) error { fmt.Println( "start rpc/msg_gateway server, port: ", rpcPort, @@ -33,10 +33,10 @@ func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { ) longServer, err := NewWsServer( WithPort(wsPort), - WithMaxConnNum(int64(config.Config.LongConnSvr.WebsocketMaxConnNum)), - WithHandshakeTimeout(time.Duration(config.Config.LongConnSvr.WebsocketTimeout)*time.Second), - WithMessageMaxMsgLength(config.Config.LongConnSvr.WebsocketMaxMsgLen), - WithWriteBufferSize(config.Config.LongConnSvr.WebsocketWriteBufferSize), + WithMaxConnNum(int64(conf.LongConnSvr.WebsocketMaxConnNum)), + WithHandshakeTimeout(time.Duration(conf.LongConnSvr.WebsocketTimeout)*time.Second), + WithMessageMaxMsgLength(conf.LongConnSvr.WebsocketMaxMsgLen), + WithWriteBufferSize(conf.LongConnSvr.WebsocketWriteBufferSize), ) if err != nil { return err @@ -45,7 +45,7 @@ func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { hubServer := NewServer(rpcPort, prometheusPort, longServer) netDone := make(chan error) go func() { - err = hubServer.Start() + err = hubServer.Start(conf) if err != nil { netDone <- err } diff --git a/internal/msgtransfer/init.go b/internal/msgtransfer/init.go index 16d8613db..2493933f3 100644 --- a/internal/msgtransfer/init.go +++ b/internal/msgtransfer/init.go @@ -52,7 +52,7 @@ type MsgTransfer struct { cancel context.CancelFunc } -func StartTransfer(prometheusPort int) error { +func StartTransfer(config *config.GlobalConfig, prometheusPort int) error { rdb, err := cache.NewRedis() if err != nil { return err @@ -66,12 +66,12 @@ func StartTransfer(prometheusPort int) error { if err = mongo.CreateMsgIndex(); err != nil { return err } - client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) + client, err := kdisc.NewDiscoveryRegister(config.Envs.Discovery) if err != nil { return err } - if err := client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil { + if err := client.CreateRpcRootNodes(config.GetServiceNames()); err != nil { return err } client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) diff --git a/internal/push/push_rpc_server.go b/internal/push/push_rpc_server.go index f558aeec3..c7ad3fba8 100644 --- a/internal/push/push_rpc_server.go +++ b/internal/push/push_rpc_server.go @@ -16,6 +16,7 @@ package push import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/tools/utils" @@ -34,9 +35,10 @@ import ( type pushServer struct { pusher *Pusher + config *config.GlobalConfig } -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { rdb, err := cache.NewRedis() if err != nil { return err @@ -60,6 +62,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e pbpush.RegisterPushMsgServiceServer(server, &pushServer{ pusher: pusher, + config: config, }) consumer, err := NewConsumer(pusher) diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index eaf63f868..301204e2e 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -42,9 +42,10 @@ type authServer struct { authDatabase controller.AuthDatabase userRpcClient *rpcclient.UserRpcClient RegisterCenter discoveryregistry.SvcDiscoveryRegistry + config *config.GlobalConfig } -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { rdb, err := cache.NewRedis() if err != nil { return err @@ -55,9 +56,10 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e RegisterCenter: client, authDatabase: controller.NewAuthDatabase( cache.NewMsgCacheModel(rdb), - config.Config.Secret, - config.Config.TokenPolicy.Expire, + config.Secret, + config.TokenPolicy.Expire, ), + config: config, }) return nil } diff --git a/internal/rpc/conversation/conversaion.go b/internal/rpc/conversation/conversaion.go index 8558a23ea..311e82d4b 100644 --- a/internal/rpc/conversation/conversaion.go +++ b/internal/rpc/conversation/conversaion.go @@ -17,6 +17,7 @@ package conversation import ( "context" "errors" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "sort" "github.com/OpenIMSDK/protocol/sdkws" @@ -49,6 +50,7 @@ type conversationServer struct { groupRpcClient *rpcclient.GroupRpcClient conversationDatabase controller.ConversationDatabase conversationNotificationSender *notification.ConversationNotificationSender + config *config.GlobalConfig } func (c *conversationServer) GetConversationNotReceiveMessageUserIDs( @@ -59,7 +61,7 @@ func (c *conversationServer) GetConversationNotReceiveMessageUserIDs( panic("implement me") } -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { rdb, err := cache.NewRedis() if err != nil { return err @@ -81,6 +83,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e conversationNotificationSender: notification.NewConversationNotificationSender(&msgRpcClient), groupRpcClient: &groupRpcClient, conversationDatabase: controller.NewConversationDatabase(conversationDB, cache.NewConversationRedis(rdb, cache.GetDefaultOpt(), conversationDB), tx.NewMongo(mongo.GetClient())), + config: config, }) return nil } diff --git a/internal/rpc/friend/friend.go b/internal/rpc/friend/friend.go index 84702f548..eacb5b921 100644 --- a/internal/rpc/friend/friend.go +++ b/internal/rpc/friend/friend.go @@ -16,6 +16,7 @@ package friend import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/OpenIMSDK/tools/tx" @@ -51,9 +52,10 @@ type friendServer struct { notificationSender *notification.FriendNotificationSender conversationRpcClient rpcclient.ConversationRpcClient RegisterCenter registry.SvcDiscoveryRegistry + config *config.GlobalConfig } -func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client registry.SvcDiscoveryRegistry, server *grpc.Server) error { // Initialize MongoDB mongo, err := unrelation.NewMongo() if err != nil { @@ -106,6 +108,7 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { notificationSender: notificationSender, RegisterCenter: client, conversationRpcClient: rpcclient.NewConversationRpcClient(client), + config: config, }) return nil @@ -113,13 +116,12 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { // ok. func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *pbfriend.ApplyToAddFriendReq) (resp *pbfriend.ApplyToAddFriendResp, err error) { - defer log.ZInfo(ctx, utils.GetFuncName()+" Return") resp = &pbfriend.ApplyToAddFriendResp{} if err := authverify.CheckAccessV3(ctx, req.FromUserID); err != nil { return nil, err } if req.ToUserID == req.FromUserID { - return nil, errs.ErrCanNotAddYourself.Wrap() + return nil, errs.ErrCanNotAddYourself.Wrap("req.ToUserID", req.ToUserID) } if err = CallbackBeforeAddFriend(ctx, req); err != nil && err != errs.ErrCallbackContinue { return nil, err diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 1d068b1b2..07a8ed9dd 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -17,6 +17,7 @@ package group import ( "context" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "math/big" "math/rand" "strconv" @@ -59,7 +60,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/db/unrelation" ) -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { mongo, err := unrelation.NewMongo() if err != nil { return err @@ -96,6 +97,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e }) gs.conversationRpcClient = conversationRpcClient gs.msgRpcClient = msgRpcClient + gs.config = config pbgroup.RegisterGroupServer(server, &gs) return nil } @@ -106,6 +108,7 @@ type groupServer struct { Notification *notification.GroupNotificationSender conversationRpcClient rpcclient.ConversationRpcClient msgRpcClient rpcclient.MessageRpcClient + config *config.GlobalConfig } func (s *groupServer) GetJoinedGroupIDs(ctx context.Context, req *pbgroup.GetJoinedGroupIDsReq) (*pbgroup.GetJoinedGroupIDsResp, error) { diff --git a/internal/rpc/msg/server.go b/internal/rpc/msg/server.go index fe1baa453..bbf8149a7 100644 --- a/internal/rpc/msg/server.go +++ b/internal/rpc/msg/server.go @@ -16,6 +16,7 @@ package msg import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "google.golang.org/grpc" @@ -44,6 +45,7 @@ type ( ConversationLocalCache *localcache.ConversationLocalCache Handlers MessageInterceptorChain notificationSender *rpcclient.NotificationSender + config *config.GlobalConfig } ) @@ -62,7 +64,7 @@ func (m *msgServer) execInterceptorHandler(ctx context.Context, req *msg.SendMsg return nil } -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { rdb, err := cache.NewRedis() if err != nil { return err @@ -93,6 +95,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e GroupLocalCache: localcache.NewGroupLocalCache(&groupRpcClient), ConversationLocalCache: localcache.NewConversationLocalCache(&conversationClient), friend: &friendRpcClient, + config: config, } s.notificationSender = rpcclient.NewNotificationSender(rpcclient.WithLocalSendMsg(s.SendMsg)) s.addInterceptorHandler(MessageHasReadEnabled) diff --git a/internal/rpc/third/third.go b/internal/rpc/third/third.go index 7a63d3526..dac59da73 100644 --- a/internal/rpc/third/third.go +++ b/internal/rpc/third/third.go @@ -39,7 +39,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" ) -func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error { mongo, err := unrelation.NewMongo() if err != nil { return err @@ -52,11 +52,11 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e if err != nil { return err } - apiURL := config.Config.Object.ApiURL + apiURL := config.Object.ApiURL if apiURL == "" { return fmt.Errorf("api url is empty") } - if _, err := url.Parse(config.Config.Object.ApiURL); err != nil { + if _, err := url.Parse(config.Object.ApiURL); err != nil { return err } if apiURL[len(apiURL)-1] != '/' { @@ -68,9 +68,9 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e return err } // 根据配置文件策略选择 oss 方式 - enable := config.Config.Object.Enable + enable := config.Object.Enable var o s3.Interface - switch config.Config.Object.Enable { + switch config.Object.Enable { case "minio": o, err = minio.NewMinio(cache.NewMinioCache(rdb)) case "cos": @@ -89,6 +89,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e userRpcClient: rpcclient.NewUserRpcClient(client), s3dataBase: controller.NewS3Database(rdb, o, s3db), defaultExpire: time.Hour * 24 * 7, + config: config, }) return nil } @@ -99,6 +100,7 @@ type thirdServer struct { s3dataBase controller.S3Database userRpcClient rpcclient.UserRpcClient defaultExpire time.Duration + config *config.GlobalConfig } func (t *thirdServer) FcmUpdateToken(ctx context.Context, req *third.FcmUpdateTokenReq) (resp *third.FcmUpdateTokenResp, err error) { diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go index 6f9e2949f..288022254 100644 --- a/internal/rpc/user/user.go +++ b/internal/rpc/user/user.go @@ -59,6 +59,7 @@ type userServer struct { friendRpcClient *rpcclient.FriendRpcClient groupRpcClient *rpcclient.GroupRpcClient RegisterCenter registry.SvcDiscoveryRegistry + config *config.GlobalConfig } func (s *userServer) GetGroupOnlineUser(ctx context.Context, req *pbuser.GetGroupOnlineUserReq) (*pbuser.GetGroupOnlineUserResp, error) { @@ -66,7 +67,7 @@ func (s *userServer) GetGroupOnlineUser(ctx context.Context, req *pbuser.GetGrou panic("implement me") } -func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { +func Start(config *config.GlobalConfig, client registry.SvcDiscoveryRegistry, server *grpc.Server) error { rdb, err := cache.NewRedis() if err != nil { return err @@ -76,11 +77,11 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { return err } users := make([]*tablerelation.UserModel, 0) - if len(config.Config.IMAdmin.UserID) != len(config.Config.IMAdmin.Nickname) { + if len(config.IMAdmin.UserID) != len(config.IMAdmin.Nickname) { return errors.New("len(config.Config.AppNotificationAdmin.AppManagerUid) != len(config.Config.AppNotificationAdmin.Nickname)") } - for k, v := range config.Config.IMAdmin.UserID { - users = append(users, &tablerelation.UserModel{UserID: v, Nickname: config.Config.IMAdmin.Nickname[k], AppMangerLevel: constant.AppNotificationAdmin}) + for k, v := range config.IMAdmin.UserID { + users = append(users, &tablerelation.UserModel{UserID: v, Nickname: config.IMAdmin.Nickname[k], AppMangerLevel: constant.AppNotificationAdmin}) } userDB, err := mgo.NewUserMongo(mongo.GetDatabase()) if err != nil { @@ -99,6 +100,7 @@ func Start(client registry.SvcDiscoveryRegistry, server *grpc.Server) error { groupRpcClient: &groupRpcClient, friendNotificationSender: notification.NewFriendNotificationSender(&msgRpcClient, notification.WithDBFunc(database.FindWithError)), userNotificationSender: notification.NewUserNotificationSender(&msgRpcClient, notification.WithUserFunc(database.FindWithError)), + config: config, } pbuser.RegisterUserServer(server, u) return u.UserDatabase.InitOnce(context.Background(), users) diff --git a/internal/tools/cron_task.go b/internal/tools/cron_task.go index decc1aa82..6d0d9fdf2 100644 --- a/internal/tools/cron_task.go +++ b/internal/tools/cron_task.go @@ -31,8 +31,8 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/db/cache" ) -func StartTask() error { - fmt.Println("cron task start, config", config.Config.ChatRecordsClearTime) +func StartTask(config *config.GlobalConfig) error { + fmt.Println("cron task start, config", config.ChatRecordsClearTime) msgTool, err := InitMsgTool() if err != nil { @@ -48,14 +48,14 @@ func StartTask() error { // register cron tasks var crontab = cron.New() - fmt.Println("start chatRecordsClearTime cron task", "cron config", config.Config.ChatRecordsClearTime) - _, err = crontab.AddFunc(config.Config.ChatRecordsClearTime, cronWrapFunc(rdb, "cron_clear_msg_and_fix_seq", msgTool.AllConversationClearMsgAndFixSeq)) + fmt.Println("start chatRecordsClearTime cron task", "cron config", config.ChatRecordsClearTime) + _, err = crontab.AddFunc(config.ChatRecordsClearTime, cronWrapFunc(config, rdb, "cron_clear_msg_and_fix_seq", msgTool.AllConversationClearMsgAndFixSeq)) if err != nil { return errs.Wrap(err) } - fmt.Println("start msgDestruct cron task", "cron config", config.Config.MsgDestructTime) - _, err = crontab.AddFunc(config.Config.MsgDestructTime, cronWrapFunc(rdb, "cron_conversations_destruct_msgs", msgTool.ConversationsDestructMsgs)) + fmt.Println("start msgDestruct cron task", "cron config", config.MsgDestructTime) + _, err = crontab.AddFunc(config.MsgDestructTime, cronWrapFunc(config, rdb, "cron_conversations_destruct_msgs", msgTool.ConversationsDestructMsgs)) if err != nil { return errs.Wrap(err) } @@ -93,8 +93,8 @@ func netlock(rdb redis.UniversalClient, key string, ttl time.Duration) bool { return ok } -func cronWrapFunc(rdb redis.UniversalClient, key string, fn func()) func() { - enableCronLocker := config.Config.EnableCronLocker +func cronWrapFunc(config *config.GlobalConfig, rdb redis.UniversalClient, key string, fn func()) func() { + enableCronLocker := config.EnableCronLocker return func() { // if don't enable cron-locker, call fn directly. if !enableCronLocker { diff --git a/pkg/authverify/token.go b/pkg/authverify/token.go index 97bb03391..ce2dca0c2 100644 --- a/pkg/authverify/token.go +++ b/pkg/authverify/token.go @@ -44,7 +44,7 @@ func CheckAccessV3(ctx context.Context, ownerUserID string) (err error) { if opUserID == ownerUserID { return nil } - return errs.ErrNoPermission.Wrap(utils.GetSelfFuncName()) + return errs.ErrNoPermission.Wrap("ownerUserID", ownerUserID) } func IsAppManagerUid(ctx context.Context) bool { diff --git a/pkg/common/cmd/api.go b/pkg/common/cmd/api.go index db1f488ad..859508ce3 100644 --- a/pkg/common/cmd/api.go +++ b/pkg/common/cmd/api.go @@ -16,33 +16,43 @@ package cmd import ( "github.com/OpenIMSDK/protocol/constant" + "github.com/openimsdk/open-im-server/v3/internal/api" "github.com/spf13/cobra" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) type ApiCmd struct { *RootCmd + initFunc func(config *config.GlobalConfig, port int, promPort int) error } func NewApiCmd() *ApiCmd { - ret := &ApiCmd{NewRootCmd("api")} + ret := &ApiCmd{RootCmd: NewRootCmd("api"), initFunc: api.Start} ret.SetRootCmdPt(ret) - + ret.addPreRun() + ret.addRunE() return ret } -func (a *ApiCmd) AddApi(f func(port int, promPort int) error) { +func (a *ApiCmd) addPreRun() { + a.Command.PreRun = func(cmd *cobra.Command, args []string) { + a.port = a.getPortFlag(cmd) + a.prometheusPort = a.getPrometheusPortFlag(cmd) + } +} + +func (a *ApiCmd) addRunE() { a.Command.RunE = func(cmd *cobra.Command, args []string) error { - return f(a.getPortFlag(cmd), a.getPrometheusPortFlag(cmd)) + return a.initFunc(a.config, a.port, a.prometheusPort) } } func (a *ApiCmd) GetPortFromConfig(portType string) int { if portType == constant.FlagPort { - return config2.Config.Api.OpenImApiPort[0] + return a.config.Api.OpenImApiPort[0] } else if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.ApiPrometheusPort[0] + return a.config.Prometheus.ApiPrometheusPort[0] } return 0 } diff --git a/pkg/common/cmd/cron_task.go b/pkg/common/cmd/cron_task.go index 1b0e796ac..6d56f5f4f 100644 --- a/pkg/common/cmd/cron_task.go +++ b/pkg/common/cmd/cron_task.go @@ -14,25 +14,30 @@ package cmd -import "github.com/spf13/cobra" +import ( + "github.com/openimsdk/open-im-server/v3/internal/tools" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" + "github.com/spf13/cobra" +) type CronTaskCmd struct { *RootCmd + initFunc func(config *config.GlobalConfig) error } func NewCronTaskCmd() *CronTaskCmd { - ret := &CronTaskCmd{NewRootCmd("cronTask", WithCronTaskLogName())} + ret := &CronTaskCmd{RootCmd: NewRootCmd("cronTask", WithCronTaskLogName()), + initFunc: tools.StartTask} ret.SetRootCmdPt(ret) return ret } -func (c *CronTaskCmd) addRunE(f func() error) { +func (c *CronTaskCmd) addRunE() { c.Command.RunE = func(cmd *cobra.Command, args []string) error { - return f() + return c.initFunc(c.config) } } -func (c *CronTaskCmd) Exec(f func() error) error { - c.addRunE(f) +func (c *CronTaskCmd) Exec() error { return c.Execute() } diff --git a/pkg/common/cmd/msg_gateway.go b/pkg/common/cmd/msg_gateway.go index 25fcc1177..872804603 100644 --- a/pkg/common/cmd/msg_gateway.go +++ b/pkg/common/cmd/msg_gateway.go @@ -31,6 +31,7 @@ type MsgGatewayCmd struct { func NewMsgGatewayCmd() *MsgGatewayCmd { ret := &MsgGatewayCmd{NewRootCmd("msgGateway")} + ret.addRunE() ret.SetRootCmdPt(ret) return ret } @@ -52,12 +53,11 @@ func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int { func (m *MsgGatewayCmd) addRunE() { m.Command.RunE = func(cmd *cobra.Command, args []string) error { - return msggateway.RunWsAndServer(m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd)) + return msggateway.RunWsAndServer(m.config, m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd)) } } func (m *MsgGatewayCmd) Exec() error { - m.addRunE() return m.Execute() } diff --git a/pkg/common/cmd/msg_transfer.go b/pkg/common/cmd/msg_transfer.go index e57bab89d..75ef087c1 100644 --- a/pkg/common/cmd/msg_transfer.go +++ b/pkg/common/cmd/msg_transfer.go @@ -20,8 +20,6 @@ import ( "github.com/OpenIMSDK/protocol/constant" "github.com/spf13/cobra" - config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/openimsdk/open-im-server/v3/internal/msgtransfer" ) @@ -31,18 +29,18 @@ type MsgTransferCmd struct { func NewMsgTransferCmd() *MsgTransferCmd { ret := &MsgTransferCmd{NewRootCmd("msgTransfer")} + ret.addRunE() ret.SetRootCmdPt(ret) return ret } func (m *MsgTransferCmd) addRunE() { m.Command.RunE = func(cmd *cobra.Command, args []string) error { - return msgtransfer.StartTransfer(m.getPrometheusPortFlag(cmd)) + return msgtransfer.StartTransfer(m.config, m.getPrometheusPortFlag(cmd)) } } func (m *MsgTransferCmd) Exec() error { - m.addRunE() return m.Execute() } @@ -51,7 +49,7 @@ func (m *MsgTransferCmd) GetPortFromConfig(portType string) int { return 0 } else if portType == constant.FlagPrometheusPort { n := m.getTransferProgressFlagValue() - return config2.Config.Prometheus.MessageTransferPrometheusPort[n] + return m.config.Prometheus.MessageTransferPrometheusPort[n] } return 0 } @@ -61,10 +59,10 @@ func (m *MsgTransferCmd) AddTransferProgressFlag() { } func (m *MsgTransferCmd) getTransferProgressFlagValue() int { - nindex, err := m.Command.Flags().GetInt(constant.FlagTransferProgressIndex) + nIndex, err := m.Command.Flags().GetInt(constant.FlagTransferProgressIndex) if err != nil { - fmt.Println("get transfercmd error,make sure it is k8s env or not") + fmt.Println("get transfer cmd error,make sure it is k8s env or not") return 0 } - return nindex + return nIndex } diff --git a/pkg/common/cmd/root.go b/pkg/common/cmd/root.go index eab4a32bc..651681e5b 100644 --- a/pkg/common/cmd/root.go +++ b/pkg/common/cmd/root.go @@ -36,6 +36,11 @@ type RootCmd struct { port int prometheusPort int cmdItf RootCmdPt + config *config.GlobalConfig +} + +func (rc *RootCmd) Port() int { + return rc.port } type CmdOpts struct { @@ -55,7 +60,7 @@ func WithLogName(logName string) func(*CmdOpts) { } func NewRootCmd(name string, opts ...func(*CmdOpts)) *RootCmd { - rootCmd := &RootCmd{Name: name} + rootCmd := &RootCmd{Name: name, config: config.NewGlobalConfig()} cmd := cobra.Command{ Use: "Start openIM application", Short: fmt.Sprintf(`Start %s `, name), @@ -97,7 +102,7 @@ func (rc *RootCmd) applyOptions(opts ...func(*CmdOpts)) *CmdOpts { } func (rc *RootCmd) initializeLogger(cmdOpts *CmdOpts) error { - logConfig := config.Config.Log + logConfig := rc.config.Log return log.InitFromConfig( @@ -164,7 +169,7 @@ func (r *RootCmd) GetPrometheusPortFlag() int { func (r *RootCmd) getConfFromCmdAndInit(cmdLines *cobra.Command) error { configFolderPath, _ := cmdLines.Flags().GetString(constant.FlagConf) fmt.Println("The directory of the configuration file to start the process:", configFolderPath) - return config2.InitConfig(configFolderPath) + return config2.InitConfig(r.config, configFolderPath) } func (r *RootCmd) Execute() error { diff --git a/pkg/common/cmd/rpc.go b/pkg/common/cmd/rpc.go index ea2a00b07..5199524e7 100644 --- a/pkg/common/cmd/rpc.go +++ b/pkg/common/cmd/rpc.go @@ -16,6 +16,7 @@ package cmd import ( "errors" + "github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/protocol/constant" "github.com/spf13/cobra" @@ -28,89 +29,131 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/startrpc" ) +type rpcInitFuc func(config *config2.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error + type RpcCmd struct { *RootCmd + RpcRegisterName string + initFunc rpcInitFuc } -func NewRpcCmd(name string) *RpcCmd { - ret := &RpcCmd{NewRootCmd(name)} +func NewRpcCmd(name string, initFunc rpcInitFuc) *RpcCmd { + ret := &RpcCmd{RootCmd: NewRootCmd(name), initFunc: initFunc} + ret.addPreRun() + ret.addRunE() ret.SetRootCmdPt(ret) return ret } -func (a *RpcCmd) Exec() error { - a.Command.Run = func(cmd *cobra.Command, args []string) { +func (a *RpcCmd) addPreRun() { + a.Command.PreRun = func(cmd *cobra.Command, args []string) { a.port = a.getPortFlag(cmd) a.prometheusPort = a.getPrometheusPortFlag(cmd) } +} + +func (a *RpcCmd) addRunE() { + a.Command.RunE = func(cmd *cobra.Command, args []string) error { + rpcRegisterName, err := a.GetRpcRegisterNameFromConfig() + if err != nil { + return err + } else { + return a.StartSvr(rpcRegisterName, a.initFunc) + } + } +} + +func (a *RpcCmd) Exec() error { return a.Execute() } -func (a *RpcCmd) StartSvr(name string, rpcFn func(discov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error { +func (a *RpcCmd) StartSvr(name string, rpcFn func(config *config2.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error { if a.GetPortFlag() == 0 { - return errors.New("port is required") + return errs.Wrap(errors.New("port is required")) } - return startrpc.Start(a.GetPortFlag(), name, a.GetPrometheusPortFlag(), rpcFn) + return startrpc.Start(a.GetPortFlag(), name, a.GetPrometheusPortFlag(), a.config, rpcFn) } func (a *RpcCmd) GetPortFromConfig(portType string) int { switch a.Name { case RpcPushServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImPushPort[0] + return a.config.RpcPort.OpenImPushPort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.PushPrometheusPort[0] + return a.config.Prometheus.PushPrometheusPort[0] } case RpcAuthServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImAuthPort[0] + return a.config.RpcPort.OpenImAuthPort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.AuthPrometheusPort[0] + return a.config.Prometheus.AuthPrometheusPort[0] } case RpcConversationServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImConversationPort[0] + return a.config.RpcPort.OpenImConversationPort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.ConversationPrometheusPort[0] + return a.config.Prometheus.ConversationPrometheusPort[0] } case RpcFriendServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImFriendPort[0] + return a.config.RpcPort.OpenImFriendPort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.FriendPrometheusPort[0] + return a.config.Prometheus.FriendPrometheusPort[0] } case RpcGroupServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImGroupPort[0] + return a.config.RpcPort.OpenImGroupPort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.GroupPrometheusPort[0] + return a.config.Prometheus.GroupPrometheusPort[0] } case RpcMsgServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImMessagePort[0] + return a.config.RpcPort.OpenImMessagePort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.MessagePrometheusPort[0] + return a.config.Prometheus.MessagePrometheusPort[0] } case RpcThirdServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImThirdPort[0] + return a.config.RpcPort.OpenImThirdPort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.ThirdPrometheusPort[0] + return a.config.Prometheus.ThirdPrometheusPort[0] } case RpcUserServer: if portType == constant.FlagPort { - return config2.Config.RpcPort.OpenImUserPort[0] + return a.config.RpcPort.OpenImUserPort[0] } if portType == constant.FlagPrometheusPort { - return config2.Config.Prometheus.UserPrometheusPort[0] + return a.config.Prometheus.UserPrometheusPort[0] } } return 0 } + +func (a *RpcCmd) GetRpcRegisterNameFromConfig() (string, error) { + switch a.Name { + case RpcPushServer: + return a.config.RpcRegisterName.OpenImPushName, nil + case RpcAuthServer: + return a.config.RpcRegisterName.OpenImAuthName, nil + case RpcConversationServer: + return a.config.RpcRegisterName.OpenImConversationName, nil + case RpcFriendServer: + return a.config.RpcRegisterName.OpenImFriendName, nil + case RpcGroupServer: + return a.config.RpcRegisterName.OpenImGroupName, nil + case RpcMsgServer: + return a.config.RpcRegisterName.OpenImMsgName, nil + case RpcThirdServer: + return a.config.RpcRegisterName.OpenImThirdName, nil + case RpcUserServer: + return a.config.RpcRegisterName.OpenImUserName, nil + } + return "", errs.Wrap(errors.New("can not get rpc register name"), a.Name) +} diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index 9696e9367..6228bdcf9 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -21,7 +21,7 @@ import ( "gopkg.in/yaml.v3" ) -var Config configStruct +var Config GlobalConfig const ConfKey = "conf" @@ -57,7 +57,7 @@ type MYSQL struct { SlowThreshold int `yaml:"slowThreshold"` } -type configStruct struct { +type GlobalConfig struct { Envs struct { Discovery string `yaml:"discovery"` } @@ -331,6 +331,10 @@ type configStruct struct { Notification notification `yaml:"notification"` } +func NewGlobalConfig() *GlobalConfig { + return &GlobalConfig{} +} + type notification struct { GroupCreated NotificationConf `yaml:"groupCreated"` GroupInfoSet NotificationConf `yaml:"groupInfoSet"` @@ -370,7 +374,7 @@ type notification struct { ConversationSetPrivate NotificationConf `yaml:"conversationSetPrivate"` } -func (c *configStruct) GetServiceNames() []string { +func (c *GlobalConfig) GetServiceNames() []string { return []string{ c.RpcRegisterName.OpenImUserName, c.RpcRegisterName.OpenImFriendName, @@ -384,7 +388,7 @@ func (c *configStruct) GetServiceNames() []string { } } -func (c *configStruct) RegisterConf2Registry(registry discoveryregistry.SvcDiscoveryRegistry) error { +func (c *GlobalConfig) RegisterConf2Registry(registry discoveryregistry.SvcDiscoveryRegistry) error { data, err := yaml.Marshal(c) if err != nil { return err @@ -392,11 +396,11 @@ func (c *configStruct) RegisterConf2Registry(registry discoveryregistry.SvcDisco return registry.RegisterConf2Registry(ConfKey, data) } -func (c *configStruct) GetConfFromRegistry(registry discoveryregistry.SvcDiscoveryRegistry) ([]byte, error) { +func (c *GlobalConfig) GetConfFromRegistry(registry discoveryregistry.SvcDiscoveryRegistry) ([]byte, error) { return registry.GetConfFromRegistry(ConfKey) } -func (c *configStruct) EncodeConfig() []byte { +func (c *GlobalConfig) EncodeConfig() []byte { buf := bytes.NewBuffer(nil) if err := yaml.NewEncoder(buf).Encode(c); err != nil { panic(err) diff --git a/pkg/common/config/parse.go b/pkg/common/config/parse.go index 4037429e3..1ed15e627 100644 --- a/pkg/common/config/parse.go +++ b/pkg/common/config/parse.go @@ -106,7 +106,7 @@ func initConfig(config any, configName, configFolderPath string) error { return nil } -func InitConfig(configFolderPath string) error { +func InitConfig(config *GlobalConfig, configFolderPath string) error { if configFolderPath == "" { envConfigPath := os.Getenv("OPENIMCONFIG") if envConfigPath != "" { @@ -116,9 +116,9 @@ func InitConfig(configFolderPath string) error { } } - if err := initConfig(&Config, FileName, configFolderPath); err != nil { + if err := initConfig(config, FileName, configFolderPath); err != nil { return err } - return initConfig(&Config.Notification, NotificationFileName, configFolderPath) + return initConfig(config.Notification, NotificationFileName, configFolderPath) } diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index c5105ec51..39c3238b8 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/config" "net" "net/http" "os" @@ -34,7 +35,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" + config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -53,36 +54,37 @@ func Start( rpcPort int, rpcRegisterName string, prometheusPort int, - rpcFn func(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error, + config *config2.GlobalConfig, + rpcFn func(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error, options ...grpc.ServerOption, ) error { fmt.Printf("start %s server, port: %d, prometheusPort: %d, OpenIM version: %s\n", - rpcRegisterName, rpcPort, prometheusPort, config.Version) - rpcTcpAddr := net.JoinHostPort(network.GetListenIP(config.Config.Rpc.ListenIP), strconv.Itoa(rpcPort)) + rpcRegisterName, rpcPort, prometheusPort, config2.Version) + rpcTcpAddr := net.JoinHostPort(network.GetListenIP(config.Rpc.ListenIP), strconv.Itoa(rpcPort)) listener, err := net.Listen( "tcp", rpcTcpAddr, ) if err != nil { - return errs.Wrap(err, "rpc start err", rpcTcpAddr) + return errs.Wrap(err, "listen err", rpcTcpAddr) } defer listener.Close() - client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) + client, err := kdisc.NewDiscoveryRegister(config.Envs.Discovery) if err != nil { return errs.Wrap(err) } defer client.Close() client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) - registerIP, err := network.GetRpcRegisterIP(config.Config.Rpc.RegisterIP) + registerIP, err := network.GetRpcRegisterIP(config.Rpc.RegisterIP) if err != nil { return errs.Wrap(err) } var reg *prometheus.Registry var metric *grpcprometheus.ServerMetrics - if config.Config.Prometheus.Enable { + if config.Prometheus.Enable { cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName) reg, metric, _ = prommetrics.NewGrpcPromObj(cusMetrics) options = append(options, mw.GrpcServer(), grpc.StreamInterceptor(metric.StreamServerInterceptor()), @@ -97,7 +99,7 @@ func Start( once.Do(srv.GracefulStop) }() - err = rpcFn(client, srv) + err = rpcFn(config, client, srv) if err != nil { return err } @@ -117,7 +119,7 @@ func Start( httpServer *http.Server ) go func() { - if config.Config.Prometheus.Enable && prometheusPort != 0 { + if config.Prometheus.Enable && prometheusPort != 0 { metric.InitializeMetrics(srv) // Create a HTTP server for prometheus. httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)}