You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Open-IM-Server/internal/msggateway/n_ws_server.go

212 lines
5.3 KiB

package msggateway
import (
"errors"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
"github.com/OpenIMSDK/Open-IM-Server/pkg/rpcclient"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
"github.com/go-playground/validator/v10"
)
type LongConnServer interface {
Run() error
wsHandler(w http.ResponseWriter, r *http.Request)
GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s interface{}) error
SetMessageHandler(msgRpcClient *rpcclient.MsgClient)
UnRegister(c *Client)
Compressor
Encoder
MessageHandler
}
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1024)
},
}
type WsServer struct {
port int
wsMaxConnNum int64
registerChan chan *Client
unregisterChan chan *Client
clients *UserMap
clientPool sync.Pool
onlineUserNum int64
onlineUserConnNum int64
handshakeTimeout time.Duration
hubServer *Server
validate *validator.Validate
Compressor
Encoder
MessageHandler
}
func (ws *WsServer) SetMessageHandler(rpcClient *rpcclient.MsgClient) {
ws.MessageHandler = NewGrpcHandler(ws.validate, rpcClient)
}
func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c
}
func (ws *WsServer) Validate(s interface{}) error {
return nil
}
func (ws *WsServer) GetUserAllCons(userID string) ([]*Client, bool) {
return ws.clients.GetAll(userID)
}
func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) {
return ws.clients.Get(userID, platform)
}
func NewWsServer(opts ...Option) (*WsServer, error) {
var config configs
for _, o := range opts {
o(&config)
}
if config.port < 1024 {
return nil, errors.New("port not allow to listen")
}
v := validator.New()
return &WsServer{
port: config.port,
wsMaxConnNum: config.maxConnNum,
handshakeTimeout: config.handshakeTimeout,
clientPool: sync.Pool{
New: func() interface{} {
return new(Client)
},
},
registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000),
validate: v,
clients: newUserMap(),
Compressor: NewGzipCompressor(),
Encoder: NewGobEncoder(),
}, nil
}
func (ws *WsServer) Run() error {
var client *Client
go func() {
for {
select {
case client = <-ws.registerChan:
ws.registerClient(client)
case client = <-ws.unregisterChan:
ws.unregisterClient(client)
}
}
}()
http.HandleFunc("/", ws.wsHandler)
// http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {})
return http.ListenAndServe(":"+utils.IntToString(ws.port), nil) //Start listening
}
func (ws *WsServer) registerClient(client *Client) {
var (
userOK bool
clientOK bool
cli []*Client
)
cli, userOK, clientOK = ws.clients.Get(client.userID, client.platformID)
if !userOK {
ws.clients.Set(client.userID, client)
atomic.AddInt64(&ws.onlineUserNum, 1)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
} else {
if clientOK { //已经有同平台的连接存在
ws.clients.Set(client.userID, client)
ws.multiTerminalLoginChecker(cli)
} else {
ws.clients.Set(client.userID, client)
atomic.AddInt64(&ws.onlineUserConnNum, 1)
}
}
log.ZInfo(client.ctx, "user online", "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum)
}
func (ws *WsServer) multiTerminalLoginChecker(client []*Client) {
}
func (ws *WsServer) unregisterClient(client *Client) {
defer ws.clientPool.Put(client)
isDeleteUser := ws.clients.delete(client.userID, client.platformID)
if isDeleteUser {
atomic.AddInt64(&ws.onlineUserNum, -1)
}
atomic.AddInt64(&ws.onlineUserConnNum, -1)
log.ZInfo(client.ctx, "user offline", "close reason", client.closedErr, "online user Num", ws.onlineUserNum, "online user conn Num", ws.onlineUserConnNum)
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
context := newContext(w, r)
if ws.onlineUserConnNum >= ws.wsMaxConnNum {
httpError(context, errs.ErrConnOverMaxNumLimit)
return
}
var (
token string
userID string
platformID string
exists bool
compression bool
)
token, exists = context.Query(Token)
if !exists {
httpError(context, errs.ErrConnArgsErr)
return
}
userID, exists = context.Query(WsUserID)
if !exists {
httpError(context, errs.ErrConnArgsErr)
return
}
platformID, exists = context.Query(PlatformID)
if !exists {
httpError(context, errs.ErrConnArgsErr)
return
}
err := tokenverify.WsVerifyToken(token, userID, platformID)
if err != nil {
httpError(context, err)
return
}
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout)
err = wsLongConn.GenerateLongConn(w, r)
if err != nil {
httpError(context, err)
return
}
compressProtoc, exists := context.Query(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
compressProtoc, exists = context.GetHeader(Compression)
if exists {
if compressProtoc == GzipCompressionProtocol {
compression = true
}
}
client := ws.clientPool.Get().(*Client)
client.ResetClient(context, wsLongConn, compression, ws)
ws.registerChan <- client
go client.readMessage()
}