|
|
|
|
@ -13,6 +13,7 @@ import (
|
|
|
|
|
"github.com/openimsdk/open-im-server/v3/pkg/common/webhook"
|
|
|
|
|
"github.com/openimsdk/open-im-server/v3/pkg/rpccache"
|
|
|
|
|
pbAuth "github.com/openimsdk/protocol/auth"
|
|
|
|
|
"github.com/openimsdk/tools/errs"
|
|
|
|
|
"github.com/openimsdk/tools/mcontext"
|
|
|
|
|
|
|
|
|
|
"github.com/go-playground/validator/v10"
|
|
|
|
|
@ -64,6 +65,8 @@ type WsServer struct {
|
|
|
|
|
webhookClient *webhook.Client
|
|
|
|
|
userClient *rpcli.UserClient
|
|
|
|
|
authClient *rpcli.AuthClient
|
|
|
|
|
|
|
|
|
|
ready atomic.Bool
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type kickHandler struct {
|
|
|
|
|
@ -93,6 +96,8 @@ func (ws *WsServer) SetDiscoveryRegistry(ctx context.Context, disCov discovery.C
|
|
|
|
|
ws.authClient = rpcli.NewAuthClient(authConn)
|
|
|
|
|
ws.MessageHandler = NewGrpcHandler(ws.validate, rpcli.NewMsgClient(msgConn), rpcli.NewPushMsgServiceClient(pushConn))
|
|
|
|
|
ws.disCov = disCov
|
|
|
|
|
|
|
|
|
|
ws.ready.Store(true)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -457,6 +462,11 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
// Create a new connection context
|
|
|
|
|
connContext := newContext(w, r)
|
|
|
|
|
|
|
|
|
|
if !ws.ready.Load() {
|
|
|
|
|
httpError(connContext, errs.New("ws server not ready"))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check if the current number of online user connections exceeds the maximum limit
|
|
|
|
|
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
|
|
|
|
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
|
|
|
|
|
@ -473,6 +483,11 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ws.authClient == nil {
|
|
|
|
|
httpError(connContext, errs.New("auth client is not initialized"))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Call the authentication client to parse the Token obtained from the context
|
|
|
|
|
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
|
|
|
|
if err != nil {
|
|
|
|
|
|