diff --git a/src/msg_gateway/gate/ws_server.go b/src/msg_gateway/gate/ws_server.go index ffc14680d..73897bb71 100644 --- a/src/msg_gateway/gate/ws_server.go +++ b/src/msg_gateway/gate/ws_server.go @@ -4,25 +4,27 @@ import ( "Open_IM/src/common/config" "Open_IM/src/common/log" "Open_IM/src/utils" - "github.com/gorilla/websocket" "net/http" + "sync" "time" + + "github.com/gorilla/websocket" ) type WServer struct { wsAddr string wsMaxConnNum int wsUpGrader *websocket.Upgrader - wsConnToUser map[*websocket.Conn]string - wsUserToConn map[string]*websocket.Conn + wsConnToUser sync.Map + wsUserToConn sync.Map } func (ws *WServer) onInit(wsPort int) { ip := utils.ServerIP ws.wsAddr = ip + ":" + utils.IntToString(wsPort) ws.wsMaxConnNum = config.Config.LongConnSvr.WebsocketMaxConnNum - ws.wsConnToUser = make(map[*websocket.Conn]string) - ws.wsUserToConn = make(map[string]*websocket.Conn) + ws.wsConnToUser = sync.Map{} + ws.wsUserToConn = sync.Map{} ws.wsUpGrader = &websocket.Upgrader{ HandshakeTimeout: time.Duration(config.Config.LongConnSvr.WebsocketTimeOut) * time.Second, ReadBufferSize: config.Config.LongConnSvr.WebsocketMaxMsgLen, @@ -80,18 +82,20 @@ func (ws *WServer) writeMsg(conn *websocket.Conn, a int, msg []byte) error { func (ws *WServer) addUserConn(uid string, conn *websocket.Conn) { rwLock.Lock() defer rwLock.Unlock() - if oldConn, ok := ws.wsUserToConn[uid]; ok { + if v, ok := ws.wsUserToConn.Load(uid); ok { + oldConn := v.(*websocket.Conn) err := oldConn.Close() - delete(ws.wsConnToUser, oldConn) + ws.wsConnToUser.Delete(oldConn) if err != nil { log.ErrorByKv("close err", "", "uid", uid, "conn", conn) } } else { log.InfoByKv("this user is first login", "", "uid", uid) } - ws.wsConnToUser[conn] = uid - ws.wsUserToConn[uid] = conn - log.WarnByKv("WS Add operation", "", "wsUser added", ws.wsUserToConn, "uid", uid, "online_num", len(ws.wsUserToConn)) + + ws.wsConnToUser.Store(conn, uid) + ws.wsUserToConn.Store(uid, conn) + log.WarnByKv("WS Add operation", "", "wsUser added", ws.wsUserToConn, "uid", uid, "online_num", ws.onlineNum()) } @@ -99,15 +103,16 @@ func (ws *WServer) delUserConn(conn *websocket.Conn) { rwLock.Lock() defer rwLock.Unlock() var uidPlatform string - if uid, ok := ws.wsConnToUser[conn]; ok { + if v, ok := ws.wsConnToUser.Load(conn); ok { + uid := v.(string) uidPlatform = uid - if _, ok = ws.wsUserToConn[uid]; ok { - delete(ws.wsUserToConn, uid) - log.WarnByKv("WS delete operation", "", "wsUser deleted", ws.wsUserToConn, "uid", uid, "online_num", len(ws.wsUserToConn)) + if _, ok := ws.wsUserToConn.Load(uid); ok { + ws.wsUserToConn.Delete(uid) + log.WarnByKv("WS delete operation", "", "wsUser deleted", ws.wsUserToConn, "uid", uid, "online_num", ws.onlineNum()) } else { - log.WarnByKv("uid not exist", "", "wsUser deleted", ws.wsUserToConn, "uid", uid, "online_num", len(ws.wsUserToConn)) + log.WarnByKv("uid not exist", "", "wsUser deleted", ws.wsUserToConn, "uid", uid, "online_num", ws.onlineNum()) } - delete(ws.wsConnToUser, conn) + ws.wsConnToUser.Delete(conn) } err := conn.Close() if err != nil { @@ -119,7 +124,8 @@ func (ws *WServer) delUserConn(conn *websocket.Conn) { func (ws *WServer) getUserConn(uid string) *websocket.Conn { rwLock.RLock() defer rwLock.RUnlock() - if conn, ok := ws.wsUserToConn[uid]; ok { + if v, ok := ws.wsUserToConn.Load(uid); ok { + conn := v.(*websocket.Conn) return conn } return nil @@ -127,8 +133,8 @@ func (ws *WServer) getUserConn(uid string) *websocket.Conn { func (ws *WServer) getUserUid(conn *websocket.Conn) string { rwLock.RLock() defer rwLock.RUnlock() - - if uid, ok := ws.wsConnToUser[conn]; ok { + if v, ok := ws.wsConnToUser.Load(conn); ok { + uid := v.(string) return uid } return "" @@ -154,3 +160,12 @@ func (ws *WServer) headerCheck(w http.ResponseWriter, r *http.Request) bool { } } + +func (ws *WServer) onlineNum() int { + var count int + ws.wsUserToConn.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count +}