[Fix] inference of paddle 3.0

pull/3963/head
megemini 8 months ago
parent 7d26f93d2c
commit 30d956d490

@ -30,6 +30,7 @@ from paddle.io import DataLoader
from paddle.static import InputSpec
from yacs.config import CfgNode
import paddlespeech.utils
from paddlespeech.t2s.datasets.am_batch_fn import *
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static
@ -589,8 +590,16 @@ def get_predictor(
"Predict by TensorRT mode: {}, expect device=='gpu', but device == {}".
format(precision, device))
# after paddle 3.0, support new inference interface
if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'):
model_name = model_file.split('.')[0]
assert model_name == params_file.split('.')[
0], "The prefix of model_file and params_file should be same."
config = inference.Config(model_dir, model_name)
else:
config = inference.Config(
str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
str(Path(model_dir) / model_file),
str(Path(model_dir) / params_file))
if paddle.__version__ <= "2.5.2" and paddle.__version__ != "0.0.0":
config.enable_memory_optim()
config.switch_ir_optim(True)

Loading…
Cancel
Save