From 63d2cc7ae9626a3eef9a228cb487e4b2f9dbf071 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 3 Dec 2024 05:58:36 +0000 Subject: [PATCH] t2s infernece compatible with PIR api --- paddlespeech/t2s/exps/inference.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3edc4b63b..85cf928c4 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -120,6 +121,15 @@ def parse_args(): return args +def get_model_suffix(inference_dir, model_name): + if os.path.exists(os.path.join(inference_dir, model_name + ".pdmodel")): + return ".pdmodel" + elif os.path.exists(os.path.join(inference_dir, model_name + ".json")): + return ".json" + else: + raise ValueError("model file not found!") + + # only inference for models trained with csmsc now def main(): args = parse_args() @@ -133,9 +143,10 @@ def main(): tones_dict=args.tones_dict) # am_predictor + suffix = get_model_suffix(args.inference_dir, args.am) am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, @@ -146,9 +157,10 @@ def main(): am_dataset = args.am[args.am.rindex('_') + 1:] # voc_predictor + suffix = get_model_suffix(args.inference_dir, args.voc) voc_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.voc + ".pdmodel", + model_file=args.voc + suffix, params_file=args.voc + ".pdiparams", device=args.device, use_trt=args.use_trt,