|
|
|
@ -29,10 +29,21 @@ from yacs.config import CfgNode
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
|
from ..log import logger
|
|
|
|
|
from ..utils import stats_wrapper
|
|
|
|
|
from paddlespeech.resource import CommonTaskResource
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import get_am_inference
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import get_frontend
|
|
|
|
|
from paddlespeech.t2s.modules.normalizer import ZScore
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import get_sess
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
|
|
|
|
|
from paddlespeech.t2s.exps.syn_utils import run_frontend
|
|
|
|
|
from paddlespeech.t2s.utils import str2bool
|
|
|
|
|
|
|
|
|
|
__all__ = ['TTSExecutor']
|
|
|
|
|
ONNX_SUPPORT_SET = {
|
|
|
|
|
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
|
|
|
|
|
'fastspeech2_aishell3', 'fastspeech2_vctk', 'pwgan_csmsc', 'pwgan_ljspeech',
|
|
|
|
|
'pwgan_aishell3', 'pwgan_vctk', 'mb_melgan_csmsc', 'hifigan_csmsc',
|
|
|
|
|
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk'
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TTSExecutor(BaseExecutor):
|
|
|
|
@ -142,6 +153,8 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
default=paddle.get_device(),
|
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
|
|
|
|
|
|
self.parser.add_argument('--cpu_threads', type=int, default=2)
|
|
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--output', type=str, default='output.wav', help='output file name')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
@ -154,6 +167,11 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
'--verbose',
|
|
|
|
|
action='store_true',
|
|
|
|
|
help='Increase logger verbosity of current task.')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
"--use_onnx",
|
|
|
|
|
type=str2bool,
|
|
|
|
|
default=False,
|
|
|
|
|
help="whether to usen onnxruntime inference.")
|
|
|
|
|
|
|
|
|
|
def _init_from_path(
|
|
|
|
|
self,
|
|
|
|
@ -164,7 +182,7 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
phones_dict: Optional[os.PathLike]=None,
|
|
|
|
|
tones_dict: Optional[os.PathLike]=None,
|
|
|
|
|
speaker_dict: Optional[os.PathLike]=None,
|
|
|
|
|
voc: str='pwgan_csmsc',
|
|
|
|
|
voc: str='hifigan_csmsc',
|
|
|
|
|
voc_config: Optional[os.PathLike]=None,
|
|
|
|
|
voc_ckpt: Optional[os.PathLike]=None,
|
|
|
|
|
voc_stat: Optional[os.PathLike]=None,
|
|
|
|
@ -288,58 +306,111 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict)
|
|
|
|
|
|
|
|
|
|
# acoustic model
|
|
|
|
|
odim = self.am_config.n_mels
|
|
|
|
|
# model: {model_name}_{dataset}
|
|
|
|
|
am_name = am[:am.rindex('_')]
|
|
|
|
|
|
|
|
|
|
am_class = self.task_resource.get_model_class(am_name)
|
|
|
|
|
am_inference_class = self.task_resource.get_model_class(am_name +
|
|
|
|
|
'_inference')
|
|
|
|
|
|
|
|
|
|
if am_name == 'fastspeech2':
|
|
|
|
|
am = am_class(
|
|
|
|
|
idim=vocab_size,
|
|
|
|
|
odim=odim,
|
|
|
|
|
spk_num=spk_num,
|
|
|
|
|
**self.am_config["model"])
|
|
|
|
|
elif am_name == 'speedyspeech':
|
|
|
|
|
am = am_class(
|
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
|
tone_size=tone_size,
|
|
|
|
|
**self.am_config["model"])
|
|
|
|
|
elif am_name == 'tacotron2':
|
|
|
|
|
am = am_class(idim=vocab_size, odim=odim, **self.am_config["model"])
|
|
|
|
|
|
|
|
|
|
am.set_state_dict(paddle.load(self.am_ckpt)["main_params"])
|
|
|
|
|
am.eval()
|
|
|
|
|
am_mu, am_std = np.load(self.am_stat)
|
|
|
|
|
am_mu = paddle.to_tensor(am_mu)
|
|
|
|
|
am_std = paddle.to_tensor(am_std)
|
|
|
|
|
am_normalizer = ZScore(am_mu, am_std)
|
|
|
|
|
self.am_inference = am_inference_class(am_normalizer, am)
|
|
|
|
|
self.am_inference.eval()
|
|
|
|
|
self.am_inference = get_am_inference(
|
|
|
|
|
am=am,
|
|
|
|
|
am_config=self.am_config,
|
|
|
|
|
am_ckpt=self.am_ckpt,
|
|
|
|
|
am_stat=self.am_stat,
|
|
|
|
|
phones_dict=self.phones_dict,
|
|
|
|
|
tones_dict=self.tones_dict,
|
|
|
|
|
speaker_dict=self.speaker_dict)
|
|
|
|
|
|
|
|
|
|
# vocoder
|
|
|
|
|
# model: {model_name}_{dataset}
|
|
|
|
|
voc_name = voc[:voc.rindex('_')]
|
|
|
|
|
voc_class = self.task_resource.get_model_class(voc_name)
|
|
|
|
|
voc_inference_class = self.task_resource.get_model_class(voc_name +
|
|
|
|
|
'_inference')
|
|
|
|
|
if voc_name != 'wavernn':
|
|
|
|
|
voc = voc_class(**self.voc_config["generator_params"])
|
|
|
|
|
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"])
|
|
|
|
|
voc.remove_weight_norm()
|
|
|
|
|
voc.eval()
|
|
|
|
|
self.voc_inference = get_voc_inference(
|
|
|
|
|
voc=voc,
|
|
|
|
|
voc_config=self.voc_config,
|
|
|
|
|
voc_ckpt=self.voc_ckpt,
|
|
|
|
|
voc_stat=self.voc_stat)
|
|
|
|
|
|
|
|
|
|
def _init_from_path_onnx(self,
|
|
|
|
|
am: str='fastspeech2_csmsc',
|
|
|
|
|
am_ckpt: Optional[os.PathLike]=None,
|
|
|
|
|
phones_dict: Optional[os.PathLike]=None,
|
|
|
|
|
tones_dict: Optional[os.PathLike]=None,
|
|
|
|
|
speaker_dict: Optional[os.PathLike]=None,
|
|
|
|
|
voc: str='hifigan_csmsc',
|
|
|
|
|
voc_ckpt: Optional[os.PathLike]=None,
|
|
|
|
|
lang: str='zh',
|
|
|
|
|
device: str='cpu',
|
|
|
|
|
cpu_threads: int=2,
|
|
|
|
|
fs: int=24000):
|
|
|
|
|
if hasattr(self, 'am_sess') and hasattr(self, 'voc_sess'):
|
|
|
|
|
logger.debug('Models had been initialized.')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# am
|
|
|
|
|
if am_ckpt is None or phones_dict is None:
|
|
|
|
|
use_pretrained_am = True
|
|
|
|
|
else:
|
|
|
|
|
use_pretrained_am = False
|
|
|
|
|
|
|
|
|
|
am_tag = am + '_onnx' + '-' + lang
|
|
|
|
|
self.task_resource.set_task_model(
|
|
|
|
|
model_tag=am_tag,
|
|
|
|
|
model_type=0, # am
|
|
|
|
|
skip_download=not use_pretrained_am,
|
|
|
|
|
version=None, # default version
|
|
|
|
|
)
|
|
|
|
|
if use_pretrained_am:
|
|
|
|
|
self.am_res_path = self.task_resource.res_dir
|
|
|
|
|
self.am_ckpt = os.path.join(self.am_res_path,
|
|
|
|
|
self.task_resource.res_dict['ckpt'][0])
|
|
|
|
|
# must have phones_dict in acoustic
|
|
|
|
|
self.phones_dict = os.path.join(
|
|
|
|
|
self.am_res_path, self.task_resource.res_dict['phones_dict'])
|
|
|
|
|
self.am_fs = self.task_resource.res_dict['sample_rate']
|
|
|
|
|
logger.debug(self.am_res_path)
|
|
|
|
|
logger.debug(self.am_ckpt)
|
|
|
|
|
else:
|
|
|
|
|
self.am_ckpt = os.path.abspath(am_ckpt[0])
|
|
|
|
|
self.phones_dict = os.path.abspath(phones_dict)
|
|
|
|
|
self.am_res_path = os.path.dirname(os.path.abspath(am_ckpt))
|
|
|
|
|
self.am_fs = fs
|
|
|
|
|
|
|
|
|
|
# for speedyspeech
|
|
|
|
|
self.tones_dict = None
|
|
|
|
|
if 'tones_dict' in self.task_resource.res_dict:
|
|
|
|
|
self.tones_dict = os.path.join(
|
|
|
|
|
self.am_res_path, self.task_resource.res_dict['tones_dict'])
|
|
|
|
|
if tones_dict:
|
|
|
|
|
self.tones_dict = tones_dict
|
|
|
|
|
|
|
|
|
|
# voc
|
|
|
|
|
if voc_ckpt is None:
|
|
|
|
|
use_pretrained_voc = True
|
|
|
|
|
else:
|
|
|
|
|
use_pretrained_voc = False
|
|
|
|
|
voc_lang = lang
|
|
|
|
|
# we must use ljspeech's voc for mix am now!
|
|
|
|
|
if lang == 'mix':
|
|
|
|
|
voc_lang = 'en'
|
|
|
|
|
voc_tag = voc + '_onnx' + '-' + voc_lang
|
|
|
|
|
self.task_resource.set_task_model(
|
|
|
|
|
model_tag=voc_tag,
|
|
|
|
|
model_type=1, # vocoder
|
|
|
|
|
skip_download=not use_pretrained_voc,
|
|
|
|
|
version=None, # default version
|
|
|
|
|
)
|
|
|
|
|
if use_pretrained_voc:
|
|
|
|
|
self.voc_res_path = self.task_resource.voc_res_dir
|
|
|
|
|
self.voc_ckpt = os.path.join(
|
|
|
|
|
self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
|
|
|
|
|
logger.debug(self.voc_res_path)
|
|
|
|
|
logger.debug(self.voc_ckpt)
|
|
|
|
|
else:
|
|
|
|
|
voc = voc_class(**self.voc_config["model"])
|
|
|
|
|
voc.set_state_dict(paddle.load(self.voc_ckpt)["main_params"])
|
|
|
|
|
voc.eval()
|
|
|
|
|
voc_mu, voc_std = np.load(self.voc_stat)
|
|
|
|
|
voc_mu = paddle.to_tensor(voc_mu)
|
|
|
|
|
voc_std = paddle.to_tensor(voc_std)
|
|
|
|
|
voc_normalizer = ZScore(voc_mu, voc_std)
|
|
|
|
|
self.voc_inference = voc_inference_class(voc_normalizer, voc)
|
|
|
|
|
self.voc_inference.eval()
|
|
|
|
|
self.voc_ckpt = os.path.abspath(voc_ckpt)
|
|
|
|
|
self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt))
|
|
|
|
|
|
|
|
|
|
# frontend
|
|
|
|
|
self.frontend = get_frontend(
|
|
|
|
|
lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict)
|
|
|
|
|
|
|
|
|
|
self.am_sess = get_sess(
|
|
|
|
|
model_path=self.am_ckpt, device=device, cpu_threads=cpu_threads)
|
|
|
|
|
|
|
|
|
|
# vocoder
|
|
|
|
|
self.voc_sess = get_sess(
|
|
|
|
|
model_path=self.voc_ckpt, device=device, cpu_threads=cpu_threads)
|
|
|
|
|
|
|
|
|
|
def preprocess(self, input: Any, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -362,40 +433,28 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
"""
|
|
|
|
|
am_name = am[:am.rindex('_')]
|
|
|
|
|
am_dataset = am[am.rindex('_') + 1:]
|
|
|
|
|
get_tone_ids = False
|
|
|
|
|
merge_sentences = False
|
|
|
|
|
frontend_st = time.time()
|
|
|
|
|
get_tone_ids = False
|
|
|
|
|
if am_name == 'speedyspeech':
|
|
|
|
|
get_tone_ids = True
|
|
|
|
|
if lang == 'zh':
|
|
|
|
|
input_ids = self.frontend.get_input_ids(
|
|
|
|
|
text,
|
|
|
|
|
merge_sentences=merge_sentences,
|
|
|
|
|
get_tone_ids=get_tone_ids)
|
|
|
|
|
phone_ids = input_ids["phone_ids"]
|
|
|
|
|
if get_tone_ids:
|
|
|
|
|
tone_ids = input_ids["tone_ids"]
|
|
|
|
|
elif lang == 'en':
|
|
|
|
|
input_ids = self.frontend.get_input_ids(
|
|
|
|
|
text, merge_sentences=merge_sentences)
|
|
|
|
|
phone_ids = input_ids["phone_ids"]
|
|
|
|
|
elif lang == 'mix':
|
|
|
|
|
input_ids = self.frontend.get_input_ids(
|
|
|
|
|
text, merge_sentences=merge_sentences)
|
|
|
|
|
phone_ids = input_ids["phone_ids"]
|
|
|
|
|
else:
|
|
|
|
|
logger.error("lang should in {'zh', 'en', 'mix'}!")
|
|
|
|
|
frontend_st = time.time()
|
|
|
|
|
frontend_dict = run_frontend(
|
|
|
|
|
frontend=self.frontend,
|
|
|
|
|
text=text,
|
|
|
|
|
merge_sentences=merge_sentences,
|
|
|
|
|
get_tone_ids=get_tone_ids,
|
|
|
|
|
lang=lang)
|
|
|
|
|
self.frontend_time = time.time() - frontend_st
|
|
|
|
|
|
|
|
|
|
self.am_time = 0
|
|
|
|
|
self.voc_time = 0
|
|
|
|
|
flags = 0
|
|
|
|
|
phone_ids = frontend_dict['phone_ids']
|
|
|
|
|
for i in range(len(phone_ids)):
|
|
|
|
|
am_st = time.time()
|
|
|
|
|
part_phone_ids = phone_ids[i]
|
|
|
|
|
# am
|
|
|
|
|
if am_name == 'speedyspeech':
|
|
|
|
|
part_tone_ids = tone_ids[i]
|
|
|
|
|
part_tone_ids = frontend_dict['tone_ids'][i]
|
|
|
|
|
mel = self.am_inference(part_phone_ids, part_tone_ids)
|
|
|
|
|
# fastspeech2
|
|
|
|
|
else:
|
|
|
|
@ -417,6 +476,62 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
self.voc_time += (time.time() - voc_st)
|
|
|
|
|
self._outputs['wav'] = wav_all
|
|
|
|
|
|
|
|
|
|
def infer_onnx(self,
|
|
|
|
|
text: str,
|
|
|
|
|
lang: str='zh',
|
|
|
|
|
am: str='fastspeech2_csmsc',
|
|
|
|
|
spk_id: int=0):
|
|
|
|
|
am_name = am[:am.rindex('_')]
|
|
|
|
|
am_dataset = am[am.rindex('_') + 1:]
|
|
|
|
|
merge_sentences = False
|
|
|
|
|
get_tone_ids = False
|
|
|
|
|
if am_name == 'speedyspeech':
|
|
|
|
|
get_tone_ids = True
|
|
|
|
|
am_input_feed = {}
|
|
|
|
|
frontend_st = time.time()
|
|
|
|
|
frontend_dict = run_frontend(
|
|
|
|
|
frontend=self.frontend,
|
|
|
|
|
text=text,
|
|
|
|
|
merge_sentences=merge_sentences,
|
|
|
|
|
get_tone_ids=get_tone_ids,
|
|
|
|
|
lang=lang,
|
|
|
|
|
to_tensor=False)
|
|
|
|
|
self.frontend_time = time.time() - frontend_st
|
|
|
|
|
phone_ids = frontend_dict['phone_ids']
|
|
|
|
|
self.am_time = 0
|
|
|
|
|
self.voc_time = 0
|
|
|
|
|
flags = 0
|
|
|
|
|
for i in range(len(phone_ids)):
|
|
|
|
|
am_st = time.time()
|
|
|
|
|
part_phone_ids = phone_ids[i]
|
|
|
|
|
if am_name == 'fastspeech2':
|
|
|
|
|
am_input_feed.update({'text': part_phone_ids})
|
|
|
|
|
if am_dataset in {"aishell3", "vctk"}:
|
|
|
|
|
# NOTE: 'spk_id' should be List[int] rather than int here!!
|
|
|
|
|
am_input_feed.update({'spk_id': [spk_id]})
|
|
|
|
|
elif am_name == 'speedyspeech':
|
|
|
|
|
part_tone_ids = frontend_dict['tone_ids'][i]
|
|
|
|
|
am_input_feed.update({
|
|
|
|
|
'phones': part_phone_ids,
|
|
|
|
|
'tones': part_tone_ids
|
|
|
|
|
})
|
|
|
|
|
mel = self.am_sess.run(output_names=None, input_feed=am_input_feed)
|
|
|
|
|
mel = mel[0]
|
|
|
|
|
self.am_time += (time.time() - am_st)
|
|
|
|
|
# voc
|
|
|
|
|
voc_st = time.time()
|
|
|
|
|
wav = self.voc_sess.run(
|
|
|
|
|
output_names=None, input_feed={'logmel': mel})
|
|
|
|
|
wav = wav[0]
|
|
|
|
|
if flags == 0:
|
|
|
|
|
wav_all = wav
|
|
|
|
|
flags = 1
|
|
|
|
|
else:
|
|
|
|
|
wav_all = np.concatenate([wav_all, wav])
|
|
|
|
|
self.voc_time += (time.time() - voc_st)
|
|
|
|
|
|
|
|
|
|
self._outputs['wav'] = wav_all
|
|
|
|
|
|
|
|
|
|
def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]:
|
|
|
|
|
"""
|
|
|
|
|
Output postprocess and return results.
|
|
|
|
@ -430,6 +545,20 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def postprocess_onnx(self,
|
|
|
|
|
output: str='output.wav') -> Union[str, os.PathLike]:
|
|
|
|
|
"""
|
|
|
|
|
Output postprocess and return results.
|
|
|
|
|
This method get model output from self._outputs and convert it into human-readable results.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Union[str, os.PathLike]: Human-readable results such as texts and audio files.
|
|
|
|
|
"""
|
|
|
|
|
output = os.path.abspath(os.path.expanduser(output))
|
|
|
|
|
sf.write(output, self._outputs['wav'], samplerate=self.am_fs)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
# 命令行的入口是这里
|
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Command line entry.
|
|
|
|
@ -451,6 +580,8 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
lang = args.lang
|
|
|
|
|
device = args.device
|
|
|
|
|
spk_id = args.spk_id
|
|
|
|
|
use_onnx = args.use_onnx
|
|
|
|
|
cpu_threads = args.cpu_threads
|
|
|
|
|
|
|
|
|
|
if not args.verbose:
|
|
|
|
|
self.disable_task_loggers()
|
|
|
|
@ -487,7 +618,9 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
# other
|
|
|
|
|
lang=lang,
|
|
|
|
|
device=device,
|
|
|
|
|
output=output)
|
|
|
|
|
output=output,
|
|
|
|
|
use_onnx=use_onnx,
|
|
|
|
|
cpu_threads=cpu_threads)
|
|
|
|
|
task_results[id_] = res
|
|
|
|
|
except Exception as e:
|
|
|
|
|
has_exceptions = True
|
|
|
|
@ -501,6 +634,7 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
else:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
# pyton api 的入口是这里
|
|
|
|
|
@stats_wrapper
|
|
|
|
|
def __call__(self,
|
|
|
|
|
text: str,
|
|
|
|
@ -512,33 +646,57 @@ class TTSExecutor(BaseExecutor):
|
|
|
|
|
phones_dict: Optional[os.PathLike]=None,
|
|
|
|
|
tones_dict: Optional[os.PathLike]=None,
|
|
|
|
|
speaker_dict: Optional[os.PathLike]=None,
|
|
|
|
|
voc: str='pwgan_csmsc',
|
|
|
|
|
voc: str='hifigan_csmsc',
|
|
|
|
|
voc_config: Optional[os.PathLike]=None,
|
|
|
|
|
voc_ckpt: Optional[os.PathLike]=None,
|
|
|
|
|
voc_stat: Optional[os.PathLike]=None,
|
|
|
|
|
lang: str='zh',
|
|
|
|
|
device: str=paddle.get_device(),
|
|
|
|
|
output: str='output.wav'):
|
|
|
|
|
output: str='output.wav',
|
|
|
|
|
use_onnx: bool=False,
|
|
|
|
|
cpu_threads: int=2):
|
|
|
|
|
"""
|
|
|
|
|
Python API to call an executor.
|
|
|
|
|
"""
|
|
|
|
|
paddle.set_device(device)
|
|
|
|
|
self._init_from_path(
|
|
|
|
|
am=am,
|
|
|
|
|
am_config=am_config,
|
|
|
|
|
am_ckpt=am_ckpt,
|
|
|
|
|
am_stat=am_stat,
|
|
|
|
|
phones_dict=phones_dict,
|
|
|
|
|
tones_dict=tones_dict,
|
|
|
|
|
speaker_dict=speaker_dict,
|
|
|
|
|
voc=voc,
|
|
|
|
|
voc_config=voc_config,
|
|
|
|
|
voc_ckpt=voc_ckpt,
|
|
|
|
|
voc_stat=voc_stat,
|
|
|
|
|
lang=lang)
|
|
|
|
|
|
|
|
|
|
self.infer(text=text, lang=lang, am=am, spk_id=spk_id)
|
|
|
|
|
|
|
|
|
|
res = self.postprocess(output=output)
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
if not use_onnx:
|
|
|
|
|
paddle.set_device(device)
|
|
|
|
|
self._init_from_path(
|
|
|
|
|
am=am,
|
|
|
|
|
am_config=am_config,
|
|
|
|
|
am_ckpt=am_ckpt,
|
|
|
|
|
am_stat=am_stat,
|
|
|
|
|
phones_dict=phones_dict,
|
|
|
|
|
tones_dict=tones_dict,
|
|
|
|
|
speaker_dict=speaker_dict,
|
|
|
|
|
voc=voc,
|
|
|
|
|
voc_config=voc_config,
|
|
|
|
|
voc_ckpt=voc_ckpt,
|
|
|
|
|
voc_stat=voc_stat,
|
|
|
|
|
lang=lang)
|
|
|
|
|
|
|
|
|
|
self.infer(text=text, lang=lang, am=am, spk_id=spk_id)
|
|
|
|
|
res = self.postprocess(output=output)
|
|
|
|
|
return res
|
|
|
|
|
else:
|
|
|
|
|
# use onnx
|
|
|
|
|
# we use `cpu` for onnxruntime by default
|
|
|
|
|
# please see description in https://github.com/PaddlePaddle/PaddleSpeech/pull/2220
|
|
|
|
|
self.task_resource = CommonTaskResource(
|
|
|
|
|
task='tts', model_format='onnx')
|
|
|
|
|
assert (
|
|
|
|
|
am in ONNX_SUPPORT_SET and voc in ONNX_SUPPORT_SET
|
|
|
|
|
), f'the am and voc you choose, they should be in {ONNX_SUPPORT_SET}'
|
|
|
|
|
self._init_from_path_onnx(
|
|
|
|
|
am=am,
|
|
|
|
|
am_ckpt=am_ckpt,
|
|
|
|
|
phones_dict=phones_dict,
|
|
|
|
|
tones_dict=tones_dict,
|
|
|
|
|
speaker_dict=speaker_dict,
|
|
|
|
|
voc=voc,
|
|
|
|
|
voc_ckpt=voc_ckpt,
|
|
|
|
|
lang=lang,
|
|
|
|
|
device=device,
|
|
|
|
|
cpu_threads=cpu_threads)
|
|
|
|
|
self.infer_onnx(text=text, lang=lang, am=am, spk_id=spk_id)
|
|
|
|
|
res = self.postprocess_onnx(output=output)
|
|
|
|
|
return res
|
|
|
|
|