You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
4.2 KiB
110 lines
4.2 KiB
#!/usr/bin/env python3 -W ignore::DeprecationWarning
|
|
# https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#2-%E4%BF%AE%E6%94%B9paddle%E6%A8%A1%E5%9E%8B%E8%BE%93%E5%85%A5shape
|
|
import argparse
|
|
|
|
|
|
def process_old_ops_desc(program):
|
|
"""set matmul op head_number attr to 1 is not exist.
|
|
|
|
Args:
|
|
program (_type_): _description_
|
|
"""
|
|
for i in range(len(program.blocks[0].ops)):
|
|
if program.blocks[0].ops[i].type == "matmul":
|
|
if not program.blocks[0].ops[i].has_attr("head_number"):
|
|
program.blocks[0].ops[i]._set_attr("head_number", 1)
|
|
|
|
|
|
def infer_shape(program, input_shape_dict):
|
|
# 2002002
|
|
model_version = program.desc._version()
|
|
# 2.2.2
|
|
paddle_version = paddle.__version__
|
|
major_ver = model_version // 1000000
|
|
minor_ver = (model_version - major_ver * 1000000) // 1000
|
|
patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000
|
|
model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver)
|
|
if model_version != paddle_version:
|
|
print(
|
|
f"[WARNING] The model is saved by paddlepaddle v{model_version}, but now your paddlepaddle is version of {paddle_version}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model"
|
|
)
|
|
|
|
OP_WITHOUT_KERNEL_SET = {
|
|
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
|
|
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
|
|
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
|
|
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
|
|
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
|
|
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
|
|
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
|
|
'copy_cross_scope'
|
|
}
|
|
|
|
for k, v in input_shape_dict.items():
|
|
program.blocks[0].var(k).desc.set_shape(v)
|
|
|
|
for i in range(len(program.blocks)):
|
|
for j in range(len(program.blocks[0].ops)):
|
|
# for ops
|
|
if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET:
|
|
print(f"not infer: {program.blocks[i].ops[j].type} op")
|
|
continue
|
|
print(f"infer: {program.blocks[i].ops[j].type} op")
|
|
program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc)
|
|
|
|
|
|
def parse_arguments():
|
|
# python pd_infer_shape.py --model_dir data/exp/deepspeech2_online/checkpoints \
|
|
# --model_filename avg_1.jit.pdmodel\
|
|
# --params_filename avg_1.jit.pdiparams \
|
|
# --save_dir . \
|
|
# --input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}"
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--model_dir',
|
|
required=True,
|
|
help='Path of directory saved the input model.')
|
|
parser.add_argument(
|
|
'--model_filename', required=True, help='model.pdmodel.')
|
|
parser.add_argument(
|
|
'--params_filename', required=True, help='model.pdiparams.')
|
|
parser.add_argument(
|
|
'--save_dir',
|
|
required=True,
|
|
help='directory to save the exported model.')
|
|
parser.add_argument(
|
|
'--input_shape_dict', required=True, help="The new shape information.")
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments()
|
|
|
|
import paddle
|
|
paddle.enable_static()
|
|
import paddle.fluid as fluid
|
|
|
|
input_shape_dict_str = args.input_shape_dict
|
|
input_shape_dict = eval(input_shape_dict_str)
|
|
|
|
print("Start to load paddle model...")
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
prog, ipts, outs = fluid.io.load_inference_model(
|
|
args.model_dir,
|
|
exe,
|
|
model_filename=args.model_filename,
|
|
params_filename=args.params_filename)
|
|
|
|
process_old_ops_desc(prog)
|
|
infer_shape(prog, input_shape_dict)
|
|
|
|
fluid.io.save_inference_model(
|
|
args.save_dir,
|
|
ipts,
|
|
outs,
|
|
exe,
|
|
prog,
|
|
model_filename=args.model_filename,
|
|
params_filename=args.params_filename)
|