diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index 3296d0e1e..5e4d273ed 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -77,7 +77,7 @@ def ort_predict(args): else: phone_ids = np.random.randint(1, 266, size=(T, )) am_input_feed.update({'text': phone_ids}) - if am_dataset in {"aishell3", "vctk", "mix"}: + if am_dataset in {"aishell3", "vctk", "mix", "canton"}: am_input_feed.update({'spk_id': spk_id}) elif am_name == 'speedyspeech': phone_ids = np.random.randint(1, 92, size=(T, )) @@ -112,7 +112,7 @@ def ort_predict(args): part_phone_ids = phone_ids[i].numpy() if am_name == 'fastspeech2': am_input_feed.update({'text': part_phone_ids}) - if am_dataset in {"aishell3", "vctk", "mix"}: + if am_dataset in {"aishell3", "vctk", "mix", "canton"}: am_input_feed.update({'spk_id': spk_id}) elif am_name == 'speedyspeech': part_tone_ids = frontend_dict['tone_ids'][i].numpy()