refactor: client in msggateway (#1343)

* refactor: clietn in msggateway

Signed-off-by: rfyiamcool <rfyiamcool@163.com>

* perf: add sync.pool for req object

Signed-off-by: rfyiamcool <rfyiamcool@163.com>

---------

Signed-off-by: rfyiamcool <rfyiamcool@163.com>
pull/1353/head
fengyun.rui 1 year ago committed by GitHub
parent 1aef30dac4
commit 815fa15392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"runtime/debug" "runtime/debug"
"sync" "sync"
"sync/atomic"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
@ -70,7 +71,7 @@ type Client struct {
IsBackground bool `json:"isBackground"` IsBackground bool `json:"isBackground"`
ctx *UserConnContext ctx *UserConnContext
longConnServer LongConnServer longConnServer LongConnServer
closed bool closed atomic.Bool
closedErr error closedErr error
token string token string
} }
@ -102,18 +103,14 @@ func (c *Client) ResetClient(
c.ctx = ctx c.ctx = ctx
c.longConnServer = longConnServer c.longConnServer = longConnServer
c.IsBackground = false c.IsBackground = false
c.closed = false c.closed.Store(false)
c.closedErr = nil c.closedErr = nil
c.token = token c.token = token
} }
func (c *Client) pingHandler(_ string) error { func (c *Client) pingHandler(_ string) error {
c.conn.SetReadDeadline(pongWait) _ = c.conn.SetReadDeadline(pongWait)
err := c.writePongMsg() return c.writePongMsg()
if err != nil {
return err
}
return nil
} }
func (c *Client) readMessage() { func (c *Client) readMessage() {
@ -124,9 +121,11 @@ func (c *Client) readMessage() {
} }
c.close() c.close()
}() }()
c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadLimit(maxMessageSize)
_ = c.conn.SetReadDeadline(pongWait) _ = c.conn.SetReadDeadline(pongWait)
c.conn.SetPingHandler(c.pingHandler) c.conn.SetPingHandler(c.pingHandler)
for { for {
messageType, message, returnErr := c.conn.ReadMessage() messageType, message, returnErr := c.conn.ReadMessage()
if returnErr != nil { if returnErr != nil {
@ -134,11 +133,13 @@ func (c *Client) readMessage() {
c.closedErr = returnErr c.closedErr = returnErr
return return
} }
log.ZDebug(c.ctx, "readMessage", "messageType", messageType) log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
if c.closed { // 连接刚置位已经关闭,但是协程还没退出的场景 if c.closed.Load() { // 连接刚置位已经关闭,但是协程还没退出的场景
c.closedErr = ErrConnClosed c.closedErr = ErrConnClosed
return return
} }
switch messageType { switch messageType {
case MessageBinary: case MessageBinary:
_ = c.conn.SetReadDeadline(pongWait) _ = c.conn.SetReadDeadline(pongWait)
@ -150,9 +151,11 @@ func (c *Client) readMessage() {
case MessageText: case MessageText:
c.closedErr = ErrNotSupportMessageProtocol c.closedErr = ErrNotSupportMessageProtocol
return return
case PingMessage: case PingMessage:
err := c.writePongMsg() err := c.writePongMsg()
log.ZError(c.ctx, "writePongMsg", err) log.ZError(c.ctx, "writePongMsg", err)
case CloseMessage: case CloseMessage:
c.closedErr = ErrClientClosed c.closedErr = ErrClientClosed
return return
@ -163,29 +166,40 @@ func (c *Client) readMessage() {
func (c *Client) handleMessage(message []byte) error { func (c *Client) handleMessage(message []byte) error {
if c.IsCompress { if c.IsCompress {
var decompressErr error var err error
message, decompressErr = c.longConnServer.DeCompress(message) message, err = c.longConnServer.DeCompress(message)
if decompressErr != nil { if err != nil {
return utils.Wrap(decompressErr, "") return utils.Wrap(err, "")
} }
} }
var binaryReq Req
err := c.longConnServer.Decode(message, &binaryReq) var binaryReq = getReq()
defer freeReq(binaryReq)
err := c.longConnServer.Decode(message, binaryReq)
if err != nil { if err != nil {
return utils.Wrap(err, "") return utils.Wrap(err, "")
} }
if err := c.longConnServer.Validate(binaryReq); err != nil { if err := c.longConnServer.Validate(binaryReq); err != nil {
return utils.Wrap(err, "") return utils.Wrap(err, "")
} }
if binaryReq.SendID != c.UserID { if binaryReq.SendID != c.UserID {
return utils.Wrap(errors.New("exception conn userID not same to req userID"), binaryReq.String()) return utils.Wrap(errors.New("exception conn userID not same to req userID"), binaryReq.String())
} }
ctx := mcontext.WithMustInfoCtx( ctx := mcontext.WithMustInfoCtx(
[]string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.PlatformID), c.ctx.GetConnID()}, []string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.PlatformID), c.ctx.GetConnID()},
) )
log.ZDebug(ctx, "gateway req message", "req", binaryReq.String()) log.ZDebug(ctx, "gateway req message", "req", binaryReq.String())
var messageErr error
var resp []byte var (
resp []byte
messageErr error
)
switch binaryReq.ReqIdentifier { switch binaryReq.ReqIdentifier {
case WSGetNewestSeq: case WSGetNewestSeq:
resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq) resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq)
@ -208,23 +222,29 @@ func (c *Client) handleMessage(message []byte) error {
) )
} }
return c.replyMessage(ctx, &binaryReq, messageErr, resp) return c.replyMessage(ctx, binaryReq, messageErr, resp)
} }
func (c *Client) setAppBackgroundStatus(ctx context.Context, req Req) ([]byte, error) { func (c *Client) setAppBackgroundStatus(ctx context.Context, req *Req) ([]byte, error) {
resp, isBackground, messageErr := c.longConnServer.SetUserDeviceBackground(ctx, req) resp, isBackground, messageErr := c.longConnServer.SetUserDeviceBackground(ctx, req)
if messageErr != nil { if messageErr != nil {
return nil, messageErr return nil, messageErr
} }
c.IsBackground = isBackground c.IsBackground = isBackground
// todo callback // todo callback
return resp, nil return resp, nil
} }
func (c *Client) close() { func (c *Client) close() {
if c.closed.Load() {
return
}
c.w.Lock() c.w.Lock()
defer c.w.Unlock() defer c.w.Unlock()
c.closed = true
c.closed.Store(true)
c.conn.Close() c.conn.Close()
c.longConnServer.UnRegister(c) c.longConnServer.UnRegister(c)
} }
@ -244,6 +264,7 @@ func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, re
if err != nil { if err != nil {
log.ZWarn(ctx, "wireBinaryMsg replyMessage", err, "resp", mReply.String()) log.ZWarn(ctx, "wireBinaryMsg replyMessage", err, "resp", mReply.String())
} }
if binaryReq.ReqIdentifier == WsLogoutMsg { if binaryReq.ReqIdentifier == WsLogoutMsg {
return errors.New("user logout") return errors.New("user logout")
} }
@ -280,39 +301,42 @@ func (c *Client) KickOnlineMessage() error {
} }
func (c *Client) writeBinaryMsg(resp Resp) error { func (c *Client) writeBinaryMsg(resp Resp) error {
c.w.Lock() if c.closed.Load() {
defer c.w.Unlock()
if c.closed {
return nil return nil
} }
resultBuf := bufferPool.Get().([]byte)
encodedBuf, err := c.longConnServer.Encode(resp) encodedBuf, err := c.longConnServer.Encode(resp)
if err != nil { if err != nil {
return utils.Wrap(err, "") return utils.Wrap(err, "")
} }
c.w.Lock()
defer c.w.Unlock()
_ = c.conn.SetWriteDeadline(writeWait) _ = c.conn.SetWriteDeadline(writeWait)
if c.IsCompress { if c.IsCompress {
var compressErr error resultBuf, compressErr := c.longConnServer.Compress(encodedBuf)
resultBuf, compressErr = c.longConnServer.Compress(encodedBuf)
if compressErr != nil { if compressErr != nil {
return utils.Wrap(compressErr, "") return utils.Wrap(compressErr, "")
} }
return c.conn.WriteMessage(MessageBinary, resultBuf) return c.conn.WriteMessage(MessageBinary, resultBuf)
} else {
return c.conn.WriteMessage(MessageBinary, encodedBuf)
} }
return c.conn.WriteMessage(MessageBinary, encodedBuf)
} }
func (c *Client) writePongMsg() error { func (c *Client) writePongMsg() error {
c.w.Lock() if c.closed.Load() {
defer c.w.Unlock()
if c.closed {
return nil return nil
} }
c.w.Lock()
defer c.w.Unlock()
err := c.conn.SetWriteDeadline(writeWait) err := c.conn.SetWriteDeadline(writeWait)
if err != nil { if err != nil {
return utils.Wrap(err, "") return utils.Wrap(err, "")
} }
return c.conn.WriteMessage(PongMessage, nil) return c.conn.WriteMessage(PongMessage, nil)
} }

