diff --git a/cmd/openim-push/main.go b/cmd/openim-push/main.go index e0539fa52..bd31ffdef 100644 --- a/cmd/openim-push/main.go +++ b/cmd/openim-push/main.go @@ -26,7 +26,7 @@ func main() { pushCmd.AddPortFlag() pushCmd.AddPrometheusPortFlag() if err := pushCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := pushCmd.StartSvr(config.Config.RpcRegisterName.OpenImPushName, push.Start); err != nil { util.ExitWithError(err) diff --git a/cmd/openim-rpc/openim-rpc-auth/main.go b/cmd/openim-rpc/openim-rpc-auth/main.go index b526c3b86..992a2b432 100644 --- a/cmd/openim-rpc/openim-rpc-auth/main.go +++ b/cmd/openim-rpc/openim-rpc-auth/main.go @@ -26,7 +26,7 @@ func main() { authCmd.AddPortFlag() authCmd.AddPrometheusPortFlag() if err := authCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := authCmd.StartSvr(config.Config.RpcRegisterName.OpenImAuthName, auth.Start); err != nil { util.ExitWithError(err) diff --git a/cmd/openim-rpc/openim-rpc-conversation/main.go b/cmd/openim-rpc/openim-rpc-conversation/main.go index bde191c51..10fe0b46c 100644 --- a/cmd/openim-rpc/openim-rpc-conversation/main.go +++ b/cmd/openim-rpc/openim-rpc-conversation/main.go @@ -26,7 +26,7 @@ func main() { rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImConversationName, conversation.Start); err != nil { util.ExitWithError(err) diff --git a/cmd/openim-rpc/openim-rpc-friend/main.go b/cmd/openim-rpc/openim-rpc-friend/main.go index 8eeb9c8e1..63de23293 100644 --- a/cmd/openim-rpc/openim-rpc-friend/main.go +++ b/cmd/openim-rpc/openim-rpc-friend/main.go @@ -26,7 +26,7 @@ func main() { rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImFriendName, friend.Start); err != nil { util.ExitWithError(err) diff --git a/cmd/openim-rpc/openim-rpc-group/main.go b/cmd/openim-rpc/openim-rpc-group/main.go index a5842ffd1..c0780acab 100644 --- a/cmd/openim-rpc/openim-rpc-group/main.go +++ b/cmd/openim-rpc/openim-rpc-group/main.go @@ -26,7 +26,7 @@ func main() { rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImGroupName, group.Start); err != nil { util.ExitWithError(err) diff --git a/cmd/openim-rpc/openim-rpc-msg/main.go b/cmd/openim-rpc/openim-rpc-msg/main.go index b3895a502..62bdff0a5 100644 --- a/cmd/openim-rpc/openim-rpc-msg/main.go +++ b/cmd/openim-rpc/openim-rpc-msg/main.go @@ -26,7 +26,7 @@ func main() { rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImMsgName, msg.Start); err != nil { util.ExitWithError(err) diff --git a/cmd/openim-rpc/openim-rpc-third/main.go b/cmd/openim-rpc/openim-rpc-third/main.go index 8f390bb6a..c2893a398 100644 --- a/cmd/openim-rpc/openim-rpc-third/main.go +++ b/cmd/openim-rpc/openim-rpc-third/main.go @@ -26,7 +26,7 @@ func main() { rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImThirdName, third.Start); err != nil { util.ExitWithError(err) diff --git a/cmd/openim-rpc/openim-rpc-user/main.go b/cmd/openim-rpc/openim-rpc-user/main.go index 6994ea2b1..f7948bda0 100644 --- a/cmd/openim-rpc/openim-rpc-user/main.go +++ b/cmd/openim-rpc/openim-rpc-user/main.go @@ -26,7 +26,7 @@ func main() { rpcCmd.AddPortFlag() rpcCmd.AddPrometheusPortFlag() if err := rpcCmd.Exec(); err != nil { - panic(err.Error()) + util.ExitWithError(err) } if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImUserName, user.Start); err != nil { util.ExitWithError(err) diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 9a4005e6c..06efea12f 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -91,13 +91,7 @@ type Client struct { // } // ResetClient updates the client's state with new connection and context information. -func (c *Client) ResetClient( - ctx *UserConnContext, - conn LongConn, - isBackground, isCompress bool, - longConnServer LongConnServer, - token string, -) { +func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isBackground, isCompress bool, longConnServer LongConnServer, token string) { c.w = new(sync.Mutex) c.conn = conn c.PlatformID = utils.StringToInt(ctx.GetPlatformID()) @@ -112,9 +106,11 @@ func (c *Client) ResetClient( c.token = token } -// pingHandler handles ping messages and sends pong responses. func (c *Client) pingHandler(_ string) error { - _ = c.conn.SetReadDeadline(pongWait) + if err := c.conn.SetReadDeadline(pongWait); err != nil { + return err + } + return c.writePongMsg() } @@ -141,7 +137,8 @@ func (c *Client) readMessage() { } log.ZDebug(c.ctx, "readMessage", "messageType", messageType) - if c.closed.Load() { // 连接刚置位已经关闭,但是协程还没退出的场景 + if c.closed.Load() { + // The scenario where the connection has just been closed, but the coroutine has not exited c.closedErr = ErrConnClosed return } @@ -185,11 +182,11 @@ func (c *Client) handleMessage(message []byte) error { err := c.longConnServer.Decode(message, binaryReq) if err != nil { - return errs.Wrap(err) + return err } if err := c.longConnServer.Validate(binaryReq); err != nil { - return errs.Wrap(err) + return err } if binaryReq.SendID != c.UserID { @@ -239,7 +236,7 @@ func (c *Client) setAppBackgroundStatus(ctx context.Context, req *Req) ([]byte, } c.IsBackground = isBackground - // todo callback + // TODO: callback return resp, nil } @@ -273,7 +270,7 @@ func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, re } if binaryReq.ReqIdentifier == WsLogoutMsg { - return errors.New("user logout") + return errs.Wrap(errors.New("user logout")) } return nil } @@ -316,17 +313,21 @@ func (c *Client) writeBinaryMsg(resp Resp) error { encodedBuf, err := c.longConnServer.Encode(resp) if err != nil { - return errs.Wrap(err) + return err } c.w.Lock() defer c.w.Unlock() - _ = c.conn.SetWriteDeadline(writeWait) + err = c.conn.SetWriteDeadline(writeWait) + if err != nil { + return err + } + if c.IsCompress { resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf) if compressErr != nil { - return errs.Wrap(compressErr) + return compressErr } return c.conn.WriteMessage(MessageBinary, resultBuf) } @@ -344,7 +345,7 @@ func (c *Client) writePongMsg() error { err := c.conn.SetWriteDeadline(writeWait) if err != nil { - return errs.Wrap(err) + return err } return c.conn.WriteMessage(PongMessage, nil) diff --git a/internal/msggateway/compressor.go b/internal/msggateway/compressor.go index d4789536e..140aac4d8 100644 --- a/internal/msggateway/compressor.go +++ b/internal/msggateway/compressor.go @@ -17,7 +17,6 @@ package msggateway import ( "bytes" "compress/gzip" - "errors" "io" "sync" @@ -46,12 +45,15 @@ func NewGzipCompressor() *GzipCompressor { func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) { gzipBuffer := bytes.Buffer{} gz := gzip.NewWriter(&gzipBuffer) + if _, err := gz.Write(rawData); err != nil { - return nil, errs.Wrap(err) + return nil, errs.Wrap(err, "GzipCompressor.Compress: writing to gzip writer failed") } + if err := gz.Close(); err != nil { - return nil, errs.Wrap(err) + return nil, errs.Wrap(err, "GzipCompressor.Compress: closing gzip writer failed") } + return gzipBuffer.Bytes(), nil } @@ -63,10 +65,10 @@ func (g *GzipCompressor) CompressWithPool(rawData []byte) ([]byte, error) { gz.Reset(&gzipBuffer) if _, err := gz.Write(rawData); err != nil { - return nil, errs.Wrap(err) + return nil, errs.Wrap(err, "GzipCompressor.CompressWithPool: error writing data") } if err := gz.Close(); err != nil { - return nil, errs.Wrap(err) + return nil, errs.Wrap(err, "GzipCompressor.CompressWithPool: error closing gzip writer") } return gzipBuffer.Bytes(), nil } @@ -75,32 +77,36 @@ func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) { buff := bytes.NewBuffer(compressedData) reader, err := gzip.NewReader(buff) if err != nil { - return nil, errs.Wrap(err, "NewReader failed") + return nil, errs.Wrap(err, "GzipCompressor.DeCompress: NewReader creation failed") } - compressedData, err = io.ReadAll(reader) + decompressedData, err := io.ReadAll(reader) if err != nil { - return nil, errs.Wrap(err, "ReadAll failed") + return nil, errs.Wrap(err, "GzipCompressor.DeCompress: reading from gzip reader failed") } - _ = reader.Close() - return compressedData, nil + if err = reader.Close(); err != nil { + // Even if closing the reader fails, we've successfully read the data, + // so we return the decompressed data and an error indicating the close failure. + return decompressedData, errs.Wrap(err, "GzipCompressor.DeCompress: closing gzip reader failed") + } + return decompressedData, nil } func (g *GzipCompressor) DecompressWithPool(compressedData []byte) ([]byte, error) { reader := gzipReaderPool.Get().(*gzip.Reader) - if reader == nil { - return nil, errs.Wrap(errors.New("NewReader failed")) - } defer gzipReaderPool.Put(reader) err := reader.Reset(bytes.NewReader(compressedData)) if err != nil { - return nil, errs.Wrap(err, "NewReader failed") + return nil, errs.Wrap(err, "GzipCompressor.DecompressWithPool: resetting gzip reader failed") } - compressedData, err = io.ReadAll(reader) + decompressedData, err := io.ReadAll(reader) if err != nil { - return nil, errs.Wrap(err, "ReadAll failed") + return nil, errs.Wrap(err, "GzipCompressor.DecompressWithPool: reading from pooled gzip reader failed") + } + if err = reader.Close(); err != nil { + // Similar to DeCompress, return the data and error for close failure. + return decompressedData, errs.Wrap(err, "GzipCompressor.DecompressWithPool: closing pooled gzip reader failed") } - _ = reader.Close() - return compressedData, nil + return decompressedData, nil } diff --git a/internal/msggateway/encoder.go b/internal/msggateway/encoder.go index 69a899591..cd2c50d96 100644 --- a/internal/msggateway/encoder.go +++ b/internal/msggateway/encoder.go @@ -37,7 +37,7 @@ func (g *GobEncoder) Encode(data any) ([]byte, error) { enc := gob.NewEncoder(&buff) err := enc.Encode(data) if err != nil { - return nil, err + return nil, errs.Wrap(err, "GobEncoder.Encode failed") } return buff.Bytes(), nil } @@ -47,7 +47,7 @@ func (g *GobEncoder) Decode(encodeData []byte, decodeData any) error { dec := gob.NewDecoder(buff) err := dec.Decode(decodeData) if err != nil { - return errs.Wrap(err) + return errs.Wrap(err, "GobEncoder.Decode failed") } return nil } diff --git a/internal/msggateway/long_conn.go b/internal/msggateway/long_conn.go index a4251a50f..d475aa52b 100644 --- a/internal/msggateway/long_conn.go +++ b/internal/msggateway/long_conn.go @@ -15,9 +15,11 @@ package msggateway import ( + "errors" "net/http" "time" + "github.com/OpenIMSDK/tools/errs" "github.com/gorilla/websocket" ) @@ -96,7 +98,16 @@ func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error { } func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error { - return d.conn.SetWriteDeadline(time.Now().Add(timeout)) + // TODO add error + if timeout <= 0 { + return errs.Wrap(errors.New("timeout must be greater than 0")) + } + + // TODO SetWriteDeadline Future add error handling + if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { + return errs.Wrap(err, "GWebSocket.SetWriteDeadline failed") + } + return nil } func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { diff --git a/internal/msggateway/n_ws_server.go b/internal/msggateway/n_ws_server.go index b734dee6d..0620fa5b9 100644 --- a/internal/msggateway/n_ws_server.go +++ b/internal/msggateway/n_ws_server.go @@ -88,6 +88,7 @@ type WsServer struct { Encoder MessageHandler } + type kickHandler struct { clientOK bool oldClients []*Client @@ -129,7 +130,9 @@ func (ws *WsServer) UnRegister(c *Client) { } func (ws *WsServer) Validate(s any) error { - //?question? + if s == nil { + return errs.Wrap(errors.New("input cannot be nil")) + } return nil } diff --git a/pkg/common/db/cache/init_redis.go b/pkg/common/db/cache/init_redis.go index 3cec73be5..a41a4f460 100644 --- a/pkg/common/db/cache/init_redis.go +++ b/pkg/common/db/cache/init_redis.go @@ -49,7 +49,7 @@ func NewRedis() (redis.UniversalClient, error) { overrideConfigFromEnv() if len(config.Config.Redis.Address) == 0 { - return nil, errs.Wrap(errors.New("redis address is empty")) + return nil, errs.Wrap(errors.New("redis address is empty"), "Redis configuration error") } specialerror.AddReplace(redis.Nil, errs.ErrRecordNotFound) var rdb redis.UniversalClient @@ -65,9 +65,9 @@ func NewRedis() (redis.UniversalClient, error) { rdb = redis.NewClient(&redis.Options{ Addr: config.Config.Redis.Address[0], Username: config.Config.Redis.Username, - Password: config.Config.Redis.Password, - DB: 0, // use default DB - PoolSize: 100, // connection pool size + Password: config.Config.Redis.Password, // no password set + DB: 0, // use default DB + PoolSize: 100, // connection pool size MaxRetries: maxRetry, }) } @@ -77,9 +77,9 @@ func NewRedis() (redis.UniversalClient, error) { defer cancel() err = rdb.Ping(ctx).Err() if err != nil { - uriFormat := "address:%s, username:%s, password:%s, clusterMode:%t, enablePipeline:%t" - errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.Password, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline) - return nil, errs.Wrap(err, errMsg) + uriFormat := "address:%v, username:%s, clusterMode:%t, enablePipeline:%t" + errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline) + return nil, errs.Wrap(err, "Redis connection failed: %s", errMsg) } redisClient = rdb return rdb, err @@ -98,9 +98,11 @@ func overrideConfigFromEnv() { config.Config.Redis.Address = strings.Split(envAddr, ",") } } + if envUser := os.Getenv("REDIS_USERNAME"); envUser != "" { config.Config.Redis.Username = envUser } + if envPass := os.Getenv("REDIS_PASSWORD"); envPass != "" { config.Config.Redis.Password = envPass }