From 1512b3b06598f8b6e593bf7f4a7ae79d8c77c067 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Sun, 1 Dec 2024 23:39:54 +0800 Subject: [PATCH] add logical to enable pir infer --- paddlespeech/t2s/exps/inference.py | 7 ++----- paddlespeech/t2s/exps/jets/inference.py | 6 +++++- paddlespeech/t2s/exps/vits/inference.py | 6 +++++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 99d147958..80ecdb170 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -127,11 +127,8 @@ def main(): paddle.set_device(args.device) - # model_suffix - if os.path.exists(args.am + ".json"): - model_suffix = ".json" - else: - model_suffix = ".pdmodel" + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" # frontend frontend = get_frontend( diff --git a/paddlespeech/t2s/exps/jets/inference.py b/paddlespeech/t2s/exps/jets/inference.py index 4f6882eda..d83510a94 100644 --- a/paddlespeech/t2s/exps/jets/inference.py +++ b/paddlespeech/t2s/exps/jets/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -96,13 +97,16 @@ def main(): paddle.set_device(args.device) + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" + # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, diff --git a/paddlespeech/t2s/exps/vits/inference.py b/paddlespeech/t2s/exps/vits/inference.py index 08c1ac566..ba9e2e8da 100644 --- a/paddlespeech/t2s/exps/vits/inference.py +++ b/paddlespeech/t2s/exps/vits/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -96,13 +97,16 @@ def main(): paddle.set_device(args.device) + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" + # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt,