diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index 14c320c42..aeba0a24a 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -19,6 +19,7 @@ import ( "time" "github.com/OpenIMSDK/tools/utils" + "golang.org/x/sync/errgroup" "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) @@ -43,12 +44,22 @@ func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { if err != nil { return err } + hubServer := NewServer(rpcPort, prometheusPort, longServer) - go func() { - err := hubServer.Start() + + wg := errgroup.Group{} + wg.Go(func() error { + err = hubServer.Start() if err != nil { - panic(utils.Wrap1(err)) + return utils.Wrap1(err) } - }() - return hubServer.LongConnServer.Run() + return err + }) + + wg.Go(func() error { + return hubServer.LongConnServer.Run() + }) + + err = wg.Wait() + return err } diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index a249ff70f..99a7a4805 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -18,9 +18,12 @@ import ( "context" "errors" "net/http" + "os" + "os/signal" "strconv" "sync" "sync/atomic" + "syscall" "time" "github.com/go-playground/validator/v10" @@ -156,10 +159,22 @@ func NewWsServer(opts ...Option) (*WsServer, error) { } func (ws *WsServer) Run() error { - var client *Client - go func() { + var ( + client *Client + wg errgroup.Group + + sigs = make(chan os.Signal, 1) + done = make(chan struct{}, 1) + ) + + server := http.Server{Addr: ":" + utils.IntToString(ws.port), Handler: nil} + + wg.Go(func() error { for { select { + case <-done: + return nil + case client = <-ws.registerChan: ws.registerClient(client) case client = <-ws.unregisterChan: @@ -168,10 +183,34 @@ func (ws *WsServer) Run() error { ws.multiTerminalLoginChecker(onlineInfo.clientOK, onlineInfo.oldClients, onlineInfo.newClient) } } + }) + + wg.Go(func() error { + http.HandleFunc("/", ws.wsHandler) + return server.ListenAndServe() + }) + + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + <-sigs + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // graceful exit operation for server + _ = server.Shutdown(ctx) + _ = wg.Wait() + close(done) }() - http.HandleFunc("/", ws.wsHandler) - // http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {}) - return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) // Start listening + + select { + case <-done: + return nil + + case <-time.After(15 * time.Second): + return utils.Wrap1(errors.New("timeout exit")) + } + } var concurrentRequest = 3 diff --git a/pkg/common/cmd/msg_gateway.go b/pkg/common/cmd/msg_gateway.go index 7f0abb771..25fcc1177 100644 --- a/pkg/common/cmd/msg_gateway.go +++ b/pkg/common/cmd/msg_gateway.go @@ -17,12 +17,12 @@ package cmd import ( "log" - "github.com/openimsdk/open-im-server/v3/internal/msggateway" - v3config "github.com/openimsdk/open-im-server/v3/pkg/common/config" - "github.com/spf13/cobra" "github.com/OpenIMSDK/protocol/constant" + + "github.com/openimsdk/open-im-server/v3/internal/msggateway" + v3config "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) type MsgGatewayCmd struct { @@ -60,14 +60,19 @@ func (m *MsgGatewayCmd) Exec() error { m.addRunE() return m.Execute() } + func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int { - if portType == constant.FlagWsPort { + switch portType { + case constant.FlagWsPort: return v3config.Config.LongConnSvr.OpenImWsPort[0] - } else if portType == constant.FlagPort { + + case constant.FlagPort: return v3config.Config.LongConnSvr.OpenImMessageGatewayPort[0] - } else if portType == constant.FlagPrometheusPort { + + case constant.FlagPrometheusPort: return v3config.Config.Prometheus.MessageGatewayPrometheusPort[0] - } else { + + default: return 0 } } diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index d5e31701e..01076bbbb 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -15,14 +15,21 @@ package startrpc import ( + "errors" "fmt" "log" "net" "net/http" + "os" + "os/signal" "strconv" + "sync" + "syscall" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "golang.org/x/sync/errgroup" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" @@ -56,31 +63,37 @@ func Start( if err != nil { return err } + defer listener.Close() client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) if err != nil { return utils.Wrap1(err) } + defer client.Close() client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials())) registerIP, err := network.GetRpcRegisterIP(config.Config.Rpc.RegisterIP) if err != nil { return err } + var reg *prometheus.Registry var metric *grpcprometheus.ServerMetrics - // ctx 中间件 if config.Config.Prometheus.Enable { - ////////////////////////// cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName) - reg, metric, err = prommetrics.NewGrpcPromObj(cusMetrics) + reg, metric, _ = prommetrics.NewGrpcPromObj(cusMetrics) options = append(options, mw.GrpcServer(), grpc.StreamInterceptor(metric.StreamServerInterceptor()), grpc.UnaryInterceptor(metric.UnaryServerInterceptor())) } else { options = append(options, mw.GrpcServer()) } + srv := grpc.NewServer(options...) - defer srv.GracefulStop() + once := sync.Once{} + defer func() { + once.Do(srv.GracefulStop) + }() + err = rpcFn(client, srv) if err != nil { return utils.Wrap1(err) @@ -94,7 +107,10 @@ func Start( if err != nil { return utils.Wrap1(err) } - go func() { + + var wg errgroup.Group + + wg.Go(func() error { if config.Config.Prometheus.Enable && prometheusPort != 0 { metric.InitializeMetrics(srv) // Create a HTTP server for prometheus. @@ -103,7 +119,34 @@ func Start( log.Fatal("Unable to start a http server.") } } + return nil + }) + + wg.Go(func() error { + return utils.Wrap1(srv.Serve(listener)) + }) + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + <-sigs + + var ( + done = make(chan struct{}, 1) + gerr error + ) + + go func() { + once.Do(srv.GracefulStop) + gerr = wg.Wait() + close(done) }() - return utils.Wrap1(srv.Serve(listener)) + select { + case <-done: + return gerr + + case <-time.After(15 * time.Second): + return utils.Wrap1(errors.New("timeout exit")) + } + }