pull/2221/head
wangcanlong 3 years ago
parent a7f81afb11
commit 31ec825020

@ -19,6 +19,7 @@ loguru
matplotlib matplotlib
nara_wpe nara_wpe
onnxruntime==1.10.0 onnxruntime==1.10.0
opencc
pandas pandas
paddlenlp paddlenlp
paddlespeech_feat paddlespeech_feat
@ -36,6 +37,7 @@ soundfile~=0.10
textgrid textgrid
timer timer
tqdm tqdm
transformers==3.4.0
typeguard typeguard
visualdl visualdl
webrtcvad webrtcvad
@ -50,5 +52,3 @@ keyboard
uvicorn uvicorn
pattern_singleton pattern_singleton
braceexpand braceexpand
opencc
transformers==3.4.0

@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi fi
# inference with onnxruntime, use fastspeech2 + hifigan by default # inference with onnxruntime, use fastspeech2 + pwgan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
./local/ort_predict.sh ${train_output_path} ./local/ort_predict.sh ${train_output_path}
fi fi

@ -55,7 +55,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_ljspeech # ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_ljspeech
fi fi
# inference with onnxruntime, use fastspeech2 + hifigan by default # inference with onnxruntime, use fastspeech2 + pwgan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
./local/ort_predict.sh ${train_output_path} ./local/ort_predict.sh ${train_output_path}
fi fi

@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi fi
# inference with onnxruntime, use fastspeech2 + hifigan by default # inference with onnxruntime, use fastspeech2 + pwgan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
./local/ort_predict.sh ${train_output_path} ./local/ort_predict.sh ${train_output_path}
fi fi

@ -37,6 +37,7 @@ model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
########################################### ###########################################

@ -0,0 +1,28 @@
#!/bin/bash
if [ $# != 3 ];then
echo "usage: $0 config_path ckpt_prefix jit_model_path"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_path_prefix=$2
jit_model_export_path=$3
python3 -u ${BIN_DIR}/export.py \
--ngpu ${ngpu} \
--config ${config_path} \
--checkpoint_path ${ckpt_path_prefix} \
--export_path ${jit_model_export_path}
if [ $? -ne 0 ]; then
echo "Failed in export!"
exit 1
fi
exit 0

@ -29,10 +29,21 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import stats_wrapper from ..utils import stats_wrapper
from paddlespeech.resource import CommonTaskResource
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool
__all__ = ['TTSExecutor'] __all__ = ['TTSExecutor']
ONNX_SUPPORT_SET = {
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'pwgan_csmsc', 'pwgan_ljspeech',
'pwgan_aishell3', 'pwgan_vctk', 'mb_melgan_csmsc', 'hifigan_csmsc',
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk'
}
class TTSExecutor(BaseExecutor): class TTSExecutor(BaseExecutor):
@ -142,6 +153,8 @@ class TTSExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument('--cpu_threads', type=int, default=2)
self.parser.add_argument( self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name') '--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument( self.parser.add_argument(
@ -154,6 +167,16 @@ class TTSExecutor(BaseExecutor):
'--verbose', '--verbose',
action='store_true', action='store_true',
help='Increase logger verbosity of current task.') help='Increase logger verbosity of current task.')
self.parser.add_argument(
"--use_onnx",
type=str2bool,
default=False,
help="whether to usen onnxruntime inference.")
self.parser.add_argument(
'--fs',
type=int,
default=24000,
help='sample rate for onnx models when use specified model files.')
def _init_from_path( def _init_from_path(
self, self,
@ -164,7 +187,7 @@ class TTSExecutor(BaseExecutor):
phones_dict: Optional[os.PathLike]=None, phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
voc: str='pwgan_csmsc', voc: str='hifigan_csmsc',
voc_config: Optional[os.PathLike]=None, voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None, voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, voc_stat: Optional[os.PathLike]=None,
@ -208,7 +231,7 @@ class TTSExecutor(BaseExecutor):
self.am_ckpt = os.path.abspath(am_ckpt) self.am_ckpt = os.path.abspath(am_ckpt)
self.am_stat = os.path.abspath(am_stat) self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict) self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) self.am_res_path = os.path.dirname(self.am_config)
# for speedyspeech # for speedyspeech
self.tones_dict = None self.tones_dict = None
@ -288,58 +311,110 @@ class TTSExecutor(BaseExecutor):
lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict) lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict)
# acoustic model # acoustic model
odim = self.am_config.n_mels self.am_inference = get_am_inference(
# model: {model_name}_{dataset} am=am,
am_name = am[:am.rindex('_')] am_config=self.am_config,
am_ckpt=self.am_ckpt,
am_stat=self.am_stat,
phones_dict=self.phones_dict,
tones_dict=self.tones_dict,
speaker_dict=self.speaker_dict)
am_class = self.task_resource.get_model_class(am_name) # vocoder
am_inference_class = self.task_resource.get_model_class(am_name + self.voc_inference = get_voc_inference(
'_inference') voc=voc,
voc_config=self.voc_config,
voc_ckpt=self.voc_ckpt,
voc_stat=self.voc_stat)
if am_name == 'fastspeech2': def _init_from_path_onnx(self,
am = am_class( am: str='fastspeech2_csmsc',
idim=vocab_size, am_ckpt: Optional[os.PathLike]=None,
odim=odim, phones_dict: Optional[os.PathLike]=None,
spk_num=spk_num, tones_dict: Optional[os.PathLike]=None,
**self.am_config["model"]) speaker_dict: Optional[os.PathLike]=None,
elif am_name == 'speedyspeech': voc: str='hifigan_csmsc',
am = am_class( voc_ckpt: Optional[os.PathLike]=None,
vocab_size=vocab_size, lang: str='zh',
tone_size=tone_size, device: str='cpu',
**self.am_config["model"]) cpu_threads: int=2,
elif am_name == 'tacotron2': fs: int=24000):
am = am_class(idim=vocab_size, odim=odim, **self.am_config["model"]) if hasattr(self, 'am_sess') and hasattr(self, 'voc_sess'):
logger.debug('Models had been initialized.')
am.set_state_dict(paddle.load(self.am_ckpt)["main_params"]) return
am.eval()
am_mu, am_std = np.load(self.am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval()
# vocoder # am
# model: {model_name}_{dataset} if am_ckpt is None or phones_dict is None:
voc_name = voc[:voc.rindex('_')] use_pretrained_am = True
voc_class = self.task_resource.get_model_class(voc_name) else:
voc_inference_class = self.task_resource.get_model_class(voc_name + use_pretrained_am = False
'_inference')
if voc_name != 'wavernn': am_tag = am + '_onnx' + '-' + lang
voc = voc_class(**self.voc_config["generator_params"]) self.task_resource.set_task_model(
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) model_tag=am_tag,
voc.remove_weight_norm() model_type=0, # am
voc.eval() skip_download=not use_pretrained_am,
version=None, # default version
)
if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir
self.am_ckpt = os.path.join(self.am_res_path,
self.task_resource.res_dict['ckpt'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
self.am_res_path, self.task_resource.res_dict['phones_dict'])
self.am_fs = self.task_resource.res_dict['sample_rate']
logger.debug(self.am_res_path)
logger.debug(self.am_ckpt)
else:
self.am_ckpt = os.path.abspath(am_ckpt)
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(self.am_ckpt)
self.am_fs = fs
# for speedyspeech
self.tones_dict = None
if 'tones_dict' in self.task_resource.res_dict:
self.tones_dict = os.path.join(
self.am_res_path, self.task_resource.res_dict['tones_dict'])
if tones_dict:
self.tones_dict = tones_dict
# voc
if voc_ckpt is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_lang = lang
# we must use ljspeech's voc for mix am now!
if lang == 'mix':
voc_lang = 'en'
voc_tag = voc + '_onnx' + '-' + voc_lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version
)
if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir
self.voc_ckpt = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
logger.debug(self.voc_res_path)
logger.debug(self.voc_ckpt)
else: else:
voc = voc_class(**self.voc_config["model"]) self.voc_ckpt = os.path.abspath(voc_ckpt)
voc.set_state_dict(paddle.load(self.voc_ckpt)["main_params"]) self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt))
voc.eval()
voc_mu, voc_std = np.load(self.voc_stat) # frontend
voc_mu = paddle.to_tensor(voc_mu) self.frontend = get_frontend(
voc_std = paddle.to_tensor(voc_std) lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict)
voc_normalizer = ZScore(voc_mu, voc_std) self.am_sess = get_sess(
self.voc_inference = voc_inference_class(voc_normalizer, voc) model_path=self.am_ckpt, device=device, cpu_threads=cpu_threads)
self.voc_inference.eval()
# vocoder
self.voc_sess = get_sess(
model_path=self.voc_ckpt, device=device, cpu_threads=cpu_threads)
def preprocess(self, input: Any, *args, **kwargs): def preprocess(self, input: Any, *args, **kwargs):
""" """
@ -362,40 +437,28 @@ class TTSExecutor(BaseExecutor):
""" """
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
merge_sentences = False merge_sentences = False
frontend_st = time.time() get_tone_ids = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if lang == 'zh': frontend_st = time.time()
input_ids = self.frontend.get_input_ids( frontend_dict = run_frontend(
text, frontend=self.frontend,
text=text,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=lang)
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
logger.error("lang should in {'zh', 'en', 'mix'}!")
self.frontend_time = time.time() - frontend_st self.frontend_time = time.time() - frontend_st
self.am_time = 0 self.am_time = 0
self.voc_time = 0 self.voc_time = 0
flags = 0 flags = 0
phone_ids = frontend_dict['phone_ids']
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
am_st = time.time() am_st = time.time()
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
# am # am
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i] part_tone_ids = frontend_dict['tone_ids'][i]
mel = self.am_inference(part_phone_ids, part_tone_ids) mel = self.am_inference(part_phone_ids, part_tone_ids)
# fastspeech2 # fastspeech2
else: else:
@ -417,6 +480,62 @@ class TTSExecutor(BaseExecutor):
self.voc_time += (time.time() - voc_st) self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all self._outputs['wav'] = wav_all
def infer_onnx(self,
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
spk_id: int=0):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
merge_sentences = False
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
am_input_feed = {}
frontend_st = time.time()
frontend_dict = run_frontend(
frontend=self.frontend,
text=text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
lang=lang,
to_tensor=False)
self.frontend_time = time.time() - frontend_st
phone_ids = frontend_dict['phone_ids']
self.am_time = 0
self.voc_time = 0
flags = 0
for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i]
if am_name == 'fastspeech2':
am_input_feed.update({'text': part_phone_ids})
if am_dataset in {"aishell3", "vctk"}:
# NOTE: 'spk_id' should be List[int] rather than int here!!
am_input_feed.update({'spk_id': [spk_id]})
elif am_name == 'speedyspeech':
part_tone_ids = frontend_dict['tone_ids'][i]
am_input_feed.update({
'phones': part_phone_ids,
'tones': part_tone_ids
})
mel = self.am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0]
self.am_time += (time.time() - am_st)
# voc
voc_st = time.time()
wav = self.voc_sess.run(
output_names=None, input_feed={'logmel': mel})
wav = wav[0]
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = np.concatenate([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all
def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]: def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]:
""" """
Output postprocess and return results. Output postprocess and return results.
@ -430,6 +549,20 @@ class TTSExecutor(BaseExecutor):
output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs) output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs)
return output return output
def postprocess_onnx(self,
output: str='output.wav') -> Union[str, os.PathLike]:
"""
Output postprocess and return results.
This method get model output from self._outputs and convert it into human-readable results.
Returns:
Union[str, os.PathLike]: Human-readable results such as texts and audio files.
"""
output = os.path.abspath(os.path.expanduser(output))
sf.write(output, self._outputs['wav'], samplerate=self.am_fs)
return output
# 命令行的入口是这里
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
""" """
Command line entry. Command line entry.
@ -451,6 +584,9 @@ class TTSExecutor(BaseExecutor):
lang = args.lang lang = args.lang
device = args.device device = args.device
spk_id = args.spk_id spk_id = args.spk_id
use_onnx = args.use_onnx
cpu_threads = args.cpu_threads
fs = args.fs
if not args.verbose: if not args.verbose:
self.disable_task_loggers() self.disable_task_loggers()
@ -487,7 +623,10 @@ class TTSExecutor(BaseExecutor):
# other # other
lang=lang, lang=lang,
device=device, device=device,
output=output) output=output,
use_onnx=use_onnx,
cpu_threads=cpu_threads,
fs=fs)
task_results[id_] = res task_results[id_] = res
except Exception as e: except Exception as e:
has_exceptions = True has_exceptions = True
@ -501,6 +640,7 @@ class TTSExecutor(BaseExecutor):
else: else:
return True return True
# pyton api 的入口是这里
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
text: str, text: str,
@ -512,16 +652,20 @@ class TTSExecutor(BaseExecutor):
phones_dict: Optional[os.PathLike]=None, phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
voc: str='pwgan_csmsc', voc: str='hifigan_csmsc',
voc_config: Optional[os.PathLike]=None, voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None, voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, voc_stat: Optional[os.PathLike]=None,
lang: str='zh', lang: str='zh',
device: str=paddle.get_device(), device: str=paddle.get_device(),
output: str='output.wav'): output: str='output.wav',
use_onnx: bool=False,
cpu_threads: int=2,
fs: int=24000):
""" """
Python API to call an executor. Python API to call an executor.
""" """
if not use_onnx:
paddle.set_device(device) paddle.set_device(device)
self._init_from_path( self._init_from_path(
am=am, am=am,
@ -538,7 +682,29 @@ class TTSExecutor(BaseExecutor):
lang=lang) lang=lang)
self.infer(text=text, lang=lang, am=am, spk_id=spk_id) self.infer(text=text, lang=lang, am=am, spk_id=spk_id)
res = self.postprocess(output=output) res = self.postprocess(output=output)
return res
else:
# use onnx
# we use `cpu` for onnxruntime by default
# please see description in https://github.com/PaddlePaddle/PaddleSpeech/pull/2220
self.task_resource = CommonTaskResource(
task='tts', model_format='onnx')
assert (
am in ONNX_SUPPORT_SET and voc in ONNX_SUPPORT_SET
), f'the am and voc you choose, they should be in {ONNX_SUPPORT_SET}'
self._init_from_path_onnx(
am=am,
am_ckpt=am_ckpt,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
voc=voc,
voc_ckpt=voc_ckpt,
lang=lang,
device=device,
cpu_threads=cpu_threads,
fs=fs)
self.infer_onnx(text=text, lang=lang, am=am, spk_id=spk_id)
res = self.postprocess_onnx(output=output)
return res return res

