diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py new file mode 100644 index 000000000..a5148edda --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 -W ignore::DeprecationWarning +import argparse +import copy +import sys + +import onnx + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + required=True, + help='Path of directory saved the input model.') + parser.add_argument( + '--output_names', + required=True, + nargs='+', + help='The outputs of pruned model.') + parser.add_argument( + '--save_file', required=True, help='Path to save the new onnx model.') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + if len(set(args.output_names)) < len(args.output_names): + print( + "[ERROR] There's dumplicate name in --output_names, which is not allowed." + ) + sys.exit(-1) + + model = onnx.load(args.model) + + # collect all node outputs and graph output + output_tensor_names = set() + for node in model.graph.node: + for out in node.output: + # may contain model output + output_tensor_names.add(out) + + # for out in model.graph.output: + # output_tensor_names.add(out.name) + + for output_name in args.output_names: + if output_name not in output_tensor_names: + print( + "[ERROR] Cannot find output tensor name '{}' in onnx model graph.". + format(output_name)) + sys.exit(-1) + + output_node_indices = set() # has output names + output_to_node = dict() # all node outputs + for i, node in enumerate(model.graph.node): + for out in node.output: + output_to_node[out] = i + if out in args.output_names: + output_node_indices.add(i) + + # from outputs find all the ancestors + reserved_node_indices = copy.deepcopy( + output_node_indices) # nodes need to keep + reserved_inputs = set() # model input to keep + new_output_node_indices = copy.deepcopy(output_node_indices) + + while True and len(new_output_node_indices) > 0: + output_node_indices = copy.deepcopy(new_output_node_indices) + + new_output_node_indices = set() + + for out_node_idx in output_node_indices: + # backtrace to parenet + for ipt in model.graph.node[out_node_idx].input: + if ipt in output_to_node: + reserved_node_indices.add(output_to_node[ipt]) + new_output_node_indices.add(output_to_node[ipt]) + else: + reserved_inputs.add(ipt) + + num_inputs = len(model.graph.input) + num_outputs = len(model.graph.output) + num_nodes = len(model.graph.node) + print( + f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes" + ) + print(f"{len(reserved_node_indices)} node to keep.") + + # del node not to keep + for idx in range(num_nodes - 1, -1, -1): + if idx not in reserved_node_indices: + del model.graph.node[idx] + + # del graph input not to keep + for idx in range(num_inputs - 1, -1, -1): + if model.graph.input[idx].name not in reserved_inputs: + del model.graph.input[idx] + + # del old graph outputs + for i in range(num_outputs): + del model.graph.output[0] + + # new graph output as user input + for out in args.output_names: + model.graph.output.extend([onnx.ValueInfoProto(name=out)]) + + # infer shape + try: + from onnx_infer_shape import SymbolicShapeInference + model = SymbolicShapeInference.infer_shapes( + model, + int_max=2**31 - 1, + auto_merge=True, + guess_output_rank=False, + verbose=1) + except Exception as e: + print(f"skip infer shape step: {e}") + + # check onnx model + onnx.checker.check_model(model) + # save onnx model + onnx.save(model, args.save_file) + print("[Finished] The new model saved in {}.".format(args.save_file)) + print("[DEBUG INFO] The inputs of new model: {}".format( + [x.name for x in model.graph.input])) + print("[DEBUG INFO] The outputs of new model: {}".format( + [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py new file mode 100755 index 000000000..f508c0a35 --- /dev/null +++ b/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 -W ignore::DeprecationWarning +import argparse +import sys + +import onnx + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + required=True, + help='Path of directory saved the input model.') + parser.add_argument( + '--origin_names', + required=True, + nargs='+', + help='The original name you want to modify.') + parser.add_argument( + '--new_names', + required=True, + nargs='+', + help='The new name you want change to, the number of new_names should be same with the number of origin_names' + ) + parser.add_argument( + '--save_file', required=True, help='Path to save the new onnx model.') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + if len(set(args.origin_names)) < len(args.origin_names): + print( + "[ERROR] There's dumplicate name in --origin_names, which is not allowed." + ) + sys.exit(-1) + + if len(set(args.new_names)) < len(args.new_names): + print( + "[ERROR] There's dumplicate name in --new_names, which is not allowed." + ) + sys.exit(-1) + + if len(args.new_names) != len(args.origin_names): + print( + "[ERROR] Number of --new_names must be same with the number of --origin_names." + ) + sys.exit(-1) + + model = onnx.load(args.model) + + # collect input and all node output + output_tensor_names = set() + for ipt in model.graph.input: + output_tensor_names.add(ipt.name) + + for node in model.graph.node: + for out in node.output: + output_tensor_names.add(out) + + for origin_name in args.origin_names: + if origin_name not in output_tensor_names: + print( + f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph." + ) + sys.exit(-1) + + for new_name in args.new_names: + if new_name in output_tensor_names: + print( + "[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed." + ) + sys.exit(-1) + + # rename graph input + for i, ipt in enumerate(model.graph.input): + if ipt.name in args.origin_names: + idx = args.origin_names.index(ipt.name) + model.graph.input[i].name = args.new_names[idx] + + # rename node input and output + for i, node in enumerate(model.graph.node): + for j, ipt in enumerate(node.input): + if ipt in args.origin_names: + idx = args.origin_names.index(ipt) + model.graph.node[i].input[j] = args.new_names[idx] + + for j, out in enumerate(node.output): + if out in args.origin_names: + idx = args.origin_names.index(out) + model.graph.node[i].output[j] = args.new_names[idx] + + # rename graph output + for i, out in enumerate(model.graph.output): + if out.name in args.origin_names: + idx = args.origin_names.index(out.name) + model.graph.output[i].name = args.new_names[idx] + + # check onnx model + onnx.checker.check_model(model) + + # save model + onnx.save(model, args.save_file) + + print("[Finished] The new model saved in {}.".format(args.save_file)) + print("[DEBUG INFO] The inputs of new model: {}".format( + [x.name for x in model.graph.input])) + print("[DEBUG INFO] The outputs of new model: {}".format( + [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/pd_infer_shape.py old mode 100644 new mode 100755 diff --git a/speechx/examples/ds2_ol/onnx/local/prune.sh b/speechx/examples/ds2_ol/onnx/local/prune.sh index ee5f6b5f2..64636bccf 100755 --- a/speechx/examples/ds2_ol/onnx/local/prune.sh +++ b/speechx/examples/ds2_ol/onnx/local/prune.sh @@ -3,6 +3,7 @@ set -e if [ $# != 5 ]; then + # local/prune.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 $PWD echo "usage: $0 model_dir model_filename param_filename outputs_names save_dir" exit 1 fi diff --git a/speechx/examples/ds2_ol/onnx/local/tonnx.sh b/speechx/examples/ds2_ol/onnx/local/tonnx.sh index a57b84f61..58f0d736b 100755 --- a/speechx/examples/ds2_ol/onnx/local/tonnx.sh +++ b/speechx/examples/ds2_ol/onnx/local/tonnx.sh @@ -1,6 +1,7 @@ #!/bin/bash if [ $# != 4 ];then + # local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx echo "usage: $0 model_dir model_name param_name onnx_output_name" exit 1 fi @@ -11,6 +12,7 @@ param=$3 output=$4 pip install paddle2onnx +pip install onnx # https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2 paddle2onnx --model_dir $dir \ diff --git a/speechx/examples/ds2_ol/onnx/run.sh b/speechx/examples/ds2_ol/onnx/run.sh index 37d4f7f7a..a9f7681c6 100755 --- a/speechx/examples/ds2_ol/onnx/run.sh +++ b/speechx/examples/ds2_ol/onnx/run.sh @@ -10,6 +10,9 @@ stop_stage=100 . utils/parse_options.sh data=data +exp=exp + +mkdir -p $data $exp if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then test -f $data/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz || wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz -P $data @@ -25,21 +28,24 @@ param=avg_1.jit.pdiparams output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then - mkdir -p $data/prune + mkdir -p $exp/prune # prune model deps on output_names. - ./local/prune.sh $dir $model $param $output_names $data/prune + ./local/prune.sh $dir $model $param $output_names $exp/prune fi input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}" if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ];then - mkdir -p $data/shape + mkdir -p $exp/shape python3 local/pd_infer_shape.py \ --model_dir $dir \ --model_filename $model \ --params_filename $param \ - --save_dir $data/shape \ + --save_dir $exp/shape \ --input_shape_dict=${input_shape_dict} fi +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then + ./local/tonnx.sh $dir $model $param $exp/model.onnx +fi \ No newline at end of file