From 94437c932a8fe0be8f229c5dd4bd233fb7513c1f Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 15 Jan 2025 14:39:56 +0800 Subject: [PATCH] =?UTF-8?q?[Hackathon=207th]=20=E4=BF=AE=E6=94=B9=20infere?= =?UTF-8?q?nce=20=E5=85=BC=E5=AE=B9=20paddle=203.0=20(#3963)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Fix] inference of paddle 3.0 * [Fix] inference of paddle 3.0 * [Fix] inference of paddle 3.0 * [Fix] inference of paddle 3.0 --- paddlespeech/t2s/exps/syn_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index d29dd8110..acfaa012d 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -30,6 +30,7 @@ from paddle.io import DataLoader from paddle.static import InputSpec from yacs.config import CfgNode +import paddlespeech.utils from paddlespeech.t2s.datasets.am_batch_fn import * from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static @@ -589,8 +590,17 @@ def get_predictor( "Predict by TensorRT mode: {}, expect device=='gpu', but device == {}". format(precision, device)) - config = inference.Config( - str(Path(model_dir) / model_file), str(Path(model_dir) / params_file)) + # after paddle 3.0, support new inference interface + if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'): + model_name = str(model_file).rsplit('.', 1)[0] + assert model_name == str(params_file).rstrip( + '.pdiparams' + ), "The prefix of model_file and params_file should be same." + config = inference.Config(model_dir, model_name) + else: + config = inference.Config( + str(Path(model_dir) / model_file), + str(Path(model_dir) / params_file)) if paddle.__version__ <= "2.5.2" and paddle.__version__ != "0.0.0": config.enable_memory_optim() config.switch_ir_optim(True)