diff --git a/cmd/api/main.go b/cmd/api/main.go index 03ca933a3..ad7d5a6f4 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -4,6 +4,7 @@ import ( "OpenIM/internal/api" "OpenIM/pkg/common/cmd" "OpenIM/pkg/common/config" + "OpenIM/pkg/common/db/cache" "OpenIM/pkg/common/log" "context" "fmt" @@ -27,12 +28,16 @@ func run(port int) error { if port == 0 { port = config.Config.Api.GinPort[0] } + cache, err := cache.NewRedis() + if err != nil { + return err + } zk, err := openKeeper.NewClient(config.Config.Zookeeper.ZkAddr, config.Config.Zookeeper.Schema, 10, config.Config.Zookeeper.UserName, config.Config.Zookeeper.Password) if err != nil { return err } log.NewPrivateLog(constant.LogFileName) - router := api.NewGinRouter(zk) + router := api.NewGinRouter(zk, cache) var address string if config.Config.Api.ListenIP != "" { address = net.JoinHostPort(config.Config.Api.ListenIP, strconv.Itoa(port)) diff --git a/internal/api/a2r/api2rpc.go b/internal/api/a2r/api2rpc.go index 56acd6908..a1da7ab26 100644 --- a/internal/api/a2r/api2rpc.go +++ b/internal/api/a2r/api2rpc.go @@ -16,13 +16,13 @@ func Call[A, B, C any]( ) { var req A if err := c.BindJSON(&req); err != nil { - log.ZWarn(c, "gin bind json error", err, req) + log.ZWarn(c, "gin bind json error", err, "req", req) apiresp.GinError(c, errs.ErrArgs.Wrap(err.Error())) // 参数错误 return } if check, ok := any(&req).(interface{ Check() error }); ok { if err := check.Check(); err != nil { - log.ZWarn(c, "custom check error", err, req) + log.ZWarn(c, "custom check error", err, "req", req) apiresp.GinError(c, errs.ErrArgs.Wrap(err.Error())) // 参数校验失败 return } diff --git a/internal/api/route.go b/internal/api/route.go index e2bf10fa9..875a80817 100644 --- a/internal/api/route.go +++ b/internal/api/route.go @@ -7,20 +7,22 @@ import ( "OpenIM/pkg/common/prome" "OpenIM/pkg/discoveryregistry" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "io" "os" ) -func NewGinRouter(zk discoveryregistry.SvcDiscoveryRegistry) *gin.Engine { +func NewGinRouter(zk discoveryregistry.SvcDiscoveryRegistry, cache redis.UniversalClient) *gin.Engine { gin.SetMode(gin.ReleaseMode) f, _ := os.Create("../logs/api.log") gin.DefaultWriter = io.MultiWriter(f) // gin.SetMode(gin.DebugMode) r := gin.New() log.Info("load config: ", config.Config) - r.Use(gin.Recovery(), mw.CorsHandler(), mw.GinParseOperationID()) + r.Use(gin.Recovery(), mw.CorsHandler(), mw.GinParseOperationID(), mw.GinParseToken(cache)) + if config.Config.Prometheus.Enable { prome.NewApiRequestCounter() prome.NewApiRequestFailedCounter() diff --git a/pkg/common/constant/constant.go b/pkg/common/constant/constant.go index bfb66ce47..e553be29b 100644 --- a/pkg/common/constant/constant.go +++ b/pkg/common/constant/constant.go @@ -273,6 +273,8 @@ const ( const OperationID = "operationID" const OpUserID = "opUserID" const ConnID = "connID" +const OpUserIDPlatformID = "platformID" +const Token = "token" const ( UnreliableNotification = 1 diff --git a/pkg/common/mw/gin.go b/pkg/common/mw/gin.go index 21008facf..463ac64af 100644 --- a/pkg/common/mw/gin.go +++ b/pkg/common/mw/gin.go @@ -1,10 +1,17 @@ package mw import ( + "OpenIM/internal/apiresp" + "OpenIM/pkg/common/config" "OpenIM/pkg/common/constant" + "OpenIM/pkg/common/db/cache" + "OpenIM/pkg/common/db/controller" + "OpenIM/pkg/common/tokenverify" + "OpenIM/pkg/errs" "bytes" "encoding/json" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "io" "net/http" ) @@ -61,3 +68,51 @@ func GinParseOperationID() gin.HandlerFunc { c.Next() } } +func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { + dataBase := controller.NewAuthDatabase(cache.NewCacheModel(rdb), config.Config.TokenPolicy.AccessSecret, config.Config.TokenPolicy.AccessExpire) + return func(c *gin.Context) { + switch c.Request.Method { + case http.MethodPost: + token := c.Request.Header.Get(constant.Token) + if token == "" { + apiresp.GinError(c, errs.ErrArgs.Wrap()) + c.Abort() + return + } + claims, err := tokenverify.GetClaimFromToken(token) + if err != nil { + apiresp.GinError(c, errs.ErrTokenUnknown.Wrap()) + c.Abort() + return + } + m, err := dataBase.GetTokensWithoutError(c, claims.UID, claims.Platform) + if err != nil { + apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) + c.Abort() + return + } + if len(m) == 0 { + apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) + c.Abort() + return + } + if v, ok := m[token]; ok { + switch v { + case constant.NormalToken: + case constant.KickedToken: + apiresp.GinError(c, errs.ErrTokenKicked.Wrap()) + c.Abort() + return + default: + apiresp.GinError(c, errs.ErrTokenUnknown.Wrap()) + c.Abort() + return + } + } + c.Set(constant.OpUserIDPlatformID, constant.PlatformNameToID(claims.Platform)) + c.Set(constant.OpUserID, claims.UID) + c.Next() + } + + } +}