diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3edc4b63b..99d147958 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -126,6 +127,12 @@ def main(): paddle.set_device(args.device) + # model_suffix + if os.path.exists(args.am + ".json"): + model_suffix = ".json" + else: + model_suffix = ".pdmodel" + # frontend frontend = get_frontend( lang=args.lang, @@ -135,7 +142,7 @@ def main(): # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, @@ -148,7 +155,7 @@ def main(): # voc_predictor voc_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.voc + ".pdmodel", + model_file=args.voc + model_suffix, params_file=args.voc + ".pdiparams", device=args.device, use_trt=args.use_trt,