@ -29,8 +29,7 @@ from yacs.config import CfgNode
from . . executor import BaseExecutor
from . . log import logger
from . . utils import stats_wrapper
from paddlespeech . t2s . frontend import English
from paddlespeech . t2s . frontend . zh_frontend import Frontend
from paddlespeech . t2s . exps . syn_utils import get_frontend
from paddlespeech . t2s . modules . normalizer import ZScore
__all__ = [ ' TTSExecutor ' ]
@ -54,6 +53,7 @@ class TTSExecutor(BaseExecutor):
' fastspeech2_ljspeech ' ,
' fastspeech2_aishell3 ' ,
' fastspeech2_vctk ' ,
' fastspeech2_mix ' ,
' tacotron2_csmsc ' ,
' tacotron2_ljspeech ' ,
] ,
@ -98,7 +98,7 @@ class TTSExecutor(BaseExecutor):
self . parser . add_argument (
' --voc ' ,
type = str ,
default = ' pw gan_csmsc' ,
default = ' hifi gan_csmsc' ,
choices = [
' pwgan_csmsc ' ,
' pwgan_ljspeech ' ,
@ -135,7 +135,7 @@ class TTSExecutor(BaseExecutor):
' --lang ' ,
type = str ,
default = ' zh ' ,
help = ' Choose model language. zh or en ' )
help = ' Choose model language. zh or en or mix ' )
self . parser . add_argument (
' --device ' ,
type = str ,
@ -231,8 +231,11 @@ class TTSExecutor(BaseExecutor):
use_pretrained_voc = True
else :
use_pretrained_voc = False
voc_tag = voc + ' - ' + lang
voc_lang = lang
# we must use ljspeech's voc for mix am now!
if lang == ' mix ' :
voc_lang = ' en '
voc_tag = voc + ' - ' + voc_lang
self . task_resource . set_task_model (
model_tag = voc_tag ,
model_type = 1 , # vocoder
@ -281,13 +284,8 @@ class TTSExecutor(BaseExecutor):
spk_num = len ( spk_id )
# frontend
if lang == ' zh ' :
self . frontend = Frontend (
phone_vocab_path = self . phones_dict ,
tone_vocab_path = self . tones_dict )
elif lang == ' en ' :
self . frontend = English ( phone_vocab_path = self . phones_dict )
self . frontend = get_frontend (
lang = lang , phones_dict = self . phones_dict , tones_dict = self . tones_dict )
# acoustic model
odim = self . am_config . n_mels
@ -381,8 +379,12 @@ class TTSExecutor(BaseExecutor):
input_ids = self . frontend . get_input_ids (
text , merge_sentences = merge_sentences )
phone_ids = input_ids [ " phone_ids " ]
elif lang == ' mix ' :
input_ids = self . frontend . get_input_ids (
text , merge_sentences = merge_sentences )
phone_ids = input_ids [ " phone_ids " ]
else :
logger . error ( " lang should in { ' zh ' , ' en ' }! " )
logger . error ( " lang should in { ' zh ' , ' en ' , ' mix ' }!" )
self . frontend_time = time . time ( ) - frontend_st
self . am_time = 0
@ -398,7 +400,7 @@ class TTSExecutor(BaseExecutor):
# fastspeech2
else :
# multi speaker
if am_dataset in { " aishell3 " , " vctk " } :
if am_dataset in { ' aishell3 ' , ' vctk ' , ' mix ' } :
mel = self . am_inference (
part_phone_ids , spk_id = paddle . to_tensor ( spk_id ) )
else :