Merge pull request #1113 from yt605155624/tts_cli

[cli] update  am_name in tts cli
pull/1115/head
TianYuan 3 years ago committed by GitHub
commit 9066cfff88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -405,8 +405,6 @@ class TTSExecutor(BaseExecutor):
with open(self.voc_config) as f:
self.voc_config = CfgNode(yaml.safe_load(f))
# Enter the path of model root
with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
@ -501,10 +499,10 @@ class TTSExecutor(BaseExecutor):
"""
Model inference and result stored in self.output.
"""
model_name = am[:am.rindex('_')]
dataset = am[am.rindex('_') + 1:]
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
if 'speedyspeech' in model_name:
if am_name == 'speedyspeech':
get_tone_ids = True
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
@ -521,15 +519,14 @@ class TTSExecutor(BaseExecutor):
print("lang should in {'zh', 'en'}!")
# am
if 'speedyspeech' in model_name:
if am_name == 'speedyspeech':
mel = self.am_inference(phone_ids, tone_ids)
# fastspeech2
else:
# multi speaker
if dataset in {"aishell3", "vctk"}:
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.am_inference(phone_ids)

Loading…
Cancel
Save