fix npu inference

pull/3804/head
warrentdrew 1 year ago
parent 7b5218edf2
commit 13ea0dae56

@ -112,7 +112,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--device", "--device",
default="gpu", default="gpu",
choices=["gpu", "cpu", "xpu"], choices=["gpu", "cpu", "xpu", "npu"],
help="Device selected for inference.", ) help="Device selected for inference.", )
parser.add_argument('--cpu_threads', type=int, default=1) parser.add_argument('--cpu_threads', type=int, default=1)

@ -591,6 +591,7 @@ def get_predictor(
config = inference.Config( config = inference.Config(
str(Path(model_dir) / model_file), str(Path(model_dir) / params_file)) str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
if device != "npu":
config.enable_memory_optim() config.enable_memory_optim()
config.switch_ir_optim(True) config.switch_ir_optim(True)
if device == "gpu": if device == "gpu":

Loading…
Cancel
Save