From 5402d33519e94d08b8afec76b03d6385c90e3f60 Mon Sep 17 00:00:00 2001 From: skiffer-git <44203734@qq.com> Date: Thu, 2 Feb 2023 16:39:29 +0800 Subject: [PATCH] Error code standardization --- go.sum | 2 - internal/rpc/friend/friend.go | 97 ++++++++++++------------ internal/rpc/friend/other.go | 14 ---- pkg/common/db/controller/friend.go | 60 ++++++++++++--- pkg/common/db/controller/user.go | 50 +++++++++--- pkg/common/db/relation/friend_model_k.go | 22 +++++- pkg/common/db/relation/user_model_k.go | 5 +- 7 files changed, 161 insertions(+), 89 deletions(-) delete mode 100644 internal/rpc/friend/other.go diff --git a/go.sum b/go.sum index 0c4dff748..6da44054e 100644 --- a/go.sum +++ b/go.sum @@ -391,8 +391,6 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/OpenIMSDK/getcdv3 v1.0.3 h1:3/j92MuDPFhAJYBy/ht0qnybdaaCezefF1pMTClvvq4= -github.com/OpenIMSDK/getcdv3 v1.0.3/go.mod h1:ZvsBwAjOZZr7HBF3SytJaHCltuOfBKbM1vLSCjut7kw= github.com/OpenIMSDK/getcdv3 v1.0.4 h1:wKpLcp1gbLbh+fa7b5iCL4fTBLm87hB0+p0ZQMg9tK8= github.com/OpenIMSDK/getcdv3 v1.0.4/go.mod h1:ZvsBwAjOZZr7HBF3SytJaHCltuOfBKbM1vLSCjut7kw= github.com/OpenIMSDK/open_log v1.0.0 h1:ZQ908aWgPqfHOfkQ/oFSV20AZdRwPw+sZjC/sAPd5cA= diff --git a/internal/rpc/friend/friend.go b/internal/rpc/friend/friend.go index 8caa41a76..3a57642f3 100644 --- a/internal/rpc/friend/friend.go +++ b/internal/rpc/friend/friend.go @@ -7,14 +7,15 @@ import ( "Open_IM/pkg/common/constant" "Open_IM/pkg/common/db/controller" "Open_IM/pkg/common/db/relation" + "Open_IM/pkg/common/db/table" "Open_IM/pkg/common/log" "Open_IM/pkg/common/middleware" promePkg "Open_IM/pkg/common/prometheus" "Open_IM/pkg/common/token_verify" "Open_IM/pkg/common/tracelog" - "Open_IM/pkg/getcdv3" pbFriend "Open_IM/pkg/proto/friend" sdkws "Open_IM/pkg/proto/sdk_ws" + pbUser "Open_IM/pkg/proto/user" "Open_IM/pkg/utils" "context" grpcPrometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -23,6 +24,7 @@ import ( "strings" "Open_IM/internal/common/check" + "github.com/OpenIMSDK/getcdv3" "google.golang.org/grpc" ) @@ -33,6 +35,8 @@ type friendServer struct { etcdAddr []string controller.FriendInterface controller.BlackInterface + + userRpc pbUser.UserClient } func NewFriendServer(port int) *friendServer { @@ -43,10 +47,33 @@ func NewFriendServer(port int) *friendServer { etcdSchema: config.Config.Etcd.EtcdSchema, etcdAddr: config.Config.Etcd.EtcdAddr, } + ttl := 10 + etcdClient, err := getcdv3.NewEtcdConn(config.Config.Etcd.EtcdSchema, strings.Join(f.etcdAddr, ","), config.Config.RpcRegisterIP, config.Config.Etcd.UserName, config.Config.Etcd.Password, port, ttl) + if err != nil { + panic("NewEtcdConn failed" + err.Error()) + } + err = etcdClient.RegisterEtcd("", f.rpcRegisterName) + if err != nil { + panic("NewEtcdConn failed" + err.Error()) + } + + etcdClient.SetDefaultEtcdConfig(config.Config.RpcRegisterName.OpenImUserName, config.Config.RpcPort.OpenImUserPort) + conn := etcdClient.GetConn("", config.Config.RpcRegisterName.OpenImUserName) + f.userRpc = pbUser.NewUserClient(conn) + //mysql init var mysql relation.Mysql var model relation.FriendGorm - err := mysql.InitConn().AutoMigrateModel(&relation.FriendModel{}) + err = mysql.InitConn().AutoMigrateModel(&table.FriendModel{}) + if err != nil { + panic("db init err:" + err.Error()) + } + err = mysql.InitConn().AutoMigrateModel(&table.FriendRequestModel{}) + if err != nil { + panic("db init err:" + err.Error()) + } + + err = mysql.InitConn().AutoMigrateModel(&table.BlackModel{}) if err != nil { panic("db init err:" + err.Error()) } @@ -93,21 +120,7 @@ func (s *friendServer) Run() { } srv := grpc.NewServer(grpcOpts...) defer srv.GracefulStop() - //User friend related services register to etcd pbFriend.RegisterFriendServer(srv, s) - rpcRegisterIP := config.Config.RpcRegisterIP - if config.Config.RpcRegisterIP == "" { - rpcRegisterIP, err = utils.GetLocalIP() - if err != nil { - log.Error("", "GetLocalIP failed ", err.Error()) - } - } - log.NewInfo("", "rpcRegisterIP", rpcRegisterIP) - err = getcdv3.RegisterEtcd(s.etcdSchema, strings.Join(s.etcdAddr, ","), rpcRegisterIP, s.rpcPort, s.rpcRegisterName, 10, "") - if err != nil { - log.NewError("0", "RegisterEtcd failed ", err.Error(), s.etcdSchema, strings.Join(s.etcdAddr, ","), rpcRegisterIP, s.rpcPort, s.rpcRegisterName) - panic(utils.Wrap(err, "register friend module rpc to etcd err")) - } err = srv.Serve(listener) if err != nil { log.NewError("0", "Serve failed ", err.Error(), listener) @@ -153,9 +166,9 @@ func (s *friendServer) ImportFriends(ctx context.Context, req *pbFriend.ImportFr return nil, err } - var friends []*relation.Friend + var friends []*table.FriendModel for _, userID := range utils.RemoveDuplicateElement(req.FriendUserIDs) { - friends = append(friends, &relation.Friend{OwnerUserID: userID, FriendUserID: req.OwnerUserID, AddSource: constant.BecomeFriendByImport, OperatorUserID: tracelog.GetOpUserID(ctx)}) + friends = append(friends, &table.FriendModel{OwnerUserID: userID, FriendUserID: req.OwnerUserID, AddSource: constant.BecomeFriendByImport, OperatorUserID: tracelog.GetOpUserID(ctx)}) } if len(friends) > 0 { if err := s.FriendInterface.BecomeFriend(ctx, friends); err != nil { @@ -171,7 +184,7 @@ func (s *friendServer) RespondFriendApply(ctx context.Context, req *pbFriend.Res if err := check.Access(ctx, req.ToUserID); err != nil { return nil, err } - friendRequest := controller.FriendRequest{FromUserID: req.FromUserID, ToUserID: req.ToUserID, HandleMsg: req.HandleMsg, HandleResult: req.HandleResult} + friendRequest := table.FriendRequestModel{FromUserID: req.FromUserID, ToUserID: req.ToUserID, HandleMsg: req.HandleMsg, HandleResult: req.HandleResult} if req.HandleResult == constant.FriendResponseAgree { err := s.AgreeFriendRequest(ctx, &friendRequest) if err != nil { @@ -220,7 +233,7 @@ func (s *friendServer) GetFriends(ctx context.Context, req *pbFriend.GetFriendsR if err := check.Access(ctx, req.UserID); err != nil { return nil, err } - friends, err := s.FriendInterface.FindOwnerFriends(ctx, req.UserID, req.Pagination.PageNumber, req.Pagination.ShowNumber) + friends, total, err := s.FriendInterface.FindOwnerFriends(ctx, req.UserID, req.Pagination.PageNumber, req.Pagination.ShowNumber) if err != nil { return nil, err } @@ -236,14 +249,11 @@ func (s *friendServer) GetFriends(ctx context.Context, req *pbFriend.GetFriendsR for i, user := range users { userMap[user.UserID] = users[i] } - for _, friendUser := range friends { - - friendUserInfo, err := (convert.NewDBFriend(friendUser)).Convert() - if err != nil { - return nil, err - } - resp.FriendsInfo = append(resp.FriendsInfo, friendUserInfo) + resp.FriendsInfo, err = (*convert.DBFriend)(nil).DB2PB(friends) + if err != nil { + return nil, err } + resp.Total = int32(total) return resp, nil } @@ -253,17 +263,15 @@ func (s *friendServer) GetToFriendsApply(ctx context.Context, req *pbFriend.GetT if err := check.Access(ctx, req.UserID); err != nil { return nil, err } - friendRequests, err := s.FriendInterface.FindFriendRequestToMe(ctx, req.UserID, req.Pagination.PageNumber, req.Pagination.ShowNumber) + friendRequests, total, err := s.FriendInterface.FindFriendRequestToMe(ctx, req.UserID, req.Pagination.PageNumber, req.Pagination.ShowNumber) if err != nil { return nil, err } - for _, v := range friendRequests { - fUser, err := convert.NewDBFriendRequest(v).Convert() - if err != nil { - return nil, err - } - resp.FriendRequests = append(resp.FriendRequests, fUser) + resp.FriendRequests, err = (*convert.DBFriendRequest)(nil).DB2PB(friendRequests) + if err != nil { + return nil, err } + resp.Total = int32(total) return resp, nil } @@ -273,17 +281,15 @@ func (s *friendServer) GetFromFriendsApply(ctx context.Context, req *pbFriend.Ge if err := check.Access(ctx, req.UserID); err != nil { return nil, err } - friendRequests, err := s.FriendInterface.FindFriendRequestFromMe(ctx, req.UserID, req.Pagination.PageNumber, req.Pagination.ShowNumber) + friendRequests, total, err := s.FriendInterface.FindFriendRequestFromMe(ctx, req.UserID, req.Pagination.PageNumber, req.Pagination.ShowNumber) if err != nil { return nil, err } - for _, v := range friendRequests { - fUser, err := convert.NewDBFriendRequest(v).Convert() - if err != nil { - return nil, err - } - resp.FriendRequests = append(resp.FriendRequests, fUser) + resp.FriendRequests, err = (*convert.DBFriendRequest)(nil).DB2PB(friendRequests) + if err != nil { + return nil, err } + resp.Total = int32(total) return resp, nil } @@ -304,12 +310,9 @@ func (s *friendServer) GetFriendsInfo(ctx context.Context, req *pbFriend.GetFrie if err != nil { return nil, err } - for _, v := range friends { - fUser, err := convert.NewDBFriend(v).Convert() - if err != nil { - return nil, err - } - resp.FriendsInfo = append(resp.FriendsInfo, fUser) + resp.FriendsInfo, err = (*convert.DBFriend)(nil).DB2PB(friends) + if err != nil { + return nil, err } return &resp, nil } diff --git a/internal/rpc/friend/other.go b/internal/rpc/friend/other.go deleted file mode 100644 index 052caa171..000000000 --- a/internal/rpc/friend/other.go +++ /dev/null @@ -1,14 +0,0 @@ -package friend - -import ( - server_api_params "Open_IM/pkg/proto/sdk_ws" - "context" - "errors" -) - -func GetPublicUserInfoBatch(ctx context.Context, userIDs []string) ([]*server_api_params.PublicUserInfo, error) { - if len(userIDs) == 0 { - return []*server_api_params.PublicUserInfo{}, nil - } - return nil, errors.New("TODO:GetUserInfo") -} diff --git a/pkg/common/db/controller/friend.go b/pkg/common/db/controller/friend.go index 13bfd16b3..e61fe7054 100644 --- a/pkg/common/db/controller/friend.go +++ b/pkg/common/db/controller/friend.go @@ -1,6 +1,7 @@ package controller import ( + "Open_IM/pkg/common/constant" "Open_IM/pkg/common/db/relation" "Open_IM/pkg/common/db/table" "context" @@ -9,11 +10,11 @@ import ( type FriendInterface interface { // CheckIn 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true) - CheckIn(ctx context.Context, user1, user2 string) (err error, inUser1Friends bool, inUser2Friends bool) + CheckIn(ctx context.Context, user1, user2 string) (inUser1Friends bool, inUser2Friends bool, err error) // AddFriendRequest 增加或者更新好友申请 AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) // BecomeFriend 先判断是否在好友表,如果在则不插入 - BecomeFriend(ctx context.Context, friends []*table.FriendModel) (err error) + BecomeFriend(ctx context.Context, friends []*table.FriendModel, revFriends []*table.FriendModel) (err error) // RefuseFriendRequest 拒绝好友申请 RefuseFriendRequest(ctx context.Context, friendRequest *table.FriendRequestModel) (err error) // AgreeFriendRequest 同意好友申请 @@ -23,14 +24,14 @@ type FriendInterface interface { // UpdateRemark 更新好友备注 UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) // FindOwnerFriends 获取ownerUserID的好友列表 - FindOwnerFriends(ctx context.Context, ownerUserID string, pageNumber, showNumber int32) (friends []*table.FriendModel, err error) + FindOwnerFriends(ctx context.Context, ownerUserID string, pageNumber, showNumber int32) (friends []*table.FriendModel, total int64, err error) // FindInWhoseFriends friendUserID在哪些人的好友列表中 - FindInWhoseFriends(ctx context.Context, friendUserID string, pageNumber, showNumber int32) (friends []*table.FriendModel, err error) + FindInWhoseFriends(ctx context.Context, friendUserID string, pageNumber, showNumber int32) (friends []*table.FriendModel, total int64, err error) // FindFriendRequestFromMe 获取我发出去的好友申请 - FindFriendRequestFromMe(ctx context.Context, userID string, pageNumber, showNumber int32) (friends []*table.FriendRequestModel, err error) + FindFriendRequestFromMe(ctx context.Context, userID string, pageNumber, showNumber int32) (friends []*table.FriendRequestModel, total int64, err error) // FindFriendRequestToMe 获取我收到的的好友申请 - FindFriendRequestToMe(ctx context.Context, userID string, pageNumber, showNumber int32) (friends []*table.FriendRequestModel, err error) - // FindFriends 获取某人指定好友的信息 + FindFriendRequestToMe(ctx context.Context, userID string, pageNumber, showNumber int32) (friends []*table.FriendRequestModel, total int64, err error) + // FindFriends 获取某人指定好友的信息 如果有一个不存在也返回错误 FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*table.FriendModel, err error) } @@ -92,7 +93,7 @@ func (f *FriendController) FindFriends(ctx context.Context, ownerUserID string, type FriendDatabaseInterface interface { // CheckIn 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true) - CheckIn(ctx context.Context, user1, user2 string) (err error, inUser1Friends bool, inUser2Friends bool) + CheckIn(ctx context.Context, user1, user2 string) (inUser1Friends bool, inUser2Friends bool, err error) // AddFriendRequest 增加或者更新好友申请 AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) // BecomeFriend 先判断是否在好友表,如果在则不插入 @@ -127,15 +128,44 @@ func NewFriendDatabase(db *gorm.DB) *FriendDatabase { } // CheckIn 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true) -func (f *FriendDatabase) CheckIn(ctx context.Context, user1, user2 string) (err error, inUser1Friends bool, inUser2Friends bool) { +func (f *FriendDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Friends bool, inUser2Friends bool, err error) { + friends, err := f.friend.FindUserState(ctx, userID1, userID2) + for _, v := range friends { + if v.OwnerUserID == userID1 && v.FriendUserID == userID2 { + inUser1Friends = true + } + if v.OwnerUserID == userID2 && v.FriendUserID == userID1 { + inUser2Friends = true + } + } + return } // AddFriendRequest 增加或者更新好友申请 func (f *FriendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) { + } // BecomeFriend 先判断是否在好友表,如果在则不插入 -func (f *FriendDatabase) BecomeFriend(ctx context.Context, friends []*table.FriendModel) (err error) { +func (f *FriendDatabase) BecomeFriend(ctx context.Context, ownerUserID string, friends []*table.FriendModel) (err error) { + return f.friend.DB.Transaction(func(tx *gorm.DB) error { + //先find 找出重复的 去掉重复的 + friendUserIDs := make([]string, 0, len(friends)) + for _, v := range friends { + friendUserIDs = append(friendUserIDs, v.FriendUserID) + } + fs1, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs, tx) + if err != nil { + return err + } + fs2, err := f.friend.FindReversalFriends(ctx, ownerUserID, friendUserIDs, tx) + if err != nil { + return err + } + + return nil + }) + } // RefuseFriendRequest 拒绝好友申请 @@ -170,6 +200,14 @@ func (f *FriendDatabase) FindFriendRequestFromMe(ctx context.Context, userID str func (f *FriendDatabase) FindFriendRequestToMe(ctx context.Context, userID string, pageNumber, showNumber int32) (friends []*table.FriendRequestModel, err error) { } -// FindFriends 获取某人指定好友的信息 +// FindFriends 获取某人指定好友的信息 如果有一个不存在也返回错误 func (f *FriendDatabase) FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*table.FriendModel, err error) { + friends, err = f.friend.Find(ctx, ownerUserID, friendUserIDs) + if err != nil { + return + } + if len(friends) != len(friendUserIDs) { + err = constant.ErrRecordNotFound.Wrap() + } + return } diff --git a/pkg/common/db/controller/user.go b/pkg/common/db/controller/user.go index 0fefe77c6..1f904e8bf 100644 --- a/pkg/common/db/controller/user.go +++ b/pkg/common/db/controller/user.go @@ -1,6 +1,7 @@ package controller import ( + "Open_IM/pkg/common/constant" "Open_IM/pkg/common/db/relation" "Open_IM/pkg/common/db/table" "context" @@ -10,11 +11,17 @@ import ( type UserInterface interface { //获取指定用户的信息 如果有记录未找到 也返回错误 Find(ctx context.Context, userIDs []string) (users []*table.UserModel, err error) + //插入 Create(ctx context.Context, users []*table.UserModel) error + //更新 Update(ctx context.Context, users []*table.UserModel) (err error) + //更新带零值的 UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) + //通过名字搜索 GetByName(ctx context.Context, userName string, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) + //通过名字和id搜索 GetByNameAndID(ctx context.Context, content string, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) + //获取,如果没找到,不不返回错误 Get(ctx context.Context, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) //userIDs是否存在 只要有一个存在就为true IsExist(ctx context.Context, userIDs []string) (exist bool, err error) @@ -30,24 +37,29 @@ func (u *UserController) Find(ctx context.Context, userIDs []string) (users []*t func (u *UserController) Create(ctx context.Context, users []*table.UserModel) error { return u.database.Create(ctx, users) } -func (u *UserController) Take(ctx context.Context, userID string) (user *table.UserModel, err error) { - return u.database.Take(ctx, userID) -} + func (u *UserController) Update(ctx context.Context, users []*table.UserModel) (err error) { return u.database.Update(ctx, users) } func (u *UserController) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) { return u.database.UpdateByMap(ctx, userID, args) } + func (u *UserController) GetByName(ctx context.Context, userName string, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) { return u.database.GetByName(ctx, userName, showNumber, pageNumber) } + func (u *UserController) GetByNameAndID(ctx context.Context, content string, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) { return u.database.GetByNameAndID(ctx, content, showNumber, pageNumber) } + func (u *UserController) Get(ctx context.Context, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) { return u.database.Get(ctx, showNumber, pageNumber) } + +func (u *UserController) IsExist(ctx context.Context, userIDs []string) (exist bool, err error) { + return u.IsExist(ctx, userIDs) +} func NewUserController(db *gorm.DB) *UserController { controller := &UserController{database: newUserDatabase(db)} return controller @@ -56,12 +68,12 @@ func NewUserController(db *gorm.DB) *UserController { type UserDatabaseInterface interface { Find(ctx context.Context, userIDs []string) (users []*table.UserModel, err error) Create(ctx context.Context, users []*table.UserModel) error - Take(ctx context.Context, userID string) (user *table.UserModel, err error) Update(ctx context.Context, users []*table.UserModel) (err error) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) GetByName(ctx context.Context, userName string, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) GetByNameAndID(ctx context.Context, content string, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) Get(ctx context.Context, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) + IsExist(ctx context.Context, userIDs []string) (exist bool, err error) } type UserDatabase struct { @@ -76,16 +88,22 @@ func newUserDatabase(db *gorm.DB) *UserDatabase { return database } +// 获取指定用户的信息 如果有记录未找到 也返回错误 func (u *UserDatabase) Find(ctx context.Context, userIDs []string) (users []*table.UserModel, err error) { - return u.sqlDB.Find(ctx, userIDs) + users, err = u.sqlDB.Find(ctx, userIDs) + if err != nil { + return + } + if len(users) != len(userIDs) { + err = constant.ErrRecordNotFound.Wrap() + } + return } -func (u *UserDatabase) Create(ctx context.Context, users []*table.UserModel) error { +func (u *UserDatabase) Create(ctx context.Context, users []*table.UserModel) (err error) { return u.sqlDB.Create(ctx, users) } -func (u *UserDatabase) Take(ctx context.Context, userID string) (user *table.UserModel, err error) { - return u.sqlDB.Take(ctx, userID) -} + func (u *UserDatabase) Update(ctx context.Context, users []*table.UserModel) (err error) { return u.sqlDB.Update(ctx, users) } @@ -98,6 +116,20 @@ func (u *UserDatabase) GetByName(ctx context.Context, userName string, showNumbe func (u *UserDatabase) GetByNameAndID(ctx context.Context, content string, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) { return u.sqlDB.GetByNameAndID(ctx, content, showNumber, pageNumber) } + +// 获取,如果没找到,不返回错误 func (u *UserDatabase) Get(ctx context.Context, showNumber, pageNumber int32) (users []*table.UserModel, count int64, err error) { return u.sqlDB.Get(ctx, showNumber, pageNumber) } + +// userIDs是否存在 只要有一个存在就为true +func (u *UserDatabase) IsExist(ctx context.Context, userIDs []string) (exist bool, err error) { + users, err := u.sqlDB.Find(ctx, userIDs) + if err != nil { + return + } + if len(users) > 0 { + return true, nil + } + return false, nil +} diff --git a/pkg/common/db/relation/friend_model_k.go b/pkg/common/db/relation/friend_model_k.go index b270fef99..05e889841 100644 --- a/pkg/common/db/relation/friend_model_k.go +++ b/pkg/common/db/relation/friend_model_k.go @@ -30,11 +30,11 @@ type FriendUser struct { Nickname string `gorm:"column:name;size:255"` } -func (f *FriendGorm) Create(ctx context.Context, friends []*table.FriendModel) (err error) { +func (f *FriendGorm) Create(ctx context.Context, friends []*table.FriendModel, tx ...*gorm.DB) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetSelfFuncName(), err, "friends", friends) }() - return utils.Wrap(f.DB.Model(&table.FriendModel{}).Create(&friends).Error, "") + return utils.Wrap(getDBConn(f.DB, tx).Model(&table.FriendModel{}).Create(&friends).Error, "") } func (f *FriendGorm) Delete(ctx context.Context, ownerUserID string, friendUserIDs string) (err error) { @@ -52,7 +52,7 @@ func (f *FriendGorm) UpdateByMap(ctx context.Context, ownerUserID string, args m return utils.Wrap(f.DB.Model(&table.FriendModel{}).Where("owner_user_id = ?", ownerUserID).Updates(args).Error, "") } -func (f *FriendGorm) Update(ctx context.Context, friends []*table.FriendModel) (err error) { +func (f *FriendGorm) Update(ctx context.Context, friends []*table.FriendModel, tx ...*gorm.DB) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetSelfFuncName(), err, "friends", friends) }() @@ -92,3 +92,19 @@ func (f *FriendGorm) FindUserState(ctx context.Context, userID1, userID2 string) }() return friends, utils.Wrap(f.DB.Model(&table.FriendModel{}).Where("(owner_user_id = ? and friend_user_id = ?) or (owner_user_id = ? and friend_user_id = ?)", userID1, userID2, userID2, userID1).Find(&friends).Error, "") } + +// 获取 owner的好友列表 +func (f *FriendGorm) FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, tx ...*gorm.DB) (friends []*table.FriendModel, err error) { + defer func() { + tracelog.SetCtxDebug(ctx, utils.GetSelfFuncName(), err, "friendUserIDs", friendUserIDs, "friends", friends) + }() + return friends, utils.Wrap(getDBConn(f.DB, tx).Where("owner_user_id = ? AND friend_user_id in (?)", ownerUserID, friendUserIDs).Find(&friends).Error, "") +} + +// 获取哪些人添加了friendUserID +func (f *FriendGorm) FindReversalFriends(ctx context.Context, friendUserID string, ownerUserIDs []string, tx ...*gorm.DB) (friends []*table.FriendModel, err error) { + defer func() { + tracelog.SetCtxDebug(ctx, utils.GetSelfFuncName(), err, "friendUserID", friendUserID, "friends", friends) + }() + return friends, utils.Wrap(getDBConn(f.DB, tx).Where("friend_user_id = ? AND owner_user_id in (?)", friendUserID, ownerUserIDs).Find(&friends).Error, "") +} diff --git a/pkg/common/db/relation/user_model_k.go b/pkg/common/db/relation/user_model_k.go index ce36d44fa..37dfb4e64 100644 --- a/pkg/common/db/relation/user_model_k.go +++ b/pkg/common/db/relation/user_model_k.go @@ -19,12 +19,11 @@ func NewUserGorm(db *gorm.DB) *UserGorm { return &user } -func (u *UserGorm) Create(ctx context.Context, users []*table.UserModel) (err error) { +func (u *UserGorm) Create(ctx context.Context, users []*table.UserModel, tx ...*gorm.DB) (err error) { defer func() { tracelog.SetCtxDebug(ctx, utils.GetFuncName(1), err, "users", users) }() - err = utils.Wrap(u.DB.Model(&table.UserModel{}).Create(&users).Error, "") - return err + return utils.Wrap(getDBConn(u.DB, tx).Model(&table.UserModel{}).Create(&users).Error, "") } func (u *UserGorm) UpdateByMap(ctx context.Context, userID string, args map[string]interface{}) (err error) {