package mw import ( "context" "fmt" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/log" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/mw/specialerror" "github.com/OpenIMSDK/Open-IM-Server/pkg/errs" "github.com/OpenIMSDK/Open-IM-Server/pkg/proto/wrapperspb" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "math" "runtime/debug" ) const OperationID = "operationID" const OpUserID = "opUserID" func rpcString(v interface{}) string { if s, ok := v.(interface{ String() string }); ok { return s.String() } return fmt.Sprintf("%+v", v) } func rpcServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { var operationID string defer func() { if r := recover(); r != nil { log.ZError(ctx, "rpc panic", nil, "FullMethod", info.FullMethod, "type:", fmt.Sprintf("%T", r), "panic:", r, string(debug.Stack())) } }() funcName := info.FullMethod md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, status.New(codes.InvalidArgument, "missing metadata").Err() } if opts := md.Get(OperationID); len(opts) != 1 || opts[0] == "" { return nil, status.New(codes.InvalidArgument, "operationID error").Err() } else { operationID = opts[0] } var opUserID string if opts := md.Get(OpUserID); len(opts) == 1 { opUserID = opts[0] } ctx = context.WithValue(ctx, OperationID, operationID) ctx = context.WithValue(ctx, OpUserID, opUserID) log.ZInfo(ctx, "rpc server req", "funcName", funcName, "req", rpcString(req)) resp, err = handler(ctx, req) if err == nil { log.ZInfo(ctx, "rpc server resp", "funcName", funcName, "resp", rpcString(resp)) return resp, nil } unwrap := errs.Unwrap(err) codeErr := specialerror.ErrCode(unwrap) if codeErr == nil { log.ZError(ctx, "rpc InternalServer error", err, "req", req) codeErr = errs.ErrInternalServer } code := codeErr.Code() if code <= 0 || code > math.MaxUint32 { log.ZError(ctx, "rpc UnknownError", err, "rpc UnknownCode:", code) code = errs.ServerInternalError } grpcStatus := status.New(codes.Code(code), codeErr.Msg()) if unwrap != err { stack := fmt.Sprintf("%+v", err) if details, err := grpcStatus.WithDetails(wrapperspb.String(stack)); err == nil { grpcStatus = details } } log.ZError(ctx, "rpc server resp", err, "funcName", funcName) return nil, grpcStatus.Err() } func GrpcServer() grpc.ServerOption { return grpc.UnaryInterceptor(rpcServerInterceptor) }