format code

pull/2050/head
Hui Zhang 2 years ago
parent 3cf1f1f0b5
commit c3f762eb29

@ -1,15 +1,21 @@
#!/usr/bin/env python3
import argparse
import onnx
from onnx import version_converter, helper
import onnx
from onnx import version_converter
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument("--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument("--save-model", type=str, required=True, help='path/to/saved/model.onnx.')
parser.add_argument(
"--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument(
"--save-model",
type=str,
required=True,
help='path/to/saved/model.onnx.')
# Models must be opset10 or higher to be quantized.
parser.add_argument("--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
parser.add_argument(
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
args = parser.parse_args()
@ -24,7 +30,8 @@ if __name__ == '__main__':
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model = version_converter.convert_version(original_model, args.target_opset)
converted_model = version_converter.convert_version(original_model,
args.target_opset)
# print('The model after conversion:\n{}'.format(converted_model))
onnx.save(converted_model, args.save_model)

@ -494,6 +494,8 @@ class SymbolicShapeInference:
# contrib ops
'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \

@ -1,13 +1,20 @@
#!/usr/bin/env python3
import argparse
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
def quantize_onnx_model(onnx_model_path, quantized_model_path, nodes_to_exclude=[]):
from onnxruntime.quantization import quantize_dynamic
from onnxruntime.quantization import QuantType
def quantize_onnx_model(onnx_model_path,
quantized_model_path,
nodes_to_exclude=[]):
print("Starting quantization...")
from onnxruntime.quantization import QuantType, quantize_dynamic
quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8, nodes_to_exclude=nodes_to_exclude)
quantize_dynamic(
onnx_model_path,
quantized_model_path,
weight_type=QuantType.QInt8,
nodes_to_exclude=nodes_to_exclude)
print(f"Quantized model saved to: {quantized_model_path}")
@ -18,26 +25,24 @@ def main():
"--model-in",
type=str,
required=True,
help="ONNX model",
)
help="ONNX model", )
parser.add_argument(
"--model-out",
type=str,
required=True,
default='model.quant.onnx',
help="ONNX model",
)
help="ONNX model", )
parser.add_argument(
"--nodes-to-exclude",
type=str,
required=True,
help="nodes to exclude. e.g. conv,linear.",
)
help="nodes to exclude. e.g. conv,linear.", )
args = parser.parse_args()
nodes_to_exclude = args.nodes_to_exclude.split(',')
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
if __name__ == "__main__":
main()
Loading…
Cancel
Save