diff --git a/paddlespeech/server/utils/onnx_infer.py b/paddlespeech/server/utils/onnx_infer.py index 23d83c735..25802f627 100644 --- a/paddlespeech/server/utils/onnx_infer.py +++ b/paddlespeech/server/utils/onnx_infer.py @@ -30,7 +30,9 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): # "gpu:0" providers = ['CPUExecutionProvider'] if "gpu" in sess_conf.get("device", ""): - providers = ['CUDAExecutionProvider'] + device_id = int(sess_conf["device"].split(":")[1]) + providers = [('CUDAExecutionProvider', {'device_id': device_id})] + # fastspeech2/mb_melgan can't use trt now! if sess_conf.get("use_trt", 0): providers = ['TensorrtExecutionProvider']