add mkldnn and trt config for paddleInference (#2748)

pull/2754/head
TianYuan 2 years ago committed by GitHub
parent b358eb5c99
commit 979bbd9dcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -88,20 +88,27 @@ def parse_args():
parser.add_argument("--output_dir", type=str, help="output dir")
# inference
parser.add_argument(
"--int8",
"--use_trt",
type=str2bool,
default=False,
help="Whether to use int8 inference.", )
help="whether to use TensorRT or not in GPU", )
parser.add_argument(
"--fp16",
"--use_mkldnn",
type=str2bool,
default=False,
help="Whether to use float16 inference.", )
help="whether to use MKLDNN or not in CPU.", )
parser.add_argument(
"--precision",
type=str,
default='fp32',
choices=['fp32', 'fp16', 'bf16', 'int8'],
help="mode of running")
parser.add_argument(
"--device",
default="gpu",
choices=["gpu", "cpu"],
help="Device selected for inference.", )
parser.add_argument('--cpu_threads', type=int, default=1)
args, _ = parser.parse_known_args()
return args
@ -124,7 +131,11 @@ def main():
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
params_file=args.am + ".pdiparams",
device=args.device)
device=args.device,
use_trt=args.use_trt,
use_mkldnn=args.use_mkldnn,
cpu_threads=args.cpu_threads,
precision=args.precision)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
@ -133,7 +144,11 @@ def main():
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
device=args.device,
use_trt=args.use_trt,
use_mkldnn=args.use_mkldnn,
cpu_threads=args.cpu_threads,
precision=args.precision)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -464,18 +464,97 @@ def voc_to_static(voc_inference,
# inference
def get_predictor(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
params_file: Optional[os.PathLike]=None,
device: str='cpu'):
def get_predictor(
model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
params_file: Optional[os.PathLike]=None,
device: str='cpu',
# for gpu
use_trt: bool=False,
# for trt
use_dynamic_shape: bool=True,
min_subgraph_size: int=5,
# for cpu
cpu_threads: int=1,
use_mkldnn: bool=False,
# for trt or mkldnn
precision: int="fp32"):
"""
Args:
model_dir (os.PathLike): root path of model.pdmodel and model.pdiparams.
model_file (os.PathLike): name of model_file.
params_file (os.PathLike): name of params_file.
device (str): Choose the device you want to run, it can be: cpu/gpu, default is cpu.
use_trt (bool): whether to use TensorRT or not in GPU.
use_dynamic_shape (bool): use dynamic shape or not in TensorRT.
use_mkldnn (bool): whether to use MKLDNN or not in CPU.
cpu_threads (int): num of thread when use CPU.
precision (str): mode of running (fp32/fp16/bf16/int8).
"""
rerun_flag = False
if device != "gpu" and use_trt:
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='gpu', but device == {}".
format(precision, device))
config = inference.Config(
str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
config.enable_memory_optim()
config.switch_ir_optim(True)
if device == "gpu":
config.enable_use_gpu(100, 0)
elif device == "cpu":
else:
config.disable_gpu()
config.enable_memory_optim()
config.set_cpu_math_library_num_threads(cpu_threads)
if use_mkldnn:
# fp32
config.enable_mkldnn()
if precision == "int8":
config.enable_mkldnn_int8({
"conv2d_transpose", "conv2d", "depthwise_conv2d", "pool2d",
"transpose2", "elementwise_mul"
})
# config.enable_mkldnn_int8()
elif precision in {"fp16", "bf16"}:
config.enable_mkldnn_bfloat16()
print("MKLDNN with {}".format(precision))
if use_trt:
if precision == "bf16":
print("paddle trt does not support bf16, switching to fp16.")
precision = "fp16"
precision_map = {
"int8": inference.Config.Precision.Int8,
"fp32": inference.Config.Precision.Float32,
"fp16": inference.Config.Precision.Half,
}
assert precision in precision_map.keys()
pdtxt_name = model_file.split(".")[0] + "_" + precision + ".txt"
if use_dynamic_shape:
dynamic_shape_file = os.path.join(model_dir, pdtxt_name)
if os.path.exists(dynamic_shape_file):
config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file,
True)
# for fastspeech2
config.exp_disable_tensorrt_ops(["reshape2"])
print("trt set dynamic shape done!")
else:
# In order to avoid memory overflow when collecting dynamic shapes, it is changed to use CPU.
config.disable_gpu()
config.set_cpu_math_library_num_threads(10)
config.collect_shape_range_info(dynamic_shape_file)
print("Start collect dynamic shape...")
rerun_flag = True
if not rerun_flag:
print("Tensor RT with {}".format(precision))
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[precision],
use_static=True,
use_calib_mode=False, )
predictor = inference.create_predictor(config)
return predictor

Loading…
Cancel
Save