fix vits dygraph to static

pull/2883/head
TianYuan 3 years ago
parent 7177fd0332
commit d46ca0866a

@ -446,8 +446,16 @@ def am_to_static(am_inference,
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'vits':
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
InputSpec([-1], dtype=paddle.int64),
InputSpec([1], dtype=paddle.int64),
])
else:
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = jit.load(os.path.join(inference_dir, am))
return am_inference

@ -42,6 +42,9 @@ def evaluate(args):
# frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
# acoustic model
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
spk_num = None
if args.speaker_dict is not None:
@ -78,7 +81,7 @@ def evaluate(args):
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False
@ -105,10 +108,12 @@ def evaluate(args):
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
spk_id = None
if spk_num is not None:
if am_dataset in {"aishell3", "vctk"
} and spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id)
# wav = vits_inference(text=part_phone_ids, sids=spk_id)
wav = vits_inference(part_phone_ids)
wav = vits_inference(part_phone_ids, spk_id)
else:
wav = vits_inference(part_phone_ids)
if flags == 0:
wav_all = wav
flags = 1

Loading…
Cancel
Save