From 30d956d49045f30b66e40284fdd7d23bd94efd15 Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 6 Jan 2025 18:54:15 +0800 Subject: [PATCH] [Fix] inference of paddle 3.0 --- paddlespeech/t2s/exps/syn_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index d29dd8110..751845912 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -30,6 +30,7 @@ from paddle.io import DataLoader from paddle.static import InputSpec from yacs.config import CfgNode +import paddlespeech.utils from paddlespeech.t2s.datasets.am_batch_fn import * from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static @@ -589,8 +590,16 @@ def get_predictor( "Predict by TensorRT mode: {}, expect device=='gpu', but device == {}". format(precision, device)) - config = inference.Config( - str(Path(model_dir) / model_file), str(Path(model_dir) / params_file)) + # after paddle 3.0, support new inference interface + if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'): + model_name = model_file.split('.')[0] + assert model_name == params_file.split('.')[ + 0], "The prefix of model_file and params_file should be same." + config = inference.Config(model_dir, model_name) + else: + config = inference.Config( + str(Path(model_dir) / model_file), + str(Path(model_dir) / params_file)) if paddle.__version__ <= "2.5.2" and paddle.__version__ != "0.0.0": config.enable_memory_optim() config.switch_ir_optim(True)