From 40dde22fc48f41cffdace68847ccbeb00cc1cef4 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Tue, 19 Apr 2022 12:59:48 +0800 Subject: [PATCH] code format, test=doc --- .../server/engine/tts/online/tts_engine.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py index 8e76225dc..a84644e70 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -127,33 +127,40 @@ class TTSServerExecutor(TTSExecutor): self.voc_block = voc_block self.voc_pad = voc_pad - def get_model_info(self, step, model_name, ckpt, stat): + def get_model_info(self, + field: str, + model_name: str, + ckpt: Optional[os.PathLike], + stat: Optional[os.PathLike]): """get model information Args: - step (string): am or voc - model_name (string): model type, support fastspeech2, higigan, mb_melgan - ckpt (string): ckpt file - stat (string): stat file, including mean and standard deviation + field (str): am or voc + model_name (str): model type, support fastspeech2, higigan, mb_melgan + ckpt (Optional[os.PathLike]): ckpt file + stat (Optional[os.PathLike]): stat file, including mean and standard deviation Returns: - model, model_mu, model_std + [module]: model module + [Tensor]: mean + [Tensor]: standard deviation """ + model_class = dynamic_import(model_name, model_alias) - if step == "am": + if field == "am": odim = self.am_config.n_mels model = model_class( idim=self.vocab_size, odim=odim, **self.am_config["model"]) model.set_state_dict(paddle.load(ckpt)["main_params"]) - elif step == "voc": + elif field == "voc": model = model_class(**self.voc_config["generator_params"]) model.set_state_dict(paddle.load(ckpt)["generator_params"]) model.remove_weight_norm() else: - logger.error("Please set correct step, am or voc") + logger.error("Please set correct field, am or voc") model.eval() model_mu, model_std = np.load(stat) @@ -346,7 +353,8 @@ class TTSServerExecutor(TTSExecutor): voc_block = self.voc_block voc_pad = self.voc_pad voc_upsample = self.voc_config.n_shift - flag = 1 + # first_flag 用于标记首包 + first_flag = 1 get_tone_ids = False merge_sentences = False @@ -376,7 +384,7 @@ class TTSServerExecutor(TTSExecutor): if am == "fastspeech2_csmsc": # am mel = self.am_inference(part_phone_ids) - if flag == 1: + if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et @@ -388,11 +396,11 @@ class TTSServerExecutor(TTSExecutor): sub_wav = self.voc_inference(mel_chunk) sub_wav = self.depadding(sub_wav, voc_chunk_num, i, voc_block, voc_pad, voc_upsample) - if flag == 1: + if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et self.first_response_time = first_voc_et - frontend_st - flag = 0 + first_flag = 0 yield sub_wav @@ -427,9 +435,10 @@ class TTSServerExecutor(TTSExecutor): (mel_streaming, sub_mel), axis=0) # streaming voc + # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 while (mel_streaming.shape[0] >= end and voc_chunk_id < voc_chunk_num): - if flag == 1: + if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et voc_chunk = mel_streaming[start:end, :] @@ -439,11 +448,11 @@ class TTSServerExecutor(TTSExecutor): sub_wav = self.depadding(sub_wav, voc_chunk_num, voc_chunk_id, voc_block, voc_pad, voc_upsample) - if flag == 1: + if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et self.first_response_time = first_voc_et - frontend_st - flag = 0 + first_flag = 0 yield sub_wav @@ -470,7 +479,8 @@ class TTSEngine(BaseEngine): def __init__(self, name=None): """Initialize TTS server engine """ - super(TTSEngine, self).__init__() + #super(TTSEngine, self).__init__() + super().__init__() def init(self, config: dict) -> bool: self.config = config