Add end-to-end version of MFA FastSpeech2, test=tts

pull/2693/head
WongLaw 3 years ago
parent 35b0a1bbd9
commit 206d9e5663

@ -25,9 +25,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference \
--rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \
--rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \
--rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml
--use_rhy
fi
# for more GAN Vocoders
@ -49,9 +47,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference \
--rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \
--rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \
--rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml
--use_rhy
fi
# the pretrained models haven't release now
@ -73,9 +69,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \
--rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \
--rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml
--use_rhy
# --inference_dir=${train_output_path}/inference
fi
@ -98,9 +92,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference \
--rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \
--rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \
--rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml
--use_rhy
fi
@ -123,7 +115,5 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference \
--rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \
--rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \
--rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml
--use_rhy
fi

@ -1658,3 +1658,17 @@ g2pw_onnx_models = {
},
},
}
# ---------------------------------
# ------------- Rhy_frontend ---------------
# ---------------------------------
rhy_frontend_models = {
'rhy_e2e': {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Rhy_e2e/rhy_e2e_pretrain.zip',
'md5':
'2cc5a3fe9ced1e421f0a03929fb0d23c',
},
},
}

@ -164,12 +164,12 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
def get_frontend(lang: str='zh',
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
rhy_tuple=None):
use_rhy=False):
if lang == 'zh':
frontend = Frontend(
phone_vocab_path=phones_dict,
tone_vocab_path=tones_dict,
rhy_tuple=rhy_tuple)
use_rhy=use_rhy)
elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict)
elif lang == 'mix':

@ -45,17 +45,12 @@ def evaluate(args):
sentences = get_sentences(text_file=args.text, lang=args.lang)
if len(args.rhy_prediction_model) > 1:
rhy_tuple = (args.rhy_prediction_model, args.rhy_config, args.rhy_token)
else:
rhy_tuple = None
# frontend
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
rhy_tuple=rhy_tuple)
use_rhy=args.use_rhy)
print("frontend done!")
# acoustic model
@ -246,15 +241,8 @@ def parse_args():
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output_dir", type=str, help="output dir.")
parser.add_argument(
"--rhy_prediction_model",
type=str,
default="",
help="rhy prediction model path.")
parser.add_argument(
"--rhy_token", type=str, help="rhy prediction token path.")
parser.add_argument(
"--rhy_config", type=str, help="rhy prediction config path.")
parse.add_argument(
"--use_rhy", action="store_true", help="run rhythm frontend or not")
args = parser.parse_args()
return args

@ -18,19 +18,28 @@ import yaml
from paddlenlp.transformers import ErnieTokenizer
from yacs.config import CfgNode
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.resource.pretrained_models import rhy_frontend_models
from paddlespeech.text.models.ernie_linear import ErnieLinear
from paddlespeech.utils.env import MODEL_HOME
DefinedClassifier = {
'ErnieLinear': ErnieLinear,
}
model_version = '1.0'
class Rhy_predictor():
def __init__(self, model_path, config_path, punc_path):
with open(config_path) as f:
def __init__(
self,
model_dir: os.PathLike=MODEL_HOME, ):
uncompress_path = download_and_decompress(
rhy_frontend_models['rhy_e2e'][model_version], model_dir)
with open(os.path.join(uncompress_path, 'default.yaml')) as f:
config = CfgNode(yaml.safe_load(f))
self.punc_list = []
with open(punc_path, 'r') as f:
with open(os.path.join(uncompress_path, 'rhy_token'), 'r') as f:
for line in f:
self.punc_list.append(line.strip())
self.punc_list = [0] + self.punc_list
@ -38,7 +47,8 @@ class Rhy_predictor():
self.model = DefinedClassifier[config["model_type"]](**config["model"])
pretrained_token = config['data_params']['pretrained_token']
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
state_dict = paddle.load(model_path)
state_dict = paddle.load(
os.path.join(uncompress_path, 'snapshot_iter_153000.pdz'))
self.model.set_state_dict(state_dict["main_params"])
self.model.eval()

@ -84,7 +84,7 @@ class Frontend():
g2p_model="g2pW",
phone_vocab_path=None,
tone_vocab_path=None,
rhy_tuple=None):
use_rhy=False):
self.mix_ssml_processor = MixTextProcessor()
self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer()
@ -107,9 +107,9 @@ class Frontend():
'': [['lei5']],
'掺和': [['chan1'], ['huo5']]
}
if rhy_tuple is not None:
self.rhy_predictor = Rhy_predictor(rhy_tuple[0], rhy_tuple[1],
rhy_tuple[2])
self.use_rhy = use_rhy
if use_rhy:
self.rhy_predictor = Rhy_predictor()
print("Rhythm predictor loaded.")
# g2p_model can be pypinyin and g2pM and g2pW
self.g2p_model = g2p_model
@ -201,12 +201,12 @@ class Frontend():
segments = sentences
phones_list = []
for seg in segments:
if self.rhy_predictor is not None:
if self.use_rhy:
seg = self.rhy_predictor._clean_text(seg)
phones = []
# Replace all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg)
if self.rhy_predictor is not None:
if self.use_rhy:
seg = self.rhy_predictor.get_prediction(seg)
seg_cut = psg.lcut(seg)
initials = []
@ -215,14 +215,14 @@ class Frontend():
# 为了多音词获得更好的效果,这里采用整句预测
if self.g2p_model == "g2pW":
try:
if self.rhy_predictor is not None:
if self.use_rhy:
seg = self.rhy_predictor._clean_text(seg)
pinyins = self.g2pW_model(seg)[0]
except Exception:
# g2pW采用模型采用繁体输入如果有cover不了的简体词采用g2pM预测
print("[%s] not in g2pW dict,use g2pM" % seg)
pinyins = self.g2pM_model(seg, tone=True, char_split=False)
if self.rhy_predictor is not None:
if self.use_rhy:
rhy_text = self.rhy_predictor.get_prediction(seg)
final_py = self.rhy_predictor.pinyin_align(pinyins,
rhy_text)
@ -557,7 +557,7 @@ class Frontend():
merge_sentences=merge_sentences,
print_info=print_info,
robot=robot)
if self.rhy_predictor is not None:
if self.use_rhy:
phonemes = self.del_same_sp(phonemes)
phonemes = self.add_sp_ifno(phonemes)
result = {}

Loading…
Cancel
Save