From 482227552efb5b2b30473a54cfb9d9a016212cad Mon Sep 17 00:00:00 2001 From: luhaoling <2198702716@qq.com> Date: Mon, 5 Feb 2024 10:02:43 +0800 Subject: [PATCH] fix: reconstruct exit gracefully --- cmd/openim-api/main.go | 26 ++++++--- internal/msggateway/init.go | 21 ++----- internal/msggateway/n_ws_server.go | 67 ++++++++++------------ pkg/common/startrpc/start.go | 89 +++++++++++++++++------------- 4 files changed, 103 insertions(+), 100 deletions(-) diff --git a/cmd/openim-api/main.go b/cmd/openim-api/main.go index a45bcbdd8..695782950 100644 --- a/cmd/openim-api/main.go +++ b/cmd/openim-api/main.go @@ -89,26 +89,34 @@ func run(port int, proPort int) error { } else { address = net.JoinHostPort("0.0.0.0", strconv.Itoa(port)) } - + var ( + netDone = make(chan struct{}, 1) + netErr error + ) server := http.Server{Addr: address, Handler: router} + go func() { err = server.ListenAndServe() if err != nil && err != http.ErrServerClosed { - os.Exit(1) + netErr = errs.Wrap(err, "api start err: ", server.Addr) + close(netDone) } }() sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - <-sigs + signal.Notify(sigs, syscall.SIGUSR1) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - - // graceful shutdown operation. - if err := server.Shutdown(ctx); err != nil { - return err + select { + case <-sigs: + print("receive process terminal SIGUSR1 exit") + err := server.Shutdown(ctx) + if err != nil { + return errs.Wrap(err, "shutdown err") + } + case <-netDone: + return netErr } - return nil } diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index aeba0a24a..321407f7e 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -18,9 +18,6 @@ import ( "fmt" "time" - "github.com/OpenIMSDK/tools/utils" - "golang.org/x/sync/errgroup" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" ) @@ -46,20 +43,12 @@ func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { } hubServer := NewServer(rpcPort, prometheusPort, longServer) - - wg := errgroup.Group{} - wg.Go(func() error { + netDone := make(chan error) + go func() { err = hubServer.Start() if err != nil { - return utils.Wrap1(err) + netDone <- err } - return err - }) - - wg.Go(func() error { - return hubServer.LongConnServer.Run() - }) - - err = wg.Wait() - return err + }() + return hubServer.LongConnServer.Run(netDone) } diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index 01d92b92a..738474338 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -20,12 +20,9 @@ import ( "errors" "fmt" "net/http" - "os" - "os/signal" "strconv" "sync" "sync/atomic" - "syscall" "time" "github.com/OpenIMSDK/tools/apiresp" @@ -49,7 +46,7 @@ import ( ) type LongConnServer interface { - Run() error + Run(done chan error) error wsHandler(w http.ResponseWriter, r *http.Request) GetUserAllCons(userID string) ([]*Client, bool) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) @@ -169,23 +166,20 @@ func NewWsServer(opts ...Option) (*WsServer, error) { }, nil } -func (ws *WsServer) Run() error { +func (ws *WsServer) Run(done chan error) error { var ( - client *Client - wg errgroup.Group - - sigs = make(chan os.Signal, 1) - done = make(chan struct{}, 1) + client *Client + netErr error + shutdownDone = make(chan struct{}, 1) ) server := http.Server{Addr: ":" + utils.IntToString(ws.port), Handler: nil} - wg.Go(func() error { + go func() { for { select { - case <-done: - return nil - + case <-shutdownDone: + return case client = <-ws.registerChan: ws.registerClient(client) case client = <-ws.unregisterChan: @@ -194,33 +188,32 @@ func (ws *WsServer) Run() error { ws.multiTerminalLoginChecker(onlineInfo.clientOK, onlineInfo.oldClients, onlineInfo.newClient) } } - }) - - wg.Go(func() error { - http.HandleFunc("/", ws.wsHandler) - return server.ListenAndServe() - }) - - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - <-sigs - + }() + netDone := make(chan struct{}, 1) go func() { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - // graceful exit operation for server - _ = server.Shutdown(ctx) - _ = wg.Wait() - close(done) + http.HandleFunc("/", ws.wsHandler) + err := server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + netErr = errs.Wrap(err, "ws start err: ", server.Addr) + close(netDone) + } }() - + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + var err error select { - case <-done: - return nil - - case <-time.After(15 * time.Second): - return utils.Wrap1(errors.New("timeout exit")) + case err = <-done: + sErr := server.Shutdown(ctx) + if sErr != nil { + return errs.Wrap(sErr, "shutdown err") + } + close(shutdownDone) + if err != nil { + return err + } + case <-netDone: } + return netErr } diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index b586efd89..0c0dedae7 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -15,9 +15,10 @@ package startrpc import ( + "context" "errors" "fmt" - "log" + "github.com/OpenIMSDK/tools/errs" "net" "net/http" "os" @@ -27,14 +28,10 @@ import ( "syscall" "time" - "github.com/OpenIMSDK/tools/errs" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "golang.org/x/sync/errgroup" - "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "google.golang.org/grpc" @@ -57,12 +54,13 @@ func Start( ) error { fmt.Printf("start %s server, port: %d, prometheusPort: %d, OpenIM version: %s\n", rpcRegisterName, rpcPort, prometheusPort, config.Version) + rpcTcpAddr := net.JoinHostPort(network.GetListenIP(config.Config.Rpc.ListenIP), strconv.Itoa(rpcPort)) listener, err := net.Listen( "tcp", - net.JoinHostPort(network.GetListenIP(config.Config.Rpc.ListenIP), strconv.Itoa(rpcPort)), + rpcTcpAddr, ) if err != nil { - return errs.Wrap(err, network.GetListenIP(config.Config.Rpc.ListenIP), strconv.Itoa(rpcPort)) + return errs.Wrap(err, rpcTcpAddr) } defer listener.Close() @@ -109,48 +107,63 @@ func Start( return errs.Wrap(err) } - var wg errgroup.Group - - wg.Go(func() error { + var ( + netDone = make(chan struct{}, 1) + netErr error + httpServer *http.Server + ) + go func() { if config.Config.Prometheus.Enable && prometheusPort != 0 { metric.InitializeMetrics(srv) // Create a HTTP server for prometheus. - httpServer := &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)} - if err := httpServer.ListenAndServe(); err != nil { - log.Fatal("Unable to start a http server. ", err.Error(), "PrometheusPort:", prometheusPort) + httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)} + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + netErr = errs.Wrap(err, "prometheus start err: ", httpServer.Addr) + close(netDone) } } - return nil - }) + }() - wg.Go(func() error { - return errs.Wrap(srv.Serve(listener)) - }) + go func() { + err := srv.Serve(listener) + if err != nil { + netErr = errs.Wrap(err, "rpc start err: ", rpcTcpAddr) + close(netDone) + } + }() sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - - log.Println("23333333333:", <-sigs) - - <-sigs - - var ( - done = make(chan struct{}, 1) - gerr error - ) + signal.Notify(sigs, syscall.SIGUSR1) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + select { + case <-sigs: + print("receive process terminal SIGUSR1 exit") + if err := gracefulStopWithCtx(ctx, srv.GracefulStop); err != nil { + return err + } + ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + err := httpServer.Shutdown(ctx) + if err != nil { + return errs.Wrap(err, "shutdown err") + } + case <-netDone: + return netErr + } + return nil +} +func gracefulStopWithCtx(ctx context.Context, f func()) error { + done := make(chan struct{}, 1) go func() { - once.Do(srv.GracefulStop) - gerr = wg.Wait() + f() close(done) }() - select { + case <-ctx.Done(): + return errs.Wrap(errors.New("timeout"), "ctx graceful stop") case <-done: - return gerr - - case <-time.After(15 * time.Second): - return errs.Wrap(errors.New("timeout exit")) + return nil } - }