From c842343ab73cc3a67981525bc4ea8c6ebebf21fd Mon Sep 17 00:00:00 2001 From: "Xinwei Xiong (cubxxw)" <3293172751nss@gmail.com> Date: Thu, 22 Feb 2024 21:38:27 +0800 Subject: [PATCH] fix: fix openim api err code --- .github/workflows/e2e-test.yml | 1 + cmd/openim-api/main.go | 2 +- internal/api/custom_validator.go | 11 ++-- internal/api/route.go | 92 +++++++++++++++----------------- internal/api/user.go | 60 ++++++++++----------- pkg/common/cmd/root.go | 1 + 6 files changed, 79 insertions(+), 88 deletions(-) diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 83303887e..b5f901d25 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -105,6 +105,7 @@ jobs: - name: Exec OpenIM E2E Test run: | + sudo make test-e2e echo "" >> ./tmp/test.md echo "## OpenIM E2E Test" >> ./tmp/test.md echo "
Command Output for OpenIM E2E Test" >> ./tmp/test.md diff --git a/cmd/openim-api/main.go b/cmd/openim-api/main.go index 41b030c4b..675fbf3a6 100644 --- a/cmd/openim-api/main.go +++ b/cmd/openim-api/main.go @@ -122,7 +122,7 @@ func run(port int, proPort int) error { util.SIGTERMExit() err := server.Shutdown(ctx) if err != nil { - return errs.Wrap(err, "shutdown err") + return errs.Wrap(err, "api shutdown err") } case <-netDone: close(netDone) diff --git a/internal/api/custom_validator.go b/internal/api/custom_validator.go index 8c5890501..1df4169e4 100644 --- a/internal/api/custom_validator.go +++ b/internal/api/custom_validator.go @@ -20,19 +20,16 @@ import ( "github.com/OpenIMSDK/protocol/constant" ) +// RequiredIf validates if the specified field is required based on the session type. func RequiredIf(fl validator.FieldLevel) bool { sessionType := fl.Parent().FieldByName("SessionType").Int() + switch sessionType { case constant.SingleChatType, constant.NotificationChatType: - if fl.FieldName() == "RecvID" { - return fl.Field().String() != "" - } + return fl.FieldName() != "RecvID" || fl.Field().String() != "" case constant.GroupChatType, constant.SuperGroupChatType: - if fl.FieldName() == "GroupID" { - return fl.Field().String() != "" - } + return fl.FieldName() != "GroupID" || fl.Field().String() != "" default: return true } - return true } diff --git a/internal/api/route.go b/internal/api/route.go index 24ed5f6bb..9288f43ff 100644 --- a/internal/api/route.go +++ b/internal/api/route.go @@ -225,6 +225,7 @@ func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.Unive return r } +// GinParseToken is a middleware that parses the token in the request header and verifies it. func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { dataBase := controller.NewAuthDatabase( cache.NewMsgCacheModel(rdb), @@ -232,57 +233,52 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { config.Config.TokenPolicy.Expire, ) return func(c *gin.Context) { - switch c.Request.Method { - case http.MethodPost: - token := c.Request.Header.Get(constant.Token) - if token == "" { - log.ZWarn(c, "header get token error", errs.ErrArgs.Wrap("header must have token")) - apiresp.GinError(c, errs.ErrArgs.Wrap("header must have token")) - c.Abort() - return - } - claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret()) - if err != nil { - log.ZWarn(c, "jwt get token error", errs.ErrTokenUnknown.Wrap()) - apiresp.GinError(c, errs.ErrTokenUnknown.Wrap()) - c.Abort() - return - } - m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID) - if err != nil { - log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap()) - apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) - c.Abort() - return - } - if len(m) == 0 { - log.ZWarn(c, "cache do not exist token error", errs.ErrTokenNotExist.Wrap()) - apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) - c.Abort() + if c.Request.Method != http.MethodPost { + c.Next() + return + } + + token := c.Request.Header.Get(constant.Token) + if token == "" { + handleGinError(c, "header get token error", errs.ErrArgs, "header must have token") + return + } + + claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret()) + if err != nil { + handleGinError(c, "jwt get token error", errs.ErrTokenUnknown, "") + return + } + + m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID) + if err != nil || len(m) == 0 { + handleGinError(c, "cache get token error", errs.ErrTokenNotExist, "") + return + } + + if v, ok := m[token]; ok { + if v == constant.KickedToken { + handleGinError(c, "cache kicked token error", errs.ErrTokenKicked, "") return - } - if v, ok := m[token]; ok { - switch v { - case constant.NormalToken: - case constant.KickedToken: - log.ZWarn(c, "cache kicked token error", errs.ErrTokenKicked.Wrap()) - apiresp.GinError(c, errs.ErrTokenKicked.Wrap()) - c.Abort() - return - default: - log.ZWarn(c, "cache unknown token error", errs.ErrTokenUnknown.Wrap()) - apiresp.GinError(c, errs.ErrTokenUnknown.Wrap()) - c.Abort() - return - } - } else { - apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) - c.Abort() + } else if v != constant.NormalToken { + handleGinError(c, "cache unknown token error", errs.ErrTokenUnknown, "") return } - c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID)) - c.Set(constant.OpUserID, claims.UserID) - c.Next() + } else { + handleGinError(c, "token does not exist error", errs.ErrTokenNotExist, "") + return } + + c.Set(constant.OpUserPlatform, constant.PlatformIDToName(claims.PlatformID)) + c.Set(constant.OpUserID, claims.UserID) + c.Next() } } + +// handleGinError logs and returns an error response through Gin context. +func handleGinError(c *gin.Context, logMessage string, errType errs.CodeError, detail string) { + wrappedErr := errType.Wrap(detail) + log.ZInfo(c, logMessage, wrappedErr) + apiresp.GinError(c, wrappedErr) + c.Abort() +} diff --git a/internal/api/user.go b/internal/api/user.go index e7bbd4bfb..e181d5b32 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -64,61 +64,57 @@ func (u *UserApi) GetUsers(c *gin.Context) { a2r.Call(user.UserClient.GetPaginationUsers, u.Client, c) } -// GetUsersOnlineStatus Get user online status. +// GetUsersOnlineStatus retrieves the online status of users. func (u *UserApi) GetUsersOnlineStatus(c *gin.Context) { var req msggateway.GetUsersOnlineStatusReq if err := c.BindJSON(&req); err != nil { - apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap()) + apiresp.GinError(c, err) return } + conns, err := u.Discov.GetConns(c, config.Config.RpcRegisterName.OpenImMessageGatewayName) if err != nil { apiresp.GinError(c, err) return } - var wsResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult - var respResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult - flag := false - - // Online push message - for _, v := range conns { - msgClient := msggateway.NewMsgGatewayClient(v) + wsResult := make([]*msggateway.GetUsersOnlineStatusResp_SuccessResult, 0) + for _, conn := range conns { + msgClient := msggateway.NewMsgGatewayClient(conn) reply, err := msgClient.GetUsersOnlineStatus(c, &req) if err != nil { - log.ZWarn(c, "GetUsersOnlineStatus rpc err", err) - - parseError := apiresp.ParseError(err) - if parseError.ErrCode == errs.NoPermissionError { - apiresp.GinError(c, err) + log.ZInfo(c, "GetUsersOnlineStatus rpc error", err) + if apiresp.ParseError(err).ErrCode == errs.NoPermissionError { + apiresp.GinError(c, errs.Wrap(err)) return } - } else { - wsResult = append(wsResult, reply.SuccessResult...) + continue } + wsResult = append(wsResult, reply.SuccessResult...) } - // Traversing the userIDs in the api request body - for _, v1 := range req.UserIDs { - flag = false - res := new(msggateway.GetUsersOnlineStatusResp_SuccessResult) - // Iterate through the online results fetched from various gateways - for _, v2 := range wsResult { - // If matches the above description on the line, and vice versa - if v2.UserID == v1 { - flag = true - res.UserID = v1 + + respResult := compileResults(req.UserIDs, wsResult) + apiresp.GinSuccess(c, respResult) +} + +// compileResults aggregates online status results for the provided userIDs. +func compileResults(userIDs []string, wsResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult) []*msggateway.GetUsersOnlineStatusResp_SuccessResult { + respResult := make([]*msggateway.GetUsersOnlineStatusResp_SuccessResult, 0, len(userIDs)) + for _, userID := range userIDs { + res := &msggateway.GetUsersOnlineStatusResp_SuccessResult{ + UserID: userID, + Status: constant.OfflineStatus, // Default to offline + } + for _, result := range wsResult { + if result.UserID == userID { res.Status = constant.OnlineStatus - res.DetailPlatformStatus = append(res.DetailPlatformStatus, v2.DetailPlatformStatus...) + res.DetailPlatformStatus = append(res.DetailPlatformStatus, result.DetailPlatformStatus...) break } } - if !flag { - res.UserID = v1 - res.Status = constant.OfflineStatus - } respResult = append(respResult, res) } - apiresp.GinSuccess(c, respResult) + return respResult } func (u *UserApi) UserRegisterCount(c *gin.Context) { diff --git a/pkg/common/cmd/root.go b/pkg/common/cmd/root.go index 52a4c97e9..a0a07a005 100644 --- a/pkg/common/cmd/root.go +++ b/pkg/common/cmd/root.go @@ -142,6 +142,7 @@ func (r *RootCmd) getPortFlag(cmd *cobra.Command) int { return port } +// GetPortFlag returns the port flag func (r *RootCmd) GetPortFlag() int { return r.port }