|
|
|
@ -30,7 +30,9 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
|
|
|
|
|
# "gpu:0"
|
|
|
|
|
providers = ['CPUExecutionProvider']
|
|
|
|
|
if "gpu" in sess_conf.get("device", ""):
|
|
|
|
|
providers = ['CUDAExecutionProvider']
|
|
|
|
|
device_id = int(sess_conf["device"].split(":")[1])
|
|
|
|
|
providers = [('CUDAExecutionProvider', {'device_id': device_id})]
|
|
|
|
|
|
|
|
|
|
# fastspeech2/mb_melgan can't use trt now!
|
|
|
|
|
if sess_conf.get("use_trt", 0):
|
|
|
|
|
providers = ['TensorrtExecutionProvider']
|
|
|
|
|