[TTS] Support set device id for tts prediction, test=tts (#3019)

pull/3028/head
MistEO 3 years ago committed by GitHub
parent 817263fd30
commit 319c805968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -490,6 +490,7 @@ def get_predictor(
device: str='cpu',
# for gpu
use_trt: bool=False,
device_id: int=0,
# for trt
use_dynamic_shape: bool=True,
min_subgraph_size: int=5,
@ -505,6 +506,7 @@ def get_predictor(
params_file (os.PathLike): name of params_file.
device (str): Choose the device you want to run, it can be: cpu/gpu, default is cpu.
use_trt (bool): whether to use TensorRT or not in GPU.
device_id (int): Choose your device id, only valid when the device is gpu, default 0.
use_dynamic_shape (bool): use dynamic shape or not in TensorRT.
use_mkldnn (bool): whether to use MKLDNN or not in CPU.
cpu_threads (int): num of thread when use CPU.
@ -521,7 +523,7 @@ def get_predictor(
config.enable_memory_optim()
config.switch_ir_optim(True)
if device == "gpu":
config.enable_use_gpu(100, 0)
config.enable_use_gpu(100, device_id)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)

Loading…
Cancel
Save