From 060e33762352abcbf508e72ec6ec8fe69b1bdf04 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 2 Aug 2022 09:29:49 +0000 Subject: [PATCH 01/11] fix dataloader factory, test=asr --- paddlespeech/s2t/io/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 831830241..735d29da2 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -389,6 +389,7 @@ class DataLoaderFactory(): config['mini_batch_size'] = args.ngpu config['subsampling_factor'] = 1 config['num_encs'] = 1 + config['shortest_first'] = False elif mode == 'valid': config['manifest'] = config.dev_manifest config['train_mode'] = False From 923b0b873e193e8831988613c6d27e33716369b0 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 2 Aug 2022 09:31:01 +0000 Subject: [PATCH 02/11] fix import kws.exps.mdtc --- paddlespeech/kws/exps/__init__.py | 0 paddlespeech/kws/exps/mdtc/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 paddlespeech/kws/exps/__init__.py create mode 100644 paddlespeech/kws/exps/mdtc/__init__.py diff --git a/paddlespeech/kws/exps/__init__.py b/paddlespeech/kws/exps/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlespeech/kws/exps/mdtc/__init__.py b/paddlespeech/kws/exps/mdtc/__init__.py new file mode 100644 index 000000000..e69de29bb From dca51c5131f0fc2959b508964e36219473d46097 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 3 Aug 2022 03:04:35 +0000 Subject: [PATCH 03/11] fix wenetspeech conf, test=asr --- examples/wenetspeech/asr1/conf/conformer.yaml | 1 + examples/wenetspeech/asr1/local/export.sh | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100755 examples/wenetspeech/asr1/local/export.sh diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index d1ac20b9b..8a44db1e8 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -37,6 +37,7 @@ model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false + init_type: 'kaiming_uniform' # !Warning: need to convergence # https://yaml.org/type/float.html ########################################### diff --git a/examples/wenetspeech/asr1/local/export.sh b/examples/wenetspeech/asr1/local/export.sh new file mode 100755 index 000000000..6b646b469 --- /dev/null +++ b/examples/wenetspeech/asr1/local/export.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: $0 config_path ckpt_prefix jit_model_path" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_path_prefix=$2 +jit_model_export_path=$3 + +python3 -u ${BIN_DIR}/export.py \ +--ngpu ${ngpu} \ +--config ${config_path} \ +--checkpoint_path ${ckpt_path_prefix} \ +--export_path ${jit_model_export_path} + + +if [ $? -ne 0 ]; then + echo "Failed in export!" + exit 1 +fi + + +exit 0 From ac2a9ec9007ab54b479811ffbff0f7aa75c6b391 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 3 Aug 2022 07:48:06 +0000 Subject: [PATCH 04/11] fix version of onnxruntime --- docs/requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index bf1486c5e..77dc609e7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -18,7 +18,7 @@ librosa==0.8.1 loguru matplotlib nara_wpe -onnxruntime +onnxruntime-gpu==1.10.0 pandas paddlenlp paddlespeech_feat diff --git a/setup.py b/setup.py index 1cc82fa76..6809411f7 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ base = [ "loguru", "matplotlib", "nara_wpe", - "onnxruntime", + "onnxruntime-gpu==1.10.0", "pandas", "paddlenlp", "paddlespeech_feat", From 24ae1c063eb30bfa19e0fe5887cfc9c3a94eda42 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 3 Aug 2022 08:11:34 +0000 Subject: [PATCH 05/11] onnxruntime doesn't support gpu for mac --- docs/requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 77dc609e7..d6e27e226 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -18,7 +18,7 @@ librosa==0.8.1 loguru matplotlib nara_wpe -onnxruntime-gpu==1.10.0 +onnxruntime==1.10.0 pandas paddlenlp paddlespeech_feat diff --git a/setup.py b/setup.py index 6809411f7..c1be724fb 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ base = [ "loguru", "matplotlib", "nara_wpe", - "onnxruntime-gpu==1.10.0", + "onnxruntime==1.10.0", "pandas", "paddlenlp", "paddlespeech_feat", From b9ade18055c999850d7a8ea8e6b91b7b30cd0abf Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 3 Aug 2022 12:10:59 +0000 Subject: [PATCH 06/11] add onnxruntime infer for cli --- examples/aishell3/tts3/run.sh | 2 +- examples/ljspeech/tts3/run.sh | 2 +- examples/vctk/tts3/run.sh | 2 +- paddlespeech/cli/tts/infer.py | 350 +++++++++++++----- paddlespeech/resource/pretrained_models.py | 2 +- paddlespeech/t2s/exps/inference.py | 5 - paddlespeech/t2s/exps/inference_streaming.py | 16 +- paddlespeech/t2s/exps/ort_predict.py | 12 +- paddlespeech/t2s/exps/ort_predict_e2e.py | 91 ++--- .../t2s/exps/ort_predict_streaming.py | 44 ++- paddlespeech/t2s/exps/syn_utils.py | 110 +++--- paddlespeech/t2s/exps/synthesize_e2e.py | 33 +- paddlespeech/t2s/exps/synthesize_streaming.py | 17 +- paddlespeech/t2s/frontend/mix_frontend.py | 8 +- paddlespeech/t2s/frontend/phonectic.py | 9 +- paddlespeech/t2s/frontend/zh_frontend.py | 28 +- 16 files changed, 450 insertions(+), 281 deletions(-) diff --git a/examples/aishell3/tts3/run.sh b/examples/aishell3/tts3/run.sh index 868087a01..24715fee1 100755 --- a/examples/aishell3/tts3/run.sh +++ b/examples/aishell3/tts3/run.sh @@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then fi -# inference with onnxruntime, use fastspeech2 + hifigan by default +# inference with onnxruntime, use fastspeech2 + pwgan by default if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then ./local/ort_predict.sh ${train_output_path} fi diff --git a/examples/ljspeech/tts3/run.sh b/examples/ljspeech/tts3/run.sh index c4a596386..260f06c8b 100755 --- a/examples/ljspeech/tts3/run.sh +++ b/examples/ljspeech/tts3/run.sh @@ -55,7 +55,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_ljspeech fi -# inference with onnxruntime, use fastspeech2 + hifigan by default +# inference with onnxruntime, use fastspeech2 + pwgan by default if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then ./local/ort_predict.sh ${train_output_path} fi diff --git a/examples/vctk/tts3/run.sh b/examples/vctk/tts3/run.sh index 3d2a4a947..b45afd7be 100755 --- a/examples/vctk/tts3/run.sh +++ b/examples/vctk/tts3/run.sh @@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then fi -# inference with onnxruntime, use fastspeech2 + hifigan by default +# inference with onnxruntime, use fastspeech2 + pwgan by default if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then ./local/ort_predict.sh ${train_output_path} fi diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 65750e1a8..4d5ddb754 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -29,10 +29,21 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper +from paddlespeech.resource import CommonTaskResource +from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_frontend -from paddlespeech.t2s.modules.normalizer import ZScore +from paddlespeech.t2s.exps.syn_utils import get_sess +from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import run_frontend +from paddlespeech.t2s.utils import str2bool __all__ = ['TTSExecutor'] +ONNX_SUPPORT_SET = { + 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', + 'fastspeech2_aishell3', 'fastspeech2_vctk', 'pwgan_csmsc', 'pwgan_ljspeech', + 'pwgan_aishell3', 'pwgan_vctk', 'mb_melgan_csmsc', 'hifigan_csmsc', + 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk' +} class TTSExecutor(BaseExecutor): @@ -142,6 +153,8 @@ class TTSExecutor(BaseExecutor): default=paddle.get_device(), help='Choose device to execute model inference.') + self.parser.add_argument('--cpu_threads', type=int, default=2) + self.parser.add_argument( '--output', type=str, default='output.wav', help='output file name') self.parser.add_argument( @@ -154,6 +167,11 @@ class TTSExecutor(BaseExecutor): '--verbose', action='store_true', help='Increase logger verbosity of current task.') + self.parser.add_argument( + "--use_onnx", + type=str2bool, + default=False, + help="whether to usen onnxruntime inference.") def _init_from_path( self, @@ -164,7 +182,7 @@ class TTSExecutor(BaseExecutor): phones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None, - voc: str='pwgan_csmsc', + voc: str='hifigan_csmsc', voc_config: Optional[os.PathLike]=None, voc_ckpt: Optional[os.PathLike]=None, voc_stat: Optional[os.PathLike]=None, @@ -288,58 +306,111 @@ class TTSExecutor(BaseExecutor): lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict) # acoustic model - odim = self.am_config.n_mels - # model: {model_name}_{dataset} - am_name = am[:am.rindex('_')] - - am_class = self.task_resource.get_model_class(am_name) - am_inference_class = self.task_resource.get_model_class(am_name + - '_inference') - - if am_name == 'fastspeech2': - am = am_class( - idim=vocab_size, - odim=odim, - spk_num=spk_num, - **self.am_config["model"]) - elif am_name == 'speedyspeech': - am = am_class( - vocab_size=vocab_size, - tone_size=tone_size, - **self.am_config["model"]) - elif am_name == 'tacotron2': - am = am_class(idim=vocab_size, odim=odim, **self.am_config["model"]) - - am.set_state_dict(paddle.load(self.am_ckpt)["main_params"]) - am.eval() - am_mu, am_std = np.load(self.am_stat) - am_mu = paddle.to_tensor(am_mu) - am_std = paddle.to_tensor(am_std) - am_normalizer = ZScore(am_mu, am_std) - self.am_inference = am_inference_class(am_normalizer, am) - self.am_inference.eval() + self.am_inference = get_am_inference( + am=am, + am_config=self.am_config, + am_ckpt=self.am_ckpt, + am_stat=self.am_stat, + phones_dict=self.phones_dict, + tones_dict=self.tones_dict, + speaker_dict=self.speaker_dict) # vocoder - # model: {model_name}_{dataset} - voc_name = voc[:voc.rindex('_')] - voc_class = self.task_resource.get_model_class(voc_name) - voc_inference_class = self.task_resource.get_model_class(voc_name + - '_inference') - if voc_name != 'wavernn': - voc = voc_class(**self.voc_config["generator_params"]) - voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) - voc.remove_weight_norm() - voc.eval() + self.voc_inference = get_voc_inference( + voc=voc, + voc_config=self.voc_config, + voc_ckpt=self.voc_ckpt, + voc_stat=self.voc_stat) + + def _init_from_path_onnx(self, + am: str='fastspeech2_csmsc', + am_ckpt: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='hifigan_csmsc', + voc_ckpt: Optional[os.PathLike]=None, + lang: str='zh', + device: str='cpu', + cpu_threads: int=2, + fs: int=24000): + if hasattr(self, 'am_sess') and hasattr(self, 'voc_sess'): + logger.debug('Models had been initialized.') + return + + # am + if am_ckpt is None or phones_dict is None: + use_pretrained_am = True + else: + use_pretrained_am = False + + am_tag = am + '_onnx' + '-' + lang + self.task_resource.set_task_model( + model_tag=am_tag, + model_type=0, # am + skip_download=not use_pretrained_am, + version=None, # default version + ) + if use_pretrained_am: + self.am_res_path = self.task_resource.res_dir + self.am_ckpt = os.path.join(self.am_res_path, + self.task_resource.res_dict['ckpt'][0]) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + self.am_res_path, self.task_resource.res_dict['phones_dict']) + self.am_fs = self.task_resource.res_dict['sample_rate'] + logger.debug(self.am_res_path) + logger.debug(self.am_ckpt) + else: + self.am_ckpt = os.path.abspath(am_ckpt[0]) + self.phones_dict = os.path.abspath(phones_dict) + self.am_res_path = os.path.dirname(os.path.abspath(am_ckpt)) + self.am_fs = fs + + # for speedyspeech + self.tones_dict = None + if 'tones_dict' in self.task_resource.res_dict: + self.tones_dict = os.path.join( + self.am_res_path, self.task_resource.res_dict['tones_dict']) + if tones_dict: + self.tones_dict = tones_dict + + # voc + if voc_ckpt is None: + use_pretrained_voc = True + else: + use_pretrained_voc = False + voc_lang = lang + # we must use ljspeech's voc for mix am now! + if lang == 'mix': + voc_lang = 'en' + voc_tag = voc + '_onnx' + '-' + voc_lang + self.task_resource.set_task_model( + model_tag=voc_tag, + model_type=1, # vocoder + skip_download=not use_pretrained_voc, + version=None, # default version + ) + if use_pretrained_voc: + self.voc_res_path = self.task_resource.voc_res_dir + self.voc_ckpt = os.path.join( + self.voc_res_path, self.task_resource.voc_res_dict['ckpt']) + logger.debug(self.voc_res_path) + logger.debug(self.voc_ckpt) else: - voc = voc_class(**self.voc_config["model"]) - voc.set_state_dict(paddle.load(self.voc_ckpt)["main_params"]) - voc.eval() - voc_mu, voc_std = np.load(self.voc_stat) - voc_mu = paddle.to_tensor(voc_mu) - voc_std = paddle.to_tensor(voc_std) - voc_normalizer = ZScore(voc_mu, voc_std) - self.voc_inference = voc_inference_class(voc_normalizer, voc) - self.voc_inference.eval() + self.voc_ckpt = os.path.abspath(voc_ckpt) + self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt)) + + # frontend + self.frontend = get_frontend( + lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict) + + self.am_sess = get_sess( + model_path=self.am_ckpt, device=device, cpu_threads=cpu_threads) + + # vocoder + self.voc_sess = get_sess( + model_path=self.voc_ckpt, device=device, cpu_threads=cpu_threads) def preprocess(self, input: Any, *args, **kwargs): """ @@ -362,40 +433,28 @@ class TTSExecutor(BaseExecutor): """ am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] - get_tone_ids = False merge_sentences = False - frontend_st = time.time() + get_tone_ids = False if am_name == 'speedyspeech': get_tone_ids = True - if lang == 'zh': - input_ids = self.frontend.get_input_ids( - text, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - elif lang == 'en': - input_ids = self.frontend.get_input_ids( - text, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - elif lang == 'mix': - input_ids = self.frontend.get_input_ids( - text, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - logger.error("lang should in {'zh', 'en', 'mix'}!") + frontend_st = time.time() + frontend_dict = run_frontend( + frontend=self.frontend, + text=text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=lang) self.frontend_time = time.time() - frontend_st - self.am_time = 0 self.voc_time = 0 flags = 0 + phone_ids = frontend_dict['phone_ids'] for i in range(len(phone_ids)): am_st = time.time() part_phone_ids = phone_ids[i] # am if am_name == 'speedyspeech': - part_tone_ids = tone_ids[i] + part_tone_ids = frontend_dict['tone_ids'][i] mel = self.am_inference(part_phone_ids, part_tone_ids) # fastspeech2 else: @@ -417,6 +476,62 @@ class TTSExecutor(BaseExecutor): self.voc_time += (time.time() - voc_st) self._outputs['wav'] = wav_all + def infer_onnx(self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc', + spk_id: int=0): + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] + merge_sentences = False + get_tone_ids = False + if am_name == 'speedyspeech': + get_tone_ids = True + am_input_feed = {} + frontend_st = time.time() + frontend_dict = run_frontend( + frontend=self.frontend, + text=text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=lang, + to_tensor=False) + self.frontend_time = time.time() - frontend_st + phone_ids = frontend_dict['phone_ids'] + self.am_time = 0 + self.voc_time = 0 + flags = 0 + for i in range(len(phone_ids)): + am_st = time.time() + part_phone_ids = phone_ids[i] + if am_name == 'fastspeech2': + am_input_feed.update({'text': part_phone_ids}) + if am_dataset in {"aishell3", "vctk"}: + # NOTE: 'spk_id' should be List[int] rather than int here!! + am_input_feed.update({'spk_id': [spk_id]}) + elif am_name == 'speedyspeech': + part_tone_ids = frontend_dict['tone_ids'][i] + am_input_feed.update({ + 'phones': part_phone_ids, + 'tones': part_tone_ids + }) + mel = self.am_sess.run(output_names=None, input_feed=am_input_feed) + mel = mel[0] + self.am_time += (time.time() - am_st) + # voc + voc_st = time.time() + wav = self.voc_sess.run( + output_names=None, input_feed={'logmel': mel}) + wav = wav[0] + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = np.concatenate([wav_all, wav]) + self.voc_time += (time.time() - voc_st) + + self._outputs['wav'] = wav_all + def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]: """ Output postprocess and return results. @@ -430,6 +545,20 @@ class TTSExecutor(BaseExecutor): output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs) return output + def postprocess_onnx(self, + output: str='output.wav') -> Union[str, os.PathLike]: + """ + Output postprocess and return results. + This method get model output from self._outputs and convert it into human-readable results. + + Returns: + Union[str, os.PathLike]: Human-readable results such as texts and audio files. + """ + output = os.path.abspath(os.path.expanduser(output)) + sf.write(output, self._outputs['wav'], samplerate=self.am_fs) + return output + + # 命令行的入口是这里 def execute(self, argv: List[str]) -> bool: """ Command line entry. @@ -451,6 +580,8 @@ class TTSExecutor(BaseExecutor): lang = args.lang device = args.device spk_id = args.spk_id + use_onnx = args.use_onnx + cpu_threads = args.cpu_threads if not args.verbose: self.disable_task_loggers() @@ -487,7 +618,9 @@ class TTSExecutor(BaseExecutor): # other lang=lang, device=device, - output=output) + output=output, + use_onnx=use_onnx, + cpu_threads=cpu_threads) task_results[id_] = res except Exception as e: has_exceptions = True @@ -501,6 +634,7 @@ class TTSExecutor(BaseExecutor): else: return True + # pyton api 的入口是这里 @stats_wrapper def __call__(self, text: str, @@ -512,33 +646,57 @@ class TTSExecutor(BaseExecutor): phones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None, - voc: str='pwgan_csmsc', + voc: str='hifigan_csmsc', voc_config: Optional[os.PathLike]=None, voc_ckpt: Optional[os.PathLike]=None, voc_stat: Optional[os.PathLike]=None, lang: str='zh', device: str=paddle.get_device(), - output: str='output.wav'): + output: str='output.wav', + use_onnx: bool=False, + cpu_threads: int=2): """ Python API to call an executor. """ - paddle.set_device(device) - self._init_from_path( - am=am, - am_config=am_config, - am_ckpt=am_ckpt, - am_stat=am_stat, - phones_dict=phones_dict, - tones_dict=tones_dict, - speaker_dict=speaker_dict, - voc=voc, - voc_config=voc_config, - voc_ckpt=voc_ckpt, - voc_stat=voc_stat, - lang=lang) - - self.infer(text=text, lang=lang, am=am, spk_id=spk_id) - - res = self.postprocess(output=output) - - return res + if not use_onnx: + paddle.set_device(device) + self._init_from_path( + am=am, + am_config=am_config, + am_ckpt=am_ckpt, + am_stat=am_stat, + phones_dict=phones_dict, + tones_dict=tones_dict, + speaker_dict=speaker_dict, + voc=voc, + voc_config=voc_config, + voc_ckpt=voc_ckpt, + voc_stat=voc_stat, + lang=lang) + + self.infer(text=text, lang=lang, am=am, spk_id=spk_id) + res = self.postprocess(output=output) + return res + else: + # use onnx + # we use `cpu` for onnxruntime by default + # please see description in https://github.com/PaddlePaddle/PaddleSpeech/pull/2220 + self.task_resource = CommonTaskResource( + task='tts', model_format='onnx') + assert ( + am in ONNX_SUPPORT_SET and voc in ONNX_SUPPORT_SET + ), f'the am and voc you choose, they should be in {ONNX_SUPPORT_SET}' + self._init_from_path_onnx( + am=am, + am_ckpt=am_ckpt, + phones_dict=phones_dict, + tones_dict=tones_dict, + speaker_dict=speaker_dict, + voc=voc, + voc_ckpt=voc_ckpt, + lang=lang, + device=device, + cpu_threads=cpu_threads) + self.infer_onnx(text=text, lang=lang, am=am, spk_id=spk_id) + res = self.postprocess_onnx(output=output) + return res diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index d7df0e48a..43d63925b 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -1149,7 +1149,7 @@ tts_onnx_pretrained_models = { "fastspeech2_vctk_onnx-en": { '1.0': { 'url': - 'hhttps://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip', + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip', 'md5': 'd9c3a9b02204a2070504dd99f5f959bf', 'ckpt': ['fastspeech2_vctk.onnx'], diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index ba951182d..3732e0f40 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -86,11 +86,6 @@ def parse_args(): "--inference_dir", type=str, help="dir to save inference models") parser.add_argument("--output_dir", type=str, help="output dir") # inference - parser.add_argument( - "--use_trt", - type=str2bool, - default=False, - help="Whether to use inference engin TensorRT.", ) parser.add_argument( "--int8", type=str2bool, diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py index 624defc6a..5e2ce89db 100644 --- a/paddlespeech/t2s/exps/inference_streaming.py +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -27,6 +27,7 @@ from paddlespeech.t2s.exps.syn_utils import get_predictor from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output from paddlespeech.t2s.exps.syn_utils import get_voc_output +from paddlespeech.t2s.exps.syn_utils import run_frontend from paddlespeech.t2s.utils import str2bool @@ -175,14 +176,13 @@ def main(): for utt_id, sentence in sentences: with timer() as t: # frontend - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - else: - print("lang should be 'zh' here!") + frontend_dict = run_frontend( + frontend=frontend, + text=sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=args.lang) + phone_ids = frontend_dict['phone_ids'] phones = phone_ids[0].numpy() # acoustic model orig_hs = get_am_sublayer_output( diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index 2e8596ded..bd89f74d2 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -41,17 +41,17 @@ def ort_predict(args): # am am_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.am + ".onnx", + model_path=str(Path(args.inference_dir) / (args.am + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) # vocoder voc_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.voc + ".onnx", + model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) # am warmup for T in [27, 38, 54]: diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index f33fc4128..b96da9b21 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -22,6 +22,7 @@ from timer import timer from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sess +from paddlespeech.t2s.exps.syn_utils import run_frontend from paddlespeech.t2s.utils import str2bool @@ -42,17 +43,17 @@ def ort_predict(args): fs = 24000 if am_dataset != 'ljspeech' else 22050 am_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.am + ".onnx", + model_path=str(Path(args.inference_dir) / (args.am + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) # vocoder voc_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.voc + ".onnx", + model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) merge_sentences = True @@ -78,7 +79,6 @@ def ort_predict(args): am_input_feed.update({'text': phone_ids}) if am_dataset in {"aishell3", "vctk"}: am_input_feed.update({'spk_id': spk_id}) - elif am_name == 'speedyspeech': phone_ids = np.random.randint(1, 92, size=(T, )) tone_ids = np.random.randint(1, 5, size=(T, )) @@ -93,50 +93,51 @@ def ort_predict(args): N = 0 T = 0 - merge_sentences = True + merge_sentences = False get_tone_ids = False - am_input_feed = {} if am_name == 'speedyspeech': get_tone_ids = True + am_input_feed = {} for utt_id, sentence in sentences: with timer() as t: - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - elif args.lang == 'en': - input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en'}!") - # merge_sentences=True here, so we only use the first item of phone_ids - phone_ids = phone_ids[0].numpy() - if am_name == 'fastspeech2': - am_input_feed.update({'text': phone_ids}) - if am_dataset in {"aishell3", "vctk"}: - am_input_feed.update({'spk_id': spk_id}) - elif am_name == 'speedyspeech': - tone_ids = tone_ids[0].numpy() - am_input_feed.update({'phones': phone_ids, 'tones': tone_ids}) - mel = am_sess.run(output_names=None, input_feed=am_input_feed) - mel = mel[0] - wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) - - N += len(wav[0]) - T += t.elapse - speed = len(wav[0]) / t.elapse - rtf = fs / speed - sf.write( - str(output_dir / (utt_id + ".wav")), - np.array(wav)[0], - samplerate=fs) + frontend_dict = run_frontend( + frontend=frontend, + text=sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=args.lang) + phone_ids = frontend_dict['phone_ids'] + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i].numpy() + if am_name == 'fastspeech2': + am_input_feed.update({'text': part_phone_ids}) + if am_dataset in {"aishell3", "vctk"}: + am_input_feed.update({'spk_id': spk_id}) + elif am_name == 'speedyspeech': + part_tone_ids = frontend_dict['tone_ids'][i].numpy() + am_input_feed.update({ + 'phones': part_phone_ids, + 'tones': part_tone_ids + }) + mel = am_sess.run(output_names=None, input_feed=am_input_feed) + mel = mel[0] + wav = voc_sess.run( + output_names=None, input_feed={'logmel': mel}) + wav = wav[0] + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = np.concatenate([wav_all, wav]) + wav = wav_all + N += len(wav) + T += t.elapse + speed = len(wav) / t.elapse + rtf = fs / speed + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=fs) print( - f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + f"{utt_id}, mel: {mel.shape}, wave: {len(wav)}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." ) print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py index d5241f1c6..0d07dcf37 100644 --- a/paddlespeech/t2s/exps/ort_predict_streaming.py +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -24,6 +24,7 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sess +from paddlespeech.t2s.exps.syn_utils import run_frontend from paddlespeech.t2s.utils import str2bool @@ -45,29 +46,33 @@ def ort_predict(args): # streaming acoustic model am_encoder_infer_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.am + "_am_encoder_infer" + ".onnx", + model_path=str( + Path(args.inference_dir) / + (args.am + '_am_encoder_infer' + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) am_decoder_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.am + "_am_decoder" + ".onnx", + model_path=str( + Path(args.inference_dir) / (args.am + '_am_decoder' + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) am_postnet_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.am + "_am_postnet" + ".onnx", + model_path=str( + Path(args.inference_dir) / (args.am + '_am_postnet' + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) am_mu, am_std = np.load(args.am_stat) # vocoder voc_sess = get_sess( - model_dir=args.inference_dir, - model_file=args.voc + ".onnx", + model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')), device=args.device, - cpu_threads=args.cpu_threads) + cpu_threads=args.cpu_threads, + use_trt=args.use_trt) # frontend warmup # Loading model cost 0.5+ seconds @@ -102,14 +107,13 @@ def ort_predict(args): for utt_id, sentence in sentences: with timer() as t: - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in be 'zh' here!") + frontend_dict = run_frontend( + frontend=frontend, + text=sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=args.lang) + phone_ids = frontend_dict['phone_ids'] # merge_sentences=True here, so we only use the first item of phone_ids phone_ids = phone_ids[0].numpy() orig_hs = am_encoder_infer_sess.run( diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index bade62aca..2166838cd 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -33,6 +33,8 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.utils.dynamic_import import dynamic_import +# remove [W:onnxruntime: xxx] from ort +ort.set_default_logger_severity(3) model_alias = { # acoustic model @@ -161,13 +163,42 @@ def get_frontend(lang: str='zh', elif lang == 'mix': frontend = MixFrontend( phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) - else: print("wrong lang!") - print("frontend done!") return frontend +def run_frontend(frontend: object, + text: str, + merge_sentences: bool=False, + get_tone_ids: bool=False, + lang: str='zh', + to_tensor: bool=True): + outs = dict() + if lang == 'zh': + input_ids = frontend.get_input_ids( + text, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + to_tensor=to_tensor) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + outs.update({'tone_ids': tone_ids}) + elif lang == 'en': + input_ids = frontend.get_input_ids( + text, merge_sentences=merge_sentences, to_tensor=to_tensor) + phone_ids = input_ids["phone_ids"] + elif lang == 'mix': + input_ids = frontend.get_input_ids( + text, merge_sentences=merge_sentences, to_tensor=to_tensor) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en', 'mix'}!") + outs.update({'phone_ids': phone_ids}) + return outs + + # dygraph def get_am_inference(am: str='fastspeech2_csmsc', am_config: CfgNode=None, @@ -180,30 +211,22 @@ def get_am_inference(am: str='fastspeech2_csmsc', with open(phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) - print("vocab_size:", vocab_size) - tone_size = None if tones_dict is not None: with open(tones_dict, "r") as f: tone_id = [line.strip().split() for line in f.readlines()] tone_size = len(tone_id) - print("tone_size:", tone_size) - spk_num = None if speaker_dict is not None: with open(speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] spk_num = len(spk_id) - print("spk_num:", spk_num) - odim = am_config.n_mels # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] - am_class = dynamic_import(am_name, model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias) - if am_name == 'fastspeech2': am = am_class( idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) @@ -228,7 +251,6 @@ def get_am_inference(am: str='fastspeech2_csmsc', am_normalizer = ZScore(am_mu, am_std) am_inference = am_inference_class(am_normalizer, am) am_inference.eval() - print("acoustic model done!") if return_am: return am_inference, am else: @@ -260,7 +282,6 @@ def get_voc_inference( voc_normalizer = ZScore(voc_mu, voc_std) voc_inference = voc_inference_class(voc_normalizer, voc) voc_inference.eval() - print("voc done!") return voc_inference @@ -342,9 +363,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None, def get_am_output( input: str, - am_predictor, - am, - frontend, + am_predictor: paddle.nn.Layer, + am: str, + frontend: object, lang: str='zh', merge_sentences: bool=True, speaker_dict: Optional[os.PathLike]=None, @@ -352,30 +373,23 @@ def get_am_output( am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] am_input_names = am_predictor.get_input_names() - get_tone_ids = False get_spk_id = False + get_tone_ids = False if am_name == 'speedyspeech': get_tone_ids = True if am_dataset in {"aishell3", "vctk"} and speaker_dict: get_spk_id = True spk_id = np.array([spk_id]) - if lang == 'zh': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - elif lang == 'en': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - elif lang == 'mix': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en', 'mix'}!") + + frontend_dict = run_frontend( + frontend=frontend, + text=input, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=lang) if get_tone_ids: - tone_ids = input_ids["tone_ids"] + tone_ids = frontend_dict['tone_ids'] tones = tone_ids[0].numpy() tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle.reshape(tones.shape) @@ -384,6 +398,7 @@ def get_am_output( spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) spk_id_handle.reshape(spk_id.shape) spk_id_handle.copy_from_cpu(spk_id) + phone_ids = frontend_dict['phone_ids'] phones = phone_ids[0].numpy() phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle.reshape(phones.shape) @@ -432,13 +447,13 @@ def get_streaming_am_output(input: str, lang: str='zh', merge_sentences: bool=True): get_tone_ids = False - if lang == 'zh': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - else: - print("lang should be 'zh' here!") - + frontend_dict = run_frontend( + frontend=frontend, + text=input, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=lang) + phone_ids = frontend_dict['phone_ids'] phones = phone_ids[0].numpy() am_encoder_infer_output = get_am_sublayer_output( am_encoder_infer_predictor, input=phones) @@ -455,26 +470,25 @@ def get_streaming_am_output(input: str, # onnx -def get_sess(model_dir: Optional[os.PathLike]=None, - model_file: Optional[os.PathLike]=None, +def get_sess(model_path: Optional[os.PathLike], device: str='cpu', cpu_threads: int=1, use_trt: bool=False): - - model_dir = str(Path(model_dir) / model_file) sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - - if device == "gpu": + if 'gpu' in device.lower(): + device_id = int(device.split(':')[1]) if len( + device.split(':')) == 2 else 0 # fastspeech2/mb_melgan can't use trt now! if use_trt: - providers = ['TensorrtExecutionProvider'] + provider_name = 'TensorrtExecutionProvider' else: - providers = ['CUDAExecutionProvider'] - elif device == "cpu": + provider_name = 'CUDAExecutionProvider' + providers = [(provider_name, {'device_id': device_id})] + elif device.lower() == 'cpu': providers = ['CPUExecutionProvider'] sess_options.intra_op_num_threads = cpu_threads sess = ort.InferenceSession( - model_dir, providers=providers, sess_options=sess_options) + model_path, providers=providers, sess_options=sess_options) return sess diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index ef9543296..5988bb30e 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -25,6 +25,7 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import run_frontend from paddlespeech.t2s.exps.syn_utils import voc_to_static @@ -49,6 +50,7 @@ def evaluate(args): lang=args.lang, phones_dict=args.phones_dict, tones_dict=args.tones_dict) + print("frontend done!") # acoustic model am_name = args.am[:args.am.rindex('_')] @@ -62,13 +64,14 @@ def evaluate(args): phones_dict=args.phones_dict, tones_dict=args.tones_dict, speaker_dict=args.speaker_dict) - + print("acoustic model done!") # vocoder voc_inference = get_voc_inference( voc=args.voc, voc_config=voc_config, voc_ckpt=args.voc_ckpt, voc_stat=args.voc_stat) + print("voc done!") # whether dygraph to static if args.inference_dir: @@ -78,7 +81,6 @@ def evaluate(args): am=args.am, inference_dir=args.inference_dir, speaker_dict=args.speaker_dict) - # vocoder voc_inference = voc_to_static( voc_inference=voc_inference, @@ -101,24 +103,13 @@ def evaluate(args): T = 0 for utt_id, sentence in sentences: with timer() as t: - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - elif args.lang == 'en': - input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - elif args.lang == 'mix': - input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en', 'mix'}!") + frontend_dict = run_frontend( + frontend=frontend, + text=sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=args.lang) + phone_ids = frontend_dict['phone_ids'] with paddle.no_grad(): flags = 0 for i in range(len(phone_ids)): @@ -132,7 +123,7 @@ def evaluate(args): else: mel = am_inference(part_phone_ids) elif am_name == 'speedyspeech': - part_tone_ids = tone_ids[i] + part_tone_ids = frontend_dict['tone_ids'][i] if am_dataset in {"aishell3", "vctk"}: spk_id = paddle.to_tensor(args.spk_id) mel = am_inference(part_phone_ids, part_tone_ids, diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index d8b23f1ad..6f86cc2b2 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -30,6 +30,7 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import model_alias +from paddlespeech.t2s.exps.syn_utils import run_frontend from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.utils import str2bool from paddlespeech.utils.dynamic_import import dynamic_import @@ -138,15 +139,13 @@ def evaluate(args): for utt_id, sentence in sentences: with timer() as t: - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - - phone_ids = input_ids["phone_ids"] - else: - print("lang should be 'zh' here!") + frontend_dict = run_frontend( + frontend=frontend, + text=sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=args.lang) + phone_ids = frontend_dict['phone_ids'] # merge_sentences=True here, so we only use the first item of phone_ids phone_ids = phone_ids[0] with paddle.no_grad(): diff --git a/paddlespeech/t2s/frontend/mix_frontend.py b/paddlespeech/t2s/frontend/mix_frontend.py index 6386c871e..5f145098e 100644 --- a/paddlespeech/t2s/frontend/mix_frontend.py +++ b/paddlespeech/t2s/frontend/mix_frontend.py @@ -136,7 +136,8 @@ class MixFrontend(): sentence: str, merge_sentences: bool=True, get_tone_ids: bool=False, - add_sp: bool=True) -> Dict[str, List[paddle.Tensor]]: + add_sp: bool=True, + to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: sentences = self._split(sentence) phones_list = [] @@ -152,11 +153,12 @@ class MixFrontend(): input_ids = self.zh_frontend.get_input_ids( content, merge_sentences=True, - get_tone_ids=get_tone_ids) + get_tone_ids=get_tone_ids, + to_tensor=to_tensor) elif lang == "en": input_ids = self.en_frontend.get_input_ids( - content, merge_sentences=True) + content, merge_sentences=True, to_tensor=to_tensor) phones_seg.append(input_ids["phone_ids"][0]) if add_sp: diff --git a/paddlespeech/t2s/frontend/phonectic.py b/paddlespeech/t2s/frontend/phonectic.py index 8e9f11737..873aa359c 100644 --- a/paddlespeech/t2s/frontend/phonectic.py +++ b/paddlespeech/t2s/frontend/phonectic.py @@ -82,8 +82,10 @@ class English(Phonetics): phone_ids = [self.vocab_phones[item] for item in phonemes] return np.array(phone_ids, np.int64) - def get_input_ids(self, sentence: str, - merge_sentences: bool=False) -> paddle.Tensor: + def get_input_ids(self, + sentence: str, + merge_sentences: bool=False, + to_tensor: bool=True) -> paddle.Tensor: result = {} sentences = self.text_normalizer._split(sentence, lang="en") phones_list = [] @@ -112,7 +114,8 @@ class English(Phonetics): for part_phones_list in phones_list: phone_ids = self._p2id(part_phones_list) - phone_ids = paddle.to_tensor(phone_ids) + if to_tensor: + phone_ids = paddle.to_tensor(phone_ids) temp_phone_ids.append(phone_ids) result["phone_ids"] = temp_phone_ids return result diff --git a/paddlespeech/t2s/frontend/zh_frontend.py b/paddlespeech/t2s/frontend/zh_frontend.py index 143ccbc15..ef8963c08 100644 --- a/paddlespeech/t2s/frontend/zh_frontend.py +++ b/paddlespeech/t2s/frontend/zh_frontend.py @@ -303,15 +303,15 @@ class Frontend(): print("----------------------------") return phonemes - def get_input_ids( - self, - sentence: str, - merge_sentences: bool=True, - get_tone_ids: bool=False, - robot: bool=False, - print_info: bool=False, - add_blank: bool=False, - blank_token: str="") -> Dict[str, List[paddle.Tensor]]: + def get_input_ids(self, + sentence: str, + merge_sentences: bool=True, + get_tone_ids: bool=False, + robot: bool=False, + print_info: bool=False, + add_blank: bool=False, + blank_token: str="", + to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: phonemes = self.get_phonemes( sentence, merge_sentences=merge_sentences, @@ -322,20 +322,22 @@ class Frontend(): tones = [] temp_phone_ids = [] temp_tone_ids = [] + for part_phonemes in phonemes: phones, tones = self._get_phone_tone( part_phonemes, get_tone_ids=get_tone_ids) - if add_blank: phones = insert_after_character(phones, blank_token) - if tones: tone_ids = self._t2id(tones) - tone_ids = paddle.to_tensor(tone_ids) + if to_tensor: + tone_ids = paddle.to_tensor(tone_ids) temp_tone_ids.append(tone_ids) if phones: phone_ids = self._p2id(phones) - phone_ids = paddle.to_tensor(phone_ids) + # if use paddle.to_tensor() in onnxruntime, the first time will be too low + if to_tensor: + phone_ids = paddle.to_tensor(phone_ids) temp_phone_ids.append(phone_ids) if temp_tone_ids: result["tone_ids"] = temp_tone_ids From cd662a08e06477f89e07bb0a518c7a27ab94b20c Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 4 Aug 2022 03:18:16 +0000 Subject: [PATCH 07/11] fix for load specified model files --- paddlespeech/cli/tts/infer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 4d5ddb754..c8d5447ec 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -364,7 +364,7 @@ class TTSExecutor(BaseExecutor): else: self.am_ckpt = os.path.abspath(am_ckpt[0]) self.phones_dict = os.path.abspath(phones_dict) - self.am_res_path = os.path.dirname(os.path.abspath(am_ckpt)) + self.am_res_path = os.path.dirname(os.path.abspath(self.am_ckpt)) self.am_fs = fs # for speedyspeech @@ -404,7 +404,6 @@ class TTSExecutor(BaseExecutor): # frontend self.frontend = get_frontend( lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict) - self.am_sess = get_sess( model_path=self.am_ckpt, device=device, cpu_threads=cpu_threads) From c6b25c05f484bbe380f560a12dd42b5256f09b9b Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 4 Aug 2022 06:07:56 +0000 Subject: [PATCH 08/11] change logger.debug to logger.info for streaming asr --- paddlespeech/server/utils/audio_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddlespeech/server/utils/audio_handler.py b/paddlespeech/server/utils/audio_handler.py index 43b73d6eb..add04156d 100644 --- a/paddlespeech/server/utils/audio_handler.py +++ b/paddlespeech/server/utils/audio_handler.py @@ -160,7 +160,7 @@ class ASRWsAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() - logger.debug("client receive msg={}".format(msg)) + logger.info("client receive msg={}".format(msg)) # 3. send chunk audio data to engine for chunk_data in self.read_wave(wavfile_path): @@ -170,7 +170,7 @@ class ASRWsAudioHandler: if self.punc_server and len(msg["result"]) > 0: msg["result"] = self.punc_server.run(msg["result"]) - logger.debug("client receive msg={}".format(msg)) + logger.info("client receive msg={}".format(msg)) # 4. we must send finished signal to the server audio_info = json.dumps( @@ -317,7 +317,7 @@ class TTSWsHandler: start_request = json.dumps({"task": "tts", "signal": "start"}) await ws.send(start_request) msg = await ws.recv() - logger.debug(f"client receive msg={msg}") + logger.info(f"client receive msg={msg}") msg = json.loads(msg) session = msg["session"] From 788a3062d0b91094d99922c565b6f2dcf215311b Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 4 Aug 2022 07:05:53 +0000 Subject: [PATCH 09/11] fix onnx am_ckpt from list to item in prtrained_mdoels.py --- paddlespeech/cli/tts/infer.py | 8 ++++---- paddlespeech/resource/pretrained_models.py | 15 ++++++++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index c8d5447ec..86ff4fe07 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -226,7 +226,7 @@ class TTSExecutor(BaseExecutor): self.am_ckpt = os.path.abspath(am_ckpt) self.am_stat = os.path.abspath(am_stat) self.phones_dict = os.path.abspath(phones_dict) - self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) + self.am_res_path = os.path.dirname(self.am_config) # for speedyspeech self.tones_dict = None @@ -354,7 +354,7 @@ class TTSExecutor(BaseExecutor): if use_pretrained_am: self.am_res_path = self.task_resource.res_dir self.am_ckpt = os.path.join(self.am_res_path, - self.task_resource.res_dict['ckpt'][0]) + self.task_resource.res_dict['ckpt']) # must have phones_dict in acoustic self.phones_dict = os.path.join( self.am_res_path, self.task_resource.res_dict['phones_dict']) @@ -362,9 +362,9 @@ class TTSExecutor(BaseExecutor): logger.debug(self.am_res_path) logger.debug(self.am_ckpt) else: - self.am_ckpt = os.path.abspath(am_ckpt[0]) + self.am_ckpt = os.path.abspath(am_ckpt) self.phones_dict = os.path.abspath(phones_dict) - self.am_res_path = os.path.dirname(os.path.abspath(self.am_ckpt)) + self.am_res_path = os.path.dirname(self.am_ckpt) self.am_fs = fs # for speedyspeech diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 43d63925b..bfe2bc7ec 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -1095,7 +1095,8 @@ tts_onnx_pretrained_models = { 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip', 'md5': '3e9c45af9ef70675fc1968ed5074fc88', - 'ckpt': ['speedyspeech_csmsc.onnx'], + 'ckpt': + 'speedyspeech_csmsc.onnx', 'phones_dict': 'phone_id_map.txt', 'tones_dict': @@ -1111,7 +1112,8 @@ tts_onnx_pretrained_models = { 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', 'md5': 'fd3ad38d83273ad51f0ea4f4abf3ab4e', - 'ckpt': ['fastspeech2_csmsc.onnx'], + 'ckpt': + 'fastspeech2_csmsc.onnx', 'phones_dict': 'phone_id_map.txt', 'sample_rate': @@ -1124,7 +1126,8 @@ tts_onnx_pretrained_models = { 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_onnx_1.1.0.zip', 'md5': '00754307636a48c972a5f3e65cda3d18', - 'ckpt': ['fastspeech2_ljspeech.onnx'], + 'ckpt': + 'fastspeech2_ljspeech.onnx', 'phones_dict': 'phone_id_map.txt', 'sample_rate': @@ -1137,7 +1140,8 @@ tts_onnx_pretrained_models = { 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_onnx_1.1.0.zip', 'md5': 'a1d6ee21de897ce394f5469e2bb4df0d', - 'ckpt': ['fastspeech2_aishell3.onnx'], + 'ckpt': + 'fastspeech2_aishell3.onnx', 'phones_dict': 'phone_id_map.txt', 'speaker_dict': @@ -1152,7 +1156,8 @@ tts_onnx_pretrained_models = { 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip', 'md5': 'd9c3a9b02204a2070504dd99f5f959bf', - 'ckpt': ['fastspeech2_vctk.onnx'], + 'ckpt': + 'fastspeech2_vctk.onnx', 'phones_dict': 'phone_id_map.txt', 'speaker_dict': From 8da993bbf8165669c40ec4b1016f8a2b4db6ed95 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 4 Aug 2022 07:40:10 +0000 Subject: [PATCH 10/11] fix fs bug --- paddlespeech/cli/tts/infer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 86ff4fe07..1b02192e1 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -172,6 +172,11 @@ class TTSExecutor(BaseExecutor): type=str2bool, default=False, help="whether to usen onnxruntime inference.") + self.parser.add_argument( + '--fs', + type=int, + default=24000, + help='sample rate for onnx models when use specified model files.') def _init_from_path( self, @@ -581,6 +586,7 @@ class TTSExecutor(BaseExecutor): spk_id = args.spk_id use_onnx = args.use_onnx cpu_threads = args.cpu_threads + fs = args.fs if not args.verbose: self.disable_task_loggers() @@ -619,7 +625,8 @@ class TTSExecutor(BaseExecutor): device=device, output=output, use_onnx=use_onnx, - cpu_threads=cpu_threads) + cpu_threads=cpu_threads, + fs=fs) task_results[id_] = res except Exception as e: has_exceptions = True @@ -653,7 +660,8 @@ class TTSExecutor(BaseExecutor): device: str=paddle.get_device(), output: str='output.wav', use_onnx: bool=False, - cpu_threads: int=2): + cpu_threads: int=2, + fs: int=24000): """ Python API to call an executor. """ @@ -695,7 +703,8 @@ class TTSExecutor(BaseExecutor): voc_ckpt=voc_ckpt, lang=lang, device=device, - cpu_threads=cpu_threads) + cpu_threads=cpu_threads, + fs=fs) self.infer_onnx(text=text, lang=lang, am=am, spk_id=spk_id) res = self.postprocess_onnx(output=output) return res From c3d47441cf1466b545a8eba668ff837647f13126 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 4 Aug 2022 07:52:18 +0000 Subject: [PATCH 11/11] fix fs bug in inference.py (change fixed 24000 to variable for ljspeech) --- paddlespeech/t2s/exps/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3732e0f40..7efbd8aa5 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -182,7 +182,7 @@ def main(): speed = wav.size / t.elapse rtf = fs / speed - sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=fs) print( f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." )