format code

pull/2050/head
Hui Zhang 2 years ago
parent 3cf1f1f0b5
commit c3f762eb29

@ -1,15 +1,21 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import onnx
from onnx import version_converter, helper
import onnx
from onnx import version_converter
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(prog=__doc__) parser = argparse.ArgumentParser(prog=__doc__)
parser.add_argument("--model-file", type=str, required=True, help='path/to/the/model.onnx.') parser.add_argument(
parser.add_argument("--save-model", type=str, required=True, help='path/to/saved/model.onnx.') "--model-file", type=str, required=True, help='path/to/the/model.onnx.')
parser.add_argument(
"--save-model",
type=str,
required=True,
help='path/to/saved/model.onnx.')
# Models must be opset10 or higher to be quantized. # Models must be opset10 or higher to be quantized.
parser.add_argument("--target-opset", type=int, default=11, help='path/to/the/model.onnx.') parser.add_argument(
"--target-opset", type=int, default=11, help='path/to/the/model.onnx.')
args = parser.parse_args() args = parser.parse_args()
@ -24,7 +30,8 @@ if __name__ == '__main__':
# A full list of supported adapters can be found here: # A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21 # https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model # Apply the version conversion on the original model
converted_model = version_converter.convert_version(original_model, args.target_opset) converted_model = version_converter.convert_version(original_model,
args.target_opset)
# print('The model after conversion:\n{}'.format(converted_model)) # print('The model after conversion:\n{}'.format(converted_model))
onnx.save(converted_model, args.save_model) onnx.save(converted_model, args.save_model)

@ -494,6 +494,8 @@ class SymbolicShapeInference:
# contrib ops # contrib ops
'Attention', 'BiasGelu', \ 'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \ 'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \ 'FastGelu', 'Gelu', 'LayerNormalization', \

@ -1,13 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
def quantize_onnx_model(onnx_model_path, quantized_model_path, nodes_to_exclude=[]): from onnxruntime.quantization import quantize_dynamic
from onnxruntime.quantization import QuantType
def quantize_onnx_model(onnx_model_path,
quantized_model_path,
nodes_to_exclude=[]):
print("Starting quantization...") print("Starting quantization...")
from onnxruntime.quantization import QuantType, quantize_dynamic
quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8, nodes_to_exclude=nodes_to_exclude) quantize_dynamic(
onnx_model_path,
quantized_model_path,
weight_type=QuantType.QInt8,
nodes_to_exclude=nodes_to_exclude)
print(f"Quantized model saved to: {quantized_model_path}") print(f"Quantized model saved to: {quantized_model_path}")
@ -18,26 +25,24 @@ def main():
"--model-in", "--model-in",
type=str, type=str,
required=True, required=True,
help="ONNX model", help="ONNX model", )
)
parser.add_argument( parser.add_argument(
"--model-out", "--model-out",
type=str, type=str,
required=True, required=True,
default='model.quant.onnx', default='model.quant.onnx',
help="ONNX model", help="ONNX model", )
)
parser.add_argument( parser.add_argument(
"--nodes-to-exclude", "--nodes-to-exclude",
type=str, type=str,
required=True, required=True,
help="nodes to exclude. e.g. conv,linear.", help="nodes to exclude. e.g. conv,linear.", )
)
args = parser.parse_args() args = parser.parse_args()
nodes_to_exclude = args.nodes_to_exclude.split(',') nodes_to_exclude = args.nodes_to_exclude.split(',')
quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude) quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Loading…
Cancel
Save