You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/speechx/examples/ds2_ol/onnx/local/ort_opt.py

46 lines
1.5 KiB

#!/usr/bin/env python3
import argparse
import onnxruntime as ort
# onnxruntime optimizer.
# https://onnxruntime.ai/docs/performance/graph-optimizations.html
# https://onnxruntime.ai/docs/api/python/api_summary.html#api
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_in', required=True, type=str, help='Path to onnx model.')
parser.add_argument(
'--opt_level',
required=True,
type=int,
default=0,
choices=[0, 1, 2],
help='Path to onnx model.')
parser.add_argument(
'--model_out', required=True, help='path to save the optimized model.')
parser.add_argument('--debug', default=False, help='output debug info.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
sess_options = ort.SessionOptions()
# Set graph optimization level
print(f"opt level: {args.opt_level}")
if args.opt_level == 0:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
elif args.opt_level == 1:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
else:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = args.model_out
session = ort.InferenceSession(args.model_in, sess_options)