diff --git a/go.mod b/go.mod index 14bd24ad5..5cc1f9ad3 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/openimsdk/protocol v0.0.72-alpha.29 + github.com/openimsdk/protocol v0.0.72-alpha.30 github.com/openimsdk/tools v0.0.50-alpha.12 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.18.0 diff --git a/go.sum b/go.sum index aeab03055..1095134dd 100644 --- a/go.sum +++ b/go.sum @@ -319,8 +319,8 @@ github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= github.com/onsi/gomega v1.25.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM= github.com/openimsdk/gomake v0.0.14-alpha.5 h1:VY9c5x515lTfmdhhPjMvR3BBRrRquAUCFsz7t7vbv7Y= github.com/openimsdk/gomake v0.0.14-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= -github.com/openimsdk/protocol v0.0.72-alpha.29 h1:z6Bm57IW/HNxTAJmqYjhVaLRUJLVIK0EH7G7HBzbwdc= -github.com/openimsdk/protocol v0.0.72-alpha.29/go.mod h1:OZQA9FR55lseYoN2Ql1XAHYKHJGu7OMNkUbuekrKCM8= +github.com/openimsdk/protocol v0.0.72-alpha.30 h1:LBIqDzD55cSQy3wX8fgSa3blz8+Cv54ae96/qUMINwM= +github.com/openimsdk/protocol v0.0.72-alpha.30/go.mod h1:OZQA9FR55lseYoN2Ql1XAHYKHJGu7OMNkUbuekrKCM8= github.com/openimsdk/tools v0.0.50-alpha.12 h1:rV3BxgqN+F79vZvdoQ+97Eob8ScsRVEM8D+Wrcl23uo= github.com/openimsdk/tools v0.0.50-alpha.12/go.mod h1:h1cYmfyaVtgFbKmb1Cfsl8XwUOMTt8ubVUQrdGtsUh4= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= diff --git a/internal/api/jssdk/jssdk.go b/internal/api/jssdk/jssdk.go new file mode 100644 index 000000000..7f136c74c --- /dev/null +++ b/internal/api/jssdk/jssdk.go @@ -0,0 +1,204 @@ +package jssdk + +import ( + "github.com/gin-gonic/gin" + "github.com/openimsdk/protocol/conversation" + "github.com/openimsdk/protocol/msg" + "github.com/openimsdk/protocol/sdkws" + "github.com/openimsdk/tools/a2r" + "github.com/openimsdk/tools/mcontext" + "github.com/openimsdk/tools/utils/datautil" + "sort" +) + +const ( + maxGetActiveConversation = 500 + defaultGetActiveConversation = 100 +) + +func NewJSSdkApi(msg msg.MsgClient, conv conversation.ConversationClient) *JSSdk { + return &JSSdk{ + msg: msg, + conv: conv, + } +} + +type JSSdk struct { + msg msg.MsgClient + conv conversation.ConversationClient +} + +func (x *JSSdk) GetActiveConversations(c *gin.Context) { + call(c, x.getActiveConversations) +} + +func (x *JSSdk) GetConversations(c *gin.Context) { + call(c, x.getConversations) +} + +func (x *JSSdk) getActiveConversations(ctx *gin.Context) (*ConversationsResp, error) { + req, err := a2r.ParseRequest[ActiveConversationsReq](ctx) + if err != nil { + return nil, err + } + if req.Count <= 0 || req.Count > maxGetActiveConversation { + req.Count = defaultGetActiveConversation + } + opUserID := mcontext.GetOpUserID(ctx) + conversationIDs, err := field(ctx, x.conv.GetConversationIDs, + &conversation.GetConversationIDsReq{UserID: opUserID}, (*conversation.GetConversationIDsResp).GetConversationIDs) + if err != nil { + return nil, err + } + if len(conversationIDs) == 0 { + return &ConversationsResp{}, nil + } + readSeq, err := field(ctx, x.msg.GetHasReadSeqs, + &msg.GetHasReadSeqsReq{UserID: opUserID, ConversationIDs: conversationIDs}, (*msg.SeqsInfoResp).GetMaxSeqs) + if err != nil { + return nil, err + } + activeConversation, err := field(ctx, x.msg.GetActiveConversation, + &msg.GetActiveConversationReq{ConversationIDs: conversationIDs}, (*msg.GetActiveConversationResp).GetConversations) + if err != nil { + return nil, err + } + if len(activeConversation) == 0 { + return &ConversationsResp{}, nil + } + sortConversations := sortActiveConversations{ + Conversation: activeConversation, + } + if len(activeConversation) > 1 { + pinnedConversationIDs, err := field(ctx, x.conv.GetPinnedConversationIDs, + &conversation.GetPinnedConversationIDsReq{UserID: opUserID}, (*conversation.GetPinnedConversationIDsResp).GetConversationIDs) + if err != nil { + return nil, err + } + sortConversations.PinnedConversationIDs = datautil.SliceSet(pinnedConversationIDs) + } + sort.Sort(&sortConversations) + sortList := sortConversations.Top(req.Count) + conversations, err := field(ctx, x.conv.GetConversations, + &conversation.GetConversationsReq{ + OwnerUserID: opUserID, + ConversationIDs: datautil.Slice(sortList, func(c *msg.ActiveConversation) string { + return c.ConversationID + })}, (*conversation.GetConversationsResp).GetConversations) + if err != nil { + return nil, err + } + msgs, err := field(ctx, x.msg.GetSeqMessage, + &msg.GetSeqMessageReq{ + UserID: opUserID, + Conversations: datautil.Slice(sortList, func(c *msg.ActiveConversation) *msg.ConversationSeqs { + return &msg.ConversationSeqs{ + ConversationID: c.ConversationID, + Seqs: []int64{c.MaxSeq}, + } + }), + }, (*msg.GetSeqMessageResp).GetMsgs) + if err != nil { + return nil, err + } + conversationMap := datautil.SliceToMap(conversations, func(c *conversation.Conversation) string { + return c.ConversationID + }) + resp := make([]ConversationMsg, 0, len(sortList)) + for _, c := range sortList { + conv, ok := conversationMap[c.ConversationID] + if !ok { + continue + } + var lastMsg *sdkws.MsgData + if msgList, ok := msgs[c.ConversationID]; ok && len(msgList.Msgs) > 0 { + lastMsg = msgList.Msgs[0] + } + resp = append(resp, ConversationMsg{ + Conversation: conv, + LastMsg: lastMsg, + MaxSeq: c.MaxSeq, + ReadSeq: readSeq[c.ConversationID], + }) + } + var unreadCount int64 + for _, c := range activeConversation { + count := c.MaxSeq - readSeq[c.ConversationID] + if count > 0 { + unreadCount += count + } + } + return &ConversationsResp{ + Conversations: resp, + UnreadCount: unreadCount, + }, nil +} + +func (x *JSSdk) getConversations(ctx *gin.Context) (*ConversationsResp, error) { + req, err := a2r.ParseRequest[conversation.GetConversationsReq](ctx) + if err != nil { + return nil, err + } + req.OwnerUserID = mcontext.GetOpUserID(ctx) + conversations, err := field(ctx, x.conv.GetConversations, req, (*conversation.GetConversationsResp).GetConversations) + if err != nil { + return nil, err + } + if len(conversations) == 0 { + return &ConversationsResp{}, nil + } + req.ConversationIDs = datautil.Slice(conversations, func(c *conversation.Conversation) string { + return c.ConversationID + }) + maxSeqs, err := field(ctx, x.msg.GetMaxSeqs, + &msg.GetMaxSeqsReq{ConversationIDs: req.ConversationIDs}, (*msg.SeqsInfoResp).GetMaxSeqs) + if err != nil { + return nil, err + } + readSeqs, err := field(ctx, x.msg.GetHasReadSeqs, + &msg.GetHasReadSeqsReq{UserID: req.OwnerUserID, ConversationIDs: req.ConversationIDs}, (*msg.SeqsInfoResp).GetMaxSeqs) + if err != nil { + return nil, err + } + conversationSeqs := make([]*msg.ConversationSeqs, 0, len(conversations)) + for _, c := range conversations { + if seq := maxSeqs[c.ConversationID]; seq > 0 { + conversationSeqs = append(conversationSeqs, &msg.ConversationSeqs{ + ConversationID: c.ConversationID, + Seqs: []int64{seq}, + }) + } + } + var msgs map[string]*sdkws.PullMsgs + if len(conversationSeqs) > 0 { + msgs, err = field(ctx, x.msg.GetSeqMessage, + &msg.GetSeqMessageReq{UserID: req.OwnerUserID, Conversations: conversationSeqs}, (*msg.GetSeqMessageResp).GetMsgs) + if err != nil { + return nil, err + } + } + resp := make([]ConversationMsg, 0, len(conversations)) + for _, c := range conversations { + var lastMsg *sdkws.MsgData + if msgList, ok := msgs[c.ConversationID]; ok && len(msgList.Msgs) > 0 { + lastMsg = msgList.Msgs[0] + } + resp = append(resp, ConversationMsg{ + Conversation: c, + LastMsg: lastMsg, + MaxSeq: maxSeqs[c.ConversationID], + ReadSeq: readSeqs[c.ConversationID], + }) + } + var unreadCount int64 + for conversationID, maxSeq := range maxSeqs { + count := maxSeq - readSeqs[conversationID] + if count > 0 { + unreadCount += count + } + } + return &ConversationsResp{ + Conversations: resp, + UnreadCount: unreadCount, + }, nil +} diff --git a/internal/api/jssdk/sort.go b/internal/api/jssdk/sort.go new file mode 100644 index 000000000..f5fd04148 --- /dev/null +++ b/internal/api/jssdk/sort.go @@ -0,0 +1,33 @@ +package jssdk + +import "github.com/openimsdk/protocol/msg" + +type sortActiveConversations struct { + Conversation []*msg.ActiveConversation + PinnedConversationIDs map[string]struct{} +} + +func (s sortActiveConversations) Top(limit int) []*msg.ActiveConversation { + if limit > 0 && len(s.Conversation) > limit { + return s.Conversation[:limit] + } + return s.Conversation +} + +func (s sortActiveConversations) Len() int { + return len(s.Conversation) +} + +func (s sortActiveConversations) Less(i, j int) bool { + iv, jv := s.Conversation[i], s.Conversation[j] + _, ip := s.PinnedConversationIDs[iv.ConversationID] + _, jp := s.PinnedConversationIDs[jv.ConversationID] + if ip != jp { + return ip + } + return iv.LastTime > jv.LastTime +} + +func (s sortActiveConversations) Swap(i, j int) { + s.Conversation[i], s.Conversation[j] = s.Conversation[j], s.Conversation[i] +} diff --git a/internal/api/jssdk/stu.go b/internal/api/jssdk/stu.go new file mode 100644 index 000000000..2f63975b3 --- /dev/null +++ b/internal/api/jssdk/stu.go @@ -0,0 +1,22 @@ +package jssdk + +import ( + "github.com/openimsdk/protocol/conversation" + "github.com/openimsdk/protocol/sdkws" +) + +type ActiveConversationsReq struct { + Count int `json:"count"` +} + +type ConversationMsg struct { + Conversation *conversation.Conversation `json:"conversation"` + LastMsg *sdkws.MsgData `json:"lastMsg"` + MaxSeq int64 `json:"maxSeq"` + ReadSeq int64 `json:"readSeq"` +} + +type ConversationsResp struct { + UnreadCount int64 `json:"unreadCount"` + Conversations []ConversationMsg `json:"conversations"` +} diff --git a/internal/api/jssdk/tools.go b/internal/api/jssdk/tools.go new file mode 100644 index 000000000..c57457d9f --- /dev/null +++ b/internal/api/jssdk/tools.go @@ -0,0 +1,26 @@ +package jssdk + +import ( + "context" + "github.com/gin-gonic/gin" + "github.com/openimsdk/tools/apiresp" + "google.golang.org/grpc" +) + +func field[A, B, C any](ctx context.Context, fn func(ctx context.Context, req *A, opts ...grpc.CallOption) (*B, error), req *A, get func(*B) C) (C, error) { + resp, err := fn(ctx, req) + if err != nil { + var c C + return c, err + } + return get(resp), nil +} + +func call[R any](c *gin.Context, fn func(ctx *gin.Context) (R, error)) { + resp, err := fn(c) + if err != nil { + apiresp.GinError(c, err) + return + } + apiresp.GinSuccess(c, resp) +} diff --git a/internal/api/jssdk_test.go b/internal/api/jssdk_test.go new file mode 100644 index 000000000..472ca56b5 --- /dev/null +++ b/internal/api/jssdk_test.go @@ -0,0 +1,37 @@ +package api + +import ( + "github.com/openimsdk/protocol/msg" + "sort" + "testing" +) + +func TestName(t *testing.T) { + val := sortActiveConversations{ + Conversation: []*msg.ActiveConversation{ + { + ConversationID: "100", + LastTime: 100, + }, + { + ConversationID: "200", + LastTime: 200, + }, + { + ConversationID: "300", + LastTime: 300, + }, + { + ConversationID: "400", + LastTime: 400, + }, + }, + //PinnedConversationIDs: map[string]struct{}{ + // "100": {}, + // "300": {}, + //}, + } + sort.Sort(&val) + t.Log(val) + +} diff --git a/internal/api/router.go b/internal/api/router.go index 91e45340e..8d2a688f4 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -2,6 +2,7 @@ package api import ( "fmt" + "github.com/openimsdk/open-im-server/v3/internal/api/jssdk" "github.com/gin-contrib/gzip" @@ -75,6 +76,7 @@ func newGinRouter(disCov discovery.SvcDiscoveryRegistry, config *Config) *gin.En r.Use(prommetricsGin(), gin.Recovery(), mw.CorsHandler(), mw.GinParseOperationID(), GinParseToken(authRpc)) u := NewUserApi(*userRpc) m := NewMessageApi(messageRpc, userRpc, config.Share.IMAdminUserID) + j := jssdk.NewJSSdkApi(messageRpc.Client, conversationRpc.Client) userRouterGroup := r.Group("/user") { userRouterGroup.POST("/user_register", u.UserRegister) @@ -244,6 +246,11 @@ func newGinRouter(disCov discovery.SvcDiscoveryRegistry, config *Config) *gin.En statisticsGroup.POST("/group/create", g.GroupCreateCount) statisticsGroup.POST("/group/active", m.GetActiveGroup) } + + jssdk := r.Group("/jssdk") + jssdk.POST("/get_conversations", j.GetConversations) + jssdk.POST("/get_active_conversations", j.GetActiveConversations) + return r } diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index dd296e481..1b70c81fc 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -66,6 +66,11 @@ type groupServer struct { webhookClient *webhook.Client } +func (s *groupServer) GetSpecifiedUserGroupRequestInfo(ctx context.Context, req *pbgroup.GetSpecifiedUserGroupRequestInfoReq) (*pbgroup.GetSpecifiedUserGroupRequestInfoResp, error) { + //TODO implement me + panic("implement me") +} + type Config struct { RpcConfig config.Group RedisConfig config.Redis diff --git a/internal/rpc/msg/as_read.go b/internal/rpc/msg/as_read.go index bfba4824f..03f35b42d 100644 --- a/internal/rpc/msg/as_read.go +++ b/internal/rpc/msg/as_read.go @@ -55,7 +55,7 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m conversationMaxSeqMap[conversation.ConversationID] = conversation.MaxSeq } } - maxSeqs, err := m.MsgDatabase.GetMaxSeqs(ctx, conversationIDs) + maxSeqs, err := m.MsgDatabase.GetMaxSeqsWithTime(ctx, conversationIDs) if err != nil { return nil, err } @@ -63,7 +63,8 @@ func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *m for conversationID, maxSeq := range maxSeqs { resp.Seqs[conversationID] = &msg.Seqs{ HasReadSeq: hasReadSeqs[conversationID], - MaxSeq: maxSeq, + MaxSeq: maxSeq.Seq, + MaxSeqTime: maxSeq.Time, } if v, ok := conversationMaxSeqMap[conversationID]; ok { resp.Seqs[conversationID].MaxSeq = v diff --git a/internal/rpc/msg/seq.go b/internal/rpc/msg/seq.go index 4d9eb6db9..5d40160de 100644 --- a/internal/rpc/msg/seq.go +++ b/internal/rpc/msg/seq.go @@ -16,10 +16,10 @@ package msg import ( "context" + pbmsg "github.com/openimsdk/protocol/msg" "github.com/openimsdk/tools/errs" "github.com/redis/go-redis/v9" - - pbmsg "github.com/openimsdk/protocol/msg" + "sort" ) func (m *msgServer) GetConversationMaxSeq(ctx context.Context, req *pbmsg.GetConversationMaxSeqReq) (*pbmsg.GetConversationMaxSeqResp, error) { @@ -62,3 +62,25 @@ func (m *msgServer) SetUserConversationsMinSeq(ctx context.Context, req *pbmsg.S } return &pbmsg.SetUserConversationsMinSeqResp{}, nil } + +func (m *msgServer) GetActiveConversation(ctx context.Context, req *pbmsg.GetActiveConversationReq) (*pbmsg.GetActiveConversationResp, error) { + res, err := m.MsgDatabase.GetCacheMaxSeqWithTime(ctx, req.ConversationIDs) + if err != nil { + return nil, err + } + conversations := make([]*pbmsg.ActiveConversation, 0, len(res)) + for conversationID, val := range res { + conversations = append(conversations, &pbmsg.ActiveConversation{ + MaxSeq: val.Seq, + LastTime: val.Time, + ConversationID: conversationID, + }) + } + if req.Limit > 0 { + sort.Sort(activeConversations(conversations)) + if len(conversations) > int(req.Limit) { + conversations = conversations[:req.Limit] + } + } + return &pbmsg.GetActiveConversationResp{Conversations: conversations}, nil +} diff --git a/internal/rpc/msg/utils.go b/internal/rpc/msg/utils.go index 69b4d0bf6..e3490848c 100644 --- a/internal/rpc/msg/utils.go +++ b/internal/rpc/msg/utils.go @@ -15,6 +15,7 @@ package msg import ( + "github.com/openimsdk/protocol/msg" "github.com/openimsdk/tools/errs" "github.com/redis/go-redis/v9" "go.mongodb.org/mongo-driver/mongo" @@ -28,3 +29,63 @@ func IsNotFound(err error) bool { return false } } + +type activeConversations []*msg.ActiveConversation + +func (s activeConversations) Len() int { + return len(s) +} + +func (s activeConversations) Less(i, j int) bool { + return s[i].LastTime > s[j].LastTime +} + +func (s activeConversations) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +//type seqTime struct { +// ConversationID string +// Seq int64 +// Time int64 +// Unread int64 +// Pinned bool +//} +// +//func (s seqTime) String() string { +// return fmt.Sprintf("", s.Time, s.Unread, s.Pinned) +//} +// +//type seqTimes []seqTime +// +//func (s seqTimes) Len() int { +// return len(s) +//} +// +//// Less sticky priority, unread priority, time descending +//func (s seqTimes) Less(i, j int) bool { +// iv, jv := s[i], s[j] +// if iv.Pinned && (!jv.Pinned) { +// return true +// } +// if jv.Pinned && (!iv.Pinned) { +// return false +// } +// if iv.Unread > 0 && jv.Unread == 0 { +// return true +// } +// if jv.Unread > 0 && iv.Unread == 0 { +// return false +// } +// return iv.Time > jv.Time +//} +// +//func (s seqTimes) Swap(i, j int) { +// s[i], s[j] = s[j], s[i] +//} +// +//type conversationStatus struct { +// ConversationID string +// Pinned bool +// Recv bool +//} diff --git a/pkg/common/storage/cache/redis/seq_conversation.go b/pkg/common/storage/cache/redis/seq_conversation.go index 7fe849193..71705cef7 100644 --- a/pkg/common/storage/cache/redis/seq_conversation.go +++ b/pkg/common/storage/cache/redis/seq_conversation.go @@ -12,6 +12,7 @@ import ( "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/redis/go-redis/v9" + "strconv" "time" ) @@ -57,6 +58,14 @@ func (s *seqConversationCacheRedis) getSingleMaxSeq(ctx context.Context, convers return map[string]int64{conversationID: seq}, nil } +func (s *seqConversationCacheRedis) getSingleMaxSeqWithTime(ctx context.Context, conversationID string) (map[string]database.SeqTime, error) { + seq, err := s.GetMaxSeqWithTime(ctx, conversationID) + if err != nil { + return nil, err + } + return map[string]database.SeqTime{conversationID: seq}, nil +} + func (s *seqConversationCacheRedis) batchGetMaxSeq(ctx context.Context, keys []string, keyConversationID map[string]string, seqs map[string]int64) error { result := make([]*redis.StringCmd, len(keys)) pipe := s.rdb.Pipeline() @@ -88,6 +97,46 @@ func (s *seqConversationCacheRedis) batchGetMaxSeq(ctx context.Context, keys []s return nil } +func (s *seqConversationCacheRedis) batchGetMaxSeqWithTime(ctx context.Context, keys []string, keyConversationID map[string]string, seqs map[string]database.SeqTime) error { + result := make([]*redis.SliceCmd, len(keys)) + pipe := s.rdb.Pipeline() + for i, key := range keys { + result[i] = pipe.HMGet(ctx, key, "CURR", "TIME") + } + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return errs.Wrap(err) + } + var notFoundKey []string + for i, r := range result { + val, err := r.Result() + if len(val) != 2 { + return errs.WrapMsg(err, "batchGetMaxSeqWithTime invalid result", "key", keys[i], "res", val) + } + if val[0] == nil { + notFoundKey = append(notFoundKey, keys[i]) + continue + } + seq, err := s.parseInt64(val[0]) + if err != nil { + return err + } + mill, err := s.parseInt64(val[1]) + if err != nil { + return err + } + seqs[keyConversationID[keys[i]]] = database.SeqTime{Seq: seq, Time: mill} + } + for _, key := range notFoundKey { + conversationID := keyConversationID[key] + seq, err := s.GetMaxSeqWithTime(ctx, conversationID) + if err != nil { + return err + } + seqs[conversationID] = seq + } + return nil +} + func (s *seqConversationCacheRedis) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) { switch len(conversationIDs) { case 0: @@ -121,11 +170,44 @@ func (s *seqConversationCacheRedis) GetMaxSeqs(ctx context.Context, conversation return seqs, nil } +func (s *seqConversationCacheRedis) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) { + switch len(conversationIDs) { + case 0: + return map[string]database.SeqTime{}, nil + case 1: + return s.getSingleMaxSeqWithTime(ctx, conversationIDs[0]) + } + keys := make([]string, 0, len(conversationIDs)) + keyConversationID := make(map[string]string, len(conversationIDs)) + for _, conversationID := range conversationIDs { + key := s.getSeqMallocKey(conversationID) + if _, ok := keyConversationID[key]; ok { + continue + } + keys = append(keys, key) + keyConversationID[key] = conversationID + } + if len(keys) == 1 { + return s.getSingleMaxSeqWithTime(ctx, conversationIDs[0]) + } + slotKeys, err := groupKeysBySlot(ctx, s.rdb, keys) + if err != nil { + return nil, err + } + seqs := make(map[string]database.SeqTime, len(conversationIDs)) + for _, keys := range slotKeys { + if err := s.batchGetMaxSeqWithTime(ctx, keys, keyConversationID, seqs); err != nil { + return nil, err + } + } + return seqs, nil +} + func (s *seqConversationCacheRedis) getSeqMallocKey(conversationID string) string { return cachekey.GetMallocSeqKey(conversationID) } -func (s *seqConversationCacheRedis) setSeq(ctx context.Context, key string, owner int64, currSeq int64, lastSeq int64) (int64, error) { +func (s *seqConversationCacheRedis) setSeq(ctx context.Context, key string, owner int64, currSeq int64, lastSeq int64, mill int64) (int64, error) { if lastSeq < currSeq { return 0, errs.New("lastSeq must be greater than currSeq") } @@ -138,8 +220,9 @@ local lockValue = ARGV[1] local dataSecond = ARGV[2] local curr_seq = tonumber(ARGV[3]) local last_seq = tonumber(ARGV[4]) +local mallocTime = ARGV[5] if redis.call("EXISTS", key) == 0 then - redis.call("HSET", key, "CURR", curr_seq, "LAST", last_seq) + redis.call("HSET", key, "CURR", curr_seq, "LAST", last_seq, "TIME", mallocTime) redis.call("EXPIRE", key, dataSecond) return 1 end @@ -147,11 +230,11 @@ if redis.call("HGET", key, "LOCK") ~= lockValue then return 2 end redis.call("HDEL", key, "LOCK") -redis.call("HSET", key, "CURR", curr_seq, "LAST", last_seq) +redis.call("HSET", key, "CURR", curr_seq, "LAST", last_seq, "TIME", mallocTime) redis.call("EXPIRE", key, dataSecond) return 0 ` - result, err := s.rdb.Eval(ctx, script, []string{key}, owner, int64(s.dataTime/time.Second), currSeq, lastSeq).Int64() + result, err := s.rdb.Eval(ctx, script, []string{key}, owner, int64(s.dataTime/time.Second), currSeq, lastSeq, mill).Int64() if err != nil { return 0, errs.Wrap(err) } @@ -169,6 +252,7 @@ local key = KEYS[1] local size = tonumber(ARGV[1]) local lockSecond = ARGV[2] local dataSecond = ARGV[3] +local mallocTime = ARGV[4] local result = {} if redis.call("EXISTS", key) == 0 then local lockValue = math.random(0, 999999999) @@ -176,6 +260,7 @@ if redis.call("EXISTS", key) == 0 then redis.call("EXPIRE", key, lockSecond) table.insert(result, 1) table.insert(result, lockValue) + table.insert(result, mallocTime) return result end if redis.call("HEXISTS", key, "LOCK") == 1 then @@ -189,6 +274,12 @@ if size == 0 then table.insert(result, 0) table.insert(result, curr_seq) table.insert(result, last_seq) + local setTime = redis.call("HGET", key, "TIME") + if setTime then + table.insert(result, setTime) + else + table.insert(result, 0) + end return result end local max_seq = curr_seq + size @@ -196,21 +287,25 @@ if max_seq > last_seq then local lockValue = math.random(0, 999999999) redis.call("HSET", key, "LOCK", lockValue) redis.call("HSET", key, "CURR", last_seq) + redis.call("HSET", key, "TIME", mallocTime) redis.call("EXPIRE", key, lockSecond) table.insert(result, 3) table.insert(result, curr_seq) table.insert(result, last_seq) table.insert(result, lockValue) + table.insert(result, mallocTime) return result end redis.call("HSET", key, "CURR", max_seq) +redis.call("HSET", key, "TIME", ARGV[4]) redis.call("EXPIRE", key, dataSecond) table.insert(result, 0) table.insert(result, curr_seq) table.insert(result, last_seq) +table.insert(result, mallocTime) return result ` - result, err := s.rdb.Eval(ctx, script, []string{key}, size, int64(s.lockTime/time.Second), int64(s.dataTime/time.Second)).Int64Slice() + result, err := s.rdb.Eval(ctx, script, []string{key}, size, int64(s.lockTime/time.Second), int64(s.dataTime/time.Second), time.Now().UnixMilli()).Int64Slice() if err != nil { return nil, errs.Wrap(err) } @@ -228,9 +323,9 @@ func (s *seqConversationCacheRedis) wait(ctx context.Context) error { } } -func (s *seqConversationCacheRedis) setSeqRetry(ctx context.Context, key string, owner int64, currSeq int64, lastSeq int64) { +func (s *seqConversationCacheRedis) setSeqRetry(ctx context.Context, key string, owner int64, currSeq int64, lastSeq int64, mill int64) { for i := 0; i < 10; i++ { - state, err := s.setSeq(ctx, key, owner, currSeq, lastSeq) + state, err := s.setSeq(ctx, key, owner, currSeq, lastSeq, mill) if err != nil { log.ZError(ctx, "set seq cache failed", err, "key", key, "owner", owner, "currSeq", currSeq, "lastSeq", lastSeq, "count", i+1) if err := s.wait(ctx); err != nil { @@ -267,60 +362,74 @@ func (s *seqConversationCacheRedis) getMallocSize(conversationID string, size in } func (s *seqConversationCacheRedis) Malloc(ctx context.Context, conversationID string, size int64) (int64, error) { + seq, _, err := s.mallocTime(ctx, conversationID, size) + return seq, err +} + +func (s *seqConversationCacheRedis) mallocTime(ctx context.Context, conversationID string, size int64) (int64, int64, error) { if size < 0 { - return 0, errs.New("size must be greater than 0") + return 0, 0, errs.New("size must be greater than 0") } key := s.getSeqMallocKey(conversationID) for i := 0; i < 10; i++ { states, err := s.malloc(ctx, key, size) if err != nil { - return 0, err + return 0, 0, err } switch states[0] { case 0: // success - return states[1], nil + return states[1], states[3], nil case 1: // not found mallocSize := s.getMallocSize(conversationID, size) seq, err := s.mgo.Malloc(ctx, conversationID, mallocSize) if err != nil { - return 0, err + return 0, 0, err } - s.setSeqRetry(ctx, key, states[1], seq+size, seq+mallocSize) - return seq, nil + s.setSeqRetry(ctx, key, states[1], seq+size, seq+mallocSize, states[2]) + return seq, 0, nil case 2: // locked if err := s.wait(ctx); err != nil { - return 0, err + return 0, 0, err } continue case 3: // exceeded cache max value currSeq := states[1] lastSeq := states[2] + mill := states[4] mallocSize := s.getMallocSize(conversationID, size) seq, err := s.mgo.Malloc(ctx, conversationID, mallocSize) if err != nil { - return 0, err + return 0, 0, err } if lastSeq == seq { - s.setSeqRetry(ctx, key, states[3], currSeq+size, seq+mallocSize) - return currSeq, nil + s.setSeqRetry(ctx, key, states[3], currSeq+size, seq+mallocSize, mill) + return currSeq, states[4], nil } else { log.ZWarn(ctx, "malloc seq not equal cache last seq", nil, "conversationID", conversationID, "currSeq", currSeq, "lastSeq", lastSeq, "mallocSeq", seq) - s.setSeqRetry(ctx, key, states[3], seq+size, seq+mallocSize) - return seq, nil + s.setSeqRetry(ctx, key, states[3], seq+size, seq+mallocSize, mill) + return seq, mill, nil } default: log.ZError(ctx, "malloc seq unknown state", nil, "state", states[0], "conversationID", conversationID, "size", size) - return 0, errs.New(fmt.Sprintf("unknown state: %d", states[0])) + return 0, 0, errs.New(fmt.Sprintf("unknown state: %d", states[0])) } } log.ZError(ctx, "malloc seq retrying still failed", nil, "conversationID", conversationID, "size", size) - return 0, errs.New("malloc seq waiting for lock timeout", "conversationID", conversationID, "size", size) + return 0, 0, errs.New("malloc seq waiting for lock timeout", "conversationID", conversationID, "size", size) } func (s *seqConversationCacheRedis) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) { return s.Malloc(ctx, conversationID, 0) } +func (s *seqConversationCacheRedis) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) { + seq, mill, err := s.mallocTime(ctx, conversationID, 0) + if err != nil { + return database.SeqTime{}, err + } + return database.SeqTime{Seq: seq, Time: mill}, nil +} + func (s *seqConversationCacheRedis) SetMinSeqs(ctx context.Context, seqs map[string]int64) error { keys := make([]string, 0, len(seqs)) for conversationID, seq := range seqs { @@ -331,3 +440,80 @@ func (s *seqConversationCacheRedis) SetMinSeqs(ctx context.Context, seqs map[str } return DeleteCacheBySlot(ctx, s.rocks, keys) } + +// GetCacheMaxSeqWithTime only get the existing cache, if there is no cache, no cache will be generated +func (s *seqConversationCacheRedis) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) { + if len(conversationIDs) == 0 { + return map[string]database.SeqTime{}, nil + } + key2conversationID := make(map[string]string) + keys := make([]string, 0, len(conversationIDs)) + for _, conversationID := range conversationIDs { + key := s.getSeqMallocKey(conversationID) + if _, ok := key2conversationID[key]; ok { + continue + } + key2conversationID[key] = conversationID + keys = append(keys, key) + } + slotKeys, err := groupKeysBySlot(ctx, s.rdb, keys) + if err != nil { + return nil, err + } + res := make(map[string]database.SeqTime) + for _, keys := range slotKeys { + if len(keys) == 0 { + continue + } + pipe := s.rdb.Pipeline() + cmds := make([]*redis.SliceCmd, 0, len(keys)) + for _, key := range keys { + cmds = append(cmds, pipe.HMGet(ctx, key, "CURR", "TIME")) + } + if _, err := pipe.Exec(ctx); err != nil { + return nil, errs.Wrap(err) + } + for i, cmd := range cmds { + val, err := cmd.Result() + if err != nil { + return nil, err + } + if len(val) != 2 { + return nil, errs.WrapMsg(err, "GetCacheMaxSeqWithTime invalid result", "key", keys[i], "res", val) + } + if val[0] == nil { + continue + } + seq, err := s.parseInt64(val[0]) + if err != nil { + return nil, err + } + mill, err := s.parseInt64(val[1]) + if err != nil { + return nil, err + } + conversationID := key2conversationID[keys[i]] + res[conversationID] = database.SeqTime{Seq: seq, Time: mill} + } + } + return res, nil +} + +func (s *seqConversationCacheRedis) parseInt64(val any) (int64, error) { + switch v := val.(type) { + case nil: + return 0, nil + case int: + return int64(v), nil + case int64: + return v, nil + case string: + res, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, errs.WrapMsg(err, "invalid string not int64", "value", v) + } + return res, nil + default: + return 0, errs.New("invalid result not int64", "resType", fmt.Sprintf("%T", v), "value", v) + } +} diff --git a/pkg/common/storage/cache/redis/seq_conversation_test.go b/pkg/common/storage/cache/redis/seq_conversation_test.go index 1a40624b8..d8bfdfbfb 100644 --- a/pkg/common/storage/cache/redis/seq_conversation_test.go +++ b/pkg/common/storage/cache/redis/seq_conversation_test.go @@ -14,7 +14,7 @@ import ( ) func newTestSeq() *seqConversationCacheRedis { - mgocli, err := mongo.Connect(context.Background(), options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.48:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)) + mgocli, err := mongo.Connect(context.Background(), options.Client().ApplyURI("mongodb://openIM:openIM123@127.0.0.1:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)) if err != nil { panic(err) } @@ -23,7 +23,7 @@ func newTestSeq() *seqConversationCacheRedis { panic(err) } opt := &redis.Options{ - Addr: "172.16.8.48:16379", + Addr: "127.0.0.1:16379", Password: "openIM123", DB: 1, } @@ -107,3 +107,37 @@ func TestMinSeq(t *testing.T) { ts := newTestSeq() t.Log(ts.GetMinSeq(context.Background(), "10000000")) } + +func TestMalloc(t *testing.T) { + ts := newTestSeq() + t.Log(ts.mallocTime(context.Background(), "10000000", 100)) +} + +func TestHMGET(t *testing.T) { + ts := newTestSeq() + res, err := ts.GetCacheMaxSeqWithTime(context.Background(), []string{"10000000", "123456"}) + if err != nil { + panic(err) + } + t.Log(res) +} + +func TestGetMaxSeqWithTime(t *testing.T) { + ts := newTestSeq() + t.Log(ts.GetMaxSeqWithTime(context.Background(), "10000000")) +} + +func TestGetMaxSeqWithTime1(t *testing.T) { + ts := newTestSeq() + t.Log(ts.GetMaxSeqsWithTime(context.Background(), []string{"10000000", "12345", "111"})) +} + +// +//func TestHMGET(t *testing.T) { +// ts := newTestSeq() +// res, err := ts.rdb.HMGet(context.Background(), "MALLOC_SEQ:1", "CURR", "TIME1").Result() +// if err != nil { +// panic(err) +// } +// t.Log(res) +//} diff --git a/pkg/common/storage/cache/seq_conversation.go b/pkg/common/storage/cache/seq_conversation.go index 2c893a5e8..f35d7bf52 100644 --- a/pkg/common/storage/cache/seq_conversation.go +++ b/pkg/common/storage/cache/seq_conversation.go @@ -1,6 +1,9 @@ package cache -import "context" +import ( + "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" +) type SeqConversationCache interface { Malloc(ctx context.Context, conversationID string, size int64) (int64, error) @@ -9,4 +12,7 @@ type SeqConversationCache interface { GetMinSeq(ctx context.Context, conversationID string) (int64, error) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) SetMinSeqs(ctx context.Context, seqs map[string]int64) error + GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) + GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) + GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) } diff --git a/pkg/common/storage/controller/msg.go b/pkg/common/storage/controller/msg.go index 16a7b1c9b..d579069b6 100644 --- a/pkg/common/storage/controller/msg.go +++ b/pkg/common/storage/controller/msg.go @@ -74,6 +74,10 @@ type CommonMsgDatabase interface { GetHasReadSeq(ctx context.Context, userID string, conversationID string) (int64, error) UserSetHasReadSeqs(ctx context.Context, userID string, hasReadSeqs map[string]int64) error + GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) + GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) + GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) + //GetMongoMaxAndMinSeq(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo int64, err error) //GetConversationMinMaxSeqInMongoAndCache(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo, minSeqCache, maxSeqCache int64, err error) SetSendMsgStatus(ctx context.Context, id string, status int32) error @@ -866,3 +870,16 @@ func (db *commonMsgDatabase) setMinSeq(ctx context.Context, conversationID strin func (db *commonMsgDatabase) GetDocIDs(ctx context.Context) ([]string, error) { return db.msgDocDatabase.GetDocIDs(ctx) } + +func (db *commonMsgDatabase) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) { + return db.seqConversation.GetCacheMaxSeqWithTime(ctx, conversationIDs) +} + +func (db *commonMsgDatabase) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) { + return db.seqConversation.GetMaxSeqWithTime(ctx, conversationID) +} + +func (db *commonMsgDatabase) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) { + // todo: only the time in the redis cache will be taken, not the message time + return db.seqConversation.GetMaxSeqsWithTime(ctx, conversationIDs) +} diff --git a/pkg/common/storage/database/seq.go b/pkg/common/storage/database/seq.go index cf93b795f..a97ca2d1f 100644 --- a/pkg/common/storage/database/seq.go +++ b/pkg/common/storage/database/seq.go @@ -2,6 +2,11 @@ package database import "context" +type SeqTime struct { + Seq int64 + Time int64 +} + type SeqConversation interface { Malloc(ctx context.Context, conversationID string, size int64) (int64, error) GetMaxSeq(ctx context.Context, conversationID string) (int64, error)