You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.5 KiB
46 lines
1.5 KiB
#!/usr/bin/env python3
|
|
import argparse
|
|
|
|
import onnxruntime as ort
|
|
|
|
# onnxruntime optimizer.
|
|
# https://onnxruntime.ai/docs/performance/graph-optimizations.html
|
|
# https://onnxruntime.ai/docs/api/python/api_summary.html#api
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--model_in', required=True, type=str, help='Path to onnx model.')
|
|
parser.add_argument(
|
|
'--opt_level',
|
|
required=True,
|
|
type=int,
|
|
default=0,
|
|
choices=[0, 1, 2],
|
|
help='Path to onnx model.')
|
|
parser.add_argument(
|
|
'--model_out', required=True, help='path to save the optimized model.')
|
|
parser.add_argument('--debug', default=False, help='output debug info.')
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments()
|
|
|
|
sess_options = ort.SessionOptions()
|
|
|
|
# Set graph optimization level
|
|
print(f"opt level: {args.opt_level}")
|
|
if args.opt_level == 0:
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
elif args.opt_level == 1:
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
|
else:
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
# To enable model serialization after graph optimization set this
|
|
sess_options.optimized_model_filepath = args.model_out
|
|
|
|
session = ort.InferenceSession(args.model_in, sess_options)
|