diff --git a/go.mod b/go.mod index 3609369a5..88163374d 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/mitchellh/mapstructure v1.5.0 github.com/openimsdk/protocol v0.0.73-alpha.17 - github.com/openimsdk/tools v0.0.50-alpha.103 + github.com/openimsdk/tools v0.0.50-alpha.105 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.18.0 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index 4eba14c66..2f8eeca7c 100644 --- a/go.sum +++ b/go.sum @@ -349,8 +349,8 @@ github.com/openimsdk/gomake v0.0.15-alpha.11 h1:PQudYDRESYeYlUYrrLLJhYIlUPO5x7FA github.com/openimsdk/gomake v0.0.15-alpha.11/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= github.com/openimsdk/protocol v0.0.73-alpha.17 h1:ddo0QMns1GVwAmrPIPlAQ7uKmThAYLnOt+CIOgLsJyE= github.com/openimsdk/protocol v0.0.73-alpha.17/go.mod h1:WF7EuE55vQvpyUAzDXcqg+B+446xQyEba0X35lTINmw= -github.com/openimsdk/tools v0.0.50-alpha.103 h1:jYvI86cWiVu8a8iw1panw+pwIiStuUHF76h3fxA6ESI= -github.com/openimsdk/tools v0.0.50-alpha.103/go.mod h1:qCExFBqXpQBMzZck3XGIFwivBayAn2KNqB3WAd++IJw= +github.com/openimsdk/tools v0.0.50-alpha.105 h1:axuCvKXhxY2RGLhpMMFNgBtE0B65T2Sr1JDW3UD9nBs= +github.com/openimsdk/tools v0.0.50-alpha.105/go.mod h1:x9i/e+WJFW4tocy6RNJQ9NofQiP3KJ1Y576/06TqOG4= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= diff --git a/internal/msggateway/ws_server.go b/internal/msggateway/ws_server.go index ec5cd0ab8..d490cc8b9 100644 --- a/internal/msggateway/ws_server.go +++ b/internal/msggateway/ws_server.go @@ -13,6 +13,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/webhook" "github.com/openimsdk/open-im-server/v3/pkg/rpccache" pbAuth "github.com/openimsdk/protocol/auth" + "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/mcontext" "github.com/go-playground/validator/v10" @@ -64,6 +65,8 @@ type WsServer struct { webhookClient *webhook.Client userClient *rpcli.UserClient authClient *rpcli.AuthClient + + ready atomic.Bool } type kickHandler struct { @@ -93,6 +96,8 @@ func (ws *WsServer) SetDiscoveryRegistry(ctx context.Context, disCov discovery.C ws.authClient = rpcli.NewAuthClient(authConn) ws.MessageHandler = NewGrpcHandler(ws.validate, rpcli.NewMsgClient(msgConn), rpcli.NewPushMsgServiceClient(pushConn)) ws.disCov = disCov + + ws.ready.Store(true) return nil } @@ -457,6 +462,11 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { // Create a new connection context connContext := newContext(w, r) + if !ws.ready.Load() { + httpError(connContext, errs.New("ws server not ready")) + return + } + // Check if the current number of online user connections exceeds the maximum limit if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum { // If it exceeds the maximum connection number, return an error via HTTP and stop processing @@ -473,6 +483,11 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) { return } + if ws.authClient == nil { + httpError(connContext, errs.New("auth client is not initialized")) + return + } + // Call the authentication client to parse the Token obtained from the context resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken()) if err != nil {