You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/demos/speech_web/speech_server/src/SpeechBase/tts.py

209 lines
8.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# tts 推理引擎,支持流式与非流式
# 精简化使用
# 用 onnxruntime 进行推理
# 1. 下载对应的模型
# 2. 加载模型
# 3. 端到端推理
# 4. 流式推理
import base64
import math
import logging
import numpy as np
from paddlespeech.server.utils.onnx_infer import get_sess
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.server.utils.util import denorm, get_chunks
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine
class TTS:
def __init__(self, config_path):
self.config = get_config(config_path)['tts_online-onnx']
self.config['voc_block'] = 36
self.engine = TTSEngine()
self.engine.init(self.config)
self.executor = self.engine.executor
#self.engine.warm_up()
# 前端初始化
self.frontend = Frontend(
phone_vocab_path=self.engine.executor.phones_dict,
tone_vocab_path=None)
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
"""
front_pad = min(chunk_id * block, pad)
# first chunk
if chunk_id == 0:
data = data[:block * upsample]
# last chunk
elif chunk_id == chunk_num - 1:
data = data[front_pad * upsample:]
# middle chunk
else:
data = data[front_pad * upsample:(front_pad + block) * upsample]
return data
def offlineTTS(self, text):
get_tone_ids = False
merge_sentences = False
input_ids = self.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
wav_list = []
for i in range(len(phone_ids)):
orig_hs = self.engine.executor.am_encoder_infer_sess.run(
None, input_feed={'text': phone_ids[i].numpy()}
)
hs = orig_hs[0]
am_decoder_output = self.engine.executor.am_decoder_sess.run(
None, input_feed={'xs': hs})
am_postnet_output = self.engine.executor.am_postnet_sess.run(
None,
input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
})
am_output_data = am_decoder_output + np.transpose(
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
mel = denorm(normalized_mel, self.engine.executor.am_mu, self.engine.executor.am_std)
wav = self.engine.executor.voc_sess.run(
output_names=None, input_feed={'logmel': mel})[0]
wav_list.append(wav)
wavs = np.concatenate(wav_list)
return wavs
def streamTTS(self, text):
get_tone_ids = False
merge_sentences = False
# front
input_ids = self.frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i].numpy()
voc_chunk_id = 0
# fastspeech2_csmsc
if self.config.am == "fastspeech2_csmsc_onnx":
# am
mel = self.executor.am_sess.run(
output_names=None, input_feed={'text': part_phone_ids})
mel = mel[0]
# voc streaming
mel_chunks = get_chunks(mel, self.config.voc_block, self.config.voc_pad, "voc")
voc_chunk_num = len(mel_chunks)
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': mel_chunk})
sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i,
self.config.voc_block, self.config.voc_pad,
self.config.voc_upsample)
yield self.after_process(sub_wav)
# fastspeech2_cnndecoder_csmsc
elif self.config.am == "fastspeech2_cnndecoder_csmsc_onnx":
# am
orig_hs = self.executor.am_encoder_infer_sess.run(
None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0]
# streaming voc chunk info
mel_len = orig_hs.shape[1]
voc_chunk_num = math.ceil(mel_len / self.config.voc_block)
start = 0
end = min(self.config.voc_block + self.config.voc_pad, mel_len)
# streaming am
hss = get_chunks(orig_hs, self.config.am_block, self.config.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
am_decoder_output = self.executor.am_decoder_sess.run(
None, input_feed={'xs': hs})
am_postnet_output = self.executor.am_postnet_sess.run(
None,
input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
})
am_output_data = am_decoder_output + np.transpose(
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, self.executor.am_mu,
self.executor.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i,
self.config.am_block, self.config.am_pad, 1)
if i == 0:
mel_streaming = sub_mel
else:
mel_streaming = np.concatenate(
(mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
voc_chunk = mel_streaming[start:end, :]
sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': voc_chunk})
sub_wav = self.depadding(
sub_wav[0], voc_chunk_num, voc_chunk_id,
self.config.voc_block, self.config.voc_pad, self.config.voc_upsample)
yield self.after_process(sub_wav)
voc_chunk_id += 1
start = max(
0, voc_chunk_id * self.config.voc_block - self.config.voc_pad)
end = min(
(voc_chunk_id + 1) * self.config.voc_block + self.config.voc_pad,
mel_len)
else:
logging.error(
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
)
def streamTTSBytes(self, text):
for wav in self.engine.executor.infer(
text=text,
lang=self.engine.config.lang,
am=self.engine.config.am,
spk_id=0):
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
yield wav_bytes
def after_process(self, wav):
# for tvm
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64
return wav_base64
def streamTTS_TVM(self, text):
# 用 TVM 优化
pass