fix: GetActiveConversation

pull/2664/head
withchao 1 year ago
parent 4cb56a1326
commit 9edc65867c

@ -196,3 +196,7 @@ require (
golang.org/x/crypto v0.27.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)
replace (
github.com/openimsdk/protocol => /Users/chao/Desktop/withchao/protocol
)

@ -0,0 +1,146 @@
package api
import (
"context"
"github.com/gin-gonic/gin"
"github.com/openimsdk/protocol/conversation"
"github.com/openimsdk/protocol/msg"
"github.com/openimsdk/protocol/sdkws"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/grpc"
"sort"
)
const limitGetActiveConversation = 100
type JSSdk struct {
msg msg.MsgClient
conv conversation.ConversationClient
}
func field[A, B, C any](ctx context.Context, fn func(ctx context.Context, req *A, opts ...grpc.CallOption) (*B, error), req *A, get func(*B) C) (C, error) {
resp, err := fn(ctx, req)
if err != nil {
var c C
return c, err
}
return get(resp), nil
}
func (x *JSSdk) GetActiveConversation(ctx *gin.Context) ([]ConversationMsg, error) {
opUserID := mcontext.GetOpUserID(ctx)
conversationIDs, err := field(ctx, x.conv.GetConversationIDs,
&conversation.GetConversationIDsReq{UserID: opUserID}, (*conversation.GetConversationIDsResp).GetConversationIDs)
if err != nil {
return nil, err
}
if len(conversationIDs) == 0 {
return nil, nil
}
activeConversation, err := field(ctx, x.msg.GetActiveConversation,
&msg.GetActiveConversationReq{ConversationIDs: conversationIDs}, (*msg.GetActiveConversationResp).GetConversations)
if err != nil {
return nil, err
}
if len(activeConversation) == 0 {
return nil, nil
}
sortConversations := sortActiveConversations{
Conversation: activeConversation,
}
if len(activeConversation) > 1 {
// todo get pinned conversation ids
}
sort.Sort(&sortConversations)
sortList := sortConversations.Top(limitGetActiveConversation)
conversations, err := field(ctx, x.conv.GetConversations,
&conversation.GetConversationsReq{ConversationIDs: datautil.Slice(sortList, func(c *msg.ActiveConversation) string {
return c.ConversationID
})}, (*conversation.GetConversationsResp).GetConversations)
if err != nil {
return nil, err
}
readSeq, err := field(ctx, x.msg.GetHasReadSeqs,
&msg.GetHasReadSeqsReq{UserID: opUserID, ConversationIDs: conversationIDs}, (*msg.SeqsInfoResp).GetMaxSeqs)
if err != nil {
return nil, err
}
msgs, err := field(ctx, x.msg.GetSeqMessage,
&msg.GetSeqMessageReq{
UserID: opUserID,
Conversations: datautil.Slice(sortList, func(c *msg.ActiveConversation) *msg.ConversationSeqs {
return &msg.ConversationSeqs{
ConversationID: c.ConversationID,
Seqs: []int64{c.MaxSeq},
}
}),
}, (*msg.GetSeqMessageResp).GetMsgs)
if err != nil {
return nil, err
}
conversationMap := datautil.SliceToMap(conversations, func(c *conversation.Conversation) string {
return c.ConversationID
})
resp := make([]ConversationMsg, 0, len(sortList))
for _, c := range sortList {
conv, ok := conversationMap[c.ConversationID]
if !ok {
continue
}
msgList, ok := msgs[c.ConversationID]
if ok {
continue
}
var lastMsg *sdkws.MsgData
if len(msgList.Msgs) > 0 {
lastMsg = msgList.Msgs[0]
}
resp = append(resp, ConversationMsg{
Conversation: conv,
LastMsg: lastMsg,
MaxSeq: c.MaxSeq,
MaxSeqTime: c.LastTime,
ReadSeq: readSeq[c.ConversationID],
})
}
return resp, nil
}
type ConversationMsg struct {
Conversation *conversation.Conversation `json:"conversation"`
LastMsg *sdkws.MsgData `json:"lastMsg"`
ReadSeq int64 `json:"readSeq"`
MaxSeq int64 `json:"maxSeq"`
MaxSeqTime int64 `json:"maxSeqTime"`
}
type sortActiveConversations struct {
Conversation []*msg.ActiveConversation
PinnedConversationIDs map[string]struct{}
}
func (s sortActiveConversations) Top(limit int) []*msg.ActiveConversation {
if limit > 0 && len(s.Conversation) > limit {
return s.Conversation[:limit]
}
return s.Conversation
}
func (s sortActiveConversations) Len() int {
return len(s.Conversation)
}
func (s sortActiveConversations) Less(i, j int) bool {
iv, jv := s.Conversation[i], s.Conversation[j]
_, ip := s.PinnedConversationIDs[iv.ConversationID]
_, jp := s.PinnedConversationIDs[jv.ConversationID]
if ip != jp {
return ip
}
return iv.LastTime > jv.LastTime
}
func (s sortActiveConversations) Swap(i, j int) {
s.Conversation[i], s.Conversation[j] = s.Conversation[j], s.Conversation[i]
}

