add onnxruntime infer for cli

pull/2222/head
TianYuan 2 years ago
parent 070a08f2be
commit b9ade18055

@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi fi
# inference with onnxruntime, use fastspeech2 + hifigan by default # inference with onnxruntime, use fastspeech2 + pwgan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
./local/ort_predict.sh ${train_output_path} ./local/ort_predict.sh ${train_output_path}
fi fi

@ -55,7 +55,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_ljspeech # ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_ljspeech
fi fi
# inference with onnxruntime, use fastspeech2 + hifigan by default # inference with onnxruntime, use fastspeech2 + pwgan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
./local/ort_predict.sh ${train_output_path} ./local/ort_predict.sh ${train_output_path}
fi fi

@ -54,7 +54,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi fi
# inference with onnxruntime, use fastspeech2 + hifigan by default # inference with onnxruntime, use fastspeech2 + pwgan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
./local/ort_predict.sh ${train_output_path} ./local/ort_predict.sh ${train_output_path}
fi fi

@ -29,10 +29,21 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..log import logger from ..log import logger
from ..utils import stats_wrapper from ..utils import stats_wrapper
from paddlespeech.resource import CommonTaskResource
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool
__all__ = ['TTSExecutor'] __all__ = ['TTSExecutor']
ONNX_SUPPORT_SET = {
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'pwgan_csmsc', 'pwgan_ljspeech',
'pwgan_aishell3', 'pwgan_vctk', 'mb_melgan_csmsc', 'hifigan_csmsc',
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk'
}
class TTSExecutor(BaseExecutor): class TTSExecutor(BaseExecutor):
@ -142,6 +153,8 @@ class TTSExecutor(BaseExecutor):
default=paddle.get_device(), default=paddle.get_device(),
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
self.parser.add_argument('--cpu_threads', type=int, default=2)
self.parser.add_argument( self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name') '--output', type=str, default='output.wav', help='output file name')
self.parser.add_argument( self.parser.add_argument(
@ -154,6 +167,11 @@ class TTSExecutor(BaseExecutor):
'--verbose', '--verbose',
action='store_true', action='store_true',
help='Increase logger verbosity of current task.') help='Increase logger verbosity of current task.')
self.parser.add_argument(
"--use_onnx",
type=str2bool,
default=False,
help="whether to usen onnxruntime inference.")
def _init_from_path( def _init_from_path(
self, self,
@ -164,7 +182,7 @@ class TTSExecutor(BaseExecutor):
phones_dict: Optional[os.PathLike]=None, phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
voc: str='pwgan_csmsc', voc: str='hifigan_csmsc',
voc_config: Optional[os.PathLike]=None, voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None, voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, voc_stat: Optional[os.PathLike]=None,
@ -288,58 +306,111 @@ class TTSExecutor(BaseExecutor):
lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict) lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict)
# acoustic model # acoustic model
odim = self.am_config.n_mels self.am_inference = get_am_inference(
# model: {model_name}_{dataset} am=am,
am_name = am[:am.rindex('_')] am_config=self.am_config,
am_ckpt=self.am_ckpt,
am_stat=self.am_stat,
phones_dict=self.phones_dict,
tones_dict=self.tones_dict,
speaker_dict=self.speaker_dict)
am_class = self.task_resource.get_model_class(am_name) # vocoder
am_inference_class = self.task_resource.get_model_class(am_name + self.voc_inference = get_voc_inference(
'_inference') voc=voc,
voc_config=self.voc_config,
voc_ckpt=self.voc_ckpt,
voc_stat=self.voc_stat)
if am_name == 'fastspeech2': def _init_from_path_onnx(self,
am = am_class( am: str='fastspeech2_csmsc',
idim=vocab_size, am_ckpt: Optional[os.PathLike]=None,
odim=odim, phones_dict: Optional[os.PathLike]=None,
spk_num=spk_num, tones_dict: Optional[os.PathLike]=None,
**self.am_config["model"]) speaker_dict: Optional[os.PathLike]=None,
elif am_name == 'speedyspeech': voc: str='hifigan_csmsc',
am = am_class( voc_ckpt: Optional[os.PathLike]=None,
vocab_size=vocab_size, lang: str='zh',
tone_size=tone_size, device: str='cpu',
**self.am_config["model"]) cpu_threads: int=2,
elif am_name == 'tacotron2': fs: int=24000):
am = am_class(idim=vocab_size, odim=odim, **self.am_config["model"]) if hasattr(self, 'am_sess') and hasattr(self, 'voc_sess'):
logger.debug('Models had been initialized.')
am.set_state_dict(paddle.load(self.am_ckpt)["main_params"]) return
am.eval()
am_mu, am_std = np.load(self.am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval()
# vocoder # am
# model: {model_name}_{dataset} if am_ckpt is None or phones_dict is None:
voc_name = voc[:voc.rindex('_')] use_pretrained_am = True
voc_class = self.task_resource.get_model_class(voc_name) else:
voc_inference_class = self.task_resource.get_model_class(voc_name + use_pretrained_am = False
'_inference')
if voc_name != 'wavernn': am_tag = am + '_onnx' + '-' + lang
voc = voc_class(**self.voc_config["generator_params"]) self.task_resource.set_task_model(
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) model_tag=am_tag,
voc.remove_weight_norm() model_type=0, # am
voc.eval() skip_download=not use_pretrained_am,
version=None, # default version
)
if use_pretrained_am:
self.am_res_path = self.task_resource.res_dir
self.am_ckpt = os.path.join(self.am_res_path,
self.task_resource.res_dict['ckpt'][0])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
self.am_res_path, self.task_resource.res_dict['phones_dict'])
self.am_fs = self.task_resource.res_dict['sample_rate']
logger.debug(self.am_res_path)
logger.debug(self.am_ckpt)
else: else:
voc = voc_class(**self.voc_config["model"]) self.am_ckpt = os.path.abspath(am_ckpt[0])
voc.set_state_dict(paddle.load(self.voc_ckpt)["main_params"]) self.phones_dict = os.path.abspath(phones_dict)
voc.eval() self.am_res_path = os.path.dirname(os.path.abspath(am_ckpt))
voc_mu, voc_std = np.load(self.voc_stat) self.am_fs = fs
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std) # for speedyspeech
voc_normalizer = ZScore(voc_mu, voc_std) self.tones_dict = None
self.voc_inference = voc_inference_class(voc_normalizer, voc) if 'tones_dict' in self.task_resource.res_dict:
self.voc_inference.eval() self.tones_dict = os.path.join(
self.am_res_path, self.task_resource.res_dict['tones_dict'])
if tones_dict:
self.tones_dict = tones_dict
# voc
if voc_ckpt is None:
use_pretrained_voc = True
else:
use_pretrained_voc = False
voc_lang = lang
# we must use ljspeech's voc for mix am now!
if lang == 'mix':
voc_lang = 'en'
voc_tag = voc + '_onnx' + '-' + voc_lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
skip_download=not use_pretrained_voc,
version=None, # default version
)
if use_pretrained_voc:
self.voc_res_path = self.task_resource.voc_res_dir
self.voc_ckpt = os.path.join(
self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
logger.debug(self.voc_res_path)
logger.debug(self.voc_ckpt)
else:
self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt))
# frontend
self.frontend = get_frontend(
lang=lang, phones_dict=self.phones_dict, tones_dict=self.tones_dict)
self.am_sess = get_sess(
model_path=self.am_ckpt, device=device, cpu_threads=cpu_threads)
# vocoder
self.voc_sess = get_sess(
model_path=self.voc_ckpt, device=device, cpu_threads=cpu_threads)
def preprocess(self, input: Any, *args, **kwargs): def preprocess(self, input: Any, *args, **kwargs):
""" """
@ -362,40 +433,28 @@ class TTSExecutor(BaseExecutor):
""" """
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
merge_sentences = False merge_sentences = False
frontend_st = time.time() get_tone_ids = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if lang == 'zh': frontend_st = time.time()
input_ids = self.frontend.get_input_ids( frontend_dict = run_frontend(
text, frontend=self.frontend,
text=text,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=lang)
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
logger.error("lang should in {'zh', 'en', 'mix'}!")
self.frontend_time = time.time() - frontend_st self.frontend_time = time.time() - frontend_st
self.am_time = 0 self.am_time = 0
self.voc_time = 0 self.voc_time = 0
flags = 0 flags = 0
phone_ids = frontend_dict['phone_ids']
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
am_st = time.time() am_st = time.time()
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
# am # am
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i] part_tone_ids = frontend_dict['tone_ids'][i]
mel = self.am_inference(part_phone_ids, part_tone_ids) mel = self.am_inference(part_phone_ids, part_tone_ids)
# fastspeech2 # fastspeech2
else: else:
@ -417,6 +476,62 @@ class TTSExecutor(BaseExecutor):
self.voc_time += (time.time() - voc_st) self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all self._outputs['wav'] = wav_all
def infer_onnx(self,
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
spk_id: int=0):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
merge_sentences = False
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
am_input_feed = {}
frontend_st = time.time()
frontend_dict = run_frontend(
frontend=self.frontend,
text=text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
lang=lang,
to_tensor=False)
self.frontend_time = time.time() - frontend_st
phone_ids = frontend_dict['phone_ids']
self.am_time = 0
self.voc_time = 0
flags = 0
for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i]
if am_name == 'fastspeech2':
am_input_feed.update({'text': part_phone_ids})
if am_dataset in {"aishell3", "vctk"}:
# NOTE: 'spk_id' should be List[int] rather than int here!!
am_input_feed.update({'spk_id': [spk_id]})
elif am_name == 'speedyspeech':
part_tone_ids = frontend_dict['tone_ids'][i]
am_input_feed.update({
'phones': part_phone_ids,
'tones': part_tone_ids
})
mel = self.am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0]
self.am_time += (time.time() - am_st)
# voc
voc_st = time.time()
wav = self.voc_sess.run(
output_names=None, input_feed={'logmel': mel})
wav = wav[0]
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = np.concatenate([wav_all, wav])
self.voc_time += (time.time() - voc_st)
self._outputs['wav'] = wav_all
def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]: def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]:
""" """
Output postprocess and return results. Output postprocess and return results.
@ -430,6 +545,20 @@ class TTSExecutor(BaseExecutor):
output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs) output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs)
return output return output
def postprocess_onnx(self,
output: str='output.wav') -> Union[str, os.PathLike]:
"""
Output postprocess and return results.
This method get model output from self._outputs and convert it into human-readable results.
Returns:
Union[str, os.PathLike]: Human-readable results such as texts and audio files.
"""
output = os.path.abspath(os.path.expanduser(output))
sf.write(output, self._outputs['wav'], samplerate=self.am_fs)
return output
# 命令行的入口是这里
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
""" """
Command line entry. Command line entry.
@ -451,6 +580,8 @@ class TTSExecutor(BaseExecutor):
lang = args.lang lang = args.lang
device = args.device device = args.device
spk_id = args.spk_id spk_id = args.spk_id
use_onnx = args.use_onnx
cpu_threads = args.cpu_threads
if not args.verbose: if not args.verbose:
self.disable_task_loggers() self.disable_task_loggers()
@ -487,7 +618,9 @@ class TTSExecutor(BaseExecutor):
# other # other
lang=lang, lang=lang,
device=device, device=device,
output=output) output=output,
use_onnx=use_onnx,
cpu_threads=cpu_threads)
task_results[id_] = res task_results[id_] = res
except Exception as e: except Exception as e:
has_exceptions = True has_exceptions = True
@ -501,6 +634,7 @@ class TTSExecutor(BaseExecutor):
else: else:
return True return True
# pyton api 的入口是这里
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
text: str, text: str,
@ -512,16 +646,19 @@ class TTSExecutor(BaseExecutor):
phones_dict: Optional[os.PathLike]=None, phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
voc: str='pwgan_csmsc', voc: str='hifigan_csmsc',
voc_config: Optional[os.PathLike]=None, voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None, voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, voc_stat: Optional[os.PathLike]=None,
lang: str='zh', lang: str='zh',
device: str=paddle.get_device(), device: str=paddle.get_device(),
output: str='output.wav'): output: str='output.wav',
use_onnx: bool=False,
cpu_threads: int=2):
""" """
Python API to call an executor. Python API to call an executor.
""" """
if not use_onnx:
paddle.set_device(device) paddle.set_device(device)
self._init_from_path( self._init_from_path(
am=am, am=am,
@ -538,7 +675,28 @@ class TTSExecutor(BaseExecutor):
lang=lang) lang=lang)
self.infer(text=text, lang=lang, am=am, spk_id=spk_id) self.infer(text=text, lang=lang, am=am, spk_id=spk_id)
res = self.postprocess(output=output) res = self.postprocess(output=output)
return res
else:
# use onnx
# we use `cpu` for onnxruntime by default
# please see description in https://github.com/PaddlePaddle/PaddleSpeech/pull/2220
self.task_resource = CommonTaskResource(
task='tts', model_format='onnx')
assert (
am in ONNX_SUPPORT_SET and voc in ONNX_SUPPORT_SET
), f'the am and voc you choose, they should be in {ONNX_SUPPORT_SET}'
self._init_from_path_onnx(
am=am,
am_ckpt=am_ckpt,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
voc=voc,
voc_ckpt=voc_ckpt,
lang=lang,
device=device,
cpu_threads=cpu_threads)
self.infer_onnx(text=text, lang=lang, am=am, spk_id=spk_id)
res = self.postprocess_onnx(output=output)
return res return res

