code format, test=doc

pull/1713/head
lym0302 3 years ago
parent 00a6236fe2
commit 40dde22fc4

@ -127,33 +127,40 @@ class TTSServerExecutor(TTSExecutor):
self.voc_block = voc_block self.voc_block = voc_block
self.voc_pad = voc_pad self.voc_pad = voc_pad
def get_model_info(self, step, model_name, ckpt, stat): def get_model_info(self,
field: str,
model_name: str,
ckpt: Optional[os.PathLike],
stat: Optional[os.PathLike]):
"""get model information """get model information
Args: Args:
step (string): am or voc field (str): am or voc
model_name (string): model type, support fastspeech2, higigan, mb_melgan model_name (str): model type, support fastspeech2, higigan, mb_melgan
ckpt (string): ckpt file ckpt (Optional[os.PathLike]): ckpt file
stat (string): stat file, including mean and standard deviation stat (Optional[os.PathLike]): stat file, including mean and standard deviation
Returns: Returns:
model, model_mu, model_std [module]: model module
[Tensor]: mean
[Tensor]: standard deviation
""" """
model_class = dynamic_import(model_name, model_alias) model_class = dynamic_import(model_name, model_alias)
if step == "am": if field == "am":
odim = self.am_config.n_mels odim = self.am_config.n_mels
model = model_class( model = model_class(
idim=self.vocab_size, odim=odim, **self.am_config["model"]) idim=self.vocab_size, odim=odim, **self.am_config["model"])
model.set_state_dict(paddle.load(ckpt)["main_params"]) model.set_state_dict(paddle.load(ckpt)["main_params"])
elif step == "voc": elif field == "voc":
model = model_class(**self.voc_config["generator_params"]) model = model_class(**self.voc_config["generator_params"])
model.set_state_dict(paddle.load(ckpt)["generator_params"]) model.set_state_dict(paddle.load(ckpt)["generator_params"])
model.remove_weight_norm() model.remove_weight_norm()
else: else:
logger.error("Please set correct step, am or voc") logger.error("Please set correct field, am or voc")
model.eval() model.eval()
model_mu, model_std = np.load(stat) model_mu, model_std = np.load(stat)
@ -346,7 +353,8 @@ class TTSServerExecutor(TTSExecutor):
voc_block = self.voc_block voc_block = self.voc_block
voc_pad = self.voc_pad voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift voc_upsample = self.voc_config.n_shift
flag = 1 # first_flag 用于标记首包
first_flag = 1
get_tone_ids = False get_tone_ids = False
merge_sentences = False merge_sentences = False
@ -376,7 +384,7 @@ class TTSServerExecutor(TTSExecutor):
if am == "fastspeech2_csmsc": if am == "fastspeech2_csmsc":
# am # am
mel = self.am_inference(part_phone_ids) mel = self.am_inference(part_phone_ids)
if flag == 1: if first_flag == 1:
first_am_et = time.time() first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et self.first_am_infer = first_am_et - frontend_et
@ -388,11 +396,11 @@ class TTSServerExecutor(TTSExecutor):
sub_wav = self.voc_inference(mel_chunk) sub_wav = self.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i, sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample) voc_block, voc_pad, voc_upsample)
if flag == 1: if first_flag == 1:
first_voc_et = time.time() first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st self.first_response_time = first_voc_et - frontend_st
flag = 0 first_flag = 0
yield sub_wav yield sub_wav
@ -427,9 +435,10 @@ class TTSServerExecutor(TTSExecutor):
(mel_streaming, sub_mel), axis=0) (mel_streaming, sub_mel), axis=0)
# streaming voc # streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num): voc_chunk_id < voc_chunk_num):
if flag == 1: if first_flag == 1:
first_am_et = time.time() first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :] voc_chunk = mel_streaming[start:end, :]
@ -439,11 +448,11 @@ class TTSServerExecutor(TTSExecutor):
sub_wav = self.depadding(sub_wav, voc_chunk_num, sub_wav = self.depadding(sub_wav, voc_chunk_num,
voc_chunk_id, voc_block, voc_chunk_id, voc_block,
voc_pad, voc_upsample) voc_pad, voc_upsample)
if flag == 1: if first_flag == 1:
first_voc_et = time.time() first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st self.first_response_time = first_voc_et - frontend_st
flag = 0 first_flag = 0
yield sub_wav yield sub_wav
@ -470,7 +479,8 @@ class TTSEngine(BaseEngine):
def __init__(self, name=None): def __init__(self, name=None):
"""Initialize TTS server engine """Initialize TTS server engine
""" """
super(TTSEngine, self).__init__() #super(TTSEngine, self).__init__()
super().__init__()
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
self.config = config self.config = config

Loading…
Cancel
Save