From 77c64ec20f170e1dfd2ea2e2bd90bb0488d556d3 Mon Sep 17 00:00:00 2001 From: rfyiamcool Date: Mon, 6 Nov 2023 11:34:05 +0800 Subject: [PATCH] refactor: clietn in msggateway Signed-off-by: rfyiamcool --- internal/msggateway/client.go | 75 ++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 9c165e4dd..caba9c729 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -20,6 +20,7 @@ import ( "fmt" "runtime/debug" "sync" + "sync/atomic" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" @@ -70,7 +71,7 @@ type Client struct { IsBackground bool `json:"isBackground"` ctx *UserConnContext longConnServer LongConnServer - closed bool + closed atomic.Bool closedErr error token string } @@ -102,18 +103,14 @@ func (c *Client) ResetClient( c.ctx = ctx c.longConnServer = longConnServer c.IsBackground = false - c.closed = false + c.closed.Store(false) c.closedErr = nil c.token = token } func (c *Client) pingHandler(_ string) error { - c.conn.SetReadDeadline(pongWait) - err := c.writePongMsg() - if err != nil { - return err - } - return nil + _ = c.conn.SetReadDeadline(pongWait) + return c.writePongMsg() } func (c *Client) readMessage() { @@ -124,9 +121,11 @@ func (c *Client) readMessage() { } c.close() }() + c.conn.SetReadLimit(maxMessageSize) _ = c.conn.SetReadDeadline(pongWait) c.conn.SetPingHandler(c.pingHandler) + for { messageType, message, returnErr := c.conn.ReadMessage() if returnErr != nil { @@ -134,11 +133,13 @@ func (c *Client) readMessage() { c.closedErr = returnErr return } + log.ZDebug(c.ctx, "readMessage", "messageType", messageType) - if c.closed { // 连接刚置位已经关闭,但是协程还没退出的场景 + if c.closed.Load() { // 连接刚置位已经关闭,但是协程还没退出的场景 c.closedErr = ErrConnClosed return } + switch messageType { case MessageBinary: _ = c.conn.SetReadDeadline(pongWait) @@ -150,9 +151,11 @@ func (c *Client) readMessage() { case MessageText: c.closedErr = ErrNotSupportMessageProtocol return + case PingMessage: err := c.writePongMsg() log.ZError(c.ctx, "writePongMsg", err) + case CloseMessage: c.closedErr = ErrClientClosed return @@ -163,10 +166,10 @@ func (c *Client) readMessage() { func (c *Client) handleMessage(message []byte) error { if c.IsCompress { - var decompressErr error - message, decompressErr = c.longConnServer.DeCompress(message) - if decompressErr != nil { - return utils.Wrap(decompressErr, "") + var err error + message, err = c.longConnServer.DeCompress(message) + if err != nil { + return utils.Wrap(err, "") } } var binaryReq Req @@ -174,18 +177,26 @@ func (c *Client) handleMessage(message []byte) error { if err != nil { return utils.Wrap(err, "") } + if err := c.longConnServer.Validate(binaryReq); err != nil { return utils.Wrap(err, "") } + if binaryReq.SendID != c.UserID { return utils.Wrap(errors.New("exception conn userID not same to req userID"), binaryReq.String()) } + ctx := mcontext.WithMustInfoCtx( []string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.PlatformID), c.ctx.GetConnID()}, ) + log.ZDebug(ctx, "gateway req message", "req", binaryReq.String()) - var messageErr error - var resp []byte + + var ( + resp []byte + messageErr error + ) + switch binaryReq.ReqIdentifier { case WSGetNewestSeq: resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq) @@ -216,15 +227,21 @@ func (c *Client) setAppBackgroundStatus(ctx context.Context, req Req) ([]byte, e if messageErr != nil { return nil, messageErr } + c.IsBackground = isBackground // todo callback return resp, nil } func (c *Client) close() { + if c.closed.Load() { + return + } + c.w.Lock() defer c.w.Unlock() - c.closed = true + + c.closed.Store(true) c.conn.Close() c.longConnServer.UnRegister(c) } @@ -244,6 +261,7 @@ func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, re if err != nil { log.ZWarn(ctx, "wireBinaryMsg replyMessage", err, "resp", mReply.String()) } + if binaryReq.ReqIdentifier == WsLogoutMsg { return errors.New("user logout") } @@ -280,39 +298,42 @@ func (c *Client) KickOnlineMessage() error { } func (c *Client) writeBinaryMsg(resp Resp) error { - c.w.Lock() - defer c.w.Unlock() - if c.closed { + if c.closed.Load() { return nil } - resultBuf := bufferPool.Get().([]byte) encodedBuf, err := c.longConnServer.Encode(resp) if err != nil { return utils.Wrap(err, "") } + + c.w.Lock() + defer c.w.Unlock() + _ = c.conn.SetWriteDeadline(writeWait) if c.IsCompress { - var compressErr error - resultBuf, compressErr = c.longConnServer.Compress(encodedBuf) + resultBuf, compressErr := c.longConnServer.Compress(encodedBuf) if compressErr != nil { return utils.Wrap(compressErr, "") } return c.conn.WriteMessage(MessageBinary, resultBuf) - } else { - return c.conn.WriteMessage(MessageBinary, encodedBuf) } + + return c.conn.WriteMessage(MessageBinary, encodedBuf) } func (c *Client) writePongMsg() error { - c.w.Lock() - defer c.w.Unlock() - if c.closed { + if c.closed.Load() { return nil } + + c.w.Lock() + defer c.w.Unlock() + err := c.conn.SetWriteDeadline(writeWait) if err != nil { return utils.Wrap(err, "") } + return c.conn.WriteMessage(PongMessage, nil) }