diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 07ff1cc72..ed37887b9 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -446,8 +446,16 @@ def am_to_static(am_inference, am_inference = jit.to_static( am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) elif am_name == 'vits': - am_inference = jit.to_static( - am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: + am_inference = jit.to_static( + am_inference, + input_spec=[ + InputSpec([-1], dtype=paddle.int64), + InputSpec([1], dtype=paddle.int64), + ]) + else: + am_inference = jit.to_static( + am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) jit.save(am_inference, os.path.join(inference_dir, am)) am_inference = jit.load(os.path.join(inference_dir, am)) return am_inference diff --git a/paddlespeech/t2s/exps/vits/synthesize_e2e.py b/paddlespeech/t2s/exps/vits/synthesize_e2e.py index eb3cad034..dac459a5b 100644 --- a/paddlespeech/t2s/exps/vits/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/vits/synthesize_e2e.py @@ -42,6 +42,9 @@ def evaluate(args): # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) + # acoustic model + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] spk_num = None if args.speaker_dict is not None: @@ -78,7 +81,7 @@ def evaluate(args): am=args.am, inference_dir=args.inference_dir, speaker_dict=args.speaker_dict) - + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) merge_sentences = False @@ -105,10 +108,12 @@ def evaluate(args): for i in range(len(phone_ids)): part_phone_ids = phone_ids[i] spk_id = None - if spk_num is not None: + if am_dataset in {"aishell3", "vctk" + } and spk_num is not None: spk_id = paddle.to_tensor(args.spk_id) - # wav = vits_inference(text=part_phone_ids, sids=spk_id) - wav = vits_inference(part_phone_ids) + wav = vits_inference(part_phone_ids, spk_id) + else: + wav = vits_inference(part_phone_ids) if flags == 0: wav_all = wav flags = 1