|
|
|
@ -129,7 +129,10 @@ def evaluate(args):
|
|
|
|
|
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
|
|
|
|
|
elif am_name == 'speedyspeech':
|
|
|
|
|
am = am_class(
|
|
|
|
|
vocab_size=vocab_size, tone_size=tone_size, **am_config["model"])
|
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
|
tone_size=tone_size,
|
|
|
|
|
spk_num=spk_num,
|
|
|
|
|
**am_config["model"])
|
|
|
|
|
elif am_name == 'tacotron2':
|
|
|
|
|
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
|
|
|
|
|
|
|
|
|
@ -171,10 +174,6 @@ def evaluate(args):
|
|
|
|
|
InputSpec([-1], dtype=paddle.int64),
|
|
|
|
|
InputSpec([1], dtype=paddle.int64)
|
|
|
|
|
])
|
|
|
|
|
paddle.jit.save(am_inference,
|
|
|
|
|
os.path.join(args.inference_dir, args.am))
|
|
|
|
|
am_inference = paddle.jit.load(
|
|
|
|
|
os.path.join(args.inference_dir, args.am))
|
|
|
|
|
else:
|
|
|
|
|
am_inference = jit.to_static(
|
|
|
|
|
am_inference,
|
|
|
|
@ -184,6 +183,16 @@ def evaluate(args):
|
|
|
|
|
am_inference = paddle.jit.load(
|
|
|
|
|
os.path.join(args.inference_dir, args.am))
|
|
|
|
|
elif am_name == 'speedyspeech':
|
|
|
|
|
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
|
|
|
|
|
am_inference = jit.to_static(
|
|
|
|
|
am_inference,
|
|
|
|
|
input_spec=[
|
|
|
|
|
InputSpec([-1], dtype=paddle.int64), # text
|
|
|
|
|
InputSpec([-1], dtype=paddle.int64), # tone
|
|
|
|
|
None, # duration
|
|
|
|
|
InputSpec([-1], dtype=paddle.int64) # spk_id
|
|
|
|
|
])
|
|
|
|
|
else:
|
|
|
|
|
am_inference = jit.to_static(
|
|
|
|
|
am_inference,
|
|
|
|
|
input_spec=[
|
|
|
|
@ -242,6 +251,11 @@ def evaluate(args):
|
|
|
|
|
mel = am_inference(part_phone_ids)
|
|
|
|
|
elif am_name == 'speedyspeech':
|
|
|
|
|
part_tone_ids = tone_ids[i]
|
|
|
|
|
if am_dataset in {"aishell3", "vctk"}:
|
|
|
|
|
spk_id = paddle.to_tensor(args.spk_id)
|
|
|
|
|
mel = am_inference(part_phone_ids, part_tone_ids,
|
|
|
|
|
spk_id)
|
|
|
|
|
else:
|
|
|
|
|
mel = am_inference(part_phone_ids, part_tone_ids)
|
|
|
|
|
elif am_name == 'tacotron2':
|
|
|
|
|
mel = am_inference(part_phone_ids)
|
|
|
|
@ -269,8 +283,9 @@ def main():
|
|
|
|
|
type=str,
|
|
|
|
|
default='fastspeech2_csmsc',
|
|
|
|
|
choices=[
|
|
|
|
|
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
|
|
|
|
|
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc'
|
|
|
|
|
'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
|
|
|
|
|
'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
|
|
|
|
|
'tacotron2_csmsc'
|
|
|
|
|
],
|
|
|
|
|
help='Choose acoustic model type of tts task.')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|