|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import argparse
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
@ -126,6 +127,12 @@ def main():
|
|
|
|
|
|
|
|
|
|
paddle.set_device(args.device)
|
|
|
|
|
|
|
|
|
|
# model_suffix
|
|
|
|
|
if os.path.exists(args.am + ".json"):
|
|
|
|
|
model_suffix = ".json"
|
|
|
|
|
else:
|
|
|
|
|
model_suffix = ".pdmodel"
|
|
|
|
|
|
|
|
|
|
# frontend
|
|
|
|
|
frontend = get_frontend(
|
|
|
|
|
lang=args.lang,
|
|
|
|
@ -135,7 +142,7 @@ def main():
|
|
|
|
|
# am_predictor
|
|
|
|
|
am_predictor = get_predictor(
|
|
|
|
|
model_dir=args.inference_dir,
|
|
|
|
|
model_file=args.am + ".pdmodel",
|
|
|
|
|
model_file=args.am + model_suffix,
|
|
|
|
|
params_file=args.am + ".pdiparams",
|
|
|
|
|
device=args.device,
|
|
|
|
|
use_trt=args.use_trt,
|
|
|
|
@ -148,7 +155,7 @@ def main():
|
|
|
|
|
# voc_predictor
|
|
|
|
|
voc_predictor = get_predictor(
|
|
|
|
|
model_dir=args.inference_dir,
|
|
|
|
|
model_file=args.voc + ".pdmodel",
|
|
|
|
|
model_file=args.voc + model_suffix,
|
|
|
|
|
params_file=args.voc + ".pdiparams",
|
|
|
|
|
device=args.device,
|
|
|
|
|
use_trt=args.use_trt,
|
|
|
|
|