Merge pull request #2116 from zh794390558/onnx
[speechx]remove fluid tools for onnx exportpull/2120/head
commit
803fec21aa
@ -1,111 +0,0 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#2-%E4%BF%AE%E6%94%B9paddle%E6%A8%A1%E5%9E%8B%E8%BE%93%E5%85%A5shape
|
||||
import argparse
|
||||
|
||||
# paddle inference shape
|
||||
|
||||
|
||||
def process_old_ops_desc(program):
|
||||
"""set matmul op head_number attr to 1 is not exist.
|
||||
|
||||
Args:
|
||||
program (_type_): _description_
|
||||
"""
|
||||
for i in range(len(program.blocks[0].ops)):
|
||||
if program.blocks[0].ops[i].type == "matmul":
|
||||
if not program.blocks[0].ops[i].has_attr("head_number"):
|
||||
program.blocks[0].ops[i]._set_attr("head_number", 1)
|
||||
|
||||
|
||||
def infer_shape(program, input_shape_dict):
|
||||
# 2002002
|
||||
model_version = program.desc._version()
|
||||
# 2.2.2
|
||||
paddle_version = paddle.__version__
|
||||
major_ver = model_version // 1000000
|
||||
minor_ver = (model_version - major_ver * 1000000) // 1000
|
||||
patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000
|
||||
model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver)
|
||||
if model_version != paddle_version:
|
||||
print(
|
||||
f"[WARNING] The model is saved by paddlepaddle v{model_version}, but now your paddlepaddle is version of {paddle_version}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model"
|
||||
)
|
||||
|
||||
OP_WITHOUT_KERNEL_SET = {
|
||||
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
|
||||
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
|
||||
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
|
||||
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
|
||||
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
|
||||
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
|
||||
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
|
||||
'copy_cross_scope'
|
||||
}
|
||||
|
||||
for k, v in input_shape_dict.items():
|
||||
program.blocks[0].var(k).desc.set_shape(v)
|
||||
|
||||
for i in range(len(program.blocks)):
|
||||
for j in range(len(program.blocks[0].ops)):
|
||||
# for ops
|
||||
if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET:
|
||||
print(f"not infer: {program.blocks[i].ops[j].type} op")
|
||||
continue
|
||||
print(f"infer: {program.blocks[i].ops[j].type} op")
|
||||
program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
# python pd_infer_shape.py --model_dir data/exp/deepspeech2_online/checkpoints \
|
||||
# --model_filename avg_1.jit.pdmodel\
|
||||
# --params_filename avg_1.jit.pdiparams \
|
||||
# --save_dir . \
|
||||
# --input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--model_filename', required=True, help='model.pdmodel.')
|
||||
parser.add_argument(
|
||||
'--params_filename', required=True, help='model.pdiparams.')
|
||||
parser.add_argument(
|
||||
'--save_dir',
|
||||
required=True,
|
||||
help='directory to save the exported model.')
|
||||
parser.add_argument(
|
||||
'--input_shape_dict', required=True, help="The new shape information.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
import paddle
|
||||
paddle.enable_static()
|
||||
import paddle.fluid as fluid
|
||||
|
||||
input_shape_dict_str = args.input_shape_dict
|
||||
input_shape_dict = eval(input_shape_dict_str)
|
||||
|
||||
print("Start to load paddle model...")
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
|
||||
prog, ipts, outs = fluid.io.load_inference_model(
|
||||
args.model_dir,
|
||||
exe,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
||||
|
||||
process_old_ops_desc(prog)
|
||||
infer_shape(prog, input_shape_dict)
|
||||
|
||||
fluid.io.save_inference_model(
|
||||
args.save_dir,
|
||||
ipts,
|
||||
outs,
|
||||
exe,
|
||||
prog,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
@ -1,158 +0,0 @@
|
||||
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
||||
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#1-%E8%A3%81%E5%89%AApaddle%E6%A8%A1%E5%9E%8B
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
# paddle prune model.
|
||||
|
||||
|
||||
def prepend_feed_ops(program,
|
||||
feed_target_names: List[str],
|
||||
feed_holder_name='feed'):
|
||||
import paddle.fluid.core as core
|
||||
if len(feed_target_names) == 0:
|
||||
return
|
||||
|
||||
global_block = program.global_block()
|
||||
feed_var = global_block.create_var(
|
||||
name=feed_holder_name,
|
||||
type=core.VarDesc.VarType.FEED_MINIBATCH,
|
||||
persistable=True, )
|
||||
|
||||
for i, name in enumerate(feed_target_names, 0):
|
||||
if not global_block.has_var(name):
|
||||
print(
|
||||
f"The input[{i}]: '{name}' doesn't exist in pruned inference program, which will be ignored in new saved model."
|
||||
)
|
||||
continue
|
||||
|
||||
out = global_block.var(name)
|
||||
global_block._prepend_op(
|
||||
type='feed',
|
||||
inputs={'X': [feed_var]},
|
||||
outputs={'Out': [out]},
|
||||
attrs={'col': i}, )
|
||||
|
||||
|
||||
def append_fetch_ops(program,
|
||||
fetch_target_names: List[str],
|
||||
fetch_holder_name='fetch'):
|
||||
"""in the place, we will add the fetch op
|
||||
|
||||
Args:
|
||||
program (_type_): inference program
|
||||
fetch_target_names (List[str]): target names
|
||||
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
|
||||
"""
|
||||
import paddle.fluid.core as core
|
||||
global_block = program.global_block()
|
||||
fetch_var = global_block.create_var(
|
||||
name=fetch_holder_name,
|
||||
type=core.VarDesc.VarType.FETCH_LIST,
|
||||
persistable=True, )
|
||||
|
||||
print(f"the len of fetch_target_names: {len(fetch_target_names)}")
|
||||
|
||||
for i, name in enumerate(fetch_target_names):
|
||||
global_block.append_op(
|
||||
type='fetch',
|
||||
inputs={'X': [name]},
|
||||
outputs={'Out': [fetch_var]},
|
||||
attrs={'col': i}, )
|
||||
|
||||
|
||||
def insert_fetch(program,
|
||||
fetch_target_names: List[str],
|
||||
fetch_holder_name='fetch'):
|
||||
"""in the place, we will add the fetch op
|
||||
|
||||
Args:
|
||||
program (_type_): inference program
|
||||
fetch_target_names (List[str]): target names
|
||||
fetch_holder_name (str, optional): fetch op name. Defaults to 'fetch'.
|
||||
"""
|
||||
global_block = program.global_block()
|
||||
|
||||
# remove fetch
|
||||
need_to_remove_op_index = []
|
||||
for i, op in enumerate(global_block.ops):
|
||||
if op.type == 'fetch':
|
||||
need_to_remove_op_index.append(i)
|
||||
|
||||
for index in reversed(need_to_remove_op_index):
|
||||
global_block._remove_op(index)
|
||||
|
||||
program.desc.flush()
|
||||
|
||||
# append new fetch
|
||||
append_fetch_ops(program, fetch_target_names, fetch_holder_name)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
required=True,
|
||||
help='Path of directory saved the input model.')
|
||||
parser.add_argument(
|
||||
'--model_filename', required=True, help='model.pdmodel.')
|
||||
parser.add_argument(
|
||||
'--params_filename', required=True, help='model.pdiparams.')
|
||||
parser.add_argument(
|
||||
'--output_names',
|
||||
required=True,
|
||||
help='The outputs of model. sep by comma')
|
||||
parser.add_argument(
|
||||
'--save_dir',
|
||||
required=True,
|
||||
help='directory to save the exported model.')
|
||||
parser.add_argument('--debug', default=False, help='output debug info.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
|
||||
args.output_names = args.output_names.split(",")
|
||||
|
||||
if len(set(args.output_names)) < len(args.output_names):
|
||||
print(
|
||||
f"[ERROR] There's dumplicate name in --output_names {args.output_names}, which is not allowed."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
import paddle
|
||||
paddle.enable_static()
|
||||
# hack prepend_feed_ops
|
||||
paddle.fluid.io.prepend_feed_ops = prepend_feed_ops
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
print("start to load paddle model")
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
prog, ipts, outs = fluid.io.load_inference_model(
|
||||
args.model_dir,
|
||||
exe,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
||||
|
||||
print("start to load insert fetch op")
|
||||
new_outputs = []
|
||||
insert_fetch(prog, args.output_names)
|
||||
for out_name in args.output_names:
|
||||
new_outputs.append(prog.global_block().var(out_name))
|
||||
|
||||
# not equal to paddle.static.save_inference_model
|
||||
fluid.io.save_inference_model(
|
||||
args.save_dir,
|
||||
ipts,
|
||||
new_outputs,
|
||||
exe,
|
||||
prog,
|
||||
model_filename=args.model_filename,
|
||||
params_filename=args.params_filename)
|
||||
|
||||
if args.debug:
|
||||
for op in prog.global_block().ops:
|
||||
print(op)
|
@ -1,23 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
if [ $# != 5 ]; then
|
||||
# local/prune.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 $PWD
|
||||
echo "usage: $0 model_dir model_filename param_filename outputs_names save_dir"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dir=$1
|
||||
model=$2
|
||||
param=$3
|
||||
outputs=$4
|
||||
save_dir=$5
|
||||
|
||||
|
||||
python local/pd_prune_model.py \
|
||||
--model_dir $dir \
|
||||
--model_filename $model \
|
||||
--params_filename $param \
|
||||
--output_names $outputs \
|
||||
--save_dir $save_dir
|
Loading…
Reference in new issue