From 4bc28d25a3bce122da79aca912f51d4aa2949288 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Sun, 1 Dec 2024 23:19:39 +0800 Subject: [PATCH] add logical to enable pir infer --- paddlespeech/t2s/exps/inference.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3edc4b63b..99d147958 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -126,6 +127,12 @@ def main(): paddle.set_device(args.device) + # model_suffix + if os.path.exists(args.am + ".json"): + model_suffix = ".json" + else: + model_suffix = ".pdmodel" + # frontend frontend = get_frontend( lang=args.lang, @@ -135,7 +142,7 @@ def main(): # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, @@ -148,7 +155,7 @@ def main(): # voc_predictor voc_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.voc + ".pdmodel", + model_file=args.voc + model_suffix, params_file=args.voc + ".pdiparams", device=args.device, use_trt=args.use_trt,