@ -1149,7 +1149,7 @@ tts_onnx_pretrained_models = {
"fastspeech2_vctk_onnx-en": { "fastspeech2_vctk_onnx-en": {
'1.0': { '1.0': {
'url': 'url':
'hhttps://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip', 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_onnx_1.1.0.zip',
'md5': 'md5':
'd9c3a9b02204a2070504dd99f5f959bf', 'd9c3a9b02204a2070504dd99f5f959bf',
'ckpt': ['fastspeech2_vctk.onnx'], 'ckpt': ['fastspeech2_vctk.onnx'],

@ -86,11 +86,6 @@ def parse_args():
"--inference_dir", type=str, help="dir to save inference models") "--inference_dir", type=str, help="dir to save inference models")
parser.add_argument("--output_dir", type=str, help="output dir") parser.add_argument("--output_dir", type=str, help="output dir")
# inference # inference
parser.add_argument(
"--use_trt",
type=str2bool,
default=False,
help="Whether to use inference engin TensorRT.", )
parser.add_argument( parser.add_argument(
"--int8", "--int8",
type=str2bool, type=str2bool,

@ -27,6 +27,7 @@ from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output
from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.exps.syn_utils import get_voc_output
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -175,14 +176,13 @@ def main():
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
# frontend # frontend
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
else: phone_ids = frontend_dict['phone_ids']
print("lang should be 'zh' here!")
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
# acoustic model # acoustic model
orig_hs = get_am_sublayer_output( orig_hs = get_am_sublayer_output(

@ -41,17 +41,17 @@ def ort_predict(args):
# am # am
am_sess = get_sess( am_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.am + '.onnx')),
model_file=args.am + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# vocoder # vocoder
voc_sess = get_sess( voc_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')),
model_file=args.voc + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# am warmup # am warmup
for T in [27, 38, 54]: for T in [27, 38, 54]:

@ -22,6 +22,7 @@ from timer import timer
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -42,17 +43,17 @@ def ort_predict(args):
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
am_sess = get_sess( am_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.am + '.onnx')),
model_file=args.am + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# vocoder # vocoder
voc_sess = get_sess( voc_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')),
model_file=args.voc + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
merge_sentences = True merge_sentences = True
@ -78,7 +79,6 @@ def ort_predict(args):
am_input_feed.update({'text': phone_ids}) am_input_feed.update({'text': phone_ids})
if am_dataset in {"aishell3", "vctk"}: if am_dataset in {"aishell3", "vctk"}:
am_input_feed.update({'spk_id': spk_id}) am_input_feed.update({'spk_id': spk_id})
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
phone_ids = np.random.randint(1, 92, size=(T, )) phone_ids = np.random.randint(1, 92, size=(T, ))
tone_ids = np.random.randint(1, 5, size=(T, )) tone_ids = np.random.randint(1, 5, size=(T, ))
@ -93,50 +93,51 @@ def ort_predict(args):
N = 0 N = 0
T = 0 T = 0
merge_sentences = True merge_sentences = False
get_tone_ids = False get_tone_ids = False
am_input_feed = {}
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
am_input_feed = {}
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
if get_tone_ids: phone_ids = frontend_dict['phone_ids']
tone_ids = input_ids["tone_ids"] flags = 0
elif args.lang == 'en': for i in range(len(phone_ids)):
input_ids = frontend.get_input_ids( part_phone_ids = phone_ids[i].numpy()
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
# merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0].numpy()
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
am_input_feed.update({'text': phone_ids}) am_input_feed.update({'text': part_phone_ids})
if am_dataset in {"aishell3", "vctk"}: if am_dataset in {"aishell3", "vctk"}:
am_input_feed.update({'spk_id': spk_id}) am_input_feed.update({'spk_id': spk_id})
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
tone_ids = tone_ids[0].numpy() part_tone_ids = frontend_dict['tone_ids'][i].numpy()
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids}) am_input_feed.update({
'phones': part_phone_ids,
'tones': part_tone_ids
})
mel = am_sess.run(output_names=None, input_feed=am_input_feed) mel = am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0] mel = mel[0]
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) wav = voc_sess.run(
output_names=None, input_feed={'logmel': mel})
N += len(wav[0]) wav = wav[0]
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = np.concatenate([wav_all, wav])
wav = wav_all
N += len(wav)
T += t.elapse T += t.elapse
speed = len(wav[0]) / t.elapse speed = len(wav) / t.elapse
rtf = fs / speed rtf = fs / speed
sf.write( sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=fs)
str(output_dir / (utt_id + ".wav")),
np.array(wav)[0],
samplerate=fs)
print( print(
f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." f"{utt_id}, mel: {mel.shape}, wave: {len(wav)}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
) )
print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")

@ -24,6 +24,7 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -45,29 +46,33 @@ def ort_predict(args):
# streaming acoustic model # streaming acoustic model
am_encoder_infer_sess = get_sess( am_encoder_infer_sess = get_sess(
model_dir=args.inference_dir, model_path=str(
model_file=args.am + "_am_encoder_infer" + ".onnx", Path(args.inference_dir) /
(args.am + '_am_encoder_infer' + '.onnx')),
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
am_decoder_sess = get_sess( am_decoder_sess = get_sess(
model_dir=args.inference_dir, model_path=str(
model_file=args.am + "_am_decoder" + ".onnx", Path(args.inference_dir) / (args.am + '_am_decoder' + '.onnx')),
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
am_postnet_sess = get_sess( am_postnet_sess = get_sess(
model_dir=args.inference_dir, model_path=str(
model_file=args.am + "_am_postnet" + ".onnx", Path(args.inference_dir) / (args.am + '_am_postnet' + '.onnx')),
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(args.am_stat)
# vocoder # vocoder
voc_sess = get_sess( voc_sess = get_sess(
model_dir=args.inference_dir, model_path=str(Path(args.inference_dir) / (args.voc + '.onnx')),
model_file=args.voc + ".onnx",
device=args.device, device=args.device,
cpu_threads=args.cpu_threads) cpu_threads=args.cpu_threads,
use_trt=args.use_trt)
# frontend warmup # frontend warmup
# Loading model cost 0.5+ seconds # Loading model cost 0.5+ seconds
@ -102,14 +107,13 @@ def ort_predict(args):
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
else: phone_ids = frontend_dict['phone_ids']
print("lang should in be 'zh' here!")
# merge_sentences=True here, so we only use the first item of phone_ids # merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0].numpy() phone_ids = phone_ids[0].numpy()
orig_hs = am_encoder_infer_sess.run( orig_hs = am_encoder_infer_sess.run(

@ -33,6 +33,8 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
# remove [W:onnxruntime: xxx] from ort
ort.set_default_logger_severity(3)
model_alias = { model_alias = {
# acoustic model # acoustic model
@ -161,13 +163,42 @@ def get_frontend(lang: str='zh',
elif lang == 'mix': elif lang == 'mix':
frontend = MixFrontend( frontend = MixFrontend(
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
else: else:
print("wrong lang!") print("wrong lang!")
print("frontend done!")
return frontend return frontend
def run_frontend(frontend: object,
text: str,
merge_sentences: bool=False,
get_tone_ids: bool=False,
lang: str='zh',
to_tensor: bool=True):
outs = dict()
if lang == 'zh':
input_ids = frontend.get_input_ids(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
outs.update({'tone_ids': tone_ids})
elif lang == 'en':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = frontend.get_input_ids(
text, merge_sentences=merge_sentences, to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
outs.update({'phone_ids': phone_ids})
return outs
# dygraph # dygraph
def get_am_inference(am: str='fastspeech2_csmsc', def get_am_inference(am: str='fastspeech2_csmsc',
am_config: CfgNode=None, am_config: CfgNode=None,
@ -180,30 +211,22 @@ def get_am_inference(am: str='fastspeech2_csmsc',
with open(phones_dict, "r") as f: with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if tones_dict is not None: if tones_dict is not None:
with open(tones_dict, "r") as f: with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None spk_num = None
if speaker_dict is not None: if speaker_dict is not None:
with open(speaker_dict, 'rt') as f: with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels odim = am_config.n_mels
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias) am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias)
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
am = am_class( am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
@ -228,7 +251,6 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
print("acoustic model done!")
if return_am: if return_am:
return am_inference, am return am_inference, am
else: else:
@ -260,7 +282,6 @@ def get_voc_inference(
voc_normalizer = ZScore(voc_mu, voc_std) voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference = voc_inference_class(voc_normalizer, voc) voc_inference = voc_inference_class(voc_normalizer, voc)
voc_inference.eval() voc_inference.eval()
print("voc done!")
return voc_inference return voc_inference
@ -342,9 +363,9 @@ def get_predictor(model_dir: Optional[os.PathLike]=None,
def get_am_output( def get_am_output(
input: str, input: str,
am_predictor, am_predictor: paddle.nn.Layer,
am, am: str,
frontend, frontend: object,
lang: str='zh', lang: str='zh',
merge_sentences: bool=True, merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None, speaker_dict: Optional[os.PathLike]=None,
@ -352,30 +373,23 @@ def get_am_output(
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names() am_input_names = am_predictor.get_input_names()
get_tone_ids = False
get_spk_id = False get_spk_id = False
get_tone_ids = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True get_spk_id = True
spk_id = np.array([spk_id]) spk_id = np.array([spk_id])
if lang == 'zh':
input_ids = frontend.get_input_ids( frontend_dict = run_frontend(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) frontend=frontend,
phone_ids = input_ids["phone_ids"] text=input,
elif lang == 'en': merge_sentences=merge_sentences,
input_ids = frontend.get_input_ids( get_tone_ids=get_tone_ids,
input, merge_sentences=merge_sentences) lang=lang)
phone_ids = input_ids["phone_ids"]
elif lang == 'mix':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
if get_tone_ids: if get_tone_ids:
tone_ids = input_ids["tone_ids"] tone_ids = frontend_dict['tone_ids']
tones = tone_ids[0].numpy() tones = tone_ids[0].numpy()
tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle = am_predictor.get_input_handle(am_input_names[1])
tones_handle.reshape(tones.shape) tones_handle.reshape(tones.shape)
@ -384,6 +398,7 @@ def get_am_output(
spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
spk_id_handle.reshape(spk_id.shape) spk_id_handle.reshape(spk_id.shape)
spk_id_handle.copy_from_cpu(spk_id) spk_id_handle.copy_from_cpu(spk_id)
phone_ids = frontend_dict['phone_ids']
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle = am_predictor.get_input_handle(am_input_names[0])
phones_handle.reshape(phones.shape) phones_handle.reshape(phones.shape)
@ -432,13 +447,13 @@ def get_streaming_am_output(input: str,
lang: str='zh', lang: str='zh',
merge_sentences: bool=True): merge_sentences: bool=True):
get_tone_ids = False get_tone_ids = False
if lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) text=input,
phone_ids = input_ids["phone_ids"] merge_sentences=merge_sentences,
else: get_tone_ids=get_tone_ids,
print("lang should be 'zh' here!") lang=lang)
phone_ids = frontend_dict['phone_ids']
phones = phone_ids[0].numpy() phones = phone_ids[0].numpy()
am_encoder_infer_output = get_am_sublayer_output( am_encoder_infer_output = get_am_sublayer_output(
am_encoder_infer_predictor, input=phones) am_encoder_infer_predictor, input=phones)
@ -455,26 +470,25 @@ def get_streaming_am_output(input: str,
# onnx # onnx
def get_sess(model_dir: Optional[os.PathLike]=None, def get_sess(model_path: Optional[os.PathLike],
model_file: Optional[os.PathLike]=None,
device: str='cpu', device: str='cpu',
cpu_threads: int=1, cpu_threads: int=1,
use_trt: bool=False): use_trt: bool=False):
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if 'gpu' in device.lower():
if device == "gpu": device_id = int(device.split(':')[1]) if len(
device.split(':')) == 2 else 0
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if use_trt: if use_trt:
providers = ['TensorrtExecutionProvider'] provider_name = 'TensorrtExecutionProvider'
else: else:
providers = ['CUDAExecutionProvider'] provider_name = 'CUDAExecutionProvider'
elif device == "cpu": providers = [(provider_name, {'device_id': device_id})]
elif device.lower() == 'cpu':
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = cpu_threads sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession( sess = ort.InferenceSession(
model_dir, providers=providers, sess_options=sess_options) model_path, providers=providers, sess_options=sess_options)
return sess return sess

@ -25,6 +25,7 @@ from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.exps.syn_utils import voc_to_static
@ -49,6 +50,7 @@ def evaluate(args):
lang=args.lang, lang=args.lang,
phones_dict=args.phones_dict, phones_dict=args.phones_dict,
tones_dict=args.tones_dict) tones_dict=args.tones_dict)
print("frontend done!")
# acoustic model # acoustic model
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
@ -62,13 +64,14 @@ def evaluate(args):
phones_dict=args.phones_dict, phones_dict=args.phones_dict,
tones_dict=args.tones_dict, tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict) speaker_dict=args.speaker_dict)
print("acoustic model done!")
# vocoder # vocoder
voc_inference = get_voc_inference( voc_inference = get_voc_inference(
voc=args.voc, voc=args.voc,
voc_config=voc_config, voc_config=voc_config,
voc_ckpt=args.voc_ckpt, voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat) voc_stat=args.voc_stat)
print("voc done!")
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
@ -78,7 +81,6 @@ def evaluate(args):
am=args.am, am=args.am,
inference_dir=args.inference_dir, inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict) speaker_dict=args.speaker_dict)
# vocoder # vocoder
voc_inference = voc_to_static( voc_inference = voc_to_static(
voc_inference=voc_inference, voc_inference=voc_inference,
@ -101,24 +103,13 @@ def evaluate(args):
T = 0 T = 0
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
phone_ids = input_ids["phone_ids"] lang=args.lang)
if get_tone_ids: phone_ids = frontend_dict['phone_ids']
tone_ids = input_ids["tone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'mix':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en', 'mix'}!")
with paddle.no_grad(): with paddle.no_grad():
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
@ -132,7 +123,7 @@ def evaluate(args):
else: else:
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
part_tone_ids = tone_ids[i] part_tone_ids = frontend_dict['tone_ids'][i]
if am_dataset in {"aishell3", "vctk"}: if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, part_tone_ids, mel = am_inference(part_phone_ids, part_tone_ids,

@ -30,6 +30,7 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.exps.syn_utils import model_alias from paddlespeech.t2s.exps.syn_utils import model_alias
from paddlespeech.t2s.exps.syn_utils import run_frontend
from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.exps.syn_utils import voc_to_static
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
@ -138,15 +139,13 @@ def evaluate(args):
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
if args.lang == 'zh': frontend_dict = run_frontend(
input_ids = frontend.get_input_ids( frontend=frontend,
sentence, text=sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
lang=args.lang)
phone_ids = input_ids["phone_ids"] phone_ids = frontend_dict['phone_ids']
else:
print("lang should be 'zh' here!")
# merge_sentences=True here, so we only use the first item of phone_ids # merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0] phone_ids = phone_ids[0]
with paddle.no_grad(): with paddle.no_grad():

@ -136,7 +136,8 @@ class MixFrontend():
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
get_tone_ids: bool=False, get_tone_ids: bool=False,
add_sp: bool=True) -> Dict[str, List[paddle.Tensor]]: add_sp: bool=True,
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
sentences = self._split(sentence) sentences = self._split(sentence)
phones_list = [] phones_list = []
@ -152,11 +153,12 @@ class MixFrontend():
input_ids = self.zh_frontend.get_input_ids( input_ids = self.zh_frontend.get_input_ids(
content, content,
merge_sentences=True, merge_sentences=True,
get_tone_ids=get_tone_ids) get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
elif lang == "en": elif lang == "en":
input_ids = self.en_frontend.get_input_ids( input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=True) content, merge_sentences=True, to_tensor=to_tensor)
phones_seg.append(input_ids["phone_ids"][0]) phones_seg.append(input_ids["phone_ids"][0])
if add_sp: if add_sp:

@ -82,8 +82,10 @@ class English(Phonetics):
phone_ids = [self.vocab_phones[item] for item in phonemes] phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64) return np.array(phone_ids, np.int64)
def get_input_ids(self, sentence: str, def get_input_ids(self,
merge_sentences: bool=False) -> paddle.Tensor: sentence: str,
merge_sentences: bool=False,
to_tensor: bool=True) -> paddle.Tensor:
result = {} result = {}
sentences = self.text_normalizer._split(sentence, lang="en") sentences = self.text_normalizer._split(sentence, lang="en")
phones_list = [] phones_list = []
@ -112,6 +114,7 @@ class English(Phonetics):
for part_phones_list in phones_list: for part_phones_list in phones_list:
phone_ids = self._p2id(part_phones_list) phone_ids = self._p2id(part_phones_list)
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids) temp_phone_ids.append(phone_ids)
result["phone_ids"] = temp_phone_ids result["phone_ids"] = temp_phone_ids

@ -303,15 +303,15 @@ class Frontend():
print("----------------------------") print("----------------------------")
return phonemes return phonemes
def get_input_ids( def get_input_ids(self,
self,
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
get_tone_ids: bool=False, get_tone_ids: bool=False,
robot: bool=False, robot: bool=False,
print_info: bool=False, print_info: bool=False,
add_blank: bool=False, add_blank: bool=False,
blank_token: str="<pad>") -> Dict[str, List[paddle.Tensor]]: blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes( phonemes = self.get_phonemes(
sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
@ -322,19 +322,21 @@ class Frontend():
tones = [] tones = []
temp_phone_ids = [] temp_phone_ids = []
temp_tone_ids = [] temp_tone_ids = []
for part_phonemes in phonemes: for part_phonemes in phonemes:
phones, tones = self._get_phone_tone( phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids) part_phonemes, get_tone_ids=get_tone_ids)
if add_blank: if add_blank:
phones = insert_after_character(phones, blank_token) phones = insert_after_character(phones, blank_token)
if tones: if tones:
tone_ids = self._t2id(tones) tone_ids = self._t2id(tones)
if to_tensor:
tone_ids = paddle.to_tensor(tone_ids) tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids) temp_tone_ids.append(tone_ids)
if phones: if phones:
phone_ids = self._p2id(phones) phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids) temp_phone_ids.append(phone_ids)
if temp_tone_ids: if temp_tone_ids:

Loading…
Cancel
Save