diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index e7e11f79c..96c5dc896 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -74,6 +74,7 @@ type Client struct { closedErr error token string hbCtx context.Context + hbCancel context.CancelFunc } // ResetClient updates the client's state with new connection and context information. @@ -90,7 +91,7 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer c.closed.Store(false) c.closedErr = nil c.token = ctx.GetToken() - c.hbCtx, _ = context.WithCancel(c.ctx) + c.hbCtx, c.hbCancel = context.WithCancel(c.ctx) } func (c *Client) pingHandler(_ string) error { @@ -101,6 +102,13 @@ func (c *Client) pingHandler(_ string) error { return c.writePongMsg() } +func (c *Client) pongHandler(_ string) error { + if err := c.conn.SetReadDeadline(pongWait); err != nil { + return err + } + return nil +} + // readMessage continuously reads messages from the connection. func (c *Client) readMessage() { defer func() { @@ -113,8 +121,9 @@ func (c *Client) readMessage() { c.conn.SetReadLimit(maxMessageSize) _ = c.conn.SetReadDeadline(pongWait) + c.conn.SetPongHandler(c.pongHandler) c.conn.SetPingHandler(c.pingHandler) - go c.activeHeartbeat(c.hbCtx) + c.activeHeartbeat(c.hbCtx) for { log.ZDebug(c.ctx, "readMessage") @@ -240,7 +249,7 @@ func (c *Client) close() { c.closed.Store(true) c.conn.Close() - <-c.hbCtx.Done() // Close server-initiated heartbeat. + c.hbCancel() // Close server-initiated heartbeat. c.longConnServer.UnRegister(c) } @@ -330,21 +339,23 @@ func (c *Client) writeBinaryMsg(resp Resp) error { // Actively initiate Heartbeat when platform in Web. func (c *Client) activeHeartbeat(ctx context.Context) { if c.PlatformID == constant.WebPlatformID { - log.ZDebug(ctx, "server initiative send heartbeat start.") - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := c.conn.WriteMessage(PingMessage, nil); err != nil { - log.ZError(c.ctx, "send Ping Message error.", err) + go func() { + log.ZDebug(ctx, "server initiative send heartbeat start.") + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := c.conn.WriteMessage(PingMessage, nil); err != nil { + log.ZError(c.ctx, "send Ping Message error.", err) + return + } + case <-c.hbCtx.Done(): return } - case <-c.hbCtx.Done(): - return } - } + }() } }