parent
28c1794b9b
commit
6477b6f3e6
@ -0,0 +1,127 @@
|
|||||||
|
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
required=True,
|
||||||
|
help='Path of directory saved the input model.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_names',
|
||||||
|
required=True,
|
||||||
|
nargs='+',
|
||||||
|
help='The outputs of pruned model.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--save_file', required=True, help='Path to save the new onnx model.')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
if len(set(args.output_names)) < len(args.output_names):
|
||||||
|
print(
|
||||||
|
"[ERROR] There's dumplicate name in --output_names, which is not allowed."
|
||||||
|
)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
model = onnx.load(args.model)
|
||||||
|
|
||||||
|
# collect all node outputs and graph output
|
||||||
|
output_tensor_names = set()
|
||||||
|
for node in model.graph.node:
|
||||||
|
for out in node.output:
|
||||||
|
# may contain model output
|
||||||
|
output_tensor_names.add(out)
|
||||||
|
|
||||||
|
# for out in model.graph.output:
|
||||||
|
# output_tensor_names.add(out.name)
|
||||||
|
|
||||||
|
for output_name in args.output_names:
|
||||||
|
if output_name not in output_tensor_names:
|
||||||
|
print(
|
||||||
|
"[ERROR] Cannot find output tensor name '{}' in onnx model graph.".
|
||||||
|
format(output_name))
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
output_node_indices = set() # has output names
|
||||||
|
output_to_node = dict() # all node outputs
|
||||||
|
for i, node in enumerate(model.graph.node):
|
||||||
|
for out in node.output:
|
||||||
|
output_to_node[out] = i
|
||||||
|
if out in args.output_names:
|
||||||
|
output_node_indices.add(i)
|
||||||
|
|
||||||
|
# from outputs find all the ancestors
|
||||||
|
reserved_node_indices = copy.deepcopy(
|
||||||
|
output_node_indices) # nodes need to keep
|
||||||
|
reserved_inputs = set() # model input to keep
|
||||||
|
new_output_node_indices = copy.deepcopy(output_node_indices)
|
||||||
|
|
||||||
|
while True and len(new_output_node_indices) > 0:
|
||||||
|
output_node_indices = copy.deepcopy(new_output_node_indices)
|
||||||
|
|
||||||
|
new_output_node_indices = set()
|
||||||
|
|
||||||
|
for out_node_idx in output_node_indices:
|
||||||
|
# backtrace to parenet
|
||||||
|
for ipt in model.graph.node[out_node_idx].input:
|
||||||
|
if ipt in output_to_node:
|
||||||
|
reserved_node_indices.add(output_to_node[ipt])
|
||||||
|
new_output_node_indices.add(output_to_node[ipt])
|
||||||
|
else:
|
||||||
|
reserved_inputs.add(ipt)
|
||||||
|
|
||||||
|
num_inputs = len(model.graph.input)
|
||||||
|
num_outputs = len(model.graph.output)
|
||||||
|
num_nodes = len(model.graph.node)
|
||||||
|
print(
|
||||||
|
f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes"
|
||||||
|
)
|
||||||
|
print(f"{len(reserved_node_indices)} node to keep.")
|
||||||
|
|
||||||
|
# del node not to keep
|
||||||
|
for idx in range(num_nodes - 1, -1, -1):
|
||||||
|
if idx not in reserved_node_indices:
|
||||||
|
del model.graph.node[idx]
|
||||||
|
|
||||||
|
# del graph input not to keep
|
||||||
|
for idx in range(num_inputs - 1, -1, -1):
|
||||||
|
if model.graph.input[idx].name not in reserved_inputs:
|
||||||
|
del model.graph.input[idx]
|
||||||
|
|
||||||
|
# del old graph outputs
|
||||||
|
for i in range(num_outputs):
|
||||||
|
del model.graph.output[0]
|
||||||
|
|
||||||
|
# new graph output as user input
|
||||||
|
for out in args.output_names:
|
||||||
|
model.graph.output.extend([onnx.ValueInfoProto(name=out)])
|
||||||
|
|
||||||
|
# infer shape
|
||||||
|
try:
|
||||||
|
from onnx_infer_shape import SymbolicShapeInference
|
||||||
|
model = SymbolicShapeInference.infer_shapes(
|
||||||
|
model,
|
||||||
|
int_max=2**31 - 1,
|
||||||
|
auto_merge=True,
|
||||||
|
guess_output_rank=False,
|
||||||
|
verbose=1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"skip infer shape step: {e}")
|
||||||
|
|
||||||
|
# check onnx model
|
||||||
|
onnx.checker.check_model(model)
|
||||||
|
# save onnx model
|
||||||
|
onnx.save(model, args.save_file)
|
||||||
|
print("[Finished] The new model saved in {}.".format(args.save_file))
|
||||||
|
print("[DEBUG INFO] The inputs of new model: {}".format(
|
||||||
|
[x.name for x in model.graph.input]))
|
||||||
|
print("[DEBUG INFO] The outputs of new model: {}".format(
|
||||||
|
[x.name for x in model.graph.output]))
|
@ -0,0 +1,110 @@
|
|||||||
|
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
required=True,
|
||||||
|
help='Path of directory saved the input model.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--origin_names',
|
||||||
|
required=True,
|
||||||
|
nargs='+',
|
||||||
|
help='The original name you want to modify.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--new_names',
|
||||||
|
required=True,
|
||||||
|
nargs='+',
|
||||||
|
help='The new name you want change to, the number of new_names should be same with the number of origin_names'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--save_file', required=True, help='Path to save the new onnx model.')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
if len(set(args.origin_names)) < len(args.origin_names):
|
||||||
|
print(
|
||||||
|
"[ERROR] There's dumplicate name in --origin_names, which is not allowed."
|
||||||
|
)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
if len(set(args.new_names)) < len(args.new_names):
|
||||||
|
print(
|
||||||
|
"[ERROR] There's dumplicate name in --new_names, which is not allowed."
|
||||||
|
)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
if len(args.new_names) != len(args.origin_names):
|
||||||
|
print(
|
||||||
|
"[ERROR] Number of --new_names must be same with the number of --origin_names."
|
||||||
|
)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
model = onnx.load(args.model)
|
||||||
|
|
||||||
|
# collect input and all node output
|
||||||
|
output_tensor_names = set()
|
||||||
|
for ipt in model.graph.input:
|
||||||
|
output_tensor_names.add(ipt.name)
|
||||||
|
|
||||||
|
for node in model.graph.node:
|
||||||
|
for out in node.output:
|
||||||
|
output_tensor_names.add(out)
|
||||||
|
|
||||||
|
for origin_name in args.origin_names:
|
||||||
|
if origin_name not in output_tensor_names:
|
||||||
|
print(
|
||||||
|
f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph."
|
||||||
|
)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
for new_name in args.new_names:
|
||||||
|
if new_name in output_tensor_names:
|
||||||
|
print(
|
||||||
|
"[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
|
||||||
|
)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
# rename graph input
|
||||||
|
for i, ipt in enumerate(model.graph.input):
|
||||||
|
if ipt.name in args.origin_names:
|
||||||
|
idx = args.origin_names.index(ipt.name)
|
||||||
|
model.graph.input[i].name = args.new_names[idx]
|
||||||
|
|
||||||
|
# rename node input and output
|
||||||
|
for i, node in enumerate(model.graph.node):
|
||||||
|
for j, ipt in enumerate(node.input):
|
||||||
|
if ipt in args.origin_names:
|
||||||
|
idx = args.origin_names.index(ipt)
|
||||||
|
model.graph.node[i].input[j] = args.new_names[idx]
|
||||||
|
|
||||||
|
for j, out in enumerate(node.output):
|
||||||
|
if out in args.origin_names:
|
||||||
|
idx = args.origin_names.index(out)
|
||||||
|
model.graph.node[i].output[j] = args.new_names[idx]
|
||||||
|
|
||||||
|
# rename graph output
|
||||||
|
for i, out in enumerate(model.graph.output):
|
||||||
|
if out.name in args.origin_names:
|
||||||
|
idx = args.origin_names.index(out.name)
|
||||||
|
model.graph.output[i].name = args.new_names[idx]
|
||||||
|
|
||||||
|
# check onnx model
|
||||||
|
onnx.checker.check_model(model)
|
||||||
|
|
||||||
|
# save model
|
||||||
|
onnx.save(model, args.save_file)
|
||||||
|
|
||||||
|
print("[Finished] The new model saved in {}.".format(args.save_file))
|
||||||
|
print("[DEBUG INFO] The inputs of new model: {}".format(
|
||||||
|
[x.name for x in model.graph.input]))
|
||||||
|
print("[DEBUG INFO] The outputs of new model: {}".format(
|
||||||
|
[x.name for x in model.graph.output]))
|
Loading…
Reference in new issue