add logical to enable pir infer

pull/3921/head
liyulingyue 10 months ago
parent 4bc28d25a3
commit 1512b3b065

@ -127,11 +127,8 @@ def main():
paddle.set_device(args.device)
# model_suffix
if os.path.exists(args.am + ".json"):
model_suffix = ".json"
else:
model_suffix = ".pdmodel"
# set model_suffix
model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel"
# frontend
frontend = get_frontend(

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import paddle
@ -96,13 +97,16 @@ def main():
paddle.set_device(args.device)
# set model_suffix
model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel"
# frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
# am_predictor
am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
model_file=args.am + model_suffix,
params_file=args.am + ".pdiparams",
device=args.device,
use_trt=args.use_trt,

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import paddle
@ -96,13 +97,16 @@ def main():
paddle.set_device(args.device)
# set model_suffix
model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel"
# frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
# am_predictor
am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
model_file=args.am + model_suffix,
params_file=args.am + ".pdiparams",
device=args.device,
use_trt=args.use_trt,

Loading…
Cancel
Save