diff --git a/examples/csmsc/tts3_rhy_e2e/local/synthesize_e2e.sh b/examples/csmsc/tts3_rhy_e2e/local/synthesize_e2e.sh index 8d82f6063..4eadf291a 100755 --- a/examples/csmsc/tts3_rhy_e2e/local/synthesize_e2e.sh +++ b/examples/csmsc/tts3_rhy_e2e/local/synthesize_e2e.sh @@ -25,9 +25,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --inference_dir=${train_output_path}/inference \ - --rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \ - --rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \ - --rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml + --use_rhy fi # for more GAN Vocoders @@ -49,9 +47,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --inference_dir=${train_output_path}/inference \ - --rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \ - --rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \ - --rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml + --use_rhy fi # the pretrained models haven't release now @@ -73,9 +69,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ - --rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \ - --rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \ - --rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml + --use_rhy # --inference_dir=${train_output_path}/inference fi @@ -98,9 +92,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --inference_dir=${train_output_path}/inference \ - --rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \ - --rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \ - --rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml + --use_rhy fi @@ -123,7 +115,5 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --inference_dir=${train_output_path}/inference \ - --rhy_prediction_model=${MAIN_ROOT}/examples/other/rhy/exp/default/snapshot_iter_2600.pdz \ - --rhy_token=${MAIN_ROOT}/examples/other/rhy/data/rhy_token \ - --rhy_config=${MAIN_ROOT}/examples/other/rhy/conf/default.yaml + --use_rhy fi diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 067246749..11d59f795 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -1658,3 +1658,17 @@ g2pw_onnx_models = { }, }, } + +# --------------------------------- +# ------------- Rhy_frontend --------------- +# --------------------------------- +rhy_frontend_models = { + 'rhy_e2e': { + '1.0': { + 'url': + 'https://paddlespeech.bj.bcebos.com/Rhy_e2e/rhy_e2e_pretrain.zip', + 'md5': + '2cc5a3fe9ced1e421f0a03929fb0d23c', + }, + }, +} diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 7c0a53904..e8e551fee 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -164,12 +164,12 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], def get_frontend(lang: str='zh', phones_dict: Optional[os.PathLike]=None, tones_dict: Optional[os.PathLike]=None, - rhy_tuple=None): + use_rhy=False): if lang == 'zh': frontend = Frontend( phone_vocab_path=phones_dict, tone_vocab_path=tones_dict, - rhy_tuple=rhy_tuple) + use_rhy=use_rhy) elif lang == 'en': frontend = English(phone_vocab_path=phones_dict) elif lang == 'mix': diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 2397c1152..df1c58d5a 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -45,17 +45,12 @@ def evaluate(args): sentences = get_sentences(text_file=args.text, lang=args.lang) - if len(args.rhy_prediction_model) > 1: - rhy_tuple = (args.rhy_prediction_model, args.rhy_config, args.rhy_token) - else: - rhy_tuple = None - # frontend frontend = get_frontend( lang=args.lang, phones_dict=args.phones_dict, tones_dict=args.tones_dict, - rhy_tuple=rhy_tuple) + use_rhy=args.use_rhy) print("frontend done!") # acoustic model @@ -246,15 +241,8 @@ def parse_args(): type=str, help="text to synthesize, a 'utt_id sentence' pair per line.") parser.add_argument("--output_dir", type=str, help="output dir.") - parser.add_argument( - "--rhy_prediction_model", - type=str, - default="", - help="rhy prediction model path.") - parser.add_argument( - "--rhy_token", type=str, help="rhy prediction token path.") - parser.add_argument( - "--rhy_config", type=str, help="rhy prediction config path.") + parse.add_argument( + "--use_rhy", action="store_true", help="run rhythm frontend or not") args = parser.parse_args() return args diff --git a/paddlespeech/t2s/frontend/rhy_prediction/rhy_predictor.py b/paddlespeech/t2s/frontend/rhy_prediction/rhy_predictor.py index 06978175d..22afe98b2 100644 --- a/paddlespeech/t2s/frontend/rhy_prediction/rhy_predictor.py +++ b/paddlespeech/t2s/frontend/rhy_prediction/rhy_predictor.py @@ -18,19 +18,28 @@ import yaml from paddlenlp.transformers import ErnieTokenizer from yacs.config import CfgNode +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.resource.pretrained_models import rhy_frontend_models from paddlespeech.text.models.ernie_linear import ErnieLinear +from paddlespeech.utils.env import MODEL_HOME DefinedClassifier = { 'ErnieLinear': ErnieLinear, } +model_version = '1.0' + class Rhy_predictor(): - def __init__(self, model_path, config_path, punc_path): - with open(config_path) as f: + def __init__( + self, + model_dir: os.PathLike=MODEL_HOME, ): + uncompress_path = download_and_decompress( + rhy_frontend_models['rhy_e2e'][model_version], model_dir) + with open(os.path.join(uncompress_path, 'default.yaml')) as f: config = CfgNode(yaml.safe_load(f)) self.punc_list = [] - with open(punc_path, 'r') as f: + with open(os.path.join(uncompress_path, 'rhy_token'), 'r') as f: for line in f: self.punc_list.append(line.strip()) self.punc_list = [0] + self.punc_list @@ -38,7 +47,8 @@ class Rhy_predictor(): self.model = DefinedClassifier[config["model_type"]](**config["model"]) pretrained_token = config['data_params']['pretrained_token'] self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token) - state_dict = paddle.load(model_path) + state_dict = paddle.load( + os.path.join(uncompress_path, 'snapshot_iter_153000.pdz')) self.model.set_state_dict(state_dict["main_params"]) self.model.eval() diff --git a/paddlespeech/t2s/frontend/zh_frontend.py b/paddlespeech/t2s/frontend/zh_frontend.py index a308759fd..740f618e1 100644 --- a/paddlespeech/t2s/frontend/zh_frontend.py +++ b/paddlespeech/t2s/frontend/zh_frontend.py @@ -84,7 +84,7 @@ class Frontend(): g2p_model="g2pW", phone_vocab_path=None, tone_vocab_path=None, - rhy_tuple=None): + use_rhy=False): self.mix_ssml_processor = MixTextProcessor() self.tone_modifier = ToneSandhi() self.text_normalizer = TextNormalizer() @@ -107,9 +107,9 @@ class Frontend(): '嘞': [['lei5']], '掺和': [['chan1'], ['huo5']] } - if rhy_tuple is not None: - self.rhy_predictor = Rhy_predictor(rhy_tuple[0], rhy_tuple[1], - rhy_tuple[2]) + self.use_rhy = use_rhy + if use_rhy: + self.rhy_predictor = Rhy_predictor() print("Rhythm predictor loaded.") # g2p_model can be pypinyin and g2pM and g2pW self.g2p_model = g2p_model @@ -201,12 +201,12 @@ class Frontend(): segments = sentences phones_list = [] for seg in segments: - if self.rhy_predictor is not None: + if self.use_rhy: seg = self.rhy_predictor._clean_text(seg) phones = [] # Replace all English words in the sentence seg = re.sub('[a-zA-Z]+', '', seg) - if self.rhy_predictor is not None: + if self.use_rhy: seg = self.rhy_predictor.get_prediction(seg) seg_cut = psg.lcut(seg) initials = [] @@ -215,14 +215,14 @@ class Frontend(): # 为了多音词获得更好的效果,这里采用整句预测 if self.g2p_model == "g2pW": try: - if self.rhy_predictor is not None: + if self.use_rhy: seg = self.rhy_predictor._clean_text(seg) pinyins = self.g2pW_model(seg)[0] except Exception: # g2pW采用模型采用繁体输入,如果有cover不了的简体词,采用g2pM预测 print("[%s] not in g2pW dict,use g2pM" % seg) pinyins = self.g2pM_model(seg, tone=True, char_split=False) - if self.rhy_predictor is not None: + if self.use_rhy: rhy_text = self.rhy_predictor.get_prediction(seg) final_py = self.rhy_predictor.pinyin_align(pinyins, rhy_text) @@ -557,7 +557,7 @@ class Frontend(): merge_sentences=merge_sentences, print_info=print_info, robot=robot) - if self.rhy_predictor is not None: + if self.use_rhy: phonemes = self.del_same_sp(phonemes) phonemes = self.add_sp_ifno(phonemes) result = {}