@ -16,6 +16,7 @@ package msggateway
import ( import (
"context" "context"
"sync"
"github.com/OpenIMSDK/protocol/push" "github.com/OpenIMSDK/protocol/push"
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
@ -49,6 +50,27 @@ func (r *Req) String() string {
return utils.StructToJsonString(tReq) return utils.StructToJsonString(tReq)
} }
var reqPool = sync.Pool{
New: func() any {
return new(Req)
},
}
func getReq() *Req {
req := reqPool.Get().(*Req)
req.Data = nil
req.MsgIncr = ""
req.OperationID = ""
req.ReqIdentifier = 0
req.SendID = ""
req.Token = ""
return req
}
func freeReq(req *Req) {
reqPool.Put(req)
}
type Resp struct { type Resp struct {
ReqIdentifier int32 `json:"reqIdentifier"` ReqIdentifier int32 `json:"reqIdentifier"`
MsgIncr string `json:"msgIncr"` MsgIncr string `json:"msgIncr"`
@ -69,12 +91,12 @@ func (r *Resp) String() string {
} }
type MessageHandler interface { type MessageHandler interface {
GetSeq(context context.Context, data Req) ([]byte, error) GetSeq(context context.Context, data *Req) ([]byte, error)
SendMessage(context context.Context, data Req) ([]byte, error) SendMessage(context context.Context, data *Req) ([]byte, error)
SendSignalMessage(context context.Context, data Req) ([]byte, error) SendSignalMessage(context context.Context, data *Req) ([]byte, error)
PullMessageBySeqList(context context.Context, data Req) ([]byte, error) PullMessageBySeqList(context context.Context, data *Req) ([]byte, error)
UserLogout(context context.Context, data Req) ([]byte, error) UserLogout(context context.Context, data *Req) ([]byte, error)
SetUserDeviceBackground(context context.Context, data Req) ([]byte, bool, error) SetUserDeviceBackground(context context.Context, data *Req) ([]byte, bool, error)
} }
var _ MessageHandler = (*GrpcHandler)(nil) var _ MessageHandler = (*GrpcHandler)(nil)
@ -94,7 +116,7 @@ func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDi
} }
} }
func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) { func (g GrpcHandler) GetSeq(context context.Context, data *Req) ([]byte, error) {
req := sdkws.GetMaxSeqReq{} req := sdkws.GetMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil { if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, err return nil, err
@ -113,7 +135,7 @@ func (g GrpcHandler) GetSeq(context context.Context, data Req) ([]byte, error) {
return c, nil return c, nil
} }
func (g GrpcHandler) SendMessage(context context.Context, data Req) ([]byte, error) { func (g GrpcHandler) SendMessage(context context.Context, data *Req) ([]byte, error) {
msgData := sdkws.MsgData{} msgData := sdkws.MsgData{}
if err := proto.Unmarshal(data.Data, &msgData); err != nil { if err := proto.Unmarshal(data.Data, &msgData); err != nil {
return nil, err return nil, err
@ -133,7 +155,7 @@ func (g GrpcHandler) SendMessage(context context.Context, data Req) ([]byte, err
return c, nil return c, nil
} }
func (g GrpcHandler) SendSignalMessage(context context.Context, data Req) ([]byte, error) { func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]byte, error) {
resp, err := g.msgRpcClient.SendMsg(context, nil) resp, err := g.msgRpcClient.SendMsg(context, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -145,7 +167,7 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data Req) ([]byt
return c, nil return c, nil
} }
func (g GrpcHandler) PullMessageBySeqList(context context.Context, data Req) ([]byte, error) { func (g GrpcHandler) PullMessageBySeqList(context context.Context, data *Req) ([]byte, error) {
req := sdkws.PullMessageBySeqsReq{} req := sdkws.PullMessageBySeqsReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil { if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, err return nil, err
@ -164,7 +186,7 @@ func (g GrpcHandler) PullMessageBySeqList(context context.Context, data Req) ([]
return c, nil return c, nil
} }
func (g GrpcHandler) UserLogout(context context.Context, data Req) ([]byte, error) { func (g GrpcHandler) UserLogout(context context.Context, data *Req) ([]byte, error) {
req := push.DelUserPushTokenReq{} req := push.DelUserPushTokenReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil { if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, err return nil, err
@ -180,7 +202,7 @@ func (g GrpcHandler) UserLogout(context context.Context, data Req) ([]byte, erro
return c, nil return c, nil
} }
func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data Req) ([]byte, bool, error) { func (g GrpcHandler) SetUserDeviceBackground(_ context.Context, data *Req) ([]byte, bool, error) {
req := sdkws.SetAppBackgroundStatusReq{} req := sdkws.SetAppBackgroundStatusReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil { if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, false, err return nil, false, err

Loading…
Cancel
Save