|
|
|
@ -7,7 +7,8 @@
|
|
|
|
|
# 4. 流式推理
|
|
|
|
|
|
|
|
|
|
import base64
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
import logging
|
|
|
|
|
import numpy as np
|
|
|
|
|
from paddlespeech.server.utils.onnx_infer import get_sess
|
|
|
|
|
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
|
|
|
@ -17,14 +18,14 @@ from paddlespeech.server.utils.config import get_config
|
|
|
|
|
|
|
|
|
|
from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TTS:
|
|
|
|
|
def __init__(self, config_path):
|
|
|
|
|
self.config = get_config(config_path)['tts_online-onnx']
|
|
|
|
|
self.config['voc_block'] = 36
|
|
|
|
|
self.engine = TTSEngine()
|
|
|
|
|
self.engine = TTSEngine()
|
|
|
|
|
self.engine.init(self.config)
|
|
|
|
|
self.engine.warm_up()
|
|
|
|
|
self.executor = self.engine.executor
|
|
|
|
|
#self.engine.warm_up()
|
|
|
|
|
|
|
|
|
|
# 前端初始化
|
|
|
|
|
self.frontend = Frontend(
|
|
|
|
@ -81,8 +82,105 @@ class TTS:
|
|
|
|
|
return wavs
|
|
|
|
|
|
|
|
|
|
def streamTTS(self, text):
|
|
|
|
|
for sub_wav_base64 in self.engine.run(sentence=text):
|
|
|
|
|
yield sub_wav_base64
|
|
|
|
|
|
|
|
|
|
get_tone_ids = False
|
|
|
|
|
merge_sentences = False
|
|
|
|
|
|
|
|
|
|
# front
|
|
|
|
|
input_ids = self.frontend.get_input_ids(
|
|
|
|
|
text,
|
|
|
|
|
merge_sentences=merge_sentences,
|
|
|
|
|
get_tone_ids=get_tone_ids)
|
|
|
|
|
phone_ids = input_ids["phone_ids"]
|
|
|
|
|
|
|
|
|
|
for i in range(len(phone_ids)):
|
|
|
|
|
part_phone_ids = phone_ids[i].numpy()
|
|
|
|
|
voc_chunk_id = 0
|
|
|
|
|
|
|
|
|
|
# fastspeech2_csmsc
|
|
|
|
|
if self.config.am == "fastspeech2_csmsc_onnx":
|
|
|
|
|
# am
|
|
|
|
|
mel = self.executor.am_sess.run(
|
|
|
|
|
output_names=None, input_feed={'text': part_phone_ids})
|
|
|
|
|
mel = mel[0]
|
|
|
|
|
|
|
|
|
|
# voc streaming
|
|
|
|
|
mel_chunks = get_chunks(mel, self.config.voc_block, self.config.voc_pad, "voc")
|
|
|
|
|
voc_chunk_num = len(mel_chunks)
|
|
|
|
|
for i, mel_chunk in enumerate(mel_chunks):
|
|
|
|
|
sub_wav = self.executor.voc_sess.run(
|
|
|
|
|
output_names=None, input_feed={'logmel': mel_chunk})
|
|
|
|
|
sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i,
|
|
|
|
|
self.config.voc_block, self.config.voc_pad,
|
|
|
|
|
self.config.voc_upsample)
|
|
|
|
|
|
|
|
|
|
yield self.after_process(sub_wav)
|
|
|
|
|
|
|
|
|
|
# fastspeech2_cnndecoder_csmsc
|
|
|
|
|
elif self.config.am == "fastspeech2_cnndecoder_csmsc_onnx":
|
|
|
|
|
# am
|
|
|
|
|
orig_hs = self.executor.am_encoder_infer_sess.run(
|
|
|
|
|
None, input_feed={'text': part_phone_ids})
|
|
|
|
|
orig_hs = orig_hs[0]
|
|
|
|
|
|
|
|
|
|
# streaming voc chunk info
|
|
|
|
|
mel_len = orig_hs.shape[1]
|
|
|
|
|
voc_chunk_num = math.ceil(mel_len / self.config.voc_block)
|
|
|
|
|
start = 0
|
|
|
|
|
end = min(self.config.voc_block + self.config.voc_pad, mel_len)
|
|
|
|
|
|
|
|
|
|
# streaming am
|
|
|
|
|
hss = get_chunks(orig_hs, self.config.am_block, self.config.am_pad, "am")
|
|
|
|
|
am_chunk_num = len(hss)
|
|
|
|
|
for i, hs in enumerate(hss):
|
|
|
|
|
am_decoder_output = self.executor.am_decoder_sess.run(
|
|
|
|
|
None, input_feed={'xs': hs})
|
|
|
|
|
am_postnet_output = self.executor.am_postnet_sess.run(
|
|
|
|
|
None,
|
|
|
|
|
input_feed={
|
|
|
|
|
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
|
|
|
|
|
})
|
|
|
|
|
am_output_data = am_decoder_output + np.transpose(
|
|
|
|
|
am_postnet_output[0], (0, 2, 1))
|
|
|
|
|
normalized_mel = am_output_data[0][0]
|
|
|
|
|
|
|
|
|
|
sub_mel = denorm(normalized_mel, self.executor.am_mu,
|
|
|
|
|
self.executor.am_std)
|
|
|
|
|
sub_mel = self.depadding(sub_mel, am_chunk_num, i,
|
|
|
|
|
self.config.am_block, self.config.am_pad, 1)
|
|
|
|
|
|
|
|
|
|
if i == 0:
|
|
|
|
|
mel_streaming = sub_mel
|
|
|
|
|
else:
|
|
|
|
|
mel_streaming = np.concatenate(
|
|
|
|
|
(mel_streaming, sub_mel), axis=0)
|
|
|
|
|
|
|
|
|
|
# streaming voc
|
|
|
|
|
# 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理
|
|
|
|
|
while (mel_streaming.shape[0] >= end and
|
|
|
|
|
voc_chunk_id < voc_chunk_num):
|
|
|
|
|
voc_chunk = mel_streaming[start:end, :]
|
|
|
|
|
|
|
|
|
|
sub_wav = self.executor.voc_sess.run(
|
|
|
|
|
output_names=None, input_feed={'logmel': voc_chunk})
|
|
|
|
|
sub_wav = self.depadding(
|
|
|
|
|
sub_wav[0], voc_chunk_num, voc_chunk_id,
|
|
|
|
|
self.config.voc_block, self.config.voc_pad, self.config.voc_upsample)
|
|
|
|
|
|
|
|
|
|
yield self.after_process(sub_wav)
|
|
|
|
|
|
|
|
|
|
voc_chunk_id += 1
|
|
|
|
|
start = max(
|
|
|
|
|
0, voc_chunk_id * self.config.voc_block - self.config.voc_pad)
|
|
|
|
|
end = min(
|
|
|
|
|
(voc_chunk_id + 1) * self.config.voc_block + self.config.voc_pad,
|
|
|
|
|
mel_len)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
logging.error(
|
|
|
|
|
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def streamTTSBytes(self, text):
|
|
|
|
|
for wav in self.engine.executor.infer(
|
|
|
|
@ -106,16 +204,6 @@ class TTS:
|
|
|
|
|
# 用 TVM 优化
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
text = "啊哈哈哈哈哈哈啊哈哈哈哈哈哈啊哈哈哈哈哈哈啊哈哈哈哈哈哈啊哈哈哈哈哈哈"
|
|
|
|
|
config_path="../../PaddleSpeech/demos/streaming_tts_server/conf/tts_online_application.yaml"
|
|
|
|
|
tts = TTS(config_path)
|
|
|
|
|
|
|
|
|
|
for sub_wav in tts.streamTTS(text):
|
|
|
|
|
print("sub_wav_base64: ", len(sub_wav))
|
|
|
|
|
|
|
|
|
|
end_wav = tts.offlineTTS(text)
|
|
|
|
|
print(end_wav)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|