fix ort_predict_e2e.py

pull/2990/head
JiehangXie 3 years ago
parent 38208346c8
commit 57cb6329e2

@ -77,7 +77,7 @@ def ort_predict(args):
else:
phone_ids = np.random.randint(1, 266, size=(T, ))
am_input_feed.update({'text': phone_ids})
if am_dataset in {"aishell3", "vctk", "mix"}:
if am_dataset in {"aishell3", "vctk", "mix", "canton"}:
am_input_feed.update({'spk_id': spk_id})
elif am_name == 'speedyspeech':
phone_ids = np.random.randint(1, 92, size=(T, ))
@ -112,7 +112,7 @@ def ort_predict(args):
part_phone_ids = phone_ids[i].numpy()
if am_name == 'fastspeech2':
am_input_feed.update({'text': part_phone_ids})
if am_dataset in {"aishell3", "vctk", "mix"}:
if am_dataset in {"aishell3", "vctk", "mix", "canton"}:
am_input_feed.update({'spk_id': spk_id})
elif am_name == 'speedyspeech':
part_tone_ids = frontend_dict['tone_ids'][i].numpy()

Loading…
Cancel
Save