diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 7b9f4bc0e..74b874e95 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -16,7 +16,6 @@ package msggateway import ( "context" - "encoding/json" "fmt" "sync" "sync/atomic" @@ -31,7 +30,6 @@ import ( "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/openimsdk/tools/mcontext" - "github.com/openimsdk/tools/utils/stringutil" ) var ( @@ -64,7 +62,7 @@ type PingPongHandler func(string) error type Client struct { w *sync.Mutex - conn LongConn + conn ClientConn PlatformID int `json:"platformID"` IsCompress bool `json:"isCompress"` UserID string `json:"userID"` @@ -84,10 +82,10 @@ type Client struct { } // ResetClient updates the client's state with new connection and context information. -func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) { +func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) { c.w = new(sync.Mutex) c.conn = conn - c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID()) + c.PlatformID = ctx.GetPlatformID() c.IsCompress = ctx.GetCompression() c.IsBackground = ctx.GetBackground() c.UserID = ctx.GetUserID() @@ -112,22 +110,6 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer c.subUserIDs = make(map[string]struct{}) } -func (c *Client) pingHandler(appData string) error { - if err := c.conn.SetReadDeadline(pongWait); err != nil { - return err - } - - log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData) - return c.writePongMsg(appData) -} - -func (c *Client) pongHandler(_ string) error { - if err := c.conn.SetReadDeadline(pongWait); err != nil { - return err - } - return nil -} - // readMessage continuously reads messages from the connection. func (c *Client) readMessage() { defer func() { @@ -138,52 +120,25 @@ func (c *Client) readMessage() { c.close() }() - c.conn.SetReadLimit(maxMessageSize) - _ = c.conn.SetReadDeadline(pongWait) - c.conn.SetPongHandler(c.pongHandler) - c.conn.SetPingHandler(c.pingHandler) - c.activeHeartbeat(c.hbCtx) - for { log.ZDebug(c.ctx, "readMessage") - messageType, message, returnErr := c.conn.ReadMessage() + message, returnErr := c.conn.ReadMessage() if returnErr != nil { - log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType) + log.ZWarn(c.ctx, "readMessage", returnErr) c.closedErr = returnErr return } - log.ZDebug(c.ctx, "readMessage", "messageType", messageType) if c.closed.Load() { // The scenario where the connection has just been closed, but the coroutine has not exited c.closedErr = ErrConnClosed return } - switch messageType { - case MessageBinary: - _ = c.conn.SetReadDeadline(pongWait) - parseDataErr := c.handleMessage(message) - if parseDataErr != nil { - c.closedErr = parseDataErr - return - } - case MessageText: - _ = c.conn.SetReadDeadline(pongWait) - parseDataErr := c.handlerTextMessage(message) - if parseDataErr != nil { - c.closedErr = parseDataErr - return - } - case PingMessage: - err := c.writePongMsg("") - log.ZError(c.ctx, "writePongMsg", err) - - case CloseMessage: - c.closedErr = ErrClientClosed + parseDataErr := c.handleMessage(message) + if parseDataErr != nil { + c.closedErr = parseDataErr return - - default: } } } @@ -358,109 +313,13 @@ func (c *Client) writeBinaryMsg(resp Resp) error { c.w.Lock() defer c.w.Unlock() - err = c.conn.SetWriteDeadline(writeWait) - if err != nil { - return err - } - if c.IsCompress { resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf) if compressErr != nil { return compressErr } - return c.conn.WriteMessage(MessageBinary, resultBuf) - } - - return c.conn.WriteMessage(MessageBinary, encodedBuf) -} - -// Actively initiate Heartbeat when platform in Web. -func (c *Client) activeHeartbeat(ctx context.Context) { - if c.PlatformID == constant.WebPlatformID { - go func() { - defer func() { - if r := recover(); r != nil { - log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r)) - } - }() - log.ZDebug(ctx, "server initiative send heartbeat start.") - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := c.writePingMsg(); err != nil { - log.ZWarn(c.ctx, "send Ping Message error.", err) - return - } - case <-c.hbCtx.Done(): - return - } - } - }() - } -} -func (c *Client) writePingMsg() error { - if c.closed.Load() { - return nil + return c.conn.WriteMessage(resultBuf) } - c.w.Lock() - defer c.w.Unlock() - - err := c.conn.SetWriteDeadline(writeWait) - if err != nil { - return err - } - - return c.conn.WriteMessage(PingMessage, nil) -} - -func (c *Client) writePongMsg(appData string) error { - log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData) - if c.closed.Load() { - log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr) - return nil - } - - c.w.Lock() - defer c.w.Unlock() - - err := c.conn.SetWriteDeadline(writeWait) - if err != nil { - log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData) - return errs.Wrap(err) - } - err = c.conn.WriteMessage(PongMessage, []byte(appData)) - if err != nil { - log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage) - } - - return errs.Wrap(err) -} - -func (c *Client) handlerTextMessage(b []byte) error { - var msg TextMessage - if err := json.Unmarshal(b, &msg); err != nil { - return err - } - switch msg.Type { - case TextPong: - return nil - case TextPing: - msg.Type = TextPong - msgData, err := json.Marshal(msg) - if err != nil { - return err - } - c.w.Lock() - defer c.w.Unlock() - if err := c.conn.SetWriteDeadline(writeWait); err != nil { - return err - } - return c.conn.WriteMessage(MessageText, msgData) - default: - return fmt.Errorf("not support message type %s", msg.Type) - } + return c.conn.WriteMessage(encodedBuf) } diff --git a/internal/msggateway/client_conn.go b/internal/msggateway/client_conn.go new file mode 100644 index 000000000..15a0d8c07 --- /dev/null +++ b/internal/msggateway/client_conn.go @@ -0,0 +1,229 @@ +package msggateway + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/openimsdk/tools/log" +) + +var ErrWriteFull = fmt.Errorf("websocket write buffer full,close connection") + +type ClientConn interface { + ReadMessage() ([]byte, error) + WriteMessage(message []byte) error + Close() error +} + +type websocketMessage struct { + MessageType int + Data []byte +} + +func NewWebSocketClientConn(conn *websocket.Conn, readLimit int64, readTimeout time.Duration, pingInterval time.Duration) ClientConn { + c := &websocketClientConn{ + readTimeout: readTimeout, + conn: conn, + writer: make(chan *websocketMessage, 256), + done: make(chan struct{}), + } + if readLimit > 0 { + c.conn.SetReadLimit(readLimit) + } + c.conn.SetPingHandler(c.pingHandler) + c.conn.SetPongHandler(c.pongHandler) + + go c.loopSend() + if pingInterval > 0 { + go c.doPing(pingInterval) + } + return c +} + +type websocketClientConn struct { + readTimeout time.Duration + conn *websocket.Conn + writer chan *websocketMessage + done chan struct{} + err atomic.Pointer[error] +} + +func (c *websocketClientConn) ReadMessage() ([]byte, error) { + buf, err := c.readMessage() + if err != nil { + return nil, c.closeBy(fmt.Errorf("read message %w", err)) + } + return buf, nil +} + +func (c *websocketClientConn) WriteMessage(message []byte) error { + return c.writeMessage(websocket.BinaryMessage, message) +} + +func (c *websocketClientConn) Close() error { + return c.closeBy(fmt.Errorf("websocket connection closed")) +} + +func (c *websocketClientConn) closeBy(err error) error { + if !c.err.CompareAndSwap(nil, &err) { + return *c.err.Load() + } + close(c.done) + log.ZWarn(context.Background(), "websocket connection closed", err, "remoteAddr", c.conn.RemoteAddr(), + "chan length", len(c.writer)) + return err +} + +func (c *websocketClientConn) writeMessage(messageType int, data []byte) error { + if errPtr := c.err.Load(); errPtr != nil { + return *errPtr + } + select { + case c.writer <- &websocketMessage{MessageType: messageType, Data: data}: + return nil + default: + return c.closeBy(ErrWriteFull) + } +} + +func (c *websocketClientConn) loopSend() { + defer func() { + _ = c.conn.Close() + }() + var err error + for { + select { + case <-c.done: + for { + select { + case msg := <-c.writer: + switch msg.MessageType { + case websocket.TextMessage, websocket.BinaryMessage: + err = c.conn.WriteMessage(msg.MessageType, msg.Data) + default: + err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{}) + } + if err != nil { + _ = c.closeBy(err) + return + } + default: + return + } + } + case msg := <-c.writer: + switch msg.MessageType { + case websocket.TextMessage, websocket.BinaryMessage: + err = c.conn.WriteMessage(msg.MessageType, msg.Data) + default: + err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{}) + } + if err != nil { + _ = c.closeBy(err) + return + } + } + } +} + +func (c *websocketClientConn) setReadDeadline() error { + deadline := time.Now().Add(c.readTimeout) + return c.conn.SetReadDeadline(deadline) +} + +func (c *websocketClientConn) readMessage() ([]byte, error) { + for { + if err := c.setReadDeadline(); err != nil { + return nil, err + } + messageType, buf, err := c.conn.ReadMessage() + if err != nil { + return nil, err + } + switch messageType { + case websocket.BinaryMessage: + return buf, nil + case websocket.TextMessage: + if err := c.onReadTextMessage(buf); err != nil { + return nil, err + } + case websocket.PingMessage: + if err := c.pingHandler(string(buf)); err != nil { + return nil, err + } + case websocket.PongMessage: + if err := c.pongHandler(string(buf)); err != nil { + return nil, err + } + case websocket.CloseMessage: + if len(buf) == 0 { + return nil, errors.New("websocket connection closed by peer") + } + return nil, fmt.Errorf("websocket connection closed by peer, data %s", string(buf)) + default: + return nil, fmt.Errorf("unknown websocket message type %d", messageType) + } + } +} + +func (c *websocketClientConn) onReadTextMessage(buf []byte) error { + var msg struct { + Type string `json:"type"` + Body json.RawMessage `json:"body"` + } + if err := json.Unmarshal(buf, &msg); err != nil { + return err + } + switch msg.Type { + case TextPong: + return nil + case TextPing: + msg.Type = TextPong + msgData, err := json.Marshal(msg) + if err != nil { + return err + } + return c.writeMessage(websocket.TextMessage, msgData) + default: + return fmt.Errorf("not support text message type %s", msg.Type) + } +} + +func (c *websocketClientConn) pingHandler(appData string) error { + log.ZDebug(context.Background(), "ping handler recv ping", "remoteAddr", c.conn.RemoteAddr(), "appData", appData) + if err := c.setReadDeadline(); err != nil { + return err + } + err := c.conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second*1)) + if err != nil { + log.ZWarn(context.Background(), "ping handler write pong error", err, "remoteAddr", c.conn.RemoteAddr(), "appData", appData) + } + log.ZDebug(context.Background(), "ping handler write pong success", "remoteAddr", c.conn.RemoteAddr(), "appData", appData) + return nil +} + +func (c *websocketClientConn) pongHandler(string) error { + return nil +} + +func (c *websocketClientConn) doPing(d time.Duration) { + ticker := time.NewTicker(d) + defer ticker.Stop() + for { + select { + case <-c.done: + return + case <-ticker.C: + if err := c.writeMessage(websocket.PingMessage, nil); err != nil { + _ = c.closeBy(fmt.Errorf("send ping %w", err)) + return + } + } + } +} diff --git a/internal/msggateway/context.go b/internal/msggateway/context.go index 37b5a7cdc..9fa28e667 100644 --- a/internal/msggateway/context.go +++ b/internal/msggateway/context.go @@ -15,6 +15,8 @@ package msggateway import ( + "encoding/base64" + "encoding/json" "net/http" "net/url" "strconv" @@ -24,10 +26,21 @@ import ( "github.com/openimsdk/protocol/constant" "github.com/openimsdk/tools/utils/encrypt" - "github.com/openimsdk/tools/utils/stringutil" "github.com/openimsdk/tools/utils/timeutil" ) +type UserConnContextInfo struct { + Token string `json:"token"` + UserID string `json:"userID"` + PlatformID int `json:"platformID"` + OperationID string `json:"operationID"` + Compression string `json:"compression"` + SDKType string `json:"sdkType"` + SendResponse bool `json:"sendResponse"` + Background bool `json:"background"` + SDKVersion string `json:"sdkVersion"` +} + type UserConnContext struct { RespWriter http.ResponseWriter Req *http.Request @@ -35,6 +48,7 @@ type UserConnContext struct { Method string RemoteAddr string ConnID string + info *UserConnContextInfo } func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) { @@ -58,9 +72,11 @@ func (c *UserConnContext) Value(key any) any { case constant.ConnID: return c.GetConnID() case constant.OpUserPlatform: - return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID())) + return c.GetPlatformID() case constant.RemoteAddr: return c.RemoteAddr + case SDKVersion: + return c.info.SDKVersion default: return "" } @@ -83,28 +99,90 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont func newTempContext() *UserConnContext { return &UserConnContext{ - Req: &http.Request{URL: &url.URL{}}, + Req: &http.Request{URL: &url.URL{}}, + info: &UserConnContextInfo{}, } } -func (c *UserConnContext) GetRemoteAddr() string { - return c.RemoteAddr +func (c *UserConnContext) ParseEssentialArgs() error { + query := c.Req.URL.Query() + if data := query.Get("v"); data != "" { + return c.parseByJson(data) + } else { + return c.parseByQuery(query, c.Req.Header) + } +} + +func (c *UserConnContext) parseByQuery(query url.Values, header http.Header) error { + info := UserConnContextInfo{ + Token: query.Get(Token), + UserID: query.Get(WsUserID), + OperationID: query.Get(OperationID), + Compression: query.Get(Compression), + SDKType: query.Get(SDKType), + SDKVersion: query.Get(SDKVersion), + } + platformID, err := strconv.Atoi(query.Get(PlatformID)) + if err != nil { + return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int") + } + info.PlatformID = platformID + if val := query.Get(SendResponse); val != "" { + ok, err := strconv.ParseBool(val) + if err != nil { + return servererrs.ErrConnArgsErr.WrapMsg("isMsgResp is not bool") + } + info.SendResponse = ok + } + if info.Compression == "" { + info.Compression = header.Get(Compression) + } + background, err := strconv.ParseBool(query.Get(BackgroundStatus)) + if err != nil { + return err + } + info.Background = background + return c.checkInfo(&info) } -func (c *UserConnContext) Query(key string) (string, bool) { - var value string - if value = c.Req.URL.Query().Get(key); value == "" { - return value, false +func (c *UserConnContext) parseByJson(data string) error { + reqInfo, err := base64.RawURLEncoding.DecodeString(data) + if err != nil { + return servererrs.ErrConnArgsErr.WrapMsg("data is not base64") + } + var info UserConnContextInfo + if err := json.Unmarshal(reqInfo, &info); err != nil { + return servererrs.ErrConnArgsErr.WrapMsg("data is not json", "info", err.Error()) } - return value, true + return c.checkInfo(&info) } -func (c *UserConnContext) GetHeader(key string) (string, bool) { - var value string - if value = c.Req.Header.Get(key); value == "" { - return value, false +func (c *UserConnContext) checkInfo(info *UserConnContextInfo) error { + if info.OperationID == "" { + return servererrs.ErrConnArgsErr.WrapMsg("operationID is empty") + } + if info.Token == "" { + return servererrs.ErrConnArgsErr.WrapMsg("token is empty") + } + if info.UserID == "" { + return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty") + } + if _, ok := constant.PlatformID2Name[info.PlatformID]; !ok { + return servererrs.ErrConnArgsErr.WrapMsg("platformID is invalid") + } + switch info.SDKType { + case "": + info.SDKType = GoSDK + case GoSDK, JsSDK: + default: + return servererrs.ErrConnArgsErr.WrapMsg("sdkType is invalid") } - return value, true + c.info = info + return nil +} + +func (c *UserConnContext) GetRemoteAddr() string { + return c.RemoteAddr } func (c *UserConnContext) SetHeader(key, value string) { @@ -120,97 +198,76 @@ func (c *UserConnContext) GetConnID() string { } func (c *UserConnContext) GetUserID() string { - return c.Req.URL.Query().Get(WsUserID) + if c == nil || c.info == nil { + return "" + } + return c.info.UserID } -func (c *UserConnContext) GetPlatformID() string { - return c.Req.URL.Query().Get(PlatformID) +func (c *UserConnContext) GetPlatformID() int { + if c == nil || c.info == nil { + return 0 + } + return c.info.PlatformID } func (c *UserConnContext) GetOperationID() string { - return c.Req.URL.Query().Get(OperationID) + if c == nil || c.info == nil { + return "" + } + return c.info.OperationID } func (c *UserConnContext) SetOperationID(operationID string) { - values := c.Req.URL.Query() - values.Set(OperationID, operationID) - c.Req.URL.RawQuery = values.Encode() + if c.info == nil { + c.info = &UserConnContextInfo{} + } + c.info.OperationID = operationID } func (c *UserConnContext) GetToken() string { - return c.Req.URL.Query().Get(Token) + if c == nil || c.info == nil { + return "" + } + return c.info.Token } -func (c *UserConnContext) GetSDKVersion() string { - return c.Req.URL.Query().Get(SDKVersion) +func (c *UserConnContext) GetCompression() bool { + return c != nil && c.info != nil && c.info.Compression == GzipCompressionProtocol } -func (c *UserConnContext) GetCompression() bool { - compression, exists := c.Query(Compression) - if exists && compression == GzipCompressionProtocol { - return true - } else { - compression, exists := c.GetHeader(Compression) - if exists && compression == GzipCompressionProtocol { - return true - } +func (c *UserConnContext) GetSDKType() string { + if c == nil || c.info == nil { + return GoSDK + } + switch c.info.SDKType { + case "", GoSDK: + return GoSDK + case JsSDK: + return JsSDK + default: + return "" } - return false } -func (c *UserConnContext) GetSDKType() string { - sdkType := c.Req.URL.Query().Get(SDKType) - if sdkType == "" { - sdkType = GoSDK +func (c *UserConnContext) GetSDKVersion() string { + if c == nil || c.info == nil { + return "" } - return sdkType + return c.info.SDKVersion } func (c *UserConnContext) ShouldSendResp() bool { - errResp, exists := c.Query(SendResponse) - if exists { - b, err := strconv.ParseBool(errResp) - if err != nil { - return false - } else { - return b - } - } - return false + return c != nil && c.info != nil && c.info.SendResponse } func (c *UserConnContext) SetToken(token string) { - c.Req.URL.RawQuery = Token + "=" + token + if c.info == nil { + c.info = &UserConnContextInfo{} + } + c.info.Token = token } func (c *UserConnContext) GetBackground() bool { - b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus)) - if err != nil { - return false - } - return b -} -func (c *UserConnContext) ParseEssentialArgs() error { - _, exists := c.Query(Token) - if !exists { - return servererrs.ErrConnArgsErr.WrapMsg("token is empty") - } - _, exists = c.Query(WsUserID) - if !exists { - return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty") - } - platformIDStr, exists := c.Query(PlatformID) - if !exists { - return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty") - } - _, err := strconv.Atoi(platformIDStr) - if err != nil { - return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int") - } - switch sdkType, _ := c.Query(SDKType); sdkType { - case "", GoSDK, JsSDK: - default: - return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js") - } - return nil + return c != nil && c.info != nil && c.info.Background } diff --git a/internal/msggateway/long_conn.go b/internal/msggateway/long_conn.go deleted file mode 100644 index c1b3e27c9..000000000 --- a/internal/msggateway/long_conn.go +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright © 2023 OpenIM. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package msggateway - -import ( - "encoding/json" - "net/http" - "time" - - "github.com/openimsdk/tools/apiresp" - - "github.com/gorilla/websocket" - "github.com/openimsdk/tools/errs" -) - -type LongConn interface { - // Close this connection - Close() error - // WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1). - WriteMessage(messageType int, message []byte) error - // ReadMessage Read message from connection. - ReadMessage() (int, []byte, error) - // SetReadDeadline sets the read deadline on the underlying network connection, - // after a read has timed out, will return an error. - SetReadDeadline(timeout time.Duration) error - // SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error. - SetWriteDeadline(timeout time.Duration) error - // Dial Try to dial a connection,url must set auth args,header can control compress data - Dial(urlStr string, requestHeader http.Header) (*http.Response, error) - // IsNil Whether the connection of the current long connection is nil - IsNil() bool - // SetConnNil Set the connection of the current long connection to nil - SetConnNil() - // SetReadLimit sets the maximum size for a message read from the peer.bytes - SetReadLimit(limit int64) - SetPongHandler(handler PingPongHandler) - SetPingHandler(handler PingPongHandler) - // GenerateLongConn Check the connection of the current and when it was sent are the same - GenerateLongConn(w http.ResponseWriter, r *http.Request) error -} -type GWebSocket struct { - protocolType int - conn *websocket.Conn - handshakeTimeout time.Duration - writeBufferSize int -} - -func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket { - return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs} -} - -func (d *GWebSocket) Close() error { - return d.conn.Close() -} - -func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error { - upgrader := &websocket.Upgrader{ - HandshakeTimeout: d.handshakeTimeout, - CheckOrigin: func(r *http.Request) bool { return true }, - } - if d.writeBufferSize > 0 { // default is 4kb. - upgrader.WriteBufferSize = d.writeBufferSize - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - // The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade - return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed") - } - d.conn = conn - return nil -} - -func (d *GWebSocket) WriteMessage(messageType int, message []byte) error { - // d.setSendConn(d.conn) - return d.conn.WriteMessage(messageType, message) -} - -// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) { -// d.sendConn = sendConn -//} - -func (d *GWebSocket) ReadMessage() (int, []byte, error) { - return d.conn.ReadMessage() -} - -func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error { - return d.conn.SetReadDeadline(time.Now().Add(timeout)) -} - -func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error { - if timeout <= 0 { - return errs.New("timeout must be greater than 0") - } - - // TODO SetWriteDeadline Future add error handling - if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { - return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed") - } - return nil -} - -func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { - conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader) - if err != nil { - return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr) - } - d.conn = conn - return httpResp, nil -} - -func (d *GWebSocket) IsNil() bool { - return d.conn == nil - // - // if d.conn != nil { - // return false - // } - // return true -} - -func (d *GWebSocket) SetConnNil() { - d.conn = nil -} - -func (d *GWebSocket) SetReadLimit(limit int64) { - d.conn.SetReadLimit(limit) -} - -func (d *GWebSocket) SetPongHandler(handler PingPongHandler) { - d.conn.SetPongHandler(handler) -} - -func (d *GWebSocket) SetPingHandler(handler PingPongHandler) { - d.conn.SetPingHandler(handler) -} - -func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error { - if err := d.GenerateLongConn(w, r); err != nil { - return err - } - data, err := json.Marshal(apiresp.ParseError(err)) - if err != nil { - _ = d.Close() - return errs.WrapMsg(err, "json marshal failed") - } - - if err := d.WriteMessage(MessageText, data); err != nil { - _ = d.Close() - return errs.WrapMsg(err, "WriteMessage failed") - } - _ = d.Close() - return nil -} - -func (d *GWebSocket) RespondWithSuccess() error { - data, err := json.Marshal(apiresp.ParseError(nil)) - if err != nil { - _ = d.Close() - return errs.WrapMsg(err, "json marshal failed") - } - - if err := d.WriteMessage(MessageText, data); err != nil { - _ = d.Close() - return errs.WrapMsg(err, "WriteMessage failed") - } - return nil -} diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index d490cc8b9..0f7e1f8e6 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -2,18 +2,20 @@ package msggateway import ( "context" + "encoding/json" "fmt" "net/http" "sync" "sync/atomic" "time" + "github.com/gorilla/websocket" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" + "github.com/openimsdk/tools/apiresp" "github.com/openimsdk/open-im-server/v3/pkg/common/webhook" "github.com/openimsdk/open-im-server/v3/pkg/rpccache" pbAuth "github.com/openimsdk/protocol/auth" - "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/mcontext" "github.com/go-playground/validator/v10" @@ -23,10 +25,11 @@ import ( "github.com/openimsdk/protocol/msggateway" "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/log" - "github.com/openimsdk/tools/utils/stringutil" "golang.org/x/sync/errgroup" ) +var wsSuccessResponse, _ = json.Marshal(&apiresp.ApiResponse{}) + type LongConnServer interface { Run(ctx context.Context) error wsHandler(w http.ResponseWriter, r *http.Request) @@ -43,6 +46,7 @@ type LongConnServer interface { } type WsServer struct { + websocket *websocket.Upgrader msgGatewayConfig *Config port int wsMaxConnNum int64 @@ -136,9 +140,13 @@ func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer { o(&config) } //userRpcClient := rpcclient.NewUserRpcClient(client, config.Discovery.RpcService.User, config.Share.IMAdminUser) - + upgrader := &websocket.Upgrader{ + HandshakeTimeout: config.handshakeTimeout, + CheckOrigin: func(r *http.Request) bool { return true }, + } v := validator.New() return &WsServer{ + websocket: upgrader, msgGatewayConfig: msgGatewayConfig, port: config.port, wsMaxConnNum: config.maxConnNum, @@ -260,8 +268,7 @@ func (ws *WsServer) registerClient(client *Client) { ) oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID) - log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID, - "sdkVersion", client.SDKVersion) + log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID) if !userOK { ws.clients.Set(client.UserID, client) @@ -448,7 +455,7 @@ func (ws *WsServer) unregisterClient(client *Client) { // validateRespWithRequest checks if the response matches the expected userID and platformID. func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error { userID := ctx.GetUserID() - platformID := stringutil.StringToInt32(ctx.GetPlatformID()) + platformID := int32(ctx.GetPlatformID()) if resp.UserID != userID { return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID)) } @@ -458,19 +465,37 @@ func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.P return nil } +func (ws *WsServer) handlerError(ctx *UserConnContext, w http.ResponseWriter, r *http.Request, err error) { + if !ctx.ShouldSendResp() { + httpError(ctx, err) + return + } + // the browser cannot get the response of upgrade failure + data, err := json.Marshal(apiresp.ParseError(err)) + if err != nil { + log.ZError(ctx, "json marshal failed", err) + return + } + conn, upgradeErr := ws.websocket.Upgrade(w, r, nil) + if upgradeErr != nil { + log.ZWarn(ctx, "websocket upgrade failed", upgradeErr, "respErr", err, "resp", string(data)) + return + } + defer conn.Close() + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + log.ZWarn(ctx, "WriteMessage failed", err, "respErr", err, "resp", string(data)) + return + } +} + func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { // Create a new connection context connContext := newContext(w, r) - if !ws.ready.Load() { - httpError(connContext, errs.New("ws server not ready")) - return - } - // Check if the current number of online user connections exceeds the maximum limit if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum { // If it exceeds the maximum connection number, return an error via HTTP and stop processing - httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")) + ws.handlerError(connContext, w, r, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit")) return } @@ -478,31 +503,14 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { err := connContext.ParseEssentialArgs() if err != nil { // If there's an error during parsing, return an error via HTTP and stop processing - - httpError(connContext, err) - return - } - - if ws.authClient == nil { - httpError(connContext, errs.New("auth client is not initialized")) + ws.handlerError(connContext, w, r, err) return } // Call the authentication client to parse the Token obtained from the context resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken()) if err != nil { - // If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag - shouldSendError := connContext.ShouldSendResp() - if shouldSendError { - // Create a WebSocket connection object and attempt to send the error message via WebSocket - wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) - if err := wsLongConn.RespondWithError(err, w, r); err == nil { - // If the error message is successfully sent via WebSocket, stop processing - return - } - } - // If sending via WebSocket is not required or fails, return the error via HTTP and stop processing - httpError(connContext, err) + ws.handlerError(connContext, w, r, err) return } @@ -510,32 +518,30 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { err = ws.validateRespWithRequest(connContext, resp) if err != nil { // If validation fails, return an error via HTTP and stop processing - httpError(connContext, err) + ws.handlerError(connContext, w, r, err) return } - - log.ZDebug(connContext, "new conn", "token", connContext.GetToken()) - // Create a WebSocket long connection object - wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize) - if err := wsLongConn.GenerateLongConn(w, r); err != nil { - //If the creation of the long connection fails, the error is handled internally during the handshake process. - log.ZWarn(connContext, "long connection fails", err) + conn, err := ws.websocket.Upgrade(w, r, nil) + if err != nil { + log.ZWarn(connContext, "websocket upgrade failed", err) return - } else { - // Check if a normal response should be sent via WebSocket - shouldSendSuccessResp := connContext.ShouldSendResp() - if shouldSendSuccessResp { - // Attempt to send a success message through WebSocket - if err := wsLongConn.RespondWithSuccess(); err != nil { - // If the success message is successfully sent, end further processing - return - } + } + if connContext.ShouldSendResp() { + if err := conn.WriteMessage(websocket.TextMessage, wsSuccessResponse); err != nil { + log.ZWarn(connContext, "WriteMessage first response", err) + return } } - // Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection - client := ws.clientPool.Get().(*Client) - client.ResetClient(connContext, wsLongConn, ws) + log.ZDebug(connContext, "new conn", "token", connContext.GetToken()) + + var pingInterval time.Duration + if connContext.GetPlatformID() == constant.WebPlatformID { + pingInterval = pingPeriod + } + + client := new(Client) + client.ResetClient(connContext, NewWebSocketClientConn(conn, maxMessageSize, pongWait, pingInterval), ws) // Register the client with the server and start message processing ws.registerChan <- client