@ -1095,7 +1095,8 @@ tts_onnx_pretrained_models = {
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip', 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip',
'md5': 'md5':
'3e9c45af9ef70675fc1968ed5074fc88', '3e9c45af9ef70675fc1968ed5074fc88',
'ckpt': ['speedyspeech_csmsc.onnx'], 'ckpt':
'speedyspeech_csmsc.onnx',
'phones_dict': 'phones_dict':
'phone_id_map.txt', 'phone_id_map.txt',
'tones_dict': 'tones_dict':
@ -1111,7 +1112,8 @@ tts_onnx_pretrained_models = {
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip', 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip',
'md5': 'md5':
'fd3ad38d83273ad51f0ea4f4abf3ab4e', 'fd3ad38d83273ad51f0ea4f4abf3ab4e',
'ckpt': ['fastspeech2_csmsc.onnx'], 'ckpt':
'fastspeech2_csmsc.onnx',
'phones_dict': 'phones_dict':
'phone_id_map.txt', 'phone_id_map.txt',
'sample_rate': 'sample_rate':
@ -1124,7 +1126,8 @@ tts_onnx_pretrained_models = {
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_onnx_1.1.0.zip', 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_ljspeech_onnx_1.1.0.zip',
'md5': 'md5':
'00754307636a48c972a5f3e65cda3d18', '00754307636a48c972a5f3e65cda3d18',
'ckpt': ['fastspeech2_ljspeech.onnx'], 'ckpt':
'fastspeech2_ljspeech.onnx',
'phones_dict': 'phones_dict':
'phone_id_map.txt', 'phone_id_map.txt',
'sample_rate': 'sample_rate':
@ -1137,7 +1140,8 @@ tts_onnx_pretrained_models = {
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_onnx_1.1.0.zip', 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_onnx_1.1.0.zip',
'md5': 'md5':
'a1d6ee21de897ce394f5469e2bb4df0d', 'a1d6ee21de897ce394f5469e2bb4df0d',
'ckpt': ['fastspeech2_aishell3.onnx'], 'ckpt':
'fastspeech2_aishell3.onnx',
'phones_dict': 'phones_dict':
'phone_id_map.txt', 'phone_id_map.txt',
'speaker_dict': 'speaker_dict':
@ -1149,10 +1153,11 @@ tts_onnx_pretrained_models = {
"fastspeech2_vctk_onnx-en": { "fastspeech2_vctk_onnx-en": {
'1.0': { '1.0': {
'url': 'url':
'hhttps://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip', 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip',
'md5': 'md5':
'd9c3a9b02204a2070504dd99f5f959bf', 'd9c3a9b02204a2070504dd99f5f959bf',
'ckpt': ['fastspeech2_vctk.onnx'], 'ckpt':
'fastspeech2_vctk.onnx',
'phones_dict': 'phones_dict':
'phone_id_map.txt', 'phone_id_map.txt',
'speaker_dict': 'speaker_dict':

@ -389,6 +389,7 @@ class DataLoaderFactory():
config['mini_batch_size'] = args.ngpu config['mini_batch_size'] = args.ngpu
config['subsampling_factor'] = 1 config['subsampling_factor'] = 1
config['num_encs'] = 1 config['num_encs'] = 1
config['shortest_first'] = False
elif mode == 'valid': elif mode == 'valid':
config['manifest'] = config.dev_manifest config['manifest'] = config.dev_manifest
config['train_mode'] = False config['train_mode'] = False

@ -160,7 +160,7 @@ class ASRWsAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
logger.debug("client receive msg={}".format(msg)) logger.info("client receive msg={}".format(msg))
# 3. send chunk audio data to engine # 3. send chunk audio data to engine
for chunk_data in self.read_wave(wavfile_path): for chunk_data in self.read_wave(wavfile_path):
@ -170,7 +170,7 @@ class ASRWsAudioHandler:
if self.punc_server and len(msg["result"]) > 0: if self.punc_server and len(msg["result"]) > 0:
msg["result"] = self.punc_server.run(msg["result"]) msg["result"] = self.punc_server.run(msg["result"])
logger.debug("client receive msg={}".format(msg)) logger.info("client receive msg={}".format(msg))
# 4. we must send finished signal to the server # 4. we must send finished signal to the server
audio_info = json.dumps( audio_info = json.dumps(
@ -317,7 +317,7 @@ class TTSWsHandler:
start_request = json.dumps({"task": "tts", "signal": "start"}) start_request = json.dumps({"task": "tts", "signal": "start"})
await ws.send(start_request) await ws.send(start_request)
msg = await ws.recv() msg = await ws.recv()
logger.debug(f"client receive msg={msg}") logger.info(f"client receive msg={msg}")
msg = json.loads(msg) msg = json.loads(msg)
session = msg["session"] session = msg["session"]

@ -86,11 +86,6 @@ def parse_args():
"--inference_dir", type=str, help="dir to save inference models") "--inference_dir", type=str, help="dir to save inference models")
parser.add_argument("--output_dir", type=str, help="output dir") parser.add_argument("--output_dir", type=str, help="output dir")
# inference # inference
parser.add_argument(
"--use_trt",
type=str2bool,
default=False,
help="Whether to use inference engin TensorRT.", )
parser.add_argument( parser.add_argument(
"--int8", "--int8",
type=str2bool, type=str2bool,
@ -187,7 +182,7 @@ def main():
speed = wav.size / t.elapse speed = wav.size / t.elapse
rtf = fs / speed rtf = fs / speed
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=fs)
print( print(
f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
) )

@ -27,6 +27,7 @@ from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output
from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.exps.syn_utils import get_voc_output
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -175,14 +176,13 @@ def main():
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
# frontend # frontend
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
else: phone_ids = frontend_dict['phone_ids']
print("lang should be 'zh' here!")
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
# acoustic model # acoustic model
orig_hs = get_am_sublayer_output( orig_hs = get_am_sublayer_output(

@ -41,17 +41,17 @@ def ort_predict(args):
# am # am
am_sess = get_sess( am_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.am + '.onnx')),
model_file=args.am + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# vocoder # vocoder
voc_sess = get_sess( voc_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')),
model_file=args.voc + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# am warmup # am warmup
for T in [27, 38, 54]: for T in [27, 38, 54]:

@ -22,6 +22,7 @@ from timer import timer
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -42,17 +43,17 @@ def ort_predict(args):
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
am_sess = get_sess( am_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.am + '.onnx')),
model_file=args.am + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# vocoder # vocoder
voc_sess = get_sess( voc_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')),
model_file=args.voc + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
merge_sentences = True merge_sentences = True
@ -78,7 +79,6 @@ def ort_predict(args):
am_input_feed.update({'text': phone_ids}) am_input_feed.update({'text': phone_ids})
if am_dataset in {"aishell3", "vctk"}: if am_dataset in {"aishell3", "vctk"}:
am_input_feed.update({'spk_id': spk_id}) am_input_feed.update({'spk_id': spk_id})
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
phone_ids = np.random.randint(1, 92, size=(T, )) phone_ids = np.random.randint(1, 92, size=(T, ))
tone_ids = np.random.randint(1, 5, size=(T, )) tone_ids = np.random.randint(1, 5, size=(T, ))
@ -93,50 +93,51 @@ def ort_predict(args):
N = 0 N = 0
T = 0 T = 0
merge_sentences = True merge_sentences = False
get_tone_ids = False get_tone_ids = False
am_input_feed = {}
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
am_input_feed = {}
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
if get_tone_ids: phone_ids = frontend_dict['phone_ids']
tone_ids = input_ids["tone_ids"] flags = 0
elif args.lang == 'en': for i in range(len(phone_ids)):
input_ids = frontend.get_input_ids( part_phone_ids = phone_ids[i].numpy()
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
# merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0].numpy()
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
am_input_feed.update({'text': phone_ids}) am_input_feed.update({'text': part_phone_ids})
if am_dataset in {"aishell3", "vctk"}: if am_dataset in {"aishell3", "vctk"}:
am_input_feed.update({'spk_id': spk_id}) am_input_feed.update({'spk_id': spk_id})
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
tone_ids = tone_ids[0].numpy() part_tone_ids = frontend_dict['tone_ids'][i].numpy()
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids}) am_input_feed.update({
'phones': part_phone_ids,
'tones': part_tone_ids
})
mel = am_sess.run(output_names=None, input_feed=am_input_feed) mel = am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0] mel = mel[0]
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) wav = voc_sess.run(
output_names=None, input_feed={'logmel': mel})
N += len(wav[0]) wav = wav[0]
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = np.concatenate([wav_all, wav])
wav = wav_all
N += len(wav)
T += t.elapse T += t.elapse
speed = len(wav[0]) / t.elapse speed = len(wav) / t.elapse
rtf = fs / speed rtf = fs / speed
sf.write( sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=fs)
str(output_dir / (utt_id + ".wav")),
np.array(wav)[0],
samplerate=fs)
print( print(
f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." f"{utt_id}, mel: {mel.shape}, wave: {len(wav)}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
) )
print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")

@ -24,6 +24,7 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -45,29 +46,33 @@ def ort_predict(args):
# streaming acoustic model # streaming acoustic model
am_encoder_infer_sess = get_sess( am_encoder_infer_sess = get_sess(
model_dir=args.inference_dir, model_path=str(
model_file=args.am + "_am_encoder_infer" + ".onnx", Path(args.inference_dir) /
(args.am + '_am_encoder_infer' + '.onnx')),
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
am_decoder_sess = get_sess( am_decoder_sess = get_sess(
model_dir=args.inference_dir, model_path=str(
model_file=args.am + "_am_decoder" + ".onnx", Path(args.inference_dir) / (args.am + '_am_decoder' + '.onnx')),
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
am_postnet_sess = get_sess( am_postnet_sess = get_sess(
model_dir=args.inference_dir, model_path=str(
model_file=args.am + "_am_postnet" + ".onnx", Path(args.inference_dir) / (args.am + '_am_postnet' + '.onnx')),
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(args.am_stat)
# vocoder # vocoder
voc_sess = get_sess( voc_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')),
model_file=args.voc + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# frontend warmup # frontend warmup
# Loading model cost 0.5+ seconds # Loading model cost 0.5+ seconds
@ -102,14 +107,13 @@ def ort_predict(args):
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
else: phone_ids = frontend_dict['phone_ids']
print("lang should in be 'zh' here!")
# merge_sentences=True here, so we only use the first item of phone_ids # merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0].numpy() phone_ids = phone_ids[0].numpy()
orig_hs = am_encoder_infer_sess.run( orig_hs = am_encoder_infer_sess.run(

@ -33,6 +33,8 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
# remove [W:onnxruntime: xxx] from ort
ort.set_default_logger_severity(3)
model_alias = { model_alias = {
# acoustic model # acoustic model
@ -161,13 +163,42 @@ def get_frontend(lang: str='zh',
elif lang == 'mix': elif lang == 'mix':
frontend = MixFrontend( frontend = MixFrontend(
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
else: else:
print("wrong lang!") print("wrong lang!")
print("frontend done!")
return frontend return frontend
def run_frontend(frontend: object,
text: str,
merge_sentences: bool=False,
get_tone_ids: bool=False,
lang: str='zh',
to_tensor: bool=True):
outs = dict()
if lang == 'zh':
input_ids = frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
outs.update({'tone_ids': tone_ids})
elif lang == 'en':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
outs.update({'phone_ids': phone_ids})
return outs
# dygraph # dygraph
def get_am_inference(am: str='fastspeech2_csmsc', def get_am_inference(am: str='fastspeech2_csmsc',
am_config: CfgNode=None, am_config: CfgNode=None,
@ -180,30 +211,22 @@ def get_am_inference(am: str='fastspeech2_csmsc',
with open(phones_dict, "r") as f: with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if tones_dict is not None: if tones_dict is not None:
with open(tones_dict, "r") as f: with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None spk_num = None
if speaker_dict is not None: if speaker_dict is not None:
with open(speaker_dict, 'rt') as f: with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels odim = am_config.n_mels
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias) am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias)
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
am = am_class( am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
@ -228,7 +251,6 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
print("acoustic model done!")
if return_am: if return_am:
return am_inference, am return am_inference, am
else: else:
@ -260,7 +282,6 @@ def get_voc_inference(
voc_normalizer = ZScore(voc_mu, voc_std) voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference = voc_inference_class(voc_normalizer, voc) voc_inference = voc_inference_class(voc_normalizer, voc)
voc_inference.eval() voc_inference.eval()
print("voc done!")
return voc_inference return voc_inference
@ -342,9 +363,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None,
def get_am_output( def get_am_output(
input: str, input: str,
am_predictor, am_predictor: paddle.nn.Layer,
am, am: str,
frontend, frontend: object,
lang: str='zh', lang: str='zh',
merge_sentences: bool=True, merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
@ -352,30 +373,23 @@ def get_am_output(
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names() am_input_names = am_predictor.get_input_names()
get_tone_ids = False
get_spk_id = False get_spk_id = False
get_tone_ids = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True get_spk_id = True
spk_id = np.array([spk_id]) spk_id = np.array([spk_id])
if lang == 'zh':
input_ids = frontend.get_input_ids( frontend_dict = run_frontend(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) frontend=frontend,
phone_ids = input_ids["phone_ids"] text=input,
elif lang == 'en': merge_sentences=merge_sentences,
input_ids = frontend.get_input_ids( get_tone_ids=get_tone_ids,
input, merge_sentences=merge_sentences) lang=lang)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = frontend_dict['tone_ids']
tones = tone_ids[0].numpy() tones = tone_ids[0].numpy()
tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle = am_predictor.get_input_handle(am_input_names[1])
tones_handle.reshape(tones.shape) tones_handle.reshape(tones.shape)
@ -384,6 +398,7 @@ def get_am_output(
spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
spk_id_handle.reshape(spk_id.shape) spk_id_handle.reshape(spk_id.shape)
spk_id_handle.copy_from_cpu(spk_id) spk_id_handle.copy_from_cpu(spk_id)
phone_ids = frontend_dict['phone_ids']
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle = am_predictor.get_input_handle(am_input_names[0])
phones_handle.reshape(phones.shape) phones_handle.reshape(phones.shape)
@ -432,13 +447,13 @@ def get_streaming_am_output(input: str,
lang: str='zh', lang: str='zh',
merge_sentences: bool=True): merge_sentences: bool=True):
get_tone_ids = False get_tone_ids = False
if lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) text=input,
phone_ids = input_ids["phone_ids"] merge_sentences=merge_sentences,
else: get_tone_ids=get_tone_ids,
print("lang should be 'zh' here!") lang=lang)
phone_ids = frontend_dict['phone_ids']
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
am_encoder_infer_output = get_am_sublayer_output( am_encoder_infer_output = get_am_sublayer_output(
am_encoder_infer_predictor, input=phones) am_encoder_infer_predictor, input=phones)
@ -455,26 +470,25 @@ def get_streaming_am_output(input: str,
# onnx # onnx
def get_sess(model_dir: Optional[os.PathLike]=None, def get_sess(model_path: Optional[os.PathLike],
model_file: Optional[os.PathLike]=None,
device: str='cpu', device: str='cpu',
cpu_threads: int=1, cpu_threads: int=1,
use_trt: bool=False): use_trt: bool=False):
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if 'gpu' in device.lower():
if device == "gpu": device_id = int(device.split(':')[1]) if len(
device.split(':')) == 2 else 0
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if use_trt: if use_trt:
providers = ['TensorrtExecutionProvider'] provider_name = 'TensorrtExecutionProvider'
else: else:
providers = ['CUDAExecutionProvider'] provider_name = 'CUDAExecutionProvider'
elif device == "cpu": providers = [(provider_name, {'device_id': device_id})]
elif device.lower() == 'cpu':
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = cpu_threads sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession( sess = ort.InferenceSession(
model_dir, providers=providers, sess_options=sess_options) model_path, providers=providers, sess_options=sess_options)
return sess return sess

@ -25,6 +25,7 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.exps.syn_utils import voc_to_static
@ -49,6 +50,7 @@ def evaluate(args):
lang=args.lang, lang=args.lang,
phones_dict=args.phones_dict, phones_dict=args.phones_dict,
tones_dict=args.tones_dict) tones_dict=args.tones_dict)
print("frontend done!")
# acoustic model # acoustic model
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
@ -62,13 +64,14 @@ def evaluate(args):
phones_dict=args.phones_dict, phones_dict=args.phones_dict,
tones_dict=args.tones_dict, tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict) speaker_dict=args.speaker_dict)
print("acoustic model done!")
# vocoder # vocoder
voc_inference = get_voc_inference( voc_inference = get_voc_inference(
voc=args.voc, voc=args.voc,
voc_config=voc_config, voc_config=voc_config,
voc_ckpt=args.voc_ckpt, voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat) voc_stat=args.voc_stat)
print("voc done!")
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
@ -78,7 +81,6 @@ def evaluate(args):
am=args.am, am=args.am,
inference_dir=args.inference_dir, inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict) speaker_dict=args.speaker_dict)
# vocoder # vocoder
voc_inference = voc_to_static( voc_inference = voc_to_static(
voc_inference=voc_inference, voc_inference=voc_inference,
@ -101,24 +103,13 @@ def evaluate(args):
T = 0 T = 0
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
if get_tone_ids: phone_ids = frontend_dict['phone_ids']
tone_ids = input_ids["tone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'mix':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
with paddle.no_grad(): with paddle.no_grad():
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
@ -132,7 +123,7 @@ def evaluate(args):
else: else:
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
part_tone_ids = tone_ids[i] part_tone_ids = frontend_dict['tone_ids'][i]
if am_dataset in {"aishell3", "vctk"}: if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, part_tone_ids, mel = am_inference(part_phone_ids, part_tone_ids,

@ -30,6 +30,7 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import model_alias from paddlespeech.t2s.exps.syn_utils import model_alias
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.exps.syn_utils import voc_to_static
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
@ -138,15 +139,13 @@ def evaluate(args):
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
lang=args.lang)
phone_ids = input_ids["phone_ids"] phone_ids = frontend_dict['phone_ids']
else:
print("lang should be 'zh' here!")
# merge_sentences=True here, so we only use the first item of phone_ids # merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0] phone_ids = phone_ids[0]
with paddle.no_grad(): with paddle.no_grad():

@ -1,18 +1,11 @@
import os import os
import json import json
import requests
import zipfile
import onnxruntime import onnxruntime
import numpy as np import numpy as np
from io import BytesIO
import shutil
from transformers import BertTokenizer from opencc import OpenCC
try: from transformers import BertTokenizer
from opencc import OpenCC
except:
pass
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data, prepare_onnx_input, get_phoneme_labels, get_char_phoneme_labels from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data, prepare_onnx_input, get_phoneme_labels, get_char_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.utils import load_config from paddlespeech.t2s.frontend.g2pw.utils import load_config

@ -136,7 +136,8 @@ class MixFrontend():
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
get_tone_ids: bool=False, get_tone_ids: bool=False,
add_sp: bool=True) -> Dict[str, List[paddle.Tensor]]: add_sp: bool=True,
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
sentences = self._split(sentence) sentences = self._split(sentence)
phones_list = [] phones_list = []
@ -152,11 +153,12 @@ class MixFrontend():
input_ids = self.zh_frontend.get_input_ids( input_ids = self.zh_frontend.get_input_ids(
content, content,
merge_sentences=True, merge_sentences=True,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
elif lang == "en": elif lang == "en":
input_ids = self.en_frontend.get_input_ids( input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=True) content, merge_sentences=True, to_tensor=to_tensor)
phones_seg.append(input_ids["phone_ids"][0]) phones_seg.append(input_ids["phone_ids"][0])
if add_sp: if add_sp:

@ -82,8 +82,10 @@ class English(Phonetics):
phone_ids = [self.vocab_phones[item] for item in phonemes] phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64) return np.array(phone_ids, np.int64)
def get_input_ids(self, sentence: str, def get_input_ids(self,
merge_sentences: bool=False) -> paddle.Tensor: sentence: str,
merge_sentences: bool=False,
to_tensor: bool=True) -> paddle.Tensor:
result = {} result = {}
sentences = self.text_normalizer._split(sentence, lang="en") sentences = self.text_normalizer._split(sentence, lang="en")
phones_list = [] phones_list = []
@ -112,6 +114,7 @@ class English(Phonetics):
for part_phones_list in phones_list: for part_phones_list in phones_list:
phone_ids = self._p2id(part_phones_list) phone_ids = self._p2id(part_phones_list)
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids) temp_phone_ids.append(phone_ids)
result["phone_ids"] = temp_phone_ids result["phone_ids"] = temp_phone_ids

Loading…
Cancel
Save