Add end-to-end version of MFA FastSpeech2, test=tts

pull/2693/head
WongLaw 3 years ago
parent ca7e12150d
commit 5251152d9b

@ -1668,7 +1668,7 @@ rhy_frontend_models = {
'url':
'https://paddlespeech.bj.bcebos.com/Rhy_e2e/rhy_e2e_pretrain.zip',
'md5':
'2cc5a3fe9ced1e421f0a03929fb0d23c',
'd36566b835977ea05ffbd9c0210c8e3c',
},
},
}

@ -241,7 +241,7 @@ def parse_args():
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output_dir", type=str, help="output dir.")
parse.add_argument(
parser.add_argument(
"--use_rhy", action="store_true", help="run rhythm frontend or not")
args = parser.parse_args()

@ -37,7 +37,7 @@ class Rhy_predictor():
model_dir: os.PathLike=MODEL_HOME, ):
uncompress_path = download_and_decompress(
rhy_frontend_models['rhy_e2e'][model_version], model_dir)
with open(os.path.join(uncompress_path, 'default.yaml')) as f:
with open(os.path.join(uncompress_path, 'rhy_default.yaml')) as f:
config = CfgNode(yaml.safe_load(f))
self.punc_list = []
with open(os.path.join(uncompress_path, 'rhy_token'), 'r') as f:
@ -45,11 +45,11 @@ class Rhy_predictor():
self.punc_list.append(line.strip())
self.punc_list = [0] + self.punc_list
self.make_rhy_dict()
self.model = DefinedClassifier[config["model_type"]](**config["model"])
self.model = DefinedClassifier["ErnieLinear"](**config["model"])
pretrained_token = config['data_params']['pretrained_token']
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
state_dict = paddle.load(
os.path.join(uncompress_path, 'snapshot_iter_153000.pdz'))
os.path.join(uncompress_path, 'snapshot_iter_2600.pdz'))
self.model.set_state_dict(state_dict["main_params"])
self.model.eval()

Loading…
Cancel
Save