[TTS] Support set device id for tts prediction, test=tts (#3019)

pull/3028/head
MistEO 3 years ago committed by GitHub
parent 817263fd30
commit 319c805968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -490,6 +490,7 @@ def get_predictor(
device: str='cpu', device: str='cpu',
# for gpu # for gpu
use_trt: bool=False, use_trt: bool=False,
device_id: int=0,
# for trt # for trt
use_dynamic_shape: bool=True, use_dynamic_shape: bool=True,
min_subgraph_size: int=5, min_subgraph_size: int=5,
@ -505,6 +506,7 @@ def get_predictor(
params_file (os.PathLike): name of params_file. params_file (os.PathLike): name of params_file.
device (str): Choose the device you want to run, it can be: cpu/gpu, default is cpu. device (str): Choose the device you want to run, it can be: cpu/gpu, default is cpu.
use_trt (bool): whether to use TensorRT or not in GPU. use_trt (bool): whether to use TensorRT or not in GPU.
device_id (int): Choose your device id, only valid when the device is gpu, default 0.
use_dynamic_shape (bool): use dynamic shape or not in TensorRT. use_dynamic_shape (bool): use dynamic shape or not in TensorRT.
use_mkldnn (bool): whether to use MKLDNN or not in CPU. use_mkldnn (bool): whether to use MKLDNN or not in CPU.
cpu_threads (int): num of thread when use CPU. cpu_threads (int): num of thread when use CPU.
@ -521,7 +523,7 @@ def get_predictor(
config.enable_memory_optim() config.enable_memory_optim()
config.switch_ir_optim(True) config.switch_ir_optim(True)
if device == "gpu": if device == "gpu":
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, device_id)
else: else:
config.disable_gpu() config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads) config.set_cpu_math_library_num_threads(cpu_threads)

Loading…
Cancel
Save