|
|
@ -127,33 +127,40 @@ class TTSServerExecutor(TTSExecutor):
|
|
|
|
self.voc_block = voc_block
|
|
|
|
self.voc_block = voc_block
|
|
|
|
self.voc_pad = voc_pad
|
|
|
|
self.voc_pad = voc_pad
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_info(self, step, model_name, ckpt, stat):
|
|
|
|
def get_model_info(self,
|
|
|
|
|
|
|
|
field: str,
|
|
|
|
|
|
|
|
model_name: str,
|
|
|
|
|
|
|
|
ckpt: Optional[os.PathLike],
|
|
|
|
|
|
|
|
stat: Optional[os.PathLike]):
|
|
|
|
"""get model information
|
|
|
|
"""get model information
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
step (string): am or voc
|
|
|
|
field (str): am or voc
|
|
|
|
model_name (string): model type, support fastspeech2, higigan, mb_melgan
|
|
|
|
model_name (str): model type, support fastspeech2, higigan, mb_melgan
|
|
|
|
ckpt (string): ckpt file
|
|
|
|
ckpt (Optional[os.PathLike]): ckpt file
|
|
|
|
stat (string): stat file, including mean and standard deviation
|
|
|
|
stat (Optional[os.PathLike]): stat file, including mean and standard deviation
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
model, model_mu, model_std
|
|
|
|
[module]: model module
|
|
|
|
|
|
|
|
[Tensor]: mean
|
|
|
|
|
|
|
|
[Tensor]: standard deviation
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
model_class = dynamic_import(model_name, model_alias)
|
|
|
|
model_class = dynamic_import(model_name, model_alias)
|
|
|
|
|
|
|
|
|
|
|
|
if step == "am":
|
|
|
|
if field == "am":
|
|
|
|
odim = self.am_config.n_mels
|
|
|
|
odim = self.am_config.n_mels
|
|
|
|
model = model_class(
|
|
|
|
model = model_class(
|
|
|
|
idim=self.vocab_size, odim=odim, **self.am_config["model"])
|
|
|
|
idim=self.vocab_size, odim=odim, **self.am_config["model"])
|
|
|
|
model.set_state_dict(paddle.load(ckpt)["main_params"])
|
|
|
|
model.set_state_dict(paddle.load(ckpt)["main_params"])
|
|
|
|
|
|
|
|
|
|
|
|
elif step == "voc":
|
|
|
|
elif field == "voc":
|
|
|
|
model = model_class(**self.voc_config["generator_params"])
|
|
|
|
model = model_class(**self.voc_config["generator_params"])
|
|
|
|
model.set_state_dict(paddle.load(ckpt)["generator_params"])
|
|
|
|
model.set_state_dict(paddle.load(ckpt)["generator_params"])
|
|
|
|
model.remove_weight_norm()
|
|
|
|
model.remove_weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
logger.error("Please set correct step, am or voc")
|
|
|
|
logger.error("Please set correct field, am or voc")
|
|
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
model.eval()
|
|
|
|
model_mu, model_std = np.load(stat)
|
|
|
|
model_mu, model_std = np.load(stat)
|
|
|
@ -346,7 +353,8 @@ class TTSServerExecutor(TTSExecutor):
|
|
|
|
voc_block = self.voc_block
|
|
|
|
voc_block = self.voc_block
|
|
|
|
voc_pad = self.voc_pad
|
|
|
|
voc_pad = self.voc_pad
|
|
|
|
voc_upsample = self.voc_config.n_shift
|
|
|
|
voc_upsample = self.voc_config.n_shift
|
|
|
|
flag = 1
|
|
|
|
# first_flag 用于标记首包
|
|
|
|
|
|
|
|
first_flag = 1
|
|
|
|
|
|
|
|
|
|
|
|
get_tone_ids = False
|
|
|
|
get_tone_ids = False
|
|
|
|
merge_sentences = False
|
|
|
|
merge_sentences = False
|
|
|
@ -376,7 +384,7 @@ class TTSServerExecutor(TTSExecutor):
|
|
|
|
if am == "fastspeech2_csmsc":
|
|
|
|
if am == "fastspeech2_csmsc":
|
|
|
|
# am
|
|
|
|
# am
|
|
|
|
mel = self.am_inference(part_phone_ids)
|
|
|
|
mel = self.am_inference(part_phone_ids)
|
|
|
|
if flag == 1:
|
|
|
|
if first_flag == 1:
|
|
|
|
first_am_et = time.time()
|
|
|
|
first_am_et = time.time()
|
|
|
|
self.first_am_infer = first_am_et - frontend_et
|
|
|
|
self.first_am_infer = first_am_et - frontend_et
|
|
|
|
|
|
|
|
|
|
|
@ -388,11 +396,11 @@ class TTSServerExecutor(TTSExecutor):
|
|
|
|
sub_wav = self.voc_inference(mel_chunk)
|
|
|
|
sub_wav = self.voc_inference(mel_chunk)
|
|
|
|
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
|
|
|
|
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
|
|
|
|
voc_block, voc_pad, voc_upsample)
|
|
|
|
voc_block, voc_pad, voc_upsample)
|
|
|
|
if flag == 1:
|
|
|
|
if first_flag == 1:
|
|
|
|
first_voc_et = time.time()
|
|
|
|
first_voc_et = time.time()
|
|
|
|
self.first_voc_infer = first_voc_et - first_am_et
|
|
|
|
self.first_voc_infer = first_voc_et - first_am_et
|
|
|
|
self.first_response_time = first_voc_et - frontend_st
|
|
|
|
self.first_response_time = first_voc_et - frontend_st
|
|
|
|
flag = 0
|
|
|
|
first_flag = 0
|
|
|
|
|
|
|
|
|
|
|
|
yield sub_wav
|
|
|
|
yield sub_wav
|
|
|
|
|
|
|
|
|
|
|
@ -427,9 +435,10 @@ class TTSServerExecutor(TTSExecutor):
|
|
|
|
(mel_streaming, sub_mel), axis=0)
|
|
|
|
(mel_streaming, sub_mel), axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
# streaming voc
|
|
|
|
# streaming voc
|
|
|
|
|
|
|
|
# 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理
|
|
|
|
while (mel_streaming.shape[0] >= end and
|
|
|
|
while (mel_streaming.shape[0] >= end and
|
|
|
|
voc_chunk_id < voc_chunk_num):
|
|
|
|
voc_chunk_id < voc_chunk_num):
|
|
|
|
if flag == 1:
|
|
|
|
if first_flag == 1:
|
|
|
|
first_am_et = time.time()
|
|
|
|
first_am_et = time.time()
|
|
|
|
self.first_am_infer = first_am_et - frontend_et
|
|
|
|
self.first_am_infer = first_am_et - frontend_et
|
|
|
|
voc_chunk = mel_streaming[start:end, :]
|
|
|
|
voc_chunk = mel_streaming[start:end, :]
|
|
|
@ -439,11 +448,11 @@ class TTSServerExecutor(TTSExecutor):
|
|
|
|
sub_wav = self.depadding(sub_wav, voc_chunk_num,
|
|
|
|
sub_wav = self.depadding(sub_wav, voc_chunk_num,
|
|
|
|
voc_chunk_id, voc_block,
|
|
|
|
voc_chunk_id, voc_block,
|
|
|
|
voc_pad, voc_upsample)
|
|
|
|
voc_pad, voc_upsample)
|
|
|
|
if flag == 1:
|
|
|
|
if first_flag == 1:
|
|
|
|
first_voc_et = time.time()
|
|
|
|
first_voc_et = time.time()
|
|
|
|
self.first_voc_infer = first_voc_et - first_am_et
|
|
|
|
self.first_voc_infer = first_voc_et - first_am_et
|
|
|
|
self.first_response_time = first_voc_et - frontend_st
|
|
|
|
self.first_response_time = first_voc_et - frontend_st
|
|
|
|
flag = 0
|
|
|
|
first_flag = 0
|
|
|
|
|
|
|
|
|
|
|
|
yield sub_wav
|
|
|
|
yield sub_wav
|
|
|
|
|
|
|
|
|
|
|
@ -470,7 +479,8 @@ class TTSEngine(BaseEngine):
|
|
|
|
def __init__(self, name=None):
|
|
|
|
def __init__(self, name=None):
|
|
|
|
"""Initialize TTS server engine
|
|
|
|
"""Initialize TTS server engine
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
super(TTSEngine, self).__init__()
|
|
|
|
#super(TTSEngine, self).__init__()
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
def init(self, config: dict) -> bool:
|
|
|
|
def init(self, config: dict) -> bool:
|
|
|
|
self.config = config
|
|
|
|
self.config = config
|
|
|
|