@ -0,0 +1,37 @@
package api
import (
"github.com/openimsdk/protocol/msg"
"sort"
"testing"
)
func TestName(t *testing.T) {
val := sortActiveConversations{
Conversation: []*msg.ActiveConversation{
{
ConversationID: "100",
LastTime: 100,
},
{
ConversationID: "200",
LastTime: 200,
},
{
ConversationID: "300",
LastTime: 300,
},
{
ConversationID: "400",
LastTime: 400,
},
},
//PinnedConversationIDs: map[string]struct{}{
// "100": {},
// "300": {},
//},
}
sort.Sort(&val)
t.Log(val)
}

@ -55,7 +55,7 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m
conversationMaxSeqMap[conversation.ConversationID] = conversation.MaxSeq
}
}
maxSeqs, err := m.MsgDatabase.GetMaxSeqs(ctx, conversationIDs)
maxSeqs, err := m.MsgDatabase.GetMaxSeqsWithTime(ctx, conversationIDs)
if err != nil {
return nil, err
}
@ -63,7 +63,8 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m
for conversationID, maxSeq := range maxSeqs {
resp.Seqs[conversationID] = &msg.Seqs{
HasReadSeq: hasReadSeqs[conversationID],
MaxSeq: maxSeq,
MaxSeq: maxSeq.Seq,
MaxSeqTime: maxSeq.Time,
}
if v, ok := conversationMaxSeqMap[conversationID]; ok {
resp.Seqs[conversationID].MaxSeq = v

@ -16,10 +16,10 @@ package msg
import (
"context"
pbmsg "github.com/openimsdk/protocol/msg"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
pbmsg "github.com/openimsdk/protocol/msg"
"sort"
)
func (m *msgServer) GetConversationMaxSeq(ctx context.Context, req *pbmsg.GetConversationMaxSeqReq) (*pbmsg.GetConversationMaxSeqResp, error) {
@ -62,3 +62,25 @@ func (m *msgServer) SetUserConversationsMinSeq(ctx context.Context, req *pbmsg.S
}
return &pbmsg.SetUserConversationsMinSeqResp{}, nil
}
func (m *msgServer) GetActiveConversation(ctx context.Context, req *pbmsg.GetActiveConversationReq) (*pbmsg.GetActiveConversationResp, error) {
res, err := m.MsgDatabase.GetCacheMaxSeqWithTime(ctx, req.ConversationIDs)
if err != nil {
return nil, err
}
conversations := make([]*pbmsg.ActiveConversation, 0, len(res))
for conversationID, val := range res {
conversations = append(conversations, &pbmsg.ActiveConversation{
MaxSeq: val.Seq,
LastTime: val.Time,
ConversationID: conversationID,
})
}
if req.Limit > 0 {
sort.Sort(activeConversations(conversations))
if len(conversations) > int(req.Limit) {
conversations = conversations[:req.Limit]
}
}
return &pbmsg.GetActiveConversationResp{Conversations: conversations}, nil
}

@ -15,6 +15,7 @@
package msg
import (
"github.com/openimsdk/protocol/msg"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/mongo"
@ -28,3 +29,63 @@ func IsNotFound(err error) bool {
return false
}
}
type activeConversations []*msg.ActiveConversation
func (s activeConversations) Len() int {
return len(s)
}
func (s activeConversations) Less(i, j int) bool {
return s[i].LastTime > s[j].LastTime
}
func (s activeConversations) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
//type seqTime struct {
// ConversationID string
// Seq int64
// Time int64
// Unread int64
// Pinned bool
//}
//
//func (s seqTime) String() string {
// return fmt.Sprintf("<Time_%d,Unread_%d,Pinned_%t>", s.Time, s.Unread, s.Pinned)
//}
//
//type seqTimes []seqTime
//
//func (s seqTimes) Len() int {
// return len(s)
//}
//
//// Less sticky priority, unread priority, time descending
//func (s seqTimes) Less(i, j int) bool {
// iv, jv := s[i], s[j]
// if iv.Pinned && (!jv.Pinned) {
// return true
// }
// if jv.Pinned && (!iv.Pinned) {
// return false
// }
// if iv.Unread > 0 && jv.Unread == 0 {
// return true
// }
// if jv.Unread > 0 && iv.Unread == 0 {
// return false
// }
// return iv.Time > jv.Time
//}
//
//func (s seqTimes) Swap(i, j int) {
// s[i], s[j] = s[j], s[i]
//}
//
//type conversationStatus struct {
// ConversationID string
// Pinned bool
// Recv bool
//}

@ -0,0 +1,36 @@
package msg
//func TestName(t *testing.T) {
// arr := seqTimes{
// {
// Time: 100,
// Pinned: true,
// Unread: 1,
// },
// {
// Time: 200,
// Pinned: true,
// Unread: 10,
// },
// {
// Time: 300,
// Pinned: false,
// Unread: 10,
// },
// {
// Time: 100,
// Pinned: false,
// Unread: 0,
// },
// {
// Time: 400,
// Pinned: true,
// Unread: 0,
// },
// }
// rand.Shuffle(len(arr), func(i, j int) {
// arr[i], arr[j] = arr[j], arr[i]
// })
// sort.Sort(arr)
// fmt.Println(arr)
//}

@ -12,6 +12,7 @@ import (
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
"strconv"
"time"
)
@ -57,6 +58,14 @@ func (s *seqConversationCacheRedis) getSingleMaxSeq(ctx context.Context, convers
return map[string]int64{conversationID: seq}, nil
}
func (s *seqConversationCacheRedis) getSingleMaxSeqWithTime(ctx context.Context, conversationID string) (map[string]database.SeqTime, error) {
seq, err := s.GetMaxSeqWithTime(ctx, conversationID)
if err != nil {
return nil, err
}
return map[string]database.SeqTime{conversationID: seq}, nil
}
func (s *seqConversationCacheRedis) batchGetMaxSeq(ctx context.Context, keys []string, keyConversationID map[string]string, seqs map[string]int64) error {
result := make([]*redis.StringCmd, len(keys))
pipe := s.rdb.Pipeline()
@ -88,6 +97,46 @@ func (s *seqConversationCacheRedis) batchGetMaxSeq(ctx context.Context, keys []s
return nil
}
func (s *seqConversationCacheRedis) batchGetMaxSeqWithTime(ctx context.Context, keys []string, keyConversationID map[string]string, seqs map[string]database.SeqTime) error {
result := make([]*redis.SliceCmd, len(keys))
pipe := s.rdb.Pipeline()
for i, key := range keys {
result[i] = pipe.HMGet(ctx, key, "CURR", "TIME")
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return errs.Wrap(err)
}
var notFoundKey []string
for i, r := range result {
val, err := r.Result()
if len(val) != 2 {
return errs.WrapMsg(err, "batchGetMaxSeqWithTime invalid result", "key", keys[i], "res", val)
}
if val[0] == nil {
notFoundKey = append(notFoundKey, keys[i])
continue
}
seq, err := s.parseInt64(val[0])
if err != nil {
return err
}
mill, err := s.parseInt64(val[1])
if err != nil {
return err
}
seqs[keyConversationID[keys[i]]] = database.SeqTime{Seq: seq, Time: mill}
}
for _, key := range notFoundKey {
conversationID := keyConversationID[key]
seq, err := s.GetMaxSeqWithTime(ctx, conversationID)
if err != nil {
return err
}
seqs[conversationID] = seq
}
return nil
}
func (s *seqConversationCacheRedis) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) {
switch len(conversationIDs) {
case 0:
@ -121,6 +170,39 @@ func (s *seqConversationCacheRedis) GetMaxSeqs(ctx context.Context, conversation
return seqs, nil
}
func (s *seqConversationCacheRedis) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
switch len(conversationIDs) {
case 0:
return map[string]database.SeqTime{}, nil
case 1:
return s.getSingleMaxSeqWithTime(ctx, conversationIDs[0])
}
keys := make([]string, 0, len(conversationIDs))
keyConversationID := make(map[string]string, len(conversationIDs))
for _, conversationID := range conversationIDs {
key := s.getSeqMallocKey(conversationID)
if _, ok := keyConversationID[key]; ok {
continue
}
keys = append(keys, key)
keyConversationID[key] = conversationID
}
if len(keys) == 1 {
return s.getSingleMaxSeqWithTime(ctx, conversationIDs[0])
}
slotKeys, err := groupKeysBySlot(ctx, s.rdb, keys)
if err != nil {
return nil, err
}
seqs := make(map[string]database.SeqTime, len(conversationIDs))
for _, keys := range slotKeys {
if err := s.batchGetMaxSeqWithTime(ctx, keys, keyConversationID, seqs); err != nil {
return nil, err
}
}
return seqs, nil
}
func (s *seqConversationCacheRedis) getSeqMallocKey(conversationID string) string {
return cachekey.GetMallocSeqKey(conversationID)
}
@ -318,11 +400,11 @@ func (s *seqConversationCacheRedis) mallocTime(ctx context.Context, conversation
}
if lastSeq == seq {
s.setSeqRetry(ctx, key, states[3], currSeq+size, seq+mallocSize)
return currSeq, 0, nil
return currSeq, states[4], nil
} else {
log.ZWarn(ctx, "malloc seq not equal cache last seq", nil, "conversationID", conversationID, "currSeq", currSeq, "lastSeq", lastSeq, "mallocSeq", seq)
s.setSeqRetry(ctx, key, states[3], seq+size, seq+mallocSize)
return seq, 0, nil
return seq, states[4], nil
}
default:
log.ZError(ctx, "malloc seq unknown state", nil, "state", states[0], "conversationID", conversationID, "size", size)
@ -337,6 +419,14 @@ func (s *seqConversationCacheRedis) GetMaxSeq(ctx context.Context, conversationI
return s.Malloc(ctx, conversationID, 0)
}
func (s *seqConversationCacheRedis) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) {
seq, mill, err := s.mallocTime(ctx, conversationID, 0)
if err != nil {
return database.SeqTime{}, err
}
return database.SeqTime{Seq: seq, Time: mill}, nil
}
func (s *seqConversationCacheRedis) SetMinSeqs(ctx context.Context, seqs map[string]int64) error {
keys := make([]string, 0, len(seqs))
for conversationID, seq := range seqs {
@ -348,6 +438,79 @@ func (s *seqConversationCacheRedis) SetMinSeqs(ctx context.Context, seqs map[str
return DeleteCacheBySlot(ctx, s.rocks, keys)
}
func (s *seqConversationCacheRedis) GetMaxSeqWithTime(ctx context.Context, conversationID string) (int64, int64, error) {
return s.mallocTime(ctx, conversationID, 0)
// GetCacheMaxSeqWithTime only get the existing cache, if there is no cache, no cache will be generated
func (s *seqConversationCacheRedis) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
if len(conversationIDs) == 0 {
return map[string]database.SeqTime{}, nil
}
key2conversationID := make(map[string]string)
keys := make([]string, 0, len(conversationIDs))
for _, conversationID := range conversationIDs {
key := s.getSeqMallocKey(conversationID)
if _, ok := key2conversationID[key]; ok {
continue
}
key2conversationID[key] = conversationID
keys = append(keys, key)
}
slotKeys, err := groupKeysBySlot(ctx, s.rdb, keys)
if err != nil {
return nil, err
}
res := make(map[string]database.SeqTime)
for _, keys := range slotKeys {
if len(keys) == 0 {
continue
}
pipe := s.rdb.Pipeline()
cmds := make([]*redis.SliceCmd, 0, len(keys))
for _, key := range keys {
cmds = append(cmds, pipe.HMGet(ctx, key, "CURR", "TIME"))
}
if _, err := pipe.Exec(ctx); err != nil {
return nil, errs.Wrap(err)
}
for i, cmd := range cmds {
val, err := cmd.Result()
if err != nil {
return nil, err
}
if len(val) != 2 {
return nil, errs.WrapMsg(err, "GetCacheMaxSeqWithTime invalid result", "key", keys[i], "res", val)
}
if val[0] == nil {
continue
}
seq, err := s.parseInt64(val[0])
if err != nil {
return nil, err
}
mill, err := s.parseInt64(val[1])
if err != nil {
return nil, err
}
conversationID := key2conversationID[keys[i]]
res[conversationID] = database.SeqTime{Seq: seq, Time: mill}
}
}
return res, nil
}
func (s *seqConversationCacheRedis) parseInt64(val any) (int64, error) {
switch v := val.(type) {
case nil:
return 0, nil
case int:
return int64(v), nil
case int64:
return v, nil
case string:
res, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, errs.WrapMsg(err, "invalid string not int64", "value", v)
}
return res, nil
default:
return 0, errs.New("invalid result not int64", "resType", fmt.Sprintf("%T", v), "value", v)
}
}

@ -110,10 +110,34 @@ func TestMinSeq(t *testing.T) {
func TestMalloc(t *testing.T) {
ts := newTestSeq()
t.Log(ts.Malloc(context.Background(), "10000000", 100))
t.Log(ts.mallocTime(context.Background(), "10000000", 100))
}
func TestHMGET(t *testing.T) {
ts := newTestSeq()
res, err := ts.GetCacheMaxSeqWithTime(context.Background(), []string{"10000000", "123456"})
if err != nil {
panic(err)
}
t.Log(res)
}
func TestGetMaxSeqWithTime(t *testing.T) {
ts := newTestSeq()
t.Log(ts.GetMaxSeqWithTime(context.Background(), "10000000"))
}
func TestGetMaxSeqWithTime1(t *testing.T) {
ts := newTestSeq()
t.Log(ts.GetMaxSeqsWithTime(context.Background(), []string{"10000000", "12345", "111"}))
}
//
//func TestHMGET(t *testing.T) {
// ts := newTestSeq()
// res, err := ts.rdb.HMGet(context.Background(), "MALLOC_SEQ:1", "CURR", "TIME1").Result()
// if err != nil {
// panic(err)
// }
// t.Log(res)
//}

@ -1,6 +1,9 @@
package cache
import "context"
import (
"context"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
)
type SeqConversationCache interface {
Malloc(ctx context.Context, conversationID string, size int64) (int64, error)
@ -9,4 +12,7 @@ type SeqConversationCache interface {
GetMinSeq(ctx context.Context, conversationID string) (int64, error)
GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error)
SetMinSeqs(ctx context.Context, seqs map[string]int64) error
GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error)
}

@ -74,6 +74,10 @@ type CommonMsgDatabase interface {
GetHasReadSeq(ctx context.Context, userID string, conversationID string) (int64, error)
UserSetHasReadSeqs(ctx context.Context, userID string, hasReadSeqs map[string]int64) error
GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error)
GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
//GetMongoMaxAndMinSeq(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo int64, err error)
//GetConversationMinMaxSeqInMongoAndCache(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo, minSeqCache, maxSeqCache int64, err error)
SetSendMsgStatus(ctx context.Context, id string, status int32) error
@ -866,3 +870,15 @@ func (db *commonMsgDatabase) setMinSeq(ctx context.Context, conversationID strin
func (db *commonMsgDatabase) GetDocIDs(ctx context.Context) ([]string, error) {
return db.msgDocDatabase.GetDocIDs(ctx)
}
func (db *commonMsgDatabase) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
return db.seqConversation.GetCacheMaxSeqWithTime(ctx, conversationIDs)
}
func (db *commonMsgDatabase) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) {
return db.seqConversation.GetMaxSeqWithTime(ctx, conversationID)
}
func (db *commonMsgDatabase) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
return db.seqConversation.GetMaxSeqsWithTime(ctx, conversationIDs)
}

@ -2,6 +2,11 @@ package database
import "context"
type SeqTime struct {
Seq int64
Time int64
}
type SeqConversation interface {
Malloc(ctx context.Context, conversationID string, size int64) (int64, error)
GetMaxSeq(ctx context.Context, conversationID string) (int64, error)

Loading…
Cancel
Save