Merge branch 'OpenIMSDK:main' into main

pull/620/head
pluto 2 years ago committed by GitHub
commit 9ac6a9ed8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1 +0,0 @@
go build -o

@ -100,7 +100,7 @@ services:
openim_server: openim_server:
image: ghcr.io/openimsdk/openim-server:v3.0.1 image: ghcr.io/openimsdk/openim-server:v3.0.0
container_name: openim-server container_name: openim-server
volumes: volumes:
- ./logs:/Open-IM-Server/logs - ./logs:/Open-IM-Server/logs
@ -123,7 +123,7 @@ services:
max-file: "2" max-file: "2"
openim_chat: openim_chat:
image: openim/openim_chat:v1.1.0 image: ghcr.io/openimsdk/openim-chat:v1.0.0
container_name: openim_chat container_name: openim_chat
restart: always restart: always
depends_on: depends_on:

@ -21,7 +21,12 @@ if ! command -v docker >/dev/null 2>&1; then
fi fi
# Start Docker services using docker-compose # Start Docker services using docker-compose
docker-compose up -d if command -v docker-compose &> /dev/null
then
docker-compose up -d
else
docker compose up -d
fi
# Move back to the 'scripts' folder # Move back to the 'scripts' folder
cd scripts cd scripts

@ -214,6 +214,7 @@ func (m *MessageApi) SendMessage(c *gin.Context) {
if err != nil { if err != nil {
log.ZError(c, "decodeData failed", err) log.ZError(c, "decodeData failed", err)
apiresp.GinError(c, err) apiresp.GinError(c, err)
return
} }
sendMsgReq.MsgData.RecvID = req.RecvID sendMsgReq.MsgData.RecvID = req.RecvID
var status int var status int
@ -260,6 +261,7 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) {
if err != nil { if err != nil {
log.ZError(c, "GetAllUserIDs failed", err) log.ZError(c, "GetAllUserIDs failed", err)
apiresp.GinError(c, err) apiresp.GinError(c, err)
return
} }
if len(recvIDsPart) < showNumber { if len(recvIDsPart) < showNumber {
recvIDs = append(recvIDs, recvIDsPart...) recvIDs = append(recvIDs, recvIDsPart...)
@ -275,6 +277,7 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) {
if err != nil { if err != nil {
log.ZError(c, "decodeData failed", err) log.ZError(c, "decodeData failed", err)
apiresp.GinError(c, err) apiresp.GinError(c, err)
return
} }
for _, recvID := range recvIDs { for _, recvID := range recvIDs {
sendMsgReq.MsgData.RecvID = recvID sendMsgReq.MsgData.RecvID = recvID

@ -25,9 +25,7 @@ import (
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/prome" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/prome"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
"github.com/OpenIMSDK/Open-IM-Server/pkg/errs"
"github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msggateway" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/msggateway"
"github.com/OpenIMSDK/Open-IM-Server/pkg/startrpc" "github.com/OpenIMSDK/Open-IM-Server/pkg/startrpc"
"github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils"
@ -84,9 +82,6 @@ func (s *Server) GetUsersOnlineStatus(
ctx context.Context, ctx context.Context,
req *msggateway.GetUsersOnlineStatusReq, req *msggateway.GetUsersOnlineStatusReq,
) (*msggateway.GetUsersOnlineStatusResp, error) { ) (*msggateway.GetUsersOnlineStatusResp, error) {
if !tokenverify.IsAppManagerUid(ctx) {
return nil, errs.ErrNoPermission.Wrap("only app manager")
}
var resp msggateway.GetUsersOnlineStatusResp var resp msggateway.GetUsersOnlineStatusResp
for _, userID := range req.UserIDs { for _, userID := range req.UserIDs {
clients, ok := s.LongConnServer.GetUserAllCons(userID) clients, ok := s.LongConnServer.GetUserAllCons(userID)
@ -181,13 +176,12 @@ func (s *Server) KickUserOffline(
if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok { if clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID)); ok {
for _, client := range clients { for _, client := range clients {
log.ZDebug(ctx, "kick user offline", "userID", v, "platformID", req.PlatformID, "client", client) log.ZDebug(ctx, "kick user offline", "userID", v, "platformID", req.PlatformID, "client", client)
err := client.KickOnlineMessage() if err := client.longConnServer.KickUserConn(client); err != nil {
if err != nil { log.ZWarn(ctx, "kick user offline failed", err, "userID", v, "platformID", req.PlatformID)
return nil, err
} }
} }
} else { } else {
log.ZWarn(ctx, "conn not exist", nil, "userID", v, "platformID", req.PlatformID) log.ZInfo(ctx, "conn not exist", "userID", v, "platformID", req.PlatformID)
} }
} }
return &msggateway.KickUserOfflineResp{}, nil return &msggateway.KickUserOfflineResp{}, nil

@ -47,6 +47,7 @@ type LongConnServer interface {
Validate(s interface{}) error Validate(s interface{}) error
SetCacheHandler(cache cache.MsgModel) SetCacheHandler(cache cache.MsgModel)
SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry) SetDiscoveryRegistry(client discoveryregistry.SvcDiscoveryRegistry)
KickUserConn(client *Client) error
UnRegister(c *Client) UnRegister(c *Client)
Compressor Compressor
Encoder Encoder
@ -145,7 +146,7 @@ func (ws *WsServer) Run() error {
case client = <-ws.unregisterChan: case client = <-ws.unregisterChan:
ws.unregisterClient(client) ws.unregisterClient(client)
case onlineInfo := <-ws.kickHandlerChan: case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo) ws.multiTerminalLoginChecker(onlineInfo.clientOK, onlineInfo.oldClients, onlineInfo.newClient)
} }
} }
}() }()
@ -207,80 +208,77 @@ func getRemoteAdders(client []*Client) string {
return ret return ret
} }
func (ws *WsServer) multiTerminalLoginChecker(info *kickHandler) { func (ws *WsServer) KickUserConn(client *Client) error {
ws.clients.deleteClients(client.UserID, []*Client{client})
return client.KickOnlineMessage()
}
func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) {
switch config.Config.MultiLoginPolicy { switch config.Config.MultiLoginPolicy {
case constant.DefalutNotKick: case constant.DefalutNotKick:
case constant.PCAndOther: case constant.PCAndOther:
if constant.PlatformIDToClass(info.newClient.PlatformID) == constant.TerminalPC { if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC {
return return
} }
fallthrough fallthrough
case constant.AllLoginButSameTermKick: case constant.AllLoginButSameTermKick:
if info.clientOK { if clientOK {
ws.clients.deleteClients(info.newClient.UserID, info.oldClients) ws.clients.deleteClients(newClient.UserID, oldClients)
for _, c := range info.oldClients { for _, c := range oldClients {
err := c.KickOnlineMessage() err := c.KickOnlineMessage()
if err != nil { if err != nil {
log.ZWarn(c.ctx, "KickOnlineMessage", err) log.ZWarn(c.ctx, "KickOnlineMessage", err)
} }
} }
m, err := ws.cache.GetTokensWithoutError( m, err := ws.cache.GetTokensWithoutError(
info.newClient.ctx, newClient.ctx,
info.newClient.UserID, newClient.UserID,
info.newClient.PlatformID, newClient.PlatformID,
) )
if err != nil && err != redis.Nil { if err != nil && err != redis.Nil {
log.ZWarn( log.ZWarn(
info.newClient.ctx, newClient.ctx,
"get token from redis err", "get token from redis err",
err, err,
"userID", "userID",
info.newClient.UserID, newClient.UserID,
"platformID", "platformID",
info.newClient.PlatformID, newClient.PlatformID,
) )
return return
} }
if m == nil { if m == nil {
log.ZWarn( log.ZWarn(
info.newClient.ctx, newClient.ctx,
"m is nil", "m is nil",
errors.New("m is nil"), errors.New("m is nil"),
"userID", "userID",
info.newClient.UserID, newClient.UserID,
"platformID", "platformID",
info.newClient.PlatformID, newClient.PlatformID,
) )
return return
} }
log.ZDebug( log.ZDebug(
info.newClient.ctx, newClient.ctx,
"get token from redis", "get token from redis",
"userID", "userID",
info.newClient.UserID, newClient.UserID,
"platformID", "platformID",
info.newClient.PlatformID, newClient.PlatformID,
"tokenMap", "tokenMap",
m, m,
) )
for k := range m { for k := range m {
if k != info.newClient.ctx.GetToken() { if k != newClient.ctx.GetToken() {
m[k] = constant.KickedToken m[k] = constant.KickedToken
} }
} }
log.ZDebug(info.newClient.ctx, "set token map is ", "token map", m, "userID", info.newClient.UserID) log.ZDebug(newClient.ctx, "set token map is ", "token map", m, "userID", newClient.UserID)
err = ws.cache.SetTokenMapByUidPid(info.newClient.ctx, info.newClient.UserID, info.newClient.PlatformID, m) err = ws.cache.SetTokenMapByUidPid(newClient.ctx, newClient.UserID, newClient.PlatformID, m)
if err != nil { if err != nil {
log.ZWarn( log.ZWarn(newClient.ctx, "SetTokenMapByUidPid err", err, "userID", newClient.UserID, "platformID", newClient.PlatformID)
info.newClient.ctx,
"SetTokenMapByUidPid err",
err,
"userID",
info.newClient.UserID,
"platformID",
info.newClient.PlatformID,
)
return return
} }
} }

@ -23,6 +23,7 @@ import (
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/cache"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/controller" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/controller"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/log"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mcontext"
"github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/tokenverify"
"github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry" "github.com/OpenIMSDK/Open-IM-Server/pkg/discoveryregistry"
@ -129,11 +130,16 @@ func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID
if err != nil { if err != nil {
return err return err
} }
for _, v := range conns {
log.ZDebug(ctx, "forceKickOff", "conn", v.(*grpc.ClientConn).Target())
}
for _, v := range conns { for _, v := range conns {
client := msggateway.NewMsgGatewayClient(v) client := msggateway.NewMsgGatewayClient(v)
kickReq := &msggateway.KickUserOfflineReq{KickUserIDList: []string{userID}, PlatformID: platformID} kickReq := &msggateway.KickUserOfflineReq{KickUserIDList: []string{userID}, PlatformID: platformID}
_, err := client.KickUserOffline(ctx, kickReq) _, err := client.KickUserOffline(ctx, kickReq)
return utils.Wrap(err, "") if err != nil {
log.ZError(ctx, "forceKickOff", err, "kickReq", kickReq)
}
} }
return nil return nil
} }

@ -108,11 +108,11 @@ func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sd
func (m *msgServer) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) (resp *msg.SearchMessageResp, err error) { func (m *msgServer) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) (resp *msg.SearchMessageResp, err error) {
var chatLogs []*sdkws.MsgData var chatLogs []*sdkws.MsgData
var total int32
resp = &msg.SearchMessageResp{} resp = &msg.SearchMessageResp{}
if chatLogs, err = m.MsgDatabase.SearchMessage(ctx, req); err != nil { if total, chatLogs, err = m.MsgDatabase.SearchMessage(ctx, req); err != nil {
return nil, err return nil, err
} }
var num int
for _, chatLog := range chatLogs { for _, chatLog := range chatLogs {
pbChatLog := &msg.ChatLog{} pbChatLog := &msg.ChatLog{}
utils.CopyStructFields(pbChatLog, chatLog) utils.CopyStructFields(pbChatLog, chatLog)
@ -146,9 +146,8 @@ func (m *msgServer) SearchMessage(ctx context.Context, req *msg.SearchMessageReq
pbChatLog.GroupType = group.GroupType pbChatLog.GroupType = group.GroupType
} }
resp.ChatLogs = append(resp.ChatLogs, pbChatLog) resp.ChatLogs = append(resp.ChatLogs, pbChatLog)
num++
} }
resp.ChatLogsNum = int32(num) resp.ChatLogsNum = total
return resp, nil return resp, nil
} }

@ -41,11 +41,11 @@ func StartCronTask() error {
panic(err) panic(err)
} }
log.ZInfo(context.Background(), "start msgDestruct cron task", "cron config", config.Config.MsgDestructTime) log.ZInfo(context.Background(), "start msgDestruct cron task", "cron config", config.Config.MsgDestructTime)
_, err = c.AddFunc(config.Config.MsgDestructTime, msgTool.ConversationsDestructMsgs) // _, err = c.AddFunc(config.Config.MsgDestructTime, msgTool.ConversationsDestructMsgs)
if err != nil { // if err != nil {
fmt.Println("start conversationsDestructMsgs cron failed", err.Error(), config.Config.ChatRecordsClearTime) // fmt.Println("start conversationsDestructMsgs cron failed", err.Error(), config.Config.ChatRecordsClearTime)
panic(err) // panic(err)
} // }
c.Start() c.Start()
wg.Wait() wg.Wait()
return nil return nil

@ -54,7 +54,7 @@ type ConversationCache interface {
// get one conversation from msgCache // get one conversation from msgCache
GetConversation(ctx context.Context, ownerUserID, conversationID string) (*relationTb.ConversationModel, error) GetConversation(ctx context.Context, ownerUserID, conversationID string) (*relationTb.ConversationModel, error)
DelConvsersations(ownerUserID string, conversationIDs ...string) ConversationCache DelConversations(ownerUserID string, conversationIDs ...string) ConversationCache
DelUsersConversation(conversationID string, ownerUserIDs ...string) ConversationCache DelUsersConversation(conversationID string, ownerUserIDs ...string) ConversationCache
// get one conversation from msgCache // get one conversation from msgCache
GetConversations( GetConversations(
@ -225,9 +225,9 @@ func (c *ConversationRedisCache) GetConversation(
) )
} }
func (c *ConversationRedisCache) DelConvsersations(ownerUserID string, convsersationIDs ...string) ConversationCache { func (c *ConversationRedisCache) DelConversations(ownerUserID string, conversationIDs ...string) ConversationCache {
var keys []string var keys []string
for _, conversationID := range convsersationIDs { for _, conversationID := range conversationIDs {
keys = append(keys, c.getConversationKey(ownerUserID, conversationID)) keys = append(keys, c.getConversationKey(ownerUserID, conversationID))
} }
cache := c.NewCache() cache := c.NewCache()

@ -104,7 +104,7 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
if err != nil { if err != nil {
return err return err
} }
cache = cache.DelConversationIDs(NotUserIDs...).DelUserConversationIDsHash(NotUserIDs...).DelConvsersations(conversation.ConversationID, NotUserIDs...) cache = cache.DelConversationIDs(NotUserIDs...).DelUserConversationIDsHash(NotUserIDs...).DelConversations(conversation.ConversationID, NotUserIDs...)
} }
return nil return nil
}); err != nil { }); err != nil {
@ -128,7 +128,7 @@ func (c *conversationDatabase) CreateConversation(ctx context.Context, conversat
var userIDs []string var userIDs []string
cache := c.cache.NewCache() cache := c.cache.NewCache()
for _, conversation := range conversations { for _, conversation := range conversations {
cache = cache.DelConvsersations(conversation.OwnerUserID, conversation.ConversationID) cache = cache.DelConversations(conversation.OwnerUserID, conversation.ConversationID)
userIDs = append(userIDs, conversation.OwnerUserID) userIDs = append(userIDs, conversation.OwnerUserID)
} }
return cache.DelConversationIDs(userIDs...).DelUserConversationIDsHash(userIDs...).ExecDel(ctx) return cache.DelConversationIDs(userIDs...).DelUserConversationIDsHash(userIDs...).ExecDel(ctx)
@ -190,7 +190,7 @@ func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUs
var conversationIDs []string var conversationIDs []string
for _, conversation := range conversations { for _, conversation := range conversations {
conversationIDs = append(conversationIDs, conversation.ConversationID) conversationIDs = append(conversationIDs, conversation.ConversationID)
cache = cache.DelConvsersations(conversation.OwnerUserID, conversation.ConversationID) cache = cache.DelConversations(conversation.OwnerUserID, conversation.ConversationID)
} }
conversationTx := c.conversationDB.NewTx(tx) conversationTx := c.conversationDB.NewTx(tx)
existConversations, err := conversationTx.Find(ctx, ownerUserID, conversationIDs) existConversations, err := conversationTx.Find(ctx, ownerUserID, conversationIDs)
@ -247,7 +247,7 @@ func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context,
for _, v := range notExistUserIDs { for _, v := range notExistUserIDs {
conversation := relationTb.ConversationModel{ConversationType: constant.SuperGroupChatType, GroupID: groupID, OwnerUserID: v, ConversationID: conversationID} conversation := relationTb.ConversationModel{ConversationType: constant.SuperGroupChatType, GroupID: groupID, OwnerUserID: v, ConversationID: conversationID}
conversations = append(conversations, &conversation) conversations = append(conversations, &conversation)
cache = cache.DelConvsersations(v, conversationID) cache = cache.DelConversations(v, conversationID)
} }
cache = cache.DelConversationIDs(notExistUserIDs...).DelUserConversationIDsHash(notExistUserIDs...) cache = cache.DelConversationIDs(notExistUserIDs...).DelUserConversationIDsHash(notExistUserIDs...)
if len(conversations) > 0 { if len(conversations) > 0 {
@ -261,7 +261,7 @@ func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context,
return err return err
} }
for _, v := range existConversationUserIDs { for _, v := range existConversationUserIDs {
cache = cache.DelConvsersations(v, conversationID) cache = cache.DelConversations(v, conversationID)
} }
return nil return nil
}); err != nil { }); err != nil {

@ -92,7 +92,7 @@ type CommonMsgDatabase interface {
GetConversationMinMaxSeqInMongoAndCache(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo, minSeqCache, maxSeqCache int64, err error) GetConversationMinMaxSeqInMongoAndCache(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo, minSeqCache, maxSeqCache int64, err error)
SetSendMsgStatus(ctx context.Context, id string, status int32) error SetSendMsgStatus(ctx context.Context, id string, status int32) error
GetSendMsgStatus(ctx context.Context, id string) (int32, error) GetSendMsgStatus(ctx context.Context, id string) (int32, error)
SearchMessage(ctx context.Context, req *pbMsg.SearchMessageReq) (msgData []*sdkws.MsgData, err error) SearchMessage(ctx context.Context, req *pbMsg.SearchMessageReq) (total int32, msgData []*sdkws.MsgData, err error)
// to mq // to mq
MsgToMQ(ctx context.Context, key string, msg2mq *sdkws.MsgData) error MsgToMQ(ctx context.Context, key string, msg2mq *sdkws.MsgData) error
@ -940,14 +940,14 @@ func (db *commonMsgDatabase) RangeGroupSendCount(
return db.msgDocDatabase.RangeGroupSendCount(ctx, start, end, ase, pageNumber, showNumber) return db.msgDocDatabase.RangeGroupSendCount(ctx, start, end, ase, pageNumber, showNumber)
} }
func (db *commonMsgDatabase) SearchMessage(ctx context.Context, req *pbMsg.SearchMessageReq) (msgData []*sdkws.MsgData, err error) { func (db *commonMsgDatabase) SearchMessage(ctx context.Context, req *pbMsg.SearchMessageReq) (total int32, msgData []*sdkws.MsgData, err error) {
var totalMsgs []*sdkws.MsgData var totalMsgs []*sdkws.MsgData
msgs, err := db.msgDocDatabase.SearchMessage(ctx, req) total, msgs, err := db.msgDocDatabase.SearchMessage(ctx, req)
if err != nil { if err != nil {
return nil, err return 0, nil, err
} }
for _, msg := range msgs { for _, msg := range msgs {
totalMsgs = append(totalMsgs, convert.MsgDB2Pb(msg.Msg)) totalMsgs = append(totalMsgs, convert.MsgDB2Pb(msg.Msg))
} }
return totalMsgs, nil return total, totalMsgs, nil
} }

@ -110,7 +110,7 @@ type MsgDocModelInterface interface {
GetMsgDocModelByIndex(ctx context.Context, conversationID string, index, sort int64) (*MsgDocModel, error) GetMsgDocModelByIndex(ctx context.Context, conversationID string, index, sort int64) (*MsgDocModel, error)
DeleteMsgsInOneDocByIndex(ctx context.Context, docID string, indexes []int) error DeleteMsgsInOneDocByIndex(ctx context.Context, docID string, indexes []int) error
MarkSingleChatMsgsAsRead(ctx context.Context, userID string, docID string, indexes []int64) error MarkSingleChatMsgsAsRead(ctx context.Context, userID string, docID string, indexes []int64) error
SearchMessage(ctx context.Context, req *msg.SearchMessageReq) ([]*MsgInfoModel, error) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) (int32, []*MsgInfoModel, error)
RangeUserSendCount( RangeUserSendCount(
ctx context.Context, ctx context.Context,
start time.Time, start time.Time,

@ -1067,20 +1067,20 @@ func (m *MsgMongoDriver) RangeGroupSendCount(
return result[0].MsgCount, result[0].UserCount, groups, dateCount, nil return result[0].MsgCount, result[0].UserCount, groups, dateCount, nil
} }
func (m *MsgMongoDriver) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) ([]*table.MsgInfoModel, error) { func (m *MsgMongoDriver) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) (int32, []*table.MsgInfoModel, error) {
msgs, err := m.searchMessage(ctx, req) total, msgs, err := m.searchMessage(ctx, req)
if err != nil { if err != nil {
return nil, err return 0, nil, err
} }
for _, msg1 := range msgs { for _, msg1 := range msgs {
if msg1.IsRead { if msg1.IsRead {
msg1.Msg.IsRead = true msg1.Msg.IsRead = true
} }
} }
return msgs, nil return total, msgs, nil
} }
func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessageReq) ([]*table.MsgInfoModel, error) { func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessageReq) (int32, []*table.MsgInfoModel, error) {
var pipe mongo.Pipeline var pipe mongo.Pipeline
condition := bson.A{} condition := bson.A{}
if req.SendTime != "" { if req.SendTime != "" {
@ -1153,16 +1153,16 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
} }
cursor, err := m.MsgCollection.Aggregate(ctx, pipe) cursor, err := m.MsgCollection.Aggregate(ctx, pipe)
if err != nil { if err != nil {
return nil, err return 0, nil, err
} }
var msgsDocs []table.MsgDocModel var msgsDocs []table.MsgDocModel
err = cursor.All(ctx, &msgsDocs) err = cursor.All(ctx, &msgsDocs)
if err != nil { if err != nil {
return nil, err return 0, nil, err
} }
if len(msgsDocs) == 0 { if len(msgsDocs) == 0 {
return nil, errs.Wrap(mongo.ErrNoDocuments) return 0, nil, errs.Wrap(mongo.ErrNoDocuments)
} }
msgs := make([]*table.MsgInfoModel, 0) msgs := make([]*table.MsgInfoModel, 0)
for index := range msgsDocs { for index := range msgsDocs {
@ -1187,14 +1187,14 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
} }
data, err := json.Marshal(&revokeContent) data, err := json.Marshal(&revokeContent)
if err != nil { if err != nil {
return nil, err return 0, nil, err
} }
elem := sdkws.NotificationElem{ elem := sdkws.NotificationElem{
Detail: string(data), Detail: string(data),
} }
content, err := json.Marshal(&elem) content, err := json.Marshal(&elem)
if err != nil { if err != nil {
return nil, err return 0, nil, err
} }
msg.Msg.ContentType = constant.MsgRevokeNotification msg.Msg.ContentType = constant.MsgRevokeNotification
msg.Msg.Content = string(content) msg.Msg.Content = string(content)
@ -1209,5 +1209,5 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
} else { } else {
msgs = msgs[start:] msgs = msgs[start:]
} }
return msgs, nil return n, msgs, nil
} }

@ -126,6 +126,11 @@ func (x *MarkMsgsAsReadReq) Check() error {
if x.UserID == "" { if x.UserID == "" {
return errs.ErrArgs.Wrap("userID is empty") return errs.ErrArgs.Wrap("userID is empty")
} }
for _, seq := range x.Seqs {
if seq == 0 {
return errs.ErrArgs.Wrap("seqs has 0 value is invalid")
}
}
return nil return nil
} }
@ -139,6 +144,11 @@ func (x *MarkConversationAsReadReq) Check() error {
if x.HasReadSeq < 1 { if x.HasReadSeq < 1 {
return errs.ErrArgs.Wrap("hasReadSeq is invalid") return errs.ErrArgs.Wrap("hasReadSeq is invalid")
} }
for _, seq := range x.Seqs {
if seq == 0 {
return errs.ErrArgs.Wrap("seqs has 0 value is invalid")
}
}
return nil return nil
} }

Loading…
Cancel
Save