You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Open-IM-Server/internal/rpc/auth/auth.go

344 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"context"
"errors"
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/redisutil"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
"github.com/openimsdk/open-im-server/v3/pkg/authverify"
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
pbauth "github.com/openimsdk/protocol/auth"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/protocol/msggateway"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/tokenverify"
"google.golang.org/grpc"
)
type authServer struct {
pbauth.UnimplementedAuthServer
blacklistDB controller.UserGlobalBlackDatabase
authDatabase controller.AuthDatabase
RegisterCenter discovery.SvcDiscoveryRegistry
config *Config
userClient *rpcli.UserClient
}
type Config struct {
RpcConfig config.Auth
RedisConfig config.Redis
MongodbConfig config.Mongo
Share config.Share
Discovery config.Discovery
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server *grpc.Server) error {
rdb, err := redisutil.NewRedisClient(ctx, config.RedisConfig.Build())
if err != nil {
return err
}
mgocli, err := mongoutil.NewMongoDB(ctx, config.MongodbConfig.Build())
if err != nil {
return err
}
userGlobalBlackDB, err := mgo.NewUserGlobalBlackMongo(mgocli.GetDB())
if err != nil {
return err
}
userConn, err := client.GetConn(ctx, config.Share.RpcRegisterName.User)
if err != nil {
return err
}
pbauth.RegisterAuthServer(server, &authServer{
RegisterCenter: client,
authDatabase: controller.NewAuthDatabase(
redis2.NewTokenCacheModel(rdb, config.RpcConfig.TokenPolicy.Expire),
config.Share.Secret,
config.RpcConfig.TokenPolicy.Expire,
config.Share.MultiLogin,
config.Share.IMAdminUserID,
),
config: config,
blacklistDB: controller.NewUserGlobalBlackDatabase(userGlobalBlackDB),
userClient: rpcli.NewUserClient(userConn),
})
return nil
}
func (s *authServer) GetAdminToken(ctx context.Context, req *pbauth.GetAdminTokenReq) (*pbauth.GetAdminTokenResp, error) {
resp := pbauth.GetAdminTokenResp{}
if req.Secret != s.config.Share.Secret {
return nil, errs.ErrNoPermission.WrapMsg("secret invalid")
}
if !datautil.Contain(req.UserID, s.config.Share.IMAdminUserID...) {
return nil, errs.ErrArgs.WrapMsg("userID is error.", "userID", req.UserID, "adminUserID", s.config.Share.IMAdminUserID)
}
if err := s.userClient.CheckUser(ctx, []string{req.UserID}); err != nil {
return nil, err
}
token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(constant.AdminPlatformID))
if err != nil {
return nil, err
}
prommetrics.UserLoginCounter.Inc()
resp.Token = token
resp.ExpireTimeSeconds = s.config.RpcConfig.TokenPolicy.Expire * 24 * 60 * 60
return &resp, nil
}
func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) {
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
return nil, err
}
if req.PlatformID == constant.AdminPlatformID {
return nil, errs.ErrNoPermission.WrapMsg("platformID invalid. platformID must not be adminPlatformID")
}
resp := pbauth.GetUserTokenResp{}
if authverify.IsManagerUserID(req.UserID, s.config.Share.IMAdminUserID) {
return nil, errs.ErrNoPermission.WrapMsg("don't get Admin token")
}
user, err := s.userClient.GetUserInfo(ctx, req.UserID)
if err != nil {
return nil, err
}
if user.AppMangerLevel >= constant.AppNotificationAdmin {
return nil, errs.ErrArgs.WrapMsg("app account can`t get token")
}
// 仅黑名单status=2禁止登录冻结status=1允许获取 token仅在收发消息层面拦截
status, _ := s.blacklistDB.GetStatus(ctx, req.UserID)
if status == model.UserStatusBlacklist {
if kickErr := s.forceKickOffAllPlatforms(ctx, req.UserID); kickErr != nil {
log.ZWarn(ctx, "GetUserToken forceKickOffAllPlatforms failed", kickErr, "userID", req.UserID)
}
log.ZWarn(ctx, "GetUserToken is blocked", errors.New("user is in global blacklist, userID="+req.UserID), "userID", req.UserID, "status", status)
return nil, servererrs.ErrUserBlocked.WithDetail("user is in global blacklist, userID=" + req.UserID)
}
token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID))
if err != nil {
return nil, err
}
resp.Token = token
resp.ExpireTimeSeconds = s.config.RpcConfig.TokenPolicy.Expire * 24 * 60 * 60
return &resp, nil
}
func (s *authServer) parseToken(ctx context.Context, tokensString string) (claims *tokenverify.Claims, err error) {
claims, err = tokenverify.GetClaimFromToken(tokensString, authverify.Secret(s.config.Share.Secret))
if err != nil {
return nil, err
}
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
if isAdmin {
return claims, nil
}
// 非管理员用户检查全局黑名单:仅 status=2黑名单拦截status=1冻结允许通过 token 校验
status, _ := s.blacklistDB.GetStatus(ctx, claims.UserID)
if status == model.UserStatusBlacklist {
if kickErr := s.forceKickOffAllPlatforms(ctx, claims.UserID); kickErr != nil {
log.ZWarn(ctx, "parseToken forceKickOffAllPlatforms failed", kickErr, "userID", claims.UserID)
}
log.ZWarn(ctx, "parseToken is blocked", errors.New("user is in global blacklist, userID="+claims.UserID), "userID", claims.UserID, "status", status)
return nil, servererrs.ErrUserBlocked.WithDetail("user is in global blacklist, userID=" + claims.UserID)
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
if err != nil {
return nil, err
}
if len(m) == 0 {
return nil, servererrs.ErrTokenNotExist.Wrap()
}
if v, ok := m[tokensString]; ok {
switch v {
case constant.NormalToken:
return claims, nil
case constant.KickedToken:
return nil, servererrs.ErrTokenKicked.Wrap()
default:
return nil, errs.Wrap(errs.ErrTokenUnknown)
}
}
return nil, servererrs.ErrTokenNotExist.Wrap()
}
func (s *authServer) ParseToken(ctx context.Context, req *pbauth.ParseTokenReq) (resp *pbauth.ParseTokenResp, err error) {
resp = &pbauth.ParseTokenResp{}
claims, err := s.parseToken(ctx, req.Token)
if err != nil {
return nil, err
}
resp.UserID = claims.UserID
resp.PlatformID = int32(claims.PlatformID)
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
return resp, nil
}
func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) {
if err := authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID); err != nil {
return nil, err
}
if err := s.forceKickOff(ctx, req.UserID, req.PlatformID); err != nil {
return nil, err
}
return &pbauth.ForceLogoutResp{}, nil
}
func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID int32) error {
conns, err := s.RegisterCenter.GetConns(ctx, s.config.Share.RpcRegisterName.MessageGateway)
if err != nil {
return err
}
for _, v := range conns {
log.ZDebug(ctx, "forceKickOff", "userID", userID, "platformID", platformID)
client := msggateway.NewMsgGatewayClient(v)
kickReq := &msggateway.KickUserOfflineReq{KickUserIDList: []string{userID}, PlatformID: platformID}
_, err := client.KickUserOffline(ctx, kickReq)
if err != nil {
log.ZError(ctx, "forceKickOff", err, "kickReq", kickReq)
}
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, userID, int(platformID))
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
for k := range m {
m[k] = constant.KickedToken
log.ZDebug(ctx, "set token map is ", "token map", m, "userID",
userID, "token", k)
err = s.authDatabase.SetTokenMapByUidPid(ctx, userID, int(platformID), m)
if err != nil {
return err
}
}
return nil
}
func (s *authServer) forceKickOffAllPlatforms(ctx context.Context, userID string) error {
for platformID := range constant.PlatformID2Name {
if int32(platformID) == constant.AdminPlatformID {
continue
}
if err := s.forceKickOff(ctx, userID, int32(platformID)); err != nil {
return err
}
}
return nil
}
func (s *authServer) InvalidateToken(ctx context.Context, req *pbauth.InvalidateTokenReq) (*pbauth.InvalidateTokenResp, error) {
m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, int(req.PlatformID))
if err != nil && !errors.Is(err, redis.Nil) {
return nil, err
}
if m == nil {
return nil, errs.New("token map is empty").Wrap()
}
log.ZDebug(ctx, "get token from redis", "userID", req.UserID, "platformID",
req.PlatformID, "tokenMap", m)
for k := range m {
if k != req.GetPreservedToken() {
m[k] = constant.KickedToken
}
}
log.ZDebug(ctx, "set token map is ", "token map", m, "userID",
req.UserID, "token", req.GetPreservedToken())
err = s.authDatabase.SetTokenMapByUidPid(ctx, req.UserID, int(req.PlatformID), m)
if err != nil {
return nil, err
}
return &pbauth.InvalidateTokenResp{}, nil
}
func (s *authServer) KickTokens(ctx context.Context, req *pbauth.KickTokensReq) (*pbauth.KickTokensResp, error) {
if err := s.authDatabase.BatchSetTokenMapByUidPid(ctx, req.Tokens); err != nil {
return nil, err
}
return &pbauth.KickTokensResp{}, nil
}
// GetActiveDevices returns all platforms that have at least one valid (non-kicked) token for the user.
// Only the user themselves or an admin can call this.
func (s *authServer) GetActiveDevices(ctx context.Context, req *pbauth.GetActiveDevicesReq) (*pbauth.GetActiveDevicesResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
return nil, err
}
var devices []*pbauth.DeviceInfo
for platformID, platformName := range constant.PlatformID2Name {
if int32(platformID) == constant.AdminPlatformID {
continue
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, platformID)
if err != nil {
return nil, err
}
for _, state := range m {
if state == constant.NormalToken {
devices = append(devices, &pbauth.DeviceInfo{
PlatformID: int32(platformID),
PlatformName: platformName,
})
break
}
}
}
return &pbauth.GetActiveDevicesResp{Devices: devices}, nil
}
// KickDevice kicks the specified platform device offline for the given user.
// Only the user themselves or an admin can call this.
func (s *authServer) KickDevice(ctx context.Context, req *pbauth.KickDeviceReq) (*pbauth.KickDeviceResp, error) {
if err := authverify.CheckAccessV3(ctx, req.UserID, s.config.Share.IMAdminUserID); err != nil {
return nil, err
}
if req.PlatformID == constant.AdminPlatformID {
return nil, errs.ErrArgs.WrapMsg("cannot kick admin platform")
}
if _, ok := constant.PlatformID2Name[int(req.PlatformID)]; !ok {
return nil, errs.ErrArgs.WrapMsg("invalid platformID", "platformID", req.PlatformID)
}
if err := s.forceKickOff(ctx, req.UserID, req.PlatformID); err != nil {
return nil, err
}
return &pbauth.KickDeviceResp{}, nil
}