code format, test=doc

pull/1713/head
lym0302 3 years ago
parent 00a6236fe2
commit 40dde22fc4

@ -127,33 +127,40 @@ class TTSServerExecutor(TTSExecutor):
self.voc_block = voc_block
self.voc_pad = voc_pad
def get_model_info(self, step, model_name, ckpt, stat):
def get_model_info(self,
field: str,
model_name: str,
ckpt: Optional[os.PathLike],
stat: Optional[os.PathLike]):
"""get model information
Args:
step (string): am or voc
model_name (string): model type, support fastspeech2, higigan, mb_melgan
ckpt (string): ckpt file
stat (string): stat file, including mean and standard deviation
field (str): am or voc
model_name (str): model type, support fastspeech2, higigan, mb_melgan
ckpt (Optional[os.PathLike]): ckpt file
stat (Optional[os.PathLike]): stat file, including mean and standard deviation
Returns:
model, model_mu, model_std
[module]: model module
[Tensor]: mean
[Tensor]: standard deviation
"""
model_class = dynamic_import(model_name, model_alias)
if step == "am":
if field == "am":
odim = self.am_config.n_mels
model = model_class(
idim=self.vocab_size, odim=odim, **self.am_config["model"])
model.set_state_dict(paddle.load(ckpt)["main_params"])
elif step == "voc":
elif field == "voc":
model = model_class(**self.voc_config["generator_params"])
model.set_state_dict(paddle.load(ckpt)["generator_params"])
model.remove_weight_norm()
else:
logger.error("Please set correct step, am or voc")
logger.error("Please set correct field, am or voc")
model.eval()
model_mu, model_std = np.load(stat)
@ -346,7 +353,8 @@ class TTSServerExecutor(TTSExecutor):
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift
flag = 1
# first_flag 用于标记首包
first_flag = 1
get_tone_ids = False
merge_sentences = False
@ -376,7 +384,7 @@ class TTSServerExecutor(TTSExecutor):
if am == "fastspeech2_csmsc":
# am
mel = self.am_inference(part_phone_ids)
if flag == 1:
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
@ -388,11 +396,11 @@ class TTSServerExecutor(TTSExecutor):
sub_wav = self.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
if flag == 1:
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
flag = 0
first_flag = 0
yield sub_wav
@ -427,9 +435,10 @@ class TTSServerExecutor(TTSExecutor):
(mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
if flag == 1:
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
@ -439,11 +448,11 @@ class TTSServerExecutor(TTSExecutor):
sub_wav = self.depadding(sub_wav, voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
if flag == 1:
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
flag = 0
first_flag = 0
yield sub_wav
@ -470,7 +479,8 @@ class TTSEngine(BaseEngine):
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
#super(TTSEngine, self).__init__()
super().__init__()
def init(self, config: dict) -> bool:
self.config = config

Loading…
Cancel
Save