diff --git a/auto/api/v1/core.go b/auto/api/v1/core.go index 159e938f..3489568e 100644 --- a/auto/api/v1/core.go +++ b/auto/api/v1/core.go @@ -31,7 +31,6 @@ type Core interface { SendUserWhisper(*web.SendWhisperReq) mir.Error ReadMessage(*web.ReadMessageReq) mir.Error GetMessages(*web.GetMessagesReq) (*web.GetMessagesResp, mir.Error) - GetUnreadMsgCount(*web.GetUnreadMsgCountReq) (*web.GetUnreadMsgCountResp, mir.Error) GetUserInfo(*web.UserInfoReq) (*web.UserInfoResp, mir.Error) SyncSearchIndex(*web.SyncSearchIndexReq) mir.Error @@ -229,20 +228,6 @@ func RegisterCoreServant(e *gin.Engine, s Core) { resp, err := s.GetMessages(req) s.Render(c, resp, err) }) - router.Handle("GET", "/user/msgcount/unread", func(c *gin.Context) { - select { - case <-c.Request.Context().Done(): - return - default: - } - req := new(web.GetUnreadMsgCountReq) - if err := s.Bind(c, req); err != nil { - s.Render(c, nil, err) - return - } - resp, err := s.GetUnreadMsgCount(req) - s.Render(c, resp, err) - }) router.Handle("GET", "/user/info", func(c *gin.Context) { select { case <-c.Request.Context().Done(): @@ -332,10 +317,6 @@ func (UnimplementedCoreServant) GetMessages(req *web.GetMessagesReq) (*web.GetMe return nil, mir.Errorln(http.StatusNotImplemented, http.StatusText(http.StatusNotImplemented)) } -func (UnimplementedCoreServant) GetUnreadMsgCount(req *web.GetUnreadMsgCountReq) (*web.GetUnreadMsgCountResp, mir.Error) { - return nil, mir.Errorln(http.StatusNotImplemented, http.StatusText(http.StatusNotImplemented)) -} - func (UnimplementedCoreServant) GetUserInfo(req *web.UserInfoReq) (*web.UserInfoResp, mir.Error) { return nil, mir.Errorln(http.StatusNotImplemented, http.StatusText(http.StatusNotImplemented)) } diff --git a/auto/api/v1/relax.go b/auto/api/v1/relax.go new file mode 100644 index 00000000..c1e64ab4 --- /dev/null +++ b/auto/api/v1/relax.go @@ -0,0 +1,82 @@ +// Code generated by go-mir. DO NOT EDIT. +// versions: +// - mir v4.0.0 + +package v1 + +import ( + "net/http" + + "github.com/alimy/mir/v4" + "github.com/gin-gonic/gin" + "github.com/rocboss/paopao-ce/internal/model/web" +) + +type Relax interface { + _default_ + + // Chain provide handlers chain for gin + Chain() gin.HandlersChain + + GetUnreadMsgCount(*web.GetUnreadMsgCountReq) (*web.GetUnreadMsgCountResp, mir.Error) + + mustEmbedUnimplementedRelaxServant() +} + +type RelaxChain interface { + ChainGetUnreadMsgCount() gin.HandlersChain + + mustEmbedUnimplementedRelaxChain() +} + +// RegisterRelaxServant register Relax servant to gin +func RegisterRelaxServant(e *gin.Engine, s Relax, m ...RelaxChain) { + var cc RelaxChain + if len(m) > 0 { + cc = m[0] + } else { + cc = &UnimplementedRelaxChain{} + } + router := e.Group("v1") + // use chain for router + middlewares := s.Chain() + router.Use(middlewares...) + + // register routes info to router + router.Handle("GET", "/user/msgcount/unread", append(cc.ChainGetUnreadMsgCount(), func(c *gin.Context) { + select { + case <-c.Request.Context().Done(): + return + default: + } + req := new(web.GetUnreadMsgCountReq) + if err := s.Bind(c, req); err != nil { + s.Render(c, nil, err) + return + } + resp, err := s.GetUnreadMsgCount(req) + s.Render(c, resp, err) + })...) +} + +// UnimplementedRelaxServant can be embedded to have forward compatible implementations. +type UnimplementedRelaxServant struct{} + +func (UnimplementedRelaxServant) Chain() gin.HandlersChain { + return nil +} + +func (UnimplementedRelaxServant) GetUnreadMsgCount(req *web.GetUnreadMsgCountReq) (*web.GetUnreadMsgCountResp, mir.Error) { + return nil, mir.Errorln(http.StatusNotImplemented, http.StatusText(http.StatusNotImplemented)) +} + +func (UnimplementedRelaxServant) mustEmbedUnimplementedRelaxServant() {} + +// UnimplementedRelaxChain can be embedded to have forward compatible implementations. +type UnimplementedRelaxChain struct{} + +func (b *UnimplementedRelaxChain) ChainGetUnreadMsgCount() gin.HandlersChain { + return nil +} + +func (b *UnimplementedRelaxChain) mustEmbedUnimplementedRelaxChain() {} diff --git a/go.mod b/go.mod index 91e4a0c2..5be25f88 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/go-resty/resty/v2 v2.7.0 github.com/goccy/go-json v0.10.2 github.com/gofrs/uuid/v5 v5.0.0 - github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang-migrate/migrate/v4 v4.15.2 github.com/huaweicloud/huaweicloud-sdk-go-obs v3.23.4+incompatible github.com/jackc/pgx/v5 v5.4.2 @@ -77,6 +77,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/cel-go v0.17.1 // indirect diff --git a/go.sum b/go.sum index a8ca9086..b4ddf67e 100644 --- a/go.sum +++ b/go.sum @@ -592,6 +592,8 @@ github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzw github.com/golang-jwt/jwt/v4 v4.1.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= +github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-migrate/migrate/v4 v4.15.2 h1:vU+M05vs6jWHKDdmE1Ecwj0BznygFc4QsdRe2E/L7kc= github.com/golang-migrate/migrate/v4 v4.15.2/go.mod h1:f2toGLkYqD3JH+Todi4aZ2ZdbeUNx4sIwiOK96rE9Lw= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= diff --git a/internal/model/web/core.go b/internal/model/web/core.go index 78c04bb3..4b01e9c8 100644 --- a/internal/model/web/core.go +++ b/internal/model/web/core.go @@ -40,14 +40,6 @@ type UserInfoResp struct { Followings int64 `json:"followings"` } -type GetUnreadMsgCountReq struct { - SimpleInfo `json:"-" binding:"-"` -} - -type GetUnreadMsgCountResp struct { - Count int64 `json:"count"` -} - type GetMessagesReq BasePageReq type GetMessagesResp base.PageResp diff --git a/internal/model/web/relax.go b/internal/model/web/relax.go new file mode 100644 index 00000000..432f86c5 --- /dev/null +++ b/internal/model/web/relax.go @@ -0,0 +1,13 @@ +// Copyright 2023 ROC. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package web + +type GetUnreadMsgCountReq struct { + SimpleInfo `json:"-" binding:"-"` +} + +type GetUnreadMsgCountResp struct { + Count int64 `json:"count"` +} diff --git a/internal/servants/chain/jwt.go b/internal/servants/chain/jwt.go index ae96d550..04ccd58e 100644 --- a/internal/servants/chain/jwt.go +++ b/internal/servants/chain/jwt.go @@ -5,10 +5,11 @@ package chain import ( + "errors" "strings" "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/pkg/app" "github.com/rocboss/paopao-ce/pkg/xerror" @@ -51,10 +52,53 @@ func JWT() gin.HandlerFunc { ecode = xerror.UnauthorizedAuthNotExist } } else { - switch err.(*jwt.ValidationError).Errors { - case jwt.ValidationErrorExpired: + if errors.Is(err, jwt.ErrTokenExpired) { ecode = xerror.UnauthorizedTokenTimeout - default: + } else { + ecode = xerror.UnauthorizedTokenError + } + } + } else { + ecode = xerror.InvalidParams + } + if ecode != xerror.Success { + response := app.NewResponse(c) + response.ToErrorResponse(ecode) + c.Abort() + return + } + c.Next() + } +} + +func JwtSurely() gin.HandlerFunc { + return func(c *gin.Context) { + var ( + token string + ecode = xerror.Success + ) + if s, exist := c.GetQuery("token"); exist { + token = s + } else { + token = c.GetHeader("Authorization") + // 验证前端传过来的token格式,不为空,开头为Bearer + if token == "" || !strings.HasPrefix(token, "Bearer ") { + response := app.NewResponse(c) + response.ToErrorResponse(xerror.UnauthorizedTokenError) + c.Abort() + return + } + // 验证通过,提取有效部分(除去Bearer) + token = token[7:] + } + if token != "" { + if claims, err := app.ParseToken(token); err == nil { + c.Set("UID", claims.UID) + c.Set("USERNAME", claims.Username) + } else { + if errors.Is(err, jwt.ErrTokenExpired) { + ecode = xerror.UnauthorizedTokenTimeout + } else { ecode = xerror.UnauthorizedTokenError } } diff --git a/internal/servants/web/core.go b/internal/servants/web/core.go index de092a60..19267198 100644 --- a/internal/servants/web/core.go +++ b/internal/servants/web/core.go @@ -82,16 +82,6 @@ func (s *coreSrv) GetUserInfo(req *web.UserInfoReq) (*web.UserInfoResp, mir.Erro return resp, nil } -func (s *coreSrv) GetUnreadMsgCount(req *web.GetUnreadMsgCountReq) (*web.GetUnreadMsgCountResp, mir.Error) { - count, err := s.Ds.GetUnreadCount(req.Uid) - if err != nil { - return nil, xerror.ServerError - } - return &web.GetUnreadMsgCountResp{ - Count: count, - }, nil -} - func (s *coreSrv) GetMessages(req *web.GetMessagesReq) (*web.GetMessagesResp, mir.Error) { conditions := &ms.ConditionsT{ "receiver_user_id": req.UserId, diff --git a/internal/servants/web/relax.go b/internal/servants/web/relax.go new file mode 100644 index 00000000..ddd7f078 --- /dev/null +++ b/internal/servants/web/relax.go @@ -0,0 +1,53 @@ +// Copyright 2023 ROC. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package web + +import ( + "github.com/alimy/mir/v4" + "github.com/gin-gonic/gin" + api "github.com/rocboss/paopao-ce/auto/api/v1" + "github.com/rocboss/paopao-ce/internal/model/web" + "github.com/rocboss/paopao-ce/internal/servants/base" + "github.com/rocboss/paopao-ce/internal/servants/chain" + "github.com/rocboss/paopao-ce/pkg/xerror" +) + +var ( + _ api.Relax = (*relaxSrv)(nil) + _ api.RelaxChain = (*relaxChain)(nil) +) + +type relaxSrv struct { + api.UnimplementedRelaxServant + *base.DaoServant +} + +type relaxChain struct { + api.UnimplementedRelaxChain +} + +func (s *relaxSrv) GetUnreadMsgCount(req *web.GetUnreadMsgCountReq) (*web.GetUnreadMsgCountResp, mir.Error) { + count, err := s.Ds.GetUnreadCount(req.Uid) + if err != nil { + return nil, xerror.ServerError + } + return &web.GetUnreadMsgCountResp{ + Count: count, + }, nil +} + +func (*relaxChain) ChainGetUnreadMsgCount() gin.HandlersChain { + return gin.HandlersChain{chain.JwtSurely()} +} + +func newRelaxSrv(s *base.DaoServant) api.Relax { + return &relaxSrv{ + DaoServant: s, + } +} + +func newRelaxChain() api.RelaxChain { + return &relaxChain{} +} diff --git a/internal/servants/web/web.go b/internal/servants/web/web.go index 408b3b08..a508d02a 100644 --- a/internal/servants/web/web.go +++ b/internal/servants/web/web.go @@ -34,6 +34,7 @@ func RouteWeb(e *gin.Engine) { api.RegisterPubServant(e, newPubSrv(ds)) api.RegisterFollowshipServant(e, newFollowshipSrv(ds)) api.RegisterFriendshipServant(e, newFriendshipSrv(ds)) + api.RegisterRelaxServant(e, newRelaxSrv(ds), newRelaxChain()) // regster servants if needed by configure cfg.Be("Alipay", func() { client := conf.MustAlipayClient() diff --git a/mirc/web/v1/core.go b/mirc/web/v1/core.go index f12c7018..0ec4a68c 100644 --- a/mirc/web/v1/core.go +++ b/mirc/web/v1/core.go @@ -21,9 +21,6 @@ type Core struct { // GetUserInfo 获取当前用户信息 GetUserInfo func(Get, web.UserInfoReq) web.UserInfoResp `mir:"/user/info"` - // GetUnreadMsgCount 获取当前用户未读消息数量 - GetUnreadMsgCount func(Get, web.GetUnreadMsgCountReq) web.GetUnreadMsgCountResp `mir:"/user/msgcount/unread"` - // GetMessages 获取消息列表 GetMessages func(Get, web.GetMessagesReq) web.GetMessagesResp `mir:"/user/messages"` diff --git a/mirc/web/v1/relax.go b/mirc/web/v1/relax.go new file mode 100644 index 00000000..3bdfdb91 --- /dev/null +++ b/mirc/web/v1/relax.go @@ -0,0 +1,20 @@ +package v1 + +import ( + . "github.com/alimy/mir/v4" + . "github.com/alimy/mir/v4/engine" + "github.com/rocboss/paopao-ce/internal/model/web" +) + +func init() { + Entry[Relax]() +} + +// Relax 放宽授权的服务 +type Relax struct { + Chain `mir:"-"` + Group `mir:"v1"` + + // GetUnreadMsgCount 获取当前用户未读消息数量 + GetUnreadMsgCount func(Get, Chain, web.GetUnreadMsgCountReq) web.GetUnreadMsgCountResp `mir:"/user/msgcount/unread"` +} diff --git a/pkg/app/jwt.go b/pkg/app/jwt.go index a861faa2..28c9f439 100644 --- a/pkg/app/jwt.go +++ b/pkg/app/jwt.go @@ -7,7 +7,7 @@ package app import ( "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/rocboss/paopao-ce/internal/conf" "github.com/rocboss/paopao-ce/internal/core/ms" ) @@ -38,18 +38,15 @@ func GenerateToken(User *ms.User) (string, error) { return token, err } -func ParseToken(token string) (*Claims, error) { - tokenClaims, err := jwt.ParseWithClaims(token, &Claims{}, func(token *jwt.Token) (any, error) { +func ParseToken(token string) (res *Claims, err error) { + var tokenClaims *jwt.Token + tokenClaims, err = jwt.ParseWithClaims(token, &Claims{}, func(_ *jwt.Token) (any, error) { return GetJWTSecret(), nil }) - if err != nil { - return nil, err + if err == nil && tokenClaims != nil && tokenClaims.Valid { + res, _ = tokenClaims.Claims.(*Claims) + } else { + err = jwt.ErrTokenNotValidYet } - if tokenClaims != nil { - if claims, ok := tokenClaims.Claims.(*Claims); ok && tokenClaims.Valid { - return claims, nil - } - } - - return nil, err + return }