From 95d75ad7edf761a19823395ae6ae3c61e5132b8b Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Fri, 30 Jun 2023 14:38:19 +0800 Subject: [PATCH] middleware checker --- pkg/a2r/api2rpc.go | 3 ++- pkg/checker/check.go | 5 +++++ pkg/common/mw/rpc_server_interceptor.go | 6 ++++++ 3 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 pkg/checker/check.go diff --git a/pkg/a2r/api2rpc.go b/pkg/a2r/api2rpc.go index bb1cb1fac..4cfa3ee4c 100644 --- a/pkg/a2r/api2rpc.go +++ b/pkg/a2r/api2rpc.go @@ -2,6 +2,7 @@ package a2r import ( "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/checker" "github.com/OpenIMSDK/Open-IM-Server/pkg/apiresp" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" @@ -21,7 +22,7 @@ func Call[A, B, C any]( apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap()) // 参数错误 return } - if check, ok := any(&req).(interface{ Check() error }); ok { + if check, ok := any(&req).(checker.Checker); ok { if err := check.Check(); err != nil { log.ZWarn(c, "custom check error", err, "req", req) apiresp.GinError(c, errs.ErrArgs.Wrap(err.Error())) // 参数校验失败 diff --git a/pkg/checker/check.go b/pkg/checker/check.go new file mode 100644 index 000000000..f59986ec7 --- /dev/null +++ b/pkg/checker/check.go @@ -0,0 +1,5 @@ +package checker + +type Checker interface { + Check() error +} diff --git a/pkg/common/mw/rpc_server_interceptor.go b/pkg/common/mw/rpc_server_interceptor.go index 6d2df5459..933ab88f1 100644 --- a/pkg/common/mw/rpc_server_interceptor.go +++ b/pkg/common/mw/rpc_server_interceptor.go @@ -3,6 +3,7 @@ package mw import ( "context" "fmt" + "github.com/OpenIMSDK/Open-IM-Server/pkg/checker" "math" "runtime" "strings" @@ -91,6 +92,11 @@ func RpcServerInterceptor(ctx context.Context, req interface{}, info *grpc.Unary return nil, status.New(codes.InvalidArgument, err.Error()).Err() } } + if err := req.(checker.Checker); err != nil { + if err := err.Check(); err != nil { + return nil, status.New(codes.InvalidArgument, err.Error()).Err() + } + } log.ZInfo(ctx, "rpc server req", "funcName", funcName, "req", rpcString(req)) resp, err = handler(ctx, req) if err == nil {