specify id, test=doc

pull/2129/head
lym0302 2 years ago
parent d66d6a05c7
commit 3d5ed00c60

@ -30,7 +30,9 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
# "gpu:0"
providers = ['CPUExecutionProvider']
if "gpu" in sess_conf.get("device", ""):
providers = ['CUDAExecutionProvider']
device_id = int(sess_conf["device"].split(":")[1])
providers = [('CUDAExecutionProvider', {'device_id': device_id})]
# fastspeech2/mb_melgan can't use trt now!
if sess_conf.get("use_trt", 0):
providers = ['TensorrtExecutionProvider']

Loading…
Cancel
Save