diff --git a/go.mod b/go.mod index e3edae0ec..051beb403 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/openimsdk/protocol v0.0.72-alpha.48 + github.com/openimsdk/protocol v0.0.72-alpha.54 github.com/openimsdk/tools v0.0.50-alpha.16 github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.18.0 @@ -197,5 +197,3 @@ require ( golang.org/x/crypto v0.27.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) - -replace github.com/openimsdk/protocol => /Users/chao/Desktop/withchao/protocol \ No newline at end of file diff --git a/go.sum b/go.sum index 48a42a251..816c82094 100644 --- a/go.sum +++ b/go.sum @@ -319,6 +319,8 @@ github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= github.com/onsi/gomega v1.25.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM= github.com/openimsdk/gomake v0.0.14-alpha.5 h1:VY9c5x515lTfmdhhPjMvR3BBRrRquAUCFsz7t7vbv7Y= github.com/openimsdk/gomake v0.0.14-alpha.5/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= +github.com/openimsdk/protocol v0.0.72-alpha.54 h1:opato7N4QjjRq/SHD54bDSVBpOEEDp1VLWVk5Os2A9s= +github.com/openimsdk/protocol v0.0.72-alpha.54/go.mod h1:OZQA9FR55lseYoN2Ql1XAHYKHJGu7OMNkUbuekrKCM8= github.com/openimsdk/tools v0.0.50-alpha.16 h1:bC1AQvJMuOHtZm8LZRvN8L5mH1Ws2VYdL+TLTs1iGSc= github.com/openimsdk/tools v0.0.50-alpha.16/go.mod h1:h1cYmfyaVtgFbKmb1Cfsl8XwUOMTt8ubVUQrdGtsUh4= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= diff --git a/internal/api/msg.go b/internal/api/msg.go index e99bd4255..ce94b5f4f 100644 --- a/internal/api/msg.go +++ b/internal/api/msg.go @@ -375,3 +375,11 @@ func (m *MessageApi) SearchMsg(c *gin.Context) { func (m *MessageApi) GetServerTime(c *gin.Context) { a2r.Call(msg.MsgClient.GetServerTime, m.Client, c) } + +func (m *MessageApi) GetStreamMsg(c *gin.Context) { + a2r.Call(msg.MsgClient.GetStreamMsg, m.Client, c) +} + +func (m *MessageApi) AppendStreamMsg(c *gin.Context) { + a2r.Call(msg.MsgClient.AppendStreamMsg, m.Client, c) +} diff --git a/internal/api/router.go b/internal/api/router.go index 17c998912..3c4976d58 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -229,6 +229,8 @@ func newGinRouter(disCov discovery.SvcDiscoveryRegistry, config *Config) *gin.En msgGroup.POST("/batch_send_msg", m.BatchSendMsg) msgGroup.POST("/check_msg_is_send_success", m.CheckMsgIsSendSuccess) msgGroup.POST("/get_server_time", m.GetServerTime) + msgGroup.POST("/get_stream_msg", m.GetStreamMsg) + msgGroup.POST("/append_stream_msg", m.AppendStreamMsg) } // Conversation conversationGroup := r.Group("/conversation") diff --git a/internal/rpc/msg/send.go b/internal/rpc/msg/send.go index 2c3f8c0a3..4762f24de 100644 --- a/internal/rpc/msg/send.go +++ b/internal/rpc/msg/send.go @@ -34,6 +34,11 @@ import ( func (m *msgServer) SendMsg(ctx context.Context, req *pbmsg.SendMsgReq) (*pbmsg.SendMsgResp, error) { if req.MsgData != nil { m.encapsulateMsgData(req.MsgData) + if req.MsgData.ContentType == constant.Stream { + if err := m.handlerStreamMsg(ctx, req.MsgData); err != nil { + return nil, err + } + } switch req.MsgData.SessionType { case constant.SingleChatType: return m.sendMsgSingleChat(ctx, req) diff --git a/internal/rpc/msg/server.go b/internal/rpc/msg/server.go index a8628383a..bf8781747 100644 --- a/internal/rpc/msg/server.go +++ b/internal/rpc/msg/server.go @@ -102,6 +102,10 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg if err != nil { return err } + streamMsg, err := mgo.NewStreamMsgMongo(mgocli.GetDB()) + if err != nil { + return err + } seqUserCache := redis.NewSeqUserCacheRedis(rdb, seqUser) msgDatabase, err := controller.NewCommonMsgDatabase(msgDocModel, msgModel, seqUserCache, seqConversationCache, &config.KafkaConfig) if err != nil { @@ -110,6 +114,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg s := &msgServer{ Conversation: &conversationClient, MsgDatabase: msgDatabase, + StreamMsgDatabase: controller.NewStreamMsgDatabase(streamMsg), RegisterCenter: client, UserLocalCache: rpccache.NewUserLocalCache(userRpcClient, &config.LocalCacheConfig, rdb), GroupLocalCache: rpccache.NewGroupLocalCache(groupRpcClient, &config.LocalCacheConfig, rdb), diff --git a/internal/rpc/msg/stream_msg.go b/internal/rpc/msg/stream_msg.go index f959ebc79..5db2aad48 100644 --- a/internal/rpc/msg/stream_msg.go +++ b/internal/rpc/msg/stream_msg.go @@ -31,8 +31,11 @@ func (m *msgServer) getStreamMsg(ctx context.Context, clientMsgID string) (*mode if err != nil { return nil, err } - if !res.End && res.DeadlineTime.Before(time.Now()) { + now := time.Now() + if !res.End && res.DeadlineTime.Before(now) { res.End = true + res.DeadlineTime = now + _ = m.StreamMsgDatabase.AppendStreamMsg(ctx, res.ClientMsgID, 0, nil, true, now) } return res, nil } @@ -64,7 +67,8 @@ func (m *msgServer) AppendStreamMsg(ctx context.Context, req *msg.AppendStreamMs if len(req.Packets) == 0 && res.End == req.End { return &msg.AppendStreamMsgResp{}, nil } - if err := m.StreamMsgDatabase.AppendStreamMsg(ctx, req.ClientMsgID, int(req.StartIndex), req.Packets, req.End); err != nil { + deadlineTime := time.Now().Add(StreamDeadlineTime) + if err := m.StreamMsgDatabase.AppendStreamMsg(ctx, req.ClientMsgID, int(req.StartIndex), req.Packets, req.End, deadlineTime); err != nil { return nil, err } conversation, err := m.Conversation.GetConversation(ctx, res.UserID, res.ConversationID) @@ -72,10 +76,11 @@ func (m *msgServer) AppendStreamMsg(ctx context.Context, req *msg.AppendStreamMs return nil, err } tips := &sdkws.StreamMsgTips{ - ClientMsgID: res.ClientMsgID, - StartIndex: req.StartIndex, - Packets: req.Packets, - End: req.End, + ConversationID: res.ConversationID, + ClientMsgID: res.ClientMsgID, + StartIndex: req.StartIndex, + Packets: req.Packets, + End: req.End, } var ( recvID string diff --git a/pkg/common/storage/controller/stream_msg.go b/pkg/common/storage/controller/stream_msg.go index ca4402fb9..3409ccd93 100644 --- a/pkg/common/storage/controller/stream_msg.go +++ b/pkg/common/storage/controller/stream_msg.go @@ -2,11 +2,33 @@ package controller import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "time" ) type StreamMsgDatabase interface { CreateStreamMsg(ctx context.Context, model *model.StreamMsg) error - AppendStreamMsg(ctx context.Context, clientMsgID string, startIndex int, packets []string, end bool) error + AppendStreamMsg(ctx context.Context, clientMsgID string, startIndex int, packets []string, end bool, deadlineTime time.Time) error GetStreamMsg(ctx context.Context, clientMsgID string) (*model.StreamMsg, error) } + +func NewStreamMsgDatabase(db database.StreamMsg) StreamMsgDatabase { + return &streamMsgDatabase{db: db} +} + +type streamMsgDatabase struct { + db database.StreamMsg +} + +func (m *streamMsgDatabase) CreateStreamMsg(ctx context.Context, model *model.StreamMsg) error { + return m.db.CreateStreamMsg(ctx, model) +} + +func (m *streamMsgDatabase) AppendStreamMsg(ctx context.Context, clientMsgID string, startIndex int, packets []string, end bool, deadlineTime time.Time) error { + return m.db.AppendStreamMsg(ctx, clientMsgID, startIndex, packets, end, deadlineTime) +} + +func (m *streamMsgDatabase) GetStreamMsg(ctx context.Context, clientMsgID string) (*model.StreamMsg, error) { + return m.db.GetStreamMsg(ctx, clientMsgID) +} diff --git a/pkg/common/storage/database/mgo/stream_msg.go b/pkg/common/storage/database/mgo/stream_msg.go new file mode 100644 index 000000000..c57798daa --- /dev/null +++ b/pkg/common/storage/database/mgo/stream_msg.go @@ -0,0 +1,60 @@ +package mgo + +import ( + "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/db/mongoutil" + "github.com/openimsdk/tools/errs" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "time" +) + +func NewStreamMsgMongo(db *mongo.Database) (*StreamMsgMongo, error) { + coll := db.Collection(database.StreamMsgName) + _, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{ + {Key: "client_msg_id", Value: 1}, + }, + Options: options.Index().SetUnique(true), + }) + if err != nil { + return nil, errs.Wrap(err) + } + return &StreamMsgMongo{coll: coll}, nil +} + +type StreamMsgMongo struct { + coll *mongo.Collection +} + +func (m *StreamMsgMongo) CreateStreamMsg(ctx context.Context, val *model.StreamMsg) error { + if val.Packets == nil { + val.Packets = []string{} + } + return mongoutil.InsertMany(ctx, m.coll, []*model.StreamMsg{val}) +} + +func (m *StreamMsgMongo) AppendStreamMsg(ctx context.Context, clientMsgID string, startIndex int, packets []string, end bool, deadlineTime time.Time) error { + update := bson.M{ + "$set": bson.M{ + "end": end, + "deadline_time": deadlineTime, + }, + } + if len(packets) > 0 { + update["$push"] = bson.M{ + "packets": bson.M{ + "$each": packets, + "$position": startIndex, + }, + } + } + return mongoutil.UpdateOne(ctx, m.coll, bson.M{"client_msg_id": clientMsgID, "end": false}, update, true) +} + +func (m *StreamMsgMongo) GetStreamMsg(ctx context.Context, clientMsgID string) (*model.StreamMsg, error) { + return mongoutil.FindOne[*model.StreamMsg](ctx, m.coll, bson.M{"client_msg_id": clientMsgID}) +} diff --git a/pkg/common/storage/database/name.go b/pkg/common/storage/database/name.go index 748bd844d..9742f933f 100644 --- a/pkg/common/storage/database/name.go +++ b/pkg/common/storage/database/name.go @@ -17,4 +17,5 @@ const ( UserName = "user" SeqConversationName = "seq" SeqUserName = "seq_user" + StreamMsgName = "stream_msg" ) diff --git a/pkg/common/storage/database/stream_msg.go b/pkg/common/storage/database/stream_msg.go new file mode 100644 index 000000000..e83fffbaa --- /dev/null +++ b/pkg/common/storage/database/stream_msg.go @@ -0,0 +1,13 @@ +package database + +import ( + "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "time" +) + +type StreamMsg interface { + CreateStreamMsg(ctx context.Context, model *model.StreamMsg) error + AppendStreamMsg(ctx context.Context, clientMsgID string, startIndex int, packets []string, end bool, deadlineTime time.Time) error + GetStreamMsg(ctx context.Context, clientMsgID string) (*model.StreamMsg, error) +} diff --git a/tools/streammsg/main.go b/tools/streammsg/main.go new file mode 100644 index 000000000..bb567e233 --- /dev/null +++ b/tools/streammsg/main.go @@ -0,0 +1,161 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/openimsdk/open-im-server/v3/pkg/apistruct" + cbapi "github.com/openimsdk/open-im-server/v3/pkg/callbackstruct" + "github.com/openimsdk/protocol/auth" + "github.com/openimsdk/protocol/constant" + "github.com/openimsdk/protocol/msg" + "github.com/openimsdk/tools/apiresp" + "github.com/openimsdk/tools/errs" + "io" + "net/http" + "strings" + "time" +) + +const ( + getAdminToken = "/auth/get_admin_token" + sendMsgApi = "/msg/send_msg" + appendStreamMsg = "/msg/append_stream_msg" +) + +var ( + ApiAddr = "http://127.0.0.1:10002" + Token string +) + +func ApiCall[R any](api string, req any) (*R, error) { + data, err := json.Marshal(req) + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + request, err := http.NewRequestWithContext(ctx, http.MethodPost, ApiAddr+api, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + if Token != "" { + request.Header.Set("token", Token) + } + request.Header.Set(constant.OperationID, uuid.New().String()) + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + var resp R + apiResponse := apiresp.ApiResponse{ + Data: &resp, + } + if err := json.NewDecoder(response.Body).Decode(&apiResponse); err != nil { + return nil, err + } + if apiResponse.ErrCode != 0 { + return nil, errs.NewCodeError(apiResponse.ErrCode, apiResponse.ErrMsg) + } + return &resp, nil +} + +func main() { + resp, err := ApiCall[auth.GetAdminTokenResp](getAdminToken, &auth.GetAdminTokenReq{ + Secret: "openIM123", + UserID: "imAdmin", + }) + if err != nil { + fmt.Println("get admin token failed", err) + return + } + Token = resp.Token + g := gin.Default() + g.POST("/callbackExample/callbackAfterSendSingleMsgCommand", toGin(handlerUserMsg)) + if err := g.Run(":10006"); err != nil { + panic(err) + } +} + +func toGin[R any](fn func(c *gin.Context, req *R) error) gin.HandlerFunc { + return func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + fmt.Printf("HTTP %s %s %s\n", c.Request.Method, c.Request.URL, body) + var req R + if err := json.Unmarshal(body, &req); err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + if err := fn(c, &req); err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + c.String(http.StatusOK, "{}") + } +} + +func handlerUserMsg(c *gin.Context, req *cbapi.CallbackAfterSendSingleMsgReq) error { + if req.ContentType != constant.Text { + return nil + } + if !strings.Contains(req.Content, "stream") { + return nil + } + apiReq := apistruct.SendMsgReq{ + RecvID: req.SendID, + SendMsg: apistruct.SendMsg{ + SendID: req.RecvID, + SenderNickname: "xxx", + SenderFaceURL: "", + SenderPlatformID: constant.AdminPlatformID, + ContentType: constant.Stream, + SessionType: req.SessionType, + SendTime: time.Now().UnixMilli(), + Content: map[string]any{ + "type": "xxx", + "content": "server test stream msg", + }, + }, + } + go func() { + if err := doPushStreamMsg(&apiReq); err != nil { + fmt.Println("doPushStreamMsg failed", err) + return + } + fmt.Println("doPushStreamMsg success") + }() + return nil +} + +func doPushStreamMsg(sendReq *apistruct.SendMsgReq) error { + resp, err := ApiCall[msg.SendMsgResp](sendMsgApi, sendReq) + if err != nil { + return err + } + const num = 5 + for i := 1; i <= num; i++ { + _, err := ApiCall[msg.AppendStreamMsgResp](appendStreamMsg, &msg.AppendStreamMsgReq{ + ClientMsgID: resp.ClientMsgID, + StartIndex: int64(i - 1), + Packets: []string{ + fmt.Sprintf("stream_msg_packet_%03d", i), + }, + End: i == num, + }) + if err != nil { + fmt.Println("append stream msg failed", "clientMsgID", resp.ClientMsgID, "index", fmt.Sprintf("%d/%d", i, num), "error", err) + return err + } + fmt.Println("append stream msg success", "clientMsgID", resp.ClientMsgID, "index", fmt.Sprintf("%d/%d", i, num)) + time.Sleep(time.Second * 10) + } + return nil +}