add logical to enable pir infer

pull/3921/head
liyulingyue 10 months ago
parent 4015676a42
commit 4bc28d25a3

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
import paddle import paddle
@ -126,6 +127,12 @@ def main():
paddle.set_device(args.device) paddle.set_device(args.device)
# model_suffix
if os.path.exists(args.am + ".json"):
model_suffix = ".json"
else:
model_suffix = ".pdmodel"
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,
@ -135,7 +142,7 @@ def main():
# am_predictor # am_predictor
am_predictor = get_predictor( am_predictor = get_predictor(
model_dir=args.inference_dir, model_dir=args.inference_dir,
model_file=args.am + ".pdmodel", model_file=args.am + model_suffix,
params_file=args.am + ".pdiparams", params_file=args.am + ".pdiparams",
device=args.device, device=args.device,
use_trt=args.use_trt, use_trt=args.use_trt,
@ -148,7 +155,7 @@ def main():
# voc_predictor # voc_predictor
voc_predictor = get_predictor( voc_predictor = get_predictor(
model_dir=args.inference_dir, model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel", model_file=args.voc + model_suffix,
params_file=args.voc + ".pdiparams", params_file=args.voc + ".pdiparams",
device=args.device, device=args.device,
use_trt=args.use_trt, use_trt=args.use_trt,

Loading…
Cancel
Save