[Hackathon 7th] 修改 inference 兼容 paddle 3.0 (#3963)

* [Fix] inference of paddle 3.0

* [Fix] inference of paddle 3.0

* [Fix] inference of paddle 3.0

* [Fix] inference of paddle 3.0
pull/3968/head
megemini 8 months ago committed by GitHub
parent cb15e382cb
commit 94437c932a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -30,6 +30,7 @@ from paddle.io import DataLoader
from paddle.static import InputSpec from paddle.static import InputSpec
from yacs.config import CfgNode from yacs.config import CfgNode
import paddlespeech.utils
from paddlespeech.t2s.datasets.am_batch_fn import * from paddlespeech.t2s.datasets.am_batch_fn import *
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static
@ -589,8 +590,17 @@ def get_predictor(
"Predict by TensorRT mode: {}, expect device=='gpu', but device == {}". "Predict by TensorRT mode: {}, expect device=='gpu', but device == {}".
format(precision, device)) format(precision, device))
config = inference.Config( # after paddle 3.0, support new inference interface
str(Path(model_dir) / model_file), str(Path(model_dir) / params_file)) if paddlespeech.utils.satisfy_paddle_version('3.0.0-beta'):
model_name = str(model_file).rsplit('.', 1)[0]
assert model_name == str(params_file).rstrip(
'.pdiparams'
), "The prefix of model_file and params_file should be same."
config = inference.Config(model_dir, model_name)
else:
config = inference.Config(
str(Path(model_dir) / model_file),
str(Path(model_dir) / params_file))
if paddle.__version__ <= "2.5.2" and paddle.__version__ != "0.0.0": if paddle.__version__ <= "2.5.2" and paddle.__version__ != "0.0.0":
config.enable_memory_optim() config.enable_memory_optim()
config.switch_ir_optim(True) config.switch_ir_optim(True)

Loading…
Cancel
Save