Merge pull request #1370 from jerryuhoo/fix_multispk

[tts] Add speedyspeech multi-speaker support for synthesize_e2e.py
pull/1374/head
TianYuan 3 years ago committed by GitHub
commit 490762801f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -129,7 +129,10 @@ def evaluate(args):
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"]) idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
am = am_class( am = am_class(
vocab_size=vocab_size, tone_size=tone_size, **am_config["model"]) vocab_size=vocab_size,
tone_size=tone_size,
spk_num=spk_num,
**am_config["model"])
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
@ -171,10 +174,6 @@ def evaluate(args):
InputSpec([-1], dtype=paddle.int64), InputSpec([-1], dtype=paddle.int64),
InputSpec([1], dtype=paddle.int64) InputSpec([1], dtype=paddle.int64)
]) ])
paddle.jit.save(am_inference,
os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(
os.path.join(args.inference_dir, args.am))
else: else:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
@ -184,6 +183,16 @@ def evaluate(args):
am_inference = paddle.jit.load( am_inference = paddle.jit.load(
os.path.join(args.inference_dir, args.am)) os.path.join(args.inference_dir, args.am))
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
am_inference = jit.to_static(
am_inference,
input_spec=[
InputSpec([-1], dtype=paddle.int64), # text
InputSpec([-1], dtype=paddle.int64), # tone
None, # duration
InputSpec([-1], dtype=paddle.int64) # spk_id
])
else:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -242,6 +251,11 @@ def evaluate(args):
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
part_tone_ids = tone_ids[i] part_tone_ids = tone_ids[i]
if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, part_tone_ids,
spk_id)
else:
mel = am_inference(part_phone_ids, part_tone_ids) mel = am_inference(part_phone_ids, part_tone_ids)
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
mel = am_inference(part_phone_ids) mel = am_inference(part_phone_ids)
@ -269,8 +283,9 @@ def main():
type=str, type=str,
default='fastspeech2_csmsc', default='fastspeech2_csmsc',
choices=[ choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc' 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
'tacotron2_csmsc'
], ],
help='Choose acoustic model type of tts task.') help='Choose acoustic model type of tts task.')
parser.add_argument( parser.add_argument(

Loading…
Cancel
Save