#!/usr/bin/env python3 -W ignore::DeprecationWarning # https://github.com/jiangjiajun/PaddleUtils/blob/main/paddle/README.md#2-%E4%BF%AE%E6%94%B9paddle%E6%A8%A1%E5%9E%8B%E8%BE%93%E5%85%A5shape import argparse # paddle inference shape def process_old_ops_desc(program): """set matmul op head_number attr to 1 is not exist. Args: program (_type_): _description_ """ for i in range(len(program.blocks[0].ops)): if program.blocks[0].ops[i].type == "matmul": if not program.blocks[0].ops[i].has_attr("head_number"): program.blocks[0].ops[i]._set_attr("head_number", 1) def infer_shape(program, input_shape_dict): # 2002002 model_version = program.desc._version() # 2.2.2 paddle_version = paddle.__version__ major_ver = model_version // 1000000 minor_ver = (model_version - major_ver * 1000000) // 1000 patch_ver = model_version - major_ver * 1000000 - minor_ver * 1000 model_version = "{}.{}.{}".format(major_ver, minor_ver, patch_ver) if model_version != paddle_version: print( f"[WARNING] The model is saved by paddlepaddle v{model_version}, but now your paddlepaddle is version of {paddle_version}, this difference may cause error, it is recommend you reinstall a same version of paddlepaddle for this model" ) OP_WITHOUT_KERNEL_SET = { 'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify', 'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream', 'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv', 'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl', 'copy_cross_scope' } for k, v in input_shape_dict.items(): program.blocks[0].var(k).desc.set_shape(v) for i in range(len(program.blocks)): for j in range(len(program.blocks[0].ops)): # for ops if program.blocks[i].ops[j].type in OP_WITHOUT_KERNEL_SET: print(f"not infer: {program.blocks[i].ops[j].type} op") continue print(f"infer: {program.blocks[i].ops[j].type} op") program.blocks[i].ops[j].desc.infer_shape(program.blocks[i].desc) def parse_arguments(): # python pd_infer_shape.py --model_dir data/exp/deepspeech2_online/checkpoints \ # --model_filename avg_1.jit.pdmodel\ # --params_filename avg_1.jit.pdiparams \ # --save_dir . \ # --input_shape_dict="{'audio_chunk':[1,-1,161], 'audio_chunk_lens':[1], 'chunk_state_c_box':[5, 1, 1024], 'chunk_state_h_box':[5,1,1024]}" parser = argparse.ArgumentParser() parser.add_argument( '--model_dir', required=True, help='Path of directory saved the input model.') parser.add_argument( '--model_filename', required=True, help='model.pdmodel.') parser.add_argument( '--params_filename', required=True, help='model.pdiparams.') parser.add_argument( '--save_dir', required=True, help='directory to save the exported model.') parser.add_argument( '--input_shape_dict', required=True, help="The new shape information.") return parser.parse_args() if __name__ == '__main__': args = parse_arguments() import paddle paddle.enable_static() import paddle.fluid as fluid input_shape_dict_str = args.input_shape_dict input_shape_dict = eval(input_shape_dict_str) print("Start to load paddle model...") exe = fluid.Executor(fluid.CPUPlace()) prog, ipts, outs = fluid.io.load_inference_model( args.model_dir, exe, model_filename=args.model_filename, params_filename=args.params_filename) process_old_ops_desc(prog) infer_shape(prog, input_shape_dict) fluid.io.save_inference_model( args.save_dir, ipts, outs, exe, prog, model_filename=args.model_filename, params_filename=args.params